1
SM2 WWMM
Sun Yimin edited this page 2024-02-23 10:18:37 +08:00
This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

MFMM=Montgomery Friendly modules Montgomery Multiplication

首先NIST P256 / SM2 256 的素数P都是Montgomery Friendly modules。

输入:
X, Y都是Montgomery数值表示
X, Y都用64位的字表示
X = X3 * 2^192 + X2 * 2^128 + X1 * 2^64 + X0
Y = Y3 * 2^192 + Y2 * 2^128 + Y1 * 2^64 + Y0
0<=X, Y < p

输出:
X * Y * 2^(-256) mod p

acc0, acc1, acc2, acc3, acc4, acc5是64位寄存器

第一步计算X * Y0

其结果tmp = acc4 * 2^256 + acc3 * 2^192 + acc2 * 2^128 + acc1 * 2 ^ 64 + acc0。
X 乘以Y的其它高位64位字的结果肯定是 2^64的倍数所以T mod 2 ^ 64 = acc0

第二步first reduction step计算(tmp + acc0 * p) / 2^64

这里p=p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + p0, 不管NIST P256还是SM2p0 = 2^64 - 1
所以我们扩展(tmp + acc0 * p) / 2^64 
= (acc4 * 2^256 + acc3 * 2^192 + acc2 * 2^128 + acc1 * 2 ^ 64 + acc0 + acc0 * ( p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + 2^64 - 1)) / 2^64
= acc4 * 2^192 + acc3 * 2^128 + acc2 * 2^64 + acc1 + acc0 * p3 * 2^128 + acc0*p2*2^64+acc0*p1+acc0
= acc4 * 2^192 + (acc3 + acc0 * p3) * 2^128 + (acc2+acc0*p2) * 2^64 + acc0*p1 + acc1 + acc0

(carry1, acc1) = acc0 + acc1 + acc0 * p1
(carry2, acc2) = carry1 + acc2 + acc0 * p2
(carry3, acc3) = carry2 + acc3 + acc0 * p3
(carry4, acc4) = carry3 + acc4
acc5 = carry4

进位处理后,结果表示成 tmp = acc5 * 2^256 + acc4 * 2^192 + acc3 * 2^128 + acc2 * 2 ^ 64 + acc1

H = high64(acc0*2^32) 超出64位宽部分 , L = low64(acc0*2^32)
NIST P
p = 0xffffffff00000001 0000000000000000 00000000ffffffff ffffffffffffffff
   = p3 * 2^192 + p1 * 2^64 + 2^64 - 1
p * acc0 = acc0 * p3 * 2^192 + acc0 * p1 * 2^64 + acc0 * 2^64 - acc0
(tmp + acc0 * p) / 2^64 = acc0 * p3 * 2^128 + acc0 * p1 + acc0 + acc4 * 2^192 + acc3 * 2^128 + acc2 * 2^64 + acc1
    =acc4 * 2^192 + (acc0 * p3 + acc3) * 2^128 + acc2 * 2^64 + acc0 + acc1 + acc0* (2^32 - 1)
    =acc4 * 2^192 + (acc0 * p3 + acc3) * 2^128 + acc2 * 2^64 + acc1 + acc0* 2^32 
    =acc4 * 2^192 + (acc0 * p3 + acc3) * 2^128 + (acc2 + H(acc0* 2^32))* 2^64 + acc1 + L(acc0* 2^32) 


amd64 汇编表示为:
MOVQ acc0, AX
MOVQ acc0, t1
SHLQ $32, acc0         // L(acc0 * 2^32) 
MULQ p256const1<>(SB)  // acc0 * p3 = (DX, AX), DX为高64位
SHRQ $32, t1           // t1 = H(acc0 * 2^32)
ADDQ acc0, acc1        // (carry1, acc1) = acc1 + L(acc0 * 2^32)
ADCQ t1, acc2          // (carry2, acc2) = carry1 + acc2 + H(acc0 * 2^32)
ADCQ AX, acc3          // (carry3, acc3) = carry2 + acc3 + L(acc0 * p3)
ADCQ DX, acc4          // (carry4, acc4) = carry3 + acc4 + H(acc0 * p3)
ADCQ $0, acc5          // acc5 = carry4
XORQ acc0, acc0        // acc0 = 0
结果用五个64位寄存器表示(acc5, acc4, acc3, acc2, acc1)

