From 13f3c356caebedc4300fce81470fbf6a1121aa5c Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sun, 27 May 2018 13:02:29 +0200 Subject: [PATCH] unified task package --- app/dns/nameserver.go | 6 +- app/dns/server.go | 6 +- app/proxyman/inbound/dynamic.go | 6 +- common/functions/functions.go | 23 --- common/task/common.go | 9 ++ common/{signal/task.go => task/periodic.go} | 12 +- .../task_test.go => task/periodic_test.go} | 9 +- common/task/task.go | 137 ++++++++++++++++++ common/task/task_test.go | 43 ++++++ proxy/dokodemo/dokodemo.go | 7 +- proxy/freedom/freedom.go | 4 +- proxy/http/client.go | 2 + proxy/http/server.go | 6 +- proxy/shadowsocks/client.go | 6 +- proxy/shadowsocks/server.go | 6 +- proxy/socks/client.go | 6 +- proxy/socks/server.go | 6 +- proxy/vmess/encoding/server.go | 6 +- proxy/vmess/inbound/inbound.go | 6 +- proxy/vmess/outbound/outbound.go | 6 +- proxy/vmess/vmess.go | 6 +- 21 files changed, 252 insertions(+), 66 deletions(-) delete mode 100644 common/functions/functions.go create mode 100644 common/task/common.go rename common/{signal/task.go => task/periodic.go} (81%) rename common/{signal/task_test.go => task/periodic_test.go} (84%) create mode 100644 common/task/task.go create mode 100644 common/task/task_test.go diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index 1d7667124..da7a11ac4 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -11,7 +11,7 @@ import ( "v2ray.com/core/common/buf" "v2ray.com/core/common/dice" "v2ray.com/core/common/net" - "v2ray.com/core/common/signal" + "v2ray.com/core/common/task" "v2ray.com/core/transport/internet/udp" ) @@ -42,7 +42,7 @@ type UDPNameServer struct { address net.Destination requests map[uint16]*PendingRequest udpServer *udp.Dispatcher - cleanup *signal.PeriodicTask + cleanup *task.Periodic } func NewUDPNameServer(address net.Destination, dispatcher core.Dispatcher) *UDPNameServer { @@ -51,7 +51,7 @@ func NewUDPNameServer(address net.Destination, dispatcher core.Dispatcher) *UDPN requests: make(map[uint16]*PendingRequest), udpServer: udp.NewDispatcher(dispatcher), } - s.cleanup = &signal.PeriodicTask{ + s.cleanup = &task.Periodic{ Interval: time.Minute, Execute: s.Cleanup, } diff --git a/app/dns/server.go b/app/dns/server.go index dbe879baf..fe55c7f7a 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -11,7 +11,7 @@ import ( "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/net" - "v2ray.com/core/common/signal" + "v2ray.com/core/common/task" ) const ( @@ -33,7 +33,7 @@ type Server struct { hosts map[string]net.IP records map[string]*DomainRecord servers []NameServer - task *signal.PeriodicTask + task *task.Periodic } func New(ctx context.Context, config *Config) (*Server, error) { @@ -42,7 +42,7 @@ func New(ctx context.Context, config *Config) (*Server, error) { servers: make([]NameServer, len(config.NameServers)), hosts: config.GetInternalHosts(), } - server.task = &signal.PeriodicTask{ + server.task = &task.Periodic{ Interval: time.Minute * 10, Execute: func() error { server.cleanup() diff --git a/app/proxyman/inbound/dynamic.go b/app/proxyman/inbound/dynamic.go index 957279f31..49a4d76e7 100644 --- a/app/proxyman/inbound/dynamic.go +++ b/app/proxyman/inbound/dynamic.go @@ -10,7 +10,7 @@ import ( "v2ray.com/core/app/proxyman/mux" "v2ray.com/core/common/dice" "v2ray.com/core/common/net" - "v2ray.com/core/common/signal" + "v2ray.com/core/common/task" "v2ray.com/core/proxy" ) @@ -25,7 +25,7 @@ type DynamicInboundHandler struct { worker []worker lastRefresh time.Time mux *mux.Server - task *signal.PeriodicTask + task *task.Periodic } func NewDynamicInboundHandler(ctx context.Context, tag string, receiverConfig *proxyman.ReceiverConfig, proxyConfig interface{}) (*DynamicInboundHandler, error) { @@ -39,7 +39,7 @@ func NewDynamicInboundHandler(ctx context.Context, tag string, receiverConfig *p v: v, } - h.task = &signal.PeriodicTask{ + h.task = &task.Periodic{ Interval: time.Minute * time.Duration(h.receiverConfig.AllocationStrategy.GetRefreshValue()), Execute: h.refresh, } diff --git a/common/functions/functions.go b/common/functions/functions.go deleted file mode 100644 index 1b790baef..000000000 --- a/common/functions/functions.go +++ /dev/null @@ -1,23 +0,0 @@ -package functions - -import "v2ray.com/core/common" - -// Task is a function that may return an error. -type Task func() error - -// OnSuccess returns a Task to run a follow task if pre-condition passes, otherwise the error in pre-condition is returned. -func OnSuccess(pre func() error, followup Task) Task { - return func() error { - if err := pre(); err != nil { - return err - } - return followup() - } -} - -// Close returns a Task to close the object. -func Close(obj interface{}) Task { - return func() error { - return common.Close(obj) - } -} diff --git a/common/task/common.go b/common/task/common.go new file mode 100644 index 000000000..95d5d9431 --- /dev/null +++ b/common/task/common.go @@ -0,0 +1,9 @@ +package task + +import "v2ray.com/core/common" + +func Close(v interface{}) Task { + return func() error { + return common.Close(v) + } +} diff --git a/common/signal/task.go b/common/task/periodic.go similarity index 81% rename from common/signal/task.go rename to common/task/periodic.go index 11f98c4af..00c491e6e 100644 --- a/common/signal/task.go +++ b/common/task/periodic.go @@ -1,12 +1,12 @@ -package signal +package task import ( "sync" "time" ) -// PeriodicTask is a task that runs periodically. -type PeriodicTask struct { +// Periodic is a task that runs periodically. +type Periodic struct { // Interval of the task being run Interval time.Duration // Execute is the task function @@ -19,7 +19,7 @@ type PeriodicTask struct { closed bool } -func (t *PeriodicTask) checkedExecute() error { +func (t *Periodic) checkedExecute() error { t.access.Lock() defer t.access.Unlock() @@ -41,7 +41,7 @@ func (t *PeriodicTask) checkedExecute() error { } // Start implements common.Runnable. Start must not be called multiple times without Close being called. -func (t *PeriodicTask) Start() error { +func (t *Periodic) Start() error { t.access.Lock() t.closed = false t.access.Unlock() @@ -55,7 +55,7 @@ func (t *PeriodicTask) Start() error { } // Close implements common.Closable. -func (t *PeriodicTask) Close() error { +func (t *Periodic) Close() error { t.access.Lock() defer t.access.Unlock() diff --git a/common/signal/task_test.go b/common/task/periodic_test.go similarity index 84% rename from common/signal/task_test.go rename to common/task/periodic_test.go index 75be9742b..1abfa1b95 100644 --- a/common/signal/task_test.go +++ b/common/task/periodic_test.go @@ -1,19 +1,20 @@ -package signal_test +package task_test import ( "testing" "time" - "v2ray.com/core/common" - . "v2ray.com/core/common/signal" + . "v2ray.com/core/common/task" . "v2ray.com/ext/assert" + + "v2ray.com/core/common" ) func TestPeriodicTaskStop(t *testing.T) { assert := With(t) value := 0 - task := &PeriodicTask{ + task := &Periodic{ Interval: time.Second * 2, Execute: func() error { value++ diff --git a/common/task/task.go b/common/task/task.go new file mode 100644 index 000000000..96307f079 --- /dev/null +++ b/common/task/task.go @@ -0,0 +1,137 @@ +package task + +import ( + "context" + + "v2ray.com/core/common/signal" +) + +type Task func() error + +type executionContext struct { + ctx context.Context + task Task + onSuccess Task + onFailure Task +} + +func (c *executionContext) executeTask() error { + if c.ctx == nil && c.task == nil { + return nil + } + + if c.ctx == nil { + return c.task() + } + + if c.task == nil { + <-c.ctx.Done() + return c.ctx.Err() + } + + return executeParallel(func() error { + <-c.ctx.Done() + return c.ctx.Err() + }, c.task) +} + +func (c *executionContext) run() error { + err := c.executeTask() + if err == nil && c.onSuccess != nil { + return c.onSuccess() + } + if err != nil && c.onFailure != nil { + return c.onFailure() + } + return err +} + +type ExecutionOption func(*executionContext) + +func WithContext(ctx context.Context) ExecutionOption { + return func(c *executionContext) { + c.ctx = ctx + } +} + +func Parallel(tasks ...Task) ExecutionOption { + return func(c *executionContext) { + c.task = func() error { + return executeParallel(tasks...) + } + } +} + +func Sequential(tasks ...Task) ExecutionOption { + return func(c *executionContext) { + c.task = func() error { + return execute(tasks...) + } + } +} + +func OnSuccess(task Task) ExecutionOption { + return func(c *executionContext) { + c.onSuccess = task + } +} + +func OnFailure(task Task) ExecutionOption { + return func(c *executionContext) { + c.onFailure = task + } +} + +func Single(task Task, opts ExecutionOption) Task { + return Run(append([]ExecutionOption{Sequential(task)}, opts)...) +} + +func Run(opts ...ExecutionOption) Task { + var c executionContext + for _, opt := range opts { + opt(&c) + } + return func() error { + return c.run() + } +} + +// execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass. +func execute(tasks ...Task) error { + for _, task := range tasks { + if err := task(); err != nil { + return err + } + } + return nil +} + +// executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass. +func executeParallel(tasks ...Task) error { + n := len(tasks) + s := signal.NewSemaphore(n) + done := make(chan error, 1) + + for _, task := range tasks { + <-s.Wait() + go func(f func() error) { + if err := f(); err != nil { + select { + case done <- err: + default: + } + } + s.Signal() + }(task) + } + + for i := 0; i < n; i++ { + select { + case err := <-done: + return err + case <-s.Wait(): + } + } + + return nil +} diff --git a/common/task/task_test.go b/common/task/task_test.go new file mode 100644 index 000000000..886564106 --- /dev/null +++ b/common/task/task_test.go @@ -0,0 +1,43 @@ +package task_test + +import ( + "context" + "errors" + "testing" + "time" + + . "v2ray.com/core/common/task" + . "v2ray.com/ext/assert" +) + +func TestExecuteParallel(t *testing.T) { + assert := With(t) + + err := Run(Parallel(func() error { + time.Sleep(time.Millisecond * 200) + return errors.New("test") + }, func() error { + time.Sleep(time.Millisecond * 500) + return errors.New("test2") + }))() + + assert(err.Error(), Equals, "test") +} + +func TestExecuteParallelContextCancel(t *testing.T) { + assert := With(t) + + ctx, cancel := context.WithCancel(context.Background()) + err := Run(WithContext(ctx), Parallel(func() error { + time.Sleep(time.Millisecond * 2000) + return errors.New("test") + }, func() error { + time.Sleep(time.Millisecond * 5000) + return errors.New("test2") + }, func() error { + cancel() + return nil + }))() + + assert(err.Error(), HasSubstring, "canceled") +} diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index cba00c2c1..40bf3289d 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -9,9 +9,9 @@ import ( "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/buf" - "v2ray.com/core/common/functions" "v2ray.com/core/common/net" "v2ray.com/core/common/signal" + "v2ray.com/core/common/task" "v2ray.com/core/proxy" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/udp" @@ -118,7 +118,10 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in return nil } - if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil { + if err := task.Run(task.WithContext(ctx), + task.Parallel( + task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))), + responseDone))(); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index ddf888d5f..91a68e071 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -10,10 +10,10 @@ import ( "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/dice" - "v2ray.com/core/common/functions" "v2ray.com/core/common/net" "v2ray.com/core/common/retry" "v2ray.com/core/common/signal" + "v2ray.com/core/common/task" "v2ray.com/core/proxy" "v2ray.com/core/transport/internet" ) @@ -136,7 +136,7 @@ func (h *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia return nil } - if err := signal.ExecuteParallel(ctx, requestDone, functions.OnSuccess(responseDone, functions.Close(output))); err != nil { + if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, task.Single(responseDone, task.OnSuccess(task.Close(output)))))(); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/http/client.go b/proxy/http/client.go index e60d9dd9f..347f59429 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -1,4 +1,6 @@ package http +/* type Client struct { } +*/ diff --git a/proxy/http/server.go b/proxy/http/server.go index 4ff9dedfe..d2653f351 100755 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -10,13 +10,14 @@ import ( "strings" "time" + "v2ray.com/core/common/task" + "v2ray.com/core/transport/pipe" "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/errors" - "v2ray.com/core/common/functions" "v2ray.com/core/common/log" "v2ray.com/core/common/net" http_proto "v2ray.com/core/common/protocol/http" @@ -210,7 +211,8 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade return nil } - if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil { + var closeWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))) + if err := task.Run(task.WithContext(ctx), task.Parallel(closeWriter, responseDone))(); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index 84427caba..51eb92e66 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -3,10 +3,11 @@ package shadowsocks import ( "context" + "v2ray.com/core/common/task" + "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/buf" - "v2ray.com/core/common/functions" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" "v2ray.com/core/common/retry" @@ -158,7 +159,8 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial return nil } - if err := signal.ExecuteParallel(ctx, requestDone, functions.OnSuccess(responseDone, functions.Close(link.Writer))); err != nil { + var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer))) + if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index bc6653f47..a8558fb0d 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -4,10 +4,11 @@ import ( "context" "time" + "v2ray.com/core/common/task" + "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/buf" - "v2ray.com/core/common/functions" "v2ray.com/core/common/log" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" @@ -216,7 +217,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, return nil } - if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil { + var requestDoneAndCloseWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))) + if err := task.Run(task.WithContext(ctx), task.Parallel(requestDoneAndCloseWriter, responseDone))(); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/socks/client.go b/proxy/socks/client.go index 4078345c0..ede7e6b49 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -4,10 +4,11 @@ import ( "context" "time" + "v2ray.com/core/common/task" + "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/buf" - "v2ray.com/core/common/functions" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" "v2ray.com/core/common/retry" @@ -130,7 +131,8 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial } } - if err := signal.ExecuteParallel(ctx, requestFunc, functions.OnSuccess(responseFunc, functions.Close(link.Writer))); err != nil { + var responseDonePost = task.Single(responseFunc, task.OnSuccess(task.Close(link.Writer))) + if err := task.Run(task.WithContext(ctx), task.Parallel(requestFunc, responseDonePost))(); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 928dd820e..bd3b165ca 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -5,10 +5,11 @@ import ( "io" "time" + "v2ray.com/core/common/task" + "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/buf" - "v2ray.com/core/common/functions" "v2ray.com/core/common/log" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" @@ -160,7 +161,8 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ return nil } - if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil { + var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))) + if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 0071f15d7..58677212f 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -19,7 +19,7 @@ import ( "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" "v2ray.com/core/common/serial" - "v2ray.com/core/common/signal" + "v2ray.com/core/common/task" "v2ray.com/core/proxy/vmess" ) @@ -33,7 +33,7 @@ type sessionId struct { type SessionHistory struct { sync.RWMutex cache map[sessionId]time.Time - task *signal.PeriodicTask + task *task.Periodic } // NewSessionHistory creates a new SessionHistory object. @@ -41,7 +41,7 @@ func NewSessionHistory() *SessionHistory { h := &SessionHistory{ cache: make(map[sessionId]time.Time, 128), } - h.task = &signal.PeriodicTask{ + h.task = &task.Periodic{ Interval: time.Second * 30, Execute: func() error { h.removeExpiredEntries() diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index ffc2a3de3..f2890c99b 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -9,11 +9,12 @@ import ( "sync" "time" + "v2ray.com/core/common/task" + "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/errors" - "v2ray.com/core/common/functions" "v2ray.com/core/common/log" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" @@ -294,7 +295,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i return transferResponse(timer, session, request, response, link.Reader, writer) } - if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil { + var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))) + if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil { pipe.CloseError(link.Reader) pipe.CloseError(link.Writer) return newError("connection ends").Base(err) diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 5629cf516..cc4a32f90 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -6,12 +6,13 @@ import ( "context" "time" + "v2ray.com/core/common/task" + "v2ray.com/core/transport/pipe" "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/buf" - "v2ray.com/core/common/functions" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" "v2ray.com/core/common/retry" @@ -161,7 +162,8 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia return buf.Copy(bodyReader, output, buf.UpdateActivity(timer)) } - if err := signal.ExecuteParallel(ctx, requestDone, functions.OnSuccess(responseDone, functions.Close(output))); err != nil { + var responseDonePost = task.Single(responseDone, task.OnSuccess(task.Close(output))) + if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDonePost))(); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/vmess/vmess.go b/proxy/vmess/vmess.go index 6efd5719a..0c7038b81 100644 --- a/proxy/vmess/vmess.go +++ b/proxy/vmess/vmess.go @@ -14,7 +14,7 @@ import ( "v2ray.com/core/common" "v2ray.com/core/common/protocol" - "v2ray.com/core/common/signal" + "v2ray.com/core/common/task" ) const ( @@ -34,7 +34,7 @@ type TimedUserValidator struct { userHash map[[16]byte]indexTimePair hasher protocol.IDHash baseTime protocol.Timestamp - task *signal.PeriodicTask + task *task.Periodic } type indexTimePair struct { @@ -49,7 +49,7 @@ func NewTimedUserValidator(hasher protocol.IDHash) *TimedUserValidator { hasher: hasher, baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2), } - tuv.task = &signal.PeriodicTask{ + tuv.task = &task.Periodic{ Interval: updateInterval, Execute: func() error { tuv.updateUserHash()