diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index d9dbffb47..356891369 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -139,34 +139,13 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol buffer.Clear() request.Version = socks5Version - switch addrType { - case addrTypeIPv4: - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil { - return nil, err - } - request.Address = net.IPAddress(buffer.Bytes()) - case addrTypeIPv6: - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil { - return nil, err - } - request.Address = net.IPAddress(buffer.Bytes()) - case addrTypeDomain: - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil { - return nil, err - } - domainLength := int(buffer.Byte(0)) - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil { - return nil, err - } - request.Address = net.ParseAddress(string(buffer.BytesFrom(-domainLength))) - default: - return nil, newError("Unknown address type: ", addrType) - } - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil { - return nil, err + addr, port, err := ReadAddress(buffer, addrType, reader) + if err != nil { + return nil, newError("failed to read address").Base(err) } - request.Port = net.PortFromBytes(buffer.BytesFrom(-2)) + request.Address = addr + request.Port = port responseAddress := net.AnyIP responsePort := net.Port(1717)