diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 64316170f..cff2ca0c2 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -277,18 +277,6 @@ func (c *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { } func (c *Connection) waitForDataInput() error { - if c.State() == StatePeerTerminating { - return io.EOF - } - - duration := time.Second * 8 - if !c.rd.IsZero() { - duration = time.Until(c.rd) - if duration < 0 { - return ErrIOTimeout - } - } - for i := 0; i < 16; i++ { select { case <-c.dataInput.Wait(): @@ -298,6 +286,14 @@ func (c *Connection) waitForDataInput() error { } } + duration := time.Second * 16 + if !c.rd.IsZero() { + duration = time.Until(c.rd) + if duration < 0 { + return ErrIOTimeout + } + } + timeout := time.NewTimer(duration) defer timeout.Stop() @@ -335,7 +331,16 @@ func (c *Connection) Read(b []byte) (int, error) { } func (c *Connection) waitForDataOutput() error { - duration := time.Minute + for i := 0; i < 16; i++ { + select { + case <-c.dataOutput.Wait(): + return nil + default: + runtime.Gosched() + } + } + + duration := time.Second * 16 if !c.wd.IsZero() { duration = time.Until(c.wd) if duration < 0 { @@ -343,15 +348,6 @@ func (c *Connection) waitForDataOutput() error { } } - for i := 0; i < 16; i++ { - select { - case <-c.dataInput.Wait(): - return nil - default: - runtime.Gosched() - } - } - timeout := time.NewTimer(duration) defer timeout.Stop()