From 3cc6d8f653474131ddab271370b82f0ce6df82f7 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sun, 17 Dec 2017 01:22:39 +0100 Subject: [PATCH] fix a data race in KCP --- transport/internet/kcp/connection.go | 26 ++++++++------------------ transport/internet/kcp/sending.go | 8 +++++--- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 871e7aa8e..ba503ecee 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -8,7 +8,6 @@ import ( "time" "v2ray.com/core/app/log" - "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/predicate" ) @@ -360,16 +359,12 @@ func (v *Connection) Write(b []byte) (int, error) { return totalWritten, io.ErrClosedPipe } - for { - rb := v.sendingWorker.Push() - if rb == nil { - break - } - common.Must(rb.Reset(func(bb []byte) (int, error) { - return copy(bb[:v.mss], b[totalWritten:]), nil - })) + for v.sendingWorker.Push(func(bb []byte) (int, error) { + n := copy(bb[:v.mss], b[totalWritten:]) + totalWritten += n + return n, nil + }) { v.dataUpdater.WakeUp() - totalWritten += rb.Len() if totalWritten == len(b) { return totalWritten, nil } @@ -390,14 +385,9 @@ func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { return io.ErrClosedPipe } - for { - rb := v.sendingWorker.Push() - if rb == nil { - break - } - common.Must(rb.Reset(func(bb []byte) (int, error) { - return mb.Read(bb[:v.mss]) - })) + for v.sendingWorker.Push(func(bb []byte) (int, error) { + return mb.Read(bb[:v.mss]) + }) { v.dataUpdater.WakeUp() if mb.IsEmpty() { return nil diff --git a/transport/internet/kcp/sending.go b/transport/internet/kcp/sending.go index cfd40eaee..7b2d9769d 100644 --- a/transport/internet/kcp/sending.go +++ b/transport/internet/kcp/sending.go @@ -3,6 +3,7 @@ package kcp import ( "sync" + "v2ray.com/core/common" "v2ray.com/core/common/buf" ) @@ -284,17 +285,18 @@ func (v *SendingWorker) ProcessSegment(current uint32, seg *AckSegment, rto uint } } -func (v *SendingWorker) Push() *buf.Buffer { +func (v *SendingWorker) Push(f buf.Supplier) bool { v.Lock() defer v.Unlock() if v.window.IsFull() { - return nil + return false } b := v.window.Push(v.nextNumber) v.nextNumber++ - return b + common.Must(b.Reset(f)) + return true } func (v *SendingWorker) Write(seg Segment) error {