mirror of
https://github.com/emmansun/gmsm.git
synced 2025-09-16 03:43:50 +08:00
413 lines
12 KiB
Go
413 lines
12 KiB
Go
![]() |
// Copyright 2025 Sun Yimin. All rights reserved.
|
|||
|
// Use of this source code is governed by a MIT-style
|
|||
|
// license that can be found in the LICENSE file.
|
|||
|
// Code generated by generate.go. DO NOT EDIT.
|
|||
|
|
|||
|
//go:build go1.24
|
|||
|
|
|||
|
package mlkem
|
|||
|
|
|||
|
import (
|
|||
|
"bytes"
|
|||
|
"crypto/sha3"
|
|||
|
"crypto/subtle"
|
|||
|
"errors"
|
|||
|
"io"
|
|||
|
)
|
|||
|
|
|||
|
// A DecapsulationKey512 is the secret key used to decapsulate a shared key from a
|
|||
|
// ciphertext. It includes various precomputed values.
|
|||
|
type DecapsulationKey512 struct {
|
|||
|
d [32]byte // decapsulation key seed
|
|||
|
z [32]byte // implicit rejection sampling seed
|
|||
|
|
|||
|
ρ [32]byte // rho, sampleNTT seed for A, stored for the encapsulation key
|
|||
|
h [32]byte // H(ek), stored for ML-KEM.Decaps_internal
|
|||
|
|
|||
|
encryptionKey512
|
|||
|
decryptionKey512
|
|||
|
}
|
|||
|
|
|||
|
// Seed returns the decapsulation key as a 64-byte seed in the "d || z" form.
|
|||
|
//
|
|||
|
// The decapsulation key must be kept secret.
|
|||
|
func (dk *DecapsulationKey512) Seed() []byte {
|
|||
|
var b [SeedSize]byte
|
|||
|
copy(b[:], dk.d[:])
|
|||
|
copy(b[32:], dk.z[:])
|
|||
|
return b[:]
|
|||
|
}
|
|||
|
|
|||
|
// Bytes returns the decapsulation key as a byte slice
|
|||
|
// using the full expanded NIST encoding.
|
|||
|
func (dk *DecapsulationKey512) Bytes() []byte {
|
|||
|
b := make([]byte, 0, DecapsulationKeySize512)
|
|||
|
|
|||
|
// ByteEncode₁₂(s)
|
|||
|
for i := range dk.s {
|
|||
|
b = polyByteEncode(b, dk.s[i])
|
|||
|
}
|
|||
|
|
|||
|
// ByteEncode₁₂(t) || ρ
|
|||
|
for i := range dk.t {
|
|||
|
b = polyByteEncode(b, dk.t[i])
|
|||
|
}
|
|||
|
b = append(b, dk.ρ[:]...)
|
|||
|
|
|||
|
// H(ek) || z
|
|||
|
b = append(b, dk.h[:]...)
|
|||
|
b = append(b, dk.z[:]...)
|
|||
|
|
|||
|
return b
|
|||
|
}
|
|||
|
|
|||
|
// EncapsulationKey returns the public encapsulation key necessary to produce
|
|||
|
// ciphertexts.
|
|||
|
func (dk *DecapsulationKey512) EncapsulationKey() *EncapsulationKey512 {
|
|||
|
return &EncapsulationKey512{
|
|||
|
ρ: dk.ρ,
|
|||
|
h: dk.h,
|
|||
|
encryptionKey512: dk.encryptionKey512,
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// An EncapsulationKey512 is the public key used to produce ciphertexts to be
|
|||
|
// decapsulated by the corresponding [DecapsulationKey512].
|
|||
|
type EncapsulationKey512 struct {
|
|||
|
ρ [32]byte // sampleNTT seed for A
|
|||
|
h [32]byte // H(ek)
|
|||
|
encryptionKey512
|
|||
|
}
|
|||
|
|
|||
|
// Bytes returns the encapsulation key as a byte slice.
|
|||
|
func (ek *EncapsulationKey512) Bytes() []byte {
|
|||
|
// The actual logic is in a separate function to outline this allocation.
|
|||
|
b := make([]byte, 0, EncapsulationKeySize512)
|
|||
|
return ek.bytes(b)
|
|||
|
}
|
|||
|
|
|||
|
func (ek *EncapsulationKey512) bytes(b []byte) []byte {
|
|||
|
for i := range ek.t {
|
|||
|
b = polyByteEncode(b, ek.t[i])
|
|||
|
}
|
|||
|
b = append(b, ek.ρ[:]...)
|
|||
|
return b
|
|||
|
}
|
|||
|
|
|||
|
// encryptionKey512 is the parsed and expanded form of a PKE encryption key.
|
|||
|
type encryptionKey512 struct {
|
|||
|
t [k512]nttElement // ByteDecode₁₂(ek[:384k])
|
|||
|
a [k512 * k512]nttElement // A[i*k+j] = sampleNTT(ρ, j, i)
|
|||
|
}
|
|||
|
|
|||
|
// decryptionKey512 is the parsed and expanded form of a PKE decryption key.
|
|||
|
type decryptionKey512 struct {
|
|||
|
s [k512]nttElement // ByteDecode₁₂(dk[:decryptionKey512Size])
|
|||
|
}
|
|||
|
|
|||
|
// GenerateKey512 generates a new decapsulation key. The decapsulation key must be kept secret.
|
|||
|
// See FIPS 203, Algorithm 19.
|
|||
|
func GenerateKey512(rand io.Reader) (*DecapsulationKey512, error) {
|
|||
|
// The actual logic is in a separate function to outline this allocation.
|
|||
|
dk := &DecapsulationKey512{}
|
|||
|
return generateKey512(dk, rand)
|
|||
|
}
|
|||
|
|
|||
|
func generateKey512(dk *DecapsulationKey512, rand io.Reader) (*DecapsulationKey512, error) {
|
|||
|
var d [32]byte
|
|||
|
if _, err := io.ReadFull(rand, d[:]); err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
var z [32]byte
|
|||
|
if _, err := io.ReadFull(rand, z[:]); err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
kemKeyGen512(dk, &d, &z)
|
|||
|
return dk, nil
|
|||
|
}
|
|||
|
|
|||
|
// NewDecapsulationKeyFromSeed512 parses a decapsulation key from a 64-byte
|
|||
|
// seed in the "d || z" form. The seed must be uniformly random.
|
|||
|
func NewDecapsulationKeyFromSeed512(seed []byte) (*DecapsulationKey512, error) {
|
|||
|
// The actual logic is in a separate function to outline this allocation.
|
|||
|
dk := &DecapsulationKey512{}
|
|||
|
return newKeyFromSeed512(dk, seed)
|
|||
|
}
|
|||
|
|
|||
|
func newKeyFromSeed512(dk *DecapsulationKey512, seed []byte) (*DecapsulationKey512, error) {
|
|||
|
if len(seed) != SeedSize {
|
|||
|
return nil, errors.New("mlkem: invalid seed length")
|
|||
|
}
|
|||
|
d := (*[32]byte)(seed[:32])
|
|||
|
z := (*[32]byte)(seed[32:])
|
|||
|
kemKeyGen512(dk, d, z)
|
|||
|
|
|||
|
return dk, nil
|
|||
|
}
|
|||
|
|
|||
|
// NewDecapsulationKey512 parses a decapsulation key from its expanded NIST format.
|
|||
|
func NewDecapsulationKey512(b []byte) (*DecapsulationKey512, error) {
|
|||
|
if len(b) != DecapsulationKeySize512 {
|
|||
|
return nil, errors.New("mlkem: invalid decapsulation key length")
|
|||
|
}
|
|||
|
|
|||
|
dk := &DecapsulationKey512{}
|
|||
|
for i := range dk.s {
|
|||
|
var err error
|
|||
|
dk.s[i], err = polyByteDecode[nttElement](b[:encodingSize12])
|
|||
|
if err != nil {
|
|||
|
return nil, errors.New("mlkem: invalid secret key encoding")
|
|||
|
}
|
|||
|
b = b[encodingSize12:]
|
|||
|
}
|
|||
|
|
|||
|
ek, err := NewEncapsulationKey512(b[:EncapsulationKeySize512])
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
dk.ρ = ek.ρ
|
|||
|
dk.h = ek.h
|
|||
|
dk.encryptionKey512 = ek.encryptionKey512
|
|||
|
b = b[EncapsulationKeySize512:]
|
|||
|
|
|||
|
if !bytes.Equal(dk.h[:], b[:32]) {
|
|||
|
return nil, errors.New("mlkem: inconsistent H(ek) in encoded bytes")
|
|||
|
}
|
|||
|
|
|||
|
copy(dk.z[:], b[32:])
|
|||
|
|
|||
|
return dk, nil
|
|||
|
}
|
|||
|
|
|||
|
// kemKeyGen512 generates a decapsulation key.
|
|||
|
//
|
|||
|
// It implements ML-KEM.KeyGen_internal according to FIPS 203, Algorithm 16, and
|
|||
|
// K-PKE.KeyGen according to FIPS 203, Algorithm 13. The two are merged to save
|
|||
|
// copies and allocations.
|
|||
|
func kemKeyGen512(dk *DecapsulationKey512, d, z *[32]byte) {
|
|||
|
dk.d = *d
|
|||
|
dk.z = *z
|
|||
|
|
|||
|
g := sha3.New512()
|
|||
|
g.Write(d[:])
|
|||
|
g.Write([]byte{k512}) // Module dimension as a domain separator.
|
|||
|
G := g.Sum(make([]byte, 0, 64))
|
|||
|
ρ, σ := G[:32], G[32:] // rho, sigma
|
|||
|
dk.ρ = [32]byte(ρ)
|
|||
|
|
|||
|
A := &dk.a
|
|||
|
for i := range byte(k512) {
|
|||
|
for j := range byte(k512) {
|
|||
|
A[i*k512+j] = sampleNTT(ρ, j, i)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
var N byte
|
|||
|
s := &dk.s
|
|||
|
for i := range s {
|
|||
|
s[i] = ntt(samplePolyCBD(σ, N, η1_512))
|
|||
|
N++
|
|||
|
}
|
|||
|
e := make([]nttElement, k512)
|
|||
|
for i := range e {
|
|||
|
e[i] = ntt(samplePolyCBD(σ, N, η1_512))
|
|||
|
N++
|
|||
|
}
|
|||
|
|
|||
|
t := &dk.t
|
|||
|
for i := range t { // t = A ◦ s + e
|
|||
|
t[i] = e[i]
|
|||
|
for j := range s {
|
|||
|
t[i] = polyAdd(t[i], nttMul(A[i*k512+j], s[j]))
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
H := sha3.New256()
|
|||
|
ek := dk.EncapsulationKey().Bytes()
|
|||
|
H.Write(ek)
|
|||
|
H.Sum(dk.h[:0])
|
|||
|
}
|
|||
|
|
|||
|
// Encapsulate generates a shared key and an associated ciphertext from an
|
|||
|
// encapsulation key.
|
|||
|
//
|
|||
|
// The shared key must be kept secret. See FIPS 203, Algorithm 20.
|
|||
|
func (ek *EncapsulationKey512) Encapsulate(rand io.Reader) (sharedKey, ciphertext []byte, err error) {
|
|||
|
// The actual logic is in a separate function to outline this allocation.
|
|||
|
var cc [CiphertextSize512]byte
|
|||
|
return ek.encapsulate(&cc, rand)
|
|||
|
}
|
|||
|
|
|||
|
func (ek *EncapsulationKey512) encapsulate(cc *[CiphertextSize512]byte, rand io.Reader) (sharedKey, ciphertext []byte, err error) {
|
|||
|
var m [messageSize]byte
|
|||
|
if _, err := io.ReadFull(rand, m[:]); err != nil {
|
|||
|
return nil, nil, err
|
|||
|
}
|
|||
|
sharedKey, ciphertext = kemEncaps512(cc, ek, &m)
|
|||
|
return sharedKey, ciphertext, nil
|
|||
|
}
|
|||
|
|
|||
|
// EncapsulateInternal is a derandomized version of Encapsulate, exclusively for
|
|||
|
// use in tests.
|
|||
|
func (ek *EncapsulationKey512) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
|
|||
|
cc := &[CiphertextSize512]byte{}
|
|||
|
return kemEncaps512(cc, ek, m)
|
|||
|
}
|
|||
|
|
|||
|
// kemEncaps512 generates a shared key and an associated ciphertext.
|
|||
|
//
|
|||
|
// It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
|
|||
|
func kemEncaps512(cc *[CiphertextSize512]byte, ek *EncapsulationKey512, m *[messageSize]byte) (K, c []byte) {
|
|||
|
g := sha3.New512()
|
|||
|
g.Write(m[:])
|
|||
|
g.Write(ek.h[:])
|
|||
|
G := g.Sum(nil)
|
|||
|
K, r := G[:SharedKeySize], G[SharedKeySize:]
|
|||
|
c = pkeEncrypt512(cc, &ek.encryptionKey512, m, r)
|
|||
|
return K, c
|
|||
|
}
|
|||
|
|
|||
|
// NewEncapsulationKey512 parses an encapsulation key from its encoded form.
|
|||
|
// If the encapsulation key is not valid, NewEncapsulationKey512 returns an error.
|
|||
|
func NewEncapsulationKey512(encapsulationKey []byte) (*EncapsulationKey512, error) {
|
|||
|
// The actual logic is in a separate function to outline this allocation.
|
|||
|
ek := &EncapsulationKey512{}
|
|||
|
return parseEK512(ek, encapsulationKey)
|
|||
|
}
|
|||
|
|
|||
|
// parseEK512 parses an encryption key from its encoded form.
|
|||
|
//
|
|||
|
// It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
|
|||
|
// Algorithm 14.
|
|||
|
func parseEK512(ek *EncapsulationKey512, ekPKE []byte) (*EncapsulationKey512, error) {
|
|||
|
if len(ekPKE) != EncapsulationKeySize512 {
|
|||
|
return nil, errors.New("mlkem: invalid encapsulation key length")
|
|||
|
}
|
|||
|
|
|||
|
h := sha3.New256()
|
|||
|
h.Write(ekPKE)
|
|||
|
h.Sum(ek.h[:0])
|
|||
|
|
|||
|
for i := range ek.t {
|
|||
|
var err error
|
|||
|
ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
ekPKE = ekPKE[encodingSize12:]
|
|||
|
}
|
|||
|
copy(ek.ρ[:], ekPKE)
|
|||
|
|
|||
|
for i := range byte(k512) {
|
|||
|
for j := range byte(k512) {
|
|||
|
ek.a[i*k512+j] = sampleNTT(ek.ρ[:], j, i)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
return ek, nil
|
|||
|
}
|
|||
|
|
|||
|
// pkeEncrypt512 encrypt a plaintext message.
|
|||
|
//
|
|||
|
// It implements K-PKE.Encrypt according to FIPS 203, Algorithm 14, although the
|
|||
|
// computation of t and AT is done in parseEK512.
|
|||
|
func pkeEncrypt512(cc *[CiphertextSize512]byte, ex *encryptionKey512, m *[messageSize]byte, rnd []byte) []byte {
|
|||
|
var N byte
|
|||
|
r, e1 := make([]nttElement, k512), make([]ringElement, k512)
|
|||
|
for i := range r {
|
|||
|
r[i] = ntt(samplePolyCBD(rnd, N, η1_512))
|
|||
|
N++
|
|||
|
}
|
|||
|
for i := range e1 {
|
|||
|
e1[i] = samplePolyCBD(rnd, N, η2_512)
|
|||
|
N++
|
|||
|
}
|
|||
|
e2 := samplePolyCBD(rnd, N, η2_512)
|
|||
|
|
|||
|
u := make([]ringElement, k512) // NTT⁻¹(AT ◦ r) + e1
|
|||
|
for i := range u {
|
|||
|
u[i] = e1[i]
|
|||
|
for j := range r {
|
|||
|
// Note that i and j are inverted, as we need the transposed of A.
|
|||
|
u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k512+i], r[j])))
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
μ := ringDecodeAndDecompress1(m)
|
|||
|
|
|||
|
var vNTT nttElement // t⊺ ◦ r
|
|||
|
for i := range ex.t {
|
|||
|
vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
|
|||
|
}
|
|||
|
v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
|
|||
|
|
|||
|
c := cc[:0]
|
|||
|
for _, f := range u {
|
|||
|
c = ringCompressAndEncode10(c, f)
|
|||
|
}
|
|||
|
c = ringCompressAndEncode4(c, v)
|
|||
|
|
|||
|
return c
|
|||
|
}
|
|||
|
|
|||
|
// Decapsulate generates a shared key from a ciphertext and a decapsulation key.
|
|||
|
// If the ciphertext is not valid, Decapsulate returns an error.
|
|||
|
//
|
|||
|
// The shared key must be kept secret.
|
|||
|
func (dk *DecapsulationKey512) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
|
|||
|
if len(ciphertext) != CiphertextSize512 {
|
|||
|
return nil, errors.New("mlkem: invalid ciphertext length")
|
|||
|
}
|
|||
|
c := (*[CiphertextSize512]byte)(ciphertext)
|
|||
|
// Note that the hash check (step 3 of the decapsulation input check from
|
|||
|
// FIPS 203, Section 7.3) is foregone as a DecapsulationKey is always
|
|||
|
// validly generated by ML-KEM.KeyGen_internal.
|
|||
|
return kemDecaps512(dk, c), nil
|
|||
|
}
|
|||
|
|
|||
|
// kemDecaps512 produces a shared key from a ciphertext.
|
|||
|
//
|
|||
|
// It implements ML-KEM.Decaps_internal according to FIPS 203, Algorithm 18.
|
|||
|
func kemDecaps512(dk *DecapsulationKey512, c *[CiphertextSize512]byte) (K []byte) {
|
|||
|
m := pkeDecrypt512(&dk.decryptionKey512, c)
|
|||
|
g := sha3.New512()
|
|||
|
g.Write(m[:])
|
|||
|
g.Write(dk.h[:])
|
|||
|
G := g.Sum(make([]byte, 0, 64))
|
|||
|
Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
|
|||
|
J := sha3.NewSHAKE256()
|
|||
|
J.Write(dk.z[:])
|
|||
|
J.Write(c[:])
|
|||
|
Kout := make([]byte, SharedKeySize)
|
|||
|
J.Read(Kout)
|
|||
|
var cc [CiphertextSize512]byte
|
|||
|
c1 := pkeEncrypt512(&cc, &dk.encryptionKey512, (*[32]byte)(m), r)
|
|||
|
|
|||
|
subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
|
|||
|
return Kout
|
|||
|
}
|
|||
|
|
|||
|
// pkeDecrypt512 decrypts a ciphertext.
|
|||
|
//
|
|||
|
// It implements K-PKE.Decrypt according to FIPS 203, Algorithm 15,
|
|||
|
// although s is retained from kemKeyGen512.
|
|||
|
func pkeDecrypt512(dx *decryptionKey512, c *[CiphertextSize512]byte) []byte {
|
|||
|
u := make([]ringElement, k512)
|
|||
|
for i := range u {
|
|||
|
b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)])
|
|||
|
u[i] = ringDecodeAndDecompress10(b)
|
|||
|
}
|
|||
|
|
|||
|
b := (*[encodingSize4]byte)(c[encodingSize10*k512:])
|
|||
|
v := ringDecodeAndDecompress4(b)
|
|||
|
|
|||
|
var mask nttElement // s⊺ ◦ NTT(u)
|
|||
|
for i := range dx.s {
|
|||
|
mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
|
|||
|
}
|
|||
|
w := polySub(v, inverseNTT(mask))
|
|||
|
|
|||
|
return ringCompressAndEncode1(nil, w)
|
|||
|
}
|