diff --git a/proxy/trojan/client.go b/proxy/trojan/client.go index dd74345d5..3724f0b00 100644 --- a/proxy/trojan/client.go +++ b/proxy/trojan/client.go @@ -146,7 +146,14 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter } // write some request payload to buffer - if err = buf.CopyOnceTimeout(link.Reader, bodyWriter, proxy.FirstPayloadTimeout); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout { + err = buf.CopyOnceTimeout(link.Reader, bodyWriter, proxy.FirstPayloadTimeout) + switch err { + case buf.ErrNotTimeoutReader, buf.ErrReadTimeout: + if err := connWriter.WriteHeader(); err != nil { + return newError("failed to write request header").Base(err).AtWarning() + } + case nil: + default: return newError("failed to write a request payload").Base(err).AtWarning() } diff --git a/proxy/trojan/protocol.go b/proxy/trojan/protocol.go index 21e6a5a84..dd20e7248 100644 --- a/proxy/trojan/protocol.go +++ b/proxy/trojan/protocol.go @@ -37,7 +37,7 @@ type ConnWriter struct { // Write implements io.Writer func (c *ConnWriter) Write(p []byte) (n int, err error) { if !c.headerSent { - if err := c.writeHeader(); err != nil { + if err := c.WriteHeader(); err != nil { return 0, newError("failed to write request header").Base(err) } } @@ -60,7 +60,7 @@ func (c *ConnWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { return nil } -func (c *ConnWriter) writeHeader() error { +func (c *ConnWriter) WriteHeader() error { buffer := buf.StackNew() defer buffer.Release()