diff --git a/transport/internet/quic/dialer.go b/transport/internet/quic/dialer.go index c8fe92c51..e299a7f28 100644 --- a/transport/internet/quic/dialer.go +++ b/transport/internet/quic/dialer.go @@ -13,47 +13,54 @@ import ( "v2ray.com/core/transport/internet/tls" ) -type clientSessions struct { - access sync.Mutex - sessions map[net.Destination][]quic.Session +type sessionContext struct { + rawConn *sysConn + session quic.Session } -func removeInactiveSessions(sessions []quic.Session) []quic.Session { - lastActive := 0 +type clientSessions struct { + access sync.Mutex + sessions map[net.Destination][]*sessionContext +} + +func isActive(s quic.Session) bool { + select { + case <-s.Context().Done(): + return false + default: + return true + } +} + +func removeInactiveSessions(sessions []*sessionContext) []*sessionContext { + activeSessions := make([]*sessionContext, 0, len(sessions)) for _, s := range sessions { - active := true - select { - case <-s.Context().Done(): - active = false - default: - } - if active { - sessions[lastActive] = s - lastActive++ + if isActive(s.session) { + activeSessions = append(activeSessions, s) + } else { + s.rawConn.Close() + s.session.Close() } } - if lastActive < len(sessions) { - for i := lastActive; i < len(sessions); i++ { - sessions[i] = nil - } - sessions = sessions[:lastActive] + if len(activeSessions) < len(sessions) { + return activeSessions } return sessions } -func openStream(sessions []quic.Session) (quic.Stream, net.Addr, error) { +func openStream(sessions []*sessionContext) (quic.Stream, net.Addr) { for _, s := range sessions { - stream, err := s.OpenStream() + stream, err := s.session.OpenStream() if err != nil { - newError("failed to create stream").Base(err).WriteToLog() + newError("failed to create stream").Base(err).AtWarning().WriteToLog() continue } - return stream, s.LocalAddr(), nil + return stream, s.session.LocalAddr() } - return nil, nil, nil + return nil, nil } func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) { @@ -61,12 +68,12 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo defer s.access.Unlock() if s.sessions == nil { - s.sessions = make(map[net.Destination][]quic.Session) + s.sessions = make(map[net.Destination][]*sessionContext) } dest := net.DestinationFromAddr(destAddr) - var sessions []quic.Session + var sessions []*sessionContext if s, found := s.sessions[dest]; found { sessions = s } @@ -74,10 +81,7 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo sessions = removeInactiveSessions(sessions) s.sessions[dest] = sessions - stream, local, err := openStream(sessions) - if err != nil { - return nil, err - } + stream, local := openStream(sessions) if stream != nil { return &interConn{ stream: stream, @@ -96,8 +100,8 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo quicConfig := &quic.Config{ ConnectionIDLength: 12, - HandshakeTimeout: time.Second * 4, - IdleTimeout: time.Second * 60, + HandshakeTimeout: time.Second * 8, + IdleTimeout: time.Second * 600, MaxReceiveStreamFlowControlWindow: 512 * 1024, MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024, MaxIncomingUniStreams: -1, @@ -111,11 +115,14 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo session, err := quic.DialContext(context.Background(), conn, destAddr, "", tlsConfig.GetTLSConfig(tls.WithDestination(dest)), quicConfig) if err != nil { - rawConn.Close() + conn.Close() return nil, err } - s.sessions[dest] = append(sessions, session) + s.sessions[dest] = append(sessions, &sessionContext{ + session: session, + rawConn: conn, + }) stream, err = session.OpenStream() if err != nil { return nil, err diff --git a/transport/internet/quic/hub.go b/transport/internet/quic/hub.go index 4cfbcef1a..c22a876f0 100644 --- a/transport/internet/quic/hub.go +++ b/transport/internet/quic/hub.go @@ -8,13 +8,16 @@ import ( "v2ray.com/core/common" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol/tls/cert" + "v2ray.com/core/common/signal/done" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/tls" ) // Listener is an internet.Listener that listens for TCP connections. type Listener struct { + rawConn *sysConn listener quic.Listener + done *done.Instance addConn internet.ConnHandler } @@ -22,9 +25,17 @@ func (l *Listener) acceptStreams(session quic.Session) { for { stream, err := session.AcceptStream() if err != nil { - newError("failed to accept stream").Base(err).WriteToLog() - session.Close() - return + newError("failed to accept stream").Base(err).AtWarning().WriteToLog() + select { + case <-session.Context().Done(): + return + case <-l.done.Wait(): + session.Close() + return + default: + time.Sleep(time.Second) + continue + } } conn := &interConn{ @@ -42,7 +53,10 @@ func (l *Listener) keepAccepting() { for { conn, err := l.listener.Accept() if err != nil { - newError("failed to accept QUIC sessions").Base(err).WriteToLog() + newError("failed to accept QUIC sessions").Base(err).AtWarning().WriteToLog() + if l.done.Done() { + break + } time.Sleep(time.Second) continue } @@ -57,7 +71,10 @@ func (l *Listener) Addr() net.Addr { // Close implements internet.Listener.Close. func (l *Listener) Close() error { - return l.listener.Close() + l.done.Close() + l.listener.Close() + l.rawConn.Close() + return nil } // Listen creates a new Listener based on configurations. @@ -85,11 +102,11 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti quicConfig := &quic.Config{ ConnectionIDLength: 12, - HandshakeTimeout: time.Second * 4, - IdleTimeout: time.Second * 60, + HandshakeTimeout: time.Second * 8, + IdleTimeout: time.Second * 600, MaxReceiveStreamFlowControlWindow: 512 * 1024, MaxReceiveConnectionFlowControlWindow: 4 * 1024 * 1024, - MaxIncomingStreams: 64, + MaxIncomingStreams: 8192, MaxIncomingUniStreams: -1, } @@ -101,11 +118,13 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti qListener, err := quic.Listen(conn, tlsConfig.GetTLSConfig(), quicConfig) if err != nil { - rawConn.Close() + conn.Close() return nil, err } listener := &Listener{ + done: done.New(), + rawConn: conn, listener: qListener, addConn: handler, }