1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-10-19 10:13:37 -04:00
v2fly/transport/internet/transportcommon/httpDialer.go
Xiaokang Wang (Shelikhoo) 7db39fb566
Add (Experimental) Meyka Building Blocks to request Transport (#3120)
* add packetconn assembler

* let kcp use environment dependency injection

* Add destination override to simplified setting

* add dtls dialer

* add dtls listener

* add dtls to default

* fix bugs

* add debug options to freedom outbound

* fix kcp test failure for transport environment
2024-08-22 04:05:05 +01:00

268 lines
6.9 KiB
Go

package transportcommon
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"
"golang.org/x/net/http2"
"github.com/v2fly/v2ray-core/v5/transport/internet/security"
)
type DialerFunc func(ctx context.Context, addr string) (net.Conn, error)
func NewALPNAwareHTTPRoundTripper(ctx context.Context, dialer DialerFunc,
backdropTransport http.RoundTripper,
) http.RoundTripper {
return NewALPNAwareHTTPRoundTripperWithH2Pool(ctx, dialer, backdropTransport, 1)
}
// NewALPNAwareHTTPRoundTripperWithH2Pool creates an instance of RoundTripper that dial to remote HTTPS endpoint with
// an alternative version of TLS implementation.
func NewALPNAwareHTTPRoundTripperWithH2Pool(ctx context.Context, dialer DialerFunc,
backdropTransport http.RoundTripper,
h2PoolSize int,
) http.RoundTripper {
rtImpl := &alpnAwareHTTPRoundTripperImpl{
connectWithH1: map[string]bool{},
backdropTransport: backdropTransport,
pendingConn: map[pendingConnKey]*unclaimedConnection{},
dialer: dialer,
ctx: ctx,
}
rtImpl.init()
return rtImpl
}
type alpnAwareHTTPRoundTripperImpl struct {
accessConnectWithH1 sync.Mutex
connectWithH1 map[string]bool
httpsH1Transport http.RoundTripper
httpsH2Transport http.RoundTripper
backdropTransport http.RoundTripper
accessDialingConnection sync.Mutex
pendingConn map[pendingConnKey]*unclaimedConnection
ctx context.Context
dialer DialerFunc
h2PoolSize int
}
type pendingConnKey struct {
isH2 bool
dest string
}
var (
errEAGAIN = errors.New("incorrect ALPN negotiated, try again with another ALPN")
errEAGAINTooMany = errors.New("incorrect ALPN negotiated")
errExpired = errors.New("connection have expired")
)
func (r *alpnAwareHTTPRoundTripperImpl) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme != "https" {
return r.backdropTransport.RoundTrip(req)
}
for retryCount := 0; retryCount < 5; retryCount++ {
effectivePort := req.URL.Port()
if effectivePort == "" {
effectivePort = "443"
}
if r.getShouldConnectWithH1(fmt.Sprintf("%v:%v", req.URL.Hostname(), effectivePort)) {
resp, err := r.httpsH1Transport.RoundTrip(req)
if errors.Is(err, errEAGAIN) {
continue
}
return resp, err
}
resp, err := r.httpsH2Transport.RoundTrip(req)
if errors.Is(err, errEAGAIN) {
continue
}
return resp, err
}
return nil, errEAGAINTooMany
}
func (r *alpnAwareHTTPRoundTripperImpl) getShouldConnectWithH1(domainName string) bool {
r.accessConnectWithH1.Lock()
defer r.accessConnectWithH1.Unlock()
if value, set := r.connectWithH1[domainName]; set {
return value
}
return false
}
func (r *alpnAwareHTTPRoundTripperImpl) setShouldConnectWithH1(domainName string) {
r.accessConnectWithH1.Lock()
defer r.accessConnectWithH1.Unlock()
r.connectWithH1[domainName] = true
}
func (r *alpnAwareHTTPRoundTripperImpl) clearShouldConnectWithH1(domainName string) {
r.accessConnectWithH1.Lock()
defer r.accessConnectWithH1.Unlock()
r.connectWithH1[domainName] = false
}
func getPendingConnectionID(dest string, alpnIsH2 bool) pendingConnKey {
return pendingConnKey{isH2: alpnIsH2, dest: dest}
}
func (r *alpnAwareHTTPRoundTripperImpl) putConn(addr string, alpnIsH2 bool, conn net.Conn) {
connID := getPendingConnectionID(addr, alpnIsH2)
r.pendingConn[connID] = NewUnclaimedConnection(conn, time.Minute)
}
func (r *alpnAwareHTTPRoundTripperImpl) getConn(addr string, alpnIsH2 bool) net.Conn {
connID := getPendingConnectionID(addr, alpnIsH2)
if conn, ok := r.pendingConn[connID]; ok {
delete(r.pendingConn, connID)
if claimedConnection, err := conn.claimConnection(); err == nil {
return claimedConnection
}
}
return nil
}
func (r *alpnAwareHTTPRoundTripperImpl) dialOrGetTLSWithExpectedALPN(ctx context.Context, addr string, expectedH2 bool) (net.Conn, error) {
r.accessDialingConnection.Lock()
defer r.accessDialingConnection.Unlock()
if r.getShouldConnectWithH1(addr) == expectedH2 {
return nil, errEAGAIN
}
// Get a cached connection if possible to reduce preflight connection closed without sending data
if gconn := r.getConn(addr, expectedH2); gconn != nil {
return gconn, nil
}
conn, err := r.dialTLS(ctx, addr)
if err != nil {
return nil, err
}
protocol := ""
if connAPLNGetter, ok := conn.(security.ConnectionApplicationProtocol); ok {
connectionALPN, err := connAPLNGetter.GetConnectionApplicationProtocol()
if err != nil {
return nil, newError("failed to get connection ALPN").Base(err).AtWarning()
}
protocol = connectionALPN
}
protocolIsH2 := protocol == http2.NextProtoTLS
if protocolIsH2 == expectedH2 {
return conn, err
}
r.putConn(addr, protocolIsH2, conn)
if protocolIsH2 {
r.clearShouldConnectWithH1(addr)
} else {
r.setShouldConnectWithH1(addr)
}
return nil, errEAGAIN
}
func (r *alpnAwareHTTPRoundTripperImpl) dialTLS(ctx context.Context, addr string) (net.Conn, error) {
_ = ctx
return r.dialer(r.ctx, addr)
}
func (r *alpnAwareHTTPRoundTripperImpl) init() {
if r.h2PoolSize >= 2 {
r.httpsH2Transport = newH2TransportPool(int64(r.h2PoolSize), func() *http2.Transport {
return &http2.Transport{
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
return r.dialOrGetTLSWithExpectedALPN(context.Background(), addr, true)
},
}
})
} else {
r.httpsH2Transport = &http2.Transport{
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
return r.dialOrGetTLSWithExpectedALPN(context.Background(), addr, true)
},
}
}
r.httpsH1Transport = &http.Transport{
DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
return r.dialOrGetTLSWithExpectedALPN(ctx, addr, false)
},
}
}
func NewUnclaimedConnection(conn net.Conn, expireTime time.Duration) *unclaimedConnection {
c := &unclaimedConnection{
Conn: conn,
}
time.AfterFunc(expireTime, c.tick)
return c
}
type unclaimedConnection struct {
net.Conn
claimed bool
access sync.Mutex
}
func (c *unclaimedConnection) claimConnection() (net.Conn, error) {
c.access.Lock()
defer c.access.Unlock()
if !c.claimed {
c.claimed = true
return c.Conn, nil
}
return nil, errExpired
}
func (c *unclaimedConnection) tick() {
c.access.Lock()
defer c.access.Unlock()
if !c.claimed {
c.claimed = true
c.Conn.Close()
c.Conn = nil
}
}
type h2TransportFactory func() *http2.Transport
func newH2TransportPool(size int64, h2factory h2TransportFactory) *h2TransportPool {
return &h2TransportPool{
pool: make([]*http2.Transport, size),
size: size,
h2factory: h2factory,
}
}
type h2TransportPool struct {
pool []*http2.Transport
h2factory h2TransportFactory
usageCount int64
size int64
}
func (h *h2TransportPool) RoundTrip(request *http.Request) (*http.Response, error) {
currentSlot := h.usageCount % h.size
h.usageCount++
if h.pool[currentSlot] == nil {
h.pool[currentSlot] = h.h2factory()
}
return h.pool[currentSlot].RoundTrip(request)
}