mldsa: refactor the implementation of key and sign/verify

This commit is contained in:
Sun Yimin 2025-09-12 18:23:23 +08:00 committed by GitHub
parent b294ea7388
commit fd2eedf24b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 373 additions and 335 deletions

View File

@ -109,11 +109,26 @@ type PrivateKey44 struct {
s1 [l44]ringElement // private secret of size L with short coefficients (-4..4) or (-2..2) s1 [l44]ringElement // private secret of size L with short coefficients (-4..4) or (-2..2)
s2 [k44]ringElement // private secret of size K with short coefficients (-4..4) or (-2..2) s2 [k44]ringElement // private secret of size K with short coefficients (-4..4) or (-2..2)
t0 [k44]ringElement // the Polynomial encoding of the 13 LSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the private key. t0 [k44]ringElement // the Polynomial encoding of the 13 LSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the private key.
t1 [k44]ringElement // the Polynomial encoding of the 10 MSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the public key.
s1NTTCache [l44]nttElement s1NTTCache [l44]nttElement
s2NTTCache [k44]nttElement s2NTTCache [k44]nttElement
t0NTTCache [k44]nttElement t0NTTCache [k44]nttElement
a [k44 * l44]nttElement // a is generated and stored in NTT representation a [k44 * l44]nttElement // a is generated and stored in NTT representation
nttOnce sync.Once nttOnce sync.Once
t1Once sync.Once
}
// PublicKey returns the public key corresponding to the private key.
// Although we can derive the public key from the private key,
// but we do NOT need to derive it at most of the time.
func (sk *PrivateKey44) PublicKey() crypto.PublicKey {
sk.ensureT1()
return &PublicKey44{
rho: sk.rho,
t1: sk.t1,
tr: sk.tr,
a: sk.a,
}
} }
func (sk *PrivateKey44) ensureNTT() { func (sk *PrivateKey44) ensureNTT() {
@ -130,11 +145,36 @@ func (sk *PrivateKey44) ensureNTT() {
}) })
} }
func (sk *PrivateKey44) ensureT1() {
sk.ensureNTT()
sk.t1Once.Do(func() {
// t = NTT_inv(A' * NTT(s1)) + s2
s1NTT := sk.s1NTTCache
A := sk.a
s2 := sk.s2
var nttT [k44]nttElement
for i := range nttT {
for j := range s1NTT {
nttT[i] = polyAdd(nttT[i], nttMul(s1NTT[j], A[i*l44+j]))
}
}
var t [k44]ringElement
t1 := &sk.t1
for i := range nttT {
t[i] = polyAdd(inverseNTT(nttT[i]), s2[i])
// compress t
for j := range n {
t1[i][j], _ = power2Round(t[i][j])
}
}
})
}
// A Key44 is the key pair for the ML-DSA-44 signature scheme. // A Key44 is the key pair for the ML-DSA-44 signature scheme.
type Key44 struct { type Key44 struct {
PrivateKey44 PrivateKey44
xi [32]byte // input seed xi [32]byte // input seed
t1 [k44]ringElement // the Polynomial encoding of the 10 MSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the public key.
} }
// A PublicKey44 is the public key for the ML-DSA-44 signature scheme. // A PublicKey44 is the public key for the ML-DSA-44 signature scheme.
@ -158,12 +198,21 @@ func (sk *Key44) PublicKey() *PublicKey44 {
} }
} }
// Seed returns a byte slice of the secret key's seed value.
func (sk *Key44) Seed() []byte {
var b [SeedSize]byte
copy(b[:], sk.xi[:])
return b[:]
}
func (pk *PublicKey44) Equal(x crypto.PublicKey) bool { func (pk *PublicKey44) Equal(x crypto.PublicKey) bool {
xx, ok := x.(*PublicKey44) xx, ok := x.(*PublicKey44)
if !ok { if !ok {
return false return false
} }
return pk.rho == xx.rho && pk.t1 == xx.t1 b1 := pk.Bytes()
b2 := xx.Bytes()
return subtle.ConstantTimeCompare(b1, b2) == 1
} }
// Bytes converts the PublicKey44 instance into a byte slice. // Bytes converts the PublicKey44 instance into a byte slice.
@ -194,15 +243,6 @@ func (pk *PublicKey44) ensureNTT() {
}) })
} }
// Bytes returns the byte representation of the PrivateKey44.
// It copies the internal seed (xi) into a fixed-size byte array
// and returns it as a slice.
func (sk *Key44) Bytes() []byte {
var b [SeedSize]byte
copy(b[:], sk.xi[:])
return b[:]
}
// Bytes converts the PrivateKey44 instance into a byte slice. // Bytes converts the PrivateKey44 instance into a byte slice.
// See FIPS 204, Algorithm 24, skEncode() // See FIPS 204, Algorithm 24, skEncode()
func (sk *PrivateKey44) Bytes() []byte { func (sk *PrivateKey44) Bytes() []byte {
@ -231,8 +271,9 @@ func (sk *PrivateKey44) Equal(x any) bool {
if !ok { if !ok {
return false return false
} }
return sk.rho == xx.rho && sk.k == xx.k && sk.tr == xx.tr && b1 := sk.Bytes()
sk.s1 == xx.s1 && sk.s2 == xx.s2 && sk.t0 == xx.t0 b2 := xx.Bytes()
return subtle.ConstantTimeCompare(b1, b2) == 1
} }
// GenerateKey44 generates a new Key44 (ML-DSA-44) using the provided random source. // GenerateKey44 generates a new Key44 (ML-DSA-44) using the provided random source.
@ -284,17 +325,17 @@ func dsaKeyGen44(sk *Key44, xi *[32]byte) {
s1 := &sk.s1 s1 := &sk.s1
s2 := &sk.s2 s2 := &sk.s2
// Algorithm 33, ExpandS // Algorithm 33, ExpandS
for s := byte(0); s < l44; s++ { for s := range byte(l44) {
s1[s] = rejBoundedPoly(rho1, eta2, 0, s) s1[s] = rejBoundedPoly(rho1, eta2, 0, s)
} }
for r := byte(0); r < k44; r++ { for r := range byte(k44) {
s2[r] = rejBoundedPoly(rho1, eta2, 0, r+l44) s2[r] = rejBoundedPoly(rho1, eta2, 0, r+l44)
} }
// Using rho generate A' = A in NTT form // Using rho generate A' = A in NTT form
A := &sk.a A := &sk.a
// Algorithm 32, ExpandA // Algorithm 32, ExpandA
for r := byte(0); r < k44; r++ { for r := range byte(k44) {
for s := byte(0); s < l44; s++ { for s := byte(0); s < l44; s++ {
A[r*l44+s] = rejNTTPoly(rho, s, r) A[r*l44+s] = rejNTTPoly(rho, s, r)
} }
@ -355,8 +396,8 @@ func parsePublicKey44(pk *PublicKey44, b []byte) (*PublicKey44, error) {
A := &pk.a A := &pk.a
rho := pk.rho[:] rho := pk.rho[:]
// Algorithm 32, ExpandA // Algorithm 32, ExpandA
for r := byte(0); r < k44; r++ { for r := range byte(k44) {
for s := byte(0); s < l44; s++ { for s := range byte(l44) {
A[r*l44+s] = rejNTTPoly(rho, s, r) A[r*l44+s] = rejNTTPoly(rho, s, r)
} }
} }
@ -404,32 +445,42 @@ func parsePrivateKey44(sk *PrivateKey44, b []byte) (*PrivateKey44, error) {
A := &sk.a A := &sk.a
rho := sk.rho[:] rho := sk.rho[:]
// Algorithm 32, ExpandA // Algorithm 32, ExpandA
for r := byte(0); r < k44; r++ { for r := range byte(k44) {
for s := byte(0); s < l44; s++ { for s := range byte(l44) {
A[r*l44+s] = rejNTTPoly(rho, s, r) A[r*l44+s] = rejNTTPoly(rho, s, r)
} }
} }
return sk, nil return sk, nil
} }
// Sign generates a digital signature for the given message and context using the private key. // Sign signs the provided digest using the private key. It is a wrapper around SignMessage.
// It uses a random seed generated from the provided random source. // It satisfies the crypto.Signer interface.
func (sk *PrivateKey44) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
return sk.SignMessage(rand, digest, opts)
}
// SignMessage signs a message with the private key.
// It satisfies the crypto.MessageSigner interface.
// //
// Parameters: // The function supports pre-hashing the message by providing a hash OID in the options.
// - rand: An io.Reader used to generate a random seed for signing. // Context data can also be provided, but is limited to 255 bytes.
// - message: The message to be signed. Must not be empty. func (sk *PrivateKey44) SignMessage(rand io.Reader, message []byte, opts crypto.SignerOpts) ([]byte, error) {
// - context: An optional context for domain separation. Must not exceed 255 bytes. var (
// context []byte
// Returns: hashOID asn1.ObjectIdentifier
// - A byte slice containing the generated signature. indicator byte = 0
// - An error if the message is empty, the context is too long, or if there is an issue )
// reading from the random source. if opts, ok := opts.(*Options); ok {
// context = opts.Context
// Note: hashOID = opts.PrehashOID
// - The function uses SHAKE256 from the SHA-3 family for hashing. }
// - The signing process involves generating a unique seed and a hash-based if len(hashOID) != 0 {
// message digest (mu) before delegating to the internal signing function. var err error
func (sk *PrivateKey44) Sign(rand io.Reader, message, context []byte) ([]byte, error) { if message, err = preHash(hashOID, message); err != nil {
return nil, err
}
indicator = 1
}
if len(message) == 0 { if len(message) == 0 {
return nil, errors.New("mldsa: empty message") return nil, errors.New("mldsa: empty message")
} }
@ -442,7 +493,7 @@ func (sk *PrivateKey44) Sign(rand io.Reader, message, context []byte) ([]byte, e
} }
H := sha3.NewSHAKE256() H := sha3.NewSHAKE256()
H.Write(sk.tr[:]) H.Write(sk.tr[:])
H.Write([]byte{0, byte(len(context))}) H.Write([]byte{indicator, byte(len(context))})
if len(context) > 0 { if len(context) > 0 {
H.Write(context) H.Write(context)
} }
@ -453,39 +504,6 @@ func (sk *PrivateKey44) Sign(rand io.Reader, message, context []byte) ([]byte, e
return sk.signInternal(seed[:], mu[:]) return sk.signInternal(seed[:], mu[:])
} }
// SignWithPreHash generates a digital signature for the given message
// using the private key and additional context. It uses a given hashing algorithm
// from the OID to pre-hash the message before signing.
// It is similar to Sign but allows for pre-hashing the message.
func (sk *PrivateKey44) SignWithPreHash(rand io.Reader, message, context []byte, oid asn1.ObjectIdentifier) ([]byte, error) {
if len(message) == 0 {
return nil, errors.New("mldsa: empty message")
}
if len(context) > 255 {
return nil, errors.New("mldsa: context too long")
}
preHashValue, err := preHash(oid, message)
if err != nil {
return nil, err
}
var seed [SeedSize]byte
if _, err := io.ReadFull(rand, seed[:]); err != nil {
return nil, err
}
H := sha3.NewSHAKE256()
H.Write(sk.tr[:])
H.Write([]byte{1, byte(len(context))})
if len(context) > 0 {
H.Write(context)
}
H.Write(preHashValue)
var mu [64]byte
H.Read(mu[:])
return sk.signInternal(seed[:], mu[:])
}
// See FIPS 204, Algorithm 7 ML-DSA.Sign_internal() // See FIPS 204, Algorithm 7 ML-DSA.Sign_internal()
func (sk *PrivateKey44) signInternal(seed, mu []byte) ([]byte, error) { func (sk *PrivateKey44) signInternal(seed, mu []byte) ([]byte, error) {
var rho2 [64 + 2]byte var rho2 [64 + 2]byte
@ -596,9 +614,25 @@ func (sk *PrivateKey44) signInternal(seed, mu []byte) ([]byte, error) {
} }
} }
// Verify checks the validity of a given signature for a message and context // VerifyWithOptions verifies a signature against a message using the public key with additional options.
// using the public key. func (pk *PublicKey44) VerifyWithOptions(sig []byte, message []byte, opts crypto.SignerOpts) bool {
func (pk *PublicKey44) Verify(sig []byte, message, context []byte) bool { var (
context []byte
hashOID asn1.ObjectIdentifier
indicator byte = 0
)
if opts, ok := opts.(*Options); ok {
context = opts.Context
hashOID = opts.PrehashOID
}
if len(hashOID) != 0 {
var err error
if message, err = preHash(hashOID, message); err != nil {
return false
}
indicator = 1
}
if len(message) == 0 { if len(message) == 0 {
return false return false
} }
@ -610,7 +644,7 @@ func (pk *PublicKey44) Verify(sig []byte, message, context []byte) bool {
} }
H := sha3.NewSHAKE256() H := sha3.NewSHAKE256()
H.Write(pk.tr[:]) H.Write(pk.tr[:])
H.Write([]byte{0, byte(len(context))}) H.Write([]byte{indicator, byte(len(context))})
if len(context) > 0 { if len(context) > 0 {
H.Write(context) H.Write(context)
} }
@ -621,35 +655,6 @@ func (pk *PublicKey44) Verify(sig []byte, message, context []byte) bool {
return pk.verifyInternal(sig, mu[:]) return pk.verifyInternal(sig, mu[:])
} }
// VerifyWithPreHash verifies a signature using a message and additional context.
// It uses a given hashing algorithm from the OID to pre-hash the message before verifying.
func (pk *PublicKey44) VerifyWithPreHash(sig []byte, message, context []byte, oid asn1.ObjectIdentifier) bool {
if len(message) == 0 {
return false
}
if len(context) > 255 {
return false
}
if len(sig) != sigEncodedLen44 {
return false
}
preHashValue, err := preHash(oid, message)
if err != nil {
return false
}
H := sha3.NewSHAKE256()
H.Write(pk.tr[:])
H.Write([]byte{1, byte(len(context))})
if len(context) > 0 {
H.Write(context)
}
H.Write(preHashValue)
var mu [64]byte
H.Read(mu[:])
return pk.verifyInternal(sig, mu[:])
}
// See FIPS 204, Algorithm 8 ML-DSA.Verify_internal() // See FIPS 204, Algorithm 8 ML-DSA.Verify_internal()
func (pk *PublicKey44) verifyInternal(sig, mu []byte) bool { func (pk *PublicKey44) verifyInternal(sig, mu []byte) bool {
// Decode the signature // Decode the signature

View File

@ -70,6 +70,10 @@ func TestKeyGen44(t *testing.T) {
if !priv.Equal(priv2) { if !priv.Equal(priv2) {
t.Errorf("Private key not equal: got %x, want %x", privBytes, priv2.Bytes()) t.Errorf("Private key not equal: got %x, want %x", privBytes, priv2.Bytes())
} }
pub3 := priv2.PublicKey()
if !pub.Equal(pub3) {
t.Errorf("Public key from private key not equal")
}
} }
} }
@ -232,7 +236,7 @@ func TestSignWithPreHash44(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("NewPrivateKey44 failed: %v", err) t.Fatalf("NewPrivateKey44 failed: %v", err)
} }
sig2, err := priv.SignWithPreHash(zeroReader, msg, context, c.oid) sig2, err := priv.Sign(zeroReader, msg, &Options{context, c.oid})
if err != nil { if err != nil {
t.Fatalf("failed to sign: %v", err) t.Fatalf("failed to sign: %v", err)
} }
@ -249,7 +253,7 @@ func TestSignWithPreHash44(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("NewPublicKey44 failed: %v", err) t.Fatalf("NewPublicKey44 failed: %v", err)
} }
if !pub.VerifyWithPreHash(sig, msg, context, c.oid) { if !pub.VerifyWithOptions(sig, msg, &Options{context, c.oid}) {
t.Error("signature verification failed") t.Error("signature verification failed")
} }
} }
@ -294,7 +298,7 @@ func TestVerify44(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("NewPublicKey44 failed: %v", err) t.Fatalf("NewPublicKey44 failed: %v", err)
} }
if pub.Verify(sig, msg, ctx) != c.passed { if pub.VerifyWithOptions(sig, msg, &Options{Context: ctx}) != c.passed {
t.Errorf("Verify failed") t.Errorf("Verify failed")
} }
} }
@ -351,10 +355,11 @@ func BenchmarkVerify44(b *testing.B) {
if err != nil { if err != nil {
b.Fatalf("NewPublicKey44 failed: %v", err) b.Fatalf("NewPublicKey44 failed: %v", err)
} }
opts := &Options{Context: ctx}
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
if !pub.Verify(sig, msg, ctx) { if !pub.VerifyWithOptions(sig, msg, opts) {
b.Errorf("Verify failed") b.Errorf("Verify failed")
} }
} }

