starcrypto/symm/ccm_stream.go

128 lines
3.1 KiB
Go

package symm
import (
"bytes"
"crypto/cipher"
"encoding/binary"
"errors"
"io"
)
const ccmStreamMagic = "SCC1"
var ErrInvalidCCMStreamChunk = errors.New("invalid ccm stream chunk")
func encryptCCMChunk(aead cipher.AEAD, plain, nonce, aad []byte, chunkIndex uint64) []byte {
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
return aead.Seal(nil, chunkNonce, plain, aad)
}
func decryptCCMChunk(aead cipher.AEAD, ciphertext, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
return aead.Open(nil, chunkNonce, ciphertext, aad)
}
func encryptCCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
if _, err := dst.Write([]byte(ccmStreamMagic)); err != nil {
return err
}
buf := make([]byte, gcmStreamChunkSize)
lenBuf := make([]byte, 4)
var chunkIndex uint64
for {
n, err := src.Read(buf)
if n > 0 {
sealed := encryptCCMChunk(aead, buf[:n], nonce, aad, chunkIndex)
binary.BigEndian.PutUint32(lenBuf, uint32(len(sealed)))
if _, werr := dst.Write(lenBuf); werr != nil {
return werr
}
if _, werr := dst.Write(sealed); werr != nil {
return werr
}
chunkIndex++
}
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
func decryptCCMChunkedOrLegacyStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
header := make([]byte, len(ccmStreamMagic))
n, err := io.ReadFull(src, header)
if err != nil {
if err == io.EOF {
return nil
}
if err != io.ErrUnexpectedEOF {
return err
}
return decryptCCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header[:n]), src), aead, nonce, aad)
}
if string(header) != ccmStreamMagic {
return decryptCCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header), src), aead, nonce, aad)
}
return decryptCCMChunkedStream(dst, src, aead, nonce, aad)
}
func decryptCCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
lenBuf := make([]byte, 4)
maxChunkLen := uint32(gcmStreamChunkSize + aead.Overhead())
var chunkIndex uint64
for {
_, err := io.ReadFull(src, lenBuf)
if err != nil {
if err == io.EOF {
return nil
}
if err == io.ErrUnexpectedEOF {
return io.ErrUnexpectedEOF
}
return err
}
chunkLen := binary.BigEndian.Uint32(lenBuf)
if chunkLen < uint32(aead.Overhead()) || chunkLen > maxChunkLen {
return ErrInvalidCCMStreamChunk
}
chunk := make([]byte, chunkLen)
if _, err := io.ReadFull(src, chunk); err != nil {
if err == io.ErrUnexpectedEOF {
return io.ErrUnexpectedEOF
}
return err
}
plain, err := decryptCCMChunk(aead, chunk, nonce, aad, chunkIndex)
if err != nil {
return err
}
if _, err := dst.Write(plain); err != nil {
return err
}
chunkIndex++
}
}
func decryptCCMLegacyBuffered(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
enc, err := io.ReadAll(src)
if err != nil {
return err
}
plain, err := aead.Open(nil, nonce, enc, aad)
if err != nil {
return err
}
_, err = dst.Write(plain)
return err
}