From 51b26c071d5a6c32fe0457e53e9833789f0db983 Mon Sep 17 00:00:00 2001 From: Emman Date: Fri, 29 Apr 2022 12:09:04 +0800 Subject: [PATCH] separate aes/sm4 ni implementation --- sm4/cipher_asm.go | 93 ++++++++++++++++---------- sm4/sm4_gcm_asm.go | 35 +++++----- sm4/sm4_gcm_test.go | 8 +-- sm4/sm4ni_gcm_asm.go | 152 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 232 insertions(+), 56 deletions(-) create mode 100644 sm4/sm4ni_gcm_asm.go diff --git a/sm4/cipher_asm.go b/sm4/cipher_asm.go index 908d437..b05775a 100644 --- a/sm4/cipher_asm.go +++ b/sm4/cipher_asm.go @@ -15,6 +15,11 @@ var supportsAES = cpu.X86.HasAES || cpu.ARM64.HasAES var supportsGFMUL = cpu.X86.HasPCLMULQDQ || cpu.ARM64.HasPMULL var useAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasBMI2 +const ( + INST_AES int = iota + INST_SM4 +) + //go:noescape func encryptBlocksAsm(xk *uint32, dst, src []byte, inst int) @@ -30,46 +35,68 @@ type sm4CipherAsm struct { blocksSize int } +type sm4CipherNI struct { + sm4Cipher +} + +func newCipherNI(key []byte) (cipher.Block, error) { + c := &sm4CipherNI{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}} + expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0], INST_SM4) + if supportsGFMUL { + return &sm4CipherNIGCM{c}, nil + } + return c, nil +} + +func (c *sm4CipherNI) Encrypt(dst, src []byte) { + if len(src) < BlockSize { + panic("sm4: input not full block") + } + if len(dst) < BlockSize { + panic("sm4: output not full block") + } + if subtle.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { + panic("sm4: invalid buffer overlap") + } + encryptBlockAsm(&c.enc[0], &dst[0], &src[0], INST_SM4) +} + +func (c *sm4CipherNI) Decrypt(dst, src []byte) { + if len(src) < BlockSize { + panic("sm4: input not full block") + } + if len(dst) < BlockSize { + panic("sm4: output not full block") + } + if subtle.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { + panic("sm4: invalid buffer overlap") + } + encryptBlockAsm(&c.dec[0], &dst[0], &src[0], INST_SM4) +} + func newCipher(key []byte) (cipher.Block, error) { - if !(supportsAES || supportSM4) { + if supportSM4 { + return newCipherNI(key) + } + + if !supportsAES { return newCipherGeneric(key) } + blocks := 4 if useAVX2 { blocks = 8 } - c := sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, blocks, blocks * BlockSize} - if supportSM4 { - expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0], 1) - } else { - expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0], 0) - } - if (supportsAES || supportSM4) && supportsGFMUL { + c := &sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, blocks, blocks * BlockSize} + expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0], INST_AES) + if supportsGFMUL { return &sm4CipherGCM{c}, nil } - return &c, nil + return c, nil } -func (c *sm4CipherAsm) BlockSize() int { return BlockSize } - func (c *sm4CipherAsm) Concurrency() int { return c.batchBlocks } -func encryptBlockAsmInst(xk *uint32, dst, src *byte) { - if supportSM4 { - encryptBlockAsm(xk, dst, src, 1) - } else { - encryptBlockAsm(xk, dst, src, 0) - } -} - -func encryptBlocksAsmInst(xk *uint32, dst, src []byte) { - if supportSM4 { - encryptBlocksAsm(xk, dst, src, 1) - } else { - encryptBlocksAsm(xk, dst, src, 0) - } -} - func (c *sm4CipherAsm) Encrypt(dst, src []byte) { if len(src) < BlockSize { panic("sm4: input not full block") @@ -80,7 +107,7 @@ func (c *sm4CipherAsm) Encrypt(dst, src []byte) { if subtle.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { panic("sm4: invalid buffer overlap") } - encryptBlockAsmInst(&c.enc[0], &dst[0], &src[0]) + encryptBlockAsm(&c.enc[0], &dst[0], &src[0], INST_AES) } func (c *sm4CipherAsm) EncryptBlocks(dst, src []byte) { @@ -93,7 +120,7 @@ func (c *sm4CipherAsm) EncryptBlocks(dst, src []byte) { if subtle.InexactOverlap(dst[:c.blocksSize], src[:c.blocksSize]) { panic("sm4: invalid buffer overlap") } - encryptBlocksAsmInst(&c.enc[0], dst, src) + encryptBlocksAsm(&c.enc[0], dst, src, INST_AES) } func (c *sm4CipherAsm) Decrypt(dst, src []byte) { @@ -106,7 +133,7 @@ func (c *sm4CipherAsm) Decrypt(dst, src []byte) { if subtle.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { panic("sm4: invalid buffer overlap") } - encryptBlockAsmInst(&c.dec[0], &dst[0], &src[0]) + encryptBlockAsm(&c.dec[0], &dst[0], &src[0], INST_AES) } func (c *sm4CipherAsm) DecryptBlocks(dst, src []byte) { @@ -119,16 +146,16 @@ func (c *sm4CipherAsm) DecryptBlocks(dst, src []byte) { if subtle.InexactOverlap(dst[:c.blocksSize], src[:c.blocksSize]) { panic("sm4: invalid buffer overlap") } - encryptBlocksAsmInst(&c.dec[0], dst, src) + encryptBlocksAsm(&c.dec[0], dst, src, INST_AES) } // expandKey is used by BenchmarkExpand to ensure that the asm implementation // of key expansion is used for the benchmark when it is available. func expandKey(key []byte, enc, dec []uint32) { if supportSM4 { - expandKeyAsm(&key[0], &ck[0], &enc[0], &dec[0], 1) + expandKeyAsm(&key[0], &ck[0], &enc[0], &dec[0], INST_SM4) } else if supportsAES { - expandKeyAsm(&key[0], &ck[0], &enc[0], &dec[0], 0) + expandKeyAsm(&key[0], &ck[0], &enc[0], &dec[0], INST_AES) } else { expandKeyGo(key, enc, dec) } diff --git a/sm4/sm4_gcm_asm.go b/sm4/sm4_gcm_asm.go index ccba10b..74c5481 100644 --- a/sm4/sm4_gcm_asm.go +++ b/sm4/sm4_gcm_asm.go @@ -12,9 +12,9 @@ import ( // sm4CipherGCM implements crypto/cipher.gcmAble so that crypto/cipher.NewGCM // will use the optimised implementation in this file when possible. Instances -// of this type only exist when hasGCMAsm returns true. +// of this type only exist when hasGCMAsm and hasAES returns true. type sm4CipherGCM struct { - sm4CipherAsm + *sm4CipherAsm } // Assert that sm4CipherGCM implements the gcmAble interface. @@ -29,31 +29,22 @@ func gcmSm4Enc(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, rk [] //go:noescape func gcmSm4Dec(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, rk []uint32) -//go:noescape -func gcmSm4niEnc(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, rk []uint32) - -//go:noescape -func gcmSm4niDec(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, rk []uint32) - //go:noescape func gcmSm4Data(productTable *[256]byte, data []byte, T *[16]byte) //go:noescape func gcmSm4Finish(productTable *[256]byte, tagMask, T *[16]byte, pLen, dLen uint64) -type gcmAsm struct { - gcm - bytesProductTable [256]byte -} - +// gcmSm4InitInst is used for test func gcmSm4InitInst(productTable *[256]byte, rk []uint32) { if supportSM4 { - gcmSm4Init(productTable, rk, 1) + gcmSm4Init(productTable, rk, INST_SM4) } else { - gcmSm4Init(productTable, rk, 0) + gcmSm4Init(productTable, rk, INST_AES) } } +// gcmSm4EncInst is used for test func gcmSm4EncInst(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, rk []uint32) { if supportSM4 { gcmSm4niEnc(productTable, dst, src, ctr, T, rk) @@ -62,6 +53,7 @@ func gcmSm4EncInst(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, r } } +// gcmSm4DecInst is used for test func gcmSm4DecInst(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, rk []uint32) { if supportSM4 { gcmSm4niDec(productTable, dst, src, ctr, T, rk) @@ -70,14 +62,19 @@ func gcmSm4DecInst(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, r } } +type gcmAsm struct { + gcm + bytesProductTable [256]byte +} + // NewGCM returns the SM4 cipher wrapped in Galois Counter Mode. This is only // called by crypto/cipher.NewGCM via the gcmAble interface. func (c *sm4CipherGCM) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) { g := &gcmAsm{} - g.cipher = &c.sm4CipherAsm + g.cipher = c.sm4CipherAsm g.nonceSize = nonceSize g.tagSize = tagSize - gcmSm4InitInst(&g.bytesProductTable, g.cipher.enc) + gcmSm4Init(&g.bytesProductTable, g.cipher.enc, INST_AES) return g, nil } @@ -122,7 +119,7 @@ func (g *gcmAsm) Seal(dst, nonce, plaintext, data []byte) []byte { } if len(plaintext) > 0 { - gcmSm4EncInst(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc) + gcmSm4Enc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc) } gcmSm4Finish(&g.bytesProductTable, &tagMask, &tagOut, uint64(len(plaintext)), uint64(len(data))) copy(out[len(plaintext):], tagOut[:]) @@ -175,7 +172,7 @@ func (g *gcmAsm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { panic("cipher: invalid buffer overlap") } if len(ciphertext) > 0 { - gcmSm4DecInst(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc) + gcmSm4Dec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc) } gcmSm4Finish(&g.bytesProductTable, &tagMask, &expectedTag, uint64(len(ciphertext)), uint64(len(data))) diff --git a/sm4/sm4_gcm_test.go b/sm4/sm4_gcm_test.go index cc5b6bc..36e30bc 100644 --- a/sm4/sm4_gcm_test.go +++ b/sm4/sm4_gcm_test.go @@ -11,11 +11,11 @@ import ( func genPrecomputeTable() *gcmAsm { key := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10} - c := sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, 4, 64} + c := &sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, 4, 64} expandKey(key, c.enc, c.dec) c1 := &sm4CipherGCM{c} g := &gcmAsm{} - g.cipher = &c1.sm4CipherAsm + g.cipher = c1.sm4CipherAsm gcmSm4InitInst(&g.bytesProductTable, g.cipher.enc) return g } @@ -145,11 +145,11 @@ func TestBothDataPlaintext(t *testing.T) { func createGcm() *gcmAsm { key := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10} - c := sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, 4, 64} + c := &sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, 4, 64} expandKey(key, c.enc, c.dec) c1 := &sm4CipherGCM{c} g := &gcmAsm{} - g.cipher = &c1.sm4CipherAsm + g.cipher = c1.sm4CipherAsm g.tagSize = 16 gcmSm4InitInst(&g.bytesProductTable, g.cipher.enc) return g diff --git a/sm4/sm4ni_gcm_asm.go b/sm4/sm4ni_gcm_asm.go new file mode 100644 index 0000000..7832378 --- /dev/null +++ b/sm4/sm4ni_gcm_asm.go @@ -0,0 +1,152 @@ +//go:build amd64 || arm64 +// +build amd64 arm64 + +package sm4 + +import ( + "crypto/cipher" + goSubtle "crypto/subtle" + + "github.com/emmansun/gmsm/internal/subtle" +) + +//go:noescape +func gcmSm4niEnc(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, rk []uint32) + +//go:noescape +func gcmSm4niDec(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, rk []uint32) + +// sm4CipherNIGCM implements crypto/cipher.gcmAble so that crypto/cipher.NewGCM +// will use the optimised implementation in this file when possible. Instances +// of this type only exist when hasGCMAsm and hasSM4 returns true. +type sm4CipherNIGCM struct { + *sm4CipherNI +} + +// Assert that sm4CipherNIGCM implements the gcmAble interface. +var _ gcmAble = (*sm4CipherNIGCM)(nil) + +type gcmNI struct { + cipher *sm4CipherNI + nonceSize int + tagSize int + bytesProductTable [256]byte +} + +func (g *gcmNI) NonceSize() int { + return g.nonceSize +} + +func (g *gcmNI) Overhead() int { + return g.tagSize +} + +// NewGCM returns the SM4 cipher wrapped in Galois Counter Mode. This is only +// called by crypto/cipher.NewGCM via the gcmAble interface. +func (c *sm4CipherNIGCM) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) { + g := &gcmNI{} + g.cipher = c.sm4CipherNI + g.nonceSize = nonceSize + g.tagSize = tagSize + gcmSm4Init(&g.bytesProductTable, g.cipher.enc, INST_SM4) + return g, nil +} + +// Seal encrypts and authenticates plaintext. See the cipher.AEAD interface for +// details. +func (g *gcmNI) Seal(dst, nonce, plaintext, data []byte) []byte { + if len(nonce) != g.nonceSize { + panic("cipher: incorrect nonce length given to GCM") + } + if uint64(len(plaintext)) > ((1<<32)-2)*BlockSize { + panic("cipher: message too large for GCM") + } + + var counter, tagMask [gcmBlockSize]byte + + if len(nonce) == gcmStandardNonceSize { + // Init counter to nonce||1 + copy(counter[:], nonce) + counter[gcmBlockSize-1] = 1 + } else { + // Otherwise counter = GHASH(nonce) + gcmSm4Data(&g.bytesProductTable, nonce, &counter) + gcmSm4Finish(&g.bytesProductTable, &tagMask, &counter, uint64(len(nonce)), uint64(0)) + } + + g.cipher.Encrypt(tagMask[:], counter[:]) + + var tagOut [gcmTagSize]byte + gcmSm4Data(&g.bytesProductTable, data, &tagOut) + + ret, out := subtle.SliceForAppend(dst, len(plaintext)+g.tagSize) + if subtle.InexactOverlap(out[:len(plaintext)], plaintext) { + panic("cipher: invalid buffer overlap") + } + + if len(plaintext) > 0 { + gcmSm4niEnc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc) + } + gcmSm4Finish(&g.bytesProductTable, &tagMask, &tagOut, uint64(len(plaintext)), uint64(len(data))) + copy(out[len(plaintext):], tagOut[:]) + + return ret +} + +// Open authenticates and decrypts ciphertext. See the cipher.AEAD interface +// for details. +func (g *gcmNI) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { + if len(nonce) != g.nonceSize { + panic("cipher: incorrect nonce length given to GCM") + } + // Sanity check to prevent the authentication from always succeeding if an implementation + // leaves tagSize uninitialized, for example. + if g.tagSize < gcmMinimumTagSize { + panic("cipher: incorrect GCM tag size") + } + + if len(ciphertext) < g.tagSize { + return nil, errOpen + } + if uint64(len(ciphertext)) > ((1<<32)-2)*uint64(BlockSize)+uint64(g.tagSize) { + return nil, errOpen + } + + tag := ciphertext[len(ciphertext)-g.tagSize:] + ciphertext = ciphertext[:len(ciphertext)-g.tagSize] + + // See GCM spec, section 7.1. + var counter, tagMask [gcmBlockSize]byte + + if len(nonce) == gcmStandardNonceSize { + // Init counter to nonce||1 + copy(counter[:], nonce) + counter[gcmBlockSize-1] = 1 + } else { + // Otherwise counter = GHASH(nonce) + gcmSm4Data(&g.bytesProductTable, nonce, &counter) + gcmSm4Finish(&g.bytesProductTable, &tagMask, &counter, uint64(len(nonce)), uint64(0)) + } + + g.cipher.Encrypt(tagMask[:], counter[:]) + + var expectedTag [gcmTagSize]byte + gcmSm4Data(&g.bytesProductTable, data, &expectedTag) + + ret, out := subtle.SliceForAppend(dst, len(ciphertext)) + if subtle.InexactOverlap(out, ciphertext) { + panic("cipher: invalid buffer overlap") + } + if len(ciphertext) > 0 { + gcmSm4niDec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc) + } + gcmSm4Finish(&g.bytesProductTable, &tagMask, &expectedTag, uint64(len(ciphertext)), uint64(len(data))) + + if goSubtle.ConstantTimeCompare(expectedTag[:g.tagSize], tag) != 1 { + for i := range out { + out[i] = 0 + } + return nil, errOpen + } + return ret, nil +}