From c055a08b2c106b01410e9648954b717378818099 Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Sun, 19 Feb 2023 14:41:28 +0000 Subject: [PATCH] Refine header based websocket earlydata fix --- transport/internet/websocket/hub.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index e344426cf..d3775334c 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -38,7 +38,8 @@ var upgrader = &websocket.Upgrader{ } func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - var earlyDataStr string + responseHeader := http.Header{} + var earlyData io.Reader if !h.earlyDataEnabled { // nolint: gocritic if request.URL.Path != h.path { @@ -50,11 +51,14 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req writer.WriteHeader(http.StatusNotFound) return } - earlyDataStr = request.Header.Get(h.earlyDataHeaderName) + earlyDataStr := request.Header.Get(h.earlyDataHeaderName) earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr))) + if strings.EqualFold("Sec-WebSocket-Protocol", h.earlyDataHeaderName) { + responseHeader.Set(h.earlyDataHeaderName, earlyDataStr) + } } else { if strings.HasPrefix(request.URL.RequestURI(), h.path) { - earlyDataStr = request.URL.RequestURI()[len(h.path):] + earlyDataStr := request.URL.RequestURI()[len(h.path):] earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr))) } else { writer.WriteHeader(http.StatusNotFound) @@ -62,11 +66,6 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } } - responseHeader := http.Header{} - if h.earlyDataEnabled && h.earlyDataHeaderName != "" { - responseHeader.Set(h.earlyDataHeaderName, earlyDataStr) - } - conn, err := upgrader.Upgrade(writer, request, responseHeader) if err != nil { newError("failed to convert to WebSocket connection").Base(err).WriteToLog()