diff --git a/common/platform/ctlcmd/ctlcmd.go b/common/platform/ctlcmd/ctlcmd.go index 19789d763..28b5fd250 100644 --- a/common/platform/ctlcmd/ctlcmd.go +++ b/common/platform/ctlcmd/ctlcmd.go @@ -39,16 +39,16 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) { } var content buf.MultiBuffer - loadTask := signal.ExecuteAsync(func() error { + loadTask := func() error { c, err := buf.ReadAllToMultiBuffer(stdoutReader) if err != nil { return err } content = c return nil - }) + } - waitTask := signal.ExecuteAsync(func() error { + waitTask := func() error { if err := cmd.Wait(); err != nil { msg := "failed to execute v2ctl" if errBuffer.Len() > 0 { @@ -57,9 +57,9 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) { return newError(msg).Base(err) } return nil - }) + } - if err := signal.ErrorOrFinish2(context.Background(), loadTask, waitTask); err != nil { + if err := signal.ExecuteParallel(context.Background(), loadTask, waitTask); err != nil { return nil, err } diff --git a/common/signal/exec.go b/common/signal/exec.go index 632253cfc..7b2dc7c22 100644 --- a/common/signal/exec.go +++ b/common/signal/exec.go @@ -4,14 +4,6 @@ import ( "context" ) -func executeAndFulfill(f func() error, done chan<- error) { - err := f() - if err != nil { - done <- err - } - close(done) -} - // Execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass. func Execute(tasks ...func() error) error { for _, task := range tasks { @@ -22,35 +14,34 @@ func Execute(tasks ...func() error) error { return nil } -// ExecuteAsync executes a function asynchronously and return its result. -func ExecuteAsync(f func() error) <-chan error { +// ExecuteParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass. +func ExecuteParallel(ctx context.Context, tasks ...func() error) error { + n := len(tasks) + s := NewSemaphore(n) done := make(chan error, 1) - go executeAndFulfill(f, done) - return done -} -func ErrorOrFinish1(ctx context.Context, c <-chan error) error { - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-c: - return err + for _, task := range tasks { + <-s.Wait() + go func(f func() error) { + if err := f(); err != nil { + select { + case done <- err: + default: + } + } + s.Signal() + }(task) } -} -func ErrorOrFinish2(ctx context.Context, c1, c2 <-chan error) error { - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-c1: - if err != nil { + for i := 0; i < n; i++ { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: return err + case <-s.Wait(): } - return ErrorOrFinish1(ctx, c2) - case err := <-c2: - if err != nil { - return err - } - return ErrorOrFinish1(ctx, c1) } + + return nil } diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 7626911dc..b78902a30 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -75,7 +75,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in return newError("failed to dispatch request").Base(err) } - requestDone := signal.ExecuteAsync(func() error { + requestDone := func() error { defer inboundRay.InboundInput().Close() defer timer.SetTimeout(d.policy().Timeouts.DownlinkOnly) @@ -86,9 +86,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in } return nil - }) + } - responseDone := signal.ExecuteAsync(func() error { + responseDone := func() error { defer timer.SetTimeout(d.policy().Timeouts.UplinkOnly) var writer buf.Writer @@ -113,9 +113,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in } return nil - }) + } - if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { inboundRay.InboundInput().CloseError() inboundRay.InboundOutput().CloseError() return newError("connection ends").Base(err) diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index 7868c3c96..466716f31 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -109,7 +109,7 @@ func (h *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, h.policy().Timeouts.ConnectionIdle) - requestDone := signal.ExecuteAsync(func() error { + requestDone := func() error { defer timer.SetTimeout(h.policy().Timeouts.DownlinkOnly) var writer buf.Writer @@ -123,9 +123,9 @@ func (h *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial } return nil - }) + } - responseDone := signal.ExecuteAsync(func() error { + responseDone := func() error { defer timer.SetTimeout(h.policy().Timeouts.UplinkOnly) v2reader := buf.NewReader(conn) @@ -134,9 +134,9 @@ func (h *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial } return nil - }) + } - if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { input.CloseError() output.CloseError() return newError("connection ends").Base(err) diff --git a/proxy/http/server.go b/proxy/http/server.go index 495a0865f..bbe2dc60d 100755 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -182,15 +182,15 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade reader = nil } - requestDone := signal.ExecuteAsync(func() error { + requestDone := func() error { defer ray.InboundInput().Close() defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly) v2reader := buf.NewReader(conn) return buf.Copy(v2reader, ray.InboundInput(), buf.UpdateActivity(timer)) - }) + } - responseDone := signal.ExecuteAsync(func() error { + responseDone := func() error { defer timer.SetTimeout(s.policy().Timeouts.UplinkOnly) v2writer := buf.NewWriter(conn) @@ -199,9 +199,9 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade } return nil - }) + } - if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { ray.InboundInput().CloseError() ray.InboundOutput().CloseError() return newError("connection ends").Base(err) @@ -251,7 +251,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri var result error = errWaitAnother - requestDone := signal.ExecuteAsync(func() error { + requestDone := func() error { request.Header.Set("Connection", "close") requestWriter := buf.NewBufferedWriter(ray.InboundInput()) @@ -260,9 +260,9 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri return newError("failed to write whole request").Base(err).AtWarning() } return nil - }) + } - responseDone := signal.ExecuteAsync(func() error { + responseDone := func() error { responseReader := bufio.NewReaderSize(buf.NewBufferedReader(ray.InboundOutput()), buf.Size) response, err := http.ReadResponse(responseReader, request) if err == nil { @@ -296,9 +296,9 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri return newError("failed to write response").Base(err).AtWarning() } return nil - }) + } - if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { input.CloseError() output.CloseError() return newError("connection ends").Base(err) diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index 8cb20d6d8..05f2ad1e0 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -105,12 +105,12 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale return err } - requestDone := signal.ExecuteAsync(func() error { + requestDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) return buf.Copy(outboundRay.OutboundInput(), bodyWriter, buf.UpdateActivity(timer)) - }) + } - responseDone := signal.ExecuteAsync(func() error { + responseDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) responseReader, err := ReadTCPResponse(user, conn) @@ -119,9 +119,9 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale } return buf.Copy(responseReader, outboundRay.OutboundOutput(), buf.UpdateActivity(timer)) - }) + } - if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { return newError("connection ends").Base(err) } @@ -135,16 +135,16 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale Request: request, }) - requestDone := signal.ExecuteAsync(func() error { + requestDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) if err := buf.Copy(outboundRay.OutboundInput(), writer, buf.UpdateActivity(timer)); err != nil { return newError("failed to transport all UDP request").Base(err) } return nil - }) + } - responseDone := signal.ExecuteAsync(func() error { + responseDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) reader := &UDPReader{ @@ -156,9 +156,9 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale return newError("failed to transport all UDP response").Base(err) } return nil - }) + } - if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index a58d19720..a93ab8632 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -172,7 +172,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, return err } - responseDone := signal.ExecuteAsync(func() error { + responseDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn)) @@ -200,9 +200,9 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, } return nil - }) + } - requestDone := signal.ExecuteAsync(func() error { + requestDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) defer ray.InboundInput().Close() @@ -211,9 +211,9 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, } return nil - }) + } - if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { ray.InboundInput().CloseError() ray.InboundOutput().CloseError() return newError("connection ends").Base(err) diff --git a/proxy/socks/client.go b/proxy/socks/client.go index 3958ae9de..9abd15f61 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -130,9 +130,7 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay, dialer proxy. } } - requestDone := signal.ExecuteAsync(requestFunc) - responseDone := signal.ExecuteAsync(responseFunc) - if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, requestFunc, responseFunc); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/socks/server.go b/proxy/socks/server.go index c89648afd..de3f0b0e4 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -137,7 +137,7 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ input := ray.InboundInput() output := ray.InboundOutput() - requestDone := signal.ExecuteAsync(func() error { + requestDone := func() error { defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly) defer input.Close() @@ -147,9 +147,9 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ } return nil - }) + } - responseDone := signal.ExecuteAsync(func() error { + responseDone := func() error { defer timer.SetTimeout(s.policy().Timeouts.UplinkOnly) v2writer := buf.NewWriter(writer) @@ -158,9 +158,9 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ } return nil - }) + } - if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { input.CloseError() output.CloseError() return newError("connection ends").Base(err) diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index dafe2bb43..883e6ea63 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -280,12 +280,12 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i input := ray.InboundInput() output := ray.InboundOutput() - requestDone := signal.ExecuteAsync(func() error { + requestDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) return transferRequest(timer, session, request, reader, input) - }) + } - responseDone := signal.ExecuteAsync(func() error { + responseDone := func() error { writer := buf.NewBufferedWriter(buf.NewWriter(connection)) defer writer.Flush() defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) @@ -294,9 +294,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i Command: h.generateCommand(ctx, request), } return transferResponse(timer, session, request, response, output, writer) - }) + } - if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { input.CloseError() output.CloseError() return newError("connection ends").Base(err) diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index a59904d39..54f86588a 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -104,7 +104,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) - requestDone := signal.ExecuteAsync(func() error { + requestDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) writer := buf.NewBufferedWriter(buf.NewWriter(conn)) @@ -140,9 +140,9 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial } return nil - }) + } - responseDone := signal.ExecuteAsync(func() error { + responseDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) reader := buf.NewBufferedReader(buf.NewReader(conn)) @@ -156,9 +156,9 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial bodyReader := session.DecodeResponseBody(request, reader) return buf.Copy(bodyReader, output, buf.UpdateActivity(timer)) - }) + } - if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { + if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil { return newError("connection ends").Base(err) }