// Copyright 2025 Sun Yimin. All rights reserved. // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. //go:build go1.24 package mldsa import ( "crypto/sha256" "crypto/sha3" "crypto/sha512" "encoding/asn1" "errors" "hash" "github.com/emmansun/gmsm/sm3" ) var ( // Digest Algorithms OIDDigestAlgorithmSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1} OIDDigestAlgorithmSHA512 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 3} OIDDigestAlgorithmSHA3_256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 8} OIDDigestAlgorithmSHA3_384 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 9} OIDDigestAlgorithmSHA3_512 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 10} OIDDigestAlgorithmSHAKE128 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 11} OIDDigestAlgorithmSHAKE256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 12} OIDDigestAlgorithmSM3 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 401} ) var ErrUnsupportedDigestAlgorithm = errors.New("mldsa: unsupported digest algorithm") type xofHashAdapter struct { *sha3.SHAKE size int } func (h *xofHashAdapter) Write(p []byte) (n int, err error) { return h.SHAKE.Write(p) } func (h *xofHashAdapter) Reset() { h.SHAKE.Reset() } func (h *xofHashAdapter) Size() int { return h.size } func (x *xofHashAdapter) BlockSize() int { return x.SHAKE.BlockSize() } func (x *xofHashAdapter) Sum(b []byte) []byte { buf := make([]byte, x.size) x.Read(buf) return append(b, buf...) } func getHashByOID(oid asn1.ObjectIdentifier) (hash.Hash, error) { switch { case oid.Equal(OIDDigestAlgorithmSHA256): return sha256.New(), nil case oid.Equal(OIDDigestAlgorithmSHA512): return sha512.New(), nil case oid.Equal(OIDDigestAlgorithmSHA3_256): return sha3.New256(), nil case oid.Equal(OIDDigestAlgorithmSHA3_384): return sha3.New384(), nil case oid.Equal(OIDDigestAlgorithmSHA3_512): return sha3.New512(), nil case oid.Equal(OIDDigestAlgorithmSHAKE128): return &xofHashAdapter{sha3.NewSHAKE128(), 32}, nil case oid.Equal(OIDDigestAlgorithmSHAKE256): return &xofHashAdapter{sha3.NewSHAKE256(), 64}, nil case oid.Equal(OIDDigestAlgorithmSM3): return sm3.New(), nil default: return nil, ErrUnsupportedDigestAlgorithm } } func preHash(oid asn1.ObjectIdentifier, data []byte) ([]byte, error) { h, err := getHashByOID(oid) if err != nil { return nil, err } h.Write(data) oidBytes, _ := asn1.Marshal(oid) return h.Sum(oidBytes), nil }