From e023859ef0d946ae75311da280828ed8985d47fc Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Mon, 10 Oct 2016 16:50:54 +0200 Subject: [PATCH] stop data updating thread when there is no data --- common/predicate/predicate.go | 25 +++++++ transport/internet/kcp/connection.go | 79 +++++++++++++++++++---- transport/internet/kcp/connection_test.go | 2 + transport/internet/kcp/receiving.go | 4 ++ transport/internet/kcp/sending.go | 4 ++ 5 files changed, 102 insertions(+), 12 deletions(-) create mode 100644 common/predicate/predicate.go diff --git a/common/predicate/predicate.go b/common/predicate/predicate.go new file mode 100644 index 000000000..d107071f8 --- /dev/null +++ b/common/predicate/predicate.go @@ -0,0 +1,25 @@ +package predicate + +type Predicate func() bool + +func All(predicates ...Predicate) Predicate { + return func() bool { + for _, p := range predicates { + if !p() { + return false + } + } + return true + } +} + +func Any(predicates ...Predicate) Predicate { + return func() bool { + for _, p := range predicates { + if p() { + return true + } + } + return false + } +} diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 8e0bde83c..ecd80a3d7 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -10,6 +10,7 @@ import ( "v2ray.com/core/common/alloc" "v2ray.com/core/common/log" + "v2ray.com/core/common/predicate" "v2ray.com/core/transport/internet" ) @@ -119,6 +120,45 @@ func (this *RoundTripInfo) SmoothedTime() uint32 { return this.srtt } +type Updater struct { + interval time.Duration + shouldContinue predicate.Predicate + shouldTerminate predicate.Predicate + updateFunc func() + notifier chan bool +} + +func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater { + u := &Updater{ + interval: time.Duration(interval) * time.Millisecond, + shouldContinue: shouldContinue, + shouldTerminate: shouldTerminate, + updateFunc: updateFunc, + notifier: make(chan bool, 1), + } + go u.Run() + return u +} + +func (this *Updater) WakeUp() { + select { + case this.notifier <- true: + default: + } +} + +func (this *Updater) Run() { + for <-this.notifier { + if this.shouldTerminate() { + return + } + for this.shouldContinue() { + this.updateFunc() + time.Sleep(this.interval) + } + } +} + // Connection is a KCP connection over UDP. type Connection struct { block internet.Authenticator @@ -147,6 +187,9 @@ type Connection struct { fastresend uint32 congestionControl bool output *BufferedSegmentWriter + + dataUpdater *Updater + pingUpdater *Updater } // NewConnection create a new KCP connection between local and remote. @@ -182,7 +225,18 @@ func NewConnection(conv uint16, writerCloser io.WriteCloser, local *net.UDPAddr, conn.congestionControl = config.Congestion conn.sendingWorker = NewSendingWorker(conn) - go conn.updateTask() + conn.dataUpdater = NewUpdater( + conn.interval, + predicate.Any(conn.sendingWorker.UpdateNecessary, conn.receivingWorker.UpdateNecessary), + func() bool { + return conn.State() == StateTerminated + }, + conn.updateTask) + conn.pingUpdater = NewUpdater( + 3000, // 3 seconds + func() bool { return conn.State() != StateTerminated }, + func() bool { return conn.State() == StateTerminated }, + conn.updateTask) return conn } @@ -240,6 +294,7 @@ func (this *Connection) Write(b []byte) (int, error) { } nBytes := this.sendingWorker.Push(b[totalWritten:]) + this.dataUpdater.WakeUp() if nBytes > 0 { totalWritten += nBytes if totalWritten == len(b) { @@ -278,16 +333,24 @@ func (this *Connection) SetState(state State) { switch state { case StateReadyToClose: this.receivingWorker.CloseRead() + this.dataUpdater.WakeUp() case StatePeerClosed: this.sendingWorker.CloseWrite() + this.dataUpdater.WakeUp() case StateTerminating: this.receivingWorker.CloseRead() this.sendingWorker.CloseWrite() + this.dataUpdater.interval = time.Second + this.dataUpdater.WakeUp() case StatePeerTerminating: this.sendingWorker.CloseWrite() + this.dataUpdater.WakeUp() case StateTerminated: this.receivingWorker.CloseRead() this.sendingWorker.CloseWrite() + this.dataUpdater.interval = time.Second + this.dataUpdater.WakeUp() + this.Terminate() } } @@ -366,16 +429,7 @@ func (this *Connection) SetWriteDeadline(t time.Time) error { // kcp update, input loop func (this *Connection) updateTask() { - for this.State() != StateTerminated { - this.flush() - - interval := time.Duration(this.Config.Tti.GetValue()) * time.Millisecond - if this.State() == StateTerminating { - interval = time.Second - } - time.Sleep(interval) - } - this.Terminate() + this.flush() } func (this *Connection) FetchInputFrom(conn io.Reader) { @@ -408,7 +462,7 @@ func (this *Connection) Terminate() { } log.Info("KCP|Connection: Terminating connection to ", this.RemoteAddr()) - this.SetState(StateTerminated) + //this.SetState(StateTerminated) this.dataInputCond.Broadcast() this.dataOutputCond.Broadcast() this.writer.Close() @@ -434,6 +488,7 @@ func (this *Connection) OnPeerClosed() { func (this *Connection) Input(data []byte) int { current := this.Elapsed() atomic.StoreUint32(&this.lastIncomingTime, current) + this.dataUpdater.WakeUp() var seg Segment for { diff --git a/transport/internet/kcp/connection_test.go b/transport/internet/kcp/connection_test.go index f3250326f..d095cbc2b 100644 --- a/transport/internet/kcp/connection_test.go +++ b/transport/internet/kcp/connection_test.go @@ -34,6 +34,8 @@ func TestConnectionReadTimeout(t *testing.T) { nBytes, err := conn.Read(b) assert.Int(nBytes).Equals(0) assert.Error(err).IsNotNil() + + conn.Terminate() } func TestConnectionReadWrite(t *testing.T) { diff --git a/transport/internet/kcp/receiving.go b/transport/internet/kcp/receiving.go index ebc248732..eaabc14f9 100644 --- a/transport/internet/kcp/receiving.go +++ b/transport/internet/kcp/receiving.go @@ -213,3 +213,7 @@ func (this *ReceivingWorker) Write(seg Segment) { func (this *ReceivingWorker) CloseRead() { } + +func (this *ReceivingWorker) UpdateNecessary() bool { + return len(this.acklist.numbers) > 0 +} diff --git a/transport/internet/kcp/sending.go b/transport/internet/kcp/sending.go index b324314b7..a76a24cac 100644 --- a/transport/internet/kcp/sending.go +++ b/transport/internet/kcp/sending.go @@ -338,3 +338,7 @@ func (this *SendingWorker) IsEmpty() bool { return this.window.IsEmpty() } + +func (this *SendingWorker) UpdateNecessary() bool { + return !this.IsEmpty() +}