diff --git a/common/net/transport.go b/common/net/transport.go index 7f9eea9ee..374317e0b 100644 --- a/common/net/transport.go +++ b/common/net/transport.go @@ -8,13 +8,18 @@ const ( bufferSize = 4 * 1024 ) +func ReadFrom(reader io.Reader) ([]byte, error) { + buffer := make([]byte, bufferSize) + nBytes, err := reader.Read(buffer) + return buffer[:nBytes], err +} + // ReaderToChan dumps all content from a given reader to a chan by constantly reading it until EOF. func ReaderToChan(stream chan<- []byte, reader io.Reader) error { for { - buffer := make([]byte, bufferSize) - nBytes, err := reader.Read(buffer) - if nBytes > 0 { - stream <- buffer[:nBytes] + data, err := ReadFrom(reader) + if len(data) > 0 { + stream <- data } if err != nil { return err diff --git a/proxy/socks/socks.go b/proxy/socks/socks.go index 1f2b138dd..d33879f8f 100644 --- a/proxy/socks/socks.go +++ b/proxy/socks/socks.go @@ -41,6 +41,9 @@ func (server *SocksServer) Listen(port uint16) error { } server.accepting = true go server.AcceptConnections(listener) + if server.config.UDPEnabled { + server.ListenUDP(port) + } return nil } @@ -66,120 +69,142 @@ func (server *SocksServer) HandleConnection(connection net.Conn) error { return err } - var dest v2net.Destination - - // TODO refactor this part - if errors.HasCode(err, 1000) { - result := protocol.Socks4RequestGranted - if auth4.Command == protocol.CmdBind { - result = protocol.Socks4RequestRejected - } - socks4Response := protocol.NewSocks4AuthenticationResponse(result, auth4.Port, auth4.IP[:]) - connection.Write(socks4Response.ToBytes(nil)) - - if result == protocol.Socks4RequestRejected { - return errors.NewInvalidOperationError("Socks4 command " + strconv.Itoa(int(auth4.Command))) - } - - dest = v2net.NewTCPDestination(v2net.IPAddress(auth4.IP[:], auth4.Port)) + if err != nil && errors.HasCode(err, 1000) { + return server.handleSocks4(reader, connection, auth4) } else { - expectedAuthMethod := protocol.AuthNotRequired - if server.config.IsPassword() { - expectedAuthMethod = protocol.AuthUserPass - } + return server.handleSocks5(reader, connection, auth) + } +} - if !auth.HasAuthMethod(expectedAuthMethod) { - authResponse := protocol.NewAuthenticationResponse(protocol.AuthNoMatchingMethod) - err = protocol.WriteAuthentication(connection, authResponse) - if err != nil { - log.Error("Socks failed to write authentication: %v", err) - return err - } - log.Warning("Socks client doesn't support allowed any auth methods.") - return errors.NewInvalidOperationError("Unsupported auth methods.") - } +func (server *SocksServer) handleSocks5(reader io.Reader, writer io.Writer, auth protocol.Socks5AuthenticationRequest) error { + expectedAuthMethod := protocol.AuthNotRequired + if server.config.IsPassword() { + expectedAuthMethod = protocol.AuthUserPass + } - authResponse := protocol.NewAuthenticationResponse(expectedAuthMethod) - err = protocol.WriteAuthentication(connection, authResponse) + if !auth.HasAuthMethod(expectedAuthMethod) { + authResponse := protocol.NewAuthenticationResponse(protocol.AuthNoMatchingMethod) + err := protocol.WriteAuthentication(writer, authResponse) if err != nil { log.Error("Socks failed to write authentication: %v", err) return err } - if server.config.IsPassword() { - upRequest, err := protocol.ReadUserPassRequest(reader) - if err != nil { - log.Error("Socks failed to read username and password: %v", err) - return err - } - status := byte(0) - if !upRequest.IsValid(server.config.Username, server.config.Password) { - status = byte(0xFF) - } - upResponse := protocol.NewSocks5UserPassResponse(status) - err = protocol.WriteUserPassResponse(connection, upResponse) - if err != nil { - log.Error("Socks failed to write user pass response: %v", err) - return err - } - if status != byte(0) { - err = errors.NewAuthenticationError(upRequest.AuthDetail()) - log.Warning(err.Error()) - return err - } - } + log.Warning("Socks client doesn't support allowed any auth methods.") + return errors.NewInvalidOperationError("Unsupported auth methods.") + } - request, err := protocol.ReadRequest(reader) + authResponse := protocol.NewAuthenticationResponse(expectedAuthMethod) + err := protocol.WriteAuthentication(writer, authResponse) + if err != nil { + log.Error("Socks failed to write authentication: %v", err) + return err + } + if server.config.IsPassword() { + upRequest, err := protocol.ReadUserPassRequest(reader) if err != nil { - log.Error("Socks failed to read request: %v", err) + log.Error("Socks failed to read username and password: %v", err) return err } + status := byte(0) + if !upRequest.IsValid(server.config.Username, server.config.Password) { + status = byte(0xFF) + } + upResponse := protocol.NewSocks5UserPassResponse(status) + err = protocol.WriteUserPassResponse(writer, upResponse) + if err != nil { + log.Error("Socks failed to write user pass response: %v", err) + return err + } + if status != byte(0) { + err = errors.NewAuthenticationError(upRequest.AuthDetail()) + log.Warning(err.Error()) + return err + } + } + request, err := protocol.ReadRequest(reader) + if err != nil { + log.Error("Socks failed to read request: %v", err) + return err + } + + response := protocol.NewSocks5Response() + + if request.Command == protocol.CmdBind || (!server.config.UDPEnabled && request.Command == protocol.CmdUdpAssociate) { response := protocol.NewSocks5Response() - - if request.Command == protocol.CmdBind || request.Command == protocol.CmdUdpAssociate { - response := protocol.NewSocks5Response() - response.Error = protocol.ErrorCommandNotSupported - err = protocol.WriteResponse(connection, response) - if err != nil { - log.Error("Socks failed to write response: %v", err) - return err - } - log.Warning("Unsupported socks command %d", request.Command) - return errors.NewInvalidOperationError("Socks command " + strconv.Itoa(int(request.Command))) - } - - response.Error = protocol.ErrorSuccess - response.Port = request.Port - response.AddrType = request.AddrType - switch response.AddrType { - case protocol.AddrTypeIPv4: - copy(response.IPv4[:], request.IPv4[:]) - case protocol.AddrTypeIPv6: - copy(response.IPv6[:], request.IPv6[:]) - case protocol.AddrTypeDomain: - response.Domain = request.Domain - } - err = protocol.WriteResponse(connection, response) + response.Error = protocol.ErrorCommandNotSupported + err = protocol.WriteResponse(writer, response) if err != nil { log.Error("Socks failed to write response: %v", err) return err } - - dest = request.Destination() + log.Warning("Unsupported socks command %d", request.Command) + return errors.NewInvalidOperationError("Socks command " + strconv.Itoa(int(request.Command))) } - ray := server.vPoint.DispatchToOutbound(v2net.NewPacket(dest, nil, true)) + response.Error = protocol.ErrorSuccess + response.Port = request.Port + response.AddrType = request.AddrType + switch response.AddrType { + case protocol.AddrTypeIPv4: + copy(response.IPv4[:], request.IPv4[:]) + case protocol.AddrTypeIPv6: + copy(response.IPv6[:], request.IPv6[:]) + case protocol.AddrTypeDomain: + response.Domain = request.Domain + } + err = protocol.WriteResponse(writer, response) + if err != nil { + log.Error("Socks failed to write response: %v", err) + return err + } + + dest := request.Destination() + data, err := v2net.ReadFrom(reader) + if err != nil { + return err + } + + packet := v2net.NewPacket(dest, data, true) + server.transport(reader, writer, packet) + return nil +} + +func (server *SocksServer) handleSocks4(reader io.Reader, writer io.Writer, auth protocol.Socks4AuthenticationRequest) error { + result := protocol.Socks4RequestGranted + if auth.Command == protocol.CmdBind { + result = protocol.Socks4RequestRejected + } + socks4Response := protocol.NewSocks4AuthenticationResponse(result, auth.Port, auth.IP[:]) + writer.Write(socks4Response.ToBytes(nil)) + + if result == protocol.Socks4RequestRejected { + return errors.NewInvalidOperationError("Socks4 command " + strconv.Itoa(int(auth.Command))) + } + + dest := v2net.NewTCPDestination(v2net.IPAddress(auth.IP[:], auth.Port)) + data, err := v2net.ReadFrom(reader) + if err != nil { + return err + } + + packet := v2net.NewPacket(dest, data, true) + server.transport(reader, writer, packet) + return nil +} + +func (server *SocksServer) transport(reader io.Reader, writer io.Writer, firstPacket v2net.Packet) { + ray := server.vPoint.DispatchToOutbound(firstPacket) input := ray.InboundInput() output := ray.InboundOutput() - var readFinish, writeFinish sync.Mutex - readFinish.Lock() - writeFinish.Lock() - go dumpInput(reader, input, &readFinish) - go dumpOutput(connection, output, &writeFinish) - writeFinish.Lock() + var inputFinish, outputFinish sync.Mutex + inputFinish.Lock() + outputFinish.Lock() - return nil + go dumpInput(reader, input, &inputFinish) + go dumpOutput(writer, output, &outputFinish) + outputFinish.Lock() } func dumpInput(reader io.Reader, input chan<- []byte, finish *sync.Mutex) {