From 791ac780f02b504beea5bd7c5122467e65ab86e8 Mon Sep 17 00:00:00 2001 From: V2Ray Date: Fri, 11 Sep 2015 17:27:36 +0200 Subject: [PATCH] Refactor vmess internal struct for better readability --- io/vmess/vmess.go | 205 +++++++++++------------------------------ io/vmess/vmess_test.go | 60 ++++-------- log/log.go | 45 +++++++++ net/vmess/vmessin.go | 24 +++-- net/vmess/vmessout.go | 30 ++---- vid.go | 5 +- vpoint.go | 7 -- 7 files changed, 147 insertions(+), 229 deletions(-) create mode 100644 log/log.go diff --git a/io/vmess/vmess.go b/io/vmess/vmess.go index b67f315e4..16dca5efc 100644 --- a/io/vmess/vmess.go +++ b/io/vmess/vmess.go @@ -11,7 +11,6 @@ import ( "io" _ "log" mrand "math/rand" - "net" "github.com/v2ray/v2ray-core" v2io "github.com/v2ray/v2ray-core/io" @@ -33,127 +32,15 @@ var ( // VMessRequest implements the request message of VMess protocol. It only contains // the header of a request message. The data part will be handled by conection // handler directly, in favor of data streaming. -// 1 Version -// 16 UserHash -// 16 Request IV -// 16 Request Key -// 4 Response Header -// 1 Command -// 2 Port -// 1 Address Type -// 256 Target Address -type VMessRequest [312]byte - -func (r *VMessRequest) Version() byte { - return r[0] -} - -func (r *VMessRequest) SetVersion(version byte) *VMessRequest { - r[0] = version - return r -} - -func (r *VMessRequest) UserHash() []byte { - return r[1:17] -} - -func (r *VMessRequest) RequestIV() []byte { - return r[17:33] -} - -func (r *VMessRequest) RequestKey() []byte { - return r[33:49] -} - -func (r *VMessRequest) ResponseHeader() []byte { - return r[49:53] -} - -func (r *VMessRequest) Command() byte { - return r[53] -} - -func (r *VMessRequest) SetCommand(command byte) *VMessRequest { - r[53] = command - return r -} - -func (r *VMessRequest) Port() uint16 { - return binary.BigEndian.Uint16(r.portBytes()) -} - -func (r *VMessRequest) portBytes() []byte { - return r[54:56] -} - -func (r *VMessRequest) SetPort(port uint16) *VMessRequest { - binary.BigEndian.PutUint16(r.portBytes(), port) - return r -} - -func (r *VMessRequest) targetAddressType() byte { - return r[56] -} - -func (r *VMessRequest) Destination() v2net.VAddress { - switch r.targetAddressType() { - case addrTypeIPv4: - fallthrough - case addrTypeIPv6: - return v2net.IPAddress(r.targetAddressBytes(), r.Port()) - case addrTypeDomain: - return v2net.DomainAddress(r.TargetAddress(), r.Port()) - default: - panic("Unpexected address type") - } -} - -func (r *VMessRequest) TargetAddress() string { - switch r.targetAddressType() { - case addrTypeIPv4: - return net.IP(r[57:61]).String() - case addrTypeIPv6: - return net.IP(r[57:73]).String() - case addrTypeDomain: - domainLength := int(r[57]) - return string(r[58 : 58+domainLength]) - default: - panic("Unexpected address type") - } -} - -func (r *VMessRequest) targetAddressBytes() []byte { - switch r.targetAddressType() { - case addrTypeIPv4: - return r[57:61] - case addrTypeIPv6: - return r[57:73] - case addrTypeDomain: - domainLength := int(r[57]) - return r[57 : 58+domainLength] - default: - panic("Unexpected address type") - } -} - -func (r *VMessRequest) SetIPv4(ipv4 []byte) *VMessRequest { - r[56] = addrTypeIPv4 - copy(r[57:], ipv4) - return r -} - -func (r *VMessRequest) SetIPv6(ipv6 []byte) *VMessRequest { - r[56] = addrTypeIPv6 - copy(r[57:], ipv6) - return r -} - -func (r *VMessRequest) SetDomain(domain string) *VMessRequest { - r[56] = addrTypeDomain - r[57] = byte(len(domain)) - copy(r[58:], []byte(domain)) - return r +type VMessRequest struct { + Version byte + UserId core.VID + RequestIV [16]byte + RequestKey [16]byte + ResponseHeader [4]byte + Command byte + Address v2net.VAddress } type VMessRequestReader struct { @@ -169,26 +56,30 @@ func NewVMessRequestReader(vUserSet *core.VUserSet) *VMessRequestReader { func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { request := new(VMessRequest) - nBytes, err := reader.Read(request[0:17] /* version + user hash */) + buffer := make([]byte, 256) + nBytes, err := reader.Read(buffer[0:1]) if err != nil { return nil, err } - if nBytes != 17 { - err = fmt.Errorf("Unexpected length of header %d", nBytes) + // TODO: verify version number + request.Version = buffer[0] + + nBytes, err = reader.Read(buffer[:len(request.UserId)]) + if err != nil { return nil, err } - // TODO: verify version number - userId, valid := r.vUserSet.IsValidUserId(request.UserHash()) + + userId, valid := r.vUserSet.IsValidUserId(buffer[:nBytes]) if !valid { return nil, ErrorInvalidUser } + request.UserId = *userId decryptor, err := NewDecryptionReader(reader, userId.Hash([]byte("PWD")), make([]byte, blockSize)) if err != nil { return nil, err } - buffer := make([]byte, 300) nBytes, err = decryptor.Read(buffer[0:1]) if err != nil { return nil, err @@ -204,15 +95,15 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { } // TODO: check number of bytes returned - _, err = decryptor.Read(request.RequestIV()) + _, err = decryptor.Read(request.RequestIV[:]) if err != nil { return nil, err } - _, err = decryptor.Read(request.RequestKey()) + _, err = decryptor.Read(request.RequestKey[:]) if err != nil { return nil, err } - _, err = decryptor.Read(request.ResponseHeader()) + _, err = decryptor.Read(request.ResponseHeader[:]) if err != nil { return nil, err } @@ -220,13 +111,13 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { if err != nil { return nil, err } - request.SetCommand(buffer[0]) + request.Command = buffer[0] _, err = decryptor.Read(buffer[0:2]) if err != nil { return nil, err } - request.SetPort(binary.BigEndian.Uint16(buffer[0:2])) + port := binary.BigEndian.Uint16(buffer[0:2]) _, err = decryptor.Read(buffer[0:1]) if err != nil { @@ -238,13 +129,13 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { if err != nil { return nil, err } - request.SetIPv4(buffer[1:5]) + request.Address = v2net.IPAddress(buffer[1:5], port) case addrTypeIPv6: _, err = decryptor.Read(buffer[1:17]) if err != nil { return nil, err } - request.SetIPv6(buffer[1:17]) + request.Address = v2net.IPAddress(buffer[1:17], port) case addrTypeDomain: _, err = decryptor.Read(buffer[1:2]) if err != nil { @@ -255,7 +146,7 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { if err != nil { return nil, err } - request.SetDomain(string(buffer[2 : 2+domainLength])) + request.Address = v2net.DomainAddress(string(buffer[2:2+domainLength]), port) } _, err = decryptor.Read(buffer[0:1]) if err != nil { @@ -271,19 +162,17 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { } type VMessRequestWriter struct { - vUserSet *core.VUserSet } -func NewVMessRequestWriter(vUserSet *core.VUserSet) *VMessRequestWriter { +func NewVMessRequestWriter() *VMessRequestWriter { writer := new(VMessRequestWriter) - writer.vUserSet = vUserSet return writer } func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) error { buffer := make([]byte, 0, 300) - buffer = append(buffer, request.Version()) - buffer = append(buffer, request.UserHash()...) + buffer = append(buffer, request.Version) + buffer = append(buffer, request.UserId.Hash([]byte("ASK"))...) encryptionBegin := len(buffer) @@ -296,13 +185,27 @@ func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) erro buffer = append(buffer, byte(randomLength)) buffer = append(buffer, randomContent...) - buffer = append(buffer, request.RequestIV()...) - buffer = append(buffer, request.RequestKey()...) - buffer = append(buffer, request.ResponseHeader()...) - buffer = append(buffer, request.Command()) - buffer = append(buffer, request.portBytes()...) - buffer = append(buffer, request.targetAddressType()) - buffer = append(buffer, request.targetAddressBytes()...) + buffer = append(buffer, request.RequestIV[:]...) + buffer = append(buffer, request.RequestKey[:]...) + buffer = append(buffer, request.ResponseHeader[:]...) + buffer = append(buffer, request.Command) + + portBytes := make([]byte, 2) + binary.BigEndian.PutUint16(portBytes, request.Address.Port) + buffer = append(buffer, portBytes...) + + switch { + case request.Address.IsIPv4(): + buffer = append(buffer, addrTypeIPv4) + buffer = append(buffer, request.Address.IP...) + case request.Address.IsIPv6(): + buffer = append(buffer, addrTypeIPv6) + buffer = append(buffer, request.Address.IP...) + case request.Address.IsDomain(): + buffer = append(buffer, addrTypeDomain) + buffer = append(buffer, byte(len(request.Address.Domain))) + buffer = append(buffer, []byte(request.Address.Domain)...) + } paddingLength := blockSize - 1 - (len(buffer)-encryptionBegin)%blockSize if paddingLength == 0 { @@ -317,11 +220,7 @@ func (w *VMessRequestWriter) Write(writer io.Writer, request *VMessRequest) erro buffer = append(buffer, paddingBuffer...) encryptionEnd := len(buffer) - userId, valid := w.vUserSet.IsValidUserId(request.UserHash()) - if !valid { - return ErrorInvalidUser - } - aesCipher, err := aes.NewCipher(userId.Hash([]byte("PWD"))) + aesCipher, err := aes.NewCipher(request.UserId.Hash([]byte("PWD"))) if err != nil { return err } @@ -344,6 +243,6 @@ type VMessResponse [4]byte func NewVMessResponse(request *VMessRequest) *VMessResponse { response := new(VMessResponse) - copy(response[:], request.ResponseHeader()) + copy(response[:], request.ResponseHeader[:]) return response } diff --git a/io/vmess/vmess_test.go b/io/vmess/vmess_test.go index 91af1651e..057c7ae6d 100644 --- a/io/vmess/vmess_test.go +++ b/io/vmess/vmess_test.go @@ -6,9 +6,13 @@ import ( "testing" "github.com/v2ray/v2ray-core" + v2net "github.com/v2ray/v2ray-core/net" + "github.com/v2ray/v2ray-core/testing/unit" ) func TestVMessSerialization(t *testing.T) { + assert := unit.Assert(t) + userId, err := core.UUIDToVID("2b2966ac-16aa-4fbf-8d81-c5f172a3da51") if err != nil { t.Fatal(err) @@ -18,31 +22,29 @@ func TestVMessSerialization(t *testing.T) { userSet.AddUser(core.VUser{userId}) request := new(VMessRequest) - request.SetVersion(byte(0x01)) - userHash := userId.Hash([]byte("ASK")) - copy(request.UserHash(), userHash) + request.Version = byte(0x01) + request.UserId = userId - _, err = rand.Read(request.RequestIV()) + _, err = rand.Read(request.RequestIV[:]) if err != nil { t.Fatal(err) } - _, err = rand.Read(request.RequestKey()) + _, err = rand.Read(request.RequestKey[:]) if err != nil { t.Fatal(err) } - _, err = rand.Read(request.ResponseHeader()) + _, err = rand.Read(request.ResponseHeader[:]) if err != nil { t.Fatal(err) } - request.SetCommand(byte(0x01)) - request.SetPort(80) - request.SetDomain("v2ray.com") + request.Command = byte(0x01) + request.Address = v2net.DomainAddress("v2ray.com", 80) buffer := bytes.NewBuffer(make([]byte, 0, 300)) - requestWriter := NewVMessRequestWriter(userSet) + requestWriter := NewVMessRequestWriter() err = requestWriter.Write(buffer, request) if err != nil { t.Fatal(err) @@ -54,35 +56,11 @@ func TestVMessSerialization(t *testing.T) { t.Fatal(err) } - if actualRequest.Version() != byte(0x01) { - t.Errorf("Expected Version 1, but got %d", actualRequest.Version()) - } - - if !bytes.Equal(request.UserHash(), actualRequest.UserHash()) { - t.Errorf("Expected user hash %v, but got %v", request.UserHash(), actualRequest.UserHash()) - } - - if !bytes.Equal(request.RequestIV(), actualRequest.RequestIV()) { - t.Errorf("Expected request IV %v, but got %v", request.RequestIV(), actualRequest.RequestIV()) - } - - if !bytes.Equal(request.RequestKey(), actualRequest.RequestKey()) { - t.Errorf("Expected request Key %v, but got %v", request.RequestKey(), actualRequest.RequestKey()) - } - - if !bytes.Equal(request.ResponseHeader(), actualRequest.ResponseHeader()) { - t.Errorf("Expected response header %v, but got %v", request.ResponseHeader(), actualRequest.ResponseHeader()) - } - - if actualRequest.Command() != byte(0x01) { - t.Errorf("Expected command 1, but got %d", actualRequest.Command()) - } - - if actualRequest.Port() != 80 { - t.Errorf("Expected port 80, but got %d", actualRequest.Port()) - } - - if actualRequest.TargetAddress() != "v2ray.com" { - t.Errorf("Expected target address v2ray.com, but got %s", actualRequest.TargetAddress()) - } + assert.Byte(actualRequest.Version).Named("Version").Equals(byte(0x01)) + assert.Bytes(actualRequest.UserId[:]).Named("UserId").Equals(request.UserId[:]) + assert.Bytes(actualRequest.RequestIV[:]).Named("RequestIV").Equals(request.RequestIV[:]) + assert.Bytes(actualRequest.RequestKey[:]).Named("RequestKey").Equals(request.RequestKey[:]) + assert.Bytes(actualRequest.ResponseHeader[:]).Named("ResponseHeader").Equals(request.ResponseHeader[:]) + assert.Byte(actualRequest.Command).Named("Command").Equals(request.Command) + assert.String(actualRequest.Address.String()).Named("Address").Equals(request.Address.String()) } diff --git a/log/log.go b/log/log.go new file mode 100644 index 000000000..7c5a7fbb0 --- /dev/null +++ b/log/log.go @@ -0,0 +1,45 @@ +package log + +import ( + "errors" + "fmt" + "log" +) + +const ( + DebugLevel = LogLevel(0) + InfoLevel = LogLevel(1) + WarningLevel = LogLevel(2) + ErrorLevel = LogLevel(3) +) + +var logLevel = WarningLevel + +type LogLevel int + +func SetLogLevel(level LogLevel) { + logLevel = level +} + +func writeLog(data string, level LogLevel) { + if level < logLevel { + return + } + log.Print(data) +} + +func Info(format string, v ...interface{}) { + data := fmt.Sprintf(format, v) + writeLog("[Info]"+data, InfoLevel) +} + +func Warning(format string, v ...interface{}) { + data := fmt.Sprintf(format, v) + writeLog("[Warning]"+data, WarningLevel) +} + +func Error(format string, v ...interface{}) error { + data := fmt.Sprintf(format, v) + writeLog("[Error]"+data, ErrorLevel) + return errors.New(data) +} diff --git a/net/vmess/vmessin.go b/net/vmess/vmessin.go index 8ad01ceab..3447c1482 100644 --- a/net/vmess/vmessin.go +++ b/net/vmess/vmessin.go @@ -12,12 +12,14 @@ import ( type VMessInboundHandler struct { vPoint *core.VPoint + clients *core.VUserSet accepting bool } -func NewVMessInboundHandler(vp *core.VPoint) *VMessInboundHandler { +func NewVMessInboundHandler(vp *core.VPoint, clients *core.VUserSet) *VMessInboundHandler { handler := new(VMessInboundHandler) handler.vPoint = vp + handler.clients = clients return handler } @@ -45,7 +47,7 @@ func (handler *VMessInboundHandler) AcceptConnections(listener net.Listener) err func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error { defer connection.Close() - reader := vmessio.NewVMessRequestReader(handler.vPoint.UserSet) + reader := vmessio.NewVMessRequestReader(handler.clients) request, err := reader.Read(connection) if err != nil { @@ -55,8 +57,8 @@ func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error response := vmessio.NewVMessResponse(request) connection.Write(response[:]) - requestKey := request.RequestKey() - requestIV := request.RequestIV() + requestKey := request.RequestKey[:] + requestIV := request.RequestIV[:] responseKey := md5.Sum(requestKey) responseIV := md5.Sum(requestIV) @@ -70,7 +72,7 @@ func (handler *VMessInboundHandler) HandleConnection(connection net.Conn) error return err } - ray := handler.vPoint.NewInboundConnectionAccepted(request.Destination()) + ray := handler.vPoint.NewInboundConnectionAccepted(request.Address) input := ray.InboundInput() output := ray.InboundOutput() finish := make(chan bool, 2) @@ -112,8 +114,18 @@ func (handler *VMessInboundHandler) waitForFinish(finish <-chan bool) { } type VMessInboundHandlerFactory struct { + allowedClients *core.VUserSet +} + +func NewVMessInboundHandlerFactory(clients []core.VUser) *VMessInboundHandlerFactory { + factory := new(VMessInboundHandlerFactory) + factory.allowedClients = core.NewVUserSet() + for _, user := range clients { + factory.allowedClients.AddUser(user) + } + return factory } func (factory *VMessInboundHandlerFactory) Create(vp *core.VPoint) *VMessInboundHandler { - return NewVMessInboundHandler(vp) + return NewVMessInboundHandler(vp, factory.allowedClients) } diff --git a/net/vmess/vmessout.go b/net/vmess/vmessout.go index 51742902a..49c07fc9c 100644 --- a/net/vmess/vmessout.go +++ b/net/vmess/vmessout.go @@ -45,23 +45,13 @@ func (handler *VMessOutboundHandler) Start(ray core.OutboundVRay) error { vNextAddress, vNextUser := handler.pickVNext() request := new(vmessio.VMessRequest) - request.SetVersion(vmessio.Version) - copy(request.UserHash(), vNextUser.Id.Hash([]byte("ASK"))) - rand.Read(request.RequestIV()) - rand.Read(request.RequestKey()) - rand.Read(request.ResponseHeader()) - request.SetCommand(byte(0x01)) - request.SetPort(handler.dest.Port) - - address := handler.dest - switch { - case address.IsIPv4(): - request.SetIPv4(address.IP) - case address.IsIPv6(): - request.SetIPv6(address.IP) - case address.IsDomain(): - request.SetDomain(address.Domain) - } + request.Version = vmessio.Version + request.UserId = vNextUser.Id + rand.Read(request.RequestIV[:]) + rand.Read(request.RequestKey[:]) + rand.Read(request.ResponseHeader[:]) + request.Command = byte(0x01) + request.Address = handler.dest conn, err := net.Dial("tcp", vNextAddress.String()) if err != nil { @@ -69,11 +59,11 @@ func (handler *VMessOutboundHandler) Start(ray core.OutboundVRay) error { } defer conn.Close() - requestWriter := vmessio.NewVMessRequestWriter(handler.vPoint.UserSet) + requestWriter := vmessio.NewVMessRequestWriter() requestWriter.Write(conn, request) - requestKey := request.RequestKey() - requestIV := request.RequestIV() + requestKey := request.RequestKey[:] + requestIV := request.RequestIV[:] responseKey := md5.Sum(requestKey) responseIV := md5.Sum(requestIV) diff --git a/vid.go b/vid.go index 0a5e959c4..f14879a09 100644 --- a/vid.go +++ b/vid.go @@ -3,7 +3,8 @@ package core import ( "crypto/md5" "encoding/hex" - "fmt" + + "github.com/v2ray/v2ray-core/log" ) // The ID of en entity, in the form of an UUID. @@ -23,7 +24,7 @@ var byteGroups = []int{8, 4, 4, 4, 12} func UUIDToVID(uuid string) (v VID, err error) { text := []byte(uuid) if len(text) < 32 { - err = fmt.Errorf("uuid: invalid UUID string: %s", text) + err = log.Error("uuid: invalid UUID string: %s", text) return } diff --git a/vpoint.go b/vpoint.go index b6f818551..de5695c94 100644 --- a/vpoint.go +++ b/vpoint.go @@ -9,7 +9,6 @@ import ( // VPoint is an single server in V2Ray system. type VPoint struct { Config VConfig - UserSet *VUserSet ichFactory InboundConnectionHandlerFactory ochFactory OutboundConnectionHandlerFactory } @@ -19,12 +18,6 @@ type VPoint struct { func NewVPoint(config *VConfig, ichFactory InboundConnectionHandlerFactory, ochFactory OutboundConnectionHandlerFactory) (*VPoint, error) { var vpoint = new(VPoint) vpoint.Config = *config - vpoint.UserSet = NewVUserSet() - - for _, user := range vpoint.Config.AllowedClients { - vpoint.UserSet.AddUser(user) - } - vpoint.ichFactory = ichFactory vpoint.ochFactory = ochFactory