From 19e0cb40e9f19683c1718f6266f755f245f7c611 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Mon, 2 Jan 2017 07:43:02 +0100 Subject: [PATCH] locker protected connection --- transport/internet/tcp/connection.go | 81 ++++++++++++++++++++++------ 1 file changed, 64 insertions(+), 17 deletions(-) diff --git a/transport/internet/tcp/connection.go b/transport/internet/tcp/connection.go index 2995db272..5806d51ba 100644 --- a/transport/internet/tcp/connection.go +++ b/transport/internet/tcp/connection.go @@ -3,15 +3,12 @@ package tcp import ( "io" "net" + "sync" "time" "v2ray.com/core/transport/internet/internal" ) -type ConnectionManager interface { - Put(internal.ConnectionID, net.Conn) -} - type RawConnection struct { net.TCPConn } @@ -27,14 +24,15 @@ func (v *RawConnection) SysFd() (int, error) { } type Connection struct { + sync.RWMutex id internal.ConnectionID reusable bool conn net.Conn - listener ConnectionManager + listener internal.ConnectionRecyler config *Config } -func NewConnection(id internal.ConnectionID, conn net.Conn, manager ConnectionManager, config *Config) *Connection { +func NewConnection(id internal.ConnectionID, conn net.Conn, manager internal.ConnectionRecyler, config *Config) *Connection { return &Connection{ id: id, conn: conn, @@ -45,22 +43,30 @@ func NewConnection(id internal.ConnectionID, conn net.Conn, manager ConnectionMa } func (v *Connection) Read(b []byte) (int, error) { - if v == nil || v.conn == nil { + conn := v.underlyingConn() + if conn == nil { return 0, io.EOF } - return v.conn.Read(b) + return conn.Read(b) } func (v *Connection) Write(b []byte) (int, error) { - if v == nil || v.conn == nil { + conn := v.underlyingConn() + if conn == nil { return 0, io.ErrClosedPipe } - return v.conn.Write(b) + return conn.Write(b) } func (v *Connection) Close() error { - if v == nil || v.conn == nil { + if v == nil { + return io.ErrClosedPipe + } + + v.Lock() + defer v.Unlock() + if v.conn == nil { return io.ErrClosedPipe } if v.Reusable() { @@ -73,33 +79,74 @@ func (v *Connection) Close() error { } func (v *Connection) LocalAddr() net.Addr { - return v.conn.LocalAddr() + conn := v.underlyingConn() + if conn == nil { + return nil + } + return conn.LocalAddr() } func (v *Connection) RemoteAddr() net.Addr { - return v.conn.RemoteAddr() + conn := v.underlyingConn() + if conn == nil { + return nil + } + return conn.RemoteAddr() } func (v *Connection) SetDeadline(t time.Time) error { - return v.conn.SetDeadline(t) + conn := v.underlyingConn() + if conn == nil { + return nil + } + return conn.SetDeadline(t) } func (v *Connection) SetReadDeadline(t time.Time) error { - return v.conn.SetReadDeadline(t) + conn := v.underlyingConn() + if conn == nil { + return nil + } + return conn.SetReadDeadline(t) } func (v *Connection) SetWriteDeadline(t time.Time) error { - return v.conn.SetWriteDeadline(t) + conn := v.underlyingConn() + if conn == nil { + return nil + } + return conn.SetWriteDeadline(t) } func (v *Connection) SetReusable(reusable bool) { + if v == nil { + return + } v.reusable = reusable } func (v *Connection) Reusable() bool { + if v == nil { + return false + } return v.config.IsConnectionReuse() && v.reusable } func (v *Connection) SysFd() (int, error) { - return internal.GetSysFd(v.conn) + conn := v.underlyingConn() + if conn == nil { + return 0, io.ErrClosedPipe + } + return internal.GetSysFd(conn) +} + +func (v *Connection) underlyingConn() net.Conn { + if v == nil { + return nil + } + + v.RLock() + defer v.RUnlock() + + return v.conn }