diff --git a/internal/zuc/eea.go b/internal/zuc/eea.go index a0821e7..1715288 100644 --- a/internal/zuc/eea.go +++ b/internal/zuc/eea.go @@ -2,6 +2,7 @@ package zuc import ( "crypto/subtle" + "errors" "github.com/emmansun/gmsm/internal/alias" "github.com/emmansun/gmsm/internal/byteorder" @@ -26,6 +27,20 @@ type eea struct { bucketSize int // size of the state bucket, 0 means no bucket } +const ( + magic = "zuceea" + stateSize = (16 + 6) * 4 // zucState32 size in bytes + minMarshaledSize = len(magic) + stateSize + 8 + 4*3 +) + +// NewEmptyCipher creates and returns a new empty ZUC-EEA cipher instance. +// This function initializes an empty eea struct that can be used for +// unmarshaling a previously saved state using the UnmarshalBinary method. +// The returned cipher instance is not ready for encryption or decryption. +func NewEmptyCipher() *eea { + return new(eea) +} + // NewCipher creates a stream cipher based on key and iv aguments. // The key must be 16 bytes long and iv must be 16 bytes long for zuc 128; // or the key must be 32 bytes long and iv must be 23 bytes long for zuc 256; @@ -57,6 +72,114 @@ func NewCipherWithBucketSize(key, iv []byte, bucketSize int) (*eea, error) { return c, nil } +func appendState(b []byte, e *zucState32) []byte { + for i := range 16 { + b = byteorder.BEAppendUint32(b, e.lfsr[i]) + } + b = byteorder.BEAppendUint32(b, e.r1) + b = byteorder.BEAppendUint32(b, e.r2) + b = byteorder.BEAppendUint32(b, e.x0) + b = byteorder.BEAppendUint32(b, e.x1) + b = byteorder.BEAppendUint32(b, e.x2) + b = byteorder.BEAppendUint32(b, e.x3) + + return b +} + +func (e *eea) MarshalBinary() ([]byte, error) { + return e.AppendBinary(make([]byte, 0, minMarshaledSize)) +} + +func (e *eea) AppendBinary(b []byte) ([]byte, error) { + b = append(b, magic...) + b = appendState(b, &e.zucState32) + b = byteorder.BEAppendUint32(b, uint32(e.xLen)) + b = byteorder.BEAppendUint64(b, e.used) + b = byteorder.BEAppendUint32(b, uint32(e.stateIndex)) + b = byteorder.BEAppendUint32(b, uint32(e.bucketSize)) + if e.xLen > 0 { + b = append(b, e.x[:e.xLen]...) + } + for _, state := range e.states { + b = appendState(b, state) + } + return b, nil +} + +func unmarshalState(b []byte, e *zucState32) []byte { + for i := range 16 { + b, e.lfsr[i] = consumeUint32(b) + } + b, e.r1 = consumeUint32(b) + b, e.r2 = consumeUint32(b) + b, e.x0 = consumeUint32(b) + b, e.x1 = consumeUint32(b) + b, e.x2 = consumeUint32(b) + b, e.x3 = consumeUint32(b) + return b +} + +func UnmarshalCipher(b []byte) (*eea, error) { + var e eea + if err := e.UnmarshalBinary(b); err != nil { + return nil, err + } + return &e, nil +} + +func (e *eea) UnmarshalBinary(b []byte) error { + if len(b) < len(magic) || (string(b[:len(magic)]) != magic) { + return errors.New("zuc: invalid eea state identifier") + } + if len(b) < minMarshaledSize { + return errors.New("zuc: invalid eea state size") + } + b = b[len(magic):] + b = unmarshalState(b, &e.zucState32) + var tmpUint32 uint32 + b, tmpUint32 = consumeUint32(b) + e.xLen = int(tmpUint32) + b, e.used = consumeUint64(b) + b, tmpUint32 = consumeUint32(b) + e.stateIndex = int(tmpUint32) + b, tmpUint32 = consumeUint32(b) + e.bucketSize = int(tmpUint32) + if e.xLen < 0 || e.xLen > RoundBytes { + return errors.New("zuc: invalid eea remaining bytes length") + } + if e.xLen > 0 { + if len(b) < e.xLen { + return errors.New("zuc: invalid eea remaining bytes") + } + copy(e.x[:e.xLen], b[:e.xLen]) + b = b[e.xLen:] + } + statesCount := len(b) / stateSize + if len(b)%stateSize != 0 { + return errors.New("zuc: invalid eea states size") + } + + for range statesCount { + var state zucState32 + b = unmarshalState(b, &state) + e.states = append(e.states, &state) + } + + if e.stateIndex >= len(e.states) { + return errors.New("zuc: invalid eea state index") + } + + return nil +} + +func consumeUint64(b []byte) ([]byte, uint64) { + return b[8:], byteorder.BEUint64(b) +} + +func consumeUint32(b []byte) ([]byte, uint32) { + return b[4:], byteorder.BEUint32(b) +} + // reference GB/T 33133.2-2021 A.2 func construcIV4EEA(count, bearer, direction uint32) []byte { iv := make([]byte, 16) diff --git a/internal/zuc/eea_test.go b/internal/zuc/eea_test.go index f95ac3a..1921c0b 100644 --- a/internal/zuc/eea_test.go +++ b/internal/zuc/eea_test.go @@ -3,6 +3,7 @@ package zuc import ( "bytes" "crypto/cipher" + "encoding" "encoding/hex" "testing" @@ -113,9 +114,11 @@ func TestXORStreamAt(t *testing.T) { t.Errorf("expected=%x, result=%x\n", expected[32:64], dst[32:64]) } } + data, _ := c.MarshalBinary() + c2, _ := UnmarshalCipher(data) for i := 1; i < 4; i++ { c.XORKeyStreamAt(dst[:i], src[:i], 0) - c.XORKeyStreamAt(dst[32:64], src[32:64], 32) + c2.XORKeyStreamAt(dst[32:64], src[32:64], 32) if !bytes.Equal(dst[32:64], expected[32:64]) { t.Errorf("expected=%x, result=%x\n", expected[32:64], dst[32:64]) } @@ -128,8 +131,10 @@ func TestXORStreamAt(t *testing.T) { if !bytes.Equal(dst[3:16], expected[3:16]) { t.Errorf("expected=%x, result=%x\n", expected[3:16], dst[3:16]) } + data, _ := c.MarshalBinary() + c2, _ := UnmarshalCipher(data) c.XORKeyStreamAt(dst[:1], src[:1], 0) - c.XORKeyStreamAt(dst[4:16], src[4:16], 4) + c2.XORKeyStreamAt(dst[4:16], src[4:16], 4) if !bytes.Equal(dst[4:16], expected[4:16]) { t.Errorf("expected=%x, result=%x\n", expected[3:16], dst[3:16]) } @@ -215,7 +220,7 @@ func TestEEAXORKeyStreamAtWithBucketSize(t *testing.T) { src := make([]byte, 10000) expected := make([]byte, 10000) dst := make([]byte, 10000) - stateCount := 1 + (10000 + RoundBytes -1) / RoundBytes + stateCount := 1 + (10000+RoundBytes-1)/RoundBytes noBucketCipher.XORKeyStream(expected, src) t.Run("Make sure the cached states are used once backward", func(t *testing.T) { @@ -293,6 +298,111 @@ func TestEEAXORKeyStreamAtWithBucketSize(t *testing.T) { }) } +func TestMarshalUnmarshalBinary(t *testing.T) { + key := bytes.Repeat([]byte{0x11}, 16) + iv := bytes.Repeat([]byte{0x22}, 16) + c, err := NewCipher(key, iv) + if err != nil { + t.Fatalf("NewCipher failed: %v", err) + } + + // Marshal and Unmarshal should round-trip + data, err := c.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary failed: %v", err) + } + + var c2 encoding.BinaryMarshaler + if c2, err = UnmarshalCipher(data); err != nil { + t.Fatalf("UnmarshalBinary failed: %v", err) + } + + // Marshal again and compare + data2, err := c2.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary (after unmarshal) failed: %v", err) + } + if !bytes.Equal(data, data2) { + t.Errorf("MarshalBinary output mismatch after round-trip") + } +} + +func TestUnmarshalBinary_InvalidMagic(t *testing.T) { + key := bytes.Repeat([]byte{0x11}, 16) + iv := bytes.Repeat([]byte{0x22}, 16) + c, _ := NewCipher(key, iv) + data, _ := c.MarshalBinary() + data[0] ^= 0xFF // corrupt magic + + _, err := UnmarshalCipher(data) + if err == nil || err.Error() != "zuc: invalid eea state identifier" { + t.Errorf("expected invalid eea state identifier error, got %v", err) + } +} + +func TestUnmarshalBinary_ShortData(t *testing.T) { + _, err := UnmarshalCipher([]byte("zuceea")) + if err == nil || err.Error() != "zuc: invalid eea state size" { + t.Errorf("expected invalid eea state size error, got %v", err) + } +} + +func TestUnmarshalBinary_InvalidXLen(t *testing.T) { + key := bytes.Repeat([]byte{0x11}, 16) + iv := bytes.Repeat([]byte{0x22}, 16) + c, _ := NewCipher(key, iv) + data, _ := c.MarshalBinary() + // corrupt xLen to an invalid value (e.g. 9999) + xLenOffset := len(magic) + stateSize + copy(data[xLenOffset:], bytes.Repeat([]byte{0xFF}, 4)) + _, err := UnmarshalCipher(data) + if err == nil || err.Error() != "zuc: invalid eea remaining bytes length" { + t.Errorf("expected invalid eea remaining bytes length error, got %v", err) + } +} + +func TestUnmarshalBinary_InvalidStatesSize(t *testing.T) { + key := bytes.Repeat([]byte{0x11}, 16) + iv := bytes.Repeat([]byte{0x22}, 16) + c, _ := NewCipher(key, iv) + data, _ := c.MarshalBinary() + // Truncate data to make states size not a multiple of stateSize + data = append(data, 0x00) + _, err := UnmarshalCipher(data) + if err == nil || err.Error() != "zuc: invalid eea states size" { + t.Errorf("expected invalid eea states size error, got %v", err) + } +} + +func TestUnmarshalBinary_InvalidRemainingBytes(t *testing.T) { + key := bytes.Repeat([]byte{0x11}, 16) + iv := bytes.Repeat([]byte{0x22}, 16) + c, err := NewCipher(key, iv) + if err != nil { + t.Fatalf("NewCipher failed: %v", err) + } + data, err := c.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary failed: %v", err) + } + + // Modify xLen to a valid value > 0 + xLenOffset := len(magic) + stateSize + data[xLenOffset+0] = 0 + data[xLenOffset+1] = 0 + data[xLenOffset+2] = 0 + data[xLenOffset+3] = 8 // xLen = 8 + + // Truncate data so remaining bytes < xLen + truncated := data[:minMarshaledSize + 4] + + c2 := NewEmptyCipher() + err = c2.UnmarshalBinary(truncated) + if err == nil || err.Error() != "zuc: invalid eea remaining bytes" { + t.Errorf("expected error 'zuc: invalid eea remaining bytes', got %v", err) + } +} + func benchmarkStream(b *testing.B, buf []byte) { b.SetBytes(int64(len(buf)))