diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index f8500dfea..e344426cf 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -38,6 +38,7 @@ var upgrader = &websocket.Upgrader{ } func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + var earlyDataStr string var earlyData io.Reader if !h.earlyDataEnabled { // nolint: gocritic if request.URL.Path != h.path { @@ -49,11 +50,11 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req writer.WriteHeader(http.StatusNotFound) return } - earlyDataStr := request.Header.Get(h.earlyDataHeaderName) + earlyDataStr = request.Header.Get(h.earlyDataHeaderName) earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr))) } else { if strings.HasPrefix(request.URL.RequestURI(), h.path) { - earlyDataStr := request.URL.RequestURI()[len(h.path):] + earlyDataStr = request.URL.RequestURI()[len(h.path):] earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr))) } else { writer.WriteHeader(http.StatusNotFound) @@ -61,7 +62,12 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } } - conn, err := upgrader.Upgrade(writer, request, nil) + responseHeader := http.Header{} + if h.earlyDataEnabled && h.earlyDataHeaderName != "" { + responseHeader.Set(h.earlyDataHeaderName, earlyDataStr) + } + + conn, err := upgrader.Upgrade(writer, request, responseHeader) if err != nil { newError("failed to convert to WebSocket connection").Base(err).WriteToLog() return