From 7792237b500d986c907aae7b071c8736ea9f8fd5 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Fri, 24 Feb 2017 01:05:16 +0100 Subject: [PATCH] adopt context in listeners --- transport/internet/kcp/kcp_test.go | 13 +----- transport/internet/kcp/listener.go | 21 +++------ transport/internet/tcp/hub.go | 15 ++----- transport/internet/tcp_hub.go | 26 +++++++---- transport/internet/websocket/hub.go | 16 +++---- transport/internet/websocket/ws_test.go | 60 +++++++------------------ 6 files changed, 49 insertions(+), 102 deletions(-) diff --git a/transport/internet/kcp/kcp_test.go b/transport/internet/kcp/kcp_test.go index 8e1048587..3159f3971 100644 --- a/transport/internet/kcp/kcp_test.go +++ b/transport/internet/kcp/kcp_test.go @@ -10,7 +10,6 @@ import ( "time" v2net "v2ray.com/core/common/net" - "v2ray.com/core/common/serial" "v2ray.com/core/testing/assert" "v2ray.com/core/transport/internet" . "v2ray.com/core/transport/internet/kcp" @@ -19,17 +18,7 @@ import ( func TestDialAndListen(t *testing.T) { assert := assert.On(t) - listerner, err := NewListener(v2net.LocalHostIP, v2net.Port(0), internet.ListenOptions{ - Stream: &internet.StreamConfig{ - Protocol: internet.TransportProtocol_MKCP, - TransportSettings: []*internet.TransportConfig{ - { - Protocol: internet.TransportProtocol_MKCP, - Settings: serial.ToTypedMessage(&Config{}), - }, - }, - }, - }) + listerner, err := NewListener(internet.ContextWithTransportSettings(context.Background(), &Config{}), v2net.LocalHostIP, v2net.Port(0)) assert.Error(err).IsNil() port := v2net.Port(listerner.Addr().(*net.UDPAddr).Port) diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index 440efa54e..2aa27dae1 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -1,6 +1,7 @@ package kcp import ( + "context" "crypto/cipher" "crypto/tls" "io" @@ -90,12 +91,8 @@ type Listener struct { security cipher.AEAD } -func NewListener(address v2net.Address, port v2net.Port, options internet.ListenOptions) (*Listener, error) { - networkSettings, err := options.Stream.GetEffectiveTransportSettings() - if err != nil { - log.Error("KCP|Listener: Failed to get KCP settings: ", err) - return nil, err - } +func NewListener(ctx context.Context, address v2net.Address, port v2net.Port) (*Listener, error) { + networkSettings := internet.TransportSettingsFromContext(ctx) kcpSettings := networkSettings.(*Config) kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false} @@ -119,12 +116,8 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen closed: make(chan bool), config: kcpSettings, } - if options.Stream != nil && options.Stream.HasSecuritySettings() { - securitySettings, err := options.Stream.GetEffectiveSecuritySettings() - if err != nil { - log.Error("KCP|Listener: Failed to get security settings: ", err) - return nil, err - } + securitySettings := internet.SecuritySettingsFromContext(ctx) + if securitySettings != nil { switch securitySettings := securitySettings.(type) { case *v2tls.Config: l.tlsConfig = securitySettings.GetTLSConfig() @@ -295,8 +288,8 @@ func (v *Writer) Close() error { return nil } -func ListenKCP(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) { - return NewListener(address, port, options) +func ListenKCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) { + return NewListener(ctx, address, port) } func init() { diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index df6cd83df..65a69ee5e 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -1,6 +1,7 @@ package tcp import ( + "context" "crypto/tls" "net" "sync" @@ -34,7 +35,7 @@ type TCPListener struct { config *Config } -func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) { +func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) { listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: address.IP(), Port: int(port), @@ -43,10 +44,7 @@ func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOp return nil, err } log.Info("TCP|Listener: Listening on ", address, ":", port) - networkSettings, err := options.Stream.GetEffectiveTransportSettings() - if err != nil { - return nil, err - } + networkSettings := internet.TransportSettingsFromContext(ctx) tcpSettings := networkSettings.(*Config) l := &TCPListener{ @@ -55,12 +53,7 @@ func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOp awaitingConns: make(chan *ConnectionWithError, 32), config: tcpSettings, } - if options.Stream != nil && options.Stream.HasSecuritySettings() { - securitySettings, err := options.Stream.GetEffectiveSecuritySettings() - if err != nil { - log.Error("TCP: Failed to get security config: ", err) - return nil, err - } + if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { tlsConfig, ok := securitySettings.(*v2tls.Config) if ok { l.tlsConfig = tlsConfig.GetTLSConfig() diff --git a/transport/internet/tcp_hub.go b/transport/internet/tcp_hub.go index 786e8bd2f..5ab5d3bb0 100644 --- a/transport/internet/tcp_hub.go +++ b/transport/internet/tcp_hub.go @@ -3,6 +3,8 @@ package internet import ( "net" + "context" + "v2ray.com/core/app/log" "v2ray.com/core/common/errors" v2net "v2ray.com/core/common/net" @@ -21,11 +23,7 @@ func RegisterTransportListener(protocol TransportProtocol, listener ListenFunc) return nil } -type ListenFunc func(address v2net.Address, port v2net.Port, options ListenOptions) (Listener, error) -type ListenOptions struct { - Stream *StreamConfig - RecvOrigDest bool -} +type ListenFunc func(ctx context.Context, address v2net.Address, port v2net.Port) (Listener, error) type Listener interface { Accept() (Connection, error) @@ -40,15 +38,25 @@ type TCPHub struct { } func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandler, settings *StreamConfig) (*TCPHub, error) { - options := ListenOptions{ - Stream: settings, - } + ctx := context.Background() protocol := settings.GetEffectiveProtocol() + transportSettings, err := settings.GetEffectiveTransportSettings() + if err != nil { + return nil, err + } + ctx = ContextWithTransportSettings(ctx, transportSettings) + if settings != nil && settings.HasSecuritySettings() { + securitySettings, err := settings.GetEffectiveSecuritySettings() + if err != nil { + return nil, err + } + ctx = ContextWithSecuritySettings(ctx, securitySettings) + } listenFunc := transportListenerCache[protocol] if listenFunc == nil { return nil, errors.New("Internet|TCPHub: ", protocol, " listener not registered.") } - listener, err := listenFunc(address, port, options) + listener, err := listenFunc(ctx, address, port) if err != nil { return nil, errors.Base(err).Message("Internet|TCPHub: Failed to listen on address: ", address, ":", port) } diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index 12b1037bd..f7d1ebb4d 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -1,6 +1,7 @@ package websocket import ( + "context" "crypto/tls" "net" "net/http" @@ -59,11 +60,8 @@ type Listener struct { config *Config } -func ListenWS(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) { - networkSettings, err := options.Stream.GetEffectiveTransportSettings() - if err != nil { - return nil, err - } +func ListenWS(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) { + networkSettings := internet.TransportSettingsFromContext(ctx) wsSettings := networkSettings.(*Config) l := &Listener{ @@ -71,18 +69,14 @@ func ListenWS(address v2net.Address, port v2net.Port, options internet.ListenOpt awaitingConns: make(chan *ConnectionWithError, 32), config: wsSettings, } - if options.Stream != nil && options.Stream.HasSecuritySettings() { - securitySettings, err := options.Stream.GetEffectiveSecuritySettings() - if err != nil { - return nil, errors.Base(err).Message("WebSocket: Failed to create apply TLS config.") - } + if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { tlsConfig, ok := securitySettings.(*v2tls.Config) if ok { l.tlsConfig = tlsConfig.GetTLSConfig() } } - err = l.listenws(address, port) + err := l.listenws(address, port) return l, err } diff --git a/transport/internet/websocket/ws_test.go b/transport/internet/websocket/ws_test.go index c351cd39c..3c6124a18 100644 --- a/transport/internet/websocket/ws_test.go +++ b/transport/internet/websocket/ws_test.go @@ -1,15 +1,12 @@ package websocket_test import ( + "bytes" + "context" "testing" "time" - "bytes" - - "context" - v2net "v2ray.com/core/common/net" - "v2ray.com/core/common/serial" "v2ray.com/core/testing/assert" tlsgen "v2ray.com/core/testing/tls" "v2ray.com/core/transport/internet" @@ -19,19 +16,9 @@ import ( func Test_listenWSAndDial(t *testing.T) { assert := assert.On(t) - listen, err := ListenWS(v2net.DomainAddress("localhost"), 13146, internet.ListenOptions{ - Stream: &internet.StreamConfig{ - Protocol: internet.TransportProtocol_WebSocket, - TransportSettings: []*internet.TransportConfig{ - { - Protocol: internet.TransportProtocol_WebSocket, - Settings: serial.ToTypedMessage(&Config{ - Path: "ws", - }), - }, - }, - }, - }) + listen, err := ListenWS(internet.ContextWithTransportSettings(context.Background(), &Config{ + Path: "ws", + }), v2net.DomainAddress("localhost"), 13146) assert.Error(err).IsNil() go func() { for { @@ -99,33 +86,6 @@ func Test_listenWSAndDial_TLS(t *testing.T) { assert.Fail("Too slow") }() - listen, err := ListenWS(v2net.DomainAddress("localhost"), 13143, internet.ListenOptions{ - Stream: &internet.StreamConfig{ - SecurityType: serial.GetMessageType(new(v2tls.Config)), - SecuritySettings: []*serial.TypedMessage{serial.ToTypedMessage(&v2tls.Config{ - Certificate: []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()}, - })}, - Protocol: internet.TransportProtocol_WebSocket, - TransportSettings: []*internet.TransportConfig{ - { - Protocol: internet.TransportProtocol_WebSocket, - Settings: serial.ToTypedMessage(&Config{ - Path: "wss", - ConnectionReuse: &ConnectionReuse{ - Enable: true, - }, - }), - }, - }, - }, - }) - assert.Error(err).IsNil() - go func() { - conn, err := listen.Accept() - assert.Error(err).IsNil() - conn.Close() - listen.Close() - }() ctx := internet.ContextWithTransportSettings(context.Background(), &Config{ Path: "wss", ConnectionReuse: &ConnectionReuse{ @@ -134,7 +94,17 @@ func Test_listenWSAndDial_TLS(t *testing.T) { }) ctx = internet.ContextWithSecuritySettings(ctx, &v2tls.Config{ AllowInsecure: true, + Certificate: []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()}, }) + listen, err := ListenWS(ctx, v2net.DomainAddress("localhost"), 13143) + assert.Error(err).IsNil() + go func() { + conn, err := listen.Accept() + assert.Error(err).IsNil() + conn.Close() + listen.Close() + }() + conn, err := Dial(ctx, v2net.TCPDestination(v2net.DomainAddress("localhost"), 13143)) assert.Error(err).IsNil() conn.Close()