1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-21 09:36:34 -05:00

Decryption reader for decoding vmess message

This commit is contained in:
V2Ray 2015-09-08 15:39:32 +02:00
parent 265e6e4dbd
commit 69bcce0b0d
2 changed files with 140 additions and 0 deletions

View File

@ -0,0 +1,73 @@
package vmess
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"fmt"
"io"
)
const (
blockSize = 16
)
type DecryptionReader struct {
cipher cipher.Block
reader io.Reader
buffer *bytes.Buffer
}
func NewDecryptionReader(reader io.Reader, key []byte) (*DecryptionReader, error) {
decryptionReader := new(DecryptionReader)
cipher, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
decryptionReader.cipher = cipher
decryptionReader.reader = reader
decryptionReader.buffer = bytes.NewBuffer(make([]byte, 0, 2 * blockSize))
return decryptionReader, nil
}
func (reader *DecryptionReader) readBlock() error {
buffer := make([]byte, blockSize)
nBytes, err := reader.reader.Read(buffer)
if err != nil {
return err
}
if nBytes < blockSize {
return fmt.Errorf("Expected to read %d bytes, but got %d bytes", blockSize, nBytes)
}
reader.cipher.Decrypt(buffer, buffer)
reader.buffer.Write(buffer)
return nil
}
func (reader *DecryptionReader) Read(p []byte) (int, error) {
if reader.buffer.Len() == 0 {
err := reader.readBlock()
if err != nil {
return 0, err
}
}
nBytes, err := reader.buffer.Read(p)
if err != nil {
return nBytes, err
}
if nBytes < len(p) {
err = reader.readBlock()
if err != nil {
return nBytes, err
}
moreBytes, err := reader.buffer.Read(p[nBytes:])
if err != nil {
return nBytes, err
}
nBytes += moreBytes
if nBytes != len(p) {
return nBytes, fmt.Errorf("Unable to read %d bytes", len(p))
}
}
return nBytes, err
}

View File

@ -0,0 +1,67 @@
package vmess
import (
"bytes"
"crypto/aes"
"crypto/rand"
mrand "math/rand"
"testing"
)
func randomBytes(p []byte, t *testing.T) {
nBytes, err := rand.Read(p)
if err != nil {
t.Fatal(err)
}
if nBytes != len(p) {
t.Error("Unable to generate %d bytes of random buffer", len(p))
}
}
func TestNormalReading(t *testing.T) {
testSize := 256
plaintext := make([]byte, testSize)
randomBytes(plaintext, t)
keySize := 16
key := make([]byte, keySize)
randomBytes(key, t)
cipher, err := aes.NewCipher(key)
if err != nil {
t.Fatal(err)
}
ciphertext := make([]byte, testSize)
for encryptSize := 0; encryptSize < testSize; encryptSize += blockSize {
cipher.Encrypt(ciphertext[encryptSize:], plaintext[encryptSize:])
}
ciphertextcopy := make([]byte, testSize)
copy(ciphertextcopy, ciphertext)
reader, err := NewDecryptionReader(bytes.NewReader(ciphertextcopy), key)
if err != nil {
t.Fatal(err)
}
readtext := make([]byte, testSize)
readSize := 0
for readSize < testSize {
nBytes := mrand.Intn(16) + 1
if nBytes > testSize - readSize {
nBytes = testSize - readSize
}
bytesRead, err := reader.Read(readtext[readSize:readSize + nBytes])
if err != nil {
t.Fatal(err)
}
if bytesRead != nBytes {
t.Errorf("Expected to read %d bytes, but only read %d bytes", nBytes, bytesRead)
}
readSize += nBytes
}
if ! bytes.Equal(readtext, plaintext) {
t.Errorf("Expected plaintext %v, but got %v", plaintext, readtext)
}
}