From bf7b8798a96e5785fe70cdd892fd614e80b5da08 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sun, 3 Dec 2017 21:29:27 +0100 Subject: [PATCH] simplify kcp interface --- transport/internet/kcp/connection.go | 28 ++++--- transport/internet/kcp/connection_test.go | 54 ++++--------- transport/internet/kcp/dialer.go | 99 +++++------------------ transport/internet/kcp/listener.go | 67 ++------------- 4 files changed, 60 insertions(+), 188 deletions(-) diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 3d064fb3e..220d4a08d 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -168,13 +168,15 @@ type SystemConnection interface { Overhead() int } -var ( - _ buf.Reader = (*Connection)(nil) -) +type ConnMetadata struct { + LocalAddr net.Addr + RemoteAddr net.Addr +} // Connection is a KCP connection over UDP. type Connection struct { - conn SystemConnection + meta *ConnMetadata + closer io.Closer rd time.Time wd time.Time // write deadline since int64 @@ -201,24 +203,24 @@ type Connection struct { } // NewConnection create a new KCP connection between local and remote. -func NewConnection(conv uint16, sysConn SystemConnection, config *Config) *Connection { +func NewConnection(conv uint16, meta *ConnMetadata, writer PacketWriter, closer io.Closer, config *Config) *Connection { log.Trace(newError("creating connection ", conv)) conn := &Connection{ conv: conv, - conn: sysConn, + meta: meta, + closer: closer, since: nowMillisec(), dataInput: make(chan bool, 1), dataOutput: make(chan bool, 1), Config: config, - output: NewRetryableWriter(NewSegmentWriter(sysConn)), - mss: config.GetMTUValue() - uint32(sysConn.Overhead()) - DataSegmentOverhead, + output: NewRetryableWriter(NewSegmentWriter(writer)), + mss: config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead, roundTrip: &RoundTripInfo{ rto: 100, minRtt: config.GetTTIValue(), }, } - sysConn.Reset(conn.Input) conn.receivingWorker = NewReceivingWorker(conn) conn.sendingWorker = NewSendingWorker(conn) @@ -413,7 +415,7 @@ func (v *Connection) Close() error { if state.Is(StateReadyToClose, StateTerminating, StateTerminated) { return ErrClosedConnection } - log.Trace(newError("closing connection to ", v.conn.RemoteAddr())) + log.Trace(newError("closing connection to ", v.meta.RemoteAddr)) if state == StateActive { v.SetState(StateReadyToClose) @@ -433,7 +435,7 @@ func (v *Connection) LocalAddr() net.Addr { if v == nil { return nil } - return v.conn.LocalAddr() + return v.meta.LocalAddr } // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. @@ -441,7 +443,7 @@ func (v *Connection) RemoteAddr() net.Addr { if v == nil { return nil } - return v.conn.RemoteAddr() + return v.meta.RemoteAddr } // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. @@ -488,7 +490,7 @@ func (v *Connection) Terminate() { v.OnDataInput() v.OnDataOutput() - v.conn.Close() + v.closer.Close() v.sendingWorker.Release() v.receivingWorker.Release() } diff --git a/transport/internet/kcp/connection_test.go b/transport/internet/kcp/connection_test.go index f397a34f9..85fc6bc29 100644 --- a/transport/internet/kcp/connection_test.go +++ b/transport/internet/kcp/connection_test.go @@ -1,59 +1,27 @@ package kcp_test import ( - "net" + "io" "testing" "time" + "v2ray.com/core/common/buf" . "v2ray.com/core/transport/internet/kcp" . "v2ray.com/ext/assert" ) -type NoOpConn struct{} +type NoOpCloser int -func (o *NoOpConn) Overhead() int { - return 0 -} - -// Write implements io.Writer. -func (o *NoOpConn) Write(b []byte) (int, error) { - return len(b), nil -} - -func (o *NoOpConn) Close() error { +func (NoOpCloser) Close() error { return nil } -func (o *NoOpConn) Read([]byte) (int, error) { - panic("Should not be called.") -} - -func (o *NoOpConn) LocalAddr() net.Addr { - return nil -} - -func (o *NoOpConn) RemoteAddr() net.Addr { - return nil -} - -func (o *NoOpConn) SetDeadline(time.Time) error { - return nil -} - -func (o *NoOpConn) SetReadDeadline(time.Time) error { - return nil -} - -func (o *NoOpConn) SetWriteDeadline(time.Time) error { - return nil -} - -func (o *NoOpConn) Reset(input func([]Segment)) {} - func TestConnectionReadTimeout(t *testing.T) { assert := With(t) - conn := NewConnection(1, &NoOpConn{}, &Config{}) + conn := NewConnection(1, &ConnMetadata{}, &KCPPacketWriter{ + Writer: buf.DiscardBytes, + }, NoOpCloser(0), &Config{}) conn.SetReadDeadline(time.Now().Add(time.Second)) b := make([]byte, 1024) @@ -63,3 +31,11 @@ func TestConnectionReadTimeout(t *testing.T) { conn.Terminate() } + +func TestConnectionInterface(t *testing.T) { + assert := With(t) + + assert((*Connection)(nil), Implements, (*io.Writer)(nil)) + assert((*Connection)(nil), Implements, (*io.Reader)(nil)) + assert((*Connection)(nil), Implements, (*buf.Reader)(nil)) +} diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index 73d9a90c1..ddb0c7411 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -2,9 +2,8 @@ package kcp import ( "context" - "crypto/cipher" "crypto/tls" - "sync" + "io" "sync/atomic" "v2ray.com/core/app/log" @@ -20,84 +19,20 @@ var ( globalConv = uint32(dice.RollUint16()) ) -type ClientConnection struct { - sync.RWMutex - net.Conn - input func([]Segment) - reader PacketReader - writer PacketWriter -} - -func (c *ClientConnection) Overhead() int { - c.RLock() - defer c.RUnlock() - if c.writer == nil { - return 0 - } - return c.writer.Overhead() -} - -// Write implements io.Writer. -func (c *ClientConnection) Write(b []byte) (int, error) { - c.RLock() - defer c.RUnlock() - - if c.writer == nil { - return len(b), nil - } - - return c.writer.Write(b) -} - -func (*ClientConnection) Read([]byte) (int, error) { - panic("KCP|ClientConnection: Read should not be called.") -} - -func (c *ClientConnection) Close() error { - return c.Conn.Close() -} - -func (c *ClientConnection) Reset(inputCallback func([]Segment)) { - c.Lock() - c.input = inputCallback - c.Unlock() -} - -func (c *ClientConnection) ResetSecurity(header internet.PacketHeader, security cipher.AEAD) { - c.Lock() - if c.reader == nil { - c.reader = new(KCPPacketReader) - } - c.reader.(*KCPPacketReader).Header = header - c.reader.(*KCPPacketReader).Security = security - if c.writer == nil { - c.writer = new(KCPPacketWriter) - } - c.writer.(*KCPPacketWriter).Header = header - c.writer.(*KCPPacketWriter).Security = security - c.writer.(*KCPPacketWriter).Writer = c.Conn - - c.Unlock() -} - -func (c *ClientConnection) Run() { +func fetchInput(ctx context.Context, input io.Reader, reader PacketReader, conn *Connection) { payload := buf.New() defer payload.Release() for { - err := payload.Reset(buf.ReadFrom(c.Conn)) + err := payload.Reset(buf.ReadFrom(input)) if err != nil { payload.Release() return } - c.RLock() - if c.input != nil { - segments := c.reader.Read(payload.Bytes()) - if len(segments) > 0 { - c.input(segments) - } + segments := reader.Read(payload.Bytes()) + if len(segments) > 0 { + conn.Input(segments) } - c.RUnlock() } } @@ -110,10 +45,6 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er if err != nil { return nil, newError("failed to dial to dest: ", err).AtWarning().Base(err) } - conn := &ClientConnection{ - Conn: rawConn, - } - go conn.Run() kcpSettings := internet.TransportSettingsFromContext(ctx).(*Config) @@ -125,9 +56,23 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er if err != nil { return nil, newError("failed to create security").Base(err) } - conn.ResetSecurity(header, security) + reader := &KCPPacketReader{ + Header: header, + Security: security, + } + writer := &KCPPacketWriter{ + Header: header, + Security: security, + Writer: rawConn, + } + conv := uint16(atomic.AddUint32(&globalConv, 1)) - session := NewConnection(conv, conn, kcpSettings) + session := NewConnection(conv, &ConnMetadata{ + LocalAddr: rawConn.LocalAddr(), + RemoteAddr: rawConn.RemoteAddr(), + }, writer, rawConn, kcpSettings) + + go fetchInput(ctx, rawConn, reader, session) var iConn internet.Connection = session diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index 8c4d82a40..b111c5137 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -4,9 +4,7 @@ import ( "context" "crypto/cipher" "crypto/tls" - "io" "sync" - "time" "v2ray.com/core/app/log" "v2ray.com/core/common" @@ -23,52 +21,6 @@ type ConnectionID struct { Conv uint16 } -type ServerConnection struct { - local net.Addr - remote net.Addr - writer PacketWriter - closer io.Closer -} - -func (c *ServerConnection) Overhead() int { - return c.writer.Overhead() -} - -func (*ServerConnection) Read([]byte) (int, error) { - panic("KCP|ServerConnection: Read should not be called.") -} - -func (c *ServerConnection) Write(b []byte) (int, error) { - return c.writer.Write(b) -} - -func (c *ServerConnection) Close() error { - return c.closer.Close() -} - -func (*ServerConnection) Reset(input func([]Segment)) { -} - -func (c *ServerConnection) LocalAddr() net.Addr { - return c.local -} - -func (c *ServerConnection) RemoteAddr() net.Addr { - return c.remote -} - -func (*ServerConnection) SetDeadline(time.Time) error { - return nil -} - -func (*ServerConnection) SetReadDeadline(time.Time) error { - return nil -} - -func (*ServerConnection) SetWriteDeadline(time.Time) error { - return nil -} - // Listener defines a server listening for connections type Listener struct { sync.Mutex @@ -172,17 +124,14 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src net.Destination, originalD Port: int(src.Port), } localAddr := v.hub.Addr() - sConn := &ServerConnection{ - local: localAddr, - remote: remoteAddr, - writer: &KCPPacketWriter{ - Header: v.header, - Writer: writer, - Security: v.security, - }, - closer: writer, - } - conn = NewConnection(conv, sConn, v.config) + conn = NewConnection(conv, &ConnMetadata{ + LocalAddr: localAddr, + RemoteAddr: remoteAddr, + }, &KCPPacketWriter{ + Header: v.header, + Security: v.security, + Writer: writer, + }, writer, v.config) var netConn internet.Connection = conn if v.tlsConfig != nil { tlsConn := tls.Server(conn, v.tlsConfig)