package sm3 import ( "bytes" "encoding" "encoding/base64" "fmt" "hash" "io" "testing" ) type sm3Test struct { out string in string halfState string // marshaled hash state after first half of in written, used by TestGoldenMarshal } var golden = []sm3Test{ {"66c7f0f462eeedd9d1f2d46bdc10e4e24167c4875cf2f7a2297da02b8f4ba8e0", "abc", "sm3\x03s\x80\x16oI\x14\xb2\xb9\x17$B\xd7ڊ\x06\x00\xa9o0\xbc\x1618\xaa\xe3\x8d\xeeM\xb0\xfb\x0eNa\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, {"debe9ff92275b8a138604889c18e5a4d6fdb70e5387e5765293dcba39c0c5732", "abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd", "sm3\x03s\x80\x16oI\x14\xb2\xb9\x17$B\xd7ڊ\x06\x00\xa9o0\xbc\x1618\xaa\xe3\x8d\xeeM\xb0\xfb\x0eNabcdabcdabcdabcdabcdabcdabcdabcd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 "}, } func TestGolden(t *testing.T) { for i := 0; i < len(golden); i++ { g := golden[i] h := Sum([]byte(g.in)) s := fmt.Sprintf("%x", h) fmt.Printf("%s\n", base64.StdEncoding.EncodeToString(h[:])) if s != g.out { t.Fatalf("SM3 function: sm3(%s) = %s want %s", g.in, s, g.out) } c := New() for j := 0; j < 3; j++ { if j < 2 { io.WriteString(c, g.in) } else { io.WriteString(c, g.in[0:len(g.in)/2]) c.Sum(nil) io.WriteString(c, g.in[len(g.in)/2:]) } s := fmt.Sprintf("%x", c.Sum(nil)) if s != g.out { t.Fatalf("sm3[%d](%s) = %s want %s", j, g.in, s, g.out) } c.Reset() } } } func TestGoldenMarshal(t *testing.T) { tests := []struct { name string newHash func() hash.Hash gold []sm3Test }{ {"", New, golden}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { for _, g := range tt.gold { h := tt.newHash() h2 := tt.newHash() io.WriteString(h, g.in[:len(g.in)/2]) state, err := h.(encoding.BinaryMarshaler).MarshalBinary() if err != nil { t.Errorf("could not marshal: %v", err) continue } if string(state) != g.halfState { t.Errorf("sm3%s(%q) state = %q, want %q", tt.name, g.in, state, g.halfState) continue } if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil { t.Errorf("could not unmarshal: %v", err) continue } io.WriteString(h, g.in[len(g.in)/2:]) io.WriteString(h2, g.in[len(g.in)/2:]) if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) { t.Errorf("sm3%s(%q) = 0x%x != marshaled 0x%x", tt.name, g.in, actual, actual2) } } }) } } func TestSize(t *testing.T) { c := New() if got := c.Size(); got != Size { t.Errorf("Size = %d; want %d", got, Size) } } func TestBlockSize(t *testing.T) { c := New() if got := c.BlockSize(); got != BlockSize { t.Errorf("BlockSize = %d want %d", got, BlockSize) } } var bench = New() var buf = make([]byte, 8192) func benchmarkSize(b *testing.B, size int) { b.SetBytes(int64(size)) sum := make([]byte, bench.Size()) for i := 0; i < b.N; i++ { bench.Reset() bench.Write(buf[:size]) bench.Sum(sum[:0]) } } func BenchmarkHash8Bytes(b *testing.B) { benchmarkSize(b, 8) } func BenchmarkHash1K(b *testing.B) { benchmarkSize(b, 1024) } func BenchmarkHash8K(b *testing.B) { benchmarkSize(b, 8192) }