package notify import ( "bytes" "crypto/aes" "crypto/cipher" cryptorand "crypto/rand" "encoding/binary" "errors" "log" "sync" "sync/atomic" "b612.me/starcrypto" ) var ( errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty") errModernPSKPayload = errors.New("invalid modern psk payload") errModernPSKRequired = errors.New("modern psk is required: call UseModernPSKClient/UseModernPSKServer or set a transport key before Connect/Listen") ) var ( modernPSKMagic = []byte("NPS1") defaultModernPSKSalt = []byte("b612.me/notify/psk/aes-gcm/v1") defaultModernPSKAAD = []byte("b612.me/notify/psk-envelope/v1") ) const modernPSKNonceSize = 12 type transportFastStreamEncoder func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) type transportFastBulkEncoder func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) type transportFastPlainEncoder func(secretKey []byte, plainLen int, fill func([]byte) error) ([]byte, error) type modernPSKTransportBundle struct { msgEn func([]byte, []byte) []byte msgDe func([]byte, []byte) []byte fastStreamEncode transportFastStreamEncoder fastBulkEncode transportFastBulkEncoder fastPlainEncode transportFastPlainEncoder } // ModernPSKOptions configures the modern PSK transport profile. // // The current profile derives a 32-byte transport key with Argon2id and uses // AES-GCM with a per-codec nonce prefix plus a per-message counter. type ModernPSKOptions struct { Salt []byte AAD []byte Argon2Params starcrypto.Argon2Params } // DefaultModernPSKOptions returns the recommended settings for the current // PSK transport profile. func DefaultModernPSKOptions() ModernPSKOptions { return ModernPSKOptions{ Salt: bytes.Clone(defaultModernPSKSalt), AAD: bytes.Clone(defaultModernPSKAAD), Argon2Params: starcrypto.DefaultArgon2idParams(), } } func defaultModernPSKCodecs() (func([]byte, []byte) []byte, func([]byte, []byte) []byte) { bundle := defaultModernPSKTransportBundle() return bundle.msgEn, bundle.msgDe } func defaultModernPSKTransportBundle() modernPSKTransportBundle { return buildModernPSKTransportBundle(defaultModernPSKAAD) } // UseModernPSKClient configures a client to use the modern PSK transport // profile. // // It disables the legacy RSA key-exchange path, derives a transport key with // Argon2id, and switches message protection to AES-GCM. Configure it before // calling Connect/ConnectTimeout. func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error { key, aad, err := deriveModernPSKKey(sharedSecret, opts) if err != nil { return err } transport := buildModernPSKTransportBundle(aad) c.SetSecretKey(key) c.SetMsgEn(transport.msgEn) c.SetMsgDe(transport.msgDe) if client, ok := c.(*ClientCommon); ok { client.fastStreamEncode = transport.fastStreamEncode client.fastBulkEncode = transport.fastBulkEncode client.fastPlainEncode = transport.fastPlainEncode } c.SetSkipExchangeKey(true) return nil } // UseModernPSKServer configures a server to use the modern PSK transport // profile for newly accepted connections. // // It derives a transport key with Argon2id and switches message protection to // AES-GCM. Configure it before calling Listen. func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error { key, aad, err := deriveModernPSKKey(sharedSecret, opts) if err != nil { return err } transport := buildModernPSKTransportBundle(aad) s.SetSecretKey(key) s.SetDefaultCommEncode(transport.msgEn) s.SetDefaultCommDecode(transport.msgDe) if server, ok := s.(*ServerCommon); ok { server.defaultFastStreamEncode = transport.fastStreamEncode server.defaultFastBulkEncode = transport.fastBulkEncode server.defaultFastPlainEncode = transport.fastPlainEncode } return nil } // UseLegacySecurityClient restores the legacy RSA key-exchange plus AES-CFB // transport profile. // // It is kept only as an explicit fallback path for existing deployments. func UseLegacySecurityClient(c Client) { c.SetSecretKey(bytes.Clone(defaultAesKey)) c.SetMsgEn(defaultMsgEn) c.SetMsgDe(defaultMsgDe) if client, ok := c.(*ClientCommon); ok { client.fastStreamEncode = nil client.fastBulkEncode = nil client.fastPlainEncode = nil } c.SetSkipExchangeKey(false) c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey)) } // UseLegacySecurityServer restores the legacy RSA key-exchange plus AES-CFB // transport profile for newly accepted connections. // // It is kept only as an explicit fallback path for existing deployments. func UseLegacySecurityServer(s Server) { s.SetSecretKey(bytes.Clone(defaultAesKey)) s.SetDefaultCommEncode(defaultMsgEn) s.SetDefaultCommDecode(defaultMsgDe) if server, ok := s.(*ServerCommon); ok { server.defaultFastStreamEncode = nil server.defaultFastBulkEncode = nil server.defaultFastPlainEncode = nil } s.SetRsaPrivKey(bytes.Clone(defaultRsaKey)) } func deriveModernPSKKey(sharedSecret []byte, opts *ModernPSKOptions) ([]byte, []byte, error) { if len(sharedSecret) == 0 { return nil, nil, errModernPSKSecretEmpty } cfg := normalizeModernPSKOptions(opts) key, err := starcrypto.DeriveArgon2idKey(string(sharedSecret), cfg.Salt, cfg.Argon2Params) if err != nil { return nil, nil, err } return key, cfg.AAD, nil } func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions { cfg := DefaultModernPSKOptions() if opts == nil { return cfg } if len(opts.Salt) > 0 { cfg.Salt = bytes.Clone(opts.Salt) } if opts.AAD != nil { cfg.AAD = bytes.Clone(opts.AAD) } if opts.Argon2Params.Time != 0 && opts.Argon2Params.Memory != 0 && opts.Argon2Params.Threads != 0 && opts.Argon2Params.KeyLen != 0 { cfg.Argon2Params = opts.Argon2Params } return cfg } func buildModernPSKCodecs(aad []byte) (func([]byte, []byte) []byte, func([]byte, []byte) []byte) { bundle := buildModernPSKTransportBundle(aad) return bundle.msgEn, bundle.msgDe } func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle { aadCopy := bytes.Clone(aad) cache := &modernPSKCodecCache{} msgEn := func(key []byte, plain []byte) []byte { runtime, err := cache.runtimeForKey(key) if err != nil { log.Print(err) return nil } out, err := runtime.sealPlainPayload(aadCopy, plain) if err != nil { log.Print(err) return nil } return out } msgDe := func(key []byte, encrypted []byte) []byte { headerLen := len(modernPSKMagic) + modernPSKNonceSize if len(encrypted) < headerLen { log.Print(errModernPSKPayload) return nil } if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) { log.Print(errModernPSKPayload) return nil } runtime, err := cache.runtimeForKey(key) if err != nil { log.Print(err) return nil } nonce := encrypted[len(modernPSKMagic):headerLen] ciphertext := encrypted[headerLen:] plain, err := runtime.aead.Open(make([]byte, 0, len(ciphertext)), nonce, ciphertext, aadCopy) if err != nil { log.Print(err) return nil } return plain } fastStreamEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { runtime, err := cache.runtimeForKey(key) if err != nil { return nil, err } return runtime.sealStreamFastPayload(aadCopy, dataID, seq, payload) } fastBulkEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { runtime, err := cache.runtimeForKey(key) if err != nil { return nil, err } return runtime.sealBulkFastPayload(aadCopy, dataID, seq, payload) } fastPlainEncode := func(key []byte, plainLen int, fill func([]byte) error) ([]byte, error) { runtime, err := cache.runtimeForKey(key) if err != nil { return nil, err } return runtime.sealFilledPayload(aadCopy, plainLen, fill) } return modernPSKTransportBundle{ msgEn: msgEn, msgDe: msgDe, fastStreamEncode: fastStreamEncode, fastBulkEncode: fastBulkEncode, fastPlainEncode: fastPlainEncode, } } func (c *ClientCommon) validateSecurityConfiguration() error { if c.securityReadyCheck && len(c.SecretKey) == 0 { return errModernPSKRequired } return nil } func (s *ServerCommon) validateSecurityConfiguration() error { if s.securityReadyCheck && len(s.SecretKey) == 0 { return errModernPSKRequired } return nil } type modernPSKCodecCache struct { mu sync.Mutex key []byte runtime *modernPSKCodecRuntime } type modernPSKCodecRuntime struct { aead cipher.AEAD prefix [modernPSKNonceSize - 8]byte seq atomic.Uint64 } func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime, error) { if c == nil { return nil, errModernPSKSecretEmpty } c.mu.Lock() defer c.mu.Unlock() if c.runtime != nil && bytes.Equal(c.key, key) { return c.runtime, nil } runtime, err := newModernPSKCodecRuntime(key) if err != nil { return nil, err } c.key = bytes.Clone(key) c.runtime = runtime return runtime, nil } func newModernPSKCodecRuntime(key []byte) (*modernPSKCodecRuntime, error) { if len(key) == 0 { return nil, errModernPSKSecretEmpty } block, err := aes.NewCipher(key) if err != nil { return nil, err } aead, err := cipher.NewGCM(block) if err != nil { return nil, err } runtime := &modernPSKCodecRuntime{ aead: aead, } if _, err := cryptorand.Read(runtime.prefix[:]); err != nil { return nil, err } return runtime, nil } func (r *modernPSKCodecRuntime) nextNonce() [modernPSKNonceSize]byte { var nonce [modernPSKNonceSize]byte if r == nil { return nonce } copy(nonce[:len(r.prefix)], r.prefix[:]) binary.BigEndian.PutUint64(nonce[len(r.prefix):], r.seq.Add(1)) return nonce } func (r *modernPSKCodecRuntime) sealStreamFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { return r.sealFilledPayload(aad, streamFastPayloadHeaderLen+len(payload), func(frame []byte) error { if err := encodeStreamFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { return err } copy(frame[streamFastPayloadHeaderLen:], payload) return nil }) } func (r *modernPSKCodecRuntime) sealBulkFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { if r == nil { return nil, errTransportPayloadEncryptFailed } return r.sealFilledPayload(aad, bulkFastPayloadHeaderLen+len(payload), func(frame []byte) error { if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { return err } copy(frame[bulkFastPayloadHeaderLen:], payload) return nil }) } func (r *modernPSKCodecRuntime) sealPlainPayload(aad []byte, plain []byte) ([]byte, error) { return r.sealFilledPayload(aad, len(plain), func(dst []byte) error { copy(dst, plain) return nil }) } func (r *modernPSKCodecRuntime) sealFilledPayload(aad []byte, plainLen int, fill func([]byte) error) ([]byte, error) { if r == nil { return nil, errTransportPayloadEncryptFailed } if plainLen < 0 { return nil, errTransportPayloadEncryptFailed } nonce := r.nextNonce() headerLen := len(modernPSKMagic) + modernPSKNonceSize out := make([]byte, headerLen+plainLen+r.aead.Overhead()) copy(out[:len(modernPSKMagic)], modernPSKMagic) copy(out[len(modernPSKMagic):headerLen], nonce[:]) frame := out[headerLen : headerLen+plainLen] if fill != nil { if err := fill(frame); err != nil { return nil, err } } sealed := r.aead.Seal(frame[:0], nonce[:], frame, aad) return out[:headerLen+len(sealed)], nil }