diff --git a/common/buf/copy.go b/common/buf/copy.go index 4046bbc34..873e187aa 100644 --- a/common/buf/copy.go +++ b/common/buf/copy.go @@ -54,7 +54,7 @@ func IgnoreWriterError() CopyOption { } } -func UpdateActivity(timer signal.ActivityTimer) CopyOption { +func UpdateActivity(timer signal.ActivityUpdater) CopyOption { return func(handler *copyHandler) { handler.onData = append(handler.onData, func(MultiBuffer) { timer.Update() diff --git a/common/signal/timer.go b/common/signal/timer.go index 44ac8a5c1..e33a43fba 100644 --- a/common/signal/timer.go +++ b/common/signal/timer.go @@ -5,26 +5,30 @@ import ( "time" ) -type ActivityTimer interface { +type ActivityUpdater interface { Update() } -type realActivityTimer struct { +type ActivityTimer struct { updated chan bool - timeout time.Duration + timeout chan time.Duration ctx context.Context cancel context.CancelFunc } -func (t *realActivityTimer) Update() { +func (t *ActivityTimer) Update() { select { case t.updated <- true: default: } } -func (t *realActivityTimer) run() { - ticker := time.NewTicker(t.timeout) +func (t *ActivityTimer) SetTimeout(timeout time.Duration) { + t.timeout <- timeout +} + +func (t *ActivityTimer) run() { + ticker := time.NewTicker(<-t.timeout) defer ticker.Stop() for { @@ -32,6 +36,9 @@ func (t *realActivityTimer) run() { case <-ticker.C: case <-t.ctx.Done(): return + case timeout := <-t.timeout: + ticker.Stop() + ticker = time.NewTicker(timeout) } select { @@ -44,14 +51,15 @@ func (t *realActivityTimer) run() { } } -func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.Context, ActivityTimer) { +func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.Context, *ActivityTimer) { ctx, cancel := context.WithCancel(ctx) - timer := &realActivityTimer{ + timer := &ActivityTimer{ ctx: ctx, cancel: cancel, - timeout: timeout, + timeout: make(chan time.Duration, 1), updated: make(chan bool, 1), } + timer.timeout <- timeout go timer.run() return ctx, timer } diff --git a/common/signal/timer_test.go b/common/signal/timer_test.go new file mode 100644 index 000000000..8c5f57e6b --- /dev/null +++ b/common/signal/timer_test.go @@ -0,0 +1,32 @@ +package signal_test + +import ( + "context" + "runtime" + "testing" + "time" + + . "v2ray.com/core/common/signal" + "v2ray.com/core/testing/assert" +) + +func TestActivityTimer(t *testing.T) { + assert := assert.On(t) + + ctx, timer := CancelAfterInactivity(context.Background(), time.Second*5) + time.Sleep(time.Second * 6) + assert.Error(ctx.Err()).IsNotNil() + runtime.KeepAlive(timer) +} + +func TestActivityTimerUpdate(t *testing.T) { + assert := assert.On(t) + + ctx, timer := CancelAfterInactivity(context.Background(), time.Second*10) + time.Sleep(time.Second * 3) + assert.Error(ctx.Err()).IsNil() + timer.SetTimeout(time.Second * 1) + time.Sleep(time.Second * 2) + assert.Error(ctx.Err()).IsNotNil() + runtime.KeepAlive(timer) +} diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 5b031aa25..33bcb120f 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -94,6 +94,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in if err := buf.Copy(inboundRay.InboundOutput(), writer, buf.UpdateActivity(timer)); err != nil { return newError("failed to transport response").Base(err) } + + timer.SetTimeout(time.Second * 2) + return nil }) diff --git a/proxy/http/server.go b/proxy/http/server.go index 28e423b42..a515e7072 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -148,6 +148,7 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade if err := buf.Copy(ray.InboundOutput(), v2writer, buf.UpdateActivity(timer)); err != nil { return err } + timer.SetTimeout(time.Second * 2) return nil }) diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 62a86e4ff..509291dc3 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -177,6 +177,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, return newError("failed to transport all TCP response").Base(err) } + timer.SetTimeout(time.Second * 2) + return nil }) diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 8103452ff..a7e3a0ca4 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -135,6 +135,7 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ if err := buf.Copy(output, v2writer, buf.UpdateActivity(timer)); err != nil { return newError("failed to transport all TCP response").Base(err) } + timer.SetTimeout(time.Second * 2) return nil }) diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index bbb03b423..9cc3c2753 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -127,7 +127,7 @@ func (v *Handler) GetUser(email string) *protocol.User { return user } -func transferRequest(timer signal.ActivityTimer, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error { +func transferRequest(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error { defer output.Close() bodyReader := session.DecodeRequestBody(request, input) @@ -137,7 +137,7 @@ func transferRequest(timer signal.ActivityTimer, session *encoding.ServerSession return nil } -func transferResponse(timer signal.ActivityTimer, session *encoding.ServerSession, request *protocol.RequestHeader, response *protocol.ResponseHeader, input buf.Reader, output io.Writer) error { +func transferResponse(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, response *protocol.ResponseHeader, input buf.Reader, output io.Writer) error { session.EncodeResponseHeader(response, output) bodyWriter := session.EncodeResponseBody(request, output)