From f1ab89d9d88b79e57275eff631e2811cc2fc3c7b Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sun, 28 Oct 2018 09:08:43 +0100 Subject: [PATCH] long running reverse test case --- app/reverse/portal.go | 33 +++-- common/mux/client.go | 4 + common/mux/session.go | 1 + testing/scenarios/reverse_test.go | 215 +++++++++++++++++++++++++++++- 4 files changed, 239 insertions(+), 14 deletions(-) diff --git a/app/reverse/portal.go b/app/reverse/portal.go index 99f76955f..2ef8c85e9 100644 --- a/app/reverse/portal.go +++ b/app/reverse/portal.go @@ -68,10 +68,7 @@ func (s *Portal) HandleConnection(ctx context.Context, link *vio.Link) error { } if isDomain(outboundMeta.Target, s.domain) { - muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{ - MaxConcurrency: 0, - MaxConnection: 256, - }) + muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{}) if err != nil { return newError("failed to create mux client worker").Base(err).AtWarning() } @@ -157,7 +154,7 @@ func (p *StaticMuxPicker) PickAvailable() (*mux.ClientWorker, error) { var minIdx int = -1 var minConn uint32 = 9999 for i, w := range p.workers { - if w.IsFull() { + if w.draining { continue } if w.client.ActiveConnections() < minConn { @@ -166,6 +163,18 @@ func (p *StaticMuxPicker) PickAvailable() (*mux.ClientWorker, error) { } } + if minIdx == -1 { + for i, w := range p.workers { + if w.IsFull() { + continue + } + if w.client.ActiveConnections() < minConn { + minConn = w.client.ActiveConnections() + minIdx = i + } + } + } + if minIdx != -1 { return p.workers[minIdx].client, nil } @@ -181,10 +190,11 @@ func (p *StaticMuxPicker) AddWorker(worker *PortalWorker) { } type PortalWorker struct { - client *mux.ClientWorker - control *task.Periodic - writer buf.Writer - reader buf.Reader + client *mux.ClientWorker + control *task.Periodic + writer buf.Writer + reader buf.Reader + draining bool } func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { @@ -221,14 +231,15 @@ func (w *PortalWorker) heartbeat() error { return newError("client worker stopped") } - if w.writer == nil { + if w.draining || w.writer == nil { return newError("already disposed") } msg := &Control{} msg.FillInRandom() - if w.client.IsClosing() { + if w.client.TotalConnections() > 256 { + w.draining = true msg.State = Control_DRAIN defer func() { diff --git a/common/mux/client.go b/common/mux/client.go index ecb7ddf8f..6df74a33e 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -190,6 +190,10 @@ func NewClientWorker(stream vio.Link, s ClientStrategy) (*ClientWorker, error) { return c, nil } +func (m *ClientWorker) TotalConnections() uint32 { + return uint32(m.sessionManager.Count()) +} + func (m *ClientWorker) ActiveConnections() uint32 { return uint32(m.sessionManager.Size()) } diff --git a/common/mux/session.go b/common/mux/session.go index c1ff40320..53c041510 100644 --- a/common/mux/session.go +++ b/common/mux/session.go @@ -61,6 +61,7 @@ func (m *SessionManager) Add(s *Session) { return } + m.count++ m.sessions[s.ID] = s } diff --git a/testing/scenarios/reverse_test.go b/testing/scenarios/reverse_test.go index 8e98ab2c8..aa311865d 100644 --- a/testing/scenarios/reverse_test.go +++ b/testing/scenarios/reverse_test.go @@ -6,13 +6,15 @@ import ( "testing" "time" + "v2ray.com/core" + "v2ray.com/core/app/log" + "v2ray.com/core/app/policy" + "v2ray.com/core/app/proxyman" "v2ray.com/core/app/reverse" "v2ray.com/core/app/router" - - "v2ray.com/core" - "v2ray.com/core/app/proxyman" "v2ray.com/core/common" "v2ray.com/core/common/compare" + clog "v2ray.com/core/common/log" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" "v2ray.com/core/common/serial" @@ -210,3 +212,210 @@ func TestReverseProxy(t *testing.T) { } wg.Wait() } + +func TestReverseProxyLongRunning(t *testing.T) { + tcpServer := tcp.Server{ + MsgProcessor: xor, + } + dest, err := tcpServer.Start() + common.Must(err) + + defer tcpServer.Close() + + userID := protocol.NewID(uuid.New()) + externalPort := tcp.PickPort() + reversePort := tcp.PickPort() + + serverConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: clog.Severity_Warning, + ErrorLogType: log.LogType_Console, + }), + serial.ToTypedMessage(&policy.Config{ + Level: map[uint32]*policy.Policy{ + 0: { + Timeout: &policy.Policy_Timeout{ + UplinkOnly: &policy.Second{Value: 0}, + DownlinkOnly: &policy.Second{Value: 0}, + }, + }, + }, + }), + serial.ToTypedMessage(&reverse.Config{ + PortalConfig: []*reverse.PortalConfig{ + { + Tag: "portal", + Domain: "test.v2ray.com", + }, + }, + }), + serial.ToTypedMessage(&router.Config{ + Rule: []*router.RoutingRule{ + { + Domain: []*router.Domain{ + {Type: router.Domain_Full, Value: "test.v2ray.com"}, + }, + Tag: "portal", + }, + { + InboundTag: []string{"external"}, + Tag: "portal", + }, + }, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + Tag: "external", + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortRange: net.SinglePortRange(externalPort), + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + Address: net.NewIPOrDomain(dest.Address), + Port: uint32(dest.Port), + NetworkList: &net.NetworkList{ + Network: []net.Network{net.Network_TCP}, + }, + }), + }, + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortRange: net.SinglePortRange(reversePort), + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&inbound.Config{ + User: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vmess.Account{ + Id: userID.String(), + AlterId: 64, + }), + }, + }, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&blackhole.Config{}), + }, + }, + } + + clientPort := tcp.PickPort() + clientConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: clog.Severity_Warning, + ErrorLogType: log.LogType_Console, + }), + serial.ToTypedMessage(&policy.Config{ + Level: map[uint32]*policy.Policy{ + 0: { + Timeout: &policy.Policy_Timeout{ + UplinkOnly: &policy.Second{Value: 0}, + DownlinkOnly: &policy.Second{Value: 0}, + }, + }, + }, + }), + serial.ToTypedMessage(&reverse.Config{ + BridgeConfig: []*reverse.BridgeConfig{ + { + Tag: "bridge", + Domain: "test.v2ray.com", + }, + }, + }), + serial.ToTypedMessage(&router.Config{ + Rule: []*router.RoutingRule{ + { + Domain: []*router.Domain{ + {Type: router.Domain_Full, Value: "test.v2ray.com"}, + }, + Tag: "reverse", + }, + { + InboundTag: []string{"bridge"}, + Tag: "freedom", + }, + }, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortRange: net.SinglePortRange(clientPort), + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + Address: net.NewIPOrDomain(dest.Address), + Port: uint32(dest.Port), + NetworkList: &net.NetworkList{ + Network: []net.Network{net.Network_TCP}, + }, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + Tag: "freedom", + ProxySettings: serial.ToTypedMessage(&freedom.Config{}), + }, + { + Tag: "reverse", + ProxySettings: serial.ToTypedMessage(&outbound.Config{ + Receiver: []*protocol.ServerEndpoint{ + { + Address: net.NewIPOrDomain(net.LocalHostIP), + Port: uint32(reversePort), + User: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vmess.Account{ + Id: userID.String(), + AlterId: 64, + SecuritySettings: &protocol.SecurityConfig{ + Type: protocol.SecurityType_AES128_GCM, + }, + }), + }, + }, + }, + }, + }), + }, + }, + } + + servers, err := InitializeServerConfigs(serverConfig, clientConfig) + common.Must(err) + + defer CloseAllServers(servers) + + for i := 0; i < 4096; i++ { + conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: int(externalPort), + }) + common.Must(err) + + payload := make([]byte, 1024) + rand.Read(payload) + + nBytes, err := conn.Write([]byte(payload)) + common.Must(err) + + if nBytes != len(payload) { + t.Error("only part of payload is written: ", nBytes) + } + + response := readFrom(conn, time.Second*5, 1024) + if err := compare.BytesEqualWithDetail(response, xor([]byte(payload))); err != nil { + t.Error(err) + } + + conn.Close() + } +}