notify/security_psk.go
starainrt 98ef9e7fcc
feat(transport): 完成安全架构拆分并收口 stream/bulk 传输优化
- 新增 managed/external/nested 三种传输保护模式
  - 新增 peer attach 显式认证、抗重放、channel binding 和可选前向保密协商
  - 明确单连接注入与可重拨连接源的语义边界
  - 禁止 ConnectByConn 场景下 dedicated bulk 走 sidecar,auto 模式自动回退 shared
  - 修正 dedicated attach 在 bootstrap/steady profile 切换下的处理逻辑
  - 优化 shared bulk super-batch 与批量 framed write 路径
  - 降低 stream/bulk fast path 的复制和分发损耗
  - 补齐 benchmark、回归测试、运行时快照和 README 文档
2026-04-20 16:35:44 +08:00

644 lines
20 KiB
Go

package notify
import (
"bytes"
"crypto/aes"
"crypto/cipher"
cryptorand "crypto/rand"
"encoding/binary"
"errors"
"log"
"sync"
"sync/atomic"
"b612.me/starcrypto"
)
var (
errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty")
errModernPSKPayload = errors.New("invalid modern psk payload")
errModernPSKRequired = errors.New("transport security is required: call UseModernPSKClient/UseModernPSKServer, UsePSKOverExternalTransportClient/Server, or set a transport key before Connect/Listen")
errModernPSKForwardSecrecyUnsupported = errors.New("forward secrecy is unsupported for external transport protection")
)
var (
modernPSKMagic = []byte("NPS1")
defaultModernPSKSalt = []byte("b612.me/notify/psk/aes-gcm/v1")
defaultModernPSKAAD = []byte("b612.me/notify/psk-envelope/v1")
)
const modernPSKNonceSize = 12
type transportFastStreamEncoder func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error)
type transportFastBulkEncoder func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error)
type transportFastPlainEncoder func(secretKey []byte, plainLen int, fill func([]byte) error) ([]byte, error)
type modernPSKTransportBundle struct {
msgEn func([]byte, []byte) []byte
msgDe func([]byte, []byte) []byte
fastStreamEncode transportFastStreamEncoder
fastBulkEncode transportFastBulkEncoder
fastPlainEncode transportFastPlainEncoder
}
var modernPSKPayloadPool sync.Pool
// ModernPSKOptions configures the modern PSK transport profile.
//
// The current profile derives a 32-byte transport key with Argon2id and uses
// AES-GCM with a per-codec nonce prefix plus a per-message counter.
type ModernPSKOptions struct {
Salt []byte
AAD []byte
Argon2Params starcrypto.Argon2Params
RequireForwardSecrecy bool
}
// DefaultModernPSKOptions returns the recommended settings for the current
// PSK transport profile.
func DefaultModernPSKOptions() ModernPSKOptions {
return ModernPSKOptions{
Salt: bytes.Clone(defaultModernPSKSalt),
AAD: bytes.Clone(defaultModernPSKAAD),
Argon2Params: starcrypto.DefaultArgon2idParams(),
}
}
func defaultModernPSKCodecs() (func([]byte, []byte) []byte, func([]byte, []byte) []byte) {
bundle := defaultModernPSKTransportBundle()
return bundle.msgEn, bundle.msgDe
}
func defaultModernPSKTransportBundle() modernPSKTransportBundle {
return buildModernPSKTransportBundle(defaultModernPSKAAD)
}
// UseModernPSKClient configures a client to use the modern PSK transport
// profile.
//
// It disables the legacy RSA key-exchange path, derives a transport key with
// Argon2id, and switches message protection to AES-GCM. Configure it before
// calling Connect/ConnectTimeout.
func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
if err != nil {
return err
}
if client, ok := c.(*ClientCommon); ok {
client.configureClientSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy)
return nil
}
c.SetSecretKey(managed.secretKey)
c.SetMsgEn(managed.msgEn)
c.SetMsgDe(managed.msgDe)
c.SetSkipExchangeKey(true)
return nil
}
// UseModernPSKServer configures a server to use the modern PSK transport
// profile for newly accepted connections.
//
// It derives a transport key with Argon2id and switches message protection to
// AES-GCM. Configure it before calling Listen.
func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
if err != nil {
return err
}
if server, ok := s.(*ServerCommon); ok {
server.configureServerSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy)
return nil
}
s.SetSecretKey(managed.secretKey)
s.SetDefaultCommEncode(managed.msgEn)
s.SetDefaultCommDecode(managed.msgDe)
return nil
}
// UsePSKOverExternalTransportClient authenticates bootstrap with PSK and then
// trusts the external channel for steady-state payload protection.
func UsePSKOverExternalTransportClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
if opts != nil && opts.RequireForwardSecrecy {
return errModernPSKForwardSecrecyUnsupported
}
bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
if err != nil {
return err
}
steady := buildExternalProtectionProfile(bootstrap.secretKey)
if client, ok := c.(*ClientCommon); ok {
client.configureClientSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false)
return nil
}
c.SetSecretKey(bootstrap.secretKey)
c.SetMsgEn(bootstrap.msgEn)
c.SetMsgDe(bootstrap.msgDe)
c.SetSkipExchangeKey(true)
return nil
}
// UsePSKOverExternalTransportServer authenticates bootstrap with PSK and then
// trusts the external channel for steady-state payload protection.
func UsePSKOverExternalTransportServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
if opts != nil && opts.RequireForwardSecrecy {
return errModernPSKForwardSecrecyUnsupported
}
bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
if err != nil {
return err
}
steady := buildExternalProtectionProfile(bootstrap.secretKey)
if server, ok := s.(*ServerCommon); ok {
server.configureServerSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false)
return nil
}
s.SetSecretKey(bootstrap.secretKey)
s.SetDefaultCommEncode(bootstrap.msgEn)
s.SetDefaultCommDecode(bootstrap.msgDe)
return nil
}
// UseNestedSecurityClient keeps notify transport protection enabled even when
// the physical connection is already protected by an outer trusted channel.
func UseNestedSecurityClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested)
if err != nil {
return err
}
if client, ok := c.(*ClientCommon); ok {
client.configureClientSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy)
return nil
}
c.SetSecretKey(managed.secretKey)
c.SetMsgEn(managed.msgEn)
c.SetMsgDe(managed.msgDe)
c.SetSkipExchangeKey(true)
return nil
}
// UseNestedSecurityServer keeps notify transport protection enabled even when
// the physical connection is already protected by an outer trusted channel.
func UseNestedSecurityServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested)
if err != nil {
return err
}
if server, ok := s.(*ServerCommon); ok {
server.configureServerSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy)
return nil
}
s.SetSecretKey(managed.secretKey)
s.SetDefaultCommEncode(managed.msgEn)
s.SetDefaultCommDecode(managed.msgDe)
return nil
}
// UseLegacySecurityClient restores the legacy RSA key-exchange plus AES-CFB
// transport profile.
//
// It is kept only as an explicit fallback path for existing deployments.
func UseLegacySecurityClient(c Client) {
if client, ok := c.(*ClientCommon); ok {
client.clearClientSecurityProfiles()
client.setClientTransportProtectionProfile(transportProtectionProfile{
mode: ProtectionManaged,
msgEn: defaultMsgEn,
msgDe: defaultMsgDe,
secretKey: bytes.Clone(defaultAesKey),
})
client.securityReadyCheck = false
client.skipKeyExchange = false
client.handshakeRsaPubKey = bytes.Clone(defaultRsaPubKey)
return
}
c.SetSecretKey(bytes.Clone(defaultAesKey))
c.SetMsgEn(defaultMsgEn)
c.SetMsgDe(defaultMsgDe)
c.SetSkipExchangeKey(false)
c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey))
}
// UseLegacySecurityServer restores the legacy RSA key-exchange plus AES-CFB
// transport profile for newly accepted connections.
//
// It is kept only as an explicit fallback path for existing deployments.
func UseLegacySecurityServer(s Server) {
if server, ok := s.(*ServerCommon); ok {
server.clearServerSecurityProfiles()
server.setServerDefaultTransportProtectionProfile(transportProtectionProfile{
mode: ProtectionManaged,
msgEn: defaultMsgEn,
msgDe: defaultMsgDe,
secretKey: bytes.Clone(defaultAesKey),
})
server.securityReadyCheck = false
server.handshakeRsaKey = bytes.Clone(defaultRsaKey)
return
}
s.SetSecretKey(bytes.Clone(defaultAesKey))
s.SetDefaultCommEncode(defaultMsgEn)
s.SetDefaultCommDecode(defaultMsgDe)
s.SetRsaPrivKey(bytes.Clone(defaultRsaKey))
}
func deriveModernPSKKey(sharedSecret []byte, opts *ModernPSKOptions) ([]byte, []byte, error) {
if len(sharedSecret) == 0 {
return nil, nil, errModernPSKSecretEmpty
}
cfg := normalizeModernPSKOptions(opts)
key, err := starcrypto.DeriveArgon2idKey(string(sharedSecret), cfg.Salt, cfg.Argon2Params)
if err != nil {
return nil, nil, err
}
return key, cfg.AAD, nil
}
func deriveModernPSKProtectionProfile(sharedSecret []byte, opts *ModernPSKOptions, mode ProtectionMode) (transportProtectionProfile, error) {
key, aad, err := deriveModernPSKKey(sharedSecret, opts)
if err != nil {
return transportProtectionProfile{}, err
}
transport := buildModernPSKTransportBundle(aad)
runtime, err := newModernPSKCodecRuntime(key, aad)
if err != nil {
return transportProtectionProfile{}, err
}
return newTransportProtectionProfile(mode, transport, runtime, key), nil
}
func buildExternalProtectionProfile(secretKey []byte) transportProtectionProfile {
return newTransportProtectionProfile(ProtectionExternal, buildExternalTransportBundle(), nil, secretKey)
}
func deriveModernPSKSessionProtectionProfile(base transportProtectionProfile, sessionKey []byte, sessionID []byte) (transportProtectionProfile, error) {
aad := bytes.Clone(defaultModernPSKAAD)
if base.runtime != nil && len(base.runtime.aad) != 0 {
aad = bytes.Clone(base.runtime.aad)
}
runtime, err := newModernPSKCodecRuntime(sessionKey, aad)
if err != nil {
return transportProtectionProfile{}, err
}
profile := newTransportProtectionProfile(base.mode, buildModernPSKTransportBundle(aad), runtime, sessionKey)
profile.keyMode = peerAttachKeyModeECDHE
profile.sessionID = cloneTransportSessionID(sessionID)
profile.forwardSecrecy = true
profile.forwardSecrecyFallback = false
return profile, nil
}
func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions {
cfg := DefaultModernPSKOptions()
if opts == nil {
return cfg
}
if len(opts.Salt) > 0 {
cfg.Salt = bytes.Clone(opts.Salt)
}
if opts.AAD != nil {
cfg.AAD = bytes.Clone(opts.AAD)
}
if opts.Argon2Params.Time != 0 && opts.Argon2Params.Memory != 0 &&
opts.Argon2Params.Threads != 0 && opts.Argon2Params.KeyLen != 0 {
cfg.Argon2Params = opts.Argon2Params
}
return cfg
}
func buildModernPSKCodecs(aad []byte) (func([]byte, []byte) []byte, func([]byte, []byte) []byte) {
bundle := buildModernPSKTransportBundle(aad)
return bundle.msgEn, bundle.msgDe
}
func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle {
aadCopy := bytes.Clone(aad)
cache := newModernPSKCodecCache(aadCopy)
msgEn := func(key []byte, plain []byte) []byte {
runtime, err := cache.runtimeForKey(key)
if err != nil {
log.Print(err)
return nil
}
out, err := runtime.sealPlainPayload(plain)
if err != nil {
log.Print(err)
return nil
}
return out
}
msgDe := func(key []byte, encrypted []byte) []byte {
headerLen := len(modernPSKMagic) + modernPSKNonceSize
if len(encrypted) < headerLen {
log.Print(errModernPSKPayload)
return nil
}
if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) {
log.Print(errModernPSKPayload)
return nil
}
runtime, err := cache.runtimeForKey(key)
if err != nil {
log.Print(err)
return nil
}
plain, err := runtime.openPayload(encrypted)
if err != nil {
log.Print(err)
return nil
}
return plain
}
fastStreamEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
runtime, err := cache.runtimeForKey(key)
if err != nil {
return nil, err
}
return runtime.sealStreamFastPayload(dataID, seq, payload)
}
fastBulkEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
runtime, err := cache.runtimeForKey(key)
if err != nil {
return nil, err
}
return runtime.sealBulkFastPayload(dataID, seq, payload)
}
fastPlainEncode := func(key []byte, plainLen int, fill func([]byte) error) ([]byte, error) {
runtime, err := cache.runtimeForKey(key)
if err != nil {
return nil, err
}
return runtime.sealFilledPayload(plainLen, fill)
}
return modernPSKTransportBundle{
msgEn: msgEn,
msgDe: msgDe,
fastStreamEncode: fastStreamEncode,
fastBulkEncode: fastBulkEncode,
fastPlainEncode: fastPlainEncode,
}
}
func (c *ClientCommon) validateSecurityConfiguration() error {
if c.securityReadyCheck && len(c.SecretKey) == 0 {
return errModernPSKRequired
}
return nil
}
func (s *ServerCommon) validateSecurityConfiguration() error {
if s.securityReadyCheck && len(s.SecretKey) == 0 {
return errModernPSKRequired
}
return nil
}
type modernPSKCodecCache struct {
mu sync.Mutex
aad []byte
key []byte
runtime *modernPSKCodecRuntime
}
type modernPSKCodecRuntime struct {
aead cipher.AEAD
key []byte
aad []byte
prefix [modernPSKNonceSize - 8]byte
seq atomic.Uint64
}
func newModernPSKCodecCache(aad []byte) *modernPSKCodecCache {
return &modernPSKCodecCache{aad: bytes.Clone(aad)}
}
func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime, error) {
if c == nil {
return nil, errModernPSKSecretEmpty
}
c.mu.Lock()
defer c.mu.Unlock()
if c.runtime != nil && bytes.Equal(c.key, key) {
return c.runtime, nil
}
runtime, err := newModernPSKCodecRuntime(key, c.aad)
if err != nil {
return nil, err
}
c.key = bytes.Clone(key)
c.runtime = runtime
return runtime, nil
}
func newModernPSKCodecRuntime(key []byte, aad []byte) (*modernPSKCodecRuntime, error) {
if len(key) == 0 {
return nil, errModernPSKSecretEmpty
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
runtime := &modernPSKCodecRuntime{
aead: aead,
key: bytes.Clone(key),
aad: bytes.Clone(aad),
}
if _, err := cryptorand.Read(runtime.prefix[:]); err != nil {
return nil, err
}
return runtime, nil
}
func (r *modernPSKCodecRuntime) fork() (*modernPSKCodecRuntime, error) {
if r == nil {
return nil, errModernPSKSecretEmpty
}
return newModernPSKCodecRuntime(r.key, r.aad)
}
func (r *modernPSKCodecRuntime) nextNonce() [modernPSKNonceSize]byte {
var nonce [modernPSKNonceSize]byte
if r == nil {
return nonce
}
copy(nonce[:len(r.prefix)], r.prefix[:])
binary.BigEndian.PutUint64(nonce[len(r.prefix):], r.seq.Add(1))
return nonce
}
func (r *modernPSKCodecRuntime) sealStreamFastPayload(dataID uint64, seq uint64, payload []byte) ([]byte, error) {
return r.sealFilledPayload(streamFastPayloadHeaderLen+len(payload), func(frame []byte) error {
if err := encodeStreamFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
return err
}
copy(frame[streamFastPayloadHeaderLen:], payload)
return nil
})
}
func (r *modernPSKCodecRuntime) sealBulkFastPayload(dataID uint64, seq uint64, payload []byte) ([]byte, error) {
if r == nil {
return nil, errTransportPayloadEncryptFailed
}
return r.sealFilledPayload(bulkFastPayloadHeaderLen+len(payload), func(frame []byte) error {
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
return err
}
copy(frame[bulkFastPayloadHeaderLen:], payload)
return nil
})
}
func (r *modernPSKCodecRuntime) sealPlainPayload(plain []byte) ([]byte, error) {
return r.sealFilledPayload(len(plain), func(dst []byte) error {
copy(dst, plain)
return nil
})
}
func (r *modernPSKCodecRuntime) sealFilledPayload(plainLen int, fill func([]byte) error) ([]byte, error) {
if r == nil {
return nil, errTransportPayloadEncryptFailed
}
if plainLen < 0 {
return nil, errTransportPayloadEncryptFailed
}
nonce := r.nextNonce()
headerLen := len(modernPSKMagic) + modernPSKNonceSize
out := make([]byte, headerLen+plainLen+r.aead.Overhead())
sealed, err := r.sealInto(out, headerLen, nonce, plainLen, fill)
if err != nil {
return nil, err
}
return out[:headerLen+len(sealed)], nil
}
func (r *modernPSKCodecRuntime) sealFilledPayloadPooled(plainLen int, fill func([]byte) error) ([]byte, func(), error) {
if r == nil {
return nil, nil, errTransportPayloadEncryptFailed
}
if plainLen < 0 {
return nil, nil, errTransportPayloadEncryptFailed
}
nonce := r.nextNonce()
headerLen := len(modernPSKMagic) + modernPSKNonceSize
totalLen := headerLen + plainLen + r.aead.Overhead()
out := getModernPSKPayloadBuffer(totalLen)
sealed, err := r.sealInto(out, headerLen, nonce, plainLen, fill)
if err != nil {
putModernPSKPayloadBuffer(out)
return nil, nil, err
}
return out[:headerLen+len(sealed)], func() {
putModernPSKPayloadBuffer(out)
}, nil
}
func (r *modernPSKCodecRuntime) sealInto(out []byte, headerLen int, nonce [modernPSKNonceSize]byte, plainLen int, fill func([]byte) error) ([]byte, error) {
copy(out[:len(modernPSKMagic)], modernPSKMagic)
copy(out[len(modernPSKMagic):headerLen], nonce[:])
frame := out[headerLen : headerLen+plainLen]
if fill != nil {
if err := fill(frame); err != nil {
return nil, err
}
}
return r.aead.Seal(frame[:0], nonce[:], frame, r.aad), nil
}
func (r *modernPSKCodecRuntime) openPayload(encrypted []byte) ([]byte, error) {
if r == nil {
return nil, errTransportPayloadDecryptFailed
}
headerLen := len(modernPSKMagic) + modernPSKNonceSize
if len(encrypted) < headerLen {
return nil, errModernPSKPayload
}
if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) {
return nil, errModernPSKPayload
}
nonce := encrypted[len(modernPSKMagic):headerLen]
ciphertext := encrypted[headerLen:]
return r.aead.Open(make([]byte, 0, len(ciphertext)), nonce, ciphertext, r.aad)
}
func (r *modernPSKCodecRuntime) openPayloadPooled(encrypted []byte, release func()) ([]byte, func(), error) {
if r == nil {
if release != nil {
release()
}
return nil, nil, errTransportPayloadDecryptFailed
}
headerLen := len(modernPSKMagic) + modernPSKNonceSize
if len(encrypted) < headerLen {
if release != nil {
release()
}
return nil, nil, errModernPSKPayload
}
if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) {
if release != nil {
release()
}
return nil, nil, errModernPSKPayload
}
nonce := encrypted[len(modernPSKMagic):headerLen]
ciphertext := encrypted[headerLen:]
plain, err := r.aead.Open(ciphertext[:0], nonce, ciphertext, r.aad)
if err != nil {
if release != nil {
release()
}
return nil, nil, err
}
return plain, release, nil
}
func (r *modernPSKCodecRuntime) openPayloadOwnedPooled(encrypted []byte) ([]byte, func(), error) {
if r == nil {
return nil, nil, errTransportPayloadDecryptFailed
}
headerLen := len(modernPSKMagic) + modernPSKNonceSize
if len(encrypted) < headerLen {
return nil, nil, errModernPSKPayload
}
if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) {
return nil, nil, errModernPSKPayload
}
nonce := encrypted[len(modernPSKMagic):headerLen]
ciphertext := encrypted[headerLen:]
plainLen := len(ciphertext) - r.aead.Overhead()
if plainLen < 0 {
return nil, nil, errModernPSKPayload
}
out := getModernPSKPayloadBuffer(plainLen)
plain, err := r.aead.Open(out[:0], nonce, ciphertext, r.aad)
if err != nil {
putModernPSKPayloadBuffer(out)
return nil, nil, err
}
return plain, func() {
putModernPSKPayloadBuffer(out)
}, nil
}
func getModernPSKPayloadBuffer(size int) []byte {
if size <= 0 {
return nil
}
if pooled, ok := modernPSKPayloadPool.Get().([]byte); ok && cap(pooled) >= size {
return pooled[:size]
}
return make([]byte, size)
}
func putModernPSKPayloadBuffer(buf []byte) {
if cap(buf) == 0 || cap(buf) > 32*1024*1024 {
return
}
modernPSKPayloadPool.Put(buf[:0])
}