mirror of
https://github.com/v2fly/v2ray-core.git
synced 2025-05-18 13:09:04 -04:00
simplify task execution
This commit is contained in:
parent
cf1705267e
commit
427679e66d
@ -2,7 +2,8 @@ package task
|
|||||||
|
|
||||||
import "v2ray.com/core/common"
|
import "v2ray.com/core/common"
|
||||||
|
|
||||||
func Close(v interface{}) Task {
|
// Close returns a func() that closes v.
|
||||||
|
func Close(v interface{}) func() error {
|
||||||
return func() error {
|
return func() error {
|
||||||
return common.Close(v)
|
return common.Close(v)
|
||||||
}
|
}
|
||||||
|
@ -6,121 +6,25 @@ import (
|
|||||||
"v2ray.com/core/common/signal/semaphore"
|
"v2ray.com/core/common/signal/semaphore"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Task func() error
|
// OnSuccess executes g() after f() returns nil.
|
||||||
|
func OnSuccess(f func() error, g func() error) func() error {
|
||||||
type executionContext struct {
|
|
||||||
ctx context.Context
|
|
||||||
tasks []Task
|
|
||||||
onSuccess Task
|
|
||||||
onFailure Task
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *executionContext) executeTask() error {
|
|
||||||
if len(c.tasks) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reuse current goroutine if we only have one task to run.
|
|
||||||
if len(c.tasks) == 1 && c.ctx == nil {
|
|
||||||
return c.tasks[0]()
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
if c.ctx != nil {
|
|
||||||
ctx = c.ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
return executeParallel(ctx, c.tasks)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *executionContext) run() error {
|
|
||||||
err := c.executeTask()
|
|
||||||
if err == nil && c.onSuccess != nil {
|
|
||||||
return c.onSuccess()
|
|
||||||
}
|
|
||||||
if err != nil && c.onFailure != nil {
|
|
||||||
return c.onFailure()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
type ExecutionOption func(*executionContext)
|
|
||||||
|
|
||||||
func WithContext(ctx context.Context) ExecutionOption {
|
|
||||||
return func(c *executionContext) {
|
|
||||||
c.ctx = ctx
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Parallel(tasks ...Task) ExecutionOption {
|
|
||||||
return func(c *executionContext) {
|
|
||||||
c.tasks = append(c.tasks, tasks...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sequential runs all tasks sequentially, and returns the first error encountered.Sequential
|
|
||||||
// Once a task returns an error, the following tasks will not run.
|
|
||||||
func Sequential(tasks ...Task) ExecutionOption {
|
|
||||||
return func(c *executionContext) {
|
|
||||||
switch len(tasks) {
|
|
||||||
case 0:
|
|
||||||
return
|
|
||||||
case 1:
|
|
||||||
c.tasks = append(c.tasks, tasks[0])
|
|
||||||
default:
|
|
||||||
c.tasks = append(c.tasks, func() error {
|
|
||||||
return execute(tasks...)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func OnSuccess(task Task) ExecutionOption {
|
|
||||||
return func(c *executionContext) {
|
|
||||||
c.onSuccess = task
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func OnFailure(task Task) ExecutionOption {
|
|
||||||
return func(c *executionContext) {
|
|
||||||
c.onFailure = task
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Single(task Task, opts ...ExecutionOption) Task {
|
|
||||||
return Run(append([]ExecutionOption{Sequential(task)}, opts...)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Run(opts ...ExecutionOption) Task {
|
|
||||||
var c executionContext
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(&c)
|
|
||||||
}
|
|
||||||
return func() error {
|
return func() error {
|
||||||
return c.run()
|
if err := f(); err != nil {
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
|
|
||||||
func execute(tasks ...Task) error {
|
|
||||||
for _, task := range tasks {
|
|
||||||
if err := task(); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return g()
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
|
// Run executes a list of tasks in parallel, returns the first error encountered or nil if all tasks pass.
|
||||||
func executeParallel(ctx context.Context, tasks []Task) error {
|
func Run(ctx context.Context, tasks ...func() error) error {
|
||||||
n := len(tasks)
|
n := len(tasks)
|
||||||
s := semaphore.New(n)
|
s := semaphore.New(n)
|
||||||
done := make(chan error, 1)
|
done := make(chan error, 1)
|
||||||
|
|
||||||
for _, task := range tasks {
|
for _, task := range tasks {
|
||||||
<-s.Wait()
|
<-s.Wait()
|
||||||
go func(f Task) {
|
go func(f func() error) {
|
||||||
err := f()
|
err := f()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s.Signal()
|
s.Signal()
|
||||||
|
@ -14,13 +14,14 @@ import (
|
|||||||
func TestExecuteParallel(t *testing.T) {
|
func TestExecuteParallel(t *testing.T) {
|
||||||
assert := With(t)
|
assert := With(t)
|
||||||
|
|
||||||
err := Run(Parallel(func() error {
|
err := Run(context.Background(),
|
||||||
time.Sleep(time.Millisecond * 200)
|
func() error {
|
||||||
return errors.New("test")
|
time.Sleep(time.Millisecond * 200)
|
||||||
}, func() error {
|
return errors.New("test")
|
||||||
time.Sleep(time.Millisecond * 500)
|
}, func() error {
|
||||||
return errors.New("test2")
|
time.Sleep(time.Millisecond * 500)
|
||||||
}))()
|
return errors.New("test2")
|
||||||
|
})
|
||||||
|
|
||||||
assert(err.Error(), Equals, "test")
|
assert(err.Error(), Equals, "test")
|
||||||
}
|
}
|
||||||
@ -29,7 +30,7 @@ func TestExecuteParallelContextCancel(t *testing.T) {
|
|||||||
assert := With(t)
|
assert := With(t)
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
err := Run(WithContext(ctx), Parallel(func() error {
|
err := Run(ctx, func() error {
|
||||||
time.Sleep(time.Millisecond * 2000)
|
time.Sleep(time.Millisecond * 2000)
|
||||||
return errors.New("test")
|
return errors.New("test")
|
||||||
}, func() error {
|
}, func() error {
|
||||||
@ -38,7 +39,7 @@ func TestExecuteParallelContextCancel(t *testing.T) {
|
|||||||
}, func() error {
|
}, func() error {
|
||||||
cancel()
|
cancel()
|
||||||
return nil
|
return nil
|
||||||
}))()
|
})
|
||||||
|
|
||||||
assert(err.Error(), HasSubstring, "canceled")
|
assert(err.Error(), HasSubstring, "canceled")
|
||||||
}
|
}
|
||||||
@ -48,7 +49,7 @@ func BenchmarkExecuteOne(b *testing.B) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
common.Must(Run(Parallel(noop))())
|
common.Must(Run(context.Background(), noop))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,17 +58,6 @@ func BenchmarkExecuteTwo(b *testing.B) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
common.Must(Run(Parallel(noop, noop))())
|
common.Must(Run(context.Background(), noop, noop))
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkExecuteContext(b *testing.B) {
|
|
||||||
noop := func() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
background := context.Background()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
common.Must(Run(WithContext(background), Parallel(noop, noop))())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -147,10 +147,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := task.Run(task.WithContext(ctx),
|
if err := task.Run(ctx, task.OnSuccess(requestDone, task.Close(link.Writer)), responseDone); err != nil {
|
||||||
task.Parallel(
|
|
||||||
task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))),
|
|
||||||
responseDone))(); err != nil {
|
|
||||||
pipe.CloseError(link.Reader)
|
pipe.CloseError(link.Reader)
|
||||||
pipe.CloseError(link.Writer)
|
pipe.CloseError(link.Writer)
|
||||||
return newError("connection ends").Base(err)
|
return newError("connection ends").Base(err)
|
||||||
|
@ -167,7 +167,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, task.Single(responseDone, task.OnSuccess(task.Close(output)))))(); err != nil {
|
if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil {
|
||||||
return newError("connection ends").Base(err)
|
return newError("connection ends").Base(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,8 +210,8 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var closeWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
|
var closeWriter = task.OnSuccess(requestDone, task.Close(link.Writer))
|
||||||
if err := task.Run(task.WithContext(ctx), task.Parallel(closeWriter, responseDone))(); err != nil {
|
if err := task.Run(ctx, closeWriter, responseDone); err != nil {
|
||||||
pipe.CloseError(link.Reader)
|
pipe.CloseError(link.Reader)
|
||||||
pipe.CloseError(link.Writer)
|
pipe.CloseError(link.Writer)
|
||||||
return newError("connection ends").Base(err)
|
return newError("connection ends").Base(err)
|
||||||
@ -307,7 +307,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDone))(); err != nil {
|
if err := task.Run(ctx, requestDone, responseDone); err != nil {
|
||||||
pipe.CloseError(link.Reader)
|
pipe.CloseError(link.Reader)
|
||||||
pipe.CloseError(link.Writer)
|
pipe.CloseError(link.Writer)
|
||||||
return newError("connection ends").Base(err)
|
return newError("connection ends").Base(err)
|
||||||
|
@ -62,8 +62,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
|
|||||||
return buf.Copy(connReader, link.Writer)
|
return buf.Copy(connReader, link.Writer)
|
||||||
}
|
}
|
||||||
|
|
||||||
var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer)))
|
var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer))
|
||||||
if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil {
|
if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil {
|
||||||
return newError("connection ends").Base(err)
|
return newError("connection ends").Base(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,8 +141,8 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
|
|||||||
return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer))
|
return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer))
|
||||||
}
|
}
|
||||||
|
|
||||||
var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer)))
|
var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer))
|
||||||
if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil {
|
if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil {
|
||||||
pipe.CloseError(link.Reader)
|
pipe.CloseError(link.Reader)
|
||||||
pipe.CloseError(link.Writer)
|
pipe.CloseError(link.Writer)
|
||||||
return newError("connection ends").Base(err)
|
return newError("connection ends").Base(err)
|
||||||
|
@ -129,8 +129,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
|
|||||||
return buf.Copy(responseReader, link.Writer, buf.UpdateActivity(timer))
|
return buf.Copy(responseReader, link.Writer, buf.UpdateActivity(timer))
|
||||||
}
|
}
|
||||||
|
|
||||||
var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
|
var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer))
|
||||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
|
if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
|
||||||
return newError("connection ends").Base(err)
|
return newError("connection ends").Base(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -167,8 +167,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
|
var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer))
|
||||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
|
if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
|
||||||
return newError("connection ends").Base(err)
|
return newError("connection ends").Base(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -229,8 +229,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestDoneAndCloseWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
|
var requestDoneAndCloseWriter = task.OnSuccess(requestDone, task.Close(link.Writer))
|
||||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDoneAndCloseWriter, responseDone))(); err != nil {
|
if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil {
|
||||||
pipe.CloseError(link.Reader)
|
pipe.CloseError(link.Reader)
|
||||||
pipe.CloseError(link.Writer)
|
pipe.CloseError(link.Writer)
|
||||||
return newError("connection ends").Base(err)
|
return newError("connection ends").Base(err)
|
||||||
|
@ -137,8 +137,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var responseDonePost = task.Single(responseFunc, task.OnSuccess(task.Close(link.Writer)))
|
var responseDonePost = task.OnSuccess(responseFunc, task.Close(link.Writer))
|
||||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestFunc, responseDonePost))(); err != nil {
|
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
|
||||||
return newError("connection ends").Base(err)
|
return newError("connection ends").Base(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,8 +164,8 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
|
var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
|
||||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
|
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
|
||||||
pipe.CloseError(link.Reader)
|
pipe.CloseError(link.Reader)
|
||||||
pipe.CloseError(link.Writer)
|
pipe.CloseError(link.Writer)
|
||||||
return newError("connection ends").Base(err)
|
return newError("connection ends").Base(err)
|
||||||
|
@ -302,8 +302,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
|
|||||||
return transferResponse(timer, svrSession, request, response, link.Reader, writer)
|
return transferResponse(timer, svrSession, request, response, link.Reader, writer)
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
|
var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
|
||||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
|
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
|
||||||
pipe.CloseError(link.Reader)
|
pipe.CloseError(link.Reader)
|
||||||
pipe.CloseError(link.Writer)
|
pipe.CloseError(link.Writer)
|
||||||
return newError("connection ends").Base(err)
|
return newError("connection ends").Base(err)
|
||||||
|
@ -161,8 +161,8 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
|||||||
return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
|
return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
|
||||||
}
|
}
|
||||||
|
|
||||||
var responseDonePost = task.Single(responseDone, task.OnSuccess(task.Close(output)))
|
var responseDonePost = task.OnSuccess(responseDone, task.Close(output))
|
||||||
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDonePost))(); err != nil {
|
if err := task.Run(ctx, requestDone, responseDonePost); err != nil {
|
||||||
return newError("connection ends").Base(err)
|
return newError("connection ends").Base(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ func (server *Server) handleConnection(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pReader, pWriter := pipe.New(pipe.WithoutSizeLimit())
|
pReader, pWriter := pipe.New(pipe.WithoutSizeLimit())
|
||||||
err := task.Run(task.Parallel(func() error {
|
err := task.Run(context.Background(), func() error {
|
||||||
defer pWriter.Close() // nolint: errcheck
|
defer pWriter.Close() // nolint: errcheck
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -96,7 +96,7 @@ func (server *Server) handleConnection(conn net.Conn) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}))()
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("failed to transfer data: ", err.Error())
|
fmt.Println("failed to transfer data: ", err.Error())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user