package symm import ( "bytes" "crypto/cipher" "encoding/binary" "errors" "io" ) const ccmStreamMagic = "SCC1" var ErrInvalidCCMStreamChunk = errors.New("invalid ccm stream chunk") func encryptCCMChunk(aead cipher.AEAD, plain, nonce, aad []byte, chunkIndex uint64) []byte { chunkNonce := deriveChunkNonce(nonce, chunkIndex) return aead.Seal(nil, chunkNonce, plain, aad) } func decryptCCMChunk(aead cipher.AEAD, ciphertext, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { chunkNonce := deriveChunkNonce(nonce, chunkIndex) return aead.Open(nil, chunkNonce, ciphertext, aad) } func encryptCCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error { if _, err := dst.Write([]byte(ccmStreamMagic)); err != nil { return err } buf := make([]byte, gcmStreamChunkSize) lenBuf := make([]byte, 4) var chunkIndex uint64 for { n, err := src.Read(buf) if n > 0 { sealed := encryptCCMChunk(aead, buf[:n], nonce, aad, chunkIndex) binary.BigEndian.PutUint32(lenBuf, uint32(len(sealed))) if _, werr := dst.Write(lenBuf); werr != nil { return werr } if _, werr := dst.Write(sealed); werr != nil { return werr } chunkIndex++ } if err != nil { if err == io.EOF { return nil } return err } } } func decryptCCMChunkedOrLegacyStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error { header := make([]byte, len(ccmStreamMagic)) n, err := io.ReadFull(src, header) if err != nil { if err == io.EOF { return nil } if err != io.ErrUnexpectedEOF { return err } return decryptCCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header[:n]), src), aead, nonce, aad) } if string(header) != ccmStreamMagic { return decryptCCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header), src), aead, nonce, aad) } return decryptCCMChunkedStream(dst, src, aead, nonce, aad) } func decryptCCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error { lenBuf := make([]byte, 4) maxChunkLen := uint32(gcmStreamChunkSize + aead.Overhead()) var chunkIndex uint64 for { _, err := io.ReadFull(src, lenBuf) if err != nil { if err == io.EOF { return nil } if err == io.ErrUnexpectedEOF { return io.ErrUnexpectedEOF } return err } chunkLen := binary.BigEndian.Uint32(lenBuf) if chunkLen < uint32(aead.Overhead()) || chunkLen > maxChunkLen { return ErrInvalidCCMStreamChunk } chunk := make([]byte, chunkLen) if _, err := io.ReadFull(src, chunk); err != nil { if err == io.ErrUnexpectedEOF { return io.ErrUnexpectedEOF } return err } plain, err := decryptCCMChunk(aead, chunk, nonce, aad, chunkIndex) if err != nil { return err } if _, err := dst.Write(plain); err != nil { return err } chunkIndex++ } } func decryptCCMLegacyBuffered(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error { enc, err := io.ReadAll(src) if err != nil { return err } plain, err := aead.Open(nil, nonce, enc, aad) if err != nil { return err } _, err = dst.Write(plain) return err }