better way to run tasks in parallel

This commit is contained in:
Darien Raymond 2018-04-11 16:45:09 +02:00
parent 9d7f43a299
commit 0caf92726b
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
11 changed files with 79 additions and 90 deletions

View File

@ -39,16 +39,16 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) {
}
var content buf.MultiBuffer
loadTask := signal.ExecuteAsync(func() error {
loadTask := func() error {
c, err := buf.ReadAllToMultiBuffer(stdoutReader)
if err != nil {
return err
}
content = c
return nil
})
}
waitTask := signal.ExecuteAsync(func() error {
waitTask := func() error {
if err := cmd.Wait(); err != nil {
msg := "failed to execute v2ctl"
if errBuffer.Len() > 0 {
@ -57,9 +57,9 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) {
return newError(msg).Base(err)
}
return nil
})
}
if err := signal.ErrorOrFinish2(context.Background(), loadTask, waitTask); err != nil {
if err := signal.ExecuteParallel(context.Background(), loadTask, waitTask); err != nil {
return nil, err
}

View File

@ -4,14 +4,6 @@ import (
"context"
)
func executeAndFulfill(f func() error, done chan<- error) {
err := f()
if err != nil {
done <- err
}
close(done)
}
// Execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
func Execute(tasks ...func() error) error {
for _, task := range tasks {
@ -22,35 +14,34 @@ func Execute(tasks ...func() error) error {
return nil
}
// ExecuteAsync executes a function asynchronously and return its result.
func ExecuteAsync(f func() error) <-chan error {
// ExecuteParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
func ExecuteParallel(ctx context.Context, tasks ...func() error) error {
n := len(tasks)
s := NewSemaphore(n)
done := make(chan error, 1)
go executeAndFulfill(f, done)
return done
}
func ErrorOrFinish1(ctx context.Context, c <-chan error) error {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-c:
return err
for _, task := range tasks {
<-s.Wait()
go func(f func() error) {
if err := f(); err != nil {
select {
case done <- err:
default:
}
}
s.Signal()
}(task)
}
}
func ErrorOrFinish2(ctx context.Context, c1, c2 <-chan error) error {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-c1:
if err != nil {
for i := 0; i < n; i++ {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
case <-s.Wait():
}
return ErrorOrFinish1(ctx, c2)
case err := <-c2:
if err != nil {
return err
}
return ErrorOrFinish1(ctx, c1)
}
return nil
}

View File

@ -75,7 +75,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
return newError("failed to dispatch request").Base(err)
}
requestDone := signal.ExecuteAsync(func() error {
requestDone := func() error {
defer inboundRay.InboundInput().Close()
defer timer.SetTimeout(d.policy().Timeouts.DownlinkOnly)
@ -86,9 +86,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
}
return nil
})
}
responseDone := signal.ExecuteAsync(func() error {
responseDone := func() error {
defer timer.SetTimeout(d.policy().Timeouts.UplinkOnly)
var writer buf.Writer
@ -113,9 +113,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
}
return nil
})
}
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
inboundRay.InboundInput().CloseError()
inboundRay.InboundOutput().CloseError()
return newError("connection ends").Base(err)

View File

@ -109,7 +109,7 @@ func (h *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, h.policy().Timeouts.ConnectionIdle)
requestDone := signal.ExecuteAsync(func() error {
requestDone := func() error {
defer timer.SetTimeout(h.policy().Timeouts.DownlinkOnly)
var writer buf.Writer
@ -123,9 +123,9 @@ func (h *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
}
return nil
})
}
responseDone := signal.ExecuteAsync(func() error {
responseDone := func() error {
defer timer.SetTimeout(h.policy().Timeouts.UplinkOnly)
v2reader := buf.NewReader(conn)
@ -134,9 +134,9 @@ func (h *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
}
return nil
})
}
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
input.CloseError()
output.CloseError()
return newError("connection ends").Base(err)

View File

@ -182,15 +182,15 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
reader = nil
}
requestDone := signal.ExecuteAsync(func() error {
requestDone := func() error {
defer ray.InboundInput().Close()
defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly)
v2reader := buf.NewReader(conn)
return buf.Copy(v2reader, ray.InboundInput(), buf.UpdateActivity(timer))
})
}
responseDone := signal.ExecuteAsync(func() error {
responseDone := func() error {
defer timer.SetTimeout(s.policy().Timeouts.UplinkOnly)
v2writer := buf.NewWriter(conn)
@ -199,9 +199,9 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
}
return nil
})
}
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
ray.InboundInput().CloseError()
ray.InboundOutput().CloseError()
return newError("connection ends").Base(err)
@ -251,7 +251,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
var result error = errWaitAnother
requestDone := signal.ExecuteAsync(func() error {
requestDone := func() error {
request.Header.Set("Connection", "close")
requestWriter := buf.NewBufferedWriter(ray.InboundInput())
@ -260,9 +260,9 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
return newError("failed to write whole request").Base(err).AtWarning()
}
return nil
})
}
responseDone := signal.ExecuteAsync(func() error {
responseDone := func() error {
responseReader := bufio.NewReaderSize(buf.NewBufferedReader(ray.InboundOutput()), buf.Size)
response, err := http.ReadResponse(responseReader, request)
if err == nil {
@ -296,9 +296,9 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
return newError("failed to write response").Base(err).AtWarning()
}
return nil
})
}
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
input.CloseError()
output.CloseError()
return newError("connection ends").Base(err)

