1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-21 17:46:58 -05:00

integrate tls settings in ws

This commit is contained in:
Darien Raymond 2016-09-30 16:53:40 +02:00
parent af6abfa3e3
commit 5ec948f690
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
16 changed files with 194 additions and 155 deletions

10
transport/config.proto Normal file
View File

@ -0,0 +1,10 @@
syntax = "proto3";
package v2ray.core.transport;
option go_package = "transport";
option java_package = "com.v2ray.core.transport";
option java_outer_classname = "ConfigProto";
message Config {
}

View File

@ -39,10 +39,13 @@ type TLSSettings struct {
func (this *TLSSettings) GetTLSConfig() *tls.Config { func (this *TLSSettings) GetTLSConfig() *tls.Config {
config := &tls.Config{ config := &tls.Config{
InsecureSkipVerify: this.AllowInsecure,
ClientSessionCache: globalSessionCache, ClientSessionCache: globalSessionCache,
} }
if this == nil {
return config
}
config.InsecureSkipVerify = this.AllowInsecure
config.Certificates = this.Certs config.Certificates = this.Certs
config.BuildNameToCertificate() config.BuildNameToCertificate()

View File

@ -1,19 +1,21 @@
package internet package internet
import ( import (
"crypto/tls"
"errors" "errors"
"net" "net"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
v2tls "v2ray.com/core/transport/internet/tls"
) )
var ( var (
ErrUnsupportedStreamType = errors.New("Unsupported stream type.") ErrUnsupportedStreamType = errors.New("Unsupported stream type.")
) )
type Dialer func(src v2net.Address, dest v2net.Destination) (Connection, error) type DialerOptions struct {
Stream *StreamSettings
}
type Dialer func(src v2net.Address, dest v2net.Destination, options DialerOptions) (Connection, error)
var ( var (
TCPDialer Dialer TCPDialer Dialer
@ -27,18 +29,21 @@ func Dial(src v2net.Address, dest v2net.Destination, settings *StreamSettings) (
var connection Connection var connection Connection
var err error var err error
dialerOptions := DialerOptions{
Stream: settings,
}
if dest.Network == v2net.Network_TCP { if dest.Network == v2net.Network_TCP {
switch { switch {
case settings.IsCapableOf(StreamConnectionTypeTCP): case settings.IsCapableOf(StreamConnectionTypeTCP):
connection, err = TCPDialer(src, dest) connection, err = TCPDialer(src, dest, dialerOptions)
case settings.IsCapableOf(StreamConnectionTypeKCP): case settings.IsCapableOf(StreamConnectionTypeKCP):
connection, err = KCPDialer(src, dest) connection, err = KCPDialer(src, dest, dialerOptions)
case settings.IsCapableOf(StreamConnectionTypeWebSocket): case settings.IsCapableOf(StreamConnectionTypeWebSocket):
connection, err = WSDialer(src, dest) connection, err = WSDialer(src, dest, dialerOptions)
// This check has to be the last one. // This check has to be the last one.
case settings.IsCapableOf(StreamConnectionTypeRawTCP): case settings.IsCapableOf(StreamConnectionTypeRawTCP):
connection, err = RawTCPDialer(src, dest) connection, err = RawTCPDialer(src, dest, dialerOptions)
default: default:
return nil, ErrUnsupportedStreamType return nil, ErrUnsupportedStreamType
} }
@ -46,19 +51,10 @@ func Dial(src v2net.Address, dest v2net.Destination, settings *StreamSettings) (
return nil, err return nil, err
} }
if settings.Security == StreamSecurityTypeNone { return connection, nil
return connection, nil
}
config := settings.TLSSettings.GetTLSConfig()
if dest.Address.Family().IsDomain() {
config.ServerName = dest.Address.Domain()
}
tlsConn := tls.Client(connection, config)
return v2tls.NewConnection(tlsConn), nil
} }
return UDPDialer(src, dest) return UDPDialer(src, dest, dialerOptions)
} }
func DialToDest(src v2net.Address, dest v2net.Destination) (net.Conn, error) { func DialToDest(src v2net.Address, dest v2net.Destination) (net.Conn, error) {

View File

@ -1,6 +1,7 @@
package kcp package kcp
import ( import (
"crypto/tls"
"net" "net"
"sync/atomic" "sync/atomic"
@ -8,13 +9,14 @@ import (
"v2ray.com/core/common/log" "v2ray.com/core/common/log"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"
v2tls "v2ray.com/core/transport/internet/tls"
) )
var ( var (
globalConv = uint32(dice.Roll(65536)) globalConv = uint32(dice.Roll(65536))
) )
func DialKCP(src v2net.Address, dest v2net.Destination) (internet.Connection, error) { func DialKCP(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (internet.Connection, error) {
dest.Network = v2net.Network_UDP dest.Network = v2net.Network_UDP
log.Info("KCP|Dialer: Dialing KCP to ", dest) log.Info("KCP|Dialer: Dialing KCP to ", dest)
conn, err := internet.DialToDest(src, dest) conn, err := internet.DialToDest(src, dest)
@ -32,7 +34,19 @@ func DialKCP(src v2net.Address, dest v2net.Destination) (internet.Connection, er
session := NewConnection(conv, conn, conn.LocalAddr().(*net.UDPAddr), conn.RemoteAddr().(*net.UDPAddr), cpip) session := NewConnection(conv, conn, conn.LocalAddr().(*net.UDPAddr), conn.RemoteAddr().(*net.UDPAddr), cpip)
session.FetchInputFrom(conn) session.FetchInputFrom(conn)
return session, nil var iConn internet.Connection
iConn = session
if options.Stream != nil && options.Stream.Security == internet.StreamSecurityTypeTLS {
config := options.Stream.TLSSettings.GetTLSConfig()
if dest.Address.Family().IsDomain() {
config.ServerName = dest.Address.Domain()
}
tlsConn := tls.Client(conn, config)
iConn = v2tls.NewConnection(tlsConn)
}
return iConn, nil
} }
func init() { func init() {

View File

@ -10,13 +10,14 @@ import (
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
"v2ray.com/core/testing/assert" "v2ray.com/core/testing/assert"
"v2ray.com/core/transport/internet"
. "v2ray.com/core/transport/internet/kcp" . "v2ray.com/core/transport/internet/kcp"
) )
func TestDialAndListen(t *testing.T) { func TestDialAndListen(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
listerner, err := NewListener(v2net.LocalHostIP, v2net.Port(0)) listerner, err := NewListener(v2net.LocalHostIP, v2net.Port(0), internet.ListenOptions{})
assert.Error(err).IsNil() assert.Error(err).IsNil()
port := v2net.Port(listerner.Addr().(*net.UDPAddr).Port) port := v2net.Port(listerner.Addr().(*net.UDPAddr).Port)
@ -45,7 +46,7 @@ func TestDialAndListen(t *testing.T) {
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
clientConn, err := DialKCP(v2net.LocalHostIP, v2net.UDPDestination(v2net.LocalHostIP, port)) clientConn, err := DialKCP(v2net.LocalHostIP, v2net.UDPDestination(v2net.LocalHostIP, port), internet.DialerOptions{})
assert.Error(err).IsNil() assert.Error(err).IsNil()
wg.Add(1) wg.Add(1)

View File

@ -1,6 +1,7 @@
package kcp package kcp
import ( import (
"crypto/tls"
"net" "net"
"sync" "sync"
"time" "time"
@ -11,6 +12,7 @@ import (
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
"v2ray.com/core/proxy" "v2ray.com/core/proxy"
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"
v2tls "v2ray.com/core/transport/internet/tls"
"v2ray.com/core/transport/internet/udp" "v2ray.com/core/transport/internet/udp"
) )
@ -22,9 +24,10 @@ type Listener struct {
sessions map[string]*Connection sessions map[string]*Connection
awaitingConns chan *Connection awaitingConns chan *Connection
hub *udp.UDPHub hub *udp.UDPHub
tlsConfig *tls.Config
} }
func NewListener(address v2net.Address, port v2net.Port) (*Listener, error) { func NewListener(address v2net.Address, port v2net.Port, options internet.ListenOptions) (*Listener, error) {
auth, err := effectiveConfig.GetAuthenticator() auth, err := effectiveConfig.GetAuthenticator()
if err != nil { if err != nil {
return nil, err return nil, err
@ -35,6 +38,9 @@ func NewListener(address v2net.Address, port v2net.Port) (*Listener, error) {
awaitingConns: make(chan *Connection, 64), awaitingConns: make(chan *Connection, 64),
running: true, running: true,
} }
if options.Stream != nil && options.Stream.Security == internet.StreamSecurityTypeTLS {
l.tlsConfig = options.Stream.TLSSettings.GetTLSConfig()
}
hub, err := udp.ListenUDP(address, port, udp.ListenOption{Callback: l.OnReceive}) hub, err := udp.ListenUDP(address, port, udp.ListenOption{Callback: l.OnReceive})
if err != nil { if err != nil {
return nil, err return nil, err
@ -120,6 +126,10 @@ func (this *Listener) Accept() (internet.Connection, error) {
} }
select { select {
case conn := <-this.awaitingConns: case conn := <-this.awaitingConns:
if this.tlsConfig != nil {
tlsConn := tls.Server(conn, this.tlsConfig)
return v2tls.NewConnection(tlsConn), nil
}
return conn, nil return conn, nil
case <-time.After(time.Second): case <-time.After(time.Second):
@ -173,8 +183,8 @@ func (this *Writer) Close() error {
return nil return nil
} }
func ListenKCP(address v2net.Address, port v2net.Port) (internet.Listener, error) { func ListenKCP(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) {
return NewListener(address, port) return NewListener(address, port, options)
} }
func init() { func init() {

View File

@ -3,6 +3,7 @@ package tcp
import ( import (
"net" "net"
"crypto/tls"
"v2ray.com/core/common/log" "v2ray.com/core/common/log"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"
@ -12,7 +13,7 @@ var (
globalCache = NewConnectionCache() globalCache = NewConnectionCache()
) )
func Dial(src v2net.Address, dest v2net.Destination) (internet.Connection, error) { func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (internet.Connection, error) {
log.Info("Dailing TCP to ", dest) log.Info("Dailing TCP to ", dest)
if src == nil { if src == nil {
src = v2net.AnyIP src = v2net.AnyIP
@ -29,15 +30,23 @@ func Dial(src v2net.Address, dest v2net.Destination) (internet.Connection, error
return nil, err return nil, err
} }
} }
if options.Stream != nil && options.Stream.Security == internet.StreamSecurityTypeTLS {
config := options.Stream.TLSSettings.GetTLSConfig()
if dest.Address.Family().IsDomain() {
config.ServerName = dest.Address.Domain()
}
conn = tls.Client(conn, config)
}
return NewConnection(id, conn, globalCache), nil return NewConnection(id, conn, globalCache), nil
} }
func DialRaw(src v2net.Address, dest v2net.Destination) (internet.Connection, error) { func DialRaw(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (internet.Connection, error) {
log.Info("Dailing Raw TCP to ", dest) log.Info("Dailing Raw TCP to ", dest)
conn, err := internet.DialToDest(src, dest) conn, err := internet.DialToDest(src, dest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO: handle dialer options
return &RawConnection{ return &RawConnection{
TCPConn: *conn.(*net.TCPConn), TCPConn: *conn.(*net.TCPConn),
}, nil }, nil

View File

@ -1,6 +1,7 @@
package tcp package tcp
import ( import (
"crypto/tls"
"errors" "errors"
"net" "net"
"sync" "sync"
@ -24,9 +25,10 @@ type TCPListener struct {
acccepting bool acccepting bool
listener *net.TCPListener listener *net.TCPListener
awaitingConns chan *ConnectionWithError awaitingConns chan *ConnectionWithError
tlsConfig *tls.Config
} }
func ListenTCP(address v2net.Address, port v2net.Port) (internet.Listener, error) { func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{ listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: address.IP(), IP: address.IP(),
Port: int(port), Port: int(port),
@ -39,6 +41,9 @@ func ListenTCP(address v2net.Address, port v2net.Port) (internet.Listener, error
listener: listener, listener: listener,
awaitingConns: make(chan *ConnectionWithError, 32), awaitingConns: make(chan *ConnectionWithError, 32),
} }
if options.Stream != nil && options.Stream.Security == internet.StreamSecurityTypeTLS {
l.tlsConfig = options.Stream.TLSSettings.GetTLSConfig()
}
go l.KeepAccepting() go l.KeepAccepting()
return l, nil return l, nil
} }
@ -53,7 +58,11 @@ func (this *TCPListener) Accept() (internet.Connection, error) {
if connErr.err != nil { if connErr.err != nil {
return nil, connErr.err return nil, connErr.err
} }
return NewConnection("", connErr.conn, this), nil conn := connErr.conn
if this.tlsConfig != nil {
conn = tls.Server(conn, this.tlsConfig)
}
return NewConnection("", conn, this), nil
case <-time.After(time.Second * 2): case <-time.After(time.Second * 2):
} }
} }
@ -139,7 +148,7 @@ func (this *RawTCPListener) Close() error {
return nil return nil
} }
func ListenRawTCP(address v2net.Address, port v2net.Port) (internet.Listener, error) { func ListenRawTCP(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{ listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: address.IP(), IP: address.IP(),
Port: int(port), Port: int(port),
@ -147,6 +156,7 @@ func ListenRawTCP(address v2net.Address, port v2net.Port) (internet.Listener, er
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO: handle listen options
return &RawTCPListener{ return &RawTCPListener{
accepting: true, accepting: true,
listener: listener, listener: listener,

View File

@ -1,14 +1,12 @@
package internet package internet
import ( import (
"crypto/tls"
"errors" "errors"
"net" "net"
"sync" "sync"
"v2ray.com/core/common/log" "v2ray.com/core/common/log"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
v2tls "v2ray.com/core/transport/internet/tls"
) )
var ( var (
@ -20,7 +18,11 @@ var (
WSListenFunc ListenFunc WSListenFunc ListenFunc
) )
type ListenFunc func(address v2net.Address, port v2net.Port) (Listener, error) type ListenFunc func(address v2net.Address, port v2net.Port, options ListenOptions) (Listener, error)
type ListenOptions struct {
Stream *StreamSettings
}
type Listener interface { type Listener interface {
Accept() (Connection, error) Accept() (Connection, error)
Close() error Close() error
@ -32,21 +34,23 @@ type TCPHub struct {
listener Listener listener Listener
connCallback ConnectionHandler connCallback ConnectionHandler
accepting bool accepting bool
tlsConfig *tls.Config
} }
func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandler, settings *StreamSettings) (*TCPHub, error) { func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandler, settings *StreamSettings) (*TCPHub, error) {
var listener Listener var listener Listener
var err error var err error
options := ListenOptions{
Stream: settings,
}
switch { switch {
case settings.IsCapableOf(StreamConnectionTypeTCP): case settings.IsCapableOf(StreamConnectionTypeTCP):
listener, err = TCPListenFunc(address, port) listener, err = TCPListenFunc(address, port, options)
case settings.IsCapableOf(StreamConnectionTypeKCP): case settings.IsCapableOf(StreamConnectionTypeKCP):
listener, err = KCPListenFunc(address, port) listener, err = KCPListenFunc(address, port, options)
case settings.IsCapableOf(StreamConnectionTypeWebSocket): case settings.IsCapableOf(StreamConnectionTypeWebSocket):
listener, err = WSListenFunc(address, port) listener, err = WSListenFunc(address, port, options)
case settings.IsCapableOf(StreamConnectionTypeRawTCP): case settings.IsCapableOf(StreamConnectionTypeRawTCP):
listener, err = RawTCPListenFunc(address, port) listener, err = RawTCPListenFunc(address, port, options)
default: default:
log.Error("Internet|Listener: Unknown stream type: ", settings.Type) log.Error("Internet|Listener: Unknown stream type: ", settings.Type)
err = ErrUnsupportedStreamType err = ErrUnsupportedStreamType
@ -57,15 +61,9 @@ func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandle
return nil, err return nil, err
} }
var tlsConfig *tls.Config
if settings.Security == StreamSecurityTypeTLS {
tlsConfig = settings.TLSSettings.GetTLSConfig()
}
hub := &TCPHub{ hub := &TCPHub{
listener: listener, listener: listener,
connCallback: callback, connCallback: callback,
tlsConfig: tlsConfig,
} }
go hub.start() go hub.start()
@ -88,10 +86,6 @@ func (this *TCPHub) start() {
} }
continue continue
} }
if this.tlsConfig != nil {
tlsConn := tls.Server(conn, this.tlsConfig)
conn = v2tls.NewConnection(tlsConn)
}
go this.connCallback(conn) go this.connCallback(conn)
} }
} }

View File

@ -18,11 +18,12 @@ func (this *Connection) Reusable() bool {
func (this *Connection) SetReusable(b bool) {} func (this *Connection) SetReusable(b bool) {}
func init() { func init() {
internet.UDPDialer = func(src v2net.Address, dest v2net.Destination) (internet.Connection, error) { internet.UDPDialer = func(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (internet.Connection, error) {
conn, err := internet.DialToDest(src, dest) conn, err := internet.DialToDest(src, dest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO: handle dialer options
return &Connection{ return &Connection{
UDPConn: *(conn.(*net.UDPConn)), UDPConn: *(conn.(*net.UDPConn)),
}, nil }, nil

View File

@ -1,12 +1,8 @@
package ws package ws
type Config struct { type Config struct {
ConnectionReuse bool ConnectionReuse bool
Path string Path string
Pto string
Cert string
PrivKey string
DeveloperInsecureSkipVerify bool
} }
func (this *Config) Apply() { func (this *Config) Apply() {
@ -17,6 +13,5 @@ var (
effectiveConfig = &Config{ effectiveConfig = &Config{
ConnectionReuse: true, ConnectionReuse: true,
Path: "", Path: "",
Pto: "",
} }
) )

View File

@ -8,23 +8,15 @@ func (this *Config) UnmarshalJSON(data []byte) error {
type JsonConfig struct { type JsonConfig struct {
ConnectionReuse bool `json:"connectionReuse"` ConnectionReuse bool `json:"connectionReuse"`
Path string `json:"Path"` Path string `json:"Path"`
Pto string `json:"Pto"`
Cert string `json:"Cert"`
PrivKey string `json:"PrivKey"`
} }
jsonConfig := &JsonConfig{ jsonConfig := &JsonConfig{
ConnectionReuse: true, ConnectionReuse: true,
Path: "", Path: "",
Pto: "",
} }
if err := json.Unmarshal(data, jsonConfig); err != nil { if err := json.Unmarshal(data, jsonConfig); err != nil {
return err return err
} }
this.ConnectionReuse = jsonConfig.ConnectionReuse this.ConnectionReuse = jsonConfig.ConnectionReuse
this.Path = jsonConfig.Path this.Path = jsonConfig.Path
this.Pto = jsonConfig.Pto
this.PrivKey = jsonConfig.PrivKey
this.Cert = jsonConfig.Cert
this.DeveloperInsecureSkipVerify = false
return nil return nil
} }

View File

@ -1,7 +1,6 @@
package ws package ws
import ( import (
"crypto/tls"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -16,7 +15,7 @@ var (
globalCache = NewConnectionCache() globalCache = NewConnectionCache()
) )
func Dial(src v2net.Address, dest v2net.Destination) (internet.Connection, error) { func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (internet.Connection, error) {
log.Info("WebSocket|Dailer: Creating connection to ", dest) log.Info("WebSocket|Dailer: Creating connection to ", dest)
if src == nil { if src == nil {
src = v2net.AnyIP src = v2net.AnyIP
@ -31,7 +30,7 @@ func Dial(src v2net.Address, dest v2net.Destination) (internet.Connection, error
} }
if conn == nil { if conn == nil {
var err error var err error
conn, err = wsDial(src, dest) conn, err = wsDial(src, dest, options)
if err != nil { if err != nil {
log.Warning("WebSocket|Dialer: Dial failed: ", err) log.Warning("WebSocket|Dialer: Dial failed: ", err)
return nil, err return nil, err
@ -44,20 +43,30 @@ func init() {
internet.WSDialer = Dial internet.WSDialer = Dial
} }
func wsDial(src v2net.Address, dest v2net.Destination) (*wsconn, error) { func wsDial(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (*wsconn, error) {
commonDial := func(network, addr string) (net.Conn, error) { commonDial := func(network, addr string) (net.Conn, error) {
return internet.DialToDest(src, dest) return internet.DialToDest(src, dest)
} }
tlsconf := &tls.Config{ServerName: dest.Address.Domain(), InsecureSkipVerify: effectiveConfig.DeveloperInsecureSkipVerify} dialer := websocket.Dialer{
NetDial: commonDial,
ReadBufferSize: 65536,
WriteBufferSize: 65536,
}
dialer := websocket.Dialer{NetDial: commonDial, ReadBufferSize: 65536, WriteBufferSize: 65536, TLSClientConfig: tlsconf} protocol := "ws"
effpto := calcPto(dest) if options.Stream != nil && options.Stream.Security == internet.StreamSecurityTypeTLS {
protocol = "wss"
dialer.TLSClientConfig = options.Stream.TLSSettings.GetTLSConfig()
if dest.Address.Family().IsDomain() {
dialer.TLSClientConfig.ServerName = dest.Address.Domain()
}
}
uri := func(dst v2net.Destination, pto string, path string) string { uri := func(dst v2net.Destination, pto string, path string) string {
return fmt.Sprintf("%v://%v/%v", pto, dst.NetAddr(), path) return fmt.Sprintf("%v://%v/%v", pto, dst.NetAddr(), path)
}(dest, effpto, effectiveConfig.Path) }(dest, protocol, effectiveConfig.Path)
conn, resp, err := dialer.Dial(uri, nil) conn, resp, err := dialer.Dial(uri, nil)
if err != nil { if err != nil {
@ -73,45 +82,3 @@ func wsDial(src v2net.Address, dest v2net.Destination) (*wsconn, error) {
return connv2ray return connv2ray
}().(*wsconn), nil }().(*wsconn), nil
} }
func calcPto(dst v2net.Destination) string {
if effectiveConfig.Pto != "" {
return effectiveConfig.Pto
}
switch dst.Port.Value() {
/*
Since the value is not given explicitly,
We are guessing it now.
HTTP Port:
80
8080
8880
2052
2082
2086
2095
HTTPS Port:
443
2053
2083
2087
2096
8443
if the port you are using is not well-known,
specify it to avoid this process.
We will return "CRASH"turn "unknown" if we can't guess it, cause Dial to fail.
*/
case 80, 8080, 8880, 2052, 2082, 2086, 2095:
return "ws"
case 443, 2053, 2083, 2087, 2096, 8443:
return "wss"
default:
return "unknown"
}
}

View File

@ -1,6 +1,7 @@
package ws package ws
import ( import (
"crypto/tls"
"errors" "errors"
"net" "net"
"net/http" "net/http"
@ -28,14 +29,18 @@ type WSListener struct {
acccepting bool acccepting bool
awaitingConns chan *ConnectionWithError awaitingConns chan *ConnectionWithError
listener *StoppableListener listener *StoppableListener
tlsConfig *tls.Config
} }
func ListenWS(address v2net.Address, port v2net.Port) (internet.Listener, error) { func ListenWS(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) {
l := &WSListener{ l := &WSListener{
acccepting: true, acccepting: true,
awaitingConns: make(chan *ConnectionWithError, 32), awaitingConns: make(chan *ConnectionWithError, 32),
} }
if options.Stream != nil && options.Stream.Security == internet.StreamSecurityTypeTLS {
l.tlsConfig = options.Stream.TLSSettings.GetTLSConfig()
}
err := l.listenws(address, port) err := l.listenws(address, port)
@ -77,10 +82,10 @@ func (wsl *WSListener) listenws(address v2net.Address, port v2net.Port) error {
return http.Serve(wsl.listener, nil) return http.Serve(wsl.listener, nil)
} }
if effectiveConfig.Pto == "wss" { if wsl.tlsConfig != nil {
listenerfunc = func() error { listenerfunc = func() error {
var err error var err error
wsl.listener, err = getstopableTLSlistener(effectiveConfig.Cert, effectiveConfig.PrivKey, address.String()+":"+strconv.Itoa(int(port.Value()))) wsl.listener, err = getstopableTLSlistener(wsl.tlsConfig, address.String()+":"+strconv.Itoa(int(port.Value())))
if err != nil { if err != nil {
return err return err
} }

View File

@ -2,14 +2,8 @@ package ws
import "crypto/tls" import "crypto/tls"
func getstopableTLSlistener(cert, key, listenaddr string) (*StoppableListener, error) { func getstopableTLSlistener(tlsConfig *tls.Config, listenaddr string) (*StoppableListener, error) {
cer, err := tls.LoadX509KeyPair(cert, key) ln, err := tls.Listen("tcp", listenaddr, tlsConfig)
if err != nil {
return nil, err
}
config := &tls.Config{Certificates: []tls.Certificate{cer}}
ln, err := tls.Listen("tcp", listenaddr, config)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,18 +1,20 @@
package ws_test package ws_test
import ( import (
"crypto/tls"
"testing" "testing"
"time" "time"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
"v2ray.com/core/testing/assert" "v2ray.com/core/testing/assert"
"v2ray.com/core/transport/internet"
. "v2ray.com/core/transport/internet/ws" . "v2ray.com/core/transport/internet/ws"
) )
func Test_Connect_ws(t *testing.T) { func Test_Connect_ws(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
(&Config{Pto: "ws", Path: ""}).Apply() (&Config{Path: ""}).Apply()
conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("echo.websocket.org"), 80)) conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("echo.websocket.org"), 80), internet.DialerOptions{})
assert.Error(err).IsNil() assert.Error(err).IsNil()
conn.Write([]byte("echo")) conn.Write([]byte("echo"))
s := make(chan int) s := make(chan int)
@ -31,8 +33,12 @@ func Test_Connect_ws(t *testing.T) {
func Test_Connect_wss(t *testing.T) { func Test_Connect_wss(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
(&Config{Pto: "wss", Path: ""}).Apply() (&Config{Path: ""}).Apply()
conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("echo.websocket.org"), 443)) conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("echo.websocket.org"), 443), internet.DialerOptions{
Stream: &internet.StreamSettings{
Security: internet.StreamSecurityTypeTLS,
},
})
assert.Error(err).IsNil() assert.Error(err).IsNil()
conn.Write([]byte("echo")) conn.Write([]byte("echo"))
s := make(chan int) s := make(chan int)
@ -51,8 +57,12 @@ func Test_Connect_wss(t *testing.T) {
func Test_Connect_wss_1_nil(t *testing.T) { func Test_Connect_wss_1_nil(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
(&Config{Pto: "wss", Path: ""}).Apply() (&Config{Path: ""}).Apply()
conn, err := Dial(nil, v2net.TCPDestination(v2net.DomainAddress("echo.websocket.org"), 443)) conn, err := Dial(nil, v2net.TCPDestination(v2net.DomainAddress("echo.websocket.org"), 443), internet.DialerOptions{
Stream: &internet.StreamSettings{
Security: internet.StreamSecurityTypeTLS,
},
})
assert.Error(err).IsNil() assert.Error(err).IsNil()
conn.Write([]byte("echo")) conn.Write([]byte("echo"))
s := make(chan int) s := make(chan int)
@ -71,8 +81,8 @@ func Test_Connect_wss_1_nil(t *testing.T) {
func Test_Connect_ws_guess(t *testing.T) { func Test_Connect_ws_guess(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
(&Config{Pto: "", Path: ""}).Apply() (&Config{Path: ""}).Apply()
conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("echo.websocket.org"), 80)) conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("echo.websocket.org"), 80), internet.DialerOptions{})
assert.Error(err).IsNil() assert.Error(err).IsNil()
conn.Write([]byte("echo")) conn.Write([]byte("echo"))
s := make(chan int) s := make(chan int)
@ -91,8 +101,12 @@ func Test_Connect_ws_guess(t *testing.T) {
func Test_Connect_wss_guess(t *testing.T) { func Test_Connect_wss_guess(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
(&Config{Pto: "", Path: ""}).Apply() (&Config{Path: ""}).Apply()
conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("echo.websocket.org"), 443)) conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("echo.websocket.org"), 443), internet.DialerOptions{
Stream: &internet.StreamSettings{
Security: internet.StreamSecurityTypeTLS,
},
})
assert.Error(err).IsNil() assert.Error(err).IsNil()
conn.Write([]byte("echo")) conn.Write([]byte("echo"))
s := make(chan int) s := make(chan int)
@ -111,24 +125,25 @@ func Test_Connect_wss_guess(t *testing.T) {
func Test_Connect_wss_guess_fail(t *testing.T) { func Test_Connect_wss_guess_fail(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
(&Config{Pto: "", Path: ""}).Apply() (&Config{Path: ""}).Apply()
_, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("static.kkdev.org"), 443)) _, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("static.kkdev.org"), 443), internet.DialerOptions{
assert.Error(err).IsNotNil() Stream: &internet.StreamSettings{
} Security: internet.StreamSecurityTypeTLS,
},
func Test_Connect_wss_guess_fail_port(t *testing.T) { })
assert := assert.On(t)
(&Config{Pto: "", Path: ""}).Apply()
_, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("static.kkdev.org"), 179))
assert.Error(err).IsNotNil() assert.Error(err).IsNotNil()
} }
func Test_Connect_wss_guess_reuse(t *testing.T) { func Test_Connect_wss_guess_reuse(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
(&Config{Pto: "", Path: "", ConnectionReuse: true}).Apply() (&Config{Path: "", ConnectionReuse: true}).Apply()
i := 3 i := 3
for i != 0 { for i != 0 {
conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("echo.websocket.org"), 443)) conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("echo.websocket.org"), 443), internet.DialerOptions{
Stream: &internet.StreamSettings{
Security: internet.StreamSecurityTypeTLS,
},
})
assert.Error(err).IsNil() assert.Error(err).IsNil()
conn.Write([]byte("echo")) conn.Write([]byte("echo"))
s := make(chan int) s := make(chan int)
@ -155,8 +170,8 @@ func Test_Connect_wss_guess_reuse(t *testing.T) {
func Test_listenWSAndDial(t *testing.T) { func Test_listenWSAndDial(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
(&Config{Pto: "ws", Path: "ws"}).Apply() (&Config{Path: "ws"}).Apply()
listen, err := ListenWS(v2net.DomainAddress("localhost"), 13142) listen, err := ListenWS(v2net.DomainAddress("localhost"), 13142, internet.ListenOptions{})
assert.Error(err).IsNil() assert.Error(err).IsNil()
go func() { go func() {
conn, err := listen.Accept() conn, err := listen.Accept()
@ -170,15 +185,15 @@ func Test_listenWSAndDial(t *testing.T) {
conn.Close() conn.Close()
listen.Close() listen.Close()
}() }()
conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("localhost"), 13142)) conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("localhost"), 13142), internet.DialerOptions{})
assert.Error(err).IsNil() assert.Error(err).IsNil()
conn.Close() conn.Close()
<-time.After(time.Second * 5) <-time.After(time.Second * 5)
conn, err = Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("localhost"), 13142)) conn, err = Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("localhost"), 13142), internet.DialerOptions{})
assert.Error(err).IsNil() assert.Error(err).IsNil()
conn.Close() conn.Close()
<-time.After(time.Second * 15) <-time.After(time.Second * 15)
conn, err = Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("localhost"), 13142)) conn, err = Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("localhost"), 13142), internet.DialerOptions{})
assert.Error(err).IsNil() assert.Error(err).IsNil()
conn.Close() conn.Close()
} }
@ -189,8 +204,17 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
<-time.After(time.Second * 5) <-time.After(time.Second * 5)
assert.Fail("Too slow") assert.Fail("Too slow")
}() }()
(&Config{Pto: "wss", Path: "wss", ConnectionReuse: true, DeveloperInsecureSkipVerify: true, PrivKey: "./../../../testing/tls/key.pem", Cert: "./../../../testing/tls/cert.pem"}).Apply() (&Config{Path: "wss", ConnectionReuse: true}).Apply()
listen, err := ListenWS(v2net.DomainAddress("localhost"), 13143)
listen, err := ListenWS(v2net.DomainAddress("localhost"), 13143, internet.ListenOptions{
Stream: &internet.StreamSettings{
Security: internet.StreamSecurityTypeTLS,
TLSSettings: &internet.TLSSettings{
AllowInsecure: true,
Certs: LoadTestCert(assert),
},
},
})
assert.Error(err).IsNil() assert.Error(err).IsNil()
go func() { go func() {
conn, err := listen.Accept() conn, err := listen.Accept()
@ -198,7 +222,21 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
conn.Close() conn.Close()
listen.Close() listen.Close()
}() }()
conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("localhost"), 13143)) conn, err := Dial(v2net.AnyIP, v2net.TCPDestination(v2net.DomainAddress("localhost"), 13143), internet.DialerOptions{
Stream: &internet.StreamSettings{
Security: internet.StreamSecurityTypeTLS,
TLSSettings: &internet.TLSSettings{
AllowInsecure: true,
Certs: LoadTestCert(assert),
},
},
})
assert.Error(err).IsNil() assert.Error(err).IsNil()
conn.Close() conn.Close()
} }
func LoadTestCert(assert *assert.Assert) []tls.Certificate {
cert, err := tls.LoadX509KeyPair("./../../../testing/tls/cert.pem", "./../../../testing/tls/key.pem")
assert.Error(err).IsNil()
return []tls.Certificate{cert}
}