diff --git a/internal/sm2ec/p256_asm_loong64.s b/internal/sm2ec/p256_asm_loong64.s index d4ba580..03f5eda 100644 --- a/internal/sm2ec/p256_asm_loong64.s +++ b/internal/sm2ec/p256_asm_loong64.s @@ -335,6 +335,311 @@ TEXT ·p256NegCond(SB),NOSPLIT,$0 RET +/* ---------------------------------------*/ +// func p256Sqr(res, in *p256Element, n int) +TEXT ·p256Sqr(SB),NOSPLIT,$0 + MOVV res+0(FP), res_ptr + MOVV in+8(FP), x_ptr + MOVV n+16(FP), y_ptr + + MOVV (8*0)(x_ptr), x0 + MOVV (8*1)(x_ptr), x1 + MOVV (8*2)(x_ptr), x2 + MOVV (8*3)(x_ptr), x3 + +sqrLoop: + SUBV $1, y_ptr + CALL sm2P256SqrInternal<>(SB) + MOVV y0, x0 + MOVV y1, x1 + MOVV y2, x2 + MOVV y3, x3 + BNE y_ptr, sqrLoop + + MOVV y0, (8*0)(res_ptr) + MOVV y1, (8*1)(res_ptr) + MOVV y2, (8*2)(res_ptr) + MOVV y3, (8*3)(res_ptr) + + RET +/* ---------------------------------------*/ +// (y3, y2, y1, y0) = (x3, x2, x1, x0) ^ 2 +TEXT sm2P256SqrInternal<>(SB),NOSPLIT,$0 + // x[1:] * x[0] + MULV x0, x1, acc1 + MULHVU x0, x1, acc2 + + MULV x0, x2, t0 + // ADDS t0, acc2 + ADDV t0, acc2, acc2 + SGTU t0, acc2, t1 + MULHVU x0, x2, acc3 + + MULV x0, x3, t0 + // ADCS t0, acc3 + ADDV t1, acc3, acc3 // no carry + ADDV t0, acc3, acc3 + SGTU t0, acc3, t1 + MULHVU x0, x3, acc4 + ADDV t1, acc4, acc4 // no carry + + // x[2:] * x[1] + MULV x1, x2, t0 + // ADDS t0, acc3 + ADDV t0, acc3, acc3 + SGTU t0, acc3, t2 + MULHVU x1, x2, t1 + // ADCS t1, acc4 + ADDV t1, acc4, acc4 + SGTU t1, acc4, t3 + ADDV t2, acc4, acc4 + SGTU t2, acc4, t4 + OR t3, t4, acc5 + + MULV x1, x3, t0 + // ADCS t0, acc4 + ADDV t0, acc4, acc4 + SGTU t0, acc4, t2 + MULHVU x1, x3, t1 + // ADC t1, acc5 + ADDV t1, acc5, acc5 // no carry + + // x[3] * x[2] + MULV x2, x3, t0 + // ADDS t0, acc5 + ADDV t0, acc5, acc5 + SGTU t0, acc5, t2 + MULHVU x2, x3, t1 + // ADC t1, acc6 + ADDV t1, t2, acc6 // no carry + + // *2 + SRLV $63, acc1, t0 + SLLV $1, acc1, acc1 + SRLV $63, acc2, t1 + ALSLV $1, t0, acc2, acc2 + SRLV $63, acc3, t2 + ALSLV $1, t1, acc3, acc3 + SRLV $63, acc4, t3 + ALSLV $1, t2, acc4, acc4 + SRLV $63, acc5, t4 + ALSLV $1, t3, acc5, acc5 + SRLV $63, acc6, acc7 + ALSLV $1, t4, acc6, acc6 + + // Missing products + MULV x0, x0, acc0 + MULHVU x0, x0, t0 + // ADDS t0, acc1 + ADDV t0, acc1, acc1 + SGTU t0, acc1, t1 + MULV x1, x1, t0 + // ADCS t0, acc2 + ADDV t0, t1, t1 // no carry + ADDV t1, acc2, acc2 + SGTU t1, acc2, t2 + MULHVU x1, x1, t0 + // ADCS t0, acc3 + ADDV t0, t2, t2 // no carry + ADDV t2, acc3, acc3 + SGTU t2, acc3, t1 + MULV x2, x2, t0 + // ADCS t0, acc4 + ADDV t0, t1, t1 // no carry + ADDV t1, acc4, acc4 + SGTU t1, acc4, t2 + MULHVU x2, x2, t0 + // ADCS t0, acc5 + ADDV t0, t2, t2 // no carry + ADDV t2, acc5, acc5 + SGTU t2, acc5, t1 + MULV x3, x3, t0 + // ADCS t0, acc6 + ADDV t0, t1, t1 // no carry + ADDV t1, acc6, acc6 + SGTU t1, acc6, t2 + MULHVU x3, x3, t0 + // ADC t0, acc7 + ADDV t0, t2, t2 // no carry + ADDV t2, acc7, acc7 // (acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7) is the result + + // First reduction step + SLLV $32, acc0, t0 + SRLV $32, acc0, t1 + + // SUBS t0, acc1 + SGTU t0, acc1, t2 + SUBV t0, acc1, acc1 + // SBCS t1, acc2 + ADDV t2, t1, t2 // no carry + SGTU t2, acc2, t3 + SUBV t2, acc2, acc2 + // SBCS t0, acc3 + ADDV t3, t0, t3 // no carry + SGTU t3, acc3, t2 + SUBV t3, acc3, acc3 + // SBC t1, acc0 + ADDV t2, t1, t2 // no carry + SUBV t2, acc0, y0 // no borrow + + // ADDS acc0, acc1, acc1 + ADDV acc0, acc1, acc1 + SGTU acc0, acc1, t0 + // ADCS $0, acc2 + ADDV t0, acc2, acc2 + SGTU t0, acc2, t1 + // ADCS $0, acc3 + ADDV t1, acc3, acc3 + SGTU t1, acc3, t0 + // ADC $0, y0, acc0 + ADDV t0, y0, acc0 + + // Second reduction step + SLLV $32, acc1, t0 + SRLV $32, acc1, t1 + + // SUBS t0, acc2 + SGTU t0, acc2, t2 + SUBV t0, acc2, acc2 + // SBCS t1, acc3 + ADDV t2, t1, t3 // no carry + SGTU t3, acc3, t2 + SUBV t3, acc3, acc3 + // SBCS t0, acc0 + ADDV t2, t0, t2 // no carry + SGTU t2, acc0, t3 + SUBV t2, acc0, acc0 + // SBC t1, acc1 + ADDV t3, t1, t2 // no carry + SUBV t2, acc1, y0 // no borrow + + // ADDS acc1, acc2 + ADDV acc1, acc2, acc2 + SGTU acc1, acc2, t0 + // ADCS $0, acc3 + ADDV t0, acc3, acc3 + SGTU t0, acc3, t1 + // ADCS $0, acc0 + ADDV t1, acc0, acc0 + SGTU t1, acc0, t0 + // ADC $0, y0, acc1 + ADDV t0, y0, acc1 + + // Third reduction step + SLLV $32, acc2, t0 + SRLV $32, acc2, t1 + + // SUBS t0, acc3 + SGTU t0, acc3, t2 + SUBV t0, acc3, acc3 + // SBCS t1, acc0 + ADDV t2, t1, t3 // no carry + SGTU t3, acc0, t2 + SUBV t3, acc0, acc0 + // SBCS t0, acc1 + ADDV t2, t0, t2 // no carry + SGTU t2, acc1, t3 + SUBV t2, acc1, acc1 + // SBC t1, acc2 + ADDV t3, t1, t2 // no carry + SUBV t2, acc2, y0 // no borrow + + // ADDS acc2, acc3 + ADDV acc2, acc3, acc3 + SGTU acc2, acc3, t0 + // ADCS $0, acc0 + ADDV t0, acc0, acc0 + SGTU t0, acc0, t1 + // ADCS $0, acc1 + ADDV t1, acc1, acc1 + SGTU t1, acc1, t0 + // ADC $0, y0, acc2 + ADDV t0, y0, acc2 + + // Last reduction step + SLLV $32, acc3, t0 + SRLV $32, acc3, t1 + + // SUBS t0, acc0 + SGTU t0, acc0, t2 + SUBV t0, acc0, acc0 + // SBCS t1, acc1 + ADDV t2, t1, t3 // no carry + SGTU t3, acc1, t2 + SUBV t3, acc1, acc1 + // SBCS t0, acc2 + ADDV t2, t0, t2 // no carry + SGTU t2, acc2, t3 + SUBV t2, acc2, acc2 + // SBC t1, acc3 + ADDV t3, t1, t2 // no carry + SUBV t2, acc3, y0 // no borrow + + // ADDS acc3, acc0 + ADDV acc3, acc0, acc0 + SGTU acc3, acc0, t0 + // ADCS $0, acc1 + ADDV t0, acc1, acc1 + SGTU t0, acc1, t1 + // ADCS $0, acc2 + ADDV t1, acc2, acc2 + SGTU t1, acc2, t0 + // ADC $0, y0, acc3 + ADDV t0, y0, acc3 + + // Add bits [511:256] of the sqr result + ADDV acc4, acc0, y0 + SGTU acc4, y0, t0 + ADDV acc5, acc1, y1 + SGTU acc5, y1, t1 + ADDV t0, y1, y1 + SGTU t0, y1, t2 + OR t1, t2, t0 + ADDV acc6, acc2, y2 + SGTU acc6, y2, t1 + ADDV t0, y2, y2 + SGTU t0, y2, t2 + OR t1, t2, t0 + ADDV acc7, acc3, y3 + SGTU acc7, y3, t1 + ADDV t0, y3, y3 + SGTU t0, y3, t2 + OR t1, t2, t0 + + // Final reduction + ADDV $1, y0, acc4 + SGTU y0, acc4, t1 + MOVV p256one<>+0X08(SB), t2 + ADDV t2, t1, t1 // no carry + ADDV y1, t1, acc5 + SGTU y1, acc5, t3 + ADDV t3, y2, acc6 + SGTU y2, acc6, t4 + ADDV $1, t2, t2 + ADDV t4, t2, t2 // no carry + ADDV y3, t2, acc7 + SGTU y3, acc7, t4 + OR t0, t4, t0 + + MASKNEZ t0, y0, y0 + MASKEQZ t0, acc4, acc4 + OR acc4, y0 + + MASKNEZ t0, y1, y1 + MASKEQZ t0, acc5, acc5 + OR acc5, y1 + + MASKNEZ t0, y2, y2 + MASKEQZ t0, acc6, acc6 + OR acc6, y2 + + MASKNEZ t0, y3, y3 + MASKEQZ t0, acc7, acc7 + OR acc7, y3 + + RET + +/* ---------------------------------------*/ // (y3, y2, y1, y0) = (x3, x2, x1, x0) * (y3, y2, y1, y0) TEXT sm2P256MulInternal<>(SB),NOSPLIT,$0 // y[0] * x diff --git a/internal/sm2ec/sm2p256_asm_loong64.go b/internal/sm2ec/sm2p256_asm_loong64.go index cf27239..255dc6b 100644 --- a/internal/sm2ec/sm2p256_asm_loong64.go +++ b/internal/sm2ec/sm2p256_asm_loong64.go @@ -37,3 +37,8 @@ func p256NegCond(val *p256Element, cond int) // //go:noescape func p256Mul(res, in1, in2 *p256Element) + +// Montgomery square, repeated n times (n >= 1). +// +//go:noescape +func p256Sqr(res, in *p256Element, n int) diff --git a/internal/sm2ec/sm2p256_asm_loong64_test.go b/internal/sm2ec/sm2p256_asm_loong64_test.go index 0c3358a..cf6bca8 100644 --- a/internal/sm2ec/sm2p256_asm_loong64_test.go +++ b/internal/sm2ec/sm2p256_asm_loong64_test.go @@ -180,3 +180,53 @@ func TestFuzzyP256Mul(t *testing.T) { p256MulTest(t, x, y, p, r) } } + +func p256SqrTest(t *testing.T, x, p, r *big.Int) { + x1 := new(big.Int).Mul(x, r) + x1 = x1.Mod(x1, p) + ax := new(p256Element) + res := new(p256Element) + res2 := new(p256Element) + one := new(p256Element) + one[0] = 1 + fromBig(ax, x1) + p256Sqr(res2, ax, 1) + p256Mul(res, res2, one) + resInt := toBigInt(res) + + expected := new(big.Int).Mul(x, x) + expected = expected.Mod(expected, p) + if resInt.Cmp(expected) != 0 { + t.FailNow() + } +} + +func TestP256SqrPMinus1(t *testing.T) { + p, _ := new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16) + r, _ := new(big.Int).SetString("10000000000000000000000000000000000000000000000000000000000000000", 16) + pMinus1 := new(big.Int).Sub(p, big.NewInt(1)) + p256SqrTest(t, pMinus1, p, r) +} + +func TestFuzzyP256Sqr(t *testing.T) { + p, _ := new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16) + r, _ := new(big.Int).SetString("10000000000000000000000000000000000000000000000000000000000000000", 16) + var scalar1 [32]byte + var timeout *time.Timer + + if testing.Short() { + timeout = time.NewTimer(10 * time.Millisecond) + } else { + timeout = time.NewTimer(2 * time.Second) + } + for { + select { + case <-timeout.C: + return + default: + } + io.ReadFull(rand.Reader, scalar1[:]) + x := new(big.Int).SetBytes(scalar1[:]) + p256SqrTest(t, x, p, r) + } +}