diff --git a/common/signal/exec.go b/common/signal/exec.go new file mode 100644 index 000000000..2ef03dcc8 --- /dev/null +++ b/common/signal/exec.go @@ -0,0 +1,30 @@ +package signal + +func executeAndFulfill(f func() error, done chan<- error) { + err := f() + if err != nil { + done <- err + } + close(done) +} + +func ExecuteAsync(f func() error) <-chan error { + done := make(chan error, 1) + go executeAndFulfill(f, done) + return done +} + +func ErrorOrFinish2(c1, c2 <-chan error) error { + select { + case err, failed := <-c1: + if failed { + return err + } + return <-c2 + case err, failed := <-c2: + if failed { + return err + } + return <-c1 + } +} diff --git a/common/task/task.go b/common/task/task.go deleted file mode 100644 index 50f79cfac..000000000 --- a/common/task/task.go +++ /dev/null @@ -1,41 +0,0 @@ -package task - -import ( - "sync" -) - -type Task interface { - Execute() error -} - -type ParallelExecutor struct { - sync.Mutex - tasks sync.WaitGroup - errors []error -} - -func (pe *ParallelExecutor) track(err error) { - if err == nil { - return - } - - pe.Lock() - pe.errors = append(pe.errors, err) - pe.Unlock() -} - -func (pe *ParallelExecutor) Execute(task Task) { - pe.tasks.Add(1) - go func() { - pe.track(task.Execute()) - pe.tasks.Done() - }() -} - -func (pe *ParallelExecutor) Wait() { - pe.tasks.Wait() -} - -func (pe *ParallelExecutor) Errors() []error { - return pe.errors -} diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 5b08f9016..c0bcc1621 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -11,6 +11,7 @@ import ( "v2ray.com/core/common/log" v2net "v2ray.com/core/common/net" "v2ray.com/core/common/serial" + "v2ray.com/core/common/signal" "v2ray.com/core/proxy" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/udp" @@ -168,35 +169,37 @@ func (v *DokodemoDoor) HandleTCPConnection(conn internet.Connection) { }) defer ray.InboundOutput().Release() - var wg sync.WaitGroup - reader := v2net.NewTimeOutReader(v.config.Timeout, conn) defer reader.Release() - wg.Add(1) - go func() { + requestDone := signal.ExecuteAsync(func() error { + defer ray.InboundInput().Close() + v2reader := buf.NewReader(reader) defer v2reader.Release() if err := buf.PipeUntilEOF(v2reader, ray.InboundInput()); err != nil { log.Info("Dokodemo: Failed to transport all TCP request: ", err) + return err } - wg.Done() - ray.InboundInput().Close() - }() - wg.Add(1) - go func() { + return nil + }) + + responseDone := signal.ExecuteAsync(func() error { + defer ray.InboundOutput().Release() + v2writer := buf.NewWriter(conn) defer v2writer.Release() if err := buf.PipeUntilEOF(ray.InboundOutput(), v2writer); err != nil { log.Info("Dokodemo: Failed to transport all TCP response: ", err) + return err } - wg.Done() - }() + return nil + }) - wg.Wait() + signal.ErrorOrFinish2(requestDone, responseDone) } type Factory struct{} diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index d313feaef..dcb6667ea 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -15,7 +15,7 @@ import ( v2net "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" "v2ray.com/core/common/serial" - "v2ray.com/core/common/task" + "v2ray.com/core/common/signal" "v2ray.com/core/common/uuid" "v2ray.com/core/proxy" "v2ray.com/core/proxy/vmess" @@ -24,66 +24,6 @@ import ( "v2ray.com/core/transport/ray" ) -type requestProcessor struct { - session *encoding.ServerSession - request *protocol.RequestHeader - input io.Reader - output ray.OutputStream -} - -func (r *requestProcessor) Execute() error { - defer r.output.Close() - - bodyReader := r.session.DecodeRequestBody(r.request, r.input) - defer bodyReader.Release() - - if err := buf.PipeUntilEOF(bodyReader, r.output); err != nil { - log.Debug("VMess|Inbound: Error when sending data to outbound: ", err) - return err - } - - return nil -} - -type responseProcessor struct { - session *encoding.ServerSession - request *protocol.RequestHeader - response *protocol.ResponseHeader - input ray.InputStream - output io.Writer -} - -func (r *responseProcessor) Execute() error { - defer r.input.Release() - r.session.EncodeResponseHeader(r.response, r.output) - - bodyWriter := r.session.EncodeResponseBody(r.request, r.output) - - // Optimize for small response packet - if data, err := r.input.Read(); err == nil { - if err := bodyWriter.Write(data); err != nil { - return err - } - - if bufferedWriter, ok := r.output.(*bufio.BufferedWriter); ok { - bufferedWriter.SetBuffered(false) - } - - if err := buf.PipeUntilEOF(r.input, bodyWriter); err != nil { - log.Debug("VMess|Inbound: Error when sending data to downstream: ", err) - return err - } - } - - if r.request.Option.Has(protocol.RequestOptionChunkStream) { - if err := bodyWriter.Write(buf.NewLocal(8)); err != nil { - return err - } - } - - return nil -} - type userByEmail struct { sync.RWMutex cache map[string]*protocol.User @@ -190,6 +130,50 @@ func (v *VMessInboundHandler) Start() error { return nil } +func transferRequest(session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output ray.OutputStream) error { + defer output.Close() + + bodyReader := session.DecodeRequestBody(request, input) + defer bodyReader.Release() + + if err := buf.PipeUntilEOF(bodyReader, output); err != nil { + log.Debug("VMess|Inbound: Error when sending data to outbound: ", err) + return err + } + return nil +} + +func transferResponse(session *encoding.ServerSession, request *protocol.RequestHeader, response *protocol.ResponseHeader, input ray.InputStream, output io.Writer) error { + defer input.Release() + session.EncodeResponseHeader(response, output) + + bodyWriter := session.EncodeResponseBody(request, output) + + // Optimize for small response packet + if data, err := input.Read(); err == nil { + if err := bodyWriter.Write(data); err != nil { + return err + } + + if bufferedWriter, ok := output.(*bufio.BufferedWriter); ok { + bufferedWriter.SetBuffered(false) + } + + if err := buf.PipeUntilEOF(input, bodyWriter); err != nil { + log.Debug("VMess|Inbound: Error when sending data to downstream: ", err) + return err + } + } + + if request.Option.Has(protocol.RequestOptionChunkStream) { + if err := bodyWriter.Write(buf.NewLocal(8)); err != nil { + return err + } + } + + return nil +} + func (v *VMessInboundHandler) HandleConnection(connection internet.Connection) { defer connection.Close() @@ -242,12 +226,8 @@ func (v *VMessInboundHandler) HandleConnection(connection internet.Connection) { connReader.SetTimeOut(userSettings.PayloadReadTimeout) reader.SetBuffered(false) - var executor task.ParallelExecutor - executor.Execute(&requestProcessor{ - session: session, - request: request, - input: reader, - output: input, + requestDone := signal.ExecuteAsync(func() error { + return transferRequest(session, request, reader, input) }) writer := bufio.NewWriter(connection) @@ -261,23 +241,19 @@ func (v *VMessInboundHandler) HandleConnection(connection internet.Connection) { response.Option.Set(protocol.ResponseOptionConnectionReuse) } - executor.Execute(&responseProcessor{ - session: session, - request: request, - response: response, - input: output, - output: writer, + responseDone := signal.ExecuteAsync(func() error { + return transferResponse(session, request, response, output, writer) }) - executor.Wait() + err = signal.ErrorOrFinish2(requestDone, responseDone) + if err != nil { + connection.SetReusable(false) + return + } if err := writer.Flush(); err != nil { connection.SetReusable(false) - } - - errors := executor.Errors() - if len(errors) > 0 { - connection.SetReusable(false) + return } }