arm64 汇编表示为:
ADDS	acc0<<32, acc1, acc1  // (carry1, acc1) = acc1 + L(acc0 * 2^32)
LSR	$32, acc0, t0         // t0 = H(acc0 * 2^32)
MUL	acc0, const1, t1      // t1 = L(acc0 * p3)
UMULH	acc0, const1, acc0    // acc0 = H(acc0 * p3)
ADCS	t0, acc2              // (carry2, acc2) = carry1 + acc2 + H(acc0 * 2^32)
ADCS	t1, acc3              // (carry3, acc3) = carry2 + acc3 + L(acc0 * p3)
ADC	$0, acc0              // acc0 = carry3 + H(acc0 * p3), arm64的实现((acc0, acc4), acc3, acc2, acc1)表示第一次reduction的结果, 不像amd64那样使用acc5, acc4, acc3, acc2, acc1。
结果也用五个64位寄存器表示(acc4, acc3, acc2, acc1) (acc0, 0, 0, 0)
另外arm64中的ZR表示zero registerADC	$0, ZR, acc5 表示 acc5 = carry + 0 + 0。

SM2曲线
p = 0x fffffffeffffffff ffffffffffffffff ffffffff00000000 ffffffffffffffff
   =  (2^64 - 2^32  - 1) * 2^192 + (2^64 - 1) * 2^128 + (2^64 - 2^32) * 2^64 + (2^64 - 1)
   =  (2^64 - 2^32  - 1) * 2^192 + (2^64 - 1) * 2^128 + (2^64 - 2^32 + 1) * 2^64  - 1
   =  (2^64 - 2^32  - 1) * 2^192 + (2^64 ) * 2^128 + ( - 2^32 + 1) * 2^64  - 1
   =  (2^64 - 2^32 ) * 2^192 +  ( - 2^32 + 1) * 2^64  - 1
   =  2^256 + (-2^32) * 2^192 + (1-2^32)*2^64 - 1

p = p3 * 2^192 + p2*2^128 + p1 * 2^64 + 2^64 - 1
(tmp + acc0 * p) / 2^64 = acc4 * 2^192 + (acc3 + acc0*p3) * 2^128 + (acc2 + acc0*p2) * 2^64 + acc1 + acc0*p1 + acc0

amd64 汇编表示为:
MOVQ p256p<>+0x08(SB), AX
MULQ acc0
ADDQ acc0, acc1             // (carry1, acc1) = acc0 + acc1
ADCQ $0, DX                 // DX = carry1 + H(acc0 * p1)
ADDQ AX, acc1               // (carry2, acc1) = acc0 + acc1 + L(acc0*p1)
ADCQ $0, DX                 // DX = DX + carry2
MOVQ DX, t1                 // t1 = H(acc0 * p1) + carry1 + carry2
MOVQ p256p<>+0x010(SB), AX
MULQ acc0
ADDQ t1, acc2               // (carry3, acc2) = t1 + acc2
ADCQ $0, DX                 // DX = carry3 + H(acc0 * p2)
ADDQ AX, acc2               // (carry4, acc2) = L(acc0 * p2) + L(t1 + acc2)
ADCQ $0, DX                 // DX = DX + carry4
MOVQ DX, t1                 // t1 = H(acc0 * p2) + carry3 + carry4
MOVQ p256p<>+0x018(SB), AX
MULQ acc0
ADDQ t1, acc3               // (carry5, acc3) = t1 + acc3
ADCQ $0, DX                 // DX = carry5 + H(acc0 * p3)
ADDQ AX, acc3               // (carry6, acc3) = L(acc0 * p3) + L(t1 + acc3)
ADCQ DX, acc4               // (carry7, acc4) = acc4 + DX + carry6
ADCQ $0, acc5               // acc5 = carry7
XORQ acc0, acc0

arm64 汇编表示为
MUL	const1, acc0, t0
    ADDS    t0, acc1, acc1       // (carry1, acc1) = acc1 + L(acc0*p1)
UMULH	const1, acc0, y0     // y0 = H(acc0*p1)

MUL	const2, acc0, t0    
ADCS	t0, acc2, acc2       // (carry2, acc2) = acc2 +  L(acc0*p2)
UMULH	const2, acc0, hlp0   // hlp0 = H(acc0*p2)

MUL	const3, acc0, t0    // t0 = L(acc0*p3)
ADCS	t0, acc3, acc3      // (carry3,acc3) = acc3 + L(acc0*p3)

UMULH	const3, acc0, hlp1 // hlp1 = H(acc0*p3), 事实上不能用hlp1, 这个寄存器被p256PointAddAsm方法全局使用
ADC	$0, acc4            // acc4 = carry3 + acc4

ADDS	acc0, acc1, acc1  // (carry4, acc1) = acc0 + acc1 + L(acc0*p1)
ADCS	y0, acc2, acc2    // (carry5, acc2) = carry4 + acc2 +  L(acc0*p2) + H(acc0*p1)
ADCS	hlp0, acc3, acc3  // (carry6, acc3) = carry5 + acc3 + L(acc0*p3) + H(acc0*p2)
ADC	$0, hlp1, acc0    // acc0 = carry6 + H(acc0*p3)

