slhdsa: reduce slice make times and supplement comments

This commit is contained in:
Sun Yimin 2025-05-22 14:21:24 +08:00 committed by GitHub
parent 44b9419aa7
commit c467b22fb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 76 additions and 51 deletions

View File

@ -38,17 +38,16 @@ func (sk *PrivateKey) Sign(message, context, addRand []byte) ([]byte, error) {
// See FIPS 205 Algorithm 19 slh_sign_internal
func (sk *PrivateKey) signInternal(msgPrefix, message, addRand []byte) ([]byte, error) {
signatureStart := make([]byte, sk.params.sigLen)
adrs := sk.addressCreator()
signatureHead := make([]byte, sk.params.sigLen)
// generate randomizer
if len(addRand) == 0 {
// substitute addRand with sk.PublicKey.seed for the deterministic variant
addRand = sk.PublicKey.seed[:sk.params.n]
}
sk.h.prfMsg(sk, addRand, msgPrefix, message, signatureStart)
R := signatureStart[:sk.params.n]
signature := signatureStart[sk.params.n:]
sk.h.prfMsg(sk, addRand, msgPrefix, message, signatureHead)
R := signatureHead[:sk.params.n]
signature := signatureHead[sk.params.n:]
// compute message digest
var digest [MAX_M]byte
@ -65,6 +64,9 @@ func (sk *PrivateKey) signInternal(msgPrefix, message, addRand []byte) ([]byte,
remaining = remaining[treeIdxLen:]
leafIdx := uint32(toInt(remaining[:leafIdxLen]) & sk.params.leafIdxMask())
// The address adrs must have the layer address set to zero (since the XMSS tree that signs a FORS key is always at layer 0),
// the tree address set to the index of the WOTS+ key within the XMSS tree that signs the FORS key.
adrs := sk.addressCreator()
adrs.setTreeAddress(treeIdx)
adrs.setTypeAndClear(AddressTypeFORSTree)
adrs.setKeyPairAddress(leafIdx)
@ -77,7 +79,7 @@ func (sk *PrivateKey) signInternal(msgPrefix, message, addRand []byte) ([]byte,
// generate ht signature and append to the SLH-DSA signature
sk.htSign(pkFors[:sk.params.n], treeIdx, leafIdx, signature)
return signatureStart, nil
return signatureHead, nil
}
// Verify verifies a pure SLH-DSA signature for the given message.

View File

@ -6,11 +6,14 @@
package slhdsa
// forsSign generates a FORS signature.
// forsSign generates a FORS signature. It signs a k*a-bits message digest md.
// In addition, it takes PrivateKey.seed and PublicKey.seed from the sk and and an address as input.
// The sigFors is a FORS signature of size n*k*(a+1) as result.
//
// See FIPS 205 Algorithm 16 fors_sign
func (sk *PrivateKey) forsSign(md []byte, adrs adrsOperations, sigFors []byte) {
var indices [MAX_K]uint32
// split md into k a-bits values
// split md into k a-bits values, eatch of which is interpreted as an integer between 0 and 2^a-1.
base2b(md, sk.params.a, indices[:sk.params.k])
twoPowerA := uint32(1 << sk.params.a)
@ -21,20 +24,22 @@ func (sk *PrivateKey) forsSign(md []byte, adrs adrsOperations, sigFors []byte) {
sk.forsGenPrivateKey(nodeID+treeIDTimeTwoPowerA, adrs, sigFors)
sigFors = sigFors[sk.params.n:]
// compute auth path
treeOffset := treeIDTimeTwoPowerA
for layer := range sk.params.a {
for j := range sk.params.a {
s := nodeID ^ 1
sk.forsNode(s+treeOffset, layer, adrs, sigFors)
sk.forsNode(s+treeOffset, j, adrs, sigFors)
nodeID >>= 1
treeOffset >>= 1
sigFors = sigFors[sk.params.n:]
}
treeIDTimeTwoPowerA += twoPowerA
treeIDTimeTwoPowerA += twoPowerA // same as treeIDTimeTwoPowerA = treeID*twoPowerA
}
}
// forsPkFromSig computes a FORS public key from a FORS signature.
//
// See FIPS 205 Algorithm 17 fors_pkFromSig
func (pk *PublicKey) forsPkFromSig(md, signature []byte, adrs adrsOperations, out []byte) []byte {
var indices [MAX_K]uint32
@ -43,7 +48,6 @@ func (pk *PublicKey) forsPkFromSig(md, signature []byte, adrs adrsOperations, ou
twoPowerA := uint32(1 << pk.params.a)
var treeIDTimeTwoPowerA uint32
// TODO: use array to avoid heap allocation?
root := make([]byte, pk.params.n*pk.params.k)
rootPt := root
for treeID := range pk.params.k {
@ -78,20 +82,25 @@ func (pk *PublicKey) forsPkFromSig(md, signature []byte, adrs adrsOperations, ou
forspkADRS.clone(adrs)
forspkADRS.setTypeAndClear(AddressTypeFORSRoots)
forspkADRS.copyKeyPairAddress(adrs)
// compute FORS public key
pk.h.t(pk, forspkADRS, root, out)
clear(root)
return signature
}
// forsNode computes the root of a Merkle subtree of FORS public values.
//
// See FIPS 205 Algorithm 15 fors_node
func (sk *PrivateKey) forsNode(nodeID, layer uint32, adrs adrsOperations, out []byte) {
if layer == 0 {
// If the subtree consists of a signle leaf node, then it simply returns a hash of the node's
// private n-byte string.
sk.forsGenPrivateKey(nodeID, adrs, out)
adrs.setTreeHeight(0)
adrs.setTreeIndex(nodeID)
sk.h.f(&sk.PublicKey, adrs, out, out)
} else {
// otherwise, it computes the roots of the left subtree and right subtree
// and hashs them togeter.
var lnode, rnode [MAX_N]byte
sk.forsNode(nodeID*2, layer-1, adrs, lnode[:])
sk.forsNode(nodeID*2+1, layer-1, adrs, rnode[:])
@ -102,17 +111,19 @@ func (sk *PrivateKey) forsNode(nodeID, layer uint32, adrs adrsOperations, out []
}
// forsGenPrivateKey generates a FORS private key value.
//
// See FIPS 205 Algorithm 14 fors_skGen
func (sk *PrivateKey) forsGenPrivateKey(i uint32, adrs adrsOperations, out []byte) {
func (sk *PrivateKey) forsGenPrivateKey(idx uint32, adrs adrsOperations, out []byte) {
skADRS := sk.addressCreator()
skADRS.clone(adrs)
skADRS.setTypeAndClear(AddressTypeFORSPRF)
skADRS.copyKeyPairAddress(adrs)
skADRS.setTreeIndex(i)
skADRS.setTreeIndex(idx)
sk.h.prf(sk, skADRS, out)
}
// base2b computes the base-2^b representation of the input byte array.
//
// See FIPS 205 Algorithm 4 base_2^b
func base2b(in []byte, base uint32, out []uint32) {
var (

View File

@ -9,6 +9,7 @@ package slhdsa
import "crypto/subtle"
// htSign generates a hypertree signature.
//
// See FIPS 205 Algorithm 12 ht_sign
func (sk *PrivateKey) htSign(pkFors []byte, treeIdx uint64, leafIdx uint32, signature []byte) {
adrs := sk.addressCreator()
@ -19,13 +20,14 @@ func (sk *PrivateKey) htSign(pkFors []byte, treeIdx uint64, leafIdx uint32, sign
var rootBuf [MAX_N]byte
root := rootBuf[:sk.params.n]
copy(root, pkFors)
tmpBuf := make([]byte, sk.params.n*sk.params.len)
for j := range sk.params.d {
adrs.setLayerAddress(j)
adrs.setTreeAddress(treeIdx)
sk.xmssSign(root, leafIdx, adrs, signature)
sk.xmssSign(root, tmpBuf, leafIdx, adrs, signature)
if j < sk.params.d-1 {
sk.xmssPkFromSig(leafIdx, signature, root, adrs, root)
sk.xmssPkFromSig(leafIdx, signature, root, tmpBuf, adrs, root)
// hm least significant bits of treeIdx
leafIdx = uint32(treeIdx & mask)
// remove least significant hm bits from treeIdx
@ -36,6 +38,7 @@ func (sk *PrivateKey) htSign(pkFors []byte, treeIdx uint64, leafIdx uint32, sign
}
// htVerify verifies a hypertree signature.
//
// See FIPS 205 Algorithm 13 ht_verify
func (pk *PublicKey) htVerify(pkFors []byte, signature []byte, treeIdx uint64, leafIdx uint32) bool {
adrs := pk.addressCreator()
@ -46,10 +49,11 @@ func (pk *PublicKey) htVerify(pkFors []byte, signature []byte, treeIdx uint64, l
var rootBuf [MAX_N]byte
root := rootBuf[:pk.params.n]
copy(root, pkFors)
tmpBuf := make([]byte, pk.params.n*pk.params.len)
for j := range pk.params.d {
adrs.setLayerAddress(j)
adrs.setTreeAddress(treeIdx)
pk.xmssPkFromSig(leafIdx, signature, root, adrs, root)
pk.xmssPkFromSig(leafIdx, signature, root, tmpBuf, adrs, root)
// hm least significant bits of treeIdx
leafIdx = uint32(treeIdx & mask)
// remove least significant hm bits from treeIdx

View File

@ -101,10 +101,7 @@ func GenerateKey(rand io.Reader, params *params) (*PrivateKey, error) {
if _, err := io.ReadFull(rand, priv.PublicKey.seed[:params.n]); err != nil {
return nil, err
}
adrs := priv.addressCreator()
adrs.setLayerAddress(params.d - 1)
priv.xmssNode(priv.root[:], 0, params.hm, adrs)
return priv, nil
return generateKeyInernal(priv.seed[:], priv.prf[:], priv.PublicKey.seed[:], params)
}
// NewPrivateKey creates a new PrivateKey instance from the provided priv.seed||priv.prf||pub.seed||pub.root and parameters.
@ -115,17 +112,10 @@ func NewPrivateKey(bytes []byte, params *params) (*PrivateKey, error) {
if len(bytes) != 4*int(params.n) {
return nil, errors.New("slhdsa: invalid key length")
}
priv := &PrivateKey{}
if err := initKey(params, &priv.PublicKey); err != nil {
priv, err := generateKeyInernal(bytes[:params.n], bytes[params.n:2*params.n], bytes[2*params.n:3*params.n], params)
if err != nil {
return nil, err
}
copy(priv.seed[:], bytes[:params.n])
copy(priv.prf[:], bytes[params.n:2*params.n])
copy(priv.PublicKey.seed[:], bytes[2*params.n:3*params.n])
adrs := priv.addressCreator()
adrs.setLayerAddress(params.d - 1)
priv.xmssNode(priv.root[:], 0, params.hm, adrs)
if subtle.ConstantTimeCompare(priv.root[:params.n], bytes[3*params.n:]) != 1 {
return nil, errors.New("slhdsa: invalid key")
}
@ -160,7 +150,8 @@ func generateKeyInernal(skSeed, skPRF, pkSeed []byte, params *params) (*PrivateK
copy(priv.PublicKey.seed[:], pkSeed)
adrs := priv.addressCreator()
adrs.setLayerAddress(params.d - 1)
priv.xmssNode(priv.root[:], 0, params.hm, adrs)
tmpBuf := make([]byte, params.n*params.len)
priv.xmssNode(priv.root[:], tmpBuf, 0, params.hm, adrs)
return priv, nil
}

View File

@ -6,7 +6,9 @@
package slhdsa
// Chaining function used in WOTS
// Chaining function used in WOTS, it takes an n-byte inout and integer start and steps as input
// and returns the result of iterating a hash function F on the inout steps times, starting from start.
//
// See FIPS 205 Algorithm 5 wots_chain
func (pk *PublicKey) wotsChain(inout []byte, start, steps byte, addr adrsOperations) {
for i := start; i < start+steps; i++ {
@ -16,18 +18,20 @@ func (pk *PublicKey) wotsChain(inout []byte, start, steps byte, addr adrsOperati
}
// wotsPkGen generates a WOTS public key.
//
// See FIPS 205 Algorithm 6 wots_pkGen
func (sk *PrivateKey) wotsPkGen(out []byte, addr adrsOperations) {
func (sk *PrivateKey) wotsPkGen(out, tmpBuf []byte, addr adrsOperations) {
skADRS := sk.addressCreator()
skADRS.clone(addr)
skADRS.setTypeAndClear(AddressTypeWOTSPRF)
skADRS.copyKeyPairAddress(addr)
// TODO: use array to avoid heap allocation?
tmpBuf := make([]byte, sk.params.n*sk.params.len)
tmp := tmpBuf
for i := uint32(0); i < sk.params.len; i++ {
// compute [len] public values
for i := range sk.params.len {
// compute secret value for chain i
skADRS.setChainAddress(i)
sk.h.prf(sk, skADRS, tmp)
// compute public value for chain i
addr.setChainAddress(i)
sk.wotsChain(tmp, 0, 15, addr) // w = 16
tmp = tmp[sk.params.n:]
@ -36,11 +40,12 @@ func (sk *PrivateKey) wotsPkGen(out []byte, addr adrsOperations) {
wotspkADRS.clone(addr)
wotspkADRS.setTypeAndClear(AddressTypeWOTSPK)
wotspkADRS.copyKeyPairAddress(addr)
// compress public key
sk.h.t(&sk.PublicKey, wotspkADRS, tmpBuf, out)
clear(tmpBuf)
}
// wotsSign generates a WOTS signature on an n-byte message.
//
// See FIPS 205 Algorithm 10 wots_sign
func (sk *PrivateKey) wotsSign(m []byte, adrs adrsOperations, sigWots []byte) {
var msgAndCsum [MAX_WOTS_LEN]byte
@ -58,6 +63,7 @@ func (sk *PrivateKey) wotsSign(m []byte, adrs adrsOperations, sigWots []byte) {
msgAndCsum[len1+1] = byte(csum>>4) & 0x0F
msgAndCsum[len1+2] = byte(csum) & 0x0F
// copy address to create key generation key address
skADRS := sk.addressCreator()
skADRS.clone(adrs)
skADRS.setTypeAndClear(AddressTypeWOTSPRF)
@ -65,16 +71,19 @@ func (sk *PrivateKey) wotsSign(m []byte, adrs adrsOperations, sigWots []byte) {
for i := range sk.params.len {
skADRS.setChainAddress(i)
// compute chain i secret value
sk.h.prf(sk, skADRS, sigWots)
adrs.setChainAddress(i)
// compute chain i signature value
sk.wotsChain(sigWots, 0, msgAndCsum[i], adrs)
sigWots = sigWots[sk.params.n:]
}
}
// wotsPkFromSig computes a WOTS public key from a message and its signature
//
// See FIPS 205 Algorithm 8 wots_pkFromSig
func (pk *PublicKey) wotsPkFromSig(signature, m []byte, adrs adrsOperations, out []byte) {
func (pk *PublicKey) wotsPkFromSig(signature, m, tmpBuf []byte, adrs adrsOperations, out []byte) {
var msgAndCsum [MAX_WOTS_LEN]byte
// convert message to base w=16
bytes2nibbles(m, msgAndCsum[:])
@ -91,7 +100,6 @@ func (pk *PublicKey) wotsPkFromSig(signature, m []byte, adrs adrsOperations, out
msgAndCsum[len1+1] = byte(csum>>4) & 0x0F
msgAndCsum[len1+2] = byte(csum) & 0x0F
tmpBuf := make([]byte, pk.params.n*pk.params.len)
copy(tmpBuf, signature)
tmp := tmpBuf
for i := range pk.params.len {
@ -99,12 +107,13 @@ func (pk *PublicKey) wotsPkFromSig(signature, m []byte, adrs adrsOperations, out
pk.wotsChain(tmp, msgAndCsum[i], 15-msgAndCsum[i], adrs)
tmp = tmp[pk.params.n:]
}
// copy address to create WOTS+ public key address
wotspkADRS := pk.addressCreator()
wotspkADRS.clone(adrs)
wotspkADRS.setTypeAndClear(AddressTypeWOTSPK)
wotspkADRS.copyKeyPairAddress(adrs)
// compress public key
pk.h.t(pk, wotspkADRS, tmpBuf, out)
clear(tmpBuf)
}
func bytes2nibbles(in, out []byte) {

View File

@ -7,16 +7,19 @@
package slhdsa
// xmssNode computes the root of a Merkle subtree of WOTS public keys.
//
// See FIPS 205 Algorithm 9 xmss_node
func (sk *PrivateKey) xmssNode(out []byte, i, z uint32, adrs adrsOperations) {
func (sk *PrivateKey) xmssNode(out, tmpBuf []byte, i, z uint32, adrs adrsOperations) {
if z == 0 { // height 0
// if the subtree is the root of a subtree, then it simply returns the value of the node's WORTS+ public key
adrs.setTypeAndClear(AddressTypeWOTSHash)
adrs.setKeyPairAddress(i)
sk.wotsPkGen(out, adrs)
sk.wotsPkGen(out, tmpBuf, adrs)
} else {
// otherwise, it computes the root of the subtree by hashing the two child nodes
var lnode, rnode [MAX_N]byte
sk.xmssNode(lnode[:], 2*i, z-1, adrs)
sk.xmssNode(rnode[:], 2*i+1, z-1, adrs)
sk.xmssNode(lnode[:], tmpBuf, 2*i, z-1, adrs)
sk.xmssNode(rnode[:], tmpBuf, 2*i+1, z-1, adrs)
adrs.setTypeAndClear(AddressTypeTree)
adrs.setTreeHeight(z)
adrs.setTreeIndex(i)
@ -24,31 +27,36 @@ func (sk *PrivateKey) xmssNode(out []byte, i, z uint32, adrs adrsOperations) {
}
}
// xmssSign generates an XMSS signature.
// xmssSign generates an XMSS signature on an n-byte message pkFors by
// creating an authentication path and signing pkFors with the appropriate WORTS+ key.
//
// See FIPS 205 Algorithm 10 xmss_sign
func (sk *PrivateKey) xmssSign(pkFors []byte, leafIdx uint32, adrs adrsOperations, signature []byte) {
func (sk *PrivateKey) xmssSign(pkFors, tmpBuf []byte, leafIdx uint32, adrs adrsOperations, signature []byte) {
// build auth path, the auth path consists of the sibling nodes of each node that is on the path from the WOTS+ key used to the root
authStart := sk.params.n * sk.params.len
authPath := signature[authStart:]
leafIdxCopy := leafIdx
for j := range sk.params.hm {
sk.xmssNode(authPath, leafIdx^1, j, adrs)
sk.xmssNode(authPath, tmpBuf, leafIdx^1, j, adrs)
authPath = authPath[sk.params.n:]
leafIdx >>= 1
}
// compute WOTS+ signature
adrs.setTypeAndClear(AddressTypeWOTSHash)
adrs.setKeyPairAddress(leafIdxCopy)
sk.wotsSign(pkFors, adrs, signature)
}
// xmssPkFromSig computes an XMSS public key from an XMSS signature.
//
// See FIPS 205 Algorithm 11 xmss_pkFromSig
func (pk *PublicKey) xmssPkFromSig(leafIdx uint32, signature, m []byte, adrs adrsOperations, out []byte) {
func (pk *PublicKey) xmssPkFromSig(leafIdx uint32, signature, m, tmpBuf []byte, adrs adrsOperations, out []byte) {
// compute WOTS pk from WOTS signature
adrs.setTypeAndClear(AddressTypeWOTSHash)
adrs.setKeyPairAddress(leafIdx)
pk.wotsPkFromSig(signature, m, adrs, out)
pk.wotsPkFromSig(signature, m, tmpBuf, adrs, out)
// compute root from WOTS pk and AUTH
// compute root from WOTS pk and AUTH path
adrs.setTypeAndClear(AddressTypeTree)
signature = signature[pk.params.len*pk.params.n:] // auth path
for k := range pk.params.hm {