1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-02 15:36:41 -05:00

better way to run tasks in parallel

This commit is contained in:
Darien Raymond 2018-04-11 16:45:09 +02:00
parent 9d7f43a299
commit 0caf92726b
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
11 changed files with 79 additions and 90 deletions

View File

@ -39,16 +39,16 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) {
} }
var content buf.MultiBuffer var content buf.MultiBuffer
loadTask := signal.ExecuteAsync(func() error { loadTask := func() error {
c, err := buf.ReadAllToMultiBuffer(stdoutReader) c, err := buf.ReadAllToMultiBuffer(stdoutReader)
if err != nil { if err != nil {
return err return err
} }
content = c content = c
return nil return nil
}) }
waitTask := signal.ExecuteAsync(func() error { waitTask := func() error {
if err := cmd.Wait(); err != nil { if err := cmd.Wait(); err != nil {
msg := "failed to execute v2ctl" msg := "failed to execute v2ctl"
if errBuffer.Len() > 0 { if errBuffer.Len() > 0 {
@ -57,9 +57,9 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) {
return newError(msg).Base(err) return newError(msg).Base(err)
} }
return nil return nil
}) }
if err := signal.ErrorOrFinish2(context.Background(), loadTask, waitTask); err != nil { if err := signal.ExecuteParallel(context.Background(), loadTask, waitTask); err != nil {
return nil, err return nil, err
} }

View File

@ -4,14 +4,6 @@ import (
"context" "context"
) )
func executeAndFulfill(f func() error, done chan<- error) {
err := f()
if err != nil {
done <- err
}
close(done)
}
// Execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass. // Execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
func Execute(tasks ...func() error) error { func Execute(tasks ...func() error) error {
for _, task := range tasks { for _, task := range tasks {
@ -22,35 +14,34 @@ func Execute(tasks ...func() error) error {
return nil return nil
} }
// ExecuteAsync executes a function asynchronously and return its result. // ExecuteParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
func ExecuteAsync(f func() error) <-chan error { func ExecuteParallel(ctx context.Context, tasks ...func() error) error {
n := len(tasks)
s := NewSemaphore(n)
done := make(chan error, 1) done := make(chan error, 1)
go executeAndFulfill(f, done)
return done
}
func ErrorOrFinish1(ctx context.Context, c <-chan error) error { for _, task := range tasks {
select { <-s.Wait()
case <-ctx.Done(): go func(f func() error) {
return ctx.Err() if err := f(); err != nil {
case err := <-c: select {
return err case done <- err:
default:
}
}
s.Signal()
}(task)
} }
}
func ErrorOrFinish2(ctx context.Context, c1, c2 <-chan error) error { for i := 0; i < n; i++ {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case err := <-c1: case err := <-done:
if err != nil {
return err return err
case <-s.Wait():
} }
return ErrorOrFinish1(ctx, c2)
case err := <-c2:
if err != nil {
return err
}
return ErrorOrFinish1(ctx, c1)
} }
return nil
} }

View File

@ -75,7 +75,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
return newError("failed to dispatch request").Base(err) return newError("failed to dispatch request").Base(err)
} }
requestDone := signal.ExecuteAsync(func() error { requestDone := func() error {
defer inboundRay.InboundInput().Close() defer inboundRay.InboundInput().Close()
defer timer.SetTimeout(d.policy().Timeouts.DownlinkOnly) defer timer.SetTimeout(d.policy().Timeouts.DownlinkOnly)
@ -86,9 +86,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
} }
return nil return nil
}) }
responseDone := signal.ExecuteAsync(func() error { responseDone := func() error {
defer timer.SetTimeout(d.policy().Timeouts.UplinkOnly) defer timer.SetTimeout(d.policy().Timeouts.UplinkOnly)
var writer buf.Writer var writer buf.Writer
@ -113,9 +113,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
} }
return nil return nil
}) }
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
inboundRay.InboundInput().CloseError() inboundRay.InboundInput().CloseError()
inboundRay.InboundOutput().CloseError() inboundRay.InboundOutput().CloseError()
return newError("connection ends").Base(err) return newError("connection ends").Base(err)

View File

