1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-26 20:16:55 -05:00

Add full VLESS fallbacks support to Trojan (#254)

* Add full VLESS fallbacks support to Trojan

* Adjustments according to linter

* Use common.Must2() for pro.Write()
This commit is contained in:
RPRX 2020-10-03 13:12:35 +00:00 committed by GitHub
parent 2e4042ebd7
commit 271532fc84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 278 additions and 76 deletions

View File

@ -1,9 +1,11 @@
package conf package conf
import ( import (
"encoding/json"
"strconv" "strconv"
"github.com/golang/protobuf/proto" // nolint: staticcheck "github.com/golang/protobuf/proto" // nolint: staticcheck
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
@ -68,8 +70,11 @@ func (c *TrojanClientConfig) Build() (proto.Message, error) {
// TrojanInboundFallback is fallback configuration // TrojanInboundFallback is fallback configuration
type TrojanInboundFallback struct { type TrojanInboundFallback struct {
Alpn string `json:"alpn"`
Path string `json:"path"`
Type string `json:"type"` Type string `json:"type"`
Dest string `json:"dest"` Dest json.RawMessage `json:"dest"`
Xver uint64 `json:"xver"`
} }
// TrojanUserConfig is user configuration // TrojanUserConfig is user configuration
@ -82,7 +87,8 @@ type TrojanUserConfig struct {
// TrojanServerConfig is Inbound configuration // TrojanServerConfig is Inbound configuration
type TrojanServerConfig struct { type TrojanServerConfig struct {
Clients []*TrojanUserConfig `json:"clients"` Clients []*TrojanUserConfig `json:"clients"`
Fallback *TrojanInboundFallback `json:"fallback"` Fallback json.RawMessage `json:"fallback"`
Fallbacks []*TrojanInboundFallback `json:"fallbacks"`
} }
// Build implements Buildable // Build implements Buildable
@ -107,11 +113,37 @@ func (c *TrojanServerConfig) Build() (proto.Message, error) {
} }
if c.Fallback != nil { if c.Fallback != nil {
fb := &trojan.Fallback{ return nil, newError(`Trojan settings: please use "fallbacks":[{}] instead of "fallback":{}`)
Dest: c.Fallback.Dest, }
for _, fb := range c.Fallbacks {
var i uint16
var s string
if err := json.Unmarshal(fb.Dest, &i); err == nil {
s = strconv.Itoa(int(i))
} else {
_ = json.Unmarshal(fb.Dest, &s)
}
config.Fallbacks = append(config.Fallbacks, &trojan.Fallback{
Alpn: fb.Alpn,
Path: fb.Path,
Type: fb.Type,
Dest: s,
Xver: fb.Xver,
})
}
for _, fb := range config.Fallbacks {
/*
if fb.Alpn == "h2" && fb.Path != "" {
return nil, newError(`Trojan fallbacks: "alpn":"h2" doesn't support "path"`)
}
*/
if fb.Path != "" && fb.Path[0] != '/' {
return nil, newError(`Trojan fallbacks: "path" must be empty or start with "/"`)
} }
if fb.Type == "" && fb.Dest != "" { if fb.Type == "" && fb.Dest != "" {
if fb.Dest == "serve-ws-none" {
fb.Type = "serve"
} else {
switch fb.Dest[0] { switch fb.Dest[0] {
case '@', '/': case '@', '/':
fb.Type = "unix" fb.Type = "unix"
@ -124,11 +156,13 @@ func (c *TrojanServerConfig) Build() (proto.Message, error) {
} }
} }
} }
if fb.Type == "" {
return nil, newError("please fill in a valid value for trojan fallback type")
} }
if fb.Type == "" {
config.Fallback = fb return nil, newError(`Trojan fallbacks: please fill in a valid value for every "dest"`)
}
if fb.Xver > 2 {
return nil, newError(`Trojan fallbacks: invalid PROXY protocol version, "xver" only accepts 0, 1, 2`)
}
} }
return config, nil return config, nil

View File

@ -78,8 +78,11 @@ type Fallback struct {
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"` Alpn string `protobuf:"bytes,1,opt,name=alpn,proto3" json:"alpn,omitempty"`
Dest string `protobuf:"bytes,2,opt,name=dest,proto3" json:"dest,omitempty"` Path string `protobuf:"bytes,2,opt,name=path,proto3" json:"path,omitempty"`
Type string `protobuf:"bytes,3,opt,name=type,proto3" json:"type,omitempty"`
Dest string `protobuf:"bytes,4,opt,name=dest,proto3" json:"dest,omitempty"`
Xver uint64 `protobuf:"varint,5,opt,name=xver,proto3" json:"xver,omitempty"`
} }
func (x *Fallback) Reset() { func (x *Fallback) Reset() {
@ -114,6 +117,20 @@ func (*Fallback) Descriptor() ([]byte, []int) {
return file_proxy_trojan_config_proto_rawDescGZIP(), []int{1} return file_proxy_trojan_config_proto_rawDescGZIP(), []int{1}
} }
func (x *Fallback) GetAlpn() string {
if x != nil {
return x.Alpn
}
return ""
}
func (x *Fallback) GetPath() string {
if x != nil {
return x.Path
}
return ""
}
func (x *Fallback) GetType() string { func (x *Fallback) GetType() string {
if x != nil { if x != nil {
return x.Type return x.Type
@ -128,6 +145,13 @@ func (x *Fallback) GetDest() string {
return "" return ""
} }
func (x *Fallback) GetXver() uint64 {
if x != nil {
return x.Xver
}
return 0
}
type ClientConfig struct { type ClientConfig struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
@ -181,7 +205,7 @@ type ServerConfig struct {
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
Users []*protocol.User `protobuf:"bytes,1,rep,name=users,proto3" json:"users,omitempty"` Users []*protocol.User `protobuf:"bytes,1,rep,name=users,proto3" json:"users,omitempty"`
Fallback *Fallback `protobuf:"bytes,2,opt,name=fallback,proto3" json:"fallback,omitempty"` Fallbacks []*Fallback `protobuf:"bytes,3,rep,name=fallbacks,proto3" json:"fallbacks,omitempty"`
} }
func (x *ServerConfig) Reset() { func (x *ServerConfig) Reset() {
@ -223,9 +247,9 @@ func (x *ServerConfig) GetUsers() []*protocol.User {
return nil return nil
} }
func (x *ServerConfig) GetFallback() *Fallback { func (x *ServerConfig) GetFallbacks() []*Fallback {
if x != nil { if x != nil {
return x.Fallback return x.Fallbacks
} }
return nil return nil
} }
@ -242,30 +266,34 @@ var file_proxy_trojan_config_proto_rawDesc = []byte{
0x6c, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x73, 0x70, 0x65, 0x63, 0x2e, 0x70, 0x72, 0x6c, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x73, 0x70, 0x65, 0x63, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x22, 0x25, 0x0a, 0x07, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x1a, 0x6f, 0x74, 0x6f, 0x22, 0x25, 0x0a, 0x07, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x1a,
0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x32, 0x0a, 0x08, 0x46, 0x61, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x6e, 0x0a, 0x08, 0x46, 0x61,
0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x61, 0x6c, 0x70, 0x6e, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x65, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x61, 0x6c, 0x70, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61,
0x73, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x64, 0x65, 0x73, 0x74, 0x22, 0x52, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x12,
0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x42, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x79,
0x0a, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2a, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x65, 0x73, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09,
0x52, 0x04, 0x64, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x78, 0x76, 0x65, 0x72, 0x18, 0x05,
0x20, 0x01, 0x28, 0x04, 0x52, 0x04, 0x78, 0x76, 0x65, 0x72, 0x22, 0x52, 0x0a, 0x0c, 0x43, 0x6c,
0x69, 0x65, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x42, 0x0a, 0x06, 0x73, 0x65,
0x72, 0x76, 0x65, 0x72, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2a, 0x2e, 0x76, 0x32, 0x72,
0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x45, 0x6e,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x52, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x22, 0x87,
0x01, 0x0a, 0x0c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
0x36, 0x0a, 0x05, 0x75, 0x73, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x20,
0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x6d,
0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x55, 0x73, 0x65, 0x72,
0x65, 0x72, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x52, 0x06, 0x73, 0x65, 0x72, 0x76, 0x52, 0x05, 0x75, 0x73, 0x65, 0x72, 0x73, 0x12, 0x3f, 0x0a, 0x09, 0x66, 0x61, 0x6c, 0x6c, 0x62,
0x65, 0x72, 0x22, 0x85, 0x01, 0x0a, 0x0c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x61, 0x63, 0x6b, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x76, 0x32, 0x72,
0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x05, 0x75, 0x73, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x74, 0x72,
0x28, 0x0b, 0x32, 0x20, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x6f, 0x6a, 0x61, 0x6e, 0x2e, 0x46, 0x61, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x52, 0x09, 0x66,
0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x61, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x73, 0x42, 0x56, 0x0a, 0x1b, 0x63, 0x6f, 0x6d, 0x2e,
0x55, 0x73, 0x65, 0x72, 0x52, 0x05, 0x75, 0x73, 0x65, 0x72, 0x73, 0x12, 0x3d, 0x0a, 0x08, 0x66,
0x61, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e,
0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79,
0x2e, 0x74, 0x72, 0x6f, 0x6a, 0x61, 0x6e, 0x2e, 0x46, 0x61, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x2e, 0x74, 0x72, 0x6f, 0x6a, 0x61, 0x6e, 0x50, 0x01, 0x5a, 0x1b, 0x76, 0x32, 0x72, 0x61, 0x79,
0x52, 0x08, 0x66, 0x61, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x42, 0x56, 0x0a, 0x1b, 0x63, 0x6f, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f,
0x6d, 0x2e, 0x76, 0x32, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x72, 0x6f, 0x6a, 0x61, 0x6e, 0xaa, 0x02, 0x17, 0x56, 0x32, 0x52, 0x61, 0x79, 0x2e, 0x43,
0x78, 0x79, 0x2e, 0x74, 0x72, 0x6f, 0x6a, 0x61, 0x6e, 0x50, 0x01, 0x5a, 0x1b, 0x76, 0x32, 0x72, 0x6f, 0x72, 0x65, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x54, 0x72, 0x6f, 0x6a, 0x61, 0x6e,
0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x79, 0x2f, 0x74, 0x72, 0x6f, 0x6a, 0x61, 0x6e, 0xaa, 0x02, 0x17, 0x56, 0x32, 0x52, 0x61, 0x79,
0x2e, 0x43, 0x6f, 0x72, 0x65, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x54, 0x72, 0x6f, 0x6a,
0x61, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (
@ -292,7 +320,7 @@ var file_proxy_trojan_config_proto_goTypes = []interface{}{
var file_proxy_trojan_config_proto_depIdxs = []int32{ var file_proxy_trojan_config_proto_depIdxs = []int32{
4, // 0: v2ray.core.proxy.trojan.ClientConfig.server:type_name -> v2ray.core.common.protocol.ServerEndpoint 4, // 0: v2ray.core.proxy.trojan.ClientConfig.server:type_name -> v2ray.core.common.protocol.ServerEndpoint
5, // 1: v2ray.core.proxy.trojan.ServerConfig.users:type_name -> v2ray.core.common.protocol.User 5, // 1: v2ray.core.proxy.trojan.ServerConfig.users:type_name -> v2ray.core.common.protocol.User
1, // 2: v2ray.core.proxy.trojan.ServerConfig.fallback:type_name -> v2ray.core.proxy.trojan.Fallback 1, // 2: v2ray.core.proxy.trojan.ServerConfig.fallbacks:type_name -> v2ray.core.proxy.trojan.Fallback
3, // [3:3] is the sub-list for method output_type 3, // [3:3] is the sub-list for method output_type
3, // [3:3] is the sub-list for method input_type 3, // [3:3] is the sub-list for method input_type
3, // [3:3] is the sub-list for extension type_name 3, // [3:3] is the sub-list for extension type_name

View File

@ -14,8 +14,11 @@ message Account {
} }
message Fallback { message Fallback {
string type = 1; string alpn = 1;
string dest = 2; string path = 2;
string type = 3;
string dest = 4;
uint64 xver = 5;
} }
message ClientConfig { message ClientConfig {
@ -24,5 +27,5 @@ message ClientConfig {
message ServerConfig { message ServerConfig {
repeated v2ray.core.common.protocol.User users = 1; repeated v2ray.core.common.protocol.User users = 1;
Fallback fallback = 2; repeated Fallback fallbacks = 3;
} }

View File

@ -4,7 +4,9 @@ package trojan
import ( import (
"context" "context"
"crypto/tls"
"io" "io"
"strconv"
"time" "time"
"v2ray.com/core" "v2ray.com/core"
@ -23,6 +25,7 @@ import (
"v2ray.com/core/features/routing" "v2ray.com/core/features/routing"
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/udp" "v2ray.com/core/transport/internet/udp"
"v2ray.com/core/transport/internet/xtls"
) )
func init() { func init() {
@ -33,9 +36,9 @@ func init() {
// Server is an inbound connection handler that handles messages in trojan protocol. // Server is an inbound connection handler that handles messages in trojan protocol.
type Server struct { type Server struct {
validator *Validator
policyManager policy.Manager policyManager policy.Manager
config *ServerConfig validator *Validator
fallbacks map[string]map[string]*Fallback // or nil
} }
// NewServer creates a new trojan inbound handler. // NewServer creates a new trojan inbound handler.
@ -56,7 +59,27 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
server := &Server{ server := &Server{
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
validator: validator, validator: validator,
config: config, }
if config.Fallbacks != nil {
server.fallbacks = make(map[string]map[string]*Fallback)
for _, fb := range config.Fallbacks {
if server.fallbacks[fb.Alpn] == nil {
server.fallbacks[fb.Alpn] = make(map[string]*Fallback)
}
server.fallbacks[fb.Alpn][fb.Path] = fb
}
if server.fallbacks[""] != nil {
for alpn, pfb := range server.fallbacks {
if alpn != "" { // && alpn != "h2" {
for path, fb := range server.fallbacks[""] {
if pfb[path] == nil {
pfb[path] = fb
}
}
}
}
}
} }
return server, nil return server, nil
@ -69,54 +92,68 @@ func (s *Server) Network() []net.Network {
// Process implements proxy.Inbound.Process(). // Process implements proxy.Inbound.Process().
func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error { // nolint: funlen,lll func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error { // nolint: funlen,lll
sid := session.ExportIDToError(ctx)
iConn := conn
if statConn, ok := iConn.(*internet.StatCouterConnection); ok {
iConn = statConn.Connection
}
sessionPolicy := s.policyManager.ForLevel(0) sessionPolicy := s.policyManager.ForLevel(0)
if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil {
return newError("unable to set read deadline").Base(err).AtWarning() return newError("unable to set read deadline").Base(err).AtWarning()
} }
buffer := buf.New() first := buf.New()
defer buffer.Release() defer first.Release()
n, err := buffer.ReadFrom(conn) firstLen, err := first.ReadFrom(conn)
if err != nil { if err != nil {
return newError("failed to read first request").Base(err) return newError("failed to read first request").Base(err)
} }
newError("firstLen = ", firstLen).AtInfo().WriteToLog(sid)
bufferedReader := &buf.BufferedReader{ bufferedReader := &buf.BufferedReader{
Reader: buf.NewReader(conn), Reader: buf.NewReader(conn),
Buffer: buf.MultiBuffer{buffer}, Buffer: buf.MultiBuffer{first},
} }
var user *protocol.MemoryUser var user *protocol.MemoryUser
fallbackEnabled := s.config.Fallback != nil
apfb := s.fallbacks
isfb := apfb != nil
shouldFallback := false shouldFallback := false
if n < 56 { // nolint: gomnd if firstLen < 58 || first.Byte(56) != '\r' { // nolint: gomnd
// invalid protocol // invalid protocol
err = newError("not trojan protocol")
log.Record(&log.AccessMessage{ log.Record(&log.AccessMessage{
From: conn.RemoteAddr(), From: conn.RemoteAddr(),
To: "", To: "",
Status: log.AccessRejected, Status: log.AccessRejected,
Reason: newError("not trojan protocol"), Reason: err,
}) })
shouldFallback = true shouldFallback = true
} else { } else {
user = s.validator.Get(hexString(buffer.BytesTo(56))) // nolint: gomnd user = s.validator.Get(hexString(first.BytesTo(56))) // nolint: gomnd
if user == nil { if user == nil {
// invalid user, let's fallback // invalid user, let's fallback
err = newError("not a valid user")
log.Record(&log.AccessMessage{ log.Record(&log.AccessMessage{
From: conn.RemoteAddr(), From: conn.RemoteAddr(),
To: "", To: "",
Status: log.AccessRejected, Status: log.AccessRejected,
Reason: newError("not a valid user"), Reason: err,
}) })
shouldFallback = true shouldFallback = true
} }
} }
if fallbackEnabled && shouldFallback { if isfb && shouldFallback {
return s.fallback(ctx, sessionPolicy, bufferedReader, buf.NewWriter(conn)) return s.fallback(ctx, sid, err, sessionPolicy, conn, iConn, apfb, first, firstLen, bufferedReader)
} else if shouldFallback { } else if shouldFallback {
return newError("invalid protocol or invalid user") return newError("invalid protocol or invalid user")
} }
@ -158,7 +195,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
Email: user.Email, Email: user.Email,
}) })
newError("received request for ", destination).WriteToLog(session.ExportIDToError(ctx)) newError("received request for ", destination).WriteToLog(sid)
return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher) return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher)
} }
@ -240,15 +277,70 @@ func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Sess
return nil return nil
} }
func (s *Server) fallback(ctx context.Context, sessionPolicy policy.Session, requestReader buf.Reader, responseWriter buf.Writer) error { // nolint: lll func (s *Server) fallback(ctx context.Context, sid errors.ExportOption, err error, sessionPolicy policy.Session, connection internet.Connection, iConn internet.Connection, apfb map[string]map[string]*Fallback, first *buf.Buffer, firstLen int64, reader buf.Reader) error { // nolint: lll
if err := connection.SetReadDeadline(time.Time{}); err != nil {
newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid)
}
newError("fallback starts").Base(err).AtInfo().WriteToLog(sid)
alpn := ""
if len(apfb) > 1 || apfb[""] == nil {
if tlsConn, ok := iConn.(*tls.Conn); ok {
alpn = tlsConn.ConnectionState().NegotiatedProtocol
newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid)
} else if xtlsConn, ok := iConn.(*xtls.Conn); ok {
alpn = xtlsConn.ConnectionState().NegotiatedProtocol
newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid)
}
if apfb[alpn] == nil {
alpn = ""
}
}
pfb := apfb[alpn]
if pfb == nil {
return newError(`failed to find the default "alpn" config`).AtWarning()
}
path := ""
if len(pfb) > 1 || pfb[""] == nil {
if firstLen >= 18 && first.Byte(4) != '*' { // not h2c
firstBytes := first.Bytes()
for i := 4; i <= 8; i++ { // 5 -> 9
if firstBytes[i] == '/' && firstBytes[i-1] == ' ' {
search := len(firstBytes)
if search > 64 {
search = 64 // up to about 60
}
for j := i + 1; j < search; j++ {
k := firstBytes[j]
if k == '\r' || k == '\n' { // avoid logging \r or \n
break
}
if k == ' ' {
path = string(firstBytes[i:j])
newError("realPath = " + path).AtInfo().WriteToLog(sid)
if pfb[path] == nil {
path = ""
}
break
}
}
break
}
}
}
}
fb := pfb[path]
if fb == nil {
return newError(`failed to find the default "path" config`).AtWarning()
}
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
var conn net.Conn var conn net.Conn
var err error if err := retry.ExponentialBackoff(5, 100).On(func() error {
fb := s.config.Fallback
if err := retry.ExponentialBackoff(5, 100).On(func() error { // nolint: gomnd
var dialer net.Dialer var dialer net.Dialer
conn, err = dialer.DialContext(ctx, fb.Type, fb.Dest) conn, err = dialer.DialContext(ctx, fb.Type, fb.Dest)
if err != nil { if err != nil {
@ -263,24 +355,69 @@ func (s *Server) fallback(ctx context.Context, sessionPolicy policy.Session, req
serverReader := buf.NewReader(conn) serverReader := buf.NewReader(conn)
serverWriter := buf.NewWriter(conn) serverWriter := buf.NewWriter(conn)
requestDone := func() error { postRequest := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
if fb.Xver != 0 {
if err := buf.Copy(requestReader, serverWriter, buf.UpdateActivity(timer)); err != nil { remoteAddr, remotePort, err := net.SplitHostPort(connection.RemoteAddr().String())
if err != nil {
return err
}
localAddr, localPort, err := net.SplitHostPort(connection.LocalAddr().String())
if err != nil {
return err
}
ipv4 := true
for i := 0; i < len(remoteAddr); i++ {
if remoteAddr[i] == ':' {
ipv4 = false
break
}
}
pro := buf.New()
defer pro.Release()
switch fb.Xver {
case 1:
if ipv4 {
common.Must2(pro.Write([]byte("PROXY TCP4 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n")))
} else {
common.Must2(pro.Write([]byte("PROXY TCP6 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n")))
}
case 2:
common.Must2(pro.Write([]byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A\x21"))) // signature + v2 + PROXY
if ipv4 {
common.Must2(pro.Write([]byte("\x11\x00\x0C"))) // AF_INET + STREAM + 12 bytes
common.Must2(pro.Write(net.ParseIP(remoteAddr).To4()))
common.Must2(pro.Write(net.ParseIP(localAddr).To4()))
} else {
common.Must2(pro.Write([]byte("\x21\x00\x24"))) // AF_INET6 + STREAM + 36 bytes
common.Must2(pro.Write(net.ParseIP(remoteAddr).To16()))
common.Must2(pro.Write(net.ParseIP(localAddr).To16()))
}
p1, _ := strconv.ParseUint(remotePort, 10, 16)
p2, _ := strconv.ParseUint(localPort, 10, 16)
common.Must2(pro.Write([]byte{byte(p1 >> 8), byte(p1), byte(p2 >> 8), byte(p2)}))
}
if err := serverWriter.WriteMultiBuffer(buf.MultiBuffer{pro}); err != nil {
return newError("failed to set PROXY protocol v", fb.Xver).Base(err).AtWarning()
}
}
if err := buf.Copy(reader, serverWriter, buf.UpdateActivity(timer)); err != nil {
return newError("failed to fallback request payload").Base(err).AtInfo() return newError("failed to fallback request payload").Base(err).AtInfo()
} }
return nil return nil
} }
responseDone := func() error { writer := buf.NewWriter(connection)
getResponse := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
if err := buf.Copy(serverReader, responseWriter, buf.UpdateActivity(timer)); err != nil { if err := buf.Copy(serverReader, writer, buf.UpdateActivity(timer)); err != nil {
return newError("failed to deliver response payload").Base(err).AtInfo() return newError("failed to deliver response payload").Base(err).AtInfo()
} }
return nil return nil
} }
if err := task.Run(ctx, task.OnSuccess(requestDone, task.Close(serverWriter)), task.OnSuccess(responseDone, task.Close(responseWriter))); err != nil { // nolint: lll if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), task.OnSuccess(getResponse, task.Close(writer))); err != nil {
common.Must(common.Interrupt(serverReader)) common.Must(common.Interrupt(serverReader))
common.Must(common.Interrupt(serverWriter)) common.Must(common.Interrupt(serverWriter))
return newError("fallback ends").Base(err).AtInfo() return newError("fallback ends").Base(err).AtInfo()