diff --git a/proxy/socks/client.go b/proxy/socks/client.go index 31d5f268c..bdb24e12c 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -146,10 +146,21 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter if err := conn.SetDeadline(time.Now().Add(p.Timeouts.Handshake)); err != nil { newError("failed to set deadline for handshake").Base(err).WriteToLog(session.ExportIDToError(ctx)) } - udpRequest, err := ClientHandshake(request, conn, conn) - if err != nil { - return newError("failed to establish connection to server").AtWarning().Base(err) + + var udpRequest *protocol.RequestHeader + var err error + if request.Version == socks4Version { + err = ClientHandshake4(request, conn, conn) + if err != nil { + return newError("failed to establish connection to server").AtWarning().Base(err) + } + } else { + udpRequest, err = ClientHandshake(request, conn, conn) + if err != nil { + return newError("failed to establish connection to server").AtWarning().Base(err) + } } + if udpRequest != nil { if udpRequest.Address == net.AnyIP || udpRequest.Address == net.AnyIPv6 { udpRequest.Address = dest.Address diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index 32c8c690b..e88c03aa0 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -537,3 +537,46 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i return nil, nil } + +func ClientHandshake4(request *protocol.RequestHeader, reader io.Reader, writer io.Writer) error { + b := buf.New() + defer b.Release() + + common.Must2(b.Write([]byte{socks4Version, cmdTCPConnect})) + portBytes := b.Extend(2) + binary.BigEndian.PutUint16(portBytes, request.Port.Value()) + switch request.Address.Family() { + case net.AddressFamilyIPv4: + common.Must2(b.Write(request.Address.IP())) + case net.AddressFamilyDomain: + common.Must2(b.Write([]byte{0x00, 0x00, 0x00, 0x01})) + case net.AddressFamilyIPv6: + return newError("ipv6 is not supported in socks4") + default: + panic("Unknown family type.") + } + if request.User != nil { + account := request.User.Account.(*Account) + common.Must2(b.WriteString(account.Username)) + } + common.Must(b.WriteByte(0x00)) + if request.Address.Family() == net.AddressFamilyDomain { + common.Must2(b.WriteString(request.Address.Domain())) + common.Must(b.WriteByte(0x00)) + } + if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil { + return err + } + + b.Clear() + if _, err := b.ReadFullFrom(reader, 8); err != nil { + return err + } + if b.Byte(0) != 0x00 { + return newError("unexpected version of the reply code: ", b.Byte(0)) + } + if b.Byte(1) != socks4RequestGranted { + return newError("server rejects request: ", b.Byte(1)) + } + return nil +}