package notify import ( "bytes" "crypto/ecdh" "crypto/hmac" cryptorand "crypto/rand" "crypto/sha256" "encoding/binary" "errors" ) const ( peerAttachECDHEPublicKeySize = 32 peerAttachSessionIDSize = 16 peerAttachKeyModeStatic = "psk-static" peerAttachKeyModeECDHE = "psk-ecdhe" transportKeyModeExternal = "external" ) var errPeerAttachForwardSecrecyInvalid = errors.New("peer attach forward secrecy is invalid") type peerAttachRequestState struct { forwardSecrecy *peerAttachForwardSecrecyClientState } type peerAttachForwardSecrecyClientState struct { privateKey *ecdh.PrivateKey publicKey []byte } type peerAttachResponseVerifyResult struct { authFallback bool steadyProfile transportProtectionProfile } func newPeerAttachForwardSecrecyClientState() (*peerAttachForwardSecrecyClientState, error) { curve := ecdh.X25519() privateKey, err := curve.GenerateKey(cryptorand.Reader) if err != nil { return nil, err } publicKey := privateKey.PublicKey().Bytes() if len(publicKey) != peerAttachECDHEPublicKeySize { return nil, errPeerAttachForwardSecrecyInvalid } return &peerAttachForwardSecrecyClientState{ privateKey: privateKey, publicKey: bytes.Clone(publicKey), }, nil } func derivePeerAttachForwardSecrecyTransportProfile(base transportProtectionProfile, bootstrapKey []byte, localPrivateKey *ecdh.PrivateKey, peerPublicKey []byte, req peerAttachRequest, resp peerAttachResponse) (transportProtectionProfile, error) { if len(bootstrapKey) == 0 || localPrivateKey == nil || len(peerPublicKey) != peerAttachECDHEPublicKeySize { return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid } curve := ecdh.X25519() publicKey, err := curve.NewPublicKey(peerPublicKey) if err != nil { return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid } sharedSecret, err := localPrivateKey.ECDH(publicKey) if err != nil { return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid } transcriptHash := peerAttachForwardSecrecyTranscriptHash(req, resp) ikm := make([]byte, 0, len(sharedSecret)+len(transcriptHash)) ikm = append(ikm, sharedSecret...) ikm = append(ikm, transcriptHash...) prk := hkdfExtractSHA256(bootstrapKey, ikm) sessionKey := hkdfExpandSHA256(prk, []byte("notify/transport/session/v1"), 32) sessionID := hkdfExpandSHA256(prk, []byte("notify/session-id/v1"), peerAttachSessionIDSize) return deriveModernPSKSessionProtectionProfile(base, sessionKey, sessionID) } func peerAttachForwardSecrecyTranscriptHash(req peerAttachRequest, resp peerAttachResponse) []byte { buf := make([]byte, 0, 256+len(req.PeerID)+len(resp.PeerID)+len(resp.Error)+len(req.ClientECDHEPublicKey)+len(resp.ServerECDHEPublicKey)) buf = appendPeerAttachTranscriptString(buf, "notify/peer-attach/forward-secrecy/v1") buf = binary.BigEndian.AppendUint64(buf, req.Features) buf = appendPeerAttachTranscriptString(buf, req.PeerID) buf = appendPeerAttachTranscriptBytes(buf, req.ClientNonce) buf = appendPeerAttachTranscriptBytes(buf, req.ClientECDHEPublicKey) buf = binary.BigEndian.AppendUint64(buf, resp.Features) buf = appendPeerAttachTranscriptString(buf, resp.PeerID) buf = appendPeerAttachTranscriptBool(buf, resp.Accepted) buf = appendPeerAttachTranscriptBool(buf, resp.Reused) buf = appendPeerAttachTranscriptString(buf, resp.Error) buf = appendPeerAttachTranscriptBytes(buf, resp.ServerNonce) buf = appendPeerAttachTranscriptString(buf, resp.KeyMode) buf = appendPeerAttachTranscriptBytes(buf, resp.ServerECDHEPublicKey) sum := sha256.Sum256(buf) return sum[:] } func appendPeerAttachTranscriptBytes(dst []byte, data []byte) []byte { dst = binary.BigEndian.AppendUint32(dst, uint32(len(data))) return append(dst, data...) } func appendPeerAttachTranscriptString(dst []byte, value string) []byte { return appendPeerAttachTranscriptBytes(dst, []byte(value)) } func appendPeerAttachTranscriptBool(dst []byte, value bool) []byte { if value { return append(dst, 1) } return append(dst, 0) } func hkdfExtractSHA256(salt []byte, ikm []byte) []byte { mac := hmac.New(sha256.New, salt) _, _ = mac.Write(ikm) return mac.Sum(nil) } func hkdfExpandSHA256(prk []byte, info []byte, size int) []byte { if size <= 0 { return nil } out := make([]byte, 0, size) var block []byte for counter := byte(1); len(out) < size; counter++ { mac := hmac.New(sha256.New, prk) if len(block) != 0 { _, _ = mac.Write(block) } _, _ = mac.Write(info) _, _ = mac.Write([]byte{counter}) block = mac.Sum(nil) remaining := size - len(out) if remaining > len(block) { remaining = len(block) } out = append(out, block[:remaining]...) } return out }