@ -109,7 +109,7 @@ func (h *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, h.policy().Timeouts.ConnectionIdle) timer := signal.CancelAfterInactivity(ctx, cancel, h.policy().Timeouts.ConnectionIdle)
requestDone := signal.ExecuteAsync(func() error { requestDone := func() error {
defer timer.SetTimeout(h.policy().Timeouts.DownlinkOnly) defer timer.SetTimeout(h.policy().Timeouts.DownlinkOnly)
var writer buf.Writer var writer buf.Writer
@ -123,9 +123,9 @@ func (h *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
} }
return nil return nil
}) }
responseDone := signal.ExecuteAsync(func() error { responseDone := func() error {
defer timer.SetTimeout(h.policy().Timeouts.UplinkOnly) defer timer.SetTimeout(h.policy().Timeouts.UplinkOnly)
v2reader := buf.NewReader(conn) v2reader := buf.NewReader(conn)
@ -134,9 +134,9 @@ func (h *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
} }
return nil return nil
}) }
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
input.CloseError() input.CloseError()
output.CloseError() output.CloseError()
return newError("connection ends").Base(err) return newError("connection ends").Base(err)

View File

@ -182,15 +182,15 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
reader = nil reader = nil
} }
requestDone := signal.ExecuteAsync(func() error { requestDone := func() error {
defer ray.InboundInput().Close() defer ray.InboundInput().Close()
defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly) defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly)
v2reader := buf.NewReader(conn) v2reader := buf.NewReader(conn)
return buf.Copy(v2reader, ray.InboundInput(), buf.UpdateActivity(timer)) return buf.Copy(v2reader, ray.InboundInput(), buf.UpdateActivity(timer))
}) }
responseDone := signal.ExecuteAsync(func() error { responseDone := func() error {
defer timer.SetTimeout(s.policy().Timeouts.UplinkOnly) defer timer.SetTimeout(s.policy().Timeouts.UplinkOnly)
v2writer := buf.NewWriter(conn) v2writer := buf.NewWriter(conn)
@ -199,9 +199,9 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
} }
return nil return nil
}) }
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
ray.InboundInput().CloseError() ray.InboundInput().CloseError()
ray.InboundOutput().CloseError() ray.InboundOutput().CloseError()
return newError("connection ends").Base(err) return newError("connection ends").Base(err)
@ -251,7 +251,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
var result error = errWaitAnother var result error = errWaitAnother
requestDone := signal.ExecuteAsync(func() error { requestDone := func() error {
request.Header.Set("Connection", "close") request.Header.Set("Connection", "close")
requestWriter := buf.NewBufferedWriter(ray.InboundInput()) requestWriter := buf.NewBufferedWriter(ray.InboundInput())
@ -260,9 +260,9 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
return newError("failed to write whole request").Base(err).AtWarning() return newError("failed to write whole request").Base(err).AtWarning()
} }
return nil return nil
}) }
responseDone := signal.ExecuteAsync(func() error { responseDone := func() error {
responseReader := bufio.NewReaderSize(buf.NewBufferedReader(ray.InboundOutput()), buf.Size) responseReader := bufio.NewReaderSize(buf.NewBufferedReader(ray.InboundOutput()), buf.Size)
response, err := http.ReadResponse(responseReader, request) response, err := http.ReadResponse(responseReader, request)
if err == nil { if err == nil {
@ -296,9 +296,9 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
return newError("failed to write response").Base(err).AtWarning() return newError("failed to write response").Base(err).AtWarning()
} }
return nil return nil
}) }
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
input.CloseError() input.CloseError()
output.CloseError() output.CloseError()
return newError("connection ends").Base(err) return newError("connection ends").Base(err)

View File

@ -105,12 +105,12 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
return err return err
} }
requestDone := signal.ExecuteAsync(func() error { requestDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
return buf.Copy(outboundRay.OutboundInput(), bodyWriter, buf.UpdateActivity(timer)) return buf.Copy(outboundRay.OutboundInput(), bodyWriter, buf.UpdateActivity(timer))
}) }
responseDone := signal.ExecuteAsync(func() error { responseDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
responseReader, err := ReadTCPResponse(user, conn) responseReader, err := ReadTCPResponse(user, conn)
@ -119,9 +119,9 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
} }
return buf.Copy(responseReader, outboundRay.OutboundOutput(), buf.UpdateActivity(timer)) return buf.Copy(responseReader, outboundRay.OutboundOutput(), buf.UpdateActivity(timer))
}) }
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
return newError("connection ends").Base(err) return newError("connection ends").Base(err)
} }
@ -135,16 +135,16 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
Request: request, Request: request,
}) })
requestDone := signal.ExecuteAsync(func() error { requestDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
if err := buf.Copy(outboundRay.OutboundInput(), writer, buf.UpdateActivity(timer)); err != nil { if err := buf.Copy(outboundRay.OutboundInput(), writer, buf.UpdateActivity(timer)); err != nil {
return newError("failed to transport all UDP request").Base(err) return newError("failed to transport all UDP request").Base(err)
} }
return nil return nil
}) }
responseDone := signal.ExecuteAsync(func() error { responseDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
reader := &UDPReader{ reader := &UDPReader{
@ -156,9 +156,9 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
return newError("failed to transport all UDP response").Base(err) return newError("failed to transport all UDP response").Base(err)
} }
return nil return nil
}) }
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
return newError("connection ends").Base(err) return newError("connection ends").Base(err)
} }

