From 427679e66d54db4d2b83c7589816e303c6466035 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Thu, 6 Dec 2018 11:35:02 +0100 Subject: [PATCH] simplify task execution --- common/task/common.go | 3 +- common/task/task.go | 110 ++----------------------------- common/task/task_test.go | 34 ++++------ proxy/dokodemo/dokodemo.go | 5 +- proxy/freedom/freedom.go | 2 +- proxy/http/server.go | 6 +- proxy/mtproto/client.go | 4 +- proxy/mtproto/server.go | 4 +- proxy/shadowsocks/client.go | 8 +-- proxy/shadowsocks/server.go | 4 +- proxy/socks/client.go | 4 +- proxy/socks/server.go | 4 +- proxy/vmess/inbound/inbound.go | 4 +- proxy/vmess/outbound/outbound.go | 4 +- testing/servers/tcp/tcp.go | 4 +- 15 files changed, 46 insertions(+), 154 deletions(-) diff --git a/common/task/common.go b/common/task/common.go index 95d5d9431..42d3b8abe 100644 --- a/common/task/common.go +++ b/common/task/common.go @@ -2,7 +2,8 @@ package task import "v2ray.com/core/common" -func Close(v interface{}) Task { +// Close returns a func() that closes v. +func Close(v interface{}) func() error { return func() error { return common.Close(v) } diff --git a/common/task/task.go b/common/task/task.go index 1b6a53197..1fc52be50 100644 --- a/common/task/task.go +++ b/common/task/task.go @@ -6,121 +6,25 @@ import ( "v2ray.com/core/common/signal/semaphore" ) -type Task func() error - -type executionContext struct { - ctx context.Context - tasks []Task - onSuccess Task - onFailure Task -} - -func (c *executionContext) executeTask() error { - if len(c.tasks) == 0 { - return nil - } - - // Reuse current goroutine if we only have one task to run. - if len(c.tasks) == 1 && c.ctx == nil { - return c.tasks[0]() - } - - ctx := context.Background() - - if c.ctx != nil { - ctx = c.ctx - } - - return executeParallel(ctx, c.tasks) -} - -func (c *executionContext) run() error { - err := c.executeTask() - if err == nil && c.onSuccess != nil { - return c.onSuccess() - } - if err != nil && c.onFailure != nil { - return c.onFailure() - } - return err -} - -type ExecutionOption func(*executionContext) - -func WithContext(ctx context.Context) ExecutionOption { - return func(c *executionContext) { - c.ctx = ctx - } -} - -func Parallel(tasks ...Task) ExecutionOption { - return func(c *executionContext) { - c.tasks = append(c.tasks, tasks...) - } -} - -// Sequential runs all tasks sequentially, and returns the first error encountered.Sequential -// Once a task returns an error, the following tasks will not run. -func Sequential(tasks ...Task) ExecutionOption { - return func(c *executionContext) { - switch len(tasks) { - case 0: - return - case 1: - c.tasks = append(c.tasks, tasks[0]) - default: - c.tasks = append(c.tasks, func() error { - return execute(tasks...) - }) - } - } -} - -func OnSuccess(task Task) ExecutionOption { - return func(c *executionContext) { - c.onSuccess = task - } -} - -func OnFailure(task Task) ExecutionOption { - return func(c *executionContext) { - c.onFailure = task - } -} - -func Single(task Task, opts ...ExecutionOption) Task { - return Run(append([]ExecutionOption{Sequential(task)}, opts...)...) -} - -func Run(opts ...ExecutionOption) Task { - var c executionContext - for _, opt := range opts { - opt(&c) - } +// OnSuccess executes g() after f() returns nil. +func OnSuccess(f func() error, g func() error) func() error { return func() error { - return c.run() - } -} - -// execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass. -func execute(tasks ...Task) error { - for _, task := range tasks { - if err := task(); err != nil { + if err := f(); err != nil { return err } + return g() } - return nil } -// executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass. -func executeParallel(ctx context.Context, tasks []Task) error { +// Run executes a list of tasks in parallel, returns the first error encountered or nil if all tasks pass. +func Run(ctx context.Context, tasks ...func() error) error { n := len(tasks) s := semaphore.New(n) done := make(chan error, 1) for _, task := range tasks { <-s.Wait() - go func(f Task) { + go func(f func() error) { err := f() if err == nil { s.Signal() diff --git a/common/task/task_test.go b/common/task/task_test.go index 02335ba99..cc9a1963d 100644 --- a/common/task/task_test.go +++ b/common/task/task_test.go @@ -14,13 +14,14 @@ import ( func TestExecuteParallel(t *testing.T) { assert := With(t) - err := Run(Parallel(func() error { - time.Sleep(time.Millisecond * 200) - return errors.New("test") - }, func() error { - time.Sleep(time.Millisecond * 500) - return errors.New("test2") - }))() + err := Run(context.Background(), + func() error { + time.Sleep(time.Millisecond * 200) + return errors.New("test") + }, func() error { + time.Sleep(time.Millisecond * 500) + return errors.New("test2") + }) assert(err.Error(), Equals, "test") } @@ -29,7 +30,7 @@ func TestExecuteParallelContextCancel(t *testing.T) { assert := With(t) ctx, cancel := context.WithCancel(context.Background()) - err := Run(WithContext(ctx), Parallel(func() error { + err := Run(ctx, func() error { time.Sleep(time.Millisecond * 2000) return errors.New("test") }, func() error { @@ -38,7 +39,7 @@ func TestExecuteParallelContextCancel(t *testing.T) { }, func() error { cancel() return nil - }))() + }) assert(err.Error(), HasSubstring, "canceled") } @@ -48,7 +49,7 @@ func BenchmarkExecuteOne(b *testing.B) { return nil } for i := 0; i < b.N; i++ { - common.Must(Run(Parallel(noop))()) + common.Must(Run(context.Background(), noop)) } } @@ -57,17 +58,6 @@ func BenchmarkExecuteTwo(b *testing.B) { return nil } for i := 0; i < b.N; i++ { - common.Must(Run(Parallel(noop, noop))()) - } -} - -func BenchmarkExecuteContext(b *testing.B) { - noop := func() error { - return nil - } - background := context.Background() - - for i := 0; i < b.N; i++ { - common.Must(Run(WithContext(background), Parallel(noop, noop))()) + common.Must(Run(context.Background(), noop, noop)) } } diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index eab0566f5..117a9656d 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -147,10 +147,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in return nil } - if err := task.Run(task.WithContext(ctx), - task.Parallel( - task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))), - responseDone))(); err != nil { + if err := task.Run(ctx, task.OnSuccess(requestDone, task.Close(link.Writer)), responseDone); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index ab7fcf09b..49d05d148 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -167,7 +167,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return nil } - if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, task.Single(responseDone, task.OnSuccess(task.Close(output)))))(); err != nil { + if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/http/server.go b/proxy/http/server.go index 32e2eaadd..1695eb576 100755 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -210,8 +210,8 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade return nil } - var closeWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))) - if err := task.Run(task.WithContext(ctx), task.Parallel(closeWriter, responseDone))(); err != nil { + var closeWriter = task.OnSuccess(requestDone, task.Close(link.Writer)) + if err := task.Run(ctx, closeWriter, responseDone); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) @@ -307,7 +307,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri return nil } - if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDone))(); err != nil { + if err := task.Run(ctx, requestDone, responseDone); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/mtproto/client.go b/proxy/mtproto/client.go index 77d4624ac..6db10e048 100644 --- a/proxy/mtproto/client.go +++ b/proxy/mtproto/client.go @@ -62,8 +62,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter return buf.Copy(connReader, link.Writer) } - var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer))) - if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil { + var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer)) + if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/mtproto/server.go b/proxy/mtproto/server.go index ed6009e9f..7bc510b92 100644 --- a/proxy/mtproto/server.go +++ b/proxy/mtproto/server.go @@ -141,8 +141,8 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)) } - var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer))) - if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil { + var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer)) + if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index bd3a4965d..f6b339747 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -129,8 +129,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter return buf.Copy(responseReader, link.Writer, buf.UpdateActivity(timer)) } - var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer))) - if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil { + var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer)) + if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil { return newError("connection ends").Base(err) } @@ -167,8 +167,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter return nil } - var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer))) - if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil { + var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer)) + if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 9f53bffc7..57933781c 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -229,8 +229,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, return nil } - var requestDoneAndCloseWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))) - if err := task.Run(task.WithContext(ctx), task.Parallel(requestDoneAndCloseWriter, responseDone))(); err != nil { + var requestDoneAndCloseWriter = task.OnSuccess(requestDone, task.Close(link.Writer)) + if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/socks/client.go b/proxy/socks/client.go index acd2cef36..83dea6294 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -137,8 +137,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter } } - var responseDonePost = task.Single(responseFunc, task.OnSuccess(task.Close(link.Writer))) - if err := task.Run(task.WithContext(ctx), task.Parallel(requestFunc, responseDonePost))(); err != nil { + var responseDonePost = task.OnSuccess(responseFunc, task.Close(link.Writer)) + if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 7674ac870..a557f765b 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -164,8 +164,8 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ return nil } - var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))) - if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil { + var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer)) + if err := task.Run(ctx, requestDonePost, responseDone); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 5d6a05eb6..6bf89cfd0 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -302,8 +302,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i return transferResponse(timer, svrSession, request, response, link.Reader, writer) } - var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))) - if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil { + var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer)) + if err := task.Run(ctx, requestDonePost, responseDone); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 6ec89abeb..4c4c74745 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -161,8 +161,8 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return buf.Copy(bodyReader, output, buf.UpdateActivity(timer)) } - var responseDonePost = task.Single(responseDone, task.OnSuccess(task.Close(output))) - if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDonePost))(); err != nil { + var responseDonePost = task.OnSuccess(responseDone, task.Close(output)) + if err := task.Run(ctx, requestDone, responseDonePost); err != nil { return newError("connection ends").Base(err) } diff --git a/testing/servers/tcp/tcp.go b/testing/servers/tcp/tcp.go index 7d21fb567..10fc94bee 100644 --- a/testing/servers/tcp/tcp.go +++ b/testing/servers/tcp/tcp.go @@ -64,7 +64,7 @@ func (server *Server) handleConnection(conn net.Conn) { } pReader, pWriter := pipe.New(pipe.WithoutSizeLimit()) - err := task.Run(task.Parallel(func() error { + err := task.Run(context.Background(), func() error { defer pWriter.Close() // nolint: errcheck for { @@ -96,7 +96,7 @@ func (server *Server) handleConnection(conn net.Conn) { return err } } - }))() + }) if err != nil { fmt.Println("failed to transfer data: ", err.Error())