From 45fbf6f0596dea5cbc7614e13af925c358af1382 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sun, 25 Nov 2018 21:56:24 +0100 Subject: [PATCH] update quic connection handling --- transport/internet/quic/conn.go | 23 ++++- transport/internet/quic/dialer.go | 164 +++++++++++++++++++++++------- transport/internet/quic/hub.go | 13 ++- 3 files changed, 150 insertions(+), 50 deletions(-) diff --git a/transport/internet/quic/conn.go b/transport/internet/quic/conn.go index f1ea166ea..f09682375 100644 --- a/transport/internet/quic/conn.go +++ b/transport/internet/quic/conn.go @@ -10,6 +10,7 @@ import ( "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/net" + "v2ray.com/core/common/signal/done" "v2ray.com/core/transport/internet" ) @@ -133,9 +134,11 @@ func (c *sysConn) SetWriteDeadline(t time.Time) error { } type interConn struct { - stream quic.Stream - local net.Addr - remote net.Addr + context *sessionContext + stream quic.Stream + done *done.Instance + local net.Addr + remote net.Addr } func (c *interConn) Read(b []byte) (int, error) { @@ -162,10 +165,13 @@ func (c *interConn) WriteMultiBuffer(mb buf.MultiBuffer) error { defer reader.Close() for { - nBytes, err := reader.Read(b[:1380]) + nBytes, err := reader.Read(b[:1200]) if err != nil { break } + if nBytes == 0 { + continue + } if _, err := c.Write(b[:nBytes]); err != nil { return err } @@ -179,7 +185,14 @@ func (c *interConn) Write(b []byte) (int, error) { } func (c *interConn) Close() error { - return c.stream.Close() + if c.context != nil { + defer c.context.onInterConnClose() + } + + common.Must(c.done.Close()) + c.stream.CancelRead(1) + c.stream.CancelWrite(1) + return nil } func (c *interConn) LocalAddr() net.Addr { diff --git a/transport/internet/quic/dialer.go b/transport/internet/quic/dialer.go index b797605c7..315b7983f 100644 --- a/transport/internet/quic/dialer.go +++ b/transport/internet/quic/dialer.go @@ -6,21 +6,77 @@ import ( "time" quic "github.com/lucas-clemente/quic-go" - "v2ray.com/core/common" "v2ray.com/core/common/net" + "v2ray.com/core/common/signal/done" + "v2ray.com/core/common/task" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/tls" ) type sessionContext struct { - rawConn *sysConn - session quic.Session + access sync.Mutex + done *done.Instance + rawConn *sysConn + session quic.Session + interConns []*interConn +} + +var errSessionClosed = newError("session closed") + +func (c *sessionContext) openStream(destAddr net.Addr) (*interConn, error) { + c.access.Lock() + defer c.access.Unlock() + + if c.done.Done() { + return nil, errSessionClosed + } + + stream, err := c.session.OpenStream() + if err != nil { + return nil, err + } + + conn := &interConn{ + stream: stream, + done: done.New(), + local: c.session.LocalAddr(), + remote: destAddr, + context: c, + } + + c.interConns = append(c.interConns, conn) + return conn, nil +} + +func (c *sessionContext) onInterConnClose() { + c.access.Lock() + defer c.access.Unlock() + + if c.done.Done() { + return + } + + activeConns := 0 + for _, conn := range c.interConns { + if !conn.done.Done() { + activeConns++ + } + } + + if activeConns > 0 { + return + } + + c.done.Close() + c.session.Close() + c.rawConn.Close() } type clientSessions struct { access sync.Mutex sessions map[net.Destination][]*sessionContext + cleanup *task.Periodic } func isActive(s quic.Session) bool { @@ -37,8 +93,13 @@ func removeInactiveSessions(sessions []*sessionContext) []*sessionContext { for _, s := range sessions { if isActive(s.session) { activeSessions = append(activeSessions, s) - } else { - s.rawConn.Close() + continue + } + if err := s.session.Close(); err != nil { + newError("failed to close session").Base(err).AtWarning().WriteToLog() + } + if err := s.rawConn.Close(); err != nil { + newError("failed to close raw connection").Base(err).AtWarning().WriteToLog() } } @@ -49,21 +110,42 @@ func removeInactiveSessions(sessions []*sessionContext) []*sessionContext { return sessions } -func openStream(sessions []*sessionContext) (quic.Stream, net.Addr) { +func openStream(sessions []*sessionContext, destAddr net.Addr) *interConn { for _, s := range sessions { if !isActive(s.session) { continue } - stream, err := s.session.OpenStream() + conn, err := s.openStream(destAddr) if err != nil { - newError("failed to create stream").Base(err).AtWarning().WriteToLog() continue } - return stream, s.session.LocalAddr() + + return conn } - return nil, nil + return nil +} + +func (s *clientSessions) cleanSessions() error { + s.access.Lock() + defer s.access.Unlock() + + if len(s.sessions) == 0 { + return nil + } + + newSessionMap := make(map[net.Destination][]*sessionContext) + + for dest, sessions := range s.sessions { + sessions = removeInactiveSessions(sessions) + if len(sessions) > 0 { + newSessionMap[dest] = sessions + } + } + + s.sessions = newSessionMap + return nil } func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) { @@ -81,14 +163,10 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo sessions = s } - { - stream, local := openStream(sessions) - if stream != nil { - return &interConn{ - stream: stream, - local: local, - remote: destAddr, - }, nil + if true { + conn := openStream(sessions, destAddr) + if conn != nil { + return conn, nil } } @@ -103,13 +181,11 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo } quicConfig := &quic.Config{ - ConnectionIDLength: 8, - HandshakeTimeout: time.Second * 8, - IdleTimeout: time.Second * 30, - MaxReceiveStreamFlowControlWindow: 128 * 1024, - MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024, - MaxIncomingUniStreams: -1, - MaxIncomingStreams: 32, + ConnectionIDLength: 12, + HandshakeTimeout: time.Second * 8, + IdleTimeout: time.Second * 30, + MaxIncomingUniStreams: -1, + MaxIncomingStreams: -1, } conn, err := wrapSysConn(rawConn, config) @@ -124,23 +200,26 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo return nil, err } - s.sessions[dest] = append(sessions, &sessionContext{ + context := &sessionContext{ session: session, rawConn: conn, - }) - stream, err := session.OpenStream() - if err != nil { - return nil, err + done: done.New(), } - return &interConn{ - stream: stream, - local: session.LocalAddr(), - remote: destAddr, - }, nil + s.sessions[dest] = append(sessions, context) + return context.openStream(destAddr) } var client clientSessions +func init() { + client.sessions = make(map[net.Destination][]*sessionContext) + client.cleanup = &task.Periodic{ + Interval: time.Minute, + Execute: client.cleanSessions, + } + common.Must(client.cleanup.Start()) +} + func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { tlsConfig := tls.ConfigFromStreamSettings(streamSettings) if tlsConfig == nil { @@ -150,9 +229,18 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me } } - destAddr, err := net.ResolveUDPAddr("udp", dest.NetAddr()) - if err != nil { - return nil, err + var destAddr *net.UDPAddr + if dest.Address.Family().IsIP() { + destAddr = &net.UDPAddr{ + IP: dest.Address.IP(), + Port: int(dest.Port), + } + } else { + addr, err := net.ResolveUDPAddr("udp", dest.NetAddr()) + if err != nil { + return nil, err + } + destAddr = addr } config := streamSettings.ProtocolSettings.(*Config) diff --git a/transport/internet/quic/hub.go b/transport/internet/quic/hub.go index 3724aa81d..eb11b4d07 100644 --- a/transport/internet/quic/hub.go +++ b/transport/internet/quic/hub.go @@ -40,6 +40,7 @@ func (l *Listener) acceptStreams(session quic.Session) { conn := &interConn{ stream: stream, + done: done.New(), local: session.LocalAddr(), remote: session.RemoteAddr(), } @@ -101,13 +102,11 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti } quicConfig := &quic.Config{ - ConnectionIDLength: 8, - HandshakeTimeout: time.Second * 8, - IdleTimeout: time.Second * 30, - MaxReceiveStreamFlowControlWindow: 128 * 1024, - MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024, - MaxIncomingStreams: 32, - MaxIncomingUniStreams: -1, + ConnectionIDLength: 12, + HandshakeTimeout: time.Second * 8, + IdleTimeout: time.Second * 30, + MaxIncomingStreams: 256, + MaxIncomingUniStreams: -1, } conn, err := wrapSysConn(rawConn, config)