From 4c74e25319ef4ba1a24b09b5abf17b4548c046aa Mon Sep 17 00:00:00 2001 From: v2ray Date: Fri, 1 Jul 2016 11:57:13 +0200 Subject: [PATCH] sending window --- transport/internet/kcp/kcp.go | 115 +++++--------------- transport/internet/kcp/sending.go | 141 +++++++++++++++++++++++++ transport/internet/kcp/sending_test.go | 33 ++++++ 3 files changed, 202 insertions(+), 87 deletions(-) diff --git a/transport/internet/kcp/kcp.go b/transport/internet/kcp/kcp.go index f13df2381..58d4cbe08 100644 --- a/transport/internet/kcp/kcp.go +++ b/transport/internet/kcp/kcp.go @@ -54,7 +54,7 @@ type KCP struct { snd_queue *SendingQueue rcv_queue *ReceivingQueue - snd_buf []*DataSegment + snd_buf *SendingWindow rcv_buf *ReceivingWindow acklist *ACKList @@ -82,6 +82,7 @@ func NewKCP(conv uint16, mtu uint32, sendingWindowSize uint32, receivingWindowSi kcp.snd_queue = NewSendingQueue(sendingQueueSize) kcp.rcv_queue = NewReceivingQueue() kcp.acklist = NewACKList(kcp) + kcp.snd_buf = NewSendingWindow(kcp, sendingWindowSize) kcp.cwnd = kcp.snd_wnd return kcp } @@ -194,8 +195,8 @@ func (kcp *KCP) update_ack(rtt int32) { func (kcp *KCP) shrink_buf() { prevUna := kcp.snd_una - if len(kcp.snd_buf) > 0 { - seg := kcp.snd_buf[0] + if kcp.snd_buf.Len() > 0 { + seg := kcp.snd_buf.First() kcp.snd_una = seg.Number } else { kcp.snd_una = kcp.snd_nxt @@ -210,16 +211,7 @@ func (kcp *KCP) parse_ack(sn uint32) { return } - for k, seg := range kcp.snd_buf { - if sn == seg.Number { - kcp.snd_buf = append(kcp.snd_buf[:k], kcp.snd_buf[k+1:]...) - seg.Release() - break - } - if _itimediff(sn, seg.Number) < 0 { - break - } - } + kcp.snd_buf.Remove(sn - kcp.snd_una) } func (kcp *KCP) parse_fastack(sn uint32) { @@ -227,26 +219,11 @@ func (kcp *KCP) parse_fastack(sn uint32) { return } - for _, seg := range kcp.snd_buf { - if _itimediff(sn, seg.Number) < 0 { - break - } else if sn != seg.Number { - seg.ackSkipped++ - } - } + kcp.snd_buf.HandleFastAck(sn) } func (kcp *KCP) HandleReceivingNext(receivingNext uint32) { - count := 0 - for _, seg := range kcp.snd_buf { - if _itimediff(receivingNext, seg.Number) > 0 { - seg.Release() - count++ - } else { - break - } - } - kcp.snd_buf = kcp.snd_buf[count:] + kcp.snd_buf.Clear(receivingNext) } func (kcp *KCP) HandleSendingNext(sendingNext uint32) { @@ -362,7 +339,6 @@ func (kcp *KCP) flush() { } current := kcp.current - lost := false // flush acknowledges if kcp.acklist.Flush() { @@ -385,47 +361,13 @@ func (kcp *KCP) flush() { seg.timeout = current seg.ackSkipped = 0 seg.transmit = 0 - kcp.snd_buf = append(kcp.snd_buf, seg) + kcp.snd_buf.Push(seg) kcp.snd_nxt++ } - // calculate resent - resent := uint32(kcp.fastresend) - if kcp.fastresend <= 0 { - resent = 0xffffffff - } - // flush data segments - for _, segment := range kcp.snd_buf { - needsend := false - if segment.transmit == 0 { - needsend = true - segment.transmit++ - segment.timeout = current + kcp.rx_rto - } else if _itimediff(current, segment.timeout) >= 0 { - needsend = true - segment.transmit++ - segment.timeout = current + kcp.rx_rto - lost = true - } else if segment.ackSkipped >= resent { - needsend = true - segment.transmit++ - segment.ackSkipped = 0 - segment.timeout = current + kcp.rx_rto - lost = true - } - - if needsend { - segment.Timestamp = current - segment.SendingNext = kcp.snd_una - segment.Opt = 0 - if kcp.state == StateReadyToClose { - segment.Opt = SegmentOptionClose - } - - kcp.output.Write(segment) - kcp.sendingUpdated = false - } + if kcp.snd_buf.Flush() { + kcp.sendingUpdated = false } if kcp.sendingUpdated || kcp.receivingUpdated || _itimediff(kcp.current, kcp.lastPingTime) >= 5000 { @@ -447,18 +389,22 @@ func (kcp *KCP) flush() { // flash remain segments kcp.output.Flush() - if kcp.congestionControl { - if lost { - kcp.cwnd = 3 * kcp.cwnd / 4 - } else { - kcp.cwnd += kcp.cwnd / 4 - } - if kcp.cwnd < 4 { - kcp.cwnd = 4 - } - if kcp.cwnd > kcp.snd_wnd { - kcp.cwnd = kcp.snd_wnd - } +} + +func (kcp *KCP) HandleLost(lost bool) { + if !kcp.congestionControl { + return + } + if lost { + kcp.cwnd = 3 * kcp.cwnd / 4 + } else { + kcp.cwnd += kcp.cwnd / 4 + } + if kcp.cwnd < 4 { + kcp.cwnd = 4 + } + if kcp.cwnd > kcp.snd_wnd { + kcp.cwnd = kcp.snd_wnd } } @@ -488,15 +434,10 @@ func (kcp *KCP) NoDelay(interval uint32, resend int, congestionControl bool) int // WaitSnd gets how many packet is waiting to be sent func (kcp *KCP) WaitSnd() uint32 { - return uint32(len(kcp.snd_buf)) + kcp.snd_queue.Len() + return uint32(kcp.snd_buf.Len()) + kcp.snd_queue.Len() } func (this *KCP) ClearSendQueue() { this.snd_queue.Clear() - - for _, seg := range this.snd_buf { - seg.Release() - } - - this.snd_buf = nil + this.snd_buf.Clear(0xFFFFFFFF) } diff --git a/transport/internet/kcp/sending.go b/transport/internet/kcp/sending.go index 9f28ee26a..281f2d715 100644 --- a/transport/internet/kcp/sending.go +++ b/transport/internet/kcp/sending.go @@ -1,5 +1,146 @@ package kcp +type SendingWindow struct { + start uint32 + cap uint32 + len uint32 + last uint32 + + data []*DataSegment + prev []uint32 + next []uint32 + + kcp *KCP +} + +func NewSendingWindow(kcp *KCP, size uint32) *SendingWindow { + window := &SendingWindow{ + start: 0, + cap: size, + len: 0, + last: 0, + data: make([]*DataSegment, size), + prev: make([]uint32, size), + next: make([]uint32, size), + } + return window +} + +func (this *SendingWindow) Len() int { + return int(this.len) +} + +func (this *SendingWindow) Push(seg *DataSegment) { + pos := (this.start + this.len) % this.cap + this.data[pos] = seg + if this.len > 0 { + this.next[this.last] = pos + this.prev[pos] = this.last + } + this.last = pos + this.len++ +} + +func (this *SendingWindow) First() *DataSegment { + return this.data[this.start] +} + +func (this *SendingWindow) Clear(una uint32) { + for this.Len() > 0 { + if this.data[this.start].Number < una { + this.Remove(0) + } + } +} + +func (this *SendingWindow) Remove(idx uint32) { + pos := (this.start + idx) % this.cap + seg := this.data[pos] + seg.Release() + this.data[pos] = nil + if pos == this.start { + if this.len == 1 { + this.len = 0 + this.start = 0 + this.last = 0 + } else { + delta := this.next[pos] - this.start + this.start = this.next[pos] + this.len -= delta + } + } else if pos == this.last { + this.last = this.prev[pos] + } else { + this.next[this.prev[pos]] = this.next[pos] + this.prev[this.next[pos]] = this.prev[pos] + } +} + +func (this *SendingWindow) HandleFastAck(number uint32) { + for i := this.start; ; i = this.next[i] { + seg := this.data[i] + if _itimediff(number, seg.Number) < 0 { + break + } + if number != seg.Number { + seg.ackSkipped++ + } + if i == this.last { + break + } + } +} + +func (this *SendingWindow) Flush() bool { + current := this.kcp.current + resent := uint32(this.kcp.fastresend) + if this.kcp.fastresend <= 0 { + resent = 0xffffffff + } + lost := false + segSent := false + + for i := this.start; ; i = this.next[i] { + segment := this.data[i] + needsend := false + if segment.transmit == 0 { + needsend = true + segment.transmit++ + segment.timeout = current + this.kcp.rx_rto + } else if _itimediff(current, segment.timeout) >= 0 { + needsend = true + segment.transmit++ + segment.timeout = current + this.kcp.rx_rto + lost = true + } else if segment.ackSkipped >= resent { + needsend = true + segment.transmit++ + segment.ackSkipped = 0 + segment.timeout = current + this.kcp.rx_rto + lost = true + } + + if needsend { + segment.Timestamp = current + segment.SendingNext = this.kcp.snd_una + segment.Opt = 0 + if this.kcp.state == StateReadyToClose { + segment.Opt = SegmentOptionClose + } + + this.kcp.output.Write(segment) + segSent = true + } + if i == this.last { + break + } + } + + this.kcp.HandleLost(lost) + + return segSent +} + type SendingQueue struct { start uint32 cap uint32 diff --git a/transport/internet/kcp/sending_test.go b/transport/internet/kcp/sending_test.go index ca6486fc4..d6f1539f0 100644 --- a/transport/internet/kcp/sending_test.go +++ b/transport/internet/kcp/sending_test.go @@ -62,3 +62,36 @@ func TestSendingQueueClear(t *testing.T) { queue.Clear() assert.Bool(queue.IsEmpty()).IsTrue() } + +func TestSendingWindow(t *testing.T) { + assert := assert.On(t) + + window := NewSendingWindow(nil, 5) + window.Push(&DataSegment{ + Number: 0, + }) + window.Push(&DataSegment{ + Number: 1, + }) + window.Push(&DataSegment{ + Number: 2, + }) + assert.Int(window.Len()).Equals(3) + + window.Remove(1) + assert.Int(window.Len()).Equals(3) + assert.Uint32(window.First().Number).Equals(0) + + window.Remove(0) + assert.Int(window.Len()).Equals(1) + assert.Uint32(window.First().Number).Equals(2) + + window.Remove(0) + assert.Int(window.Len()).Equals(0) + + window.Push(&DataSegment{ + Number: 4, + }) + assert.Int(window.Len()).Equals(1) + assert.Uint32(window.First().Number).Equals(4) +}