1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-22 01:57:12 -05:00

adopt context in listeners

This commit is contained in:
Darien Raymond 2017-02-24 01:05:16 +01:00
parent 702cfd69de
commit 7792237b50
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
6 changed files with 49 additions and 102 deletions

View File

@ -10,7 +10,6 @@ import (
"time"
v2net "v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert"
"v2ray.com/core/transport/internet"
. "v2ray.com/core/transport/internet/kcp"
@ -19,17 +18,7 @@ import (
func TestDialAndListen(t *testing.T) {
assert := assert.On(t)
listerner, err := NewListener(v2net.LocalHostIP, v2net.Port(0), internet.ListenOptions{
Stream: &internet.StreamConfig{
Protocol: internet.TransportProtocol_MKCP,
TransportSettings: []*internet.TransportConfig{
{
Protocol: internet.TransportProtocol_MKCP,
Settings: serial.ToTypedMessage(&Config{}),
},
},
},
})
listerner, err := NewListener(internet.ContextWithTransportSettings(context.Background(), &Config{}), v2net.LocalHostIP, v2net.Port(0))
assert.Error(err).IsNil()
port := v2net.Port(listerner.Addr().(*net.UDPAddr).Port)

View File

@ -1,6 +1,7 @@
package kcp
import (
"context"
"crypto/cipher"
"crypto/tls"
"io"
@ -90,12 +91,8 @@ type Listener struct {
security cipher.AEAD
}
func NewListener(address v2net.Address, port v2net.Port, options internet.ListenOptions) (*Listener, error) {
networkSettings, err := options.Stream.GetEffectiveTransportSettings()
if err != nil {
log.Error("KCP|Listener: Failed to get KCP settings: ", err)
return nil, err
}
func NewListener(ctx context.Context, address v2net.Address, port v2net.Port) (*Listener, error) {
networkSettings := internet.TransportSettingsFromContext(ctx)
kcpSettings := networkSettings.(*Config)
kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false}
@ -119,12 +116,8 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen
closed: make(chan bool),
config: kcpSettings,
}
if options.Stream != nil && options.Stream.HasSecuritySettings() {
securitySettings, err := options.Stream.GetEffectiveSecuritySettings()
if err != nil {
log.Error("KCP|Listener: Failed to get security settings: ", err)
return nil, err
}
securitySettings := internet.SecuritySettingsFromContext(ctx)
if securitySettings != nil {
switch securitySettings := securitySettings.(type) {
case *v2tls.Config:
l.tlsConfig = securitySettings.GetTLSConfig()
@ -295,8 +288,8 @@ func (v *Writer) Close() error {
return nil
}
func ListenKCP(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) {
return NewListener(address, port, options)
func ListenKCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
return NewListener(ctx, address, port)
}
func init() {

View File

@ -1,6 +1,7 @@
package tcp
import (
"context"
"crypto/tls"
"net"
"sync"
@ -34,7 +35,7 @@ type TCPListener struct {
config *Config
}
func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) {
func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: address.IP(),
Port: int(port),
@ -43,10 +44,7 @@ func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOp
return nil, err
}
log.Info("TCP|Listener: Listening on ", address, ":", port)
networkSettings, err := options.Stream.GetEffectiveTransportSettings()
if err != nil {
return nil, err
}
networkSettings := internet.TransportSettingsFromContext(ctx)
tcpSettings := networkSettings.(*Config)
l := &TCPListener{
@ -55,12 +53,7 @@ func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOp
awaitingConns: make(chan *ConnectionWithError, 32),
config: tcpSettings,
}
if options.Stream != nil && options.Stream.HasSecuritySettings() {
securitySettings, err := options.Stream.GetEffectiveSecuritySettings()
if err != nil {
log.Error("TCP: Failed to get security config: ", err)
return nil, err
}
if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
tlsConfig, ok := securitySettings.(*v2tls.Config)
if ok {
l.tlsConfig = tlsConfig.GetTLSConfig()

View File

@ -3,6 +3,8 @@ package internet
import (
"net"
"context"
"v2ray.com/core/app/log"
"v2ray.com/core/common/errors"
v2net "v2ray.com/core/common/net"
@ -21,11 +23,7 @@ func RegisterTransportListener(protocol TransportProtocol, listener ListenFunc)
return nil
}
type ListenFunc func(address v2net.Address, port v2net.Port, options ListenOptions) (Listener, error)
type ListenOptions struct {
Stream *StreamConfig
RecvOrigDest bool
}
type ListenFunc func(ctx context.Context, address v2net.Address, port v2net.Port) (Listener, error)
type Listener interface {
Accept() (Connection, error)
@ -40,15 +38,25 @@ type TCPHub struct {
}
func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandler, settings *StreamConfig) (*TCPHub, error) {
options := ListenOptions{
Stream: settings,
}
ctx := context.Background()
protocol := settings.GetEffectiveProtocol()
transportSettings, err := settings.GetEffectiveTransportSettings()
if err != nil {
return nil, err
}
ctx = ContextWithTransportSettings(ctx, transportSettings)
if settings != nil && settings.HasSecuritySettings() {
securitySettings, err := settings.GetEffectiveSecuritySettings()
if err != nil {
return nil, err
}
ctx = ContextWithSecuritySettings(ctx, securitySettings)
}
listenFunc := transportListenerCache[protocol]
if listenFunc == nil {
return nil, errors.New("Internet|TCPHub: ", protocol, " listener not registered.")
}
listener, err := listenFunc(address, port, options)
listener, err := listenFunc(ctx, address, port)
if err != nil {
return nil, errors.Base(err).Message("Internet|TCPHub: Failed to listen on address: ", address, ":", port)
}

View File

@ -1,6 +1,7 @@
package websocket
import (
"context"
"crypto/tls"
"net"
"net/http"
@ -59,11 +60,8 @@ type Listener struct {
config *Config
}
func ListenWS(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) {
networkSettings, err := options.Stream.GetEffectiveTransportSettings()
if err != nil {
return nil, err
}
func ListenWS(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
networkSettings := internet.TransportSettingsFromContext(ctx)
wsSettings := networkSettings.(*Config)
l := &Listener{
@ -71,18 +69,14 @@ func ListenWS(address v2net.Address, port v2net.Port, options internet.ListenOpt
awaitingConns: make(chan *ConnectionWithError, 32),
config: wsSettings,
}
if options.Stream != nil && options.Stream.HasSecuritySettings() {
securitySettings, err := options.Stream.GetEffectiveSecuritySettings()
if err != nil {
return nil, errors.Base(err).Message("WebSocket: Failed to create apply TLS config.")
}
if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
tlsConfig, ok := securitySettings.(*v2tls.Config)
if ok {
l.tlsConfig = tlsConfig.GetTLSConfig()
}
}
err = l.listenws(address, port)
err := l.listenws(address, port)
return l, err
}

View File

@ -1,15 +1,12 @@
package websocket_test
import (
"bytes"
"context"
"testing"
"time"
"bytes"
"context"
v2net "v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert"
tlsgen "v2ray.com/core/testing/tls"
"v2ray.com/core/transport/internet"
@ -19,19 +16,9 @@ import (
func Test_listenWSAndDial(t *testing.T) {
assert := assert.On(t)
listen, err := ListenWS(v2net.DomainAddress("localhost"), 13146, internet.ListenOptions{
Stream: &internet.StreamConfig{
Protocol: internet.TransportProtocol_WebSocket,
TransportSettings: []*internet.TransportConfig{
{
Protocol: internet.TransportProtocol_WebSocket,
Settings: serial.ToTypedMessage(&Config{
Path: "ws",
}),
},
},
},
})
listen, err := ListenWS(internet.ContextWithTransportSettings(context.Background(), &Config{
Path: "ws",
}), v2net.DomainAddress("localhost"), 13146)
assert.Error(err).IsNil()
go func() {
for {
@ -99,33 +86,6 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
assert.Fail("Too slow")
}()
listen, err := ListenWS(v2net.DomainAddress("localhost"), 13143, internet.ListenOptions{
Stream: &internet.StreamConfig{
SecurityType: serial.GetMessageType(new(v2tls.Config)),
SecuritySettings: []*serial.TypedMessage{serial.ToTypedMessage(&v2tls.Config{
Certificate: []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()},
})},
Protocol: internet.TransportProtocol_WebSocket,
TransportSettings: []*internet.TransportConfig{
{
Protocol: internet.TransportProtocol_WebSocket,
Settings: serial.ToTypedMessage(&Config{
Path: "wss",
ConnectionReuse: &ConnectionReuse{
Enable: true,
},
}),
},
},
},
})
assert.Error(err).IsNil()
go func() {
conn, err := listen.Accept()
assert.Error(err).IsNil()
conn.Close()
listen.Close()
}()
ctx := internet.ContextWithTransportSettings(context.Background(), &Config{
Path: "wss",
ConnectionReuse: &ConnectionReuse{
@ -134,7 +94,17 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
})
ctx = internet.ContextWithSecuritySettings(ctx, &v2tls.Config{
AllowInsecure: true,
Certificate: []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()},
})
listen, err := ListenWS(ctx, v2net.DomainAddress("localhost"), 13143)
assert.Error(err).IsNil()
go func() {
conn, err := listen.Accept()
assert.Error(err).IsNil()
conn.Close()
listen.Close()
}()
conn, err := Dial(ctx, v2net.TCPDestination(v2net.DomainAddress("localhost"), 13143))
assert.Error(err).IsNil()
conn.Close()