diff --git a/transport/internet/internal/pool.go b/transport/internet/internal/pool.go index 07af634a9..4fa5a2dcb 100644 --- a/transport/internet/internal/pool.go +++ b/transport/internet/internal/pool.go @@ -9,50 +9,59 @@ import ( "v2ray.com/core/common/signal" ) +// ConnectionRecyler is the interface for recycling connections. type ConnectionRecyler interface { - Put(ConnectionId, net.Conn) + // Put returns a connection back to a connection pool. + Put(ConnectionID, net.Conn) } -type ConnectionId struct { +// ConnectionID is the ID of a connection. +type ConnectionID struct { Local v2net.Address Remote v2net.Address RemotePort v2net.Port } -func NewConnectionId(source v2net.Address, dest v2net.Destination) ConnectionId { - return ConnectionId{ +// NewConnectionID creates a new ConnectionId. +func NewConnectionID(source v2net.Address, dest v2net.Destination) ConnectionID { + return ConnectionID{ Local: source, Remote: dest.Address, RemotePort: dest.Port, } } +// ExpiringConnection is a connection that will expire in certain time. type ExpiringConnection struct { conn net.Conn expire time.Time } -func (o *ExpiringConnection) Expired() bool { - return o.expire.Before(time.Now()) +// Expired returns true if the connection has expired. +func (ec *ExpiringConnection) Expired() bool { + return ec.expire.Before(time.Now()) } +// Pool is a connection pool. type Pool struct { sync.Mutex - connsByDest map[ConnectionId][]*ExpiringConnection + connsByDest map[ConnectionID][]*ExpiringConnection cleanupOnce signal.Once } +// NewConnectionPool creates a new Pool. func NewConnectionPool() *Pool { return &Pool{ - connsByDest: make(map[ConnectionId][]*ExpiringConnection), + connsByDest: make(map[ConnectionID][]*ExpiringConnection), } } -func (o *Pool) Get(id ConnectionId) net.Conn { - o.Lock() - defer o.Unlock() +// Get returns a connection with matching connection ID. Nil if not found. +func (p *Pool) Get(id ConnectionID) net.Conn { + p.Lock() + defer p.Unlock() - list, found := o.connsByDest[id] + list, found := p.connsByDest[id] if !found { return nil } @@ -72,18 +81,18 @@ func (o *Pool) Get(id ConnectionId) net.Conn { list[connIdx] = list[listLen-1] } list = list[:listLen-1] - o.connsByDest[id] = list + p.connsByDest[id] = list return conn.conn } -func (o *Pool) Cleanup() { - defer o.cleanupOnce.Reset() +func (p *Pool) cleanup() { + defer p.cleanupOnce.Reset() - for len(o.connsByDest) > 0 { + for len(p.connsByDest) > 0 { time.Sleep(time.Second * 5) expiredConns := make([]net.Conn, 0, 16) - o.Lock() - for dest, list := range o.connsByDest { + p.Lock() + for dest, list := range p.connsByDest { validConns := make([]*ExpiringConnection, 0, len(list)) for _, conn := range list { if conn.Expired() { @@ -93,34 +102,35 @@ func (o *Pool) Cleanup() { } } if len(validConns) != len(list) { - o.connsByDest[dest] = validConns + p.connsByDest[dest] = validConns } } - o.Unlock() + p.Unlock() for _, conn := range expiredConns { conn.Close() } } } -func (o *Pool) Put(id ConnectionId, conn net.Conn) { +// Put implements ConnectionRecyler.Put(). +func (p *Pool) Put(id ConnectionID, conn net.Conn) { expiringConn := &ExpiringConnection{ conn: conn, expire: time.Now().Add(time.Second * 4), } - o.Lock() - defer o.Unlock() + p.Lock() + defer p.Unlock() - list, found := o.connsByDest[id] + list, found := p.connsByDest[id] if !found { list = []*ExpiringConnection{expiringConn} } else { list = append(list, expiringConn) } - o.connsByDest[id] = list + p.connsByDest[id] = list - o.cleanupOnce.Do(func() { - go o.Cleanup() + p.cleanupOnce.Do(func() { + go p.cleanup() }) } diff --git a/transport/internet/internal/pool_test.go b/transport/internet/internal/pool_test.go index 0c4592779..65f5a6d0d 100644 --- a/transport/internet/internal/pool_test.go +++ b/transport/internet/internal/pool_test.go @@ -51,11 +51,11 @@ func TestConnectionCache(t *testing.T) { assert := assert.On(t) pool := NewConnectionPool() - conn := pool.Get(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80)))) + conn := pool.Get(NewConnectionID(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80)))) assert.Pointer(conn).IsNil() - pool.Put(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80))), &TestConnection{id: "test"}) - conn = pool.Get(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80)))) + pool.Put(NewConnectionID(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80))), &TestConnection{id: "test"}) + conn = pool.Get(NewConnectionID(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80)))) assert.String(conn.(*TestConnection).id).Equals("test") } @@ -64,9 +64,9 @@ func TestConnectionRecycle(t *testing.T) { pool := NewConnectionPool() c := &TestConnection{id: "test"} - pool.Put(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80))), c) + pool.Put(NewConnectionID(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80))), c) time.Sleep(6 * time.Second) assert.Bool(c.closed).IsTrue() - conn := pool.Get(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80)))) + conn := pool.Get(NewConnectionID(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80)))) assert.Pointer(conn).IsNil() } diff --git a/transport/internet/internal/sysfd.go b/transport/internet/internal/sysfd.go index bf0f6493b..31be9c53b 100644 --- a/transport/internet/internal/sysfd.go +++ b/transport/internet/internal/sysfd.go @@ -3,13 +3,15 @@ package internal import ( "net" "reflect" + "v2ray.com/core/common/errors" ) var ( - ErrInvalidConn = errors.New("Invalid Connection.") + errInvalidConn = errors.New("Invalid Connection.") ) +// GetSysFd returns the underlying fd of a connection. func GetSysFd(conn net.Conn) (int, error) { cv := reflect.ValueOf(conn) switch ce := cv.Elem(); ce.Kind() { @@ -21,5 +23,5 @@ func GetSysFd(conn net.Conn) (int, error) { return int(fd.Int()), nil } } - return 0, ErrInvalidConn + return 0, errInvalidConn } diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index b18d9c89a..dd2b69c16 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -156,7 +156,7 @@ func (v *Updater) Run() { type SystemConnection interface { net.Conn - Id() internal.ConnectionId + Id() internal.ConnectionID Reset(func([]Segment)) Overhead() int } diff --git a/transport/internet/kcp/connection_test.go b/transport/internet/kcp/connection_test.go index 5edfa5625..1c47b5acd 100644 --- a/transport/internet/kcp/connection_test.go +++ b/transport/internet/kcp/connection_test.go @@ -48,15 +48,15 @@ func (o *NoOpConn) SetWriteDeadline(time.Time) error { return nil } -func (o *NoOpConn) Id() internal.ConnectionId { - return internal.ConnectionId{} +func (o *NoOpConn) Id() internal.ConnectionID { + return internal.ConnectionID{} } func (o *NoOpConn) Reset(input func([]Segment)) {} type NoOpRecycler struct{} -func (o *NoOpRecycler) Put(internal.ConnectionId, net.Conn) {} +func (o *NoOpRecycler) Put(internal.ConnectionID, net.Conn) {} func TestConnectionReadTimeout(t *testing.T) { assert := assert.On(t) diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index d02325038..90cdb5a6e 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -26,7 +26,7 @@ var ( type ClientConnection struct { sync.RWMutex net.Conn - id internal.ConnectionId + id internal.ConnectionID input func([]Segment) reader PacketReader writer PacketWriter @@ -56,7 +56,7 @@ func (o *ClientConnection) Read([]byte) (int, error) { panic("KCP|ClientConnection: Read should not be called.") } -func (o *ClientConnection) Id() internal.ConnectionId { +func (o *ClientConnection) Id() internal.ConnectionID { return o.id } @@ -112,7 +112,7 @@ func DialKCP(src v2net.Address, dest v2net.Destination, options internet.DialerO dest.Network = v2net.Network_UDP log.Info("KCP|Dialer: Dialing KCP to ", dest) - id := internal.NewConnectionId(src, dest) + id := internal.NewConnectionID(src, dest) conn := globalPool.Get(id) if conn == nil { rawConn, err := internet.DialToDest(src, dest) diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index b3db87c06..dae4dc0d2 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -27,7 +27,7 @@ type ConnectionID struct { } type ServerConnection struct { - id internal.ConnectionId + id internal.ConnectionID local net.Addr remote net.Addr writer PacketWriter @@ -73,7 +73,7 @@ func (o *ServerConnection) SetWriteDeadline(time.Time) error { return nil } -func (o *ServerConnection) Id() internal.ConnectionId { +func (o *ServerConnection) Id() internal.ConnectionID { return o.id } @@ -188,7 +188,7 @@ func (v *Listener) OnReceive(payload *buf.Buffer, session *proxy.SessionInfo) { } localAddr := v.hub.Addr() sConn := &ServerConnection{ - id: internal.NewConnectionId(v2net.LocalHostIP, src), + id: internal.NewConnectionID(v2net.LocalHostIP, src), local: localAddr, remote: remoteAddr, writer: &KCPPacketWriter{ @@ -274,7 +274,7 @@ func (v *Listener) Addr() net.Addr { return v.hub.Addr() } -func (v *Listener) Put(internal.ConnectionId, net.Conn) {} +func (v *Listener) Put(internal.ConnectionID, net.Conn) {} type Writer struct { id ConnectionID diff --git a/transport/internet/tcp/connection.go b/transport/internet/tcp/connection.go index 37e98dd9c..2995db272 100644 --- a/transport/internet/tcp/connection.go +++ b/transport/internet/tcp/connection.go @@ -9,7 +9,7 @@ import ( ) type ConnectionManager interface { - Put(internal.ConnectionId, net.Conn) + Put(internal.ConnectionID, net.Conn) } type RawConnection struct { @@ -27,14 +27,14 @@ func (v *RawConnection) SysFd() (int, error) { } type Connection struct { - id internal.ConnectionId + id internal.ConnectionID reusable bool conn net.Conn listener ConnectionManager config *Config } -func NewConnection(id internal.ConnectionId, conn net.Conn, manager ConnectionManager, config *Config) *Connection { +func NewConnection(id internal.ConnectionID, conn net.Conn, manager ConnectionManager, config *Config) *Connection { return &Connection{ id: id, conn: conn, diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index 7f4696a53..3cf4df13e 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -27,7 +27,7 @@ func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOpti } tcpSettings := networkSettings.(*Config) - id := internal.NewConnectionId(src, dest) + id := internal.NewConnectionID(src, dest) var conn net.Conn if dest.Network == v2net.Network_TCP && tcpSettings.IsConnectionReuse() { conn = globalCache.Get(id) diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index 6ed346f4c..834919b2e 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -89,7 +89,7 @@ func (v *TCPListener) Accept() (internet.Connection, error) { return nil, connErr.err } conn := connErr.conn - return NewConnection(internal.ConnectionId{}, conn, v, v.config), nil + return NewConnection(internal.ConnectionID{}, conn, v, v.config), nil case <-time.After(time.Second * 2): } } @@ -125,7 +125,7 @@ func (v *TCPListener) KeepAccepting() { } } -func (v *TCPListener) Put(id internal.ConnectionId, conn net.Conn) { +func (v *TCPListener) Put(id internal.ConnectionID, conn net.Conn) { v.Lock() defer v.Unlock() if !v.acccepting {