From f751bb610c8a9debb9e12e807d9d3107949a944e Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Fri, 13 Apr 2018 13:54:36 +0200 Subject: [PATCH] refine ctlcmd --- common/buf/multi_buffer.go | 59 +++++++++++++++++++++++--------- common/buf/multi_buffer_test.go | 8 +++++ common/platform/ctlcmd/ctlcmd.go | 45 +++++++----------------- 3 files changed, 63 insertions(+), 49 deletions(-) diff --git a/common/buf/multi_buffer.go b/common/buf/multi_buffer.go index 95149b4a2..89fd05c74 100644 --- a/common/buf/multi_buffer.go +++ b/common/buf/multi_buffer.go @@ -3,7 +3,6 @@ package buf import ( "io" "net" - "os" "v2ray.com/core/common" "v2ray.com/core/common/errors" @@ -14,22 +13,12 @@ import ( func ReadAllToMultiBuffer(reader io.Reader) (MultiBuffer, error) { mb := NewMultiBufferCap(128) - for { - b := New() - err := b.Reset(ReadFrom(reader)) - if b.IsEmpty() { - b.Release() - } else { - mb.Append(b) - } - if err != nil { - if errors.Cause(err) == io.EOF || errors.Cause(err) == os.ErrClosed { - return mb, nil - } - mb.Release() - return nil, err - } + if _, err := mb.ReadFrom(reader); err != nil { + mb.Release() + return nil, err } + + return mb, nil } // ReadSizeToMultiBuffer reads specific number of bytes from reader into a MultiBuffer. @@ -102,6 +91,28 @@ func (mb MultiBuffer) Copy(b []byte) int { return total } +// ReadFrom implements io.ReaderFrom. +func (mb *MultiBuffer) ReadFrom(reader io.Reader) (int64, error) { + totalBytes := int64(0) + + for { + b := New() + err := b.Reset(ReadFrom(reader)) + if b.IsEmpty() { + b.Release() + } else { + mb.Append(b) + } + totalBytes += int64(b.Len()) + if err != nil { + if errors.Cause(err) == io.EOF { + return totalBytes, nil + } + return totalBytes, err + } + } +} + // Read implements io.Reader. func (mb *MultiBuffer) Read(b []byte) (int, error) { if mb.Len() == 0 { @@ -125,6 +136,22 @@ func (mb *MultiBuffer) Read(b []byte) (int, error) { return totalBytes, nil } +// WriteTo implements io.WriterTo. +func (mb *MultiBuffer) WriteTo(writer io.Writer) (int64, error) { + defer mb.Release() + + totalBytes := int64(0) + for _, b := range *mb { + nBytes, err := writer.Write(b.Bytes()) + totalBytes += int64(nBytes) + if err != nil { + return totalBytes, err + } + } + + return totalBytes, nil +} + // Write implements io.Writer. func (mb *MultiBuffer) Write(b []byte) (int, error) { totalBytes := len(b) diff --git a/common/buf/multi_buffer_test.go b/common/buf/multi_buffer_test.go index 656d0656e..8a55e5562 100644 --- a/common/buf/multi_buffer_test.go +++ b/common/buf/multi_buffer_test.go @@ -2,6 +2,7 @@ package buf_test import ( "crypto/rand" + "io" "testing" "v2ray.com/core/common" @@ -48,3 +49,10 @@ func TestMultiBufferSliceBySizeLarge(t *testing.T) { mb2 := mb.SliceBySize(4 * 1024) assert(mb2.Len(), Equals, int32(4*1024)) } + +func TestInterface(t *testing.T) { + assert := With(t) + + assert((*MultiBuffer)(nil), Implements, (*io.WriterTo)(nil)) + assert((*MultiBuffer)(nil), Implements, (*io.ReaderFrom)(nil)) +} diff --git a/common/platform/ctlcmd/ctlcmd.go b/common/platform/ctlcmd/ctlcmd.go index 45611c1ca..60d219cf4 100644 --- a/common/platform/ctlcmd/ctlcmd.go +++ b/common/platform/ctlcmd/ctlcmd.go @@ -1,14 +1,12 @@ package ctlcmd import ( - "context" "io" "os" "os/exec" "v2ray.com/core/common/buf" "v2ray.com/core/common/platform" - "v2ray.com/core/common/signal" ) //go:generate go run $GOPATH/src/v2ray.com/core/common/errors/errorgen/main.go -pkg ctlcmd -path Command,Platform,CtlCmd @@ -19,49 +17,30 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) { return nil, newError("v2ctl doesn't exist").Base(err) } - errBuffer := &buf.MultiBuffer{} + errBuffer := buf.MultiBuffer{} + outBuffer := buf.MultiBuffer{} cmd := exec.Command(v2ctl, args...) - cmd.Stderr = errBuffer + cmd.Stderr = &errBuffer + cmd.Stdout = &outBuffer cmd.SysProcAttr = getSysProcAttr() if input != nil { cmd.Stdin = input } - stdoutReader, err := cmd.StdoutPipe() - if err != nil { - return nil, newError("failed to get stdout from v2ctl").Base(err) - } - defer stdoutReader.Close() - if err := cmd.Start(); err != nil { return nil, newError("failed to start v2ctl").Base(err) } - var content buf.MultiBuffer - loadTask := func() error { - c, err := buf.ReadAllToMultiBuffer(stdoutReader) - if err != nil { - return newError("failed to read config").Base(err) + if err := cmd.Wait(); err != nil { + msg := "failed to execute v2ctl" + if errBuffer.Len() > 0 { + msg += ": " + errBuffer.String() } - content = c - return nil + errBuffer.Release() + outBuffer.Release() + return nil, newError(msg).Base(err) } - waitTask := func() error { - if err := cmd.Wait(); err != nil { - msg := "failed to execute v2ctl" - if errBuffer.Len() > 0 { - msg += ": " + errBuffer.String() - } - return newError(msg).Base(err) - } - return nil - } - - if err := signal.ExecuteParallel(context.Background(), loadTask, waitTask); err != nil { - return nil, err - } - - return content, nil + return outBuffer, nil }