From 8a09c6c926cb69376978ad4188639d753fd0ade3 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Wed, 27 Dec 2017 21:33:42 +0100 Subject: [PATCH] migrate to signal.Semaphore and Notifier --- common/signal/notifier.go | 22 ++ transport/internet/kcp/connection.go | 370 +++++++++++++-------------- transport/ray/direct.go | 47 ++-- 3 files changed, 215 insertions(+), 224 deletions(-) create mode 100644 common/signal/notifier.go diff --git a/common/signal/notifier.go b/common/signal/notifier.go new file mode 100644 index 000000000..0d98c2205 --- /dev/null +++ b/common/signal/notifier.go @@ -0,0 +1,22 @@ +package signal + +type Notifier struct { + c chan bool +} + +func NewNotifier() *Notifier { + return &Notifier{ + c: make(chan bool, 1), + } +} + +func (n *Notifier) Signal() { + select { + case n.c <- true: + default: + } +} + +func (n *Notifier) Wait() <-chan bool { + return n.c +} diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 8e4e05815..66d0f0baa 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -9,6 +9,7 @@ import ( "v2ray.com/core/common/buf" "v2ray.com/core/common/predicate" + "v2ray.com/core/common/signal" ) var ( @@ -120,7 +121,7 @@ type Updater struct { shouldContinue predicate.Predicate shouldTerminate predicate.Predicate updateFunc func() - notifier chan bool + notifier *signal.Semaphore } func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater { @@ -129,31 +130,31 @@ func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTermi shouldContinue: shouldContinue, shouldTerminate: shouldTerminate, updateFunc: updateFunc, - notifier: make(chan bool, 1), + notifier: signal.NewSemaphore(1), } - go u.Run() return u } func (u *Updater) WakeUp() { select { - case u.notifier <- true: + case <-u.notifier.Wait(): + go u.run() default: } } -func (u *Updater) Run() { - for <-u.notifier { - if u.shouldTerminate() { - return - } - ticker := time.NewTicker(u.Interval()) - for u.shouldContinue() { - u.updateFunc() - <-ticker.C - } - ticker.Stop() +func (u *Updater) run() { + defer u.notifier.Signal() + + if u.shouldTerminate() { + return } + ticker := time.NewTicker(u.Interval()) + for u.shouldContinue() { + u.updateFunc() + <-ticker.C + } + ticker.Stop() } func (u *Updater) Interval() time.Duration { @@ -177,8 +178,8 @@ type Connection struct { rd time.Time wd time.Time // write deadline since int64 - dataInput chan bool - dataOutput chan bool + dataInput *signal.Notifier + dataOutput *signal.Notifier Config *Config state State @@ -206,8 +207,8 @@ func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, con meta: meta, closer: closer, since: nowMillisec(), - dataInput: make(chan bool, 1), - dataOutput: make(chan bool, 1), + dataInput: signal.NewNotifier(), + dataOutput: signal.NewNotifier(), Config: config, output: NewRetryableWriter(NewSegmentWriter(writer)), mss: config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead, @@ -241,66 +242,52 @@ func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, con return conn } -func (v *Connection) Elapsed() uint32 { - return uint32(nowMillisec() - v.since) -} - -func (v *Connection) OnDataInput() { - select { - case v.dataInput <- true: - default: - } -} - -func (v *Connection) OnDataOutput() { - select { - case v.dataOutput <- true: - default: - } +func (c *Connection) Elapsed() uint32 { + return uint32(nowMillisec() - c.since) } // ReadMultiBuffer implements buf.Reader. -func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { - if v == nil { +func (c *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { + if c == nil { return nil, io.EOF } for { - if v.State().Is(StateReadyToClose, StateTerminating, StateTerminated) { + if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) { return nil, io.EOF } - mb := v.receivingWorker.ReadMultiBuffer() + mb := c.receivingWorker.ReadMultiBuffer() if !mb.IsEmpty() { return mb, nil } - if v.State() == StatePeerTerminating { + if c.State() == StatePeerTerminating { return nil, io.EOF } - if err := v.waitForDataInput(); err != nil { + if err := c.waitForDataInput(); err != nil { return nil, err } } } -func (v *Connection) waitForDataInput() error { - if v.State() == StatePeerTerminating { +func (c *Connection) waitForDataInput() error { + if c.State() == StatePeerTerminating { return io.EOF } duration := time.Minute - if !v.rd.IsZero() { - duration = time.Until(v.rd) + if !c.rd.IsZero() { + duration = time.Until(c.rd) if duration < 0 { return ErrIOTimeout } } select { - case <-v.dataInput: + case <-c.dataInput.Wait(): case <-time.After(duration): - if !v.rd.IsZero() && v.rd.Before(time.Now()) { + if !c.rd.IsZero() && c.rd.Before(time.Now()) { return ErrIOTimeout } } @@ -309,39 +296,39 @@ func (v *Connection) waitForDataInput() error { } // Read implements the Conn Read method. -func (v *Connection) Read(b []byte) (int, error) { - if v == nil { +func (c *Connection) Read(b []byte) (int, error) { + if c == nil { return 0, io.EOF } for { - if v.State().Is(StateReadyToClose, StateTerminating, StateTerminated) { + if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) { return 0, io.EOF } - nBytes := v.receivingWorker.Read(b) + nBytes := c.receivingWorker.Read(b) if nBytes > 0 { return nBytes, nil } - if err := v.waitForDataInput(); err != nil { + if err := c.waitForDataInput(); err != nil { return 0, err } } } -func (v *Connection) waitForDataOutput() error { +func (c *Connection) waitForDataOutput() error { duration := time.Minute - if !v.wd.IsZero() { - duration = time.Until(v.wd) + if !c.wd.IsZero() { + duration = time.Until(c.wd) if duration < 0 { return ErrIOTimeout } } select { - case <-v.dataOutput: + case <-c.dataOutput.Wait(): case <-time.After(duration): - if !v.wd.IsZero() && v.wd.Before(time.Now()) { + if !c.wd.IsZero() && c.wd.Before(time.Now()) { return ErrIOTimeout } } @@ -350,295 +337,290 @@ func (v *Connection) waitForDataOutput() error { } // Write implements io.Writer. -func (v *Connection) Write(b []byte) (int, error) { +func (c *Connection) Write(b []byte) (int, error) { totalWritten := 0 for { - if v == nil || v.State() != StateActive { + if c == nil || c.State() != StateActive { return totalWritten, io.ErrClosedPipe } - for v.sendingWorker.Push(func(bb []byte) (int, error) { - n := copy(bb[:v.mss], b[totalWritten:]) + for c.sendingWorker.Push(func(bb []byte) (int, error) { + n := copy(bb[:c.mss], b[totalWritten:]) totalWritten += n return n, nil }) { - v.dataUpdater.WakeUp() + c.dataUpdater.WakeUp() if totalWritten == len(b) { return totalWritten, nil } } - if err := v.waitForDataOutput(); err != nil { + if err := c.waitForDataOutput(); err != nil { return totalWritten, err } } } // WriteMultiBuffer implements buf.Writer. -func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { +func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { defer mb.Release() for { - if v == nil || v.State() != StateActive { + if c == nil || c.State() != StateActive { return io.ErrClosedPipe } - for v.sendingWorker.Push(func(bb []byte) (int, error) { - return mb.Read(bb[:v.mss]) + for c.sendingWorker.Push(func(bb []byte) (int, error) { + return mb.Read(bb[:c.mss]) }) { - v.dataUpdater.WakeUp() + c.dataUpdater.WakeUp() if mb.IsEmpty() { return nil } } - if err := v.waitForDataOutput(); err != nil { + if err := c.waitForDataOutput(); err != nil { return err } } } -func (v *Connection) SetState(state State) { - current := v.Elapsed() - atomic.StoreInt32((*int32)(&v.state), int32(state)) - atomic.StoreUint32(&v.stateBeginTime, current) - newError("#", v.meta.Conversation, " entering state ", state, " at ", current).AtDebug().WriteToLog() +func (c *Connection) SetState(state State) { + current := c.Elapsed() + atomic.StoreInt32((*int32)(&c.state), int32(state)) + atomic.StoreUint32(&c.stateBeginTime, current) + newError("#", c.meta.Conversation, " entering state ", state, " at ", current).AtDebug().WriteToLog() switch state { case StateReadyToClose: - v.receivingWorker.CloseRead() + c.receivingWorker.CloseRead() case StatePeerClosed: - v.sendingWorker.CloseWrite() + c.sendingWorker.CloseWrite() case StateTerminating: - v.receivingWorker.CloseRead() - v.sendingWorker.CloseWrite() - v.pingUpdater.SetInterval(time.Second) + c.receivingWorker.CloseRead() + c.sendingWorker.CloseWrite() + c.pingUpdater.SetInterval(time.Second) case StatePeerTerminating: - v.sendingWorker.CloseWrite() - v.pingUpdater.SetInterval(time.Second) + c.sendingWorker.CloseWrite() + c.pingUpdater.SetInterval(time.Second) case StateTerminated: - v.receivingWorker.CloseRead() - v.sendingWorker.CloseWrite() - v.pingUpdater.SetInterval(time.Second) - v.dataUpdater.WakeUp() - v.pingUpdater.WakeUp() - go v.Terminate() + c.receivingWorker.CloseRead() + c.sendingWorker.CloseWrite() + c.pingUpdater.SetInterval(time.Second) + c.dataUpdater.WakeUp() + c.pingUpdater.WakeUp() + go c.Terminate() } } // Close closes the connection. -func (v *Connection) Close() error { - if v == nil { +func (c *Connection) Close() error { + if c == nil { return ErrClosedConnection } - v.OnDataInput() - v.OnDataOutput() + c.dataInput.Signal() + c.dataOutput.Signal() - state := v.State() - if state.Is(StateReadyToClose, StateTerminating, StateTerminated) { + switch c.State() { + case StateReadyToClose, StateTerminating, StateTerminated: return ErrClosedConnection + case StateActive: + c.SetState(StateReadyToClose) + case StatePeerClosed: + c.SetState(StateTerminating) + case StatePeerTerminating: + c.SetState(StateTerminated) } - newError("closing connection to ", v.meta.RemoteAddr).WriteToLog() - if state == StateActive { - v.SetState(StateReadyToClose) - } - if state == StatePeerClosed { - v.SetState(StateTerminating) - } - if state == StatePeerTerminating { - v.SetState(StateTerminated) - } + newError("closing connection to ", c.meta.RemoteAddr).WriteToLog() return nil } // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. -func (v *Connection) LocalAddr() net.Addr { - if v == nil { +func (c *Connection) LocalAddr() net.Addr { + if c == nil { return nil } - return v.meta.LocalAddr + return c.meta.LocalAddr } // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. -func (v *Connection) RemoteAddr() net.Addr { - if v == nil { +func (c *Connection) RemoteAddr() net.Addr { + if c == nil { return nil } - return v.meta.RemoteAddr + return c.meta.RemoteAddr } // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. -func (v *Connection) SetDeadline(t time.Time) error { - if err := v.SetReadDeadline(t); err != nil { +func (c *Connection) SetDeadline(t time.Time) error { + if err := c.SetReadDeadline(t); err != nil { return err } - if err := v.SetWriteDeadline(t); err != nil { + if err := c.SetWriteDeadline(t); err != nil { return err } return nil } // SetReadDeadline implements the Conn SetReadDeadline method. -func (v *Connection) SetReadDeadline(t time.Time) error { - if v == nil || v.State() != StateActive { +func (c *Connection) SetReadDeadline(t time.Time) error { + if c == nil || c.State() != StateActive { return ErrClosedConnection } - v.rd = t + c.rd = t return nil } // SetWriteDeadline implements the Conn SetWriteDeadline method. -func (v *Connection) SetWriteDeadline(t time.Time) error { - if v == nil || v.State() != StateActive { +func (c *Connection) SetWriteDeadline(t time.Time) error { + if c == nil || c.State() != StateActive { return ErrClosedConnection } - v.wd = t + c.wd = t return nil } // kcp update, input loop -func (v *Connection) updateTask() { - v.flush() +func (c *Connection) updateTask() { + c.flush() } -func (v *Connection) Terminate() { - if v == nil { +func (c *Connection) Terminate() { + if c == nil { return } - newError("terminating connection to ", v.RemoteAddr()).WriteToLog() + newError("terminating connection to ", c.RemoteAddr()).WriteToLog() //v.SetState(StateTerminated) - v.OnDataInput() - v.OnDataOutput() + c.dataInput.Signal() + c.dataOutput.Signal() - v.closer.Close() - v.sendingWorker.Release() - v.receivingWorker.Release() + c.closer.Close() + c.sendingWorker.Release() + c.receivingWorker.Release() } -func (v *Connection) HandleOption(opt SegmentOption) { +func (c *Connection) HandleOption(opt SegmentOption) { if (opt & SegmentOptionClose) == SegmentOptionClose { - v.OnPeerClosed() + c.OnPeerClosed() } } -func (v *Connection) OnPeerClosed() { - state := v.State() - if state == StateReadyToClose { - v.SetState(StateTerminating) - } - if state == StateActive { - v.SetState(StatePeerClosed) +func (c *Connection) OnPeerClosed() { + switch c.State() { + case StateReadyToClose: + c.SetState(StateTerminating) + case StateActive: + c.SetState(StatePeerClosed) } } // Input when you received a low level packet (eg. UDP packet), call it -func (v *Connection) Input(segments []Segment) { - current := v.Elapsed() - atomic.StoreUint32(&v.lastIncomingTime, current) +func (c *Connection) Input(segments []Segment) { + current := c.Elapsed() + atomic.StoreUint32(&c.lastIncomingTime, current) for _, seg := range segments { - if seg.Conversation() != v.meta.Conversation { + if seg.Conversation() != c.meta.Conversation { break } switch seg := seg.(type) { case *DataSegment: - v.HandleOption(seg.Option) - v.receivingWorker.ProcessSegment(seg) - if v.receivingWorker.IsDataAvailable() { - v.OnDataInput() + c.HandleOption(seg.Option) + c.receivingWorker.ProcessSegment(seg) + if c.receivingWorker.IsDataAvailable() { + c.dataInput.Signal() } - v.dataUpdater.WakeUp() + c.dataUpdater.WakeUp() case *AckSegment: - v.HandleOption(seg.Option) - v.sendingWorker.ProcessSegment(current, seg, v.roundTrip.Timeout()) - v.OnDataOutput() - v.dataUpdater.WakeUp() + c.HandleOption(seg.Option) + c.sendingWorker.ProcessSegment(current, seg, c.roundTrip.Timeout()) + c.dataOutput.Signal() + c.dataUpdater.WakeUp() case *CmdOnlySegment: - v.HandleOption(seg.Option) + c.HandleOption(seg.Option) if seg.Command() == CommandTerminate { - state := v.State() - if state == StateActive || - state == StatePeerClosed { - v.SetState(StatePeerTerminating) - } else if state == StateReadyToClose { - v.SetState(StateTerminating) - } else if state == StateTerminating { - v.SetState(StateTerminated) + switch c.State() { + case StateActive, StatePeerClosed: + c.SetState(StatePeerTerminating) + case StateReadyToClose: + c.SetState(StateTerminating) + case StateTerminating: + c.SetState(StateTerminated) } } if seg.Option == SegmentOptionClose || seg.Command() == CommandTerminate { - v.OnDataInput() - v.OnDataOutput() + c.dataInput.Signal() + c.dataOutput.Signal() } - v.sendingWorker.ProcessReceivingNext(seg.ReceivinNext) - v.receivingWorker.ProcessSendingNext(seg.SendingNext) - v.roundTrip.UpdatePeerRTO(seg.PeerRTO, current) + c.sendingWorker.ProcessReceivingNext(seg.ReceivinNext) + c.receivingWorker.ProcessSendingNext(seg.SendingNext) + c.roundTrip.UpdatePeerRTO(seg.PeerRTO, current) seg.Release() default: } } } -func (v *Connection) flush() { - current := v.Elapsed() +func (c *Connection) flush() { + current := c.Elapsed() - if v.State() == StateTerminated { + if c.State() == StateTerminated { return } - if v.State() == StateActive && current-atomic.LoadUint32(&v.lastIncomingTime) >= 30000 { - v.Close() + if c.State() == StateActive && current-atomic.LoadUint32(&c.lastIncomingTime) >= 30000 { + c.Close() } - if v.State() == StateReadyToClose && v.sendingWorker.IsEmpty() { - v.SetState(StateTerminating) + if c.State() == StateReadyToClose && c.sendingWorker.IsEmpty() { + c.SetState(StateTerminating) } - if v.State() == StateTerminating { - newError("#", v.meta.Conversation, " sending terminating cmd.").AtDebug().WriteToLog() - v.Ping(current, CommandTerminate) + if c.State() == StateTerminating { + newError("#", c.meta.Conversation, " sending terminating cmd.").AtDebug().WriteToLog() + c.Ping(current, CommandTerminate) - if current-atomic.LoadUint32(&v.stateBeginTime) > 8000 { - v.SetState(StateTerminated) + if current-atomic.LoadUint32(&c.stateBeginTime) > 8000 { + c.SetState(StateTerminated) } return } - if v.State() == StatePeerTerminating && current-atomic.LoadUint32(&v.stateBeginTime) > 4000 { - v.SetState(StateTerminating) + if c.State() == StatePeerTerminating && current-atomic.LoadUint32(&c.stateBeginTime) > 4000 { + c.SetState(StateTerminating) } - if v.State() == StateReadyToClose && current-atomic.LoadUint32(&v.stateBeginTime) > 15000 { - v.SetState(StateTerminating) + if c.State() == StateReadyToClose && current-atomic.LoadUint32(&c.stateBeginTime) > 15000 { + c.SetState(StateTerminating) } // flush acknowledges - v.receivingWorker.Flush(current) - v.sendingWorker.Flush(current) + c.receivingWorker.Flush(current) + c.sendingWorker.Flush(current) - if current-atomic.LoadUint32(&v.lastPingTime) >= 3000 { - v.Ping(current, CommandPing) + if current-atomic.LoadUint32(&c.lastPingTime) >= 3000 { + c.Ping(current, CommandPing) } } -func (v *Connection) State() State { - return State(atomic.LoadInt32((*int32)(&v.state))) +func (c *Connection) State() State { + return State(atomic.LoadInt32((*int32)(&c.state))) } -func (v *Connection) Ping(current uint32, cmd Command) { +func (c *Connection) Ping(current uint32, cmd Command) { seg := NewCmdOnlySegment() - seg.Conv = v.meta.Conversation + seg.Conv = c.meta.Conversation seg.Cmd = cmd - seg.ReceivinNext = v.receivingWorker.NextNumber() - seg.SendingNext = v.sendingWorker.FirstUnacknowledged() - seg.PeerRTO = v.roundTrip.Timeout() - if v.State() == StateReadyToClose { + seg.ReceivinNext = c.receivingWorker.NextNumber() + seg.SendingNext = c.sendingWorker.FirstUnacknowledged() + seg.PeerRTO = c.roundTrip.Timeout() + if c.State() == StateReadyToClose { seg.Option = SegmentOptionClose } - v.output.Write(seg) - atomic.StoreUint32(&v.lastPingTime, current) + c.output.Write(seg) + atomic.StoreUint32(&c.lastPingTime, current) seg.Release() } diff --git a/transport/ray/direct.go b/transport/ray/direct.go index 2df368b49..efcb71cd4 100644 --- a/transport/ray/direct.go +++ b/transport/ray/direct.go @@ -9,6 +9,7 @@ import ( "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/platform" + "v2ray.com/core/common/signal" ) // NewRay creates a new Ray for direct traffic transport. @@ -57,8 +58,8 @@ type Stream struct { data buf.MultiBuffer size uint64 ctx context.Context - readSignal chan bool - writeSignal chan bool + readSignal *signal.Notifier + writeSignal *signal.Notifier close bool err bool } @@ -67,8 +68,8 @@ type Stream struct { func NewStream(ctx context.Context) *Stream { return &Stream{ ctx: ctx, - readSignal: make(chan bool, 1), - writeSignal: make(chan bool, 1), + readSignal: signal.NewNotifier(), + writeSignal: signal.NewNotifier(), size: 0, } } @@ -105,7 +106,7 @@ func (s *Stream) Peek(b *buf.Buffer) { })) } -// Read reads data from the Stream. +// ReadMultiBuffer reads data from the Stream. func (s *Stream) ReadMultiBuffer() (buf.MultiBuffer, error) { for { mb, err := s.getData() @@ -114,14 +115,14 @@ func (s *Stream) ReadMultiBuffer() (buf.MultiBuffer, error) { } if mb != nil { - s.notifyRead() + s.readSignal.Signal() return mb, nil } select { case <-s.ctx.Done(): return nil, io.EOF - case <-s.writeSignal: + case <-s.writeSignal.Wait(): } } } @@ -135,7 +136,7 @@ func (s *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) { } if mb != nil { - s.notifyRead() + s.readSignal.Signal() return mb, nil } @@ -144,7 +145,7 @@ func (s *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) { return nil, io.EOF case <-time.After(timeout): return nil, buf.ErrReadTimeout - case <-s.writeSignal: + case <-s.writeSignal.Wait(): } } } @@ -167,7 +168,7 @@ func (s *Stream) waitForStreamSize() error { select { case <-s.ctx.Done(): return io.ErrClosedPipe - case <-s.readSignal: + case <-s.readSignal.Wait(): if s.err || s.close { return io.ErrClosedPipe } @@ -177,7 +178,7 @@ func (s *Stream) waitForStreamSize() error { return nil } -// Write writes more data into the Stream. +// WriteMultiBuffer writes more data into the Stream. func (s *Stream) WriteMultiBuffer(data buf.MultiBuffer) error { if data.IsEmpty() { return nil @@ -202,31 +203,17 @@ func (s *Stream) WriteMultiBuffer(data buf.MultiBuffer) error { s.data.AppendMulti(data) s.size += uint64(data.Len()) - s.notifyWrite() + s.writeSignal.Signal() return nil } -func (s *Stream) notifyRead() { - select { - case s.readSignal <- true: - default: - } -} - -func (s *Stream) notifyWrite() { - select { - case s.writeSignal <- true: - default: - } -} - // Close closes the stream for writing. Read() still works until EOF. func (s *Stream) Close() { s.access.Lock() s.close = true - s.notifyRead() - s.notifyWrite() + s.readSignal.Signal() + s.writeSignal.Signal() s.access.Unlock() } @@ -239,7 +226,7 @@ func (s *Stream) CloseError() { s.data = nil s.size = 0 } - s.notifyRead() - s.notifyWrite() + s.readSignal.Signal() + s.writeSignal.Signal() s.access.Unlock() }