diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index 5c1f23a43..319b3adaf 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -65,11 +65,11 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol } port := net.PortFromBytes(buffer.BytesRange(2, 4)) address := net.IPAddress(buffer.BytesRange(4, 8)) - if _, err := readUntilNull(reader); /* user id */ err != nil { + if _, err := ReadUntilNull(reader); /* user id */ err != nil { return nil, err } if address.IP()[0] == 0x00 { - domain, err := readUntilNull(reader) + domain, err := ReadUntilNull(reader) if err != nil { return nil, newError("failed to read domain for socks 4a").Base(err) } @@ -113,7 +113,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol } if expectedAuth == authPassword { - username, password, err := readUsernamePassword(reader) + username, password, err := ReadUsernamePassword(reader) if err != nil { return nil, newError("failed to read username and password for authentication").Base(err) } @@ -183,7 +183,13 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol return nil, newError("unknown Socks version: ", version) } -func readUsernamePassword(reader io.Reader) (string, string, error) { +// ReadUsernamePassword reads Socks 5 username/password message from the given reader. +// +----+------+----------+------+----------+ +// |VER | ULEN | UNAME | PLEN | PASSWD | +// +----+------+----------+------+----------+ +// | 1 | 1 | 1 to 255 | 1 | 1 to 255 | +// +----+------+----------+------+----------+ +func ReadUsernamePassword(reader io.Reader) (string, string, error) { buffer := buf.New() defer buffer.Release() @@ -212,19 +218,21 @@ func readUsernamePassword(reader io.Reader) (string, string, error) { return username, password, nil } -func readUntilNull(reader io.Reader) (string, error) { - var b [256]byte - size := 0 +// ReadUntilNull reads content from given reader, until a null (0x00) byte. +func ReadUntilNull(reader io.Reader) (string, error) { + b := buf.New() + defer b.Release() + for { - _, err := reader.Read(b[size : size+1]) + _, err := b.ReadFullFrom(reader, 1) if err != nil { return "", err } - if b[size] == 0x00 { - return string(b[:size]), nil + if b.Byte(b.Len()-1) == 0x00 { + b.Resize(0, b.Len()-1) + return b.String(), nil } - size++ - if size == 256 { + if b.IsFull() { return "", newError("buffer overrun") } } diff --git a/proxy/socks/protocol_test.go b/proxy/socks/protocol_test.go index 73f517f75..21c24c400 100644 --- a/proxy/socks/protocol_test.go +++ b/proxy/socks/protocol_test.go @@ -1,6 +1,7 @@ package socks_test import ( + "bytes" "testing" "v2ray.com/core/common/buf" @@ -33,3 +34,76 @@ func TestUDPEncoding(t *testing.T) { assert(err, IsNil) assert(decodedPayload[0].Bytes(), Equals, content) } + +func TestReadUsernamePassword(t *testing.T) { + testCases := []struct { + Input []byte + Username string + Password string + Error bool + }{ + { + Input: []byte{0x05, 0x01, 'a', 0x02, 'b', 'c'}, + Username: "a", + Password: "bc", + }, + { + Input: []byte{0x05, 0x18, 'a', 0x02, 'b', 'c'}, + Error: true, + }, + } + + for _, testCase := range testCases { + reader := bytes.NewReader(testCase.Input) + username, password, err := ReadUsernamePassword(reader) + if testCase.Error { + if err == nil { + t.Error("for input: ", testCase.Input, " expect error, but actually nil") + } + } else { + if err != nil { + t.Error("for input: ", testCase.Input, " expect no error, but actually ", err.Error()) + } + if testCase.Username != username { + t.Error("for input: ", testCase.Input, " expect username ", testCase.Username, " but actually ", username) + } + if testCase.Password != password { + t.Error("for input: ", testCase.Input, " expect passowrd ", testCase.Password, " but actually ", password) + } + } + } +} + +func TestReadUntilNull(t *testing.T) { + testCases := []struct { + Input []byte + Output string + Error bool + }{ + { + Input: []byte{'a', 'b', 0x00}, + Output: "ab", + }, + { + Input: []byte{'a'}, + Error: true, + }, + } + + for _, testCase := range testCases { + reader := bytes.NewReader(testCase.Input) + value, err := ReadUntilNull(reader) + if testCase.Error { + if err == nil { + t.Error("for input: ", testCase.Input, " expect error, but actually nil") + } + } else { + if err != nil { + t.Error("for input: ", testCase.Input, " expect no error, but actually ", err.Error()) + } + if testCase.Output != value { + t.Error("for input: ", testCase.Input, " expect output ", testCase.Output, " but actually ", value) + } + } + } +}