simplify task execution

This commit is contained in:
Darien Raymond 2018-12-06 11:35:02 +01:00
parent cf1705267e
commit 427679e66d
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
15 changed files with 46 additions and 154 deletions

View File

@ -2,7 +2,8 @@ package task
import "v2ray.com/core/common"
func Close(v interface{}) Task {
// Close returns a func() that closes v.
func Close(v interface{}) func() error {
return func() error {
return common.Close(v)
}

View File

@ -6,121 +6,25 @@ import (
"v2ray.com/core/common/signal/semaphore"
)
type Task func() error
type executionContext struct {
ctx context.Context
tasks []Task
onSuccess Task
onFailure Task
}
func (c *executionContext) executeTask() error {
if len(c.tasks) == 0 {
return nil
}
// Reuse current goroutine if we only have one task to run.
if len(c.tasks) == 1 && c.ctx == nil {
return c.tasks[0]()
}
ctx := context.Background()
if c.ctx != nil {
ctx = c.ctx
}
return executeParallel(ctx, c.tasks)
}
func (c *executionContext) run() error {
err := c.executeTask()
if err == nil && c.onSuccess != nil {
return c.onSuccess()
}
if err != nil && c.onFailure != nil {
return c.onFailure()
}
return err
}
type ExecutionOption func(*executionContext)
func WithContext(ctx context.Context) ExecutionOption {
return func(c *executionContext) {
c.ctx = ctx
}
}
func Parallel(tasks ...Task) ExecutionOption {
return func(c *executionContext) {
c.tasks = append(c.tasks, tasks...)
}
}
// Sequential runs all tasks sequentially, and returns the first error encountered.Sequential
// Once a task returns an error, the following tasks will not run.
func Sequential(tasks ...Task) ExecutionOption {
return func(c *executionContext) {
switch len(tasks) {
case 0:
return
case 1:
c.tasks = append(c.tasks, tasks[0])
default:
c.tasks = append(c.tasks, func() error {
return execute(tasks...)
})
}
}
}
func OnSuccess(task Task) ExecutionOption {
return func(c *executionContext) {
c.onSuccess = task
}
}
func OnFailure(task Task) ExecutionOption {
return func(c *executionContext) {
c.onFailure = task
}
}
func Single(task Task, opts ...ExecutionOption) Task {
return Run(append([]ExecutionOption{Sequential(task)}, opts...)...)
}
func Run(opts ...ExecutionOption) Task {
var c executionContext
for _, opt := range opts {
opt(&c)
}
// OnSuccess executes g() after f() returns nil.
func OnSuccess(f func() error, g func() error) func() error {
return func() error {
return c.run()
}
}
// execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
func execute(tasks ...Task) error {
for _, task := range tasks {
if err := task(); err != nil {
if err := f(); err != nil {
return err
}
return g()
}
return nil
}
// executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
func executeParallel(ctx context.Context, tasks []Task) error {
// Run executes a list of tasks in parallel, returns the first error encountered or nil if all tasks pass.
func Run(ctx context.Context, tasks ...func() error) error {
n := len(tasks)
s := semaphore.New(n)
done := make(chan error, 1)
for _, task := range tasks {
<-s.Wait()
go func(f Task) {
go func(f func() error) {
err := f()
if err == nil {
s.Signal()

View File

@ -14,13 +14,14 @@ import (
func TestExecuteParallel(t *testing.T) {
assert := With(t)
err := Run(Parallel(func() error {
time.Sleep(time.Millisecond * 200)
return errors.New("test")
}, func() error {
time.Sleep(time.Millisecond * 500)
return errors.New("test2")
}))()
err := Run(context.Background(),
func() error {
time.Sleep(time.Millisecond * 200)
return errors.New("test")
}, func() error {
time.Sleep(time.Millisecond * 500)
return errors.New("test2")
})
assert(err.Error(), Equals, "test")
}
@ -29,7 +30,7 @@ func TestExecuteParallelContextCancel(t *testing.T) {
assert := With(t)
ctx, cancel := context.WithCancel(context.Background())
err := Run(WithContext(ctx), Parallel(func() error {
err := Run(ctx, func() error {
time.Sleep(time.Millisecond * 2000)
return errors.New("test")
}, func() error {
@ -38,7 +39,7 @@ func TestExecuteParallelContextCancel(t *testing.T) {
}, func() error {
cancel()
return nil
}))()
})
assert(err.Error(), HasSubstring, "canceled")
}
@ -48,7 +49,7 @@ func BenchmarkExecuteOne(b *testing.B) {
return nil
}
for i := 0; i < b.N; i++ {
common.Must(Run(Parallel(noop))())
common.Must(Run(context.Background(), noop))
}
}
@ -57,17 +58,6 @@ func BenchmarkExecuteTwo(b *testing.B) {
return nil
}
for i := 0; i < b.N; i++ {
common.Must(Run(Parallel(noop, noop))())
}
}
func BenchmarkExecuteContext(b *testing.B) {
noop := func() error {
return nil
}
background := context.Background()
for i := 0; i < b.N; i++ {
common.Must(Run(WithContext(background), Parallel(noop, noop))())
common.Must(Run(context.Background(), noop, noop))
}
}

View File

@ -147,10 +147,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
return nil
}
if err := task.Run(task.WithContext(ctx),
task.Parallel(
task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))),
responseDone))(); err != nil {
if err := task.Run(ctx, task.OnSuccess(requestDone, task.Close(link.Writer)), responseDone); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)

View File

@ -167,7 +167,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
return nil
}
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, task.Single(responseDone, task.OnSuccess(task.Close(output)))))(); err != nil {
if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil {
return newError("connection ends").Base(err)
}

View File

@ -210,8 +210,8 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
return nil
}
var closeWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(closeWriter, responseDone))(); err != nil {
var closeWriter = task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, closeWriter, responseDone); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)
@ -307,7 +307,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
return nil
}
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDone))(); err != nil {
if err := task.Run(ctx, requestDone, responseDone); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)

