- 引入 LogicalConn/TransportConn 分层,ClientConn 保留兼容适配层 - 新增 Stream、Bulk、RecordStream 三条数据面能力及对应控制路径 - 完成 transfer/file 传输内核与状态快照、诊断能力 - 补齐 reconnect、inbound dispatcher、modern psk 等基础模块 - 增加大规模回归、并发与基准测试覆盖 - 更新依赖库
382 lines
11 KiB
Go
382 lines
11 KiB
Go
package notify
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
cryptorand "crypto/rand"
|
|
"encoding/binary"
|
|
"errors"
|
|
"log"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"b612.me/starcrypto"
|
|
)
|
|
|
|
var (
|
|
errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty")
|
|
errModernPSKPayload = errors.New("invalid modern psk payload")
|
|
errModernPSKRequired = errors.New("modern psk is required: call UseModernPSKClient/UseModernPSKServer or set a transport key before Connect/Listen")
|
|
)
|
|
|
|
var (
|
|
modernPSKMagic = []byte("NPS1")
|
|
defaultModernPSKSalt = []byte("b612.me/notify/psk/aes-gcm/v1")
|
|
defaultModernPSKAAD = []byte("b612.me/notify/psk-envelope/v1")
|
|
)
|
|
|
|
const modernPSKNonceSize = 12
|
|
|
|
type transportFastStreamEncoder func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error)
|
|
type transportFastBulkEncoder func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error)
|
|
type transportFastPlainEncoder func(secretKey []byte, plainLen int, fill func([]byte) error) ([]byte, error)
|
|
|
|
type modernPSKTransportBundle struct {
|
|
msgEn func([]byte, []byte) []byte
|
|
msgDe func([]byte, []byte) []byte
|
|
fastStreamEncode transportFastStreamEncoder
|
|
fastBulkEncode transportFastBulkEncoder
|
|
fastPlainEncode transportFastPlainEncoder
|
|
}
|
|
|
|
// ModernPSKOptions configures the modern PSK transport profile.
|
|
//
|
|
// The current profile derives a 32-byte transport key with Argon2id and uses
|
|
// AES-GCM with a per-codec nonce prefix plus a per-message counter.
|
|
type ModernPSKOptions struct {
|
|
Salt []byte
|
|
AAD []byte
|
|
Argon2Params starcrypto.Argon2Params
|
|
}
|
|
|
|
// DefaultModernPSKOptions returns the recommended settings for the current
|
|
// PSK transport profile.
|
|
func DefaultModernPSKOptions() ModernPSKOptions {
|
|
return ModernPSKOptions{
|
|
Salt: bytes.Clone(defaultModernPSKSalt),
|
|
AAD: bytes.Clone(defaultModernPSKAAD),
|
|
Argon2Params: starcrypto.DefaultArgon2idParams(),
|
|
}
|
|
}
|
|
|
|
func defaultModernPSKCodecs() (func([]byte, []byte) []byte, func([]byte, []byte) []byte) {
|
|
bundle := defaultModernPSKTransportBundle()
|
|
return bundle.msgEn, bundle.msgDe
|
|
}
|
|
|
|
func defaultModernPSKTransportBundle() modernPSKTransportBundle {
|
|
return buildModernPSKTransportBundle(defaultModernPSKAAD)
|
|
}
|
|
|
|
// UseModernPSKClient configures a client to use the modern PSK transport
|
|
// profile.
|
|
//
|
|
// It disables the legacy RSA key-exchange path, derives a transport key with
|
|
// Argon2id, and switches message protection to AES-GCM. Configure it before
|
|
// calling Connect/ConnectTimeout.
|
|
func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
|
|
key, aad, err := deriveModernPSKKey(sharedSecret, opts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
transport := buildModernPSKTransportBundle(aad)
|
|
c.SetSecretKey(key)
|
|
c.SetMsgEn(transport.msgEn)
|
|
c.SetMsgDe(transport.msgDe)
|
|
if client, ok := c.(*ClientCommon); ok {
|
|
client.fastStreamEncode = transport.fastStreamEncode
|
|
client.fastBulkEncode = transport.fastBulkEncode
|
|
client.fastPlainEncode = transport.fastPlainEncode
|
|
}
|
|
c.SetSkipExchangeKey(true)
|
|
return nil
|
|
}
|
|
|
|
// UseModernPSKServer configures a server to use the modern PSK transport
|
|
// profile for newly accepted connections.
|
|
//
|
|
// It derives a transport key with Argon2id and switches message protection to
|
|
// AES-GCM. Configure it before calling Listen.
|
|
func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
|
|
key, aad, err := deriveModernPSKKey(sharedSecret, opts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
transport := buildModernPSKTransportBundle(aad)
|
|
s.SetSecretKey(key)
|
|
s.SetDefaultCommEncode(transport.msgEn)
|
|
s.SetDefaultCommDecode(transport.msgDe)
|
|
if server, ok := s.(*ServerCommon); ok {
|
|
server.defaultFastStreamEncode = transport.fastStreamEncode
|
|
server.defaultFastBulkEncode = transport.fastBulkEncode
|
|
server.defaultFastPlainEncode = transport.fastPlainEncode
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UseLegacySecurityClient restores the legacy RSA key-exchange plus AES-CFB
|
|
// transport profile.
|
|
//
|
|
// It is kept only as an explicit fallback path for existing deployments.
|
|
func UseLegacySecurityClient(c Client) {
|
|
c.SetSecretKey(bytes.Clone(defaultAesKey))
|
|
c.SetMsgEn(defaultMsgEn)
|
|
c.SetMsgDe(defaultMsgDe)
|
|
if client, ok := c.(*ClientCommon); ok {
|
|
client.fastStreamEncode = nil
|
|
client.fastBulkEncode = nil
|
|
client.fastPlainEncode = nil
|
|
}
|
|
c.SetSkipExchangeKey(false)
|
|
c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey))
|
|
}
|
|
|
|
// UseLegacySecurityServer restores the legacy RSA key-exchange plus AES-CFB
|
|
// transport profile for newly accepted connections.
|
|
//
|
|
// It is kept only as an explicit fallback path for existing deployments.
|
|
func UseLegacySecurityServer(s Server) {
|
|
s.SetSecretKey(bytes.Clone(defaultAesKey))
|
|
s.SetDefaultCommEncode(defaultMsgEn)
|
|
s.SetDefaultCommDecode(defaultMsgDe)
|
|
if server, ok := s.(*ServerCommon); ok {
|
|
server.defaultFastStreamEncode = nil
|
|
server.defaultFastBulkEncode = nil
|
|
server.defaultFastPlainEncode = nil
|
|
}
|
|
s.SetRsaPrivKey(bytes.Clone(defaultRsaKey))
|
|
}
|
|
|
|
func deriveModernPSKKey(sharedSecret []byte, opts *ModernPSKOptions) ([]byte, []byte, error) {
|
|
if len(sharedSecret) == 0 {
|
|
return nil, nil, errModernPSKSecretEmpty
|
|
}
|
|
cfg := normalizeModernPSKOptions(opts)
|
|
key, err := starcrypto.DeriveArgon2idKey(string(sharedSecret), cfg.Salt, cfg.Argon2Params)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return key, cfg.AAD, nil
|
|
}
|
|
|
|
func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions {
|
|
cfg := DefaultModernPSKOptions()
|
|
if opts == nil {
|
|
return cfg
|
|
}
|
|
if len(opts.Salt) > 0 {
|
|
cfg.Salt = bytes.Clone(opts.Salt)
|
|
}
|
|
if opts.AAD != nil {
|
|
cfg.AAD = bytes.Clone(opts.AAD)
|
|
}
|
|
if opts.Argon2Params.Time != 0 && opts.Argon2Params.Memory != 0 &&
|
|
opts.Argon2Params.Threads != 0 && opts.Argon2Params.KeyLen != 0 {
|
|
cfg.Argon2Params = opts.Argon2Params
|
|
}
|
|
return cfg
|
|
}
|
|
|
|
func buildModernPSKCodecs(aad []byte) (func([]byte, []byte) []byte, func([]byte, []byte) []byte) {
|
|
bundle := buildModernPSKTransportBundle(aad)
|
|
return bundle.msgEn, bundle.msgDe
|
|
}
|
|
|
|
func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle {
|
|
aadCopy := bytes.Clone(aad)
|
|
cache := &modernPSKCodecCache{}
|
|
msgEn := func(key []byte, plain []byte) []byte {
|
|
runtime, err := cache.runtimeForKey(key)
|
|
if err != nil {
|
|
log.Print(err)
|
|
return nil
|
|
}
|
|
out, err := runtime.sealPlainPayload(aadCopy, plain)
|
|
if err != nil {
|
|
log.Print(err)
|
|
return nil
|
|
}
|
|
return out
|
|
}
|
|
msgDe := func(key []byte, encrypted []byte) []byte {
|
|
headerLen := len(modernPSKMagic) + modernPSKNonceSize
|
|
if len(encrypted) < headerLen {
|
|
log.Print(errModernPSKPayload)
|
|
return nil
|
|
}
|
|
if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) {
|
|
log.Print(errModernPSKPayload)
|
|
return nil
|
|
}
|
|
runtime, err := cache.runtimeForKey(key)
|
|
if err != nil {
|
|
log.Print(err)
|
|
return nil
|
|
}
|
|
nonce := encrypted[len(modernPSKMagic):headerLen]
|
|
ciphertext := encrypted[headerLen:]
|
|
plain, err := runtime.aead.Open(make([]byte, 0, len(ciphertext)), nonce, ciphertext, aadCopy)
|
|
if err != nil {
|
|
log.Print(err)
|
|
return nil
|
|
}
|
|
return plain
|
|
}
|
|
fastStreamEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
|
runtime, err := cache.runtimeForKey(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return runtime.sealStreamFastPayload(aadCopy, dataID, seq, payload)
|
|
}
|
|
fastBulkEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
|
runtime, err := cache.runtimeForKey(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return runtime.sealBulkFastPayload(aadCopy, dataID, seq, payload)
|
|
}
|
|
fastPlainEncode := func(key []byte, plainLen int, fill func([]byte) error) ([]byte, error) {
|
|
runtime, err := cache.runtimeForKey(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return runtime.sealFilledPayload(aadCopy, plainLen, fill)
|
|
}
|
|
return modernPSKTransportBundle{
|
|
msgEn: msgEn,
|
|
msgDe: msgDe,
|
|
fastStreamEncode: fastStreamEncode,
|
|
fastBulkEncode: fastBulkEncode,
|
|
fastPlainEncode: fastPlainEncode,
|
|
}
|
|
}
|
|
|
|
func (c *ClientCommon) validateSecurityConfiguration() error {
|
|
if c.securityReadyCheck && len(c.SecretKey) == 0 {
|
|
return errModernPSKRequired
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *ServerCommon) validateSecurityConfiguration() error {
|
|
if s.securityReadyCheck && len(s.SecretKey) == 0 {
|
|
return errModernPSKRequired
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type modernPSKCodecCache struct {
|
|
mu sync.Mutex
|
|
key []byte
|
|
runtime *modernPSKCodecRuntime
|
|
}
|
|
|
|
type modernPSKCodecRuntime struct {
|
|
aead cipher.AEAD
|
|
prefix [modernPSKNonceSize - 8]byte
|
|
seq atomic.Uint64
|
|
}
|
|
|
|
func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime, error) {
|
|
if c == nil {
|
|
return nil, errModernPSKSecretEmpty
|
|
}
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.runtime != nil && bytes.Equal(c.key, key) {
|
|
return c.runtime, nil
|
|
}
|
|
runtime, err := newModernPSKCodecRuntime(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.key = bytes.Clone(key)
|
|
c.runtime = runtime
|
|
return runtime, nil
|
|
}
|
|
|
|
func newModernPSKCodecRuntime(key []byte) (*modernPSKCodecRuntime, error) {
|
|
if len(key) == 0 {
|
|
return nil, errModernPSKSecretEmpty
|
|
}
|
|
block, err := aes.NewCipher(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
aead, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
runtime := &modernPSKCodecRuntime{
|
|
aead: aead,
|
|
}
|
|
if _, err := cryptorand.Read(runtime.prefix[:]); err != nil {
|
|
return nil, err
|
|
}
|
|
return runtime, nil
|
|
}
|
|
|
|
func (r *modernPSKCodecRuntime) nextNonce() [modernPSKNonceSize]byte {
|
|
var nonce [modernPSKNonceSize]byte
|
|
if r == nil {
|
|
return nonce
|
|
}
|
|
copy(nonce[:len(r.prefix)], r.prefix[:])
|
|
binary.BigEndian.PutUint64(nonce[len(r.prefix):], r.seq.Add(1))
|
|
return nonce
|
|
}
|
|
|
|
func (r *modernPSKCodecRuntime) sealStreamFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
|
return r.sealFilledPayload(aad, streamFastPayloadHeaderLen+len(payload), func(frame []byte) error {
|
|
if err := encodeStreamFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
|
|
return err
|
|
}
|
|
copy(frame[streamFastPayloadHeaderLen:], payload)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (r *modernPSKCodecRuntime) sealBulkFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
|
if r == nil {
|
|
return nil, errTransportPayloadEncryptFailed
|
|
}
|
|
return r.sealFilledPayload(aad, bulkFastPayloadHeaderLen+len(payload), func(frame []byte) error {
|
|
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
|
|
return err
|
|
}
|
|
copy(frame[bulkFastPayloadHeaderLen:], payload)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (r *modernPSKCodecRuntime) sealPlainPayload(aad []byte, plain []byte) ([]byte, error) {
|
|
return r.sealFilledPayload(aad, len(plain), func(dst []byte) error {
|
|
copy(dst, plain)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (r *modernPSKCodecRuntime) sealFilledPayload(aad []byte, plainLen int, fill func([]byte) error) ([]byte, error) {
|
|
if r == nil {
|
|
return nil, errTransportPayloadEncryptFailed
|
|
}
|
|
if plainLen < 0 {
|
|
return nil, errTransportPayloadEncryptFailed
|
|
}
|
|
nonce := r.nextNonce()
|
|
headerLen := len(modernPSKMagic) + modernPSKNonceSize
|
|
out := make([]byte, headerLen+plainLen+r.aead.Overhead())
|
|
copy(out[:len(modernPSKMagic)], modernPSKMagic)
|
|
copy(out[len(modernPSKMagic):headerLen], nonce[:])
|
|
frame := out[headerLen : headerLen+plainLen]
|
|
if fill != nil {
|
|
if err := fill(frame); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
sealed := r.aead.Seal(frame[:0], nonce[:], frame, aad)
|
|
return out[:headerLen+len(sealed)], nil
|
|
}
|