diff --git a/common/crypto/io.go b/common/crypto/io.go index 887532556..f5a79b288 100644 --- a/common/crypto/io.go +++ b/common/crypto/io.go @@ -3,21 +3,26 @@ package crypto import ( "crypto/cipher" "io" + + "github.com/v2ray/v2ray-core/common" ) -type cryptionReader struct { +type CryptionReader struct { stream cipher.Stream reader io.Reader } -func NewCryptionReader(stream cipher.Stream, reader io.Reader) io.Reader { - return &cryptionReader{ +func NewCryptionReader(stream cipher.Stream, reader io.Reader) *CryptionReader { + return &CryptionReader{ stream: stream, reader: reader, } } -func (this *cryptionReader) Read(data []byte) (int, error) { +func (this *CryptionReader) Read(data []byte) (int, error) { + if this.reader == nil { + return 0, common.ErrorAlreadyReleased + } nBytes, err := this.reader.Read(data) if nBytes > 0 { this.stream.XORKeyStream(data[:nBytes], data[:nBytes]) @@ -25,19 +30,32 @@ func (this *cryptionReader) Read(data []byte) (int, error) { return nBytes, err } -type cryptionWriter struct { +func (this *CryptionReader) Release() { + this.reader = nil + this.stream = nil +} + +type CryptionWriter struct { stream cipher.Stream writer io.Writer } -func NewCryptionWriter(stream cipher.Stream, writer io.Writer) io.Writer { - return &cryptionWriter{ +func NewCryptionWriter(stream cipher.Stream, writer io.Writer) *CryptionWriter { + return &CryptionWriter{ stream: stream, writer: writer, } } -func (this *cryptionWriter) Write(data []byte) (int, error) { +func (this *CryptionWriter) Write(data []byte) (int, error) { + if this.writer == nil { + return 0, common.ErrorAlreadyReleased + } this.stream.XORKeyStream(data, data) return this.writer.Write(data) } + +func (this *CryptionWriter) Release() { + this.writer = nil + this.stream = nil +}