diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 42d24710f..6ea411ac2 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -331,7 +331,7 @@ func (this *Connection) updateTask() { this.Terminate() } -func (this *Connection) FetchInputFrom(conn net.Conn) { +func (this *Connection) FetchInputFrom(conn io.Reader) { go func() { payload := alloc.NewBuffer() defer payload.Release() @@ -344,8 +344,6 @@ func (this *Connection) FetchInputFrom(conn net.Conn) { payload.Slice(0, nBytes) if this.block.Open(payload) { this.Input(payload.Value) - } else { - log.Info("KCP|Connection: Invalid response from ", conn.RemoteAddr()) } } }() diff --git a/transport/internet/kcp/connection_test.go b/transport/internet/kcp/connection_test.go index 36643f723..e53d67502 100644 --- a/transport/internet/kcp/connection_test.go +++ b/transport/internet/kcp/connection_test.go @@ -1,17 +1,31 @@ package kcp_test import ( + "crypto/rand" + "io" + "net" "testing" "time" + v2net "github.com/v2ray/v2ray-core/common/net" "github.com/v2ray/v2ray-core/testing/assert" . "github.com/v2ray/v2ray-core/transport/internet/kcp" ) +type NoOpWriteCloser struct{} + +func (this *NoOpWriteCloser) Write(b []byte) (int, error) { + return len(b), nil +} + +func (this *NoOpWriteCloser) Close() error { + return nil +} + func TestConnectionReadTimeout(t *testing.T) { assert := assert.On(t) - conn := NewConnection(1, nil, nil, nil, NewSimpleAuthenticator()) + conn := NewConnection(1, &NoOpWriteCloser{}, nil, nil, NewSimpleAuthenticator()) conn.SetReadDeadline(time.Now().Add(time.Second)) b := make([]byte, 1024) @@ -19,3 +33,32 @@ func TestConnectionReadTimeout(t *testing.T) { assert.Int(nBytes).Equals(0) assert.Error(err).IsNotNil() } + +func TestConnectionReadWrite(t *testing.T) { + assert := assert.On(t) + + upReader, upWriter := io.Pipe() + downReader, downWriter := io.Pipe() + + connClient := NewConnection(1, upWriter, &net.UDPAddr{IP: v2net.LocalHostIP.IP(), Port: 1}, &net.UDPAddr{IP: v2net.LocalHostIP.IP(), Port: 2}, NewSimpleAuthenticator()) + go connClient.FetchInputFrom(downReader) + + connServer := NewConnection(1, downWriter, &net.UDPAddr{IP: v2net.LocalHostIP.IP(), Port: 2}, &net.UDPAddr{IP: v2net.LocalHostIP.IP(), Port: 1}, NewSimpleAuthenticator()) + go connServer.FetchInputFrom(upReader) + + totalWritten := 1024 * 1024 + clientSend := make([]byte, totalWritten) + rand.Read(clientSend) + nBytes, err := connClient.Write(clientSend) + assert.Int(nBytes).Equals(totalWritten) + assert.Error(err).IsNil() + + serverReceived := make([]byte, totalWritten) + totalRead := 0 + for totalRead < totalWritten { + nBytes, err = connServer.Read(serverReceived[totalRead:]) + assert.Error(err).IsNil() + totalRead += nBytes + } + assert.Bytes(serverReceived).Equals(clientSend) +}