internal/zuc: eea supports encoding.BinaryMarshaler & encoding.BinaryUnmarshaler interfaces

This commit is contained in:
Sun Yimin 2025-09-25 11:01:49 +08:00 committed by GitHub
parent a23ee40008
commit 5591ee6602
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 236 additions and 3 deletions

View File

@ -2,6 +2,7 @@ package zuc
import (
"crypto/subtle"
"errors"
"github.com/emmansun/gmsm/internal/alias"
"github.com/emmansun/gmsm/internal/byteorder"
@ -26,6 +27,20 @@ type eea struct {
bucketSize int // size of the state bucket, 0 means no bucket
}
const (
magic = "zuceea"
stateSize = (16 + 6) * 4 // zucState32 size in bytes
minMarshaledSize = len(magic) + stateSize + 8 + 4*3
)
// NewEmptyCipher creates and returns a new empty ZUC-EEA cipher instance.
// This function initializes an empty eea struct that can be used for
// unmarshaling a previously saved state using the UnmarshalBinary method.
// The returned cipher instance is not ready for encryption or decryption.
func NewEmptyCipher() *eea {
return new(eea)
}
// NewCipher creates a stream cipher based on key and iv aguments.
// The key must be 16 bytes long and iv must be 16 bytes long for zuc 128;
// or the key must be 32 bytes long and iv must be 23 bytes long for zuc 256;
@ -57,6 +72,114 @@ func NewCipherWithBucketSize(key, iv []byte, bucketSize int) (*eea, error) {
return c, nil
}
func appendState(b []byte, e *zucState32) []byte {
for i := range 16 {
b = byteorder.BEAppendUint32(b, e.lfsr[i])
}
b = byteorder.BEAppendUint32(b, e.r1)
b = byteorder.BEAppendUint32(b, e.r2)
b = byteorder.BEAppendUint32(b, e.x0)
b = byteorder.BEAppendUint32(b, e.x1)
b = byteorder.BEAppendUint32(b, e.x2)
b = byteorder.BEAppendUint32(b, e.x3)
return b
}
func (e *eea) MarshalBinary() ([]byte, error) {
return e.AppendBinary(make([]byte, 0, minMarshaledSize))
}
func (e *eea) AppendBinary(b []byte) ([]byte, error) {
b = append(b, magic...)
b = appendState(b, &e.zucState32)
b = byteorder.BEAppendUint32(b, uint32(e.xLen))
b = byteorder.BEAppendUint64(b, e.used)
b = byteorder.BEAppendUint32(b, uint32(e.stateIndex))
b = byteorder.BEAppendUint32(b, uint32(e.bucketSize))
if e.xLen > 0 {
b = append(b, e.x[:e.xLen]...)
}
for _, state := range e.states {
b = appendState(b, state)
}
return b, nil
}
func unmarshalState(b []byte, e *zucState32) []byte {
for i := range 16 {
b, e.lfsr[i] = consumeUint32(b)
}
b, e.r1 = consumeUint32(b)
b, e.r2 = consumeUint32(b)
b, e.x0 = consumeUint32(b)
b, e.x1 = consumeUint32(b)
b, e.x2 = consumeUint32(b)
b, e.x3 = consumeUint32(b)
return b
}
func UnmarshalCipher(b []byte) (*eea, error) {
var e eea
if err := e.UnmarshalBinary(b); err != nil {
return nil, err
}
return &e, nil
}
func (e *eea) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || (string(b[:len(magic)]) != magic) {
return errors.New("zuc: invalid eea state identifier")
}
if len(b) < minMarshaledSize {
return errors.New("zuc: invalid eea state size")
}
b = b[len(magic):]
b = unmarshalState(b, &e.zucState32)
var tmpUint32 uint32
b, tmpUint32 = consumeUint32(b)
e.xLen = int(tmpUint32)
b, e.used = consumeUint64(b)
b, tmpUint32 = consumeUint32(b)
e.stateIndex = int(tmpUint32)
b, tmpUint32 = consumeUint32(b)
e.bucketSize = int(tmpUint32)
if e.xLen < 0 || e.xLen > RoundBytes {
return errors.New("zuc: invalid eea remaining bytes length")
}
if e.xLen > 0 {
if len(b) < e.xLen {
return errors.New("zuc: invalid eea remaining bytes")
}
copy(e.x[:e.xLen], b[:e.xLen])
b = b[e.xLen:]
}
statesCount := len(b) / stateSize
if len(b)%stateSize != 0 {
return errors.New("zuc: invalid eea states size")
}
for range statesCount {
var state zucState32
b = unmarshalState(b, &state)
e.states = append(e.states, &state)
}
if e.stateIndex >= len(e.states) {
return errors.New("zuc: invalid eea state index")
}
return nil
}
func consumeUint64(b []byte) ([]byte, uint64) {
return b[8:], byteorder.BEUint64(b)
}
func consumeUint32(b []byte) ([]byte, uint32) {
return b[4:], byteorder.BEUint32(b)
}
// reference GB/T 33133.2-2021 A.2
func construcIV4EEA(count, bearer, direction uint32) []byte {
iv := make([]byte, 16)

View File

@ -3,6 +3,7 @@ package zuc
import (
"bytes"
"crypto/cipher"
"encoding"
"encoding/hex"
"testing"
@ -113,9 +114,11 @@ func TestXORStreamAt(t *testing.T) {
t.Errorf("expected=%x, result=%x\n", expected[32:64], dst[32:64])
}
}
data, _ := c.MarshalBinary()
c2, _ := UnmarshalCipher(data)
for i := 1; i < 4; i++ {
c.XORKeyStreamAt(dst[:i], src[:i], 0)
c.XORKeyStreamAt(dst[32:64], src[32:64], 32)
c2.XORKeyStreamAt(dst[32:64], src[32:64], 32)
if !bytes.Equal(dst[32:64], expected[32:64]) {
t.Errorf("expected=%x, result=%x\n", expected[32:64], dst[32:64])
}
@ -128,8 +131,10 @@ func TestXORStreamAt(t *testing.T) {
if !bytes.Equal(dst[3:16], expected[3:16]) {
t.Errorf("expected=%x, result=%x\n", expected[3:16], dst[3:16])
}
data, _ := c.MarshalBinary()
c2, _ := UnmarshalCipher(data)
c.XORKeyStreamAt(dst[:1], src[:1], 0)
c.XORKeyStreamAt(dst[4:16], src[4:16], 4)
c2.XORKeyStreamAt(dst[4:16], src[4:16], 4)
if !bytes.Equal(dst[4:16], expected[4:16]) {
t.Errorf("expected=%x, result=%x\n", expected[3:16], dst[3:16])
}
@ -215,7 +220,7 @@ func TestEEAXORKeyStreamAtWithBucketSize(t *testing.T) {
src := make([]byte, 10000)
expected := make([]byte, 10000)
dst := make([]byte, 10000)
stateCount := 1 + (10000 + RoundBytes -1) / RoundBytes
stateCount := 1 + (10000+RoundBytes-1)/RoundBytes
noBucketCipher.XORKeyStream(expected, src)
t.Run("Make sure the cached states are used once backward", func(t *testing.T) {
@ -293,6 +298,111 @@ func TestEEAXORKeyStreamAtWithBucketSize(t *testing.T) {
})
}
func TestMarshalUnmarshalBinary(t *testing.T) {
key := bytes.Repeat([]byte{0x11}, 16)
iv := bytes.Repeat([]byte{0x22}, 16)
c, err := NewCipher(key, iv)
if err != nil {
t.Fatalf("NewCipher failed: %v", err)
}
// Marshal and Unmarshal should round-trip
data, err := c.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary failed: %v", err)
}
var c2 encoding.BinaryMarshaler
if c2, err = UnmarshalCipher(data); err != nil {
t.Fatalf("UnmarshalBinary failed: %v", err)
}
// Marshal again and compare
data2, err := c2.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary (after unmarshal) failed: %v", err)
}
if !bytes.Equal(data, data2) {
t.Errorf("MarshalBinary output mismatch after round-trip")
}
}
func TestUnmarshalBinary_InvalidMagic(t *testing.T) {
key := bytes.Repeat([]byte{0x11}, 16)
iv := bytes.Repeat([]byte{0x22}, 16)
c, _ := NewCipher(key, iv)
data, _ := c.MarshalBinary()
data[0] ^= 0xFF // corrupt magic
_, err := UnmarshalCipher(data)
if err == nil || err.Error() != "zuc: invalid eea state identifier" {
t.Errorf("expected invalid eea state identifier error, got %v", err)
}
}
func TestUnmarshalBinary_ShortData(t *testing.T) {
_, err := UnmarshalCipher([]byte("zuceea"))
if err == nil || err.Error() != "zuc: invalid eea state size" {
t.Errorf("expected invalid eea state size error, got %v", err)
}
}
func TestUnmarshalBinary_InvalidXLen(t *testing.T) {
key := bytes.Repeat([]byte{0x11}, 16)
iv := bytes.Repeat([]byte{0x22}, 16)
c, _ := NewCipher(key, iv)
data, _ := c.MarshalBinary()
// corrupt xLen to an invalid value (e.g. 9999)
xLenOffset := len(magic) + stateSize
copy(data[xLenOffset:], bytes.Repeat([]byte{0xFF}, 4))
_, err := UnmarshalCipher(data)
if err == nil || err.Error() != "zuc: invalid eea remaining bytes length" {
t.Errorf("expected invalid eea remaining bytes length error, got %v", err)
}
}
func TestUnmarshalBinary_InvalidStatesSize(t *testing.T) {
key := bytes.Repeat([]byte{0x11}, 16)
iv := bytes.Repeat([]byte{0x22}, 16)
c, _ := NewCipher(key, iv)
data, _ := c.MarshalBinary()
// Truncate data to make states size not a multiple of stateSize
data = append(data, 0x00)
_, err := UnmarshalCipher(data)
if err == nil || err.Error() != "zuc: invalid eea states size" {
t.Errorf("expected invalid eea states size error, got %v", err)
}
}
func TestUnmarshalBinary_InvalidRemainingBytes(t *testing.T) {
key := bytes.Repeat([]byte{0x11}, 16)
iv := bytes.Repeat([]byte{0x22}, 16)
c, err := NewCipher(key, iv)
if err != nil {
t.Fatalf("NewCipher failed: %v", err)
}
data, err := c.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary failed: %v", err)
}
// Modify xLen to a valid value > 0
xLenOffset := len(magic) + stateSize
data[xLenOffset+0] = 0
data[xLenOffset+1] = 0
data[xLenOffset+2] = 0
data[xLenOffset+3] = 8 // xLen = 8
// Truncate data so remaining bytes < xLen
truncated := data[:minMarshaledSize + 4]
c2 := NewEmptyCipher()
err = c2.UnmarshalBinary(truncated)
if err == nil || err.Error() != "zuc: invalid eea remaining bytes" {
t.Errorf("expected error 'zuc: invalid eea remaining bytes', got %v", err)
}
}
func benchmarkStream(b *testing.B, buf []byte) {
b.SetBytes(int64(len(buf)))