package notify import ( "bytes" "crypto/aes" "crypto/cipher" cryptorand "crypto/rand" "encoding/binary" "errors" "log" "sync" "sync/atomic" "b612.me/starcrypto" ) var ( errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty") errModernPSKPayload = errors.New("invalid modern psk payload") errModernPSKRequired = errors.New("transport security is required: call UseModernPSKClient/UseModernPSKServer, UsePSKOverExternalTransportClient/Server, or set a transport key before Connect/Listen") errModernPSKForwardSecrecyUnsupported = errors.New("forward secrecy is unsupported for external transport protection") ) var ( modernPSKMagic = []byte("NPS1") defaultModernPSKSalt = []byte("b612.me/notify/psk/aes-gcm/v1") defaultModernPSKAAD = []byte("b612.me/notify/psk-envelope/v1") ) const modernPSKNonceSize = 12 type transportFastStreamEncoder func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) type transportFastBulkEncoder func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) type transportFastPlainEncoder func(secretKey []byte, plainLen int, fill func([]byte) error) ([]byte, error) type modernPSKTransportBundle struct { msgEn func([]byte, []byte) []byte msgDe func([]byte, []byte) []byte fastStreamEncode transportFastStreamEncoder fastBulkEncode transportFastBulkEncoder fastPlainEncode transportFastPlainEncoder } var modernPSKPayloadPool sync.Pool // ModernPSKOptions configures the modern PSK transport profile. // // The current profile derives a 32-byte transport key with Argon2id and uses // AES-GCM with a per-codec nonce prefix plus a per-message counter. type ModernPSKOptions struct { Salt []byte AAD []byte Argon2Params starcrypto.Argon2Params RequireForwardSecrecy bool } // DefaultModernPSKOptions returns the recommended settings for the current // PSK transport profile. func DefaultModernPSKOptions() ModernPSKOptions { return ModernPSKOptions{ Salt: bytes.Clone(defaultModernPSKSalt), AAD: bytes.Clone(defaultModernPSKAAD), Argon2Params: starcrypto.DefaultArgon2idParams(), } } func defaultModernPSKCodecs() (func([]byte, []byte) []byte, func([]byte, []byte) []byte) { bundle := defaultModernPSKTransportBundle() return bundle.msgEn, bundle.msgDe } func defaultModernPSKTransportBundle() modernPSKTransportBundle { return buildModernPSKTransportBundle(defaultModernPSKAAD) } // UseModernPSKClient configures a client to use the modern PSK transport // profile. // // It disables the legacy RSA key-exchange path, derives a transport key with // Argon2id, and switches message protection to AES-GCM. Configure it before // calling Connect/ConnectTimeout. func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error { managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged) if err != nil { return err } if client, ok := c.(*ClientCommon); ok { client.configureClientSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy) return nil } c.SetSecretKey(managed.secretKey) c.SetMsgEn(managed.msgEn) c.SetMsgDe(managed.msgDe) c.SetSkipExchangeKey(true) return nil } // UseModernPSKServer configures a server to use the modern PSK transport // profile for newly accepted connections. // // It derives a transport key with Argon2id and switches message protection to // AES-GCM. Configure it before calling Listen. func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error { managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged) if err != nil { return err } if server, ok := s.(*ServerCommon); ok { server.configureServerSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy) return nil } s.SetSecretKey(managed.secretKey) s.SetDefaultCommEncode(managed.msgEn) s.SetDefaultCommDecode(managed.msgDe) return nil } // UsePSKOverExternalTransportClient authenticates bootstrap with PSK and then // trusts the external channel for steady-state payload protection. func UsePSKOverExternalTransportClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error { if opts != nil && opts.RequireForwardSecrecy { return errModernPSKForwardSecrecyUnsupported } bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged) if err != nil { return err } steady := buildExternalProtectionProfile(bootstrap.secretKey) if client, ok := c.(*ClientCommon); ok { client.configureClientSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false) return nil } c.SetSecretKey(bootstrap.secretKey) c.SetMsgEn(bootstrap.msgEn) c.SetMsgDe(bootstrap.msgDe) c.SetSkipExchangeKey(true) return nil } // UsePSKOverExternalTransportServer authenticates bootstrap with PSK and then // trusts the external channel for steady-state payload protection. func UsePSKOverExternalTransportServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error { if opts != nil && opts.RequireForwardSecrecy { return errModernPSKForwardSecrecyUnsupported } bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged) if err != nil { return err } steady := buildExternalProtectionProfile(bootstrap.secretKey) if server, ok := s.(*ServerCommon); ok { server.configureServerSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false) return nil } s.SetSecretKey(bootstrap.secretKey) s.SetDefaultCommEncode(bootstrap.msgEn) s.SetDefaultCommDecode(bootstrap.msgDe) return nil } // UseNestedSecurityClient keeps notify transport protection enabled even when // the physical connection is already protected by an outer trusted channel. func UseNestedSecurityClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error { managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested) if err != nil { return err } if client, ok := c.(*ClientCommon); ok { client.configureClientSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy) return nil } c.SetSecretKey(managed.secretKey) c.SetMsgEn(managed.msgEn) c.SetMsgDe(managed.msgDe) c.SetSkipExchangeKey(true) return nil } // UseNestedSecurityServer keeps notify transport protection enabled even when // the physical connection is already protected by an outer trusted channel. func UseNestedSecurityServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error { managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested) if err != nil { return err } if server, ok := s.(*ServerCommon); ok { server.configureServerSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy) return nil } s.SetSecretKey(managed.secretKey) s.SetDefaultCommEncode(managed.msgEn) s.SetDefaultCommDecode(managed.msgDe) return nil } // UseLegacySecurityClient restores the legacy RSA key-exchange plus AES-CFB // transport profile. // // It is kept only as an explicit fallback path for existing deployments. func UseLegacySecurityClient(c Client) { if client, ok := c.(*ClientCommon); ok { client.clearClientSecurityProfiles() client.setClientTransportProtectionProfile(transportProtectionProfile{ mode: ProtectionManaged, msgEn: defaultMsgEn, msgDe: defaultMsgDe, secretKey: bytes.Clone(defaultAesKey), }) client.securityReadyCheck = false client.skipKeyExchange = false client.handshakeRsaPubKey = bytes.Clone(defaultRsaPubKey) return } c.SetSecretKey(bytes.Clone(defaultAesKey)) c.SetMsgEn(defaultMsgEn) c.SetMsgDe(defaultMsgDe) c.SetSkipExchangeKey(false) c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey)) } // UseLegacySecurityServer restores the legacy RSA key-exchange plus AES-CFB // transport profile for newly accepted connections. // // It is kept only as an explicit fallback path for existing deployments. func UseLegacySecurityServer(s Server) { if server, ok := s.(*ServerCommon); ok { server.clearServerSecurityProfiles() server.setServerDefaultTransportProtectionProfile(transportProtectionProfile{ mode: ProtectionManaged, msgEn: defaultMsgEn, msgDe: defaultMsgDe, secretKey: bytes.Clone(defaultAesKey), }) server.securityReadyCheck = false server.handshakeRsaKey = bytes.Clone(defaultRsaKey) return } s.SetSecretKey(bytes.Clone(defaultAesKey)) s.SetDefaultCommEncode(defaultMsgEn) s.SetDefaultCommDecode(defaultMsgDe) s.SetRsaPrivKey(bytes.Clone(defaultRsaKey)) } func deriveModernPSKKey(sharedSecret []byte, opts *ModernPSKOptions) ([]byte, []byte, error) { if len(sharedSecret) == 0 { return nil, nil, errModernPSKSecretEmpty } cfg := normalizeModernPSKOptions(opts) key, err := starcrypto.DeriveArgon2idKey(string(sharedSecret), cfg.Salt, cfg.Argon2Params) if err != nil { return nil, nil, err } return key, cfg.AAD, nil } func deriveModernPSKProtectionProfile(sharedSecret []byte, opts *ModernPSKOptions, mode ProtectionMode) (transportProtectionProfile, error) { key, aad, err := deriveModernPSKKey(sharedSecret, opts) if err != nil { return transportProtectionProfile{}, err } transport := buildModernPSKTransportBundle(aad) runtime, err := newModernPSKCodecRuntime(key, aad) if err != nil { return transportProtectionProfile{}, err } return newTransportProtectionProfile(mode, transport, runtime, key), nil } func buildExternalProtectionProfile(secretKey []byte) transportProtectionProfile { return newTransportProtectionProfile(ProtectionExternal, buildExternalTransportBundle(), nil, secretKey) } func deriveModernPSKSessionProtectionProfile(base transportProtectionProfile, sessionKey []byte, sessionID []byte) (transportProtectionProfile, error) { aad := bytes.Clone(defaultModernPSKAAD) if base.runtime != nil && len(base.runtime.aad) != 0 { aad = bytes.Clone(base.runtime.aad) } runtime, err := newModernPSKCodecRuntime(sessionKey, aad) if err != nil { return transportProtectionProfile{}, err } profile := newTransportProtectionProfile(base.mode, buildModernPSKTransportBundle(aad), runtime, sessionKey) profile.keyMode = peerAttachKeyModeECDHE profile.sessionID = cloneTransportSessionID(sessionID) profile.forwardSecrecy = true profile.forwardSecrecyFallback = false return profile, nil } func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions { cfg := DefaultModernPSKOptions() if opts == nil { return cfg } if len(opts.Salt) > 0 { cfg.Salt = bytes.Clone(opts.Salt) } if opts.AAD != nil { cfg.AAD = bytes.Clone(opts.AAD) } if opts.Argon2Params.Time != 0 && opts.Argon2Params.Memory != 0 && opts.Argon2Params.Threads != 0 && opts.Argon2Params.KeyLen != 0 { cfg.Argon2Params = opts.Argon2Params } return cfg } func buildModernPSKCodecs(aad []byte) (func([]byte, []byte) []byte, func([]byte, []byte) []byte) { bundle := buildModernPSKTransportBundle(aad) return bundle.msgEn, bundle.msgDe } func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle { aadCopy := bytes.Clone(aad) cache := newModernPSKCodecCache(aadCopy) msgEn := func(key []byte, plain []byte) []byte { runtime, err := cache.runtimeForKey(key) if err != nil { log.Print(err) return nil } out, err := runtime.sealPlainPayload(plain) if err != nil { log.Print(err) return nil } return out } msgDe := func(key []byte, encrypted []byte) []byte { headerLen := len(modernPSKMagic) + modernPSKNonceSize if len(encrypted) < headerLen { log.Print(errModernPSKPayload) return nil } if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) { log.Print(errModernPSKPayload) return nil } runtime, err := cache.runtimeForKey(key) if err != nil { log.Print(err) return nil } plain, err := runtime.openPayload(encrypted) if err != nil { log.Print(err) return nil } return plain } fastStreamEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { runtime, err := cache.runtimeForKey(key) if err != nil { return nil, err } return runtime.sealStreamFastPayload(dataID, seq, payload) } fastBulkEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { runtime, err := cache.runtimeForKey(key) if err != nil { return nil, err } return runtime.sealBulkFastPayload(dataID, seq, payload) } fastPlainEncode := func(key []byte, plainLen int, fill func([]byte) error) ([]byte, error) { runtime, err := cache.runtimeForKey(key) if err != nil { return nil, err } return runtime.sealFilledPayload(plainLen, fill) } return modernPSKTransportBundle{ msgEn: msgEn, msgDe: msgDe, fastStreamEncode: fastStreamEncode, fastBulkEncode: fastBulkEncode, fastPlainEncode: fastPlainEncode, } } func (c *ClientCommon) validateSecurityConfiguration() error { if c.securityReadyCheck && len(c.SecretKey) == 0 { return errModernPSKRequired } return nil } func (s *ServerCommon) validateSecurityConfiguration() error { if s.securityReadyCheck && len(s.SecretKey) == 0 { return errModernPSKRequired } return nil } type modernPSKCodecCache struct { mu sync.Mutex aad []byte key []byte runtime *modernPSKCodecRuntime } type modernPSKCodecRuntime struct { aead cipher.AEAD key []byte aad []byte prefix [modernPSKNonceSize - 8]byte seq atomic.Uint64 } func newModernPSKCodecCache(aad []byte) *modernPSKCodecCache { return &modernPSKCodecCache{aad: bytes.Clone(aad)} } func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime, error) { if c == nil { return nil, errModernPSKSecretEmpty } c.mu.Lock() defer c.mu.Unlock() if c.runtime != nil && bytes.Equal(c.key, key) { return c.runtime, nil } runtime, err := newModernPSKCodecRuntime(key, c.aad) if err != nil { return nil, err } c.key = bytes.Clone(key) c.runtime = runtime return runtime, nil } func newModernPSKCodecRuntime(key []byte, aad []byte) (*modernPSKCodecRuntime, error) { if len(key) == 0 { return nil, errModernPSKSecretEmpty } block, err := aes.NewCipher(key) if err != nil { return nil, err } aead, err := cipher.NewGCM(block) if err != nil { return nil, err } runtime := &modernPSKCodecRuntime{ aead: aead, key: bytes.Clone(key), aad: bytes.Clone(aad), } if _, err := cryptorand.Read(runtime.prefix[:]); err != nil { return nil, err } return runtime, nil } func (r *modernPSKCodecRuntime) fork() (*modernPSKCodecRuntime, error) { if r == nil { return nil, errModernPSKSecretEmpty } return newModernPSKCodecRuntime(r.key, r.aad) } func (r *modernPSKCodecRuntime) nextNonce() [modernPSKNonceSize]byte { var nonce [modernPSKNonceSize]byte if r == nil { return nonce } copy(nonce[:len(r.prefix)], r.prefix[:]) binary.BigEndian.PutUint64(nonce[len(r.prefix):], r.seq.Add(1)) return nonce } func (r *modernPSKCodecRuntime) sealStreamFastPayload(dataID uint64, seq uint64, payload []byte) ([]byte, error) { return r.sealFilledPayload(streamFastPayloadHeaderLen+len(payload), func(frame []byte) error { if err := encodeStreamFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { return err } copy(frame[streamFastPayloadHeaderLen:], payload) return nil }) } func (r *modernPSKCodecRuntime) sealBulkFastPayload(dataID uint64, seq uint64, payload []byte) ([]byte, error) { if r == nil { return nil, errTransportPayloadEncryptFailed } return r.sealFilledPayload(bulkFastPayloadHeaderLen+len(payload), func(frame []byte) error { if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { return err } copy(frame[bulkFastPayloadHeaderLen:], payload) return nil }) } func (r *modernPSKCodecRuntime) sealPlainPayload(plain []byte) ([]byte, error) { return r.sealFilledPayload(len(plain), func(dst []byte) error { copy(dst, plain) return nil }) } func (r *modernPSKCodecRuntime) sealFilledPayload(plainLen int, fill func([]byte) error) ([]byte, error) { if r == nil { return nil, errTransportPayloadEncryptFailed } if plainLen < 0 { return nil, errTransportPayloadEncryptFailed } nonce := r.nextNonce() headerLen := len(modernPSKMagic) + modernPSKNonceSize out := make([]byte, headerLen+plainLen+r.aead.Overhead()) sealed, err := r.sealInto(out, headerLen, nonce, plainLen, fill) if err != nil { return nil, err } return out[:headerLen+len(sealed)], nil } func (r *modernPSKCodecRuntime) sealFilledPayloadPooled(plainLen int, fill func([]byte) error) ([]byte, func(), error) { if r == nil { return nil, nil, errTransportPayloadEncryptFailed } if plainLen < 0 { return nil, nil, errTransportPayloadEncryptFailed } nonce := r.nextNonce() headerLen := len(modernPSKMagic) + modernPSKNonceSize totalLen := headerLen + plainLen + r.aead.Overhead() out := getModernPSKPayloadBuffer(totalLen) sealed, err := r.sealInto(out, headerLen, nonce, plainLen, fill) if err != nil { putModernPSKPayloadBuffer(out) return nil, nil, err } return out[:headerLen+len(sealed)], func() { putModernPSKPayloadBuffer(out) }, nil } func (r *modernPSKCodecRuntime) sealInto(out []byte, headerLen int, nonce [modernPSKNonceSize]byte, plainLen int, fill func([]byte) error) ([]byte, error) { copy(out[:len(modernPSKMagic)], modernPSKMagic) copy(out[len(modernPSKMagic):headerLen], nonce[:]) frame := out[headerLen : headerLen+plainLen] if fill != nil { if err := fill(frame); err != nil { return nil, err } } return r.aead.Seal(frame[:0], nonce[:], frame, r.aad), nil } func (r *modernPSKCodecRuntime) openPayload(encrypted []byte) ([]byte, error) { if r == nil { return nil, errTransportPayloadDecryptFailed } headerLen := len(modernPSKMagic) + modernPSKNonceSize if len(encrypted) < headerLen { return nil, errModernPSKPayload } if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) { return nil, errModernPSKPayload } nonce := encrypted[len(modernPSKMagic):headerLen] ciphertext := encrypted[headerLen:] return r.aead.Open(make([]byte, 0, len(ciphertext)), nonce, ciphertext, r.aad) } func (r *modernPSKCodecRuntime) openPayloadPooled(encrypted []byte, release func()) ([]byte, func(), error) { if r == nil { if release != nil { release() } return nil, nil, errTransportPayloadDecryptFailed } headerLen := len(modernPSKMagic) + modernPSKNonceSize if len(encrypted) < headerLen { if release != nil { release() } return nil, nil, errModernPSKPayload } if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) { if release != nil { release() } return nil, nil, errModernPSKPayload } nonce := encrypted[len(modernPSKMagic):headerLen] ciphertext := encrypted[headerLen:] plain, err := r.aead.Open(ciphertext[:0], nonce, ciphertext, r.aad) if err != nil { if release != nil { release() } return nil, nil, err } return plain, release, nil } func (r *modernPSKCodecRuntime) openPayloadOwnedPooled(encrypted []byte) ([]byte, func(), error) { if r == nil { return nil, nil, errTransportPayloadDecryptFailed } headerLen := len(modernPSKMagic) + modernPSKNonceSize if len(encrypted) < headerLen { return nil, nil, errModernPSKPayload } if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) { return nil, nil, errModernPSKPayload } nonce := encrypted[len(modernPSKMagic):headerLen] ciphertext := encrypted[headerLen:] plainLen := len(ciphertext) - r.aead.Overhead() if plainLen < 0 { return nil, nil, errModernPSKPayload } out := getModernPSKPayloadBuffer(plainLen) plain, err := r.aead.Open(out[:0], nonce, ciphertext, r.aad) if err != nil { putModernPSKPayloadBuffer(out) return nil, nil, err } return plain, func() { putModernPSKPayloadBuffer(out) }, nil } func getModernPSKPayloadBuffer(size int) []byte { if size <= 0 { return nil } if pooled, ok := modernPSKPayloadPool.Get().([]byte); ok && cap(pooled) >= size { return pooled[:size] } return make([]byte, size) } func putModernPSKPayloadBuffer(buf []byte) { if cap(buf) == 0 || cap(buf) > 32*1024*1024 { return } modernPSKPayloadPool.Put(buf[:0]) }