View File

@ -25,11 +25,26 @@ type PrivateKey65 struct {
s1 [l65]ringElement // private secret of size L with short coefficients (-4..4) or (-2..2) s1 [l65]ringElement // private secret of size L with short coefficients (-4..4) or (-2..2)
s2 [k65]ringElement // private secret of size K with short coefficients (-4..4) or (-2..2) s2 [k65]ringElement // private secret of size K with short coefficients (-4..4) or (-2..2)
t0 [k65]ringElement // the Polynomial encoding of the 13 LSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the private key. t0 [k65]ringElement // the Polynomial encoding of the 13 LSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the private key.
t1 [k65]ringElement // the Polynomial encoding of the 10 MSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the public key.
s1NTTCache [l65]nttElement s1NTTCache [l65]nttElement
s2NTTCache [k65]nttElement s2NTTCache [k65]nttElement
t0NTTCache [k65]nttElement t0NTTCache [k65]nttElement
a [k65 * l65]nttElement // a is generated and stored in NTT representation a [k65 * l65]nttElement // a is generated and stored in NTT representation
nttOnce sync.Once nttOnce sync.Once
t1Once sync.Once
}
// PublicKey returns the public key corresponding to the private key.
// Although we can derive the public key from the private key,
// but we do NOT need to derive it at most of the time.
func (sk *PrivateKey65) PublicKey() crypto.PublicKey {
sk.ensureT1()
return &PublicKey65{
rho: sk.rho,
t1: sk.t1,
tr: sk.tr,
a: sk.a,
}
} }
func (sk *PrivateKey65) ensureNTT() { func (sk *PrivateKey65) ensureNTT() {
@ -46,11 +61,36 @@ func (sk *PrivateKey65) ensureNTT() {
}) })
} }
func (sk *PrivateKey65) ensureT1() {
sk.ensureNTT()
sk.t1Once.Do(func() {
// t = NTT_inv(A' * NTT(s1)) + s2
s1NTT := sk.s1NTTCache
A := sk.a
s2 := sk.s2
var nttT [k65]nttElement
for i := range nttT {
for j := range s1NTT {
nttT[i] = polyAdd(nttT[i], nttMul(s1NTT[j], A[i*l65+j]))
}
}
var t [k65]ringElement
t1 := &sk.t1
for i := range nttT {
t[i] = polyAdd(inverseNTT(nttT[i]), s2[i])
// compress t
for j := range n {
t1[i][j], _ = power2Round(t[i][j])
}
}
})
}
// A Key65 is the key pair for the ML-DSA-65 signature scheme. // A Key65 is the key pair for the ML-DSA-65 signature scheme.
type Key65 struct { type Key65 struct {
PrivateKey65 PrivateKey65
xi [32]byte // input seed xi [32]byte // input seed
t1 [k65]ringElement // the Polynomial encoding of the 10 MSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the public key.
} }
// A PublicKey65 is the public key for the ML-DSA-65 signature scheme. // A PublicKey65 is the public key for the ML-DSA-65 signature scheme.
@ -74,12 +114,21 @@ func (sk *Key65) PublicKey() *PublicKey65 {
} }
} }
// Seed returns a byte slice of the secret key's seed value.
func (sk *Key65) Seed() []byte {
var b [SeedSize]byte
copy(b[:], sk.xi[:])
return b[:]
}
func (pk *PublicKey65) Equal(x crypto.PublicKey) bool { func (pk *PublicKey65) Equal(x crypto.PublicKey) bool {
xx, ok := x.(*PublicKey65) xx, ok := x.(*PublicKey65)
if !ok { if !ok {
return false return false
} }
return pk.rho == xx.rho && pk.t1 == xx.t1 b1 := pk.Bytes()
b2 := xx.Bytes()
return subtle.ConstantTimeCompare(b1, b2) == 1
} }
// Bytes converts the PublicKey65 instance into a byte slice. // Bytes converts the PublicKey65 instance into a byte slice.
@ -110,15 +159,6 @@ func (pk *PublicKey65) ensureNTT() {
}) })
} }
// Bytes returns the byte representation of the PrivateKey65.
// It copies the internal seed (xi) into a fixed-size byte array
// and returns it as a slice.
func (sk *Key65) Bytes() []byte {
var b [SeedSize]byte
copy(b[:], sk.xi[:])
return b[:]
}
// Bytes converts the PrivateKey65 instance into a byte slice. // Bytes converts the PrivateKey65 instance into a byte slice.
// See FIPS 204, Algorithm 24, skEncode() // See FIPS 204, Algorithm 24, skEncode()
func (sk *PrivateKey65) Bytes() []byte { func (sk *PrivateKey65) Bytes() []byte {
@ -147,8 +187,9 @@ func (sk *PrivateKey65) Equal(x any) bool {
if !ok { if !ok {
return false return false
} }
return sk.rho == xx.rho && sk.k == xx.k && sk.tr == xx.tr && b1 := sk.Bytes()
sk.s1 == xx.s1 && sk.s2 == xx.s2 && sk.t0 == xx.t0 b2 := xx.Bytes()
return subtle.ConstantTimeCompare(b1, b2) == 1
} }
// GenerateKey65 generates a new Key65 (ML-DSA-65) using the provided random source. // GenerateKey65 generates a new Key65 (ML-DSA-65) using the provided random source.
@ -188,8 +229,7 @@ func dsaKeyGen65(sk *Key65, xi *[32]byte) {
sk.xi = *xi sk.xi = *xi
H := sha3.NewSHAKE256() H := sha3.NewSHAKE256()
H.Write(xi[:]) H.Write(xi[:])
H.Write([]byte{k65}) H.Write([]byte{k65, l65})
H.Write([]byte{l65})
K := make([]byte, 128) K := make([]byte, 128)
H.Read(K) H.Read(K)
rho, rho1 := K[:32], K[32:96] rho, rho1 := K[:32], K[32:96]
@ -201,17 +241,17 @@ func dsaKeyGen65(sk *Key65, xi *[32]byte) {
s1 := &sk.s1 s1 := &sk.s1
s2 := &sk.s2 s2 := &sk.s2
// Algorithm 33, ExpandS // Algorithm 33, ExpandS
for s := byte(0); s < l65; s++ { for s := range byte(l65) {
s1[s] = rejBoundedPoly(rho1, eta4, 0, s) s1[s] = rejBoundedPoly(rho1, eta4, 0, s)
} }
for r := byte(0); r < k65; r++ { for r := range byte(k65) {
s2[r] = rejBoundedPoly(rho1, eta4, 0, r+l65) s2[r] = rejBoundedPoly(rho1, eta4, 0, r+l65)
} }
// Using rho generate A' = A in NTT form // Using rho generate A' = A in NTT form
A := &sk.a A := &sk.a
// Algorithm 32, ExpandA // Algorithm 32, ExpandA
for r := byte(0); r < k65; r++ { for r := range byte(k65) {
for s := byte(0); s < l65; s++ { for s := byte(0); s < l65; s++ {
A[r*l65+s] = rejNTTPoly(rho, s, r) A[r*l65+s] = rejNTTPoly(rho, s, r)
} }
@ -272,8 +312,8 @@ func parsePublicKey65(pk *PublicKey65, b []byte) (*PublicKey65, error) {
A := &pk.a A := &pk.a
rho := pk.rho[:] rho := pk.rho[:]
// Algorithm 32, ExpandA // Algorithm 32, ExpandA
for r := byte(0); r < k65; r++ { for r := range byte(k65) {
for s := byte(0); s < l65; s++ { for s := range byte(l65) {
A[r*l65+s] = rejNTTPoly(rho, s, r) A[r*l65+s] = rejNTTPoly(rho, s, r)
} }
} }
@ -321,32 +361,42 @@ func parsePrivateKey65(sk *PrivateKey65, b []byte) (*PrivateKey65, error) {
A := &sk.a A := &sk.a
rho := sk.rho[:] rho := sk.rho[:]
// Algorithm 32, ExpandA // Algorithm 32, ExpandA
for r := byte(0); r < k65; r++ { for r := range byte(k65) {
for s := byte(0); s < l65; s++ { for s := range byte(l65) {
A[r*l65+s] = rejNTTPoly(rho, s, r) A[r*l65+s] = rejNTTPoly(rho, s, r)
} }
} }
return sk, nil return sk, nil
} }
// Sign generates a digital signature for the given message and context using the private key. // Sign signs the provided digest using the private key. It is a wrapper around SignMessage.
// It uses a random seed generated from the provided random source. // It satisfies the crypto.Signer interface.
func (sk *PrivateKey65) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
return sk.SignMessage(rand, digest, opts)
}
// SignMessage signs a message with the private key.
// It satisfies the crypto.MessageSigner interface.
// //
// Parameters: // The function supports pre-hashing the message by providing a hash OID in the options.
// - rand: An io.Reader used to generate a random seed for signing. // Context data can also be provided, but is limited to 255 bytes.
// - message: The message to be signed. Must not be empty. func (sk *PrivateKey65) SignMessage(rand io.Reader, message []byte, opts crypto.SignerOpts) ([]byte, error) {
// - context: An optional context for domain separation. Must not exceed 255 bytes. var (
// context []byte
// Returns: hashOID asn1.ObjectIdentifier
// - A byte slice containing the generated signature. indicator byte = 0
// - An error if the message is empty, the context is too long, or if there is an issue )
// reading from the random source. if opts, ok := opts.(*Options); ok {
// context = opts.Context
// Note: hashOID = opts.PrehashOID
// - The function uses SHAKE256 from the SHA-3 family for hashing. }
// - The signing process involves generating a unique seed and a hash-based if len(hashOID) != 0 {
// message digest (mu) before delegating to the internal signing function. var err error
func (sk *PrivateKey65) Sign(rand io.Reader, message, context []byte) ([]byte, error) { if message, err = preHash(hashOID, message); err != nil {
return nil, err
}
indicator = 1
}
if len(message) == 0 { if len(message) == 0 {
return nil, errors.New("mldsa: empty message") return nil, errors.New("mldsa: empty message")
} }
@ -359,7 +409,7 @@ func (sk *PrivateKey65) Sign(rand io.Reader, message, context []byte) ([]byte, e
} }
H := sha3.NewSHAKE256() H := sha3.NewSHAKE256()
H.Write(sk.tr[:]) H.Write(sk.tr[:])
H.Write([]byte{0, byte(len(context))}) H.Write([]byte{indicator, byte(len(context))})
if len(context) > 0 { if len(context) > 0 {
H.Write(context) H.Write(context)
} }
@ -370,39 +420,6 @@ func (sk *PrivateKey65) Sign(rand io.Reader, message, context []byte) ([]byte, e
return sk.signInternal(seed[:], mu[:]) return sk.signInternal(seed[:], mu[:])
} }
// SignWithPreHash generates a digital signature for the given message
// using the private key and additional context. It uses a given hashing algorithm
// from the OID to pre-hash the message before signing.
// It is similar to Sign but allows for pre-hashing the message.
func (sk *PrivateKey65) SignWithPreHash(rand io.Reader, message, context []byte, oid asn1.ObjectIdentifier) ([]byte, error) {
if len(message) == 0 {
return nil, errors.New("mldsa: empty message")
}
if len(context) > 255 {
return nil, errors.New("mldsa: context too long")
}
preHashValue, err := preHash(oid, message)
if err != nil {
return nil, err
}
var seed [SeedSize]byte
if _, err := io.ReadFull(rand, seed[:]); err != nil {
return nil, err
}
H := sha3.NewSHAKE256()
H.Write(sk.tr[:])
H.Write([]byte{1, byte(len(context))})
if len(context) > 0 {
H.Write(context)
}
H.Write(preHashValue)
var mu [64]byte
H.Read(mu[:])
return sk.signInternal(seed[:], mu[:])
}
// See FIPS 204, Algorithm 7 ML-DSA.Sign_internal() // See FIPS 204, Algorithm 7 ML-DSA.Sign_internal()
func (sk *PrivateKey65) signInternal(seed, mu []byte) ([]byte, error) { func (sk *PrivateKey65) signInternal(seed, mu []byte) ([]byte, error) {
var rho2 [64 + 2]byte var rho2 [64 + 2]byte
@ -418,7 +435,7 @@ func (sk *PrivateKey65) signInternal(seed, mu []byte) ([]byte, error) {
r0NormThreshold := int(gamma2QMinus1Div32 - beta65) r0NormThreshold := int(gamma2QMinus1Div32 - beta65)
// rejection sampling loop // rejection sampling loop
for kappa := 0; ; kappa = kappa + l65 { for kappa := 0; ; kappa += l65 {
// expand mask // expand mask
var ( var (
y [l65]ringElement y [l65]ringElement
@ -513,9 +530,25 @@ func (sk *PrivateKey65) signInternal(seed, mu []byte) ([]byte, error) {
} }
} }
// Verify checks the validity of a given signature for a message and context // VerifyWithOptions verifies a signature against a message using the public key with additional options.
// using the public key. func (pk *PublicKey65) VerifyWithOptions(sig []byte, message []byte, opts crypto.SignerOpts) bool {
func (pk *PublicKey65) Verify(sig []byte, message, context []byte) bool { var (
context []byte
hashOID asn1.ObjectIdentifier
indicator byte = 0
)
if opts, ok := opts.(*Options); ok {
context = opts.Context
hashOID = opts.PrehashOID
}
if len(hashOID) != 0 {
var err error
if message, err = preHash(hashOID, message); err != nil {
return false
}
indicator = 1
}
if len(message) == 0 { if len(message) == 0 {
return false return false
} }
@ -527,7 +560,7 @@ func (pk *PublicKey65) Verify(sig []byte, message, context []byte) bool {
} }
H := sha3.NewSHAKE256() H := sha3.NewSHAKE256()
H.Write(pk.tr[:]) H.Write(pk.tr[:])
H.Write([]byte{0, byte(len(context))}) H.Write([]byte{indicator, byte(len(context))})
if len(context) > 0 { if len(context) > 0 {
H.Write(context) H.Write(context)
} }
@ -538,35 +571,6 @@ func (pk *PublicKey65) Verify(sig []byte, message, context []byte) bool {
return pk.verifyInternal(sig, mu[:]) return pk.verifyInternal(sig, mu[:])
} }
// VerifyWithPreHash verifies a signature using a message and additional context.
// It uses a given hashing algorithm from the OID to pre-hash the message before verifying.
func (pk *PublicKey65) VerifyWithPreHash(sig []byte, message, context []byte, oid asn1.ObjectIdentifier) bool {
if len(message) == 0 {
return false
}
if len(context) > 255 {
return false
}
if len(sig) != sigEncodedLen65 {
return false
}
preHashValue, err := preHash(oid, message)
if err != nil {
return false
}
H := sha3.NewSHAKE256()
H.Write(pk.tr[:])
H.Write([]byte{1, byte(len(context))})
if len(context) > 0 {
H.Write(context)
}
H.Write(preHashValue)
var mu [64]byte
H.Read(mu[:])
return pk.verifyInternal(sig, mu[:])
}
// See FIPS 204, Algorithm 8 ML-DSA.Verify_internal() // See FIPS 204, Algorithm 8 ML-DSA.Verify_internal()
func (pk *PublicKey65) verifyInternal(sig, mu []byte) bool { func (pk *PublicKey65) verifyInternal(sig, mu []byte) bool {
// Decode the signature // Decode the signature
@ -622,5 +626,5 @@ func (pk *PublicKey65) verifyInternal(sig, mu []byte) bool {
var cTilde1 [lambda192 / 4]byte var cTilde1 [lambda192 / 4]byte
H.Read(cTilde1[:]) H.Read(cTilde1[:])
return subtle.ConstantTimeLessOrEq(int(gamma1TwoPower19-beta65), zNorm) == 0 && return subtle.ConstantTimeLessOrEq(int(gamma1TwoPower19-beta65), zNorm) == 0 &&
subtle.ConstantTimeCompare(cTilde[:], cTilde1[:]) == 1 subtle.ConstantTimeCompare(cTilde, cTilde1[:]) == 1
} }

View File

@ -70,6 +70,10 @@ func TestKeyGen65(t *testing.T) {
if !priv.Equal(priv2) { if !priv.Equal(priv2) {
t.Errorf("Private key not equal: got %x, want %x", privBytes, priv2.Bytes()) t.Errorf("Private key not equal: got %x, want %x", privBytes, priv2.Bytes())
} }
pub3 := priv2.PublicKey()
if !pub.Equal(pub3) {
t.Errorf("Public key from private key not equal")
}
} }
} }
@ -232,7 +236,7 @@ func TestSignWithPreHash65(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("NewPrivateKey65 failed: %v", err) t.Fatalf("NewPrivateKey65 failed: %v", err)
} }
sig2, err := priv.SignWithPreHash(zeroReader, msg, context, c.oid) sig2, err := priv.Sign(zeroReader, msg, &Options{context, c.oid})
if err != nil { if err != nil {
t.Fatalf("failed to sign: %v", err) t.Fatalf("failed to sign: %v", err)
} }
@ -249,7 +253,7 @@ func TestSignWithPreHash65(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("NewPublicKey65 failed: %v", err) t.Fatalf("NewPublicKey65 failed: %v", err)
} }
if !pub.VerifyWithPreHash(sig, msg, context, c.oid) { if !pub.VerifyWithOptions(sig, msg, &Options{context, c.oid}) {
t.Error("signature verification failed") t.Error("signature verification failed")
} }
} }
@ -294,7 +298,7 @@ func TestVerify65(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("NewPublicKey65 failed: %v", err) t.Fatalf("NewPublicKey65 failed: %v", err)
} }
if pub.Verify(sig, msg, ctx) != c.passed { if pub.VerifyWithOptions(sig, msg, &Options{Context: ctx}) != c.passed {
t.Errorf("Verify failed") t.Errorf("Verify failed")
} }
} }
@ -341,10 +345,11 @@ func BenchmarkVerify65(b *testing.B) {
if err != nil { if err != nil {
b.Fatalf("NewPublicKey65 failed: %v", err) b.Fatalf("NewPublicKey65 failed: %v", err)
} }
opts := &Options{Context: ctx}
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
if !pub.Verify(sig, msg, ctx) { if !pub.VerifyWithOptions(sig, msg, opts) {
b.Errorf("Verify failed") b.Errorf("Verify failed")
} }
} }

