diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 25fec1979..88d7d05a3 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -38,7 +38,7 @@ const ( StatePeerClosed State = 2 // Connection is closed on remote StateTerminating State = 3 // Connection is ready to be destroyed locally StatePeerTerminating State = 4 // Connection is ready to be destroyed on remote - StateTerminated State = 5 // Connection is detroyed. + StateTerminated State = 5 // Connection is destroyed. ) func nowMillisec() int64 { @@ -137,21 +137,21 @@ func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTermi return u } -func (v *Updater) WakeUp() { +func (u *Updater) WakeUp() { select { - case v.notifier <- true: + case u.notifier <- true: default: } } -func (v *Updater) Run() { - for <-v.notifier { - if v.shouldTerminate() { +func (u *Updater) Run() { + for <-u.notifier { + if u.shouldTerminate() { return } - interval := v.Interval() - for v.shouldContinue() { - v.updateFunc() + interval := u.Interval() + for u.shouldContinue() { + u.updateFunc() time.Sleep(interval) } } @@ -280,24 +280,36 @@ func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { return nil, io.EOF } - duration := time.Minute - if !v.rd.IsZero() { - duration = time.Until(v.rd) - if duration < 0 { - return nil, ErrIOTimeout - } - } - - select { - case <-v.dataInput: - case <-time.After(duration): - if !v.rd.IsZero() && v.rd.Before(time.Now()) { - return nil, ErrIOTimeout - } + if err := v.waitForDataInput(); err != nil { + return nil, err } } } +func (v *Connection) waitForDataInput() error { + if v.State() == StatePeerTerminating { + return io.EOF + } + + duration := time.Minute + if !v.rd.IsZero() { + duration = time.Until(v.rd) + if duration < 0 { + return ErrIOTimeout + } + } + + select { + case <-v.dataInput: + case <-time.After(duration): + if !v.rd.IsZero() && v.rd.Before(time.Now()) { + return ErrIOTimeout + } + } + + return nil +} + // Read implements the Conn Read method. func (v *Connection) Read(b []byte) (int, error) { if v == nil { @@ -313,28 +325,32 @@ func (v *Connection) Read(b []byte) (int, error) { return nBytes, nil } - if v.State() == StatePeerTerminating { - return 0, io.EOF - } - - duration := time.Minute - if !v.rd.IsZero() { - duration = time.Until(v.rd) - if duration < 0 { - return 0, ErrIOTimeout - } - } - - select { - case <-v.dataInput: - case <-time.After(duration): - if !v.rd.IsZero() && v.rd.Before(time.Now()) { - return 0, ErrIOTimeout - } + if err := v.waitForDataInput(); err != nil { + return 0, err } } } +func (v *Connection) waitForDataOutput() error { + duration := time.Minute + if !v.wd.IsZero() { + duration = time.Until(v.wd) + if duration < 0 { + return ErrIOTimeout + } + } + + select { + case <-v.dataOutput: + case <-time.After(duration): + if !v.wd.IsZero() && v.wd.Before(time.Now()) { + return ErrIOTimeout + } + } + + return nil +} + // Write implements io.Writer. func (v *Connection) Write(b []byte) (int, error) { totalWritten := 0 @@ -359,20 +375,8 @@ func (v *Connection) Write(b []byte) (int, error) { } } - duration := time.Minute - if !v.wd.IsZero() { - duration = time.Until(v.wd) - if duration < 0 { - return totalWritten, ErrIOTimeout - } - } - - select { - case <-v.dataOutput: - case <-time.After(duration): - if !v.wd.IsZero() && v.wd.Before(time.Now()) { - return totalWritten, ErrIOTimeout - } + if err := v.waitForDataOutput(); err != nil { + return totalWritten, err } } } @@ -400,20 +404,8 @@ func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { } } - duration := time.Minute - if !v.wd.IsZero() { - duration = time.Until(v.wd) - if duration < 0 { - return ErrIOTimeout - } - } - - select { - case <-v.dataOutput: - case <-time.After(duration): - if !v.wd.IsZero() && v.wd.Before(time.Now()) { - return ErrIOTimeout - } + if err := v.waitForDataOutput(); err != nil { + return err } } }