1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-03 07:56:42 -05:00

adopt context in listeners

This commit is contained in:
Darien Raymond 2017-02-24 01:05:16 +01:00
parent 702cfd69de
commit 7792237b50
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
6 changed files with 49 additions and 102 deletions

View File

@ -10,7 +10,6 @@ import (
"time" "time"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert" "v2ray.com/core/testing/assert"
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"
. "v2ray.com/core/transport/internet/kcp" . "v2ray.com/core/transport/internet/kcp"
@ -19,17 +18,7 @@ import (
func TestDialAndListen(t *testing.T) { func TestDialAndListen(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
listerner, err := NewListener(v2net.LocalHostIP, v2net.Port(0), internet.ListenOptions{ listerner, err := NewListener(internet.ContextWithTransportSettings(context.Background(), &Config{}), v2net.LocalHostIP, v2net.Port(0))
Stream: &internet.StreamConfig{
Protocol: internet.TransportProtocol_MKCP,
TransportSettings: []*internet.TransportConfig{
{
Protocol: internet.TransportProtocol_MKCP,
Settings: serial.ToTypedMessage(&Config{}),
},
},
},
})
assert.Error(err).IsNil() assert.Error(err).IsNil()
port := v2net.Port(listerner.Addr().(*net.UDPAddr).Port) port := v2net.Port(listerner.Addr().(*net.UDPAddr).Port)

View File

@ -1,6 +1,7 @@
package kcp package kcp
import ( import (
"context"
"crypto/cipher" "crypto/cipher"
"crypto/tls" "crypto/tls"
"io" "io"
@ -90,12 +91,8 @@ type Listener struct {
security cipher.AEAD security cipher.AEAD
} }
func NewListener(address v2net.Address, port v2net.Port, options internet.ListenOptions) (*Listener, error) { func NewListener(ctx context.Context, address v2net.Address, port v2net.Port) (*Listener, error) {
networkSettings, err := options.Stream.GetEffectiveTransportSettings() networkSettings := internet.TransportSettingsFromContext(ctx)
if err != nil {
log.Error("KCP|Listener: Failed to get KCP settings: ", err)
return nil, err
}
kcpSettings := networkSettings.(*Config) kcpSettings := networkSettings.(*Config)
kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false} kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false}
@ -119,12 +116,8 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen
closed: make(chan bool), closed: make(chan bool),
config: kcpSettings, config: kcpSettings,
} }
if options.Stream != nil && options.Stream.HasSecuritySettings() { securitySettings := internet.SecuritySettingsFromContext(ctx)
securitySettings, err := options.Stream.GetEffectiveSecuritySettings() if securitySettings != nil {
if err != nil {
log.Error("KCP|Listener: Failed to get security settings: ", err)
return nil, err
}
switch securitySettings := securitySettings.(type) { switch securitySettings := securitySettings.(type) {
case *v2tls.Config: case *v2tls.Config:
l.tlsConfig = securitySettings.GetTLSConfig() l.tlsConfig = securitySettings.GetTLSConfig()
@ -295,8 +288,8 @@ func (v *Writer) Close() error {
return nil return nil
} }
func ListenKCP(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) { func ListenKCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
return NewListener(address, port, options) return NewListener(ctx, address, port)
} }
func init() { func init() {

View File

@ -1,6 +1,7 @@
package tcp package tcp
import ( import (
"context"
"crypto/tls" "crypto/tls"
"net" "net"
"sync" "sync"
@ -34,7 +35,7 @@ type TCPListener struct {
config *Config config *Config
} }
func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) { func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{ listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: address.IP(), IP: address.IP(),
Port: int(port), Port: int(port),
@ -43,10 +44,7 @@ func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOp
return nil, err return nil, err
} }
log.Info("TCP|Listener: Listening on ", address, ":", port) log.Info("TCP|Listener: Listening on ", address, ":", port)
networkSettings, err := options.Stream.GetEffectiveTransportSettings() networkSettings := internet.TransportSettingsFromContext(ctx)
if err != nil {
return nil, err
}
tcpSettings := networkSettings.(*Config) tcpSettings := networkSettings.(*Config)
l := &TCPListener{ l := &TCPListener{
@ -55,12 +53,7 @@ func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOp
awaitingConns: make(chan *ConnectionWithError, 32), awaitingConns: make(chan *ConnectionWithError, 32),
config: tcpSettings, config: tcpSettings,
} }
if options.Stream != nil && options.Stream.HasSecuritySettings() { if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
securitySettings, err := options.Stream.GetEffectiveSecuritySettings()
if err != nil {
log.Error("TCP: Failed to get security config: ", err)
return nil, err
}
tlsConfig, ok := securitySettings.(*v2tls.Config) tlsConfig, ok := securitySettings.(*v2tls.Config)
if ok { if ok {
l.tlsConfig = tlsConfig.GetTLSConfig() l.tlsConfig = tlsConfig.GetTLSConfig()

View File

@ -3,6 +3,8 @@ package internet
import ( import (
"net" "net"
"context"
"v2ray.com/core/app/log" "v2ray.com/core/app/log"
"v2ray.com/core/common/errors" "v2ray.com/core/common/errors"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
@ -21,11 +23,7 @@ func RegisterTransportListener(protocol TransportProtocol, listener ListenFunc)
return nil return nil
} }
type ListenFunc func(address v2net.Address, port v2net.Port, options ListenOptions) (Listener, error) type ListenFunc func(ctx context.Context, address v2net.Address, port v2net.Port) (Listener, error)
type ListenOptions struct {
Stream *StreamConfig
RecvOrigDest bool
}
type Listener interface { type Listener interface {
Accept() (Connection, error) Accept() (Connection, error)
@ -40,15 +38,25 @@ type TCPHub struct {
} }
func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandler, settings *StreamConfig) (*TCPHub, error) { func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandler, settings *StreamConfig) (*TCPHub, error) {
options := ListenOptions{ ctx := context.Background()
Stream: settings,
}
protocol := settings.GetEffectiveProtocol() protocol := settings.GetEffectiveProtocol()
transportSettings, err := settings.GetEffectiveTransportSettings()
if err != nil {
return nil, err
}
ctx = ContextWithTransportSettings(ctx, transportSettings)
if settings != nil && settings.HasSecuritySettings() {
securitySettings, err := settings.GetEffectiveSecuritySettings()
if err != nil {
return nil, err
}
ctx = ContextWithSecuritySettings(ctx, securitySettings)
}
listenFunc := transportListenerCache[protocol] listenFunc := transportListenerCache[protocol]
if listenFunc == nil { if listenFunc == nil {
return nil, errors.New("Internet|TCPHub: ", protocol, " listener not registered.") return nil, errors.New("Internet|TCPHub: ", protocol, " listener not registered.")
} }
listener, err := listenFunc(address, port, options) listener, err := listenFunc(ctx, address, port)
if err != nil { if err != nil {
return nil, errors.Base(err).Message("Internet|TCPHub: Failed to listen on address: ", address, ":", port) return nil, errors.Base(err).Message("Internet|TCPHub: Failed to listen on address: ", address, ":", port)
} }

View File

@ -1,6 +1,7 @@
package websocket package websocket
import ( import (
"context"
"crypto/tls" "crypto/tls"
"net" "net"
"net/http" "net/http"
@ -59,11 +60,8 @@ type Listener struct {
config *Config config *Config
} }
func ListenWS(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) { func ListenWS(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
networkSettings, err := options.Stream.GetEffectiveTransportSettings() networkSettings := internet.TransportSettingsFromContext(ctx)
if err != nil {
return nil, err
}
wsSettings := networkSettings.(*Config) wsSettings := networkSettings.(*Config)
l := &Listener{ l := &Listener{
@ -71,18 +69,14 @@ func ListenWS(address v2net.Address, port v2net.Port, options internet.ListenOpt
awaitingConns: make(chan *ConnectionWithError, 32), awaitingConns: make(chan *ConnectionWithError, 32),
config: wsSettings, config: wsSettings,
} }
if options.Stream != nil && options.Stream.HasSecuritySettings() { if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
securitySettings, err := options.Stream.GetEffectiveSecuritySettings()
if err != nil {
return nil, errors.Base(err).Message("WebSocket: Failed to create apply TLS config.")
}
tlsConfig, ok := securitySettings.(*v2tls.Config) tlsConfig, ok := securitySettings.(*v2tls.Config)
if ok { if ok {
l.tlsConfig = tlsConfig.GetTLSConfig() l.tlsConfig = tlsConfig.GetTLSConfig()
} }
} }
err = l.listenws(address, port) err := l.listenws(address, port)
return l, err return l, err
} }

View File

@ -1,15 +1,12 @@
package websocket_test package websocket_test
import ( import (
"bytes"
"context"
"testing" "testing"
"time" "time"
"bytes"
"context"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert" "v2ray.com/core/testing/assert"
tlsgen "v2ray.com/core/testing/tls" tlsgen "v2ray.com/core/testing/tls"
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"
@ -19,19 +16,9 @@ import (
func Test_listenWSAndDial(t *testing.T) { func Test_listenWSAndDial(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
listen, err := ListenWS(v2net.DomainAddress("localhost"), 13146, internet.ListenOptions{ listen, err := ListenWS(internet.ContextWithTransportSettings(context.Background(), &Config{
Stream: &internet.StreamConfig{
Protocol: internet.TransportProtocol_WebSocket,
TransportSettings: []*internet.TransportConfig{
{
Protocol: internet.TransportProtocol_WebSocket,
Settings: serial.ToTypedMessage(&Config{
Path: "ws", Path: "ws",
}), }), v2net.DomainAddress("localhost"), 13146)
},
},
},
})
assert.Error(err).IsNil() assert.Error(err).IsNil()
go func() { go func() {
for { for {
@ -99,33 +86,6 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
assert.Fail("Too slow") assert.Fail("Too slow")
}() }()
listen, err := ListenWS(v2net.DomainAddress("localhost"), 13143, internet.ListenOptions{
Stream: &internet.StreamConfig{
SecurityType: serial.GetMessageType(new(v2tls.Config)),
SecuritySettings: []*serial.TypedMessage{serial.ToTypedMessage(&v2tls.Config{
Certificate: []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()},
})},
Protocol: internet.TransportProtocol_WebSocket,
TransportSettings: []*internet.TransportConfig{
{
Protocol: internet.TransportProtocol_WebSocket,
Settings: serial.ToTypedMessage(&Config{
Path: "wss",
ConnectionReuse: &ConnectionReuse{
Enable: true,
},
}),
},
},
},
})
assert.Error(err).IsNil()
go func() {
conn, err := listen.Accept()
assert.Error(err).IsNil()
conn.Close()
listen.Close()
}()
ctx := internet.ContextWithTransportSettings(context.Background(), &Config{ ctx := internet.ContextWithTransportSettings(context.Background(), &Config{
Path: "wss", Path: "wss",
ConnectionReuse: &ConnectionReuse{ ConnectionReuse: &ConnectionReuse{
@ -134,7 +94,17 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
}) })
ctx = internet.ContextWithSecuritySettings(ctx, &v2tls.Config{ ctx = internet.ContextWithSecuritySettings(ctx, &v2tls.Config{
AllowInsecure: true, AllowInsecure: true,
Certificate: []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()},
}) })
listen, err := ListenWS(ctx, v2net.DomainAddress("localhost"), 13143)
assert.Error(err).IsNil()
go func() {
conn, err := listen.Accept()
assert.Error(err).IsNil()
conn.Close()
listen.Close()
}()
conn, err := Dial(ctx, v2net.TCPDestination(v2net.DomainAddress("localhost"), 13143)) conn, err := Dial(ctx, v2net.TCPDestination(v2net.DomainAddress("localhost"), 13143))
assert.Error(err).IsNil() assert.Error(err).IsNil()
conn.Close() conn.Close()