View File

@ -25,11 +25,26 @@ type PrivateKey87 struct {
s1 [l87]ringElement // private secret of size L with short coefficients (-4..4) or (-2..2) s1 [l87]ringElement // private secret of size L with short coefficients (-4..4) or (-2..2)
s2 [k87]ringElement // private secret of size K with short coefficients (-4..4) or (-2..2) s2 [k87]ringElement // private secret of size K with short coefficients (-4..4) or (-2..2)
t0 [k87]ringElement // the Polynomial encoding of the 13 LSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the private key. t0 [k87]ringElement // the Polynomial encoding of the 13 LSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the private key.
t1 [k87]ringElement // the Polynomial encoding of the 10 MSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the public key.
s1NTTCache [l87]nttElement s1NTTCache [l87]nttElement
s2NTTCache [k87]nttElement s2NTTCache [k87]nttElement
t0NTTCache [k87]nttElement t0NTTCache [k87]nttElement
a [k87 * l87]nttElement // a is generated and stored in NTT representation a [k87 * l87]nttElement // a is generated and stored in NTT representation
nttOnce sync.Once nttOnce sync.Once
t1Once sync.Once
}
// PublicKey returns the public key corresponding to the private key.
// Although we can derive the public key from the private key,
// but we do NOT need to derive it at most of the time.
func (sk *PrivateKey87) PublicKey() crypto.PublicKey {
sk.ensureT1()
return &PublicKey87{
rho: sk.rho,
t1: sk.t1,
tr: sk.tr,
a: sk.a,
}
} }
func (sk *PrivateKey87) ensureNTT() { func (sk *PrivateKey87) ensureNTT() {
@ -46,11 +61,36 @@ func (sk *PrivateKey87) ensureNTT() {
}) })
} }
func (sk *PrivateKey87) ensureT1() {
sk.ensureNTT()
sk.t1Once.Do(func() {
// t = NTT_inv(A' * NTT(s1)) + s2
s1NTT := sk.s1NTTCache
A := sk.a
s2 := sk.s2
var nttT [k87]nttElement
for i := range nttT {
for j := range s1NTT {
nttT[i] = polyAdd(nttT[i], nttMul(s1NTT[j], A[i*l87+j]))
}
}
var t [k87]ringElement
t1 := &sk.t1
for i := range nttT {
t[i] = polyAdd(inverseNTT(nttT[i]), s2[i])
// compress t
for j := range n {
t1[i][j], _ = power2Round(t[i][j])
}
}
})
}
// A Key87 is the key pair for the ML-DSA-87 signature scheme. // A Key87 is the key pair for the ML-DSA-87 signature scheme.
type Key87 struct { type Key87 struct {
PrivateKey87 PrivateKey87
xi [32]byte // input seed xi [32]byte // input seed
t1 [k87]ringElement // the Polynomial encoding of the 10 MSB of each coefficient of the uncompressed public key polynomial t. This is saved as part of the public key.
} }
// A PublicKey87 is the public key for the ML-DSA-87 signature scheme. // A PublicKey87 is the public key for the ML-DSA-87 signature scheme.
@ -74,12 +114,21 @@ func (sk *Key87) PublicKey() *PublicKey87 {
} }
} }
// Seed returns a byte slice of the secret key's seed value.
func (sk *Key87) Seed() []byte {
var b [SeedSize]byte
copy(b[:], sk.xi[:])
return b[:]
}
func (pk *PublicKey87) Equal(x crypto.PublicKey) bool { func (pk *PublicKey87) Equal(x crypto.PublicKey) bool {
xx, ok := x.(*PublicKey87) xx, ok := x.(*PublicKey87)
if !ok { if !ok {
return false return false
} }
return pk.rho == xx.rho && pk.t1 == xx.t1 b1 := pk.Bytes()
b2 := xx.Bytes()
return subtle.ConstantTimeCompare(b1, b2) == 1
} }
// Bytes converts the PublicKey87 instance into a byte slice. // Bytes converts the PublicKey87 instance into a byte slice.
@ -110,15 +159,6 @@ func (pk *PublicKey87) ensureNTT() {
}) })
} }
// Bytes returns the byte representation of the PrivateKey87.
// It copies the internal seed (xi) into a fixed-size byte array
// and returns it as a slice.
func (sk *Key87) Bytes() []byte {
var b [SeedSize]byte
copy(b[:], sk.xi[:])
return b[:]
}
// Bytes converts the PrivateKey87 instance into a byte slice. // Bytes converts the PrivateKey87 instance into a byte slice.
// See FIPS 204, Algorithm 24, skEncode() // See FIPS 204, Algorithm 24, skEncode()
func (sk *PrivateKey87) Bytes() []byte { func (sk *PrivateKey87) Bytes() []byte {
@ -147,8 +187,9 @@ func (sk *PrivateKey87) Equal(x any) bool {
if !ok { if !ok {
return false return false
} }
return sk.rho == xx.rho && sk.k == xx.k && sk.tr == xx.tr && b1 := sk.Bytes()
sk.s1 == xx.s1 && sk.s2 == xx.s2 && sk.t0 == xx.t0 b2 := xx.Bytes()
return subtle.ConstantTimeCompare(b1, b2) == 1
} }
// GenerateKey87 generates a new Key87 (ML-DSA-87) using the provided random source. // GenerateKey87 generates a new Key87 (ML-DSA-87) using the provided random source.
@ -188,8 +229,7 @@ func dsaKeyGen87(sk *Key87, xi *[32]byte) {
sk.xi = *xi sk.xi = *xi
H := sha3.NewSHAKE256() H := sha3.NewSHAKE256()
H.Write(xi[:]) H.Write(xi[:])
H.Write([]byte{k87}) H.Write([]byte{k87, l87})
H.Write([]byte{l87})
K := make([]byte, 128) K := make([]byte, 128)
H.Read(K) H.Read(K)
rho, rho1 := K[:32], K[32:96] rho, rho1 := K[:32], K[32:96]
@ -201,17 +241,17 @@ func dsaKeyGen87(sk *Key87, xi *[32]byte) {
s1 := &sk.s1 s1 := &sk.s1
s2 := &sk.s2 s2 := &sk.s2
// Algorithm 33, ExpandS // Algorithm 33, ExpandS
for s := byte(0); s < l87; s++ { for s := range byte(l87) {
s1[s] = rejBoundedPoly(rho1, eta2, 0, s) s1[s] = rejBoundedPoly(rho1, eta2, 0, s)
} }
for r := byte(0); r < k87; r++ { for r := range byte(k87) {
s2[r] = rejBoundedPoly(rho1, eta2, 0, r+l87) s2[r] = rejBoundedPoly(rho1, eta2, 0, r+l87)
} }
// Using rho generate A' = A in NTT form // Using rho generate A' = A in NTT form
A := &sk.a A := &sk.a
// Algorithm 32, ExpandA // Algorithm 32, ExpandA
for r := byte(0); r < k87; r++ { for r := range byte(k87) {
for s := byte(0); s < l87; s++ { for s := byte(0); s < l87; s++ {
A[r*l87+s] = rejNTTPoly(rho, s, r) A[r*l87+s] = rejNTTPoly(rho, s, r)
} }
@ -272,8 +312,8 @@ func parsePublicKey87(pk *PublicKey87, b []byte) (*PublicKey87, error) {
A := &pk.a A := &pk.a
rho := pk.rho[:] rho := pk.rho[:]
// Algorithm 32, ExpandA // Algorithm 32, ExpandA
for r := byte(0); r < k87; r++ { for r := range byte(k87) {
for s := byte(0); s < l87; s++ { for s := range byte(l87) {
A[r*l87+s] = rejNTTPoly(rho, s, r) A[r*l87+s] = rejNTTPoly(rho, s, r)
} }
} }
@ -321,32 +361,42 @@ func parsePrivateKey87(sk *PrivateKey87, b []byte) (*PrivateKey87, error) {
A := &sk.a A := &sk.a
rho := sk.rho[:] rho := sk.rho[:]
// Algorithm 32, ExpandA // Algorithm 32, ExpandA
for r := byte(0); r < k87; r++ { for r := range byte(k87) {
for s := byte(0); s < l87; s++ { for s := range byte(l87) {
A[r*l87+s] = rejNTTPoly(rho, s, r) A[r*l87+s] = rejNTTPoly(rho, s, r)
} }
} }
return sk, nil return sk, nil
} }
// Sign generates a digital signature for the given message and context using the private key. // Sign signs the provided digest using the private key. It is a wrapper around SignMessage.
// It uses a random seed generated from the provided random source. // It satisfies the crypto.Signer interface.
func (sk *PrivateKey87) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
return sk.SignMessage(rand, digest, opts)
}
// SignMessage signs a message with the private key.
// It satisfies the crypto.MessageSigner interface.
// //
// Parameters: // The function supports pre-hashing the message by providing a hash OID in the options.
// - rand: An io.Reader used to generate a random seed for signing. // Context data can also be provided, but is limited to 255 bytes.
// - message: The message to be signed. Must not be empty. func (sk *PrivateKey87) SignMessage(rand io.Reader, message []byte, opts crypto.SignerOpts) ([]byte, error) {
// - context: An optional context for domain separation. Must not exceed 255 bytes. var (
// context []byte
// Returns: hashOID asn1.ObjectIdentifier
// - A byte slice containing the generated signature. indicator byte = 0
// - An error if the message is empty, the context is too long, or if there is an issue )
// reading from the random source. if opts, ok := opts.(*Options); ok {
// context = opts.Context
// Note: hashOID = opts.PrehashOID
// - The function uses SHAKE256 from the SHA-3 family for hashing. }
// - The signing process involves generating a unique seed and a hash-based if len(hashOID) != 0 {
// message digest (mu) before delegating to the internal signing function. var err error
func (sk *PrivateKey87) Sign(rand io.Reader, message, context []byte) ([]byte, error) { if message, err = preHash(hashOID, message); err != nil {
return nil, err
}
indicator = 1
}
if len(message) == 0 { if len(message) == 0 {
return nil, errors.New("mldsa: empty message") return nil, errors.New("mldsa: empty message")
} }
@ -359,7 +409,7 @@ func (sk *PrivateKey87) Sign(rand io.Reader, message, context []byte) ([]byte, e
} }
H := sha3.NewSHAKE256() H := sha3.NewSHAKE256()
H.Write(sk.tr[:]) H.Write(sk.tr[:])
H.Write([]byte{0, byte(len(context))}) H.Write([]byte{indicator, byte(len(context))})
if len(context) > 0 { if len(context) > 0 {
H.Write(context) H.Write(context)
} }
@ -370,39 +420,6 @@ func (sk *PrivateKey87) Sign(rand io.Reader, message, context []byte) ([]byte, e
return sk.signInternal(seed[:], mu[:]) return sk.signInternal(seed[:], mu[:])
} }
// SignWithPreHash generates a digital signature for the given message
// using the private key and additional context. It uses a given hashing algorithm
// from the OID to pre-hash the message before signing.
// It is similar to Sign but allows for pre-hashing the message.
func (sk *PrivateKey87) SignWithPreHash(rand io.Reader, message, context []byte, oid asn1.ObjectIdentifier) ([]byte, error) {
if len(message) == 0 {
return nil, errors.New("mldsa: empty message")
}
if len(context) > 255 {
return nil, errors.New("mldsa: context too long")
}
preHashValue, err := preHash(oid, message)
if err != nil {
return nil, err
}
var seed [SeedSize]byte
if _, err := io.ReadFull(rand, seed[:]); err != nil {
return nil, err
}
H := sha3.NewSHAKE256()
H.Write(sk.tr[:])
H.Write([]byte{1, byte(len(context))})
if len(context) > 0 {
H.Write(context)
}
H.Write(preHashValue)
var mu [64]byte
H.Read(mu[:])
return sk.signInternal(seed[:], mu[:])
}
// See FIPS 204, Algorithm 7 ML-DSA.Sign_internal() // See FIPS 204, Algorithm 7 ML-DSA.Sign_internal()
func (sk *PrivateKey87) signInternal(seed, mu []byte) ([]byte, error) { func (sk *PrivateKey87) signInternal(seed, mu []byte) ([]byte, error) {
var rho2 [64 + 2]byte var rho2 [64 + 2]byte
@ -418,7 +435,7 @@ func (sk *PrivateKey87) signInternal(seed, mu []byte) ([]byte, error) {
r0NormThreshold := int(gamma2QMinus1Div32 - beta87) r0NormThreshold := int(gamma2QMinus1Div32 - beta87)
// rejection sampling loop // rejection sampling loop
for kappa := 0; ; kappa = kappa + l87 { for kappa := 0; ; kappa += l87 {
// expand mask // expand mask
var ( var (
y [l87]ringElement y [l87]ringElement
@ -513,9 +530,25 @@ func (sk *PrivateKey87) signInternal(seed, mu []byte) ([]byte, error) {
} }
} }
// Verify checks the validity of a given signature for a message and context // VerifyWithOptions verifies a signature against a message using the public key with additional options.
// using the public key. func (pk *PublicKey87) VerifyWithOptions(sig []byte, message []byte, opts crypto.SignerOpts) bool {
func (pk *PublicKey87) Verify(sig []byte, message, context []byte) bool { var (
context []byte
hashOID asn1.ObjectIdentifier
indicator byte = 0
)
if opts, ok := opts.(*Options); ok {
context = opts.Context
hashOID = opts.PrehashOID
}
if len(hashOID) != 0 {
var err error
if message, err = preHash(hashOID, message); err != nil {
return false
}
indicator = 1
}
if len(message) == 0 { if len(message) == 0 {
return false return false
} }
@ -527,7 +560,7 @@ func (pk *PublicKey87) Verify(sig []byte, message, context []byte) bool {
} }
H := sha3.NewSHAKE256() H := sha3.NewSHAKE256()
H.Write(pk.tr[:]) H.Write(pk.tr[:])
H.Write([]byte{0, byte(len(context))}) H.Write([]byte{indicator, byte(len(context))})
if len(context) > 0 { if len(context) > 0 {
H.Write(context) H.Write(context)
} }
@ -538,35 +571,6 @@ func (pk *PublicKey87) Verify(sig []byte, message, context []byte) bool {
return pk.verifyInternal(sig, mu[:]) return pk.verifyInternal(sig, mu[:])
} }
// VerifyWithPreHash verifies a signature using a message and additional context.
// It uses a given hashing algorithm from the OID to pre-hash the message before verifying.
func (pk *PublicKey87) VerifyWithPreHash(sig []byte, message, context []byte, oid asn1.ObjectIdentifier) bool {
if len(message) == 0 {
return false
}
if len(context) > 255 {
return false
}
if len(sig) != sigEncodedLen87 {
return false
}
preHashValue, err := preHash(oid, message)
if err != nil {
return false
}
H := sha3.NewSHAKE256()
H.Write(pk.tr[:])
H.Write([]byte{1, byte(len(context))})
if len(context) > 0 {
H.Write(context)
}
H.Write(preHashValue)
var mu [64]byte
H.Read(mu[:])
return pk.verifyInternal(sig, mu[:])
}
// See FIPS 204, Algorithm 8 ML-DSA.Verify_internal() // See FIPS 204, Algorithm 8 ML-DSA.Verify_internal()
func (pk *PublicKey87) verifyInternal(sig, mu []byte) bool { func (pk *PublicKey87) verifyInternal(sig, mu []byte) bool {
// Decode the signature // Decode the signature
@ -622,5 +626,5 @@ func (pk *PublicKey87) verifyInternal(sig, mu []byte) bool {
var cTilde1 [lambda256 / 4]byte var cTilde1 [lambda256 / 4]byte
H.Read(cTilde1[:]) H.Read(cTilde1[:])
return subtle.ConstantTimeLessOrEq(int(gamma1TwoPower19-beta87), zNorm) == 0 && return subtle.ConstantTimeLessOrEq(int(gamma1TwoPower19-beta87), zNorm) == 0 &&
subtle.ConstantTimeCompare(cTilde[:], cTilde1[:]) == 1 subtle.ConstantTimeCompare(cTilde, cTilde1[:]) == 1
} }

View File

@ -70,6 +70,10 @@ func TestKeyGen87(t *testing.T) {
if !priv.Equal(priv2) { if !priv.Equal(priv2) {
t.Errorf("Private key not equal: got %x, want %x", privBytes, priv2.Bytes()) t.Errorf("Private key not equal: got %x, want %x", privBytes, priv2.Bytes())
} }
pub3 := priv2.PublicKey()
if !pub.Equal(pub3) {
t.Errorf("Public key from private key not equal")
}
} }
} }
@ -184,7 +188,7 @@ func TestSignWithPreHash87(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("NewPrivateKey87 failed: %v", err) t.Fatalf("NewPrivateKey87 failed: %v", err)
} }
sig2, err := priv.SignWithPreHash(zeroReader, msg, context, c.oid) sig2, err := priv.Sign(zeroReader, msg, &Options{context, c.oid})
if err != nil { if err != nil {
t.Fatalf("failed to sign: %v", err) t.Fatalf("failed to sign: %v", err)
} }
@ -201,7 +205,7 @@ func TestSignWithPreHash87(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("NewPublicKey87 failed: %v", err) t.Fatalf("NewPublicKey87 failed: %v", err)
} }
if !pub.VerifyWithPreHash(sig, msg, context, c.oid) { if !pub.VerifyWithOptions(sig, msg, &Options{context, c.oid}) {
t.Error("signature verification failed") t.Error("signature verification failed")
} }
} }
@ -254,7 +258,7 @@ func TestVerify87(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("NewPublicKey87 failed: %v", err) t.Fatalf("NewPublicKey87 failed: %v", err)
} }
if pub.Verify(sig, msg, ctx) != c.passed { if pub.VerifyWithOptions(sig, msg, &Options{Context: ctx}) != c.passed {
t.Errorf("Verify failed") t.Errorf("Verify failed")
} }
} }
@ -301,10 +305,11 @@ func BenchmarkVerify87(b *testing.B) {
if err != nil { if err != nil {
b.Fatalf("NewPublicKey87 failed: %v", err) b.Fatalf("NewPublicKey87 failed: %v", err)
} }
opts := &Options{Context: ctx}
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
if !pub.Verify(sig, msg, ctx) { if !pub.VerifyWithOptions(sig, msg, opts) {
b.Errorf("Verify failed") b.Errorf("Verify failed")
} }
} }

View File

@ -7,6 +7,7 @@
package mldsa package mldsa
import ( import (
"crypto"
"crypto/sha256" "crypto/sha256"
"crypto/sha3" "crypto/sha3"
"crypto/sha512" "crypto/sha512"
@ -90,3 +91,12 @@ func preHash(oid asn1.ObjectIdentifier, data []byte) ([]byte, error) {
oidBytes, _ := asn1.Marshal(oid) oidBytes, _ := asn1.Marshal(oid)
return h.Sum(oidBytes), nil return h.Sum(oidBytes), nil
} }
type Options struct {
Context []byte
PrehashOID asn1.ObjectIdentifier
}
func (opts *Options) HashFunc() crypto.Hash {
return crypto.Hash(0)
}