From 3a77bbdf6522021b482556e04e2605159973976f Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Mon, 5 Apr 2021 19:34:22 +0100 Subject: [PATCH] fix early data listener bug --- transport/internet/websocket/hub.go | 42 +++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index 4537dd38d..97fa2cd01 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -3,9 +3,13 @@ package websocket import ( + "bytes" "context" "crypto/tls" + "encoding/base64" + "io" "net/http" + "strings" "sync" "time" @@ -20,8 +24,9 @@ import ( ) type requestHandler struct { - path string - ln *Listener + path string + ln *Listener + earlyDataEnabled bool } var upgrader = &websocket.Upgrader{ @@ -34,10 +39,22 @@ var upgrader = &websocket.Upgrader{ } func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - if request.URL.Path != h.path { - writer.WriteHeader(http.StatusNotFound) - return + var earlyData io.Reader + if !h.earlyDataEnabled { + if request.URL.Path != h.path { + writer.WriteHeader(http.StatusNotFound) + return + } + } else { + if strings.HasPrefix(request.URL.RequestURI(), h.path) { + earlyDataStr := request.URL.RequestURI()[len(h.path):] + earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr))) + } else { + writer.WriteHeader(http.StatusNotFound) + return + } } + conn, err := upgrader.Upgrade(writer, request, nil) if err != nil { newError("failed to convert to WebSocket connection").Base(err).WriteToLog() @@ -52,8 +69,12 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req Port: int(0), } } + if earlyData == nil { + h.ln.addConn(newConnection(conn, remoteAddr)) + } else { + h.ln.addConn(newConnectionWithEarlyData(conn, remoteAddr, earlyData)) + } - h.ln.addConn(newConnection(conn, remoteAddr)) } type Listener struct { @@ -114,11 +135,16 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet } l.listener = listener + var useEarlyData = false + if wsSettings.MaxEarlyData != 0 { + useEarlyData = true + } l.server = http.Server{ Handler: &requestHandler{ - path: wsSettings.GetNormalizedPath(), - ln: l, + path: wsSettings.GetNormalizedPath(), + ln: l, + earlyDataEnabled: useEarlyData, }, ReadHeaderTimeout: time.Second * 4, MaxHeaderBytes: 2048,