clean udp writer

This commit is contained in:
Darien Raymond 2017-04-21 14:51:09 +02:00
parent eda72624e2
commit 498c7dafdf
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
10 changed files with 48 additions and 55 deletions

View File

@ -128,6 +128,12 @@ func NewMergingWriterSize(writer io.Writer, size uint32) Writer {
}
}
func NewSequentialWriter(writer io.Writer) Writer {
return &seqWriter{
writer: writer,
}
}
// ToBytesWriter converts a Writer to io.Writer
func ToBytesWriter(writer Writer) io.Writer {
return &bytesToBufferWriter{

View File

@ -42,6 +42,25 @@ func (w *mergingWriter) Write(mb MultiBuffer) error {
return nil
}
type seqWriter struct {
writer io.Writer
}
func (w *seqWriter) Write(mb MultiBuffer) error {
defer mb.Release()
for _, b := range mb {
if b.IsEmpty() {
continue
}
if _, err := w.writer.Write(b.Bytes()); err != nil {
return err
}
}
return nil
}
type bytesToBufferWriter struct {
writer Writer
}

View File

@ -4,7 +4,6 @@ package freedom
import (
"context"
"io"
"runtime"
"time"
@ -117,7 +116,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
if destination.Network == net.Network_TCP {
writer = buf.NewWriter(conn)
} else {
writer = &seqWriter{writer: conn}
writer = buf.NewSequentialWriter(conn)
}
if err := buf.Copy(timer, input, writer); err != nil {
return newError("failed to process request").Base(err)
@ -151,19 +150,3 @@ func init() {
return New(ctx, config.(*Config))
}))
}
type seqWriter struct {
writer io.Writer
}
func (w *seqWriter) Write(mb buf.MultiBuffer) error {
defer mb.Release()
for _, b := range mb {
if _, err := w.writer.Write(b.Bytes()); err != nil {
return err
}
}
return nil
}

View File

@ -135,10 +135,10 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
if request.Command == protocol.RequestCommandUDP {
writer := &UDPWriter{
writer := buf.NewSequentialWriter(&UDPWriter{
Writer: conn,
Request: request,
}
})
requestDone := signal.ExecuteAsync(func() error {
if err := buf.Copy(timer, outboundRay.OutboundInput(), writer); err != nil {

View File

@ -238,7 +238,7 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr
return buf.NewWriter(crypto.NewCryptionWriter(stream, writer)), nil
}
func EncodeUDPPacket(request *protocol.RequestHeader, payload *buf.Buffer) (*buf.Buffer, error) {
func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buffer, error) {
user := request.User
rawAccount, err := user.GetTypedAccount()
if err != nil {
@ -266,7 +266,7 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload *buf.Buffer) (*buf
}
buffer.AppendSupplier(serial.WriteUint16(uint16(request.Port)))
buffer.Append(payload.Bytes())
buffer.Append(payload)
if request.Option.Has(RequestOptionOneTimeAuth) {
authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv))
@ -382,23 +382,12 @@ type UDPWriter struct {
Request *protocol.RequestHeader
}
func (w *UDPWriter) Write(mb buf.MultiBuffer) error {
defer mb.Release()
for _, b := range mb {
if err := w.writeInternal(b); err != nil {
return err
}
}
return nil
}
func (w *UDPWriter) writeInternal(buffer *buf.Buffer) error {
payload, err := EncodeUDPPacket(w.Request, buffer)
func (w *UDPWriter) Write(payload []byte) (int, error) {
packet, err := EncodeUDPPacket(w.Request, payload)
if err != nil {
return err
return 0, err
}
_, err = w.Writer.Write(payload.Bytes())
payload.Release()
return err
_, err = w.Writer.Write(packet.Bytes())
packet.Release()
return len(payload), err
}

View File

@ -31,7 +31,7 @@ func TestUDPEncoding(t *testing.T) {
data := buf.NewLocal(256)
data.AppendSupplier(serial.WriteString("test string"))
encodedData, err := EncodeUDPPacket(request, data)
encodedData, err := EncodeUDPPacket(request, data.Bytes())
assert.Error(err).IsNil()
decodedRequest, decodedData, err := DecodeUDPPacket(request.User, encodedData)
@ -88,7 +88,7 @@ func TestUDPReaderWriter(t *testing.T) {
}),
}
cache := buf.New()
writer := &UDPWriter{
writer := buf.NewSequentialWriter(&UDPWriter{
Writer: cache,
Request: &protocol.RequestHeader{
Version: Version,
@ -97,7 +97,7 @@ func TestUDPReaderWriter(t *testing.T) {
User: user,
Option: RequestOptionOneTimeAuth,
},
}
})
reader := &UDPReader{
Reader: cache,

View File

@ -113,7 +113,7 @@ func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) {
defer payload.Release()
data, err := EncodeUDPPacket(request, payload)
data, err := EncodeUDPPacket(request, payload.Bytes())
if err != nil {
log.Trace(newError("failed to encode UDP packet").Base(err).AtWarning())
return

View File

@ -103,7 +103,7 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay, dialer proxy.
}
defer udpConn.Close()
requestFunc = func() error {
return buf.Copy(timer, ray.OutboundInput(), &UDPWriter{request: request, writer: udpConn})
return buf.Copy(timer, ray.OutboundInput(), buf.NewSequentialWriter(NewUDPWriter(request, udpConn)))
}
responseFunc = func() error {
defer ray.OutboundOutput().Close()

View File

@ -369,17 +369,13 @@ func NewUDPWriter(request *protocol.RequestHeader, writer io.Writer) *UDPWriter
}
}
func (w *UDPWriter) Write(mb buf.MultiBuffer) error {
defer mb.Release()
for _, b := range mb {
eb := EncodeUDPPacket(w.request, b.Bytes())
defer eb.Release()
if _, err := w.writer.Write(eb.Bytes()); err != nil {
return err
}
func (w *UDPWriter) Write(b []byte) (int, error) {
eb := EncodeUDPPacket(w.request, b)
defer eb.Release()
if _, err := w.writer.Write(eb.Bytes()); err != nil {
return 0, err
}
return nil
return len(b), nil
}
func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {

View File

@ -19,7 +19,7 @@ func TestUDPEncoding(t *testing.T) {
Address: net.IPAddress([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}),
Port: 1024,
}
writer := NewUDPWriter(request, b)
writer := buf.NewSequentialWriter(NewUDPWriter(request, b))
content := []byte{'a'}
payload := buf.New()