From 4a887e3b7770c22e93e7cce3ce25b599e3d67390 Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Fri, 16 Dec 2022 19:16:00 +0000 Subject: [PATCH] Use security engine for (tls like) security client in websocket transport --- transport/internet/websocket/dialer.go | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index a56e0a6d2..9244938f7 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -4,7 +4,9 @@ import ( "bytes" "context" "encoding/base64" + "github.com/v2fly/v2ray-core/v5/transport/internet/security" "io" + gonet "net" "net/http" "time" @@ -16,7 +18,6 @@ import ( "github.com/v2fly/v2ray-core/v5/common/session" "github.com/v2fly/v2ray-core/v5/features/extension" "github.com/v2fly/v2ray-core/v5/transport/internet" - "github.com/v2fly/v2ray-core/v5/transport/internet/tls" ) // Dial dials a WebSocket connection to the given destination. @@ -48,9 +49,27 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in protocol := "ws" - if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { + securityEngine, err := security.CreateSecurityEngineFromSettings(ctx, streamSettings) + if err != nil { + return nil, newError("unable to create security engine").Base(err) + } + + if securityEngine != nil { protocol = "wss" - dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1")) + + dialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (gonet.Conn, error) { + conn, err := dialer.NetDial(network, addr) + if err != nil { + return nil, newError("dial TLS connection failed").Base(err) + } + conn, err = securityEngine.Client(conn, + security.OptionWithDestination{Dest: dest}, + security.OptionWithALPN{ALPNs: []string{"http/1.1"}}) + if err != nil { + return nil, newError("unable to create security protocol client from security engine").Base(err) + } + return conn, nil + } } host := dest.NetAddr()