View File

@ -62,8 +62,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return buf.Copy(connReader, link.Writer)
}
var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil {
var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer))
if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil {
return newError("connection ends").Base(err)
}

View File

@ -141,8 +141,8 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer))
}
var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil {
var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer))
if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)

View File

@ -129,8 +129,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return buf.Copy(responseReader, link.Writer, buf.UpdateActivity(timer))
}
var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
return newError("connection ends").Base(err)
}
@ -167,8 +167,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return nil
}
var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
return newError("connection ends").Base(err)
}

View File

@ -229,8 +229,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
return nil
}
var requestDoneAndCloseWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDoneAndCloseWriter, responseDone))(); err != nil {
var requestDoneAndCloseWriter = task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)

View File

@ -137,8 +137,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
}
}
var responseDonePost = task.Single(responseFunc, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestFunc, responseDonePost))(); err != nil {
var responseDonePost = task.OnSuccess(responseFunc, task.Close(link.Writer))
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
return newError("connection ends").Base(err)
}

View File

@ -164,8 +164,8 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
return nil
}
var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)

View File

@ -302,8 +302,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
return transferResponse(timer, svrSession, request, response, link.Reader, writer)
}
var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)

View File

@ -161,8 +161,8 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
}
var responseDonePost = task.Single(responseDone, task.OnSuccess(task.Close(output)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDonePost))(); err != nil {
var responseDonePost = task.OnSuccess(responseDone, task.Close(output))
if err := task.Run(ctx, requestDone, responseDonePost); err != nil {
return newError("connection ends").Base(err)
}

View File

@ -64,7 +64,7 @@ func (server *Server) handleConnection(conn net.Conn) {
}
pReader, pWriter := pipe.New(pipe.WithoutSizeLimit())
err := task.Run(task.Parallel(func() error {
err := task.Run(context.Background(), func() error {
defer pWriter.Close() // nolint: errcheck
for {
@ -96,7 +96,7 @@ func (server *Server) handleConnection(conn net.Conn) {
return err
}
}
}))()
})
if err != nil {
fmt.Println("failed to transfer data: ", err.Error())