diff --git a/id.go b/id.go index 6d09a6a46..72685814a 100644 --- a/id.go +++ b/id.go @@ -4,7 +4,6 @@ import ( "crypto/hmac" "crypto/md5" "encoding/hex" - "hash" mrand "math/rand" "time" @@ -19,7 +18,7 @@ const ( type ID struct { String string Bytes []byte - hasher hash.Hash + cmdKey []byte } func NewID(id string) (ID, error) { @@ -27,8 +26,13 @@ func NewID(id string) (ID, error) { if err != nil { return ID{}, log.Error("Failed to parse id %s", id) } - hasher := hmac.New(md5.New, idBytes) - return ID{id, idBytes, hasher}, nil + + md5hash := md5.New() + md5hash.Write(idBytes) + md5hash.Write([]byte("c48619fe-8f02-49e0-b9e9-edf763e17e21")) + cmdKey := md5.Sum(nil) + + return ID{id, idBytes, cmdKey[:]}, nil } func (v ID) TimeRangeHash(rangeSec int) []byte { @@ -54,10 +58,13 @@ func (v ID) TimeHash(timeSec int64) []byte { } func (v ID) Hash(data []byte) []byte { - v.hasher.Write(data) - hash := v.hasher.Sum(nil) - v.hasher.Reset() - return hash + hasher := hmac.New(md5.New, v.Bytes) + hasher.Write(data) + return hasher.Sum(nil) +} + +func (v ID) CmdKey() []byte { + return v.cmdKey } var byteGroups = []int{8, 4, 4, 4, 12} diff --git a/io/vmess/vmess.go b/io/vmess/vmess.go index 1aea3d58e..df73706a2 100644 --- a/io/vmess/vmess.go +++ b/io/vmess/vmess.go @@ -13,6 +13,7 @@ import ( "github.com/v2ray/v2ray-core" v2io "github.com/v2ray/v2ray-core/io" + "github.com/v2ray/v2ray-core/log" v2net "github.com/v2ray/v2ray-core/net" ) @@ -24,12 +25,11 @@ const ( Version = byte(0x01) blockSize = 16 - - CryptoMessage = "c48619fe-8f02-49e0-b9e9-edf763e17e21" ) var ( - ErrorInvalidUser = errors.New("Invalid User") + ErrorInvalidUser = errors.New("Invalid User") + ErrorInvalidVerion = errors.New("Invalid Version") emptyIV = make([]byte, blockSize) ) @@ -62,17 +62,13 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { request := new(VMessRequest) buffer := make([]byte, 256) - nBytes, err := reader.Read(buffer[0:1]) - if err != nil { - return nil, err - } - // TODO: verify version number - request.Version = buffer[0] - nBytes, err = reader.Read(buffer[:core.IDBytesLen]) + nBytes, err := reader.Read(buffer[:core.IDBytesLen]) if err != nil { return nil, err } + + log.Debug("Read user hash: %v", buffer[:nBytes]) userId, valid := r.vUserSet.GetUser(buffer[:nBytes]) if !valid { @@ -80,7 +76,7 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { } request.UserId = *userId - aesCipher, err := aes.NewCipher(userId.Hash([]byte(CryptoMessage))) + aesCipher, err := aes.NewCipher(userId.CmdKey()) if err != nil { return nil, err } @@ -105,6 +101,17 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { return nil, err } + nBytes, err = decryptor.Read(buffer[0:1]) + if err != nil { + return nil, err + } + + request.Version = buffer[0] + if request.Version != Version { + log.Error("Unknown VMess version %d", request.Version) + return nil, ErrorInvalidVerion + } + // TODO: check number of bytes returned _, err = decryptor.Read(request.RequestIV[:]) if err != nil { @@ -182,8 +189,10 @@ func NewVMessRequestWriter() *VMessRequestWriter { func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) error { buffer := make([]byte, 0, 300) - buffer = append(buffer, request.Version) - buffer = append(buffer, request.UserId.TimeRangeHash(30)...) + userHash := request.UserId.TimeRangeHash(30) + + log.Debug("Writing userhash: %v", userHash) + buffer = append(buffer, userHash...) encryptionBegin := len(buffer) @@ -196,6 +205,7 @@ func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) erro buffer = append(buffer, byte(randomLength)) buffer = append(buffer, randomContent...) + buffer = append(buffer, request.Version) buffer = append(buffer, request.RequestIV[:]...) buffer = append(buffer, request.RequestKey[:]...) buffer = append(buffer, request.ResponseHeader[:]...) @@ -231,7 +241,7 @@ func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) erro buffer = append(buffer, paddingBuffer...) encryptionEnd := len(buffer) - aesCipher, err := aes.NewCipher(request.UserId.Hash([]byte(CryptoMessage))) + aesCipher, err := aes.NewCipher(request.UserId.CmdKey()) if err != nil { return err } diff --git a/io/vmess/vmess_test.go b/io/vmess/vmess_test.go index 1d1900da5..1c1b0aef3 100644 --- a/io/vmess/vmess_test.go +++ b/io/vmess/vmess_test.go @@ -3,6 +3,7 @@ package vmess import ( "bytes" "crypto/rand" + "io/ioutil" "testing" "github.com/v2ray/v2ray-core" @@ -51,7 +52,7 @@ func TestVMessSerialization(t *testing.T) { t.Fatal(err) } - userSet.UserHashes[string(buffer.Bytes()[1:17])] = 0 + userSet.UserHashes[string(buffer.Bytes()[:16])] = 0 requestReader := NewVMessRequestReader(&userSet) actualRequest, err := requestReader.Read(buffer) @@ -67,3 +68,25 @@ func TestVMessSerialization(t *testing.T) { assert.Byte(actualRequest.Command).Named("Command").Equals(request.Command) assert.String(actualRequest.Address.String()).Named("Address").Equals(request.Address.String()) } + +func BenchmarkVMessRequestWriting(b *testing.B) { + userId, _ := core.NewID("2b2966ac-16aa-4fbf-8d81-c5f172a3da51") + userSet := mocks.MockUserSet{[]core.ID{}, make(map[string]int)} + userSet.AddUser(core.User{userId}) + + request := new(VMessRequest) + request.Version = byte(0x01) + request.UserId = userId + + rand.Read(request.RequestIV[:]) + rand.Read(request.RequestKey[:]) + rand.Read(request.ResponseHeader[:]) + + request.Command = byte(0x01) + request.Address = v2net.DomainAddress("v2ray.com", 80) + + requestWriter := NewVMessRequestWriter() + for i := 0; i < b.N; i++ { + requestWriter.Write(ioutil.Discard, request) + } +} diff --git a/net/vmess/vmessin.go b/net/vmess/vmessin.go index 09d0bbf51..22e4a451c 100644 --- a/net/vmess/vmessin.go +++ b/net/vmess/vmessin.go @@ -55,6 +55,7 @@ func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error request, err := reader.Read(connection) if err != nil { + log.Debug("Failed to parse VMess request: %v", err) return err } log.Debug("Received request for %s", request.Address.String()) diff --git a/net/vmess/vmessout.go b/net/vmess/vmessout.go index e4ef7ebb6..9f8533665 100644 --- a/net/vmess/vmessout.go +++ b/net/vmess/vmessout.go @@ -74,11 +74,15 @@ func (handler *VMessOutboundHandler) startCommunicate(request *vmessio.VMessRequ return err } defer conn.Close() + + input := ray.OutboundInput() + output := ray.OutboundOutput() requestWriter := vmessio.NewVMessRequestWriter() err = requestWriter.Write(conn, request) if err != nil { log.Error("Failed to write VMess request: %v", err) + close(output) return err } @@ -90,6 +94,7 @@ func (handler *VMessOutboundHandler) startCommunicate(request *vmessio.VMessRequ response := vmessio.VMessResponse{} nBytes, err := conn.Read(response[:]) if err != nil { + close(output) log.Error("Failed to read VMess response (%d bytes): %v", nBytes, err) return err } @@ -98,17 +103,17 @@ func (handler *VMessOutboundHandler) startCommunicate(request *vmessio.VMessRequ encryptRequestWriter, err := v2io.NewAesEncryptWriter(requestKey, requestIV, conn) if err != nil { + close(output) log.Error("Failed to create encrypt writer: %v", err) return err } decryptResponseReader, err := v2io.NewAesDecryptReader(responseKey[:], responseIV[:], conn) if err != nil { + close(output) log.Error("Failed to create decrypt reader: %v", err) return err } - input := ray.OutboundInput() - output := ray.OutboundOutput() readFinish := make(chan bool) writeFinish := make(chan bool)