gmsm/mldsa/mldsa87.go
2025-05-07 10:05:13 +08:00

501 lines
13 KiB
Go

// Copyright 2025 Sun Yimin. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
//go:build go1.24
package mldsa
import (
"crypto"
"crypto/sha3"
"crypto/subtle"
"errors"
"io"
)
// A PrivateKey87 is the private key for the ML-DSA-87 signature scheme.
type PrivateKey87 struct {
rho [32]byte // public random seed
k [32]byte // private random seed for signing
tr [64]byte // pre-cached public key Hash, H(pk, 64)
s1 [l87]ringElement // private secret of size L with short coefficients (-4..4) or (-2..2)
s2 [k87]ringElement // private secret of size K with short coefficients (-4..4) or (-2..2)
t0 [k87]ringElement // the Polynomial encoding of the 13 LSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the private key.
a [k87 * l87]nttElement // a is generated and stored in NTT representation
}
// A Key87 is the key pair for the ML-DSA-87 signature scheme.
type Key87 struct {
PrivateKey87
xi [32]byte // input seed
t1 [k87]ringElement // the Polynomial encoding of the 10 MSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the public key.
}
// A PublicKey87 is the public key for the ML-DSA-87 signature scheme.
type PublicKey87 struct {
rho [32]byte
t1 [k87]ringElement
tr [64]byte // H(pk, 64), need to further check if public key requires it
a [k87 * l87]nttElement // a is generated and stored in NTT representation
}
// PublicKey generates and returns the corresponding public key for the given
// Key87 instance.
func (sk *Key87) PublicKey() *PublicKey87 {
return &PublicKey87{
rho: sk.rho,
t1: sk.t1,
tr: sk.tr,
a: sk.a,
}
}
func (pk *PublicKey87) Equal(x crypto.PublicKey) bool {
xx, ok := x.(*PublicKey87)
if !ok {
return false
}
return pk.rho == xx.rho && pk.t1 == xx.t1
}
// Bytes converts the PublicKey87 instance into a byte slice.
// See FIPS 204, Algorithm 22, pkEncode()
func (pk *PublicKey87) Bytes() []byte {
// The actual logic is in a separate function to outline this allocation.
b := make([]byte, 0, PublicKeySize87)
return pk.bytes(b)
}
func (pk *PublicKey87) bytes(b []byte) []byte {
b = append(b, pk.rho[:]...)
for _, f := range pk.t1 {
b = simpleBitPack10Bits(b, f)
}
return b
}
// Bytes returns the byte representation of the PrivateKey87.
// It copies the internal seed (xi) into a fixed-size byte array
// and returns it as a slice.
func (sk *Key87) Bytes() []byte {
var b [SeedSize]byte
copy(b[:], sk.xi[:])
return b[:]
}
// Bytes converts the PrivateKey87 instance into a byte slice.
// See FIPS 204, Algorithm 24, skEncode()
func (sk *PrivateKey87) Bytes() []byte {
b := make([]byte, 0, PrivateKeySize87)
return sk.bytes(b)
}
func (sk *PrivateKey87) bytes(b []byte) []byte {
b = append(b, sk.rho[:]...)
b = append(b, sk.k[:]...)
b = append(b, sk.tr[:]...)
for _, f := range sk.s1 {
b = bitPackSigned2(b, f)
}
for _, f := range sk.s2 {
b = bitPackSigned2(b, f)
}
for _, f := range sk.t0 {
b = bitPackSigned4096(b, f)
}
return b
}
func (sk *PrivateKey87) Equal(x any) bool {
xx, ok := x.(*PrivateKey87)
if !ok {
return false
}
return sk.rho == xx.rho && sk.k == xx.k && sk.tr == xx.tr &&
sk.s1 == xx.s1 && sk.s2 == xx.s2 && sk.t0 == xx.t0
}
// GenerateKey87 generates a new Key87 (ML-DSA-87) using the provided random source.
func GenerateKey87(rand io.Reader) (*Key87, error) {
// The actual logic is in a separate function to outline this allocation.
sk := &Key87{}
return generateKey87(sk, rand)
}
func generateKey87(sk *Key87, rand io.Reader) (*Key87, error) {
// Generate a random seed.
var seed [SeedSize]byte
if _, err := io.ReadFull(rand, seed[:]); err != nil {
return nil, err
}
dsaKeyGen87(sk, &seed)
return sk, nil
}
// NewKey87 creates a new instance of Key87 using the provided seed.
func NewKey87(seed []byte) (*Key87, error) {
// The actual logic is in a separate function to outline this allocation.
sk := &Key87{}
return newPrivateKey87FromSeed(sk, seed)
}
func newPrivateKey87FromSeed(sk *Key87, seed []byte) (*Key87, error) {
if len(seed) != SeedSize {
return nil, errors.New("mldsa: invalid seed length")
}
xi := (*[32]byte)(seed)
dsaKeyGen87(sk, xi)
return sk, nil
}
func dsaKeyGen87(sk *Key87, xi *[32]byte) {
sk.xi = *xi
H := sha3.NewSHAKE256()
H.Write(xi[:])
H.Write([]byte{k87})
H.Write([]byte{l87})
K := make([]byte, 128)
H.Read(K)
rho, rho1 := K[:32], K[32:96]
K = K[96:]
sk.rho = [32]byte(rho)
sk.k = [32]byte(K)
s1 := &sk.s1
s2 := &sk.s2
// Algorithm 33, ExpandS
for s := byte(0); s < l87; s++ {
s1[s] = rejBoundedPoly(rho1, eta2, 0, s)
}
for r := byte(0); r < k87; r++ {
s2[r] = rejBoundedPoly(rho1, eta2, 0, r+l87)
}
// Using rho generate A' = A in NTT form
A := &sk.a
// Algorithm 32, ExpandA
for r := byte(0); r < k87; r++ {
for s := byte(0); s < l87; s++ {
A[r*l87+s] = rejNTTPoly(rho, s, r)
}
}
// t = NTT_inv(A' * NTT(s1)) + s2
var s1NTT [l87]nttElement
var nttT [k87]nttElement
for i := range s1 {
s1NTT[i] = ntt(s1[i])
}
for i := range nttT {
for j := range s1NTT {
nttT[i] = polyAdd(nttT[i], nttMul(s1NTT[j], A[i*l87+j]))
}
}
var t [k87]ringElement
t0 := &sk.t0
t1 := &sk.t1
for i := range nttT {
t[i] = polyAdd(inverseNTT(nttT[i]), s2[i])
// compress t
for j := range n {
t1[i][j], t0[i][j] = power2Round(t[i][j])
}
}
H.Reset()
ek := sk.PublicKey().Bytes()
H.Write(ek)
H.Read(sk.tr[:])
}
// NewPublicKey87 decode an public key from its encoded form.
// See FIPS 204, Algorithm 23 pkDecode()
func NewPublicKey87(b []byte) (*PublicKey87, error) {
// The actual logic is in a separate function to outline this allocation.
pk := &PublicKey87{}
return parsePublicKey87(pk, b)
}
// See FIPS 204, Algorithm 23 pkDecode()
func parsePublicKey87(pk *PublicKey87, b []byte) (*PublicKey87, error) {
if len(b) != PublicKeySize87 {
return nil, errors.New("mldsa: invalid public key length")
}
H := sha3.NewSHAKE256()
H.Write(b)
H.Read(pk.tr[:])
copy(pk.rho[:], b[:32])
b = b[32:]
for i := range k87 {
simpleBitUnpack10Bits(b, &pk.t1[i])
b = b[encodingSize10:]
}
A := &pk.a
rho := pk.rho[:]
// Algorithm 32, ExpandA
for r := byte(0); r < k87; r++ {
for s := byte(0); s < l87; s++ {
A[r*l87+s] = rejNTTPoly(rho, s, r)
}
}
return pk, nil
}
// NewPrivateKey87 decode an private key from its encoded form.
// See FIPS 204, Algorithm 25 skDecode()
func NewPrivateKey87(b []byte) (*PrivateKey87, error) {
// The actual logic is in a separate function to outline this allocation.
sk := &PrivateKey87{}
return parsePrivateKey87(sk, b)
}
// See FIPS 204, Algorithm 25 skDecode()
// Decode a private key from its encoded form.
func parsePrivateKey87(sk *PrivateKey87, b []byte) (*PrivateKey87, error) {
if len(b) != PrivateKeySize87 {
return nil, errors.New("mldsa: invalid private key length")
}
copy(sk.rho[:], b[:32])
copy(sk.k[:], b[32:64])
copy(sk.tr[:], b[64:128])
b = b[128:]
for i := range l87 {
f, err := bitUnpackSigned2(b)
if err != nil {
return nil, err
}
sk.s1[i] = f
b = b[encodingSize3:]
}
for i := range k87 {
f, err := bitUnpackSigned2(b)
if err != nil {
return nil, err
}
sk.s2[i] = f
b = b[encodingSize3:]
}
for i := range k87 {
bitUnpackSigned4096(b, &sk.t0[i])
b = b[encodingSize13:]
}
A := &sk.a
rho := sk.rho[:]
// Algorithm 32, ExpandA
for r := byte(0); r < k87; r++ {
for s := byte(0); s < l87; s++ {
A[r*l87+s] = rejNTTPoly(rho, s, r)
}
}
return sk, nil
}
func (sk *PrivateKey87) Sign(rand io.Reader, message, context []byte) ([]byte, error) {
if len(message) == 0 {
return nil, errors.New("mldsa: empty message")
}
if len(context) > 255 {
return nil, errors.New("mldsa: context too long")
}
var seed [SeedSize]byte
if _, err := io.ReadFull(rand, seed[:]); err != nil {
return nil, err
}
H := sha3.NewSHAKE256()
H.Write(sk.tr[:])
H.Write([]byte{0, byte(len(context))})
if len(context) > 0 {
H.Write(context)
}
H.Write(message)
var mu [64]byte
H.Read(mu[:])
return sk.signInternal(seed[:], mu[:])
}
func (sk *PrivateKey87) signInternal(seed, mu []byte) ([]byte, error) {
var s1NTT [l87]nttElement
var s2NTT [k87]nttElement
var t0NTT [k87]nttElement
for i := range s1NTT {
s1NTT[i] = ntt(sk.s1[i])
}
for i := range s2NTT {
s2NTT[i] = ntt(sk.s2[i])
}
for i := range t0NTT {
t0NTT[i] = ntt(sk.t0[i])
}
var rho2 [64 + 2]byte
H := sha3.NewSHAKE256()
H.Write(sk.k[:])
H.Write(seed[:])
H.Write(mu[:])
H.Read(rho2[:64])
A := &sk.a
// rejection sampling loop
for kappa := 0; ; kappa = kappa + l87 {
// expand mask
var y [l87]ringElement
for i := range l87 {
index := kappa + i
rho2[64] = byte(index)
rho2[65] = byte(index >> 8)
y[i] = expandMask(rho2[:], gamma1TwoPower19)
}
// compute w and w1
var w, w1 [k87]ringElement
var wNTT [k87]nttElement
for i := range w {
for j := range y {
wNTT[i] = polyAdd(wNTT[i], nttMul(ntt(y[j]), A[i*l87+j]))
}
w[i] = inverseNTT(wNTT[i])
// high bits
for j := range w[i] {
w1[i][j] = fieldElement(compressHighBits(w[i][j], gamma2QMinus1Div32))
}
}
// commitment hash
var cTilde [lambda256 / 4]byte
var w1Encoded [encodingSize4]byte
H.Reset()
H.Write(mu[:])
for i := range k87 {
simpleBitPack4Bits(w1Encoded[:0], w1[i])
H.Write(w1Encoded[:])
}
H.Read(cTilde[:])
// verifier's challenge
cNTT := ntt(sampleInBall(cTilde[:], tau60))
var cs1 [l87]ringElement
var cs2 [k87]ringElement
var z [l87]ringElement
var r0 [k87][n]int32
// compute <<cs1>> and z = <<cs1>> + y
for i := range l87 {
cs1[i] = inverseNTT(nttMul(cNTT, s1NTT[i]))
z[i] = polyAdd(cs1[i], y[i])
}
// compute <<cs2>> and r0 = LowBits(w - <<cs2>>)
for i := range k87 {
cs2[i] = inverseNTT(nttMul(cNTT, s2NTT[i]))
for j := range cs2[i] {
_, r0[i][j] = decompose(fieldSub(w[i][j], cs2[i][j]), gamma2QMinus1Div32)
}
}
zNorm := vectorInfinityNorm(z[:], 0)
r0Norm := vectorInfinityNormSigned(r0[:], 0)
// if zNorm >= gamma1 - beta || r0Norm >= gamma2 - beta, then continue
if subtle.ConstantTimeLessOrEq(int(gamma1TwoPower19-beta87), zNorm)|subtle.ConstantTimeLessOrEq(int(gamma2QMinus1Div32-beta87), r0Norm) == 1 {
continue
}
// compute <<ct0>>
var ct0 [k87]ringElement
for i := range k87 {
ct0[i] = inverseNTT(nttMul(cNTT, t0NTT[i]))
}
// compute infinity norm of <<ct0>>
ct0Norm := vectorInfinityNorm(ct0[:], 0)
// make hint
var hints [k87]ringElement
vectorMakeHint(ct0[:], cs2[:], w[:], hints[:], gamma2QMinus1Div32)
// if the number of 1 in the hint is greater than omega or the infinity norm of <<ct0>> >= gamma2, then continue
if (subtle.ConstantTimeLessOrEq(int(omega75+1), vectorCountOnes(hints[:])) | subtle.ConstantTimeLessOrEq(gamma2QMinus1Div32, ct0Norm)) == 1 {
continue
}
// signature encoding
sig := make([]byte, 0, sigEncodedLen87)
sig = append(sig, cTilde[:]...)
for i := range l87 {
sig = bitPackSignedTwoPower19(sig, z[i])
}
return hintBitPack(sig, hints[:], omega75), nil
}
}
func (pk *PublicKey87) Verify(sig []byte, message, context []byte) bool {
if len(message) == 0 {
return false
}
if len(context) > 255 {
return false
}
if len(sig) != sigEncodedLen87 {
return false
}
H := sha3.NewSHAKE256()
H.Write(pk.tr[:])
H.Write([]byte{0, byte(len(context))})
if len(context) > 0 {
H.Write(context)
}
H.Write(message)
var mu [64]byte
H.Read(mu[:])
return pk.verifyInternal(sig, mu[:])
}
func (pk *PublicKey87) verifyInternal(sig, mu []byte) bool {
// Decode the signature
cTilde := sig[:lambda256/4]
sig = sig[lambda256/4:]
var z [l87]ringElement
for i := range l87 {
bitUnpackSignedTwoPower19(sig, &z[i])
sig = sig[encodingSize20:]
}
zNorm := vectorInfinityNorm(z[:], 0)
var hints [k87]ringElement
if !hintBitUnpack(sig, hints[:], omega75) {
return false
}
// verifier's challenge
cNTT := ntt(sampleInBall(cTilde[:], tau60))
// t = t1 * 2^d
// tNTT = NTT(t)*cNTT
var tNTT [k87]nttElement
t := pk.t1
for i := range k87 {
for j := range t[i] {
t[i][j] <<= d
}
tNTT[i] = nttMul(ntt(t[i]), cNTT)
}
var w1, wApprox [k87]ringElement
var zNTT [k87]nttElement
for i := range k87 {
for j := 0; j < l87; j++ {
zNTT[i] = polyAdd(zNTT[i], nttMul(ntt(z[j]), pk.a[i*l87+j]))
}
zNTT[i] = polySub(zNTT[i], tNTT[i])
wApprox[i] = inverseNTT(zNTT[i])
}
H := sha3.NewSHAKE256()
H.Write(mu[:])
var w1Encoded [encodingSize4]byte
for i := range k87 {
for j := range wApprox[i] {
w1[i][j] = useHint(hints[i][j], wApprox[i][j], gamma2QMinus1Div32)
}
simpleBitPack4Bits(w1Encoded[:0], w1[i])
H.Write(w1Encoded[:])
}
var cTilde1 [lambda256 / 4]byte
H.Read(cTilde1[:])
return subtle.ConstantTimeLessOrEq(int(gamma1TwoPower19-beta87), zNorm) == 0 &&
subtle.ConstantTimeCompare(cTilde[:], cTilde1[:]) == 1
}