diff --git a/sm3/sm3block.go b/sm3/sm3block.go index fe53519..d32e470 100644 --- a/sm3/sm3block.go +++ b/sm3/sm3block.go @@ -7,13 +7,6 @@ var _T = []uint32{ 0x7a879d8a, } -func t(j uint8) uint32 { - if j < 16 { - return _T[0] - } - return _T[1] -} - func p0(x uint32) uint32 { return x ^ bits.RotateLeft32(x, 9) ^ bits.RotateLeft32(x, 17) } @@ -22,17 +15,11 @@ func p1(x uint32) uint32 { return x ^ bits.RotateLeft32(x, 15) ^ bits.RotateLeft32(x, 23) } -func ff(j uint8, x, y, z uint32) uint32 { - if j < 16 { - return x ^ y ^ z - } +func ff(x, y, z uint32) uint32 { return (x & y) | (x & z) | (y & z) } -func gg(j uint8, x, y, z uint32) uint32 { - if j < 16 { - return x ^ y ^ z - } +func gg(x, y, z uint32) uint32 { return (x & y) | (^x & z) } @@ -40,19 +27,51 @@ func block(dig *digest, p []byte) { var w [68]uint32 h0, h1, h2, h3, h4, h5, h6, h7 := dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7] for len(p) >= chunk { - for i := 0; i < 16; i++ { + for i := 0; i < 4; i++ { j := i * 4 w[i] = uint32(p[j])<<24 | uint32(p[j+1])<<16 | uint32(p[j+2])<<8 | uint32(p[j+3]) } - for i := 16; i < 68; i++ { - w[i] = p1(w[i-16]^w[i-9]^bits.RotateLeft32(w[i-3], 15)) ^ bits.RotateLeft32(w[i-13], 7) ^ w[i-6] - } a, b, c, d, e, f, g, h := h0, h1, h2, h3, h4, h5, h6, h7 - for i := 0; i < 64; i++ { - ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(t(uint8(i)), i), 7) + for i := 0; i < 12; i++ { + j := (i + 4) * 4 + w[i+4] = uint32(p[j])<<24 | uint32(p[j+1])<<16 | uint32(p[j+2])<<8 | uint32(p[j+3]) + ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T[0], i), 7) ss2 := ss1 ^ bits.RotateLeft32(a, 12) - tt1 := ff(uint8(i), a, b, c) + d + ss2 + (w[i] ^ w[i+4]) - tt2 := gg(uint8(i), e, f, g) + h + ss1 + w[i] + tt1 := a ^ b ^ c + d + ss2 + (w[i] ^ w[i+4]) + tt2 := e ^ f ^ g + h + ss1 + w[i] + d = c + c = bits.RotateLeft32(b, 9) + b = a + a = tt1 + h = g + g = bits.RotateLeft32(f, 19) + f = e + e = p0(tt2) + } + + for i := 12; i < 16; i++ { + w[i+4] = p1(w[i-12]^w[i-5]^bits.RotateLeft32(w[i+1], 15)) ^ bits.RotateLeft32(w[i-9], 7) ^ w[i-2] + ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T[0], i), 7) + ss2 := ss1 ^ bits.RotateLeft32(a, 12) + tt1 := a ^ b ^ c + d + ss2 + (w[i] ^ w[i+4]) + tt2 := e ^ f ^ g + h + ss1 + w[i] + d = c + c = bits.RotateLeft32(b, 9) + b = a + a = tt1 + h = g + g = bits.RotateLeft32(f, 19) + f = e + e = p0(tt2) + } + + for i := 16; i < 64; i++ { + w[i+4] = p1(w[i-12]^w[i-5]^bits.RotateLeft32(w[i+1], 15)) ^ bits.RotateLeft32(w[i-9], 7) ^ w[i-2] + ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T[1], i), 7) + ss2 := ss1 ^ bits.RotateLeft32(a, 12) + tt1 := ff(a, b, c) + d + ss2 + (w[i] ^ w[i+4]) + tt2 := gg(e, f, g) + h + ss1 + w[i] + d = c c = bits.RotateLeft32(b, 9) b = a