1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-02 15:36:41 -05:00

aggressively close connection when response is done

This commit is contained in:
Darien Raymond 2017-09-27 15:29:00 +02:00
parent 2ce1d8ffa2
commit 109a37fe7e
8 changed files with 59 additions and 12 deletions

View File

@ -54,7 +54,7 @@ func IgnoreWriterError() CopyOption {
} }
} }
func UpdateActivity(timer signal.ActivityTimer) CopyOption { func UpdateActivity(timer signal.ActivityUpdater) CopyOption {
return func(handler *copyHandler) { return func(handler *copyHandler) {
handler.onData = append(handler.onData, func(MultiBuffer) { handler.onData = append(handler.onData, func(MultiBuffer) {
timer.Update() timer.Update()

View File

@ -5,26 +5,30 @@ import (
"time" "time"
) )
type ActivityTimer interface { type ActivityUpdater interface {
Update() Update()
} }
type realActivityTimer struct { type ActivityTimer struct {
updated chan bool updated chan bool
timeout time.Duration timeout chan time.Duration
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
} }
func (t *realActivityTimer) Update() { func (t *ActivityTimer) Update() {
select { select {
case t.updated <- true: case t.updated <- true:
default: default:
} }
} }
func (t *realActivityTimer) run() { func (t *ActivityTimer) SetTimeout(timeout time.Duration) {
ticker := time.NewTicker(t.timeout) t.timeout <- timeout
}
func (t *ActivityTimer) run() {
ticker := time.NewTicker(<-t.timeout)
defer ticker.Stop() defer ticker.Stop()
for { for {
@ -32,6 +36,9 @@ func (t *realActivityTimer) run() {
case <-ticker.C: case <-ticker.C:
case <-t.ctx.Done(): case <-t.ctx.Done():
return return
case timeout := <-t.timeout:
ticker.Stop()
ticker = time.NewTicker(timeout)
} }
select { select {
@ -44,14 +51,15 @@ func (t *realActivityTimer) run() {
} }
} }
func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.Context, ActivityTimer) { func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.Context, *ActivityTimer) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
timer := &realActivityTimer{ timer := &ActivityTimer{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
timeout: timeout, timeout: make(chan time.Duration, 1),
updated: make(chan bool, 1), updated: make(chan bool, 1),
} }
timer.timeout <- timeout
go timer.run() go timer.run()
return ctx, timer return ctx, timer
} }

View File

@ -0,0 +1,32 @@
package signal_test
import (
"context"
"runtime"
"testing"
"time"
. "v2ray.com/core/common/signal"
"v2ray.com/core/testing/assert"
)
func TestActivityTimer(t *testing.T) {
assert := assert.On(t)
ctx, timer := CancelAfterInactivity(context.Background(), time.Second*5)
time.Sleep(time.Second * 6)
assert.Error(ctx.Err()).IsNotNil()
runtime.KeepAlive(timer)
}
func TestActivityTimerUpdate(t *testing.T) {
assert := assert.On(t)
ctx, timer := CancelAfterInactivity(context.Background(), time.Second*10)
time.Sleep(time.Second * 3)
assert.Error(ctx.Err()).IsNil()
timer.SetTimeout(time.Second * 1)
time.Sleep(time.Second * 2)
assert.Error(ctx.Err()).IsNotNil()
runtime.KeepAlive(timer)
}

View File

@ -94,6 +94,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
if err := buf.Copy(inboundRay.InboundOutput(), writer, buf.UpdateActivity(timer)); err != nil { if err := buf.Copy(inboundRay.InboundOutput(), writer, buf.UpdateActivity(timer)); err != nil {
return newError("failed to transport response").Base(err) return newError("failed to transport response").Base(err)
} }
timer.SetTimeout(time.Second * 2)
return nil return nil
}) })

View File

@ -148,6 +148,7 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
if err := buf.Copy(ray.InboundOutput(), v2writer, buf.UpdateActivity(timer)); err != nil { if err := buf.Copy(ray.InboundOutput(), v2writer, buf.UpdateActivity(timer)); err != nil {
return err return err
} }
timer.SetTimeout(time.Second * 2)
return nil return nil
}) })

View File

@ -177,6 +177,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
return newError("failed to transport all TCP response").Base(err) return newError("failed to transport all TCP response").Base(err)
} }
timer.SetTimeout(time.Second * 2)
return nil return nil
}) })

View File

@ -135,6 +135,7 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
if err := buf.Copy(output, v2writer, buf.UpdateActivity(timer)); err != nil { if err := buf.Copy(output, v2writer, buf.UpdateActivity(timer)); err != nil {
return newError("failed to transport all TCP response").Base(err) return newError("failed to transport all TCP response").Base(err)
} }
timer.SetTimeout(time.Second * 2)
return nil return nil
}) })

View File

@ -127,7 +127,7 @@ func (v *Handler) GetUser(email string) *protocol.User {
return user return user
} }
func transferRequest(timer signal.ActivityTimer, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error { func transferRequest(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error {
defer output.Close() defer output.Close()
bodyReader := session.DecodeRequestBody(request, input) bodyReader := session.DecodeRequestBody(request, input)
@ -137,7 +137,7 @@ func transferRequest(timer signal.ActivityTimer, session *encoding.ServerSession
return nil return nil
} }
func transferResponse(timer signal.ActivityTimer, session *encoding.ServerSession, request *protocol.RequestHeader, response *protocol.ResponseHeader, input buf.Reader, output io.Writer) error { func transferResponse(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, response *protocol.ResponseHeader, input buf.Reader, output io.Writer) error {
session.EncodeResponseHeader(response, output) session.EncodeResponseHeader(response, output)
bodyWriter := session.EncodeResponseBody(request, output) bodyWriter := session.EncodeResponseBody(request, output)