diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index 1c33410a2..e44e6b1b4 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -86,8 +86,8 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er var iConn internet.Connection = session - if config := v2tls.ConfigFromContext(ctx, v2tls.WithDestination(dest)); config != nil { - tlsConn := tls.Client(iConn, config.GetTLSConfig()) + if config := v2tls.ConfigFromContext(ctx); config != nil { + tlsConn := tls.Client(iConn, config.GetTLSConfig(v2tls.WithDestination(dest))) iConn = tlsConn } diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index 1b5f544d2..d90ed46e2 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -27,8 +27,8 @@ func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error return nil, err } - if config := tls.ConfigFromContext(ctx, tls.WithDestination(dest), tls.WithNextProto("h2")); config != nil { - conn = tls.Client(conn, config.GetTLSConfig()) + if config := tls.ConfigFromContext(ctx); config != nil { + conn = tls.Client(conn, config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("h2"))) } tcpSettings := getTCPSettingsFromContext(ctx) diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index 4f535a08b..e52cc5c00 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -39,8 +39,8 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler addConn: handler, } - if config := tls.ConfigFromContext(ctx, tls.WithNextProto("h2")); config != nil { - l.tlsConfig = config.GetTLSConfig() + if config := tls.ConfigFromContext(ctx); config != nil { + l.tlsConfig = config.GetTLSConfig(tls.WithNextProto("h2")) } if tcpSettings.HeaderSettings != nil { diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index 1faf1f7a5..7172df527 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -25,7 +25,7 @@ func (c *Config) BuildCertificates() []tls.Certificate { return certs } -func (c *Config) GetTLSConfig() *tls.Config { +func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { config := &tls.Config{ ClientSessionCache: globalSessionCache, NextProtos: []string{"http/1.1"}, @@ -34,6 +34,10 @@ func (c *Config) GetTLSConfig() *tls.Config { return config } + for _, opt := range opts { + opt(config) + } + config.InsecureSkipVerify = c.AllowInsecure config.Certificates = c.BuildCertificates() config.BuildNameToCertificate() @@ -47,10 +51,10 @@ func (c *Config) GetTLSConfig() *tls.Config { return config } -type Option func(*Config) +type Option func(*tls.Config) func WithDestination(dest net.Destination) Option { - return func(config *Config) { + return func(config *tls.Config) { if dest.Address.Family().IsDomain() && len(config.ServerName) == 0 { config.ServerName = dest.Address.Domain() } @@ -58,23 +62,21 @@ func WithDestination(dest net.Destination) Option { } func WithNextProto(protocol ...string) Option { - return func(config *Config) { - if len(config.NextProtocol) == 0 { - config.NextProtocol = protocol + return func(config *tls.Config) { + if len(config.NextProtos) == 0 { + config.NextProtos = protocol } } } -func ConfigFromContext(ctx context.Context, opts ...Option) *Config { +func ConfigFromContext(ctx context.Context) *Config { securitySettings := internet.SecuritySettingsFromContext(ctx) if securitySettings == nil { return nil } - if config, ok := securitySettings.(*Config); ok { - for _, opt := range opts { - opt(config) - } - return config + config, ok := securitySettings.(*Config) + if !ok { + return nil } - return nil + return config } diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index 9d6cf9bf9..8a1b0c713 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -41,9 +41,9 @@ func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error) protocol := "ws" - if config := tls.ConfigFromContext(ctx, tls.WithDestination(dest)); config != nil { + if config := tls.ConfigFromContext(ctx); config != nil { protocol = "wss" - dialer.TLSClientConfig = config.GetTLSConfig() + dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest)) } host := dest.NetAddr()