diff --git a/io/vmess/decryptionreader.go b/io/vmess/decryptionreader.go new file mode 100644 index 000000000..016a45bf4 --- /dev/null +++ b/io/vmess/decryptionreader.go @@ -0,0 +1,73 @@ +package vmess + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "fmt" + "io" +) + +const ( + blockSize = 16 +) + +type DecryptionReader struct { + cipher cipher.Block + reader io.Reader + buffer *bytes.Buffer +} + +func NewDecryptionReader(reader io.Reader, key []byte) (*DecryptionReader, error) { + decryptionReader := new(DecryptionReader) + cipher, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + decryptionReader.cipher = cipher + decryptionReader.reader = reader + decryptionReader.buffer = bytes.NewBuffer(make([]byte, 0, 2 * blockSize)) + return decryptionReader, nil +} + +func (reader *DecryptionReader) readBlock() error { + buffer := make([]byte, blockSize) + nBytes, err := reader.reader.Read(buffer) + if err != nil { + return err + } + if nBytes < blockSize { + return fmt.Errorf("Expected to read %d bytes, but got %d bytes", blockSize, nBytes) + } + reader.cipher.Decrypt(buffer, buffer) + reader.buffer.Write(buffer) + return nil +} + +func (reader *DecryptionReader) Read(p []byte) (int, error) { + if reader.buffer.Len() == 0 { + err := reader.readBlock() + if err != nil { + return 0, err + } + } + nBytes, err := reader.buffer.Read(p) + if err != nil { + return nBytes, err + } + if nBytes < len(p) { + err = reader.readBlock() + if err != nil { + return nBytes, err + } + moreBytes, err := reader.buffer.Read(p[nBytes:]) + if err != nil { + return nBytes, err + } + nBytes += moreBytes + if nBytes != len(p) { + return nBytes, fmt.Errorf("Unable to read %d bytes", len(p)) + } + } + return nBytes, err +} diff --git a/io/vmess/decryptionreader_test.go b/io/vmess/decryptionreader_test.go new file mode 100644 index 000000000..b61bac5fb --- /dev/null +++ b/io/vmess/decryptionreader_test.go @@ -0,0 +1,67 @@ +package vmess + +import ( + "bytes" + "crypto/aes" + "crypto/rand" + mrand "math/rand" + "testing" +) + +func randomBytes(p []byte, t *testing.T) { + nBytes, err := rand.Read(p) + if err != nil { + t.Fatal(err) + } + if nBytes != len(p) { + t.Error("Unable to generate %d bytes of random buffer", len(p)) + } +} + +func TestNormalReading(t *testing.T) { + testSize := 256 + plaintext := make([]byte, testSize) + randomBytes(plaintext, t) + + keySize := 16 + key := make([]byte, keySize) + randomBytes(key, t) + + cipher, err := aes.NewCipher(key) + if err != nil { + t.Fatal(err) + } + + ciphertext := make([]byte, testSize) + for encryptSize := 0; encryptSize < testSize; encryptSize += blockSize { + cipher.Encrypt(ciphertext[encryptSize:], plaintext[encryptSize:]) + } + + ciphertextcopy := make([]byte, testSize) + copy(ciphertextcopy, ciphertext) + + reader, err := NewDecryptionReader(bytes.NewReader(ciphertextcopy), key) + if err != nil { + t.Fatal(err) + } + + readtext := make([]byte, testSize) + readSize := 0 + for readSize < testSize { + nBytes := mrand.Intn(16) + 1 + if nBytes > testSize - readSize { + nBytes = testSize - readSize + } + bytesRead, err := reader.Read(readtext[readSize:readSize + nBytes]) + if err != nil { + t.Fatal(err) + } + if bytesRead != nBytes { + t.Errorf("Expected to read %d bytes, but only read %d bytes", nBytes, bytesRead) + } + readSize += nBytes + } + if ! bytes.Equal(readtext, plaintext) { + t.Errorf("Expected plaintext %v, but got %v", plaintext, readtext) + } +} \ No newline at end of file