手上没有arm64环境只能依赖Travis CI检验代码但是很慢效率很低不过改用arm64-graviton后好多了。

======
用加减替代乘法,但存在潜在风险,进位/借位处理太复杂,所以该实现已经被回滚
p*acc0 = acc0*2^256 -(acc0*2^32)*2^192 + (acc0 - acc0*2^32)*2^64 - acc0
(tmp + acc0 * p) / 2^64 = (acc4 * 2^256 + acc3 * 2^192 + acc2 * 2^128 + acc1 * 2 ^ 64 + acc0 + acc0*2^256 -(acc0*2^32)*2^192 + (acc0 - acc0*2^32)*2^64 - acc0? / 2^64
      = (acc4+acc0)*2^192 + (acc3  - acc0*2^32) * 2^128 + acc2 * 2^64 + (acc1 + acc0 - acc0*2^32)
      = (acc4+acc0)*2^192 + (acc3  - acc0*2^32) * 2^128 + (acc2 - H(acc0*2^32)) * 2^64 + (acc1 + acc0 - L(acc0*2^32))

(carry1, acc1) = acc0+acc1
acc2 = carry1 + acc2       // 有可能进位有可能当acc2 = 0xffffffffffffffff
(carry2, acc1) = acc1 - L  
acc2 = acc2 - H - carry2  // 有可能借位有可能在acc0足够大acc2足够小的情况下

(carry3, acc3) = acc0 + acc3
t1 = acc0 + carry3  //有可能进位吗?
(carry4, acc3) = acc3 - L  
t1 = t1 - H - carry4  // 会有可能小于0吗不可能
(carry5, acc3) = acc3 - acc0
t1 = t1 - carry5   // 会有可能小于0吗不可能

(carry6, acc4) = acc4 + t1
acc5 = carry6
======

最后使用以下算法(主要就是一轮加法,一轮减法),相当有对称性:

   acc4,         acc3,         acc2,        acc1
 + acc0,         0,            0,           acc0
 - H(acc0*2^32)  L(acc0*2^32)  H(acc0*2^32) L(acc0*2^32)

MOVQ acc0, AX
MOVQ acc0, DX
SHLQ $32, AX
SHRQ $32, DX

ADDQ acc0, acc1
ADCQ $0, acc2
ADCQ $0, acc3
ADCQ acc0, acc4
ADCQ $0, acc5
SUBQ AX, acc1
SBBQ DX, acc2
SBBQ AX, acc3
SBBQ DX, acc4
SBBQ $0, acc5

第三步,计算 X * Y1并且和tmp相加

tmp = tmp + X * Y1按逐个64位字相加的原则

tmp = tmp + X0*Y1
tmp = tmp + X1*Y1 * 2^64
tmp = tmp + X2*Y1 * 2^128
tmp = tmp + X3*Y1 * 2^192

(carry1, acc1) = acc1 + X0 * Y1
(carry2, acc2) = acc2 + carry1 + X1 * Y1
(carry3, acc3) = acc3 + carry2 + X2 * Y1
(carry4, acc4) = acc4 + carry3 + X3 * Y1
(carry5, acc5) = acc5 + carry4
acc0 = carry5

最后tmp表示成acc0*2^320 + acc5 * 2^256 + acc4 * 2^192 + acc3 * 2^128 + acc2 * 2 ^ 64 + acc1

第四步second reduction step

计算(tmp + acc1 * p) / 2^64这里p=p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + p0, 不管NIST P256还是SM2p0 = 2^64 - 1
所以我们扩展(tmp + acc1 * p) / 2^64 
= (acc0*2^320 + acc5 * 2^256 + acc4 * 2^192 + acc3 * 2^128 + acc2 * 2 ^ 64 + acc1 + acc1 * ( p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + 2^64 - 1)) / 2^64
= acc0*2^256 + acc5 * 2^192 + (acc4 + acc1*p3)*2^128 + (acc3 + acc1*p2)*2^64 + acc1*p1+ acc2 + acc1

(carry1, acc2) = acc1 + acc2 + acc1 * p1
(carry2, acc3) = carry1 + acc3 + acc1 * p2
(carry3, acc4) = carry2 + acc4 + acc1 * p3
(carry4, acc5) = carry3 + acc5
acc0 = acc0 + carry4

进位处理后,结果表示成 tmp = acc0 * 2^256 + acc5 * 2^192 + acc4 * 2^128 + acc3 * 2 ^ 64 + acc2

第五步计算X * Y2, 并且和tmp相加

tmp = tmp + X * Y2按逐个64位字相加的原则
tmp = tmp + X0*Y2
tmp = tmp + X1*Y2 * 2^64
tmp = tmp + X2*Y2 * 2^128
tmp = tmp + X3*Y2 * 2^192

(carry1, acc2) = acc2 + X0 * Y2
(carry2, acc3) = acc3 + carry1 + X1 * Y2
(carry3, acc4) = acc4 + carry2 + X2 * Y2
(carry4, acc5) = acc5 + carry3 + X3 * Y2
(carry5, acc0) = acc0 + carry4
acc1 = carry5

最后tmp表示成acc1*2^320 + acc0 * 2^256 + acc5 * 2^192 + acc4 * 2^128 + acc3 * 2 ^ 64 + acc2

第六步(Third reduction step)

计算(tmp + acc2 * p) / 2^64这里p=p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + p0, 不管NIST P256还是SM2p0 = 2^64 - 1
所以我们扩展(tmp + acc2 * p) / 2^64 
=(acc1*2^320 + acc0 * 2^256 + acc5 * 2^192 + acc4 * 2^128 + acc3 * 2 ^ 64 + acc2 + acc2 * ( p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + 2^64 - 1) ) / 2^64
=acc1*2^256 + acc0*2^192 + (acc5+acc2*p3)*2^128 + (acc4+acc2*p2)*2^64 + acc2 * p1 + acc3 + acc2

(carry1, acc3) = acc2 + acc3 + acc2 * p1
(carry2, acc4) = carry1 + acc4 + acc2 * p2
(carry3, acc5) = carry2 + acc5 + acc2 * p3
(carry4, acc0) = carry3 + acc0
acc1 = acc1 + carry4

进位处理后,结果表示成 tmp = acc1 * 2^256 + acc0 * 2^192 + acc5 * 2^128 + acc4 * 2 ^ 64 + acc3

第七步计算X * Y3

并且和tmp相加 tmp = tmp + X * Y3按逐个64位字相加的原则
tmp = tmp + X0*Y3
tmp = tmp + X1*Y3 * 2^64
tmp = tmp + X2*Y3 * 2^128
tmp = tmp + X3*Y3 * 2^192

(carry1, acc3) = acc3 + X0 * Y3
(carry2, acc4) = acc4 + carry1 + X1 * Y3
(carry3, acc5) = acc5 + carry2 + X2 * Y3
(carry4, acc0) = acc0 + carry3 + X3 * Y3
(carry5, acc1) = acc1 + carry4
acc2 = carry5

最后tmp表示成acc2*2^320 + acc1 * 2^256 + acc0 * 2^192 + acc5 * 2^128 + acc4 * 2 ^ 64 + acc3

第八步(Last reduction step)

计算(tmp + acc3 * p) / 2^64这里p=p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + p0, 不管NIST P256还是SM2p0 = 2^64 - 1
所以我们扩展(tmp + acc2 * p) / 2^64 
=(acc2*2^320 + acc1 * 2^256 + acc0 * 2^192 + acc5 * 2^128 + acc4 * 2 ^ 64 + acc3 + acc3 * ( p3 * 2^192 + p2 * 2^128 + p1 * 2^64 + 2^64 - 1) ) / 2^64
=acc2*2^256 + acc1*2^192 + (acc0+acc3*p3)*2^128 + (acc5+acc3*p2)*2^64 + acc3 * p1 + acc4 + acc3

(carry1, acc4) = acc3 + acc4 + acc3 * p1
(carry2, acc5) = carry1 + acc5 + acc3 * p2
(carry3, acc0) = carry2 + acc0 + acc3 * p3
(carry4, acc1) = carry3 + acc1
acc2 = acc2 + carry4

T = (acc2, acc1, acc0, acc5, acc4)

第九步如果T >=p则返回T - p, 否则返回T。

aws arm64-graviton2

go test -v -short -bench . -run=^$ ./...
goos: linux
goarch: arm64
pkg: github.com/emmansun/gmsm/sm2
BenchmarkLessThan32_P256
BenchmarkLessThan32_P256-2      	    3698	    279225 ns/op
BenchmarkLessThan32_P256SM2
BenchmarkLessThan32_P256SM2-2   	    4602	    258525 ns/op
BenchmarkMoreThan32_P256
BenchmarkMoreThan32_P256-2      	    4365	    274304 ns/op
BenchmarkMoreThan32_P256SM2
BenchmarkMoreThan32_P256SM2-2   	    4550	    263296 ns/op
PASS
ok  	github.com/emmansun/gmsm/sm2	4.753s