diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index b1c0a6d58..bc603565e 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -7,11 +7,13 @@ import ( "context" "encoding/base64" "io" + "net/http" "time" "github.com/v2fly/v2ray-core/v4/features/extension" "github.com/gorilla/websocket" + core "github.com/v2fly/v2ray-core/v4" "github.com/v2fly/v2ray-core/v4/common" @@ -91,7 +93,7 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in }), nil } - conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader()) + conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader()) // nolint: bodyclose if err != nil { var reason string if resp != nil { @@ -124,7 +126,20 @@ func (d dialerWithEarlyData) Dial(earlyData []byte) (*websocket.Conn, error) { return nil, newError("websocket delayed dialer cannot encode early data tail").Base(errc) } - conn, resp, err := d.dialer.Dial(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader()) + dialFunction := func() (*websocket.Conn, *http.Response, error) { + return d.dialer.Dial(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader()) + } + + if d.config.EarlyDataHeaderName != "" { + dialFunction = func() (*websocket.Conn, *http.Response, error) { + earlyDataStr := earlyDataBuf.String() + currentHeader := d.config.GetRequestHeader() + currentHeader.Set(d.config.EarlyDataHeaderName, earlyDataStr) + return d.dialer.Dial(d.uriBase, currentHeader) + } + } + + conn, resp, err := dialFunction() // nolint: bodyclose if err != nil { var reason string if resp != nil { @@ -161,7 +176,18 @@ func (d dialerWithEarlyDataRelayed) Dial(earlyData []byte) (io.ReadWriteCloser, return nil, newError("websocket delayed dialer cannot encode early data tail").Base(errc) } - conn, err := d.forwarder.DialWebsocket(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader()) + dialFunction := func() (io.ReadWriteCloser, error) { + return d.forwarder.DialWebsocket(d.uriBase+earlyDataBuf.String(), d.config.GetRequestHeader()) + } + + if d.config.EarlyDataHeaderName != "" { + earlyDataStr := earlyDataBuf.String() + currentHeader := d.config.GetRequestHeader() + currentHeader.Set(d.config.EarlyDataHeaderName, earlyDataStr) + return d.forwarder.DialWebsocket(d.uriBase, currentHeader) + } + + conn, err := dialFunction() if err != nil { var reason string return nil, newError("failed to dial to (", d.uriBase, ") with early data: ", reason).Base(err)