From 22379e5a6b22eb5198318d1a032166a6fe145237 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Thu, 24 Nov 2016 23:16:05 +0100 Subject: [PATCH] refactor connection pool --- transport/internet/internal/pool.go | 121 +++++++++++++++++++++ transport/internet/internal/pool_test.go | 72 ++++++++++++ transport/internet/tcp/connection.go | 10 +- transport/internet/tcp/connection_cache.go | 112 ------------------- transport/internet/tcp/dialer.go | 5 +- transport/internet/tcp/hub.go | 5 +- 6 files changed, 204 insertions(+), 121 deletions(-) create mode 100644 transport/internet/internal/pool.go create mode 100644 transport/internet/internal/pool_test.go delete mode 100644 transport/internet/tcp/connection_cache.go diff --git a/transport/internet/internal/pool.go b/transport/internet/internal/pool.go new file mode 100644 index 000000000..6d622fcb3 --- /dev/null +++ b/transport/internet/internal/pool.go @@ -0,0 +1,121 @@ +package internal + +import ( + "net" + "sync" + "time" + v2net "v2ray.com/core/common/net" + "v2ray.com/core/common/signal" +) + +type ConnectionId struct { + Local v2net.Address + Remote v2net.Address + RemotePort v2net.Port +} + +func NewConnectionId(source v2net.Address, dest v2net.Destination) ConnectionId { + return ConnectionId{ + Local: source, + Remote: dest.Address, + RemotePort: dest.Port, + } +} + +type ExpiringConnection struct { + conn net.Conn + expire time.Time +} + +func (o *ExpiringConnection) Expired() bool { + return o.expire.Before(time.Now()) +} + +type Pool struct { + sync.Mutex + connsByDest map[ConnectionId][]*ExpiringConnection + cleanupOnce signal.Once +} + +func NewConnectionPool() *Pool { + return &Pool{ + connsByDest: make(map[ConnectionId][]*ExpiringConnection), + } +} + +func (o *Pool) Get(id ConnectionId) net.Conn { + o.Lock() + defer o.Unlock() + + list, found := o.connsByDest[id] + if !found { + return nil + } + connIdx := -1 + for idx, conn := range list { + if !conn.Expired() { + connIdx = idx + break + } + } + if connIdx == -1 { + return nil + } + listLen := len(list) + conn := list[connIdx] + if connIdx != listLen-1 { + list[connIdx] = list[listLen-1] + } + list = list[:listLen-1] + o.connsByDest[id] = list + return conn.conn +} + +func (o *Pool) Cleanup() { + defer o.cleanupOnce.Reset() + + for len(o.connsByDest) > 0 { + time.Sleep(time.Second * 5) + expiredConns := make([]net.Conn, 0, 16) + o.Lock() + for dest, list := range o.connsByDest { + validConns := make([]*ExpiringConnection, 0, len(list)) + for _, conn := range list { + if conn.Expired() { + expiredConns = append(expiredConns, conn.conn) + } else { + validConns = append(validConns, conn) + } + } + if len(validConns) != len(list) { + o.connsByDest[dest] = validConns + } + } + o.Unlock() + for _, conn := range expiredConns { + conn.Close() + } + } +} + +func (o *Pool) Put(id ConnectionId, conn net.Conn) { + expiringConn := &ExpiringConnection{ + conn: conn, + expire: time.Now().Add(time.Second * 4), + } + + o.Lock() + defer o.Unlock() + + list, found := o.connsByDest[id] + if !found { + list = []*ExpiringConnection{expiringConn} + } else { + list = append(list, expiringConn) + } + o.connsByDest[id] = list + + o.cleanupOnce.Do(func() { + go o.Cleanup() + }) +} diff --git a/transport/internet/internal/pool_test.go b/transport/internet/internal/pool_test.go new file mode 100644 index 000000000..0c4592779 --- /dev/null +++ b/transport/internet/internal/pool_test.go @@ -0,0 +1,72 @@ +package internal_test + +import ( + "net" + "testing" + "time" + v2net "v2ray.com/core/common/net" + "v2ray.com/core/testing/assert" + . "v2ray.com/core/transport/internet/internal" +) + +type TestConnection struct { + id string + closed bool +} + +func (o *TestConnection) Read([]byte) (int, error) { + return 0, nil +} + +func (o *TestConnection) Write([]byte) (int, error) { + return 0, nil +} + +func (o *TestConnection) Close() error { + o.closed = true + return nil +} + +func (o *TestConnection) LocalAddr() net.Addr { + return nil +} + +func (o *TestConnection) RemoteAddr() net.Addr { + return nil +} + +func (o *TestConnection) SetDeadline(t time.Time) error { + return nil +} + +func (o *TestConnection) SetReadDeadline(t time.Time) error { + return nil +} + +func (o *TestConnection) SetWriteDeadline(t time.Time) error { + return nil +} + +func TestConnectionCache(t *testing.T) { + assert := assert.On(t) + + pool := NewConnectionPool() + conn := pool.Get(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80)))) + assert.Pointer(conn).IsNil() + + pool.Put(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80))), &TestConnection{id: "test"}) + conn = pool.Get(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80)))) + assert.String(conn.(*TestConnection).id).Equals("test") +} + +func TestConnectionRecycle(t *testing.T) { + assert := assert.On(t) + + pool := NewConnectionPool() + c := &TestConnection{id: "test"} + pool.Put(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80))), c) + time.Sleep(6 * time.Second) + assert.Bool(c.closed).IsTrue() + conn := pool.Get(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80)))) + assert.Pointer(conn).IsNil() +} diff --git a/transport/internet/tcp/connection.go b/transport/internet/tcp/connection.go index 4a059b4e9..4e52b20bc 100644 --- a/transport/internet/tcp/connection.go +++ b/transport/internet/tcp/connection.go @@ -9,7 +9,7 @@ import ( ) type ConnectionManager interface { - Recycle(string, net.Conn) + Put(internal.ConnectionId, net.Conn) } type RawConnection struct { @@ -27,16 +27,16 @@ func (this *RawConnection) SysFd() (int, error) { } type Connection struct { - dest string + id internal.ConnectionId conn net.Conn listener ConnectionManager reusable bool config *Config } -func NewConnection(dest string, conn net.Conn, manager ConnectionManager, config *Config) *Connection { +func NewConnection(id internal.ConnectionId, conn net.Conn, manager ConnectionManager, config *Config) *Connection { return &Connection{ - dest: dest, + id: id, conn: conn, listener: manager, reusable: config.ConnectionReuse.IsEnabled(), @@ -64,7 +64,7 @@ func (this *Connection) Close() error { return io.ErrClosedPipe } if this.Reusable() { - this.listener.Recycle(this.dest, this.conn) + this.listener.Put(this.id, this.conn) return nil } err := this.conn.Close() diff --git a/transport/internet/tcp/connection_cache.go b/transport/internet/tcp/connection_cache.go deleted file mode 100644 index 32b1abe7a..000000000 --- a/transport/internet/tcp/connection_cache.go +++ /dev/null @@ -1,112 +0,0 @@ -package tcp - -import ( - "net" - "sync" - "time" - - "v2ray.com/core/common/signal" -) - -type AwaitingConnection struct { - conn net.Conn - expire time.Time -} - -func (this *AwaitingConnection) Expired() bool { - return this.expire.Before(time.Now()) -} - -type ConnectionCache struct { - sync.Mutex - cache map[string][]*AwaitingConnection - cleanupOnce signal.Once -} - -func NewConnectionCache() *ConnectionCache { - return &ConnectionCache{ - cache: make(map[string][]*AwaitingConnection), - } -} - -func (this *ConnectionCache) Cleanup() { - defer this.cleanupOnce.Reset() - - for len(this.cache) > 0 { - time.Sleep(time.Second * 4) - this.Lock() - for key, value := range this.cache { - size := len(value) - changed := false - for i := 0; i < size; { - if value[i].Expired() { - value[i].conn.Close() - value[i] = value[size-1] - size-- - changed = true - } else { - i++ - } - } - if changed { - for i := size; i < len(value); i++ { - value[i] = nil - } - value = value[:size] - this.cache[key] = value - } - } - this.Unlock() - } -} - -func (this *ConnectionCache) Recycle(dest string, conn net.Conn) { - this.Lock() - defer this.Unlock() - - aconn := &AwaitingConnection{ - conn: conn, - expire: time.Now().Add(time.Second * 4), - } - - var list []*AwaitingConnection - if v, found := this.cache[dest]; found { - v = append(v, aconn) - list = v - } else { - list = []*AwaitingConnection{aconn} - } - this.cache[dest] = list - - go this.cleanupOnce.Do(this.Cleanup) -} - -func FindFirstValid(list []*AwaitingConnection) int { - for idx, conn := range list { - if !conn.Expired() { - return idx - } - go conn.conn.Close() - } - return -1 -} - -func (this *ConnectionCache) Get(dest string) net.Conn { - this.Lock() - defer this.Unlock() - - list, found := this.cache[dest] - if !found { - return nil - } - - firstValid := FindFirstValid(list) - if firstValid == -1 { - delete(this.cache, dest) - return nil - } - res := list[firstValid].conn - list = list[firstValid+1:] - this.cache[dest] = list - return res -} diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index 54bc2802f..bfe7e51dc 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -8,11 +8,12 @@ import ( "v2ray.com/core/common/log" v2net "v2ray.com/core/common/net" "v2ray.com/core/transport/internet" + "v2ray.com/core/transport/internet/internal" v2tls "v2ray.com/core/transport/internet/tls" ) var ( - globalCache = NewConnectionCache() + globalCache = internal.NewConnectionPool() ) func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (internet.Connection, error) { @@ -26,7 +27,7 @@ func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOpti } tcpSettings := networkSettings.(*Config) - id := src.String() + "-" + dest.NetAddr() + id := internal.NewConnectionId(src, dest) var conn net.Conn if dest.Network == v2net.Network_TCP && tcpSettings.ConnectionReuse.IsEnabled() { conn = globalCache.Get(id) diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index a75912ed2..663510136 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -10,6 +10,7 @@ import ( "v2ray.com/core/common/log" v2net "v2ray.com/core/common/net" "v2ray.com/core/transport/internet" + "v2ray.com/core/transport/internet/internal" v2tls "v2ray.com/core/transport/internet/tls" ) @@ -89,7 +90,7 @@ func (this *TCPListener) Accept() (internet.Connection, error) { return nil, connErr.err } conn := connErr.conn - return NewConnection("", conn, this, this.config), nil + return NewConnection(internal.ConnectionId{}, conn, this, this.config), nil case <-time.After(time.Second * 2): } } @@ -125,7 +126,7 @@ func (this *TCPListener) KeepAccepting() { } } -func (this *TCPListener) Recycle(dest string, conn net.Conn) { +func (this *TCPListener) Put(id internal.ConnectionId, conn net.Conn) { this.Lock() defer this.Unlock() if !this.acccepting {