View File

@ -172,7 +172,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
return err return err
} }
responseDone := signal.ExecuteAsync(func() error { responseDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn)) bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
@ -200,9 +200,9 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
} }
return nil return nil
}) }
requestDone := signal.ExecuteAsync(func() error { requestDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
defer ray.InboundInput().Close() defer ray.InboundInput().Close()
@ -211,9 +211,9 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
} }
return nil return nil
}) }
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
ray.InboundInput().CloseError() ray.InboundInput().CloseError()
ray.InboundOutput().CloseError() ray.InboundOutput().CloseError()
return newError("connection ends").Base(err) return newError("connection ends").Base(err)

View File

@ -130,9 +130,7 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay, dialer proxy.
} }
} }
requestDone := signal.ExecuteAsync(requestFunc) if err := signal.ExecuteParallel(ctx, requestFunc, responseFunc); err != nil {
responseDone := signal.ExecuteAsync(responseFunc)
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
return newError("connection ends").Base(err) return newError("connection ends").Base(err)
} }

View File

@ -137,7 +137,7 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
input := ray.InboundInput() input := ray.InboundInput()
output := ray.InboundOutput() output := ray.InboundOutput()
requestDone := signal.ExecuteAsync(func() error { requestDone := func() error {
defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly) defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly)
defer input.Close() defer input.Close()
@ -147,9 +147,9 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
} }
return nil return nil
}) }
responseDone := signal.ExecuteAsync(func() error { responseDone := func() error {
defer timer.SetTimeout(s.policy().Timeouts.UplinkOnly) defer timer.SetTimeout(s.policy().Timeouts.UplinkOnly)
v2writer := buf.NewWriter(writer) v2writer := buf.NewWriter(writer)
@ -158,9 +158,9 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
} }
return nil return nil
}) }
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
input.CloseError() input.CloseError()
output.CloseError() output.CloseError()
return newError("connection ends").Base(err) return newError("connection ends").Base(err)

View File

@ -280,12 +280,12 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
input := ray.InboundInput() input := ray.InboundInput()
output := ray.InboundOutput() output := ray.InboundOutput()
requestDone := signal.ExecuteAsync(func() error { requestDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
return transferRequest(timer, session, request, reader, input) return transferRequest(timer, session, request, reader, input)
}) }
responseDone := signal.ExecuteAsync(func() error { responseDone := func() error {
writer := buf.NewBufferedWriter(buf.NewWriter(connection)) writer := buf.NewBufferedWriter(buf.NewWriter(connection))
defer writer.Flush() defer writer.Flush()
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
@ -294,9 +294,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
Command: h.generateCommand(ctx, request), Command: h.generateCommand(ctx, request),
} }
return transferResponse(timer, session, request, response, output, writer) return transferResponse(timer, session, request, response, output, writer)
}) }
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
input.CloseError() input.CloseError()
output.CloseError() output.CloseError()
return newError("connection ends").Base(err) return newError("connection ends").Base(err)

View File

@ -104,7 +104,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
requestDone := signal.ExecuteAsync(func() error { requestDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
writer := buf.NewBufferedWriter(buf.NewWriter(conn)) writer := buf.NewBufferedWriter(buf.NewWriter(conn))
@ -140,9 +140,9 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
} }
return nil return nil
}) }
responseDone := signal.ExecuteAsync(func() error { responseDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
reader := buf.NewBufferedReader(buf.NewReader(conn)) reader := buf.NewBufferedReader(buf.NewReader(conn))
@ -156,9 +156,9 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
bodyReader := session.DecodeResponseBody(request, reader) bodyReader := session.DecodeResponseBody(request, reader)
return buf.Copy(bodyReader, output, buf.UpdateActivity(timer)) return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
}) }
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil { if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
return newError("connection ends").Base(err) return newError("connection ends").Base(err)
} }