notify/security_psk.go
starainrt 09d972c7b7
feat(notify): 重构通信内核并补齐 stream/bulk/record/transfer 能力
- 引入 LogicalConn/TransportConn 分层,ClientConn 保留兼容适配层
  - 新增 Stream、Bulk、RecordStream 三条数据面能力及对应控制路径
  - 完成 transfer/file 传输内核与状态快照、诊断能力
  - 补齐 reconnect、inbound dispatcher、modern psk 等基础模块
  - 增加大规模回归、并发与基准测试覆盖
  - 更新依赖库
2026-04-15 15:24:36 +08:00

382 lines
11 KiB
Go

package notify
import (
"bytes"
"crypto/aes"
"crypto/cipher"
cryptorand "crypto/rand"
"encoding/binary"
"errors"
"log"
"sync"
"sync/atomic"
"b612.me/starcrypto"
)
var (
errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty")
errModernPSKPayload = errors.New("invalid modern psk payload")
errModernPSKRequired = errors.New("modern psk is required: call UseModernPSKClient/UseModernPSKServer or set a transport key before Connect/Listen")
)
var (
modernPSKMagic = []byte("NPS1")
defaultModernPSKSalt = []byte("b612.me/notify/psk/aes-gcm/v1")
defaultModernPSKAAD = []byte("b612.me/notify/psk-envelope/v1")
)
const modernPSKNonceSize = 12
type transportFastStreamEncoder func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error)
type transportFastBulkEncoder func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error)
type transportFastPlainEncoder func(secretKey []byte, plainLen int, fill func([]byte) error) ([]byte, error)
type modernPSKTransportBundle struct {
msgEn func([]byte, []byte) []byte
msgDe func([]byte, []byte) []byte
fastStreamEncode transportFastStreamEncoder
fastBulkEncode transportFastBulkEncoder
fastPlainEncode transportFastPlainEncoder
}
// ModernPSKOptions configures the modern PSK transport profile.
//
// The current profile derives a 32-byte transport key with Argon2id and uses
// AES-GCM with a per-codec nonce prefix plus a per-message counter.
type ModernPSKOptions struct {
Salt []byte
AAD []byte
Argon2Params starcrypto.Argon2Params
}
// DefaultModernPSKOptions returns the recommended settings for the current
// PSK transport profile.
func DefaultModernPSKOptions() ModernPSKOptions {
return ModernPSKOptions{
Salt: bytes.Clone(defaultModernPSKSalt),
AAD: bytes.Clone(defaultModernPSKAAD),
Argon2Params: starcrypto.DefaultArgon2idParams(),
}
}
func defaultModernPSKCodecs() (func([]byte, []byte) []byte, func([]byte, []byte) []byte) {
bundle := defaultModernPSKTransportBundle()
return bundle.msgEn, bundle.msgDe
}
func defaultModernPSKTransportBundle() modernPSKTransportBundle {
return buildModernPSKTransportBundle(defaultModernPSKAAD)
}
// UseModernPSKClient configures a client to use the modern PSK transport
// profile.
//
// It disables the legacy RSA key-exchange path, derives a transport key with
// Argon2id, and switches message protection to AES-GCM. Configure it before
// calling Connect/ConnectTimeout.
func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
key, aad, err := deriveModernPSKKey(sharedSecret, opts)
if err != nil {
return err
}
transport := buildModernPSKTransportBundle(aad)
c.SetSecretKey(key)
c.SetMsgEn(transport.msgEn)
c.SetMsgDe(transport.msgDe)
if client, ok := c.(*ClientCommon); ok {
client.fastStreamEncode = transport.fastStreamEncode
client.fastBulkEncode = transport.fastBulkEncode
client.fastPlainEncode = transport.fastPlainEncode
}
c.SetSkipExchangeKey(true)
return nil
}
// UseModernPSKServer configures a server to use the modern PSK transport
// profile for newly accepted connections.
//
// It derives a transport key with Argon2id and switches message protection to
// AES-GCM. Configure it before calling Listen.
func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
key, aad, err := deriveModernPSKKey(sharedSecret, opts)
if err != nil {
return err
}
transport := buildModernPSKTransportBundle(aad)
s.SetSecretKey(key)
s.SetDefaultCommEncode(transport.msgEn)
s.SetDefaultCommDecode(transport.msgDe)
if server, ok := s.(*ServerCommon); ok {
server.defaultFastStreamEncode = transport.fastStreamEncode
server.defaultFastBulkEncode = transport.fastBulkEncode
server.defaultFastPlainEncode = transport.fastPlainEncode
}
return nil
}
// UseLegacySecurityClient restores the legacy RSA key-exchange plus AES-CFB
// transport profile.
//
// It is kept only as an explicit fallback path for existing deployments.
func UseLegacySecurityClient(c Client) {
c.SetSecretKey(bytes.Clone(defaultAesKey))
c.SetMsgEn(defaultMsgEn)
c.SetMsgDe(defaultMsgDe)
if client, ok := c.(*ClientCommon); ok {
client.fastStreamEncode = nil
client.fastBulkEncode = nil
client.fastPlainEncode = nil
}
c.SetSkipExchangeKey(false)
c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey))
}
// UseLegacySecurityServer restores the legacy RSA key-exchange plus AES-CFB
// transport profile for newly accepted connections.
//
// It is kept only as an explicit fallback path for existing deployments.
func UseLegacySecurityServer(s Server) {
s.SetSecretKey(bytes.Clone(defaultAesKey))
s.SetDefaultCommEncode(defaultMsgEn)
s.SetDefaultCommDecode(defaultMsgDe)
if server, ok := s.(*ServerCommon); ok {
server.defaultFastStreamEncode = nil
server.defaultFastBulkEncode = nil
server.defaultFastPlainEncode = nil
}
s.SetRsaPrivKey(bytes.Clone(defaultRsaKey))
}
func deriveModernPSKKey(sharedSecret []byte, opts *ModernPSKOptions) ([]byte, []byte, error) {
if len(sharedSecret) == 0 {
return nil, nil, errModernPSKSecretEmpty
}
cfg := normalizeModernPSKOptions(opts)
key, err := starcrypto.DeriveArgon2idKey(string(sharedSecret), cfg.Salt, cfg.Argon2Params)
if err != nil {
return nil, nil, err
}
return key, cfg.AAD, nil
}
func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions {
cfg := DefaultModernPSKOptions()
if opts == nil {
return cfg
}
if len(opts.Salt) > 0 {
cfg.Salt = bytes.Clone(opts.Salt)
}
if opts.AAD != nil {
cfg.AAD = bytes.Clone(opts.AAD)
}
if opts.Argon2Params.Time != 0 && opts.Argon2Params.Memory != 0 &&
opts.Argon2Params.Threads != 0 && opts.Argon2Params.KeyLen != 0 {
cfg.Argon2Params = opts.Argon2Params
}
return cfg
}
func buildModernPSKCodecs(aad []byte) (func([]byte, []byte) []byte, func([]byte, []byte) []byte) {
bundle := buildModernPSKTransportBundle(aad)
return bundle.msgEn, bundle.msgDe
}
func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle {
aadCopy := bytes.Clone(aad)
cache := &modernPSKCodecCache{}
msgEn := func(key []byte, plain []byte) []byte {
runtime, err := cache.runtimeForKey(key)
if err != nil {
log.Print(err)
return nil
}
out, err := runtime.sealPlainPayload(aadCopy, plain)
if err != nil {
log.Print(err)
return nil
}
return out
}
msgDe := func(key []byte, encrypted []byte) []byte {
headerLen := len(modernPSKMagic) + modernPSKNonceSize
if len(encrypted) < headerLen {
log.Print(errModernPSKPayload)
return nil
}
if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) {
log.Print(errModernPSKPayload)
return nil
}
runtime, err := cache.runtimeForKey(key)
if err != nil {
log.Print(err)
return nil
}
nonce := encrypted[len(modernPSKMagic):headerLen]
ciphertext := encrypted[headerLen:]
plain, err := runtime.aead.Open(make([]byte, 0, len(ciphertext)), nonce, ciphertext, aadCopy)
if err != nil {
log.Print(err)
return nil
}
return plain
}
fastStreamEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
runtime, err := cache.runtimeForKey(key)
if err != nil {
return nil, err
}
return runtime.sealStreamFastPayload(aadCopy, dataID, seq, payload)
}
fastBulkEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
runtime, err := cache.runtimeForKey(key)
if err != nil {
return nil, err
}
return runtime.sealBulkFastPayload(aadCopy, dataID, seq, payload)
}
fastPlainEncode := func(key []byte, plainLen int, fill func([]byte) error) ([]byte, error) {
runtime, err := cache.runtimeForKey(key)
if err != nil {
return nil, err
}
return runtime.sealFilledPayload(aadCopy, plainLen, fill)
}
return modernPSKTransportBundle{
msgEn: msgEn,
msgDe: msgDe,
fastStreamEncode: fastStreamEncode,
fastBulkEncode: fastBulkEncode,
fastPlainEncode: fastPlainEncode,
}
}
func (c *ClientCommon) validateSecurityConfiguration() error {
if c.securityReadyCheck && len(c.SecretKey) == 0 {
return errModernPSKRequired
}
return nil
}
func (s *ServerCommon) validateSecurityConfiguration() error {
if s.securityReadyCheck && len(s.SecretKey) == 0 {
return errModernPSKRequired
}
return nil
}
type modernPSKCodecCache struct {
mu sync.Mutex
key []byte
runtime *modernPSKCodecRuntime
}
type modernPSKCodecRuntime struct {
aead cipher.AEAD
prefix [modernPSKNonceSize - 8]byte
seq atomic.Uint64
}
func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime, error) {
if c == nil {
return nil, errModernPSKSecretEmpty
}
c.mu.Lock()
defer c.mu.Unlock()
if c.runtime != nil && bytes.Equal(c.key, key) {
return c.runtime, nil
}
runtime, err := newModernPSKCodecRuntime(key)
if err != nil {
return nil, err
}
c.key = bytes.Clone(key)
c.runtime = runtime
return runtime, nil
}
func newModernPSKCodecRuntime(key []byte) (*modernPSKCodecRuntime, error) {
if len(key) == 0 {
return nil, errModernPSKSecretEmpty
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
runtime := &modernPSKCodecRuntime{
aead: aead,
}
if _, err := cryptorand.Read(runtime.prefix[:]); err != nil {
return nil, err
}
return runtime, nil
}
func (r *modernPSKCodecRuntime) nextNonce() [modernPSKNonceSize]byte {
var nonce [modernPSKNonceSize]byte
if r == nil {
return nonce
}
copy(nonce[:len(r.prefix)], r.prefix[:])
binary.BigEndian.PutUint64(nonce[len(r.prefix):], r.seq.Add(1))
return nonce
}
func (r *modernPSKCodecRuntime) sealStreamFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
return r.sealFilledPayload(aad, streamFastPayloadHeaderLen+len(payload), func(frame []byte) error {
if err := encodeStreamFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
return err
}
copy(frame[streamFastPayloadHeaderLen:], payload)
return nil
})
}
func (r *modernPSKCodecRuntime) sealBulkFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
if r == nil {
return nil, errTransportPayloadEncryptFailed
}
return r.sealFilledPayload(aad, bulkFastPayloadHeaderLen+len(payload), func(frame []byte) error {
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
return err
}
copy(frame[bulkFastPayloadHeaderLen:], payload)
return nil
})
}
func (r *modernPSKCodecRuntime) sealPlainPayload(aad []byte, plain []byte) ([]byte, error) {
return r.sealFilledPayload(aad, len(plain), func(dst []byte) error {
copy(dst, plain)
return nil
})
}
func (r *modernPSKCodecRuntime) sealFilledPayload(aad []byte, plainLen int, fill func([]byte) error) ([]byte, error) {
if r == nil {
return nil, errTransportPayloadEncryptFailed
}
if plainLen < 0 {
return nil, errTransportPayloadEncryptFailed
}
nonce := r.nextNonce()
headerLen := len(modernPSKMagic) + modernPSKNonceSize
out := make([]byte, headerLen+plainLen+r.aead.Overhead())
copy(out[:len(modernPSKMagic)], modernPSKMagic)
copy(out[len(modernPSKMagic):headerLen], nonce[:])
frame := out[headerLen : headerLen+plainLen]
if fill != nil {
if err := fill(frame); err != nil {
return nil, err
}
}
sealed := r.aead.Seal(frame[:0], nonce[:], frame, aad)
return out[:headerLen+len(sealed)], nil
}