mirror of
https://github.com/v2fly/v2ray-core.git
synced 2025-02-20 23:47:21 -05:00
http request decide protocol based on ALPN
This commit is contained in:
parent
0e519b9fb3
commit
b7e8554ee3
@ -7,10 +7,12 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
gonet "net"
|
||||
"net/http"
|
||||
|
||||
"github.com/v2fly/v2ray-core/v5/common"
|
||||
"github.com/v2fly/v2ray-core/v5/transport/internet/transportcommon"
|
||||
|
||||
"github.com/v2fly/v2ray-core/v5/common"
|
||||
"github.com/v2fly/v2ray-core/v5/common/net"
|
||||
"github.com/v2fly/v2ray-core/v5/transport/internet/request"
|
||||
)
|
||||
@ -25,20 +27,22 @@ type httpTripperClient struct {
|
||||
assembly request.TransportClientAssembly
|
||||
}
|
||||
|
||||
type unimplementedBackDrop struct {
|
||||
}
|
||||
|
||||
func (u unimplementedBackDrop) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return nil, newError("unimplemented")
|
||||
}
|
||||
|
||||
func (h *httpTripperClient) OnTransportClientAssemblyReady(assembly request.TransportClientAssembly) {
|
||||
h.assembly = assembly
|
||||
}
|
||||
|
||||
func (h *httpTripperClient) RoundTrip(ctx context.Context, req request.Request, opts ...request.RoundTripperOption) (resp request.Response, err error) {
|
||||
if h.httpRTT == nil {
|
||||
h.httpRTT = &http.Transport{
|
||||
DialContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) {
|
||||
return h.assembly.AutoImplDialer().Dial(ctx)
|
||||
},
|
||||
DialTLSContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) {
|
||||
return h.assembly.AutoImplDialer().Dial(ctx)
|
||||
},
|
||||
}
|
||||
h.httpRTT = transportcommon.NewALPNAwareHTTPRoundTripper(ctx, func(ctx context.Context, addr string) (gonet.Conn, error) {
|
||||
return h.assembly.AutoImplDialer().Dial(ctx)
|
||||
}, unimplementedBackDrop{})
|
||||
}
|
||||
|
||||
connectionTagStr := base64.RawURLEncoding.EncodeToString(req.ConnectionTag)
|
||||
|
@ -38,7 +38,8 @@ func meekDial(ctx context.Context, dest net.Destination, streamSettings *interne
|
||||
}
|
||||
httprtSetting := &httprt.ClientConfig{Http: &httprt.HTTPConfig{
|
||||
UrlPrefix: meekSetting.Url,
|
||||
}}
|
||||
},
|
||||
}
|
||||
request := &assembly.Config{
|
||||
Assembler: serial.ToTypedMessage(simpleAssembler),
|
||||
Roundtripper: serial.ToTypedMessage(httprtSetting),
|
||||
|
5
transport/internet/security/connprop.go
Normal file
5
transport/internet/security/connprop.go
Normal file
@ -0,0 +1,5 @@
|
||||
package security
|
||||
|
||||
type ConnectionApplicationProtocol interface {
|
||||
GetConnectionApplicationProtocol() (string, error)
|
||||
}
|
@ -17,6 +17,13 @@ type Conn struct {
|
||||
*tls.Conn
|
||||
}
|
||||
|
||||
func (c *Conn) GetConnectionApplicationProtocol() (string, error) {
|
||||
if err := c.Handshake(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return c.ConnectionState().NegotiatedProtocol, nil
|
||||
}
|
||||
|
||||
func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
mb = buf.Compact(mb)
|
||||
mb, err := buf.WriteMultiBuffer(c, mb)
|
||||
|
@ -90,7 +90,18 @@ func (e Engine) Client(conn net.Conn, opts ...security.Option) (security.Conn, e
|
||||
if err != nil {
|
||||
return nil, newError("unable to finish utls handshake").Base(err)
|
||||
}
|
||||
return utlsClientConn, nil
|
||||
return uTLSClientConnection{utlsClientConn}, nil
|
||||
}
|
||||
|
||||
type uTLSClientConnection struct {
|
||||
*utls.UConn
|
||||
}
|
||||
|
||||
func (u uTLSClientConnection) GetConnectionApplicationProtocol() (string, error) {
|
||||
if err := u.Handshake(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return u.ConnectionState().NegotiatedProtocol, nil
|
||||
}
|
||||
|
||||
func uTLSConfigFromTLSConfig(config *systls.Config) (*utls.Config, error) { // nolint: unparam
|
||||
|
215
transport/internet/transportcommon/httpDialer.go
Normal file
215
transport/internet/transportcommon/httpDialer.go
Normal file
@ -0,0 +1,215 @@
|
||||
package transportcommon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/v2fly/v2ray-core/v5/transport/internet/security"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type DialerFunc func(ctx context.Context, addr string) (net.Conn, error)
|
||||
|
||||
// NewALPNAwareHTTPRoundTripper creates an instance of RoundTripper that dial to remote HTTPS endpoint with
|
||||
// an alternative version of TLS implementation.
|
||||
func NewALPNAwareHTTPRoundTripper(ctx context.Context, dialer DialerFunc,
|
||||
backdropTransport http.RoundTripper) http.RoundTripper {
|
||||
rtImpl := &alpnAwareHTTPRoundTripperImpl{
|
||||
connectWithH1: map[string]bool{},
|
||||
backdropTransport: backdropTransport,
|
||||
pendingConn: map[pendingConnKey]*unclaimedConnection{},
|
||||
dialer: dialer,
|
||||
ctx: ctx,
|
||||
}
|
||||
rtImpl.init()
|
||||
return rtImpl
|
||||
}
|
||||
|
||||
type alpnAwareHTTPRoundTripperImpl struct {
|
||||
accessConnectWithH1 sync.Mutex
|
||||
connectWithH1 map[string]bool
|
||||
|
||||
httpsH1Transport http.RoundTripper
|
||||
httpsH2Transport http.RoundTripper
|
||||
backdropTransport http.RoundTripper
|
||||
|
||||
accessDialingConnection sync.Mutex
|
||||
pendingConn map[pendingConnKey]*unclaimedConnection
|
||||
|
||||
ctx context.Context
|
||||
dialer DialerFunc
|
||||
}
|
||||
|
||||
type pendingConnKey struct {
|
||||
isH2 bool
|
||||
dest string
|
||||
}
|
||||
|
||||
var errEAGAIN = errors.New("incorrect ALPN negotiated, try again with another ALPN")
|
||||
var errEAGAINTooMany = errors.New("incorrect ALPN negotiated")
|
||||
var errExpired = errors.New("connection have expired")
|
||||
|
||||
func (r *alpnAwareHTTPRoundTripperImpl) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.Scheme != "https" {
|
||||
return r.backdropTransport.RoundTrip(req)
|
||||
}
|
||||
for retryCount := 0; retryCount < 5; retryCount++ {
|
||||
effectivePort := req.URL.Port()
|
||||
if effectivePort == "" {
|
||||
effectivePort = "443"
|
||||
}
|
||||
if r.getShouldConnectWithH1(fmt.Sprintf("%v:%v", req.URL.Hostname(), effectivePort)) {
|
||||
resp, err := r.httpsH1Transport.RoundTrip(req)
|
||||
if errors.Is(err, errEAGAIN) {
|
||||
continue
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
resp, err := r.httpsH2Transport.RoundTrip(req)
|
||||
if errors.Is(err, errEAGAIN) {
|
||||
continue
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
return nil, errEAGAINTooMany
|
||||
}
|
||||
|
||||
func (r *alpnAwareHTTPRoundTripperImpl) getShouldConnectWithH1(domainName string) bool {
|
||||
r.accessConnectWithH1.Lock()
|
||||
defer r.accessConnectWithH1.Unlock()
|
||||
if value, set := r.connectWithH1[domainName]; set {
|
||||
return value
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *alpnAwareHTTPRoundTripperImpl) setShouldConnectWithH1(domainName string) {
|
||||
r.accessConnectWithH1.Lock()
|
||||
defer r.accessConnectWithH1.Unlock()
|
||||
r.connectWithH1[domainName] = true
|
||||
}
|
||||
|
||||
func (r *alpnAwareHTTPRoundTripperImpl) clearShouldConnectWithH1(domainName string) {
|
||||
r.accessConnectWithH1.Lock()
|
||||
defer r.accessConnectWithH1.Unlock()
|
||||
r.connectWithH1[domainName] = false
|
||||
}
|
||||
|
||||
func getPendingConnectionID(dest string, alpnIsH2 bool) pendingConnKey {
|
||||
return pendingConnKey{isH2: alpnIsH2, dest: dest}
|
||||
}
|
||||
|
||||
func (r *alpnAwareHTTPRoundTripperImpl) putConn(addr string, alpnIsH2 bool, conn net.Conn) {
|
||||
connId := getPendingConnectionID(addr, alpnIsH2)
|
||||
r.pendingConn[connId] = NewUnclaimedConnection(conn, time.Minute)
|
||||
}
|
||||
func (r *alpnAwareHTTPRoundTripperImpl) getConn(addr string, alpnIsH2 bool) net.Conn {
|
||||
connId := getPendingConnectionID(addr, alpnIsH2)
|
||||
if conn, ok := r.pendingConn[connId]; ok {
|
||||
delete(r.pendingConn, connId)
|
||||
if claimedConnection, err := conn.claimConnection(); err == nil {
|
||||
return claimedConnection
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (r *alpnAwareHTTPRoundTripperImpl) dialOrGetTLSWithExpectedALPN(ctx context.Context, addr string, expectedH2 bool) (net.Conn, error) {
|
||||
r.accessDialingConnection.Lock()
|
||||
defer r.accessDialingConnection.Unlock()
|
||||
|
||||
if r.getShouldConnectWithH1(addr) == expectedH2 {
|
||||
return nil, errEAGAIN
|
||||
}
|
||||
|
||||
//Get a cached connection if possible to reduce preflight connection closed without sending data
|
||||
if gconn := r.getConn(addr, expectedH2); gconn != nil {
|
||||
return gconn, nil
|
||||
}
|
||||
|
||||
conn, err := r.dialTLS(ctx, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protocol := ""
|
||||
if connAPLNGetter, ok := conn.(security.ConnectionApplicationProtocol); ok {
|
||||
connectionALPN, err := connAPLNGetter.GetConnectionApplicationProtocol()
|
||||
if err != nil {
|
||||
return nil, newError("failed to get connection ALPN").Base(err).AtWarning()
|
||||
}
|
||||
protocol = connectionALPN
|
||||
}
|
||||
|
||||
protocolIsH2 := protocol == http2.NextProtoTLS
|
||||
|
||||
if protocolIsH2 == expectedH2 {
|
||||
return conn, err
|
||||
}
|
||||
|
||||
r.putConn(addr, protocolIsH2, conn)
|
||||
|
||||
if protocolIsH2 {
|
||||
r.clearShouldConnectWithH1(addr)
|
||||
} else {
|
||||
r.setShouldConnectWithH1(addr)
|
||||
}
|
||||
|
||||
return nil, errEAGAIN
|
||||
}
|
||||
|
||||
func (r *alpnAwareHTTPRoundTripperImpl) dialTLS(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return r.dialer(r.ctx, addr)
|
||||
}
|
||||
|
||||
func (r *alpnAwareHTTPRoundTripperImpl) init() {
|
||||
r.httpsH2Transport = &http2.Transport{
|
||||
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
return r.dialOrGetTLSWithExpectedALPN(context.Background(), addr, true)
|
||||
},
|
||||
}
|
||||
r.httpsH1Transport = &http.Transport{
|
||||
DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||
return r.dialOrGetTLSWithExpectedALPN(ctx, addr, false)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewUnclaimedConnection(conn net.Conn, expireTime time.Duration) *unclaimedConnection {
|
||||
c := &unclaimedConnection{
|
||||
Conn: conn,
|
||||
}
|
||||
time.AfterFunc(expireTime, c.tick)
|
||||
return c
|
||||
}
|
||||
|
||||
type unclaimedConnection struct {
|
||||
net.Conn
|
||||
claimed bool
|
||||
access sync.Mutex
|
||||
}
|
||||
|
||||
func (c *unclaimedConnection) claimConnection() (net.Conn, error) {
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
if !c.claimed {
|
||||
c.claimed = true
|
||||
return c.Conn, nil
|
||||
}
|
||||
return nil, errExpired
|
||||
}
|
||||
|
||||
func (c *unclaimedConnection) tick() {
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
if !c.claimed {
|
||||
c.claimed = true
|
||||
c.Conn.Close()
|
||||
c.Conn = nil
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user