diff --git a/transport/internet/request/assembler/packetconn/req2packet.go b/transport/internet/request/assembler/packetconn/req2packet.go index baf12c366..c8e7e881a 100644 --- a/transport/internet/request/assembler/packetconn/req2packet.go +++ b/transport/internet/request/assembler/packetconn/req2packet.go @@ -5,6 +5,7 @@ import ( "context" "crypto/rand" "io" + "sync" "time" "github.com/golang-collections/go-datastructures/queue" @@ -102,7 +103,7 @@ copyFromChan: waitTimer.Stop() go func() { reader, writer := io.Pipe() - defer writer.Close() + defer writer.Close() streamingRespOpt := &pipedStreamingRespOption{writer} go func() { for { @@ -176,7 +177,7 @@ func (r *requestToPacketConnClientSession) Close() error { func newRequestToPacketConnServer(ctx context.Context, config *ServerConfig) *requestToPacketConnServer { return &requestToPacketConnServer{ - sessionMap: make(map[string]*requestToPacketConnServerSession), + sessionMap: sync.Map{}, ctx: ctx, config: config, } @@ -185,7 +186,7 @@ func newRequestToPacketConnServer(ctx context.Context, config *ServerConfig) *re type requestToPacketConnServer struct { packetSessionReceiver request.SessionReceiver - sessionMap map[string]*requestToPacketConnServerSession + sessionMap sync.Map ctx context.Context config *ServerConfig @@ -203,7 +204,15 @@ func (r *requestToPacketConnServer) OnRoundTrip(ctx context.Context, req request return request.Response{}, newError("nil session id") } sessionID := string(SessionID) - session, found := r.sessionMap[sessionID] + var session *requestToPacketConnServerSession + sessionAny, found := r.sessionMap.Load(sessionID) + if found { + var ok bool + session, ok = sessionAny.(*requestToPacketConnServerSession) + if !ok { + return request.Response{}, newError("failed to cast session") + } + } if !found { ctxWithFinish, finish := context.WithCancel(ctx) session = &requestToPacketConnServerSession{ @@ -218,8 +227,10 @@ func (r *requestToPacketConnServer) OnRoundTrip(ctx context.Context, req request maxWriteDuration: int(r.config.MaxWriteDurationMs), maxSimultaneousWriteConnection: int(r.config.MaxSimultaneousWriteConnection), } - r.sessionMap[sessionID] = session - err = r.packetSessionReceiver.OnNewSession(ctx, session) + _, loaded := r.sessionMap.LoadOrStore(sessionID, session) + if !loaded { + err = r.packetSessionReceiver.OnNewSession(ctx, session) + } } if err != nil { return request.Response{}, err @@ -228,7 +239,7 @@ func (r *requestToPacketConnServer) OnRoundTrip(ctx context.Context, req request } func (r *requestToPacketConnServer) removeSessionID(sessionID []byte) { - delete(r.sessionMap, string(sessionID)) + r.sessionMap.Delete(string(sessionID)) } type requestToPacketConnServerSession struct {