View File

@ -105,12 +105,12 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
return err
}
requestDone := signal.ExecuteAsync(func() error {
requestDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
return buf.Copy(outboundRay.OutboundInput(), bodyWriter, buf.UpdateActivity(timer))
})
}
responseDone := signal.ExecuteAsync(func() error {
responseDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
responseReader, err := ReadTCPResponse(user, conn)
@ -119,9 +119,9 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
}
return buf.Copy(responseReader, outboundRay.OutboundOutput(), buf.UpdateActivity(timer))
})
}
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
return newError("connection ends").Base(err)
}
@ -135,16 +135,16 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
Request: request,
})
requestDone := signal.ExecuteAsync(func() error {
requestDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
if err := buf.Copy(outboundRay.OutboundInput(), writer, buf.UpdateActivity(timer)); err != nil {
return newError("failed to transport all UDP request").Base(err)
}
return nil
})
}
responseDone := signal.ExecuteAsync(func() error {
responseDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
reader := &UDPReader{
@ -156,9 +156,9 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
return newError("failed to transport all UDP response").Base(err)
}
return nil
})
}
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
return newError("connection ends").Base(err)
}

View File

@ -172,7 +172,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
return err
}
responseDone := signal.ExecuteAsync(func() error {
responseDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
@ -200,9 +200,9 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
}
return nil
})
}
requestDone := signal.ExecuteAsync(func() error {
requestDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
defer ray.InboundInput().Close()
@ -211,9 +211,9 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
}
return nil
})
}
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
ray.InboundInput().CloseError()
ray.InboundOutput().CloseError()
return newError("connection ends").Base(err)

View File

@ -130,9 +130,7 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay, dialer proxy.
}
}
requestDone := signal.ExecuteAsync(requestFunc)
responseDone := signal.ExecuteAsync(responseFunc)
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
if err := signal.ExecuteParallel(ctx, requestFunc, responseFunc); err != nil {
return newError("connection ends").Base(err)
}

View File

@ -137,7 +137,7 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
input := ray.InboundInput()
output := ray.InboundOutput()
requestDone := signal.ExecuteAsync(func() error {
requestDone := func() error {
defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly)
defer input.Close()
@ -147,9 +147,9 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
}
return nil
})
}
responseDone := signal.ExecuteAsync(func() error {
responseDone := func() error {
defer timer.SetTimeout(s.policy().Timeouts.UplinkOnly)
v2writer := buf.NewWriter(writer)
@ -158,9 +158,9 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
}
return nil
})
}
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
input.CloseError()
output.CloseError()
return newError("connection ends").Base(err)

View File

@ -280,12 +280,12 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
input := ray.InboundInput()
output := ray.InboundOutput()
requestDone := signal.ExecuteAsync(func() error {
requestDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
return transferRequest(timer, session, request, reader, input)
})
}
responseDone := signal.ExecuteAsync(func() error {
responseDone := func() error {
writer := buf.NewBufferedWriter(buf.NewWriter(connection))
defer writer.Flush()
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
@ -294,9 +294,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
Command: h.generateCommand(ctx, request),
}
return transferResponse(timer, session, request, response, output, writer)
})
}
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
input.CloseError()
output.CloseError()
return newError("connection ends").Base(err)

View File

@ -104,7 +104,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
requestDone := signal.ExecuteAsync(func() error {
requestDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
writer := buf.NewBufferedWriter(buf.NewWriter(conn))
@ -140,9 +140,9 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
}
return nil
})
}
responseDone := signal.ExecuteAsync(func() error {
responseDone := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
reader := buf.NewBufferedReader(buf.NewReader(conn))
@ -156,9 +156,9 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
bodyReader := session.DecodeResponseBody(request, reader)
return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
})
}
if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
return newError("connection ends").Base(err)
}