diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index c78835cf8..c195d4a76 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -74,11 +74,13 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou } h.mux = &mux.ClientManager{ Picker: &mux.IncrementalWorkerPicker{ - New: func() (*mux.ClientWorker, error) { - return mux.NewClientWorker(proxyHandler, h, mux.ClientStrategy{ + Factory: &mux.DialingWorkerFactory{ + Proxy: proxyHandler, + Dialer: h, + Strategy: mux.ClientStrategy{ MaxConcurrency: config.Concurrency, MaxConnection: 128, - }) + }, }, }, } diff --git a/common/mux/client.go b/common/mux/client.go index 136f2dd81..95fa9bdd3 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -41,7 +41,7 @@ type WorkerPicker interface { } type IncrementalWorkerPicker struct { - New func() (*ClientWorker, error) + Factory ClientWorkerFactory access sync.Mutex workers []*ClientWorker @@ -82,7 +82,7 @@ func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, error, bool) { p.cleanup() - worker, err := p.New() + worker, err := p.Factory.Create() if err != nil { return nil, err, false } @@ -107,6 +107,46 @@ func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) { return worker, err } +type ClientWorkerFactory interface { + Create() (*ClientWorker, error) +} + +type DialingWorkerFactory struct { + Proxy proxy.Outbound + Dialer internet.Dialer + Strategy ClientStrategy +} + +func (f *DialingWorkerFactory) Create() (*ClientWorker, error) { + opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)} + uplinkReader, upLinkWriter := pipe.New(opts...) + downlinkReader, downlinkWriter := pipe.New(opts...) + + c, err := NewClientWorker(vio.Link{ + Reader: downlinkReader, + Writer: upLinkWriter, + }, f.Strategy) + + if err != nil { + return nil, err + } + + go func(p proxy.Outbound, d internet.Dialer, c common.Closable) { + ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{ + Target: net.TCPDestination(muxCoolAddress, muxCoolPort), + }) + ctx, cancel := context.WithCancel(ctx) + + if err := p.Process(ctx, &vio.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil { + errors.New("failed to handler mux client connection").Base(err).WriteToLog() + } + common.Must(c.Close()) + cancel() + }(f.Proxy, f.Dialer, c.done) + + return c, nil +} + type ClientStrategy struct { MaxConcurrency uint32 MaxConnection uint32 @@ -123,36 +163,17 @@ var muxCoolAddress = net.DomainAddress("v1.mux.cool") var muxCoolPort = net.Port(9527) // NewClientWorker creates a new mux.Client. -func NewClientWorker(p proxy.Outbound, dialer internet.Dialer, s ClientStrategy) (*ClientWorker, error) { - ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{ - Target: net.TCPDestination(muxCoolAddress, muxCoolPort), - }) - ctx, cancel := context.WithCancel(ctx) - - opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)} - uplinkReader, upLinkWriter := pipe.New(opts...) - downlinkReader, downlinkWriter := pipe.New(opts...) - +func NewClientWorker(stream vio.Link, s ClientStrategy) (*ClientWorker, error) { c := &ClientWorker{ sessionManager: NewSessionManager(), - link: vio.Link{ - Reader: downlinkReader, - Writer: upLinkWriter, - }, - done: done.New(), - strategy: s, + link: stream, + done: done.New(), + strategy: s, } - go func() { - if err := p.Process(ctx, &vio.Link{Reader: uplinkReader, Writer: downlinkWriter}, dialer); err != nil { - errors.New("failed to handler mux client connection").Base(err).WriteToLog() - } - common.Must(c.done.Close()) - cancel() - }() - go c.fetchOutput() go c.monitor() + return c, nil } @@ -221,12 +242,21 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) { } } -func (m *ClientWorker) IsFull() bool { +func (m *ClientWorker) IsClosing() bool { sm := m.sessionManager - if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) { + if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) { return true } - if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) { + return false +} + +func (m *ClientWorker) IsFull() bool { + if m.IsClosing() { + return true + } + + sm := m.sessionManager + if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) { return true } return false diff --git a/common/mux/client_test.go b/common/mux/client_test.go index 8c86af811..86378d406 100644 --- a/common/mux/client_test.go +++ b/common/mux/client_test.go @@ -1,17 +1,28 @@ package mux_test import ( + "context" "testing" + "time" + "github.com/golang/mock/gomock" + "v2ray.com/core/common" "v2ray.com/core/common/errors" "v2ray.com/core/common/mux" + "v2ray.com/core/common/vio" + "v2ray.com/core/testing/mocks" + "v2ray.com/core/transport/pipe" ) func TestIncrementalPickerFailure(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockWorkerFactory := mocks.NewMuxClientWorkerFactory(mockCtl) + mockWorkerFactory.EXPECT().Create().Return(nil, errors.New("test")) + picker := mux.IncrementalWorkerPicker{ - New: func() (*mux.ClientWorker, error) { - return nil, errors.New("test") - }, + Factory: mockWorkerFactory, } _, err := picker.PickAvailable() @@ -19,3 +30,18 @@ func TestIncrementalPickerFailure(t *testing.T) { t.Error("expected error, but nil") } } + +func TestClientWorkerEOF(t *testing.T) { + reader, writer := pipe.New(pipe.WithoutSizeLimit()) + common.Must(writer.Close()) + + worker, err := mux.NewClientWorker(vio.Link{Reader: reader, Writer: writer}, mux.ClientStrategy{}) + common.Must(err) + + time.Sleep(time.Millisecond * 500) + + f := worker.Dispatch(context.Background(), nil) + if f { + t.Error("expected failed dispatching, but actually not") + } +} diff --git a/mocks.go b/mocks.go index 6c975e279..922e0f8b6 100644 --- a/mocks.go +++ b/mocks.go @@ -4,5 +4,6 @@ package core //go:generate go install github.com/golang/mock/mockgen //go:generate mockgen -package mocks -destination testing/mocks/io.go -mock_names Reader=Reader,Writer=Writer io Reader,Writer +//go:generate mockgen -package mocks -destination testing/mocks/mux.go -mock_names ClientWorkerFactory=MuxClientWorkerFactory v2ray.com/core/common/mux ClientWorkerFactory //go:generate mockgen -package mocks -destination testing/mocks/dns.go -mock_names Client=DNSClient v2ray.com/core/features/dns Client //go:generate mockgen -package mocks -destination testing/mocks/proxy.go -mock_names Inbound=ProxyInbound,Outbound=ProxyOutbound v2ray.com/core/proxy Inbound,Outbound diff --git a/testing/mocks/mux.go b/testing/mocks/mux.go new file mode 100644 index 000000000..6ef1697dd --- /dev/null +++ b/testing/mocks/mux.go @@ -0,0 +1,47 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: v2ray.com/core/common/mux (interfaces: ClientWorkerFactory) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" + mux "v2ray.com/core/common/mux" +) + +// MuxClientWorkerFactory is a mock of ClientWorkerFactory interface +type MuxClientWorkerFactory struct { + ctrl *gomock.Controller + recorder *MuxClientWorkerFactoryMockRecorder +} + +// MuxClientWorkerFactoryMockRecorder is the mock recorder for MuxClientWorkerFactory +type MuxClientWorkerFactoryMockRecorder struct { + mock *MuxClientWorkerFactory +} + +// NewMuxClientWorkerFactory creates a new mock instance +func NewMuxClientWorkerFactory(ctrl *gomock.Controller) *MuxClientWorkerFactory { + mock := &MuxClientWorkerFactory{ctrl: ctrl} + mock.recorder = &MuxClientWorkerFactoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MuxClientWorkerFactory) EXPECT() *MuxClientWorkerFactoryMockRecorder { + return m.recorder +} + +// Create mocks base method +func (m *MuxClientWorkerFactory) Create() (*mux.ClientWorker, error) { + ret := m.ctrl.Call(m, "Create") + ret0, _ := ret[0].(*mux.ClientWorker) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Create indicates an expected call of Create +func (mr *MuxClientWorkerFactoryMockRecorder) Create() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MuxClientWorkerFactory)(nil).Create)) +}