diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index d5412ee75..21894c8f9 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -137,31 +137,35 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ return nil } -func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, writer io.Writer) buf.Writer { +func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) { var sizeParser crypto.ChunkSizeEncoder = crypto.PlainChunkSizeParser{} if request.Option.Has(protocol.RequestOptionChunkMasking) { sizeParser = NewShakeSizeParser(c.requestBodyIV[:]) } var padding crypto.PaddingLengthGenerator if request.Option.Has(protocol.RequestOptionGlobalPadding) { - padding = sizeParser.(crypto.PaddingLengthGenerator) + var ok bool + padding, ok = sizeParser.(crypto.PaddingLengthGenerator) + if !ok { + return nil, newError("invalid option: RequestOptionGlobalPadding") + } } switch request.Security { case protocol.SecurityType_NONE: if request.Option.Has(protocol.RequestOptionChunkStream) { if request.Command.TransferType() == protocol.TransferTypeStream { - return crypto.NewChunkStreamWriter(sizeParser, writer) + return crypto.NewChunkStreamWriter(sizeParser, writer), nil } auth := &crypto.AEADAuthenticator{ AEAD: new(NoOpAuthenticator), NonceGenerator: crypto.GenerateEmptyBytes(), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } - return crypto.NewAuthenticationWriter(auth, sizeParser, writer, protocol.TransferTypePacket, padding) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer, protocol.TransferTypePacket, padding), nil } - return buf.NewWriter(writer) + return buf.NewWriter(writer), nil case protocol.SecurityType_LEGACY: aesStream := crypto.NewAesEncryptionStream(c.requestBodyKey[:], c.requestBodyIV[:]) cryptionWriter := crypto.NewCryptionWriter(aesStream, writer) @@ -171,10 +175,10 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write NonceGenerator: crypto.GenerateEmptyBytes(), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } - return crypto.NewAuthenticationWriter(auth, sizeParser, cryptionWriter, request.Command.TransferType(), padding) + return crypto.NewAuthenticationWriter(auth, sizeParser, cryptionWriter, request.Command.TransferType(), padding), nil } - return &buf.SequentialWriter{Writer: cryptionWriter} + return &buf.SequentialWriter{Writer: cryptionWriter}, nil case protocol.SecurityType_AES128_GCM: aead := crypto.NewAesGcm(c.requestBodyKey[:]) auth := &crypto.AEADAuthenticator{ @@ -193,7 +197,7 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write } sizeParser = NewAEADSizeParser(lengthAuth) } - return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding), nil case protocol.SecurityType_CHACHA20_POLY1305: aead, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(c.requestBodyKey[:])) common.Must(err) @@ -215,9 +219,9 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write } sizeParser = NewAEADSizeParser(lengthAuth) } - return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding), nil default: - panic("Unknown security type.") + return nil, newError("invalid option: Security") } } @@ -306,21 +310,25 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon return header, nil } -func (c *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, reader io.Reader) buf.Reader { +func (c *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, reader io.Reader) (buf.Reader, error) { var sizeParser crypto.ChunkSizeDecoder = crypto.PlainChunkSizeParser{} if request.Option.Has(protocol.RequestOptionChunkMasking) { sizeParser = NewShakeSizeParser(c.responseBodyIV[:]) } var padding crypto.PaddingLengthGenerator if request.Option.Has(protocol.RequestOptionGlobalPadding) { - padding = sizeParser.(crypto.PaddingLengthGenerator) + var ok bool + padding, ok = sizeParser.(crypto.PaddingLengthGenerator) + if !ok { + return nil, newError("invalid option: RequestOptionGlobalPadding") + } } switch request.Security { case protocol.SecurityType_NONE: if request.Option.Has(protocol.RequestOptionChunkStream) { if request.Command.TransferType() == protocol.TransferTypeStream { - return crypto.NewChunkStreamReader(sizeParser, reader) + return crypto.NewChunkStreamReader(sizeParser, reader), nil } auth := &crypto.AEADAuthenticator{ @@ -329,10 +337,10 @@ func (c *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } - return crypto.NewAuthenticationReader(auth, sizeParser, reader, protocol.TransferTypePacket, padding) + return crypto.NewAuthenticationReader(auth, sizeParser, reader, protocol.TransferTypePacket, padding), nil } - return buf.NewReader(reader) + return buf.NewReader(reader), nil case protocol.SecurityType_LEGACY: if request.Option.Has(protocol.RequestOptionChunkStream) { auth := &crypto.AEADAuthenticator{ @@ -340,10 +348,10 @@ func (c *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read NonceGenerator: crypto.GenerateEmptyBytes(), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } - return crypto.NewAuthenticationReader(auth, sizeParser, c.responseReader, request.Command.TransferType(), padding) + return crypto.NewAuthenticationReader(auth, sizeParser, c.responseReader, request.Command.TransferType(), padding), nil } - return buf.NewReader(c.responseReader) + return buf.NewReader(c.responseReader), nil case protocol.SecurityType_AES128_GCM: aead := crypto.NewAesGcm(c.responseBodyKey[:]) @@ -363,7 +371,7 @@ func (c *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read } sizeParser = NewAEADSizeParser(lengthAuth) } - return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding) + return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding), nil case protocol.SecurityType_CHACHA20_POLY1305: aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(c.responseBodyKey[:])) @@ -384,9 +392,9 @@ func (c *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read } sizeParser = NewAEADSizeParser(lengthAuth) } - return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding) + return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding), nil default: - panic("Unknown security type.") + return nil, newError("invalid option: Security") } } diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 202de2a95..44f4a72a3 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -305,21 +305,25 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request } // DecodeRequestBody returns Reader from which caller can fetch decrypted body. -func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reader io.Reader) buf.Reader { +func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reader io.Reader) (buf.Reader, error) { var sizeParser crypto.ChunkSizeDecoder = crypto.PlainChunkSizeParser{} if request.Option.Has(protocol.RequestOptionChunkMasking) { sizeParser = NewShakeSizeParser(s.requestBodyIV[:]) } var padding crypto.PaddingLengthGenerator if request.Option.Has(protocol.RequestOptionGlobalPadding) { - padding = sizeParser.(crypto.PaddingLengthGenerator) + var ok bool + padding, ok = sizeParser.(crypto.PaddingLengthGenerator) + if !ok { + return nil, newError("invalid option: RequestOptionGlobalPadding") + } } switch request.Security { case protocol.SecurityType_NONE: if request.Option.Has(protocol.RequestOptionChunkStream) { if request.Command.TransferType() == protocol.TransferTypeStream { - return crypto.NewChunkStreamReader(sizeParser, reader) + return crypto.NewChunkStreamReader(sizeParser, reader), nil } auth := &crypto.AEADAuthenticator{ @@ -327,9 +331,9 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade NonceGenerator: crypto.GenerateEmptyBytes(), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } - return crypto.NewAuthenticationReader(auth, sizeParser, reader, protocol.TransferTypePacket, padding) + return crypto.NewAuthenticationReader(auth, sizeParser, reader, protocol.TransferTypePacket, padding), nil } - return buf.NewReader(reader) + return buf.NewReader(reader), nil case protocol.SecurityType_LEGACY: aesStream := crypto.NewAesDecryptionStream(s.requestBodyKey[:], s.requestBodyIV[:]) @@ -340,9 +344,9 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade NonceGenerator: crypto.GenerateEmptyBytes(), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } - return crypto.NewAuthenticationReader(auth, sizeParser, cryptionReader, request.Command.TransferType(), padding) + return crypto.NewAuthenticationReader(auth, sizeParser, cryptionReader, request.Command.TransferType(), padding), nil } - return buf.NewReader(cryptionReader) + return buf.NewReader(cryptionReader), nil case protocol.SecurityType_AES128_GCM: aead := crypto.NewAesGcm(s.requestBodyKey[:]) @@ -362,7 +366,7 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade } sizeParser = NewAEADSizeParser(lengthAuth) } - return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding) + return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding), nil case protocol.SecurityType_CHACHA20_POLY1305: aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(s.requestBodyKey[:])) @@ -384,10 +388,10 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade } sizeParser = NewAEADSizeParser(lengthAuth) } - return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding) + return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding), nil default: - panic("Unknown security type.") + return nil, newError("invalid option: Security") } } @@ -448,21 +452,25 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr } // EncodeResponseBody returns a Writer that auto-encrypt content written by caller. -func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writer io.Writer) buf.Writer { +func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) { var sizeParser crypto.ChunkSizeEncoder = crypto.PlainChunkSizeParser{} if request.Option.Has(protocol.RequestOptionChunkMasking) { sizeParser = NewShakeSizeParser(s.responseBodyIV[:]) } var padding crypto.PaddingLengthGenerator if request.Option.Has(protocol.RequestOptionGlobalPadding) { - padding = sizeParser.(crypto.PaddingLengthGenerator) + var ok bool + padding, ok = sizeParser.(crypto.PaddingLengthGenerator) + if !ok { + return nil, newError("invalid option: RequestOptionGlobalPadding") + } } switch request.Security { case protocol.SecurityType_NONE: if request.Option.Has(protocol.RequestOptionChunkStream) { if request.Command.TransferType() == protocol.TransferTypeStream { - return crypto.NewChunkStreamWriter(sizeParser, writer) + return crypto.NewChunkStreamWriter(sizeParser, writer), nil } auth := &crypto.AEADAuthenticator{ @@ -470,9 +478,9 @@ func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ NonceGenerator: crypto.GenerateEmptyBytes(), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } - return crypto.NewAuthenticationWriter(auth, sizeParser, writer, protocol.TransferTypePacket, padding) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer, protocol.TransferTypePacket, padding), nil } - return buf.NewWriter(writer) + return buf.NewWriter(writer), nil case protocol.SecurityType_LEGACY: if request.Option.Has(protocol.RequestOptionChunkStream) { @@ -481,9 +489,9 @@ func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ NonceGenerator: crypto.GenerateEmptyBytes(), AdditionalDataGenerator: crypto.GenerateEmptyBytes(), } - return crypto.NewAuthenticationWriter(auth, sizeParser, s.responseWriter, request.Command.TransferType(), padding) + return crypto.NewAuthenticationWriter(auth, sizeParser, s.responseWriter, request.Command.TransferType(), padding), nil } - return &buf.SequentialWriter{Writer: s.responseWriter} + return &buf.SequentialWriter{Writer: s.responseWriter}, nil case protocol.SecurityType_AES128_GCM: aead := crypto.NewAesGcm(s.responseBodyKey[:]) @@ -503,7 +511,7 @@ func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ } sizeParser = NewAEADSizeParser(lengthAuth) } - return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding), nil case protocol.SecurityType_CHACHA20_POLY1305: aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(s.responseBodyKey[:])) @@ -525,9 +533,9 @@ func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ } sizeParser = NewAEADSizeParser(lengthAuth) } - return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding), nil default: - panic("Unknown security type.") + return nil, newError("invalid option: Security") } } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 70eaa2575..25e2e59fa 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -182,8 +182,10 @@ func (h *Handler) RemoveUser(ctx context.Context, email string) error { func transferResponse(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, response *protocol.ResponseHeader, input buf.Reader, output *buf.BufferedWriter) error { session.EncodeResponseHeader(response, output) - bodyWriter := session.EncodeResponseBody(request, output) - + bodyWriter, err := session.EncodeResponseBody(request, output) + if err != nil { + return newError("failed to start decoding response").Base(err) + } { // Optimize for small response packet data, err := input.ReadMultiBuffer() @@ -290,7 +292,10 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i requestDone := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) - bodyReader := svrSession.DecodeRequestBody(request, reader) + bodyReader, err := svrSession.DecodeRequestBody(request, reader) + if err != nil { + return newError("failed to start decoding").Base(err) + } if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil { return newError("failed to transfer request").Base(err) } diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index e3b9daf96..50ca01564 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -151,7 +151,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return newError("failed to encode request").Base(err).AtWarning() } - bodyWriter := session.EncodeRequestBody(request, writer) + bodyWriter, err := session.EncodeRequestBody(request, writer) + if err != nil { + return newError("failed to start encoding").Base(err) + } if err := buf.CopyOnceTimeout(input, bodyWriter, time.Millisecond*100); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout { return newError("failed to write first payload").Base(err) } @@ -183,8 +186,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } h.handleCommand(rec.Destination(), header.Command) - bodyReader := session.DecodeResponseBody(request, reader) - + bodyReader, err := session.DecodeResponseBody(request, reader) + if err != nil { + return newError("failed to start encoding response").Base(err) + } return buf.Copy(bodyReader, output, buf.UpdateActivity(timer)) }