diff --git a/common/buf/buffer.go b/common/buf/buffer.go index afc86eb82..2d6705e6b 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -76,8 +76,8 @@ func (b *Buffer) Bytes() []byte { // Reset resets the content of the Buffer with a supplier. func (b *Buffer) Reset(writer Supplier) error { - b.start = 0 nBytes, err := writer(b.v) + b.start = 0 b.end = nBytes return err } diff --git a/common/buf/io.go b/common/buf/io.go index f0024b48d..37d52e4a8 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -40,6 +40,13 @@ func ReadFullFrom(reader io.Reader, size int) Supplier { } } +// ReadAtLeastFrom create a Supplier to read at least size bytes from the given io.Reader. +func ReadAtLeastFrom(reader io.Reader, size int) Supplier { + return func(b []byte) (int, error) { + return io.ReadAtLeast(reader, b, size) + } +} + func copyInternal(timer signal.ActivityTimer, reader Reader, writer Writer) error { for { buffer, err := reader.Read() diff --git a/common/crypto/auth.go b/common/crypto/auth.go index aa68fe422..9e0ebfe1b 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -65,6 +65,7 @@ type AuthenticationReader struct { buffer *buf.Buffer reader io.Reader sizeParser ChunkSizeDecoder + size int } const ( @@ -77,56 +78,98 @@ func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, re buffer: buf.NewLocal(readerBufferSize), reader: reader, sizeParser: sizeParser, + size: -1, } } -func (r *AuthenticationReader) readChunk() error { - if err := r.buffer.Reset(buf.ReadFullFrom(r.reader, r.sizeParser.SizeBytes())); err != nil { - return err +func (r *AuthenticationReader) readSize() error { + if r.size >= 0 { + return nil } - size, err := r.sizeParser.Decode(r.buffer.Bytes()) + + sizeBytes := r.sizeParser.SizeBytes() + if r.buffer.Len() < sizeBytes { + r.buffer.Reset(buf.ReadFrom(r.buffer)) + delta := sizeBytes - r.buffer.Len() + if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil { + return err + } + } + size, err := r.sizeParser.Decode(r.buffer.BytesTo(sizeBytes)) if err != nil { return err } - if size > readerBufferSize { - return newError("size too large ", size).AtWarning() - } - - if int(size) == r.auth.Overhead() { - return io.EOF - } - - if err := r.buffer.Reset(buf.ReadFullFrom(r.reader, int(size))); err != nil { - return err - } - - b, err := r.auth.Open(r.buffer.BytesTo(0), r.buffer.Bytes()) - if err != nil { - return err - } - r.buffer.Slice(0, len(b)) + r.size = int(size) + r.buffer.SliceFrom(sizeBytes) return nil } -func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) { - if r.buffer.IsEmpty() { - if err := r.readChunk(); err != nil { +func (r *AuthenticationReader) readChunk(waitForData bool) ([]byte, error) { + if err := r.readSize(); err != nil { + return nil, err + } + if r.size > readerBufferSize-r.sizeParser.SizeBytes() { + return nil, newError("size too large ", r.size).AtWarning() + } + + if r.size == r.auth.Overhead() { + return nil, io.EOF + } + + if r.buffer.Len() < r.size { + if !waitForData { + return nil, io.ErrNoProgress + } + r.buffer.Reset(buf.ReadFrom(r.buffer)) + + delta := r.size - r.buffer.Len() + if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil { return nil, err } } - mb := buf.NewMultiBuffer() - for !r.buffer.IsEmpty() { - b := buf.New() - b.AppendSupplier(buf.ReadFrom(r.buffer)) - mb.Append(b) + b, err := r.auth.Open(r.buffer.BytesTo(0), r.buffer.BytesTo(r.size)) + if err != nil { + return nil, err } + r.buffer.SliceFrom(r.size) + r.size = -1 + return b, nil +} + +func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) { + b, err := r.readChunk(true) + if err != nil { + return nil, err + } + + mb := buf.NewMultiBuffer() + + appendBytes := func(b []byte) { + for len(b) > 0 { + buffer := buf.New() + n, _ := buffer.Write(b) + b = b[n:] + mb.Append(buffer) + } + } + appendBytes(b) + + for r.buffer.Len() >= r.sizeParser.SizeBytes() { + b, err := r.readChunk(false) + if err != nil { + break + } + appendBytes(b) + } + return mb, nil } type AuthenticationWriter struct { auth Authenticator - buffer []byte + payload []byte + buffer *buf.Buffer writer io.Writer sizeParser ChunkSizeEncoder } @@ -134,37 +177,51 @@ type AuthenticationWriter struct { func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer) *AuthenticationWriter { return &AuthenticationWriter{ auth: auth, - buffer: make([]byte, 32*1024), + payload: make([]byte, 1024), + buffer: buf.NewLocal(readerBufferSize), writer: writer, sizeParser: sizeParser, } } -func (w *AuthenticationWriter) writeInternal(b []byte) error { - sizeBytes := w.sizeParser.SizeBytes() - cipherChunk, err := w.auth.Seal(w.buffer[sizeBytes:sizeBytes], b) - if err != nil { - return err - } +func (w *AuthenticationWriter) append(b []byte) { + encryptedSize := len(b) + w.auth.Overhead() - w.sizeParser.Encode(uint16(len(cipherChunk)), w.buffer[:0]) - _, err = w.writer.Write(w.buffer[:sizeBytes+len(cipherChunk)]) + w.buffer.AppendSupplier(func(bb []byte) (int, error) { + w.sizeParser.Encode(uint16(encryptedSize), bb[:0]) + return w.sizeParser.SizeBytes(), nil + }) + + w.buffer.AppendSupplier(func(bb []byte) (int, error) { + w.auth.Seal(bb[:0], b) + return encryptedSize, nil + }) +} + +func (w *AuthenticationWriter) flush() error { + _, err := w.writer.Write(w.buffer.Bytes()) + w.buffer.Clear() return err } func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error { defer mb.Release() - const StartIndex = 17 * 1024 for { - payloadLen, _ := mb.Read(w.buffer[StartIndex:]) - err := w.writeInternal(w.buffer[StartIndex : StartIndex+payloadLen]) - if err != nil { - return err + n, _ := mb.Read(w.payload) + w.append(w.payload[:n]) + if w.buffer.Len() > readerBufferSize-2*1024 { + if err := w.flush(); err != nil { + return err + } } if mb.IsEmpty() { break } } + + if !w.buffer.IsEmpty() { + return w.flush() + } return nil } diff --git a/common/crypto/auth_test.go b/common/crypto/auth_test.go index 9992b4016..501d6a39a 100644 --- a/common/crypto/auth_test.go +++ b/common/crypto/auth_test.go @@ -23,13 +23,13 @@ func TestAuthenticationReaderWriter(t *testing.T) { aead, err := cipher.NewGCM(block) assert.Error(err).IsNil() - rawPayload := make([]byte, 8192) + rawPayload := make([]byte, 8192*10) rand.Read(rawPayload) - payload := buf.NewLocal(8192) + payload := buf.NewLocal(8192 * 10) payload.Append(rawPayload) - cache := buf.NewLocal(16 * 1024) + cache := buf.NewLocal(160 * 1024) iv := make([]byte, 12) rand.Read(iv) @@ -42,7 +42,7 @@ func TestAuthenticationReaderWriter(t *testing.T) { }, PlainChunkSizeParser{}, cache) assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil() - assert.Int(cache.Len()).Equals(8210) + assert.Int(cache.Len()).Equals(83360) assert.Error(writer.Write(buf.NewMultiBuffer())).IsNil() assert.Error(err).IsNil() @@ -54,11 +54,16 @@ func TestAuthenticationReaderWriter(t *testing.T) { AdditionalDataGenerator: &NoOpBytesGenerator{}, }, PlainChunkSizeParser{}, cache) - mb, err := reader.Read() - assert.Error(err).IsNil() - assert.Int(mb.Len()).Equals(len(rawPayload)) + mb := buf.NewMultiBuffer() - mbContent := make([]byte, 8192) + for mb.Len() < len(rawPayload) { + mb2, err := reader.Read() + assert.Error(err).IsNil() + + mb.AppendMulti(mb2) + } + + mbContent := make([]byte, 8192*10) mb.Read(mbContent) assert.Bytes(mbContent).Equals(rawPayload) diff --git a/testing/assert/bytes.go b/testing/assert/bytes.go index 3f813557c..6559735f3 100644 --- a/testing/assert/bytes.go +++ b/testing/assert/bytes.go @@ -30,7 +30,7 @@ func (subject *BytesSubject) Equals(expectation []byte) { } for idx, b := range expectation { if subject.value[idx] != b { - subject.FailWithMessage(fmt.Sprint("Bytes are different: ", b, " vs ", subject.value[idx])) + subject.FailWithMessage(fmt.Sprint("Bytes are different: ", b, " vs ", subject.value[idx], " at pos ", idx)) return } }