1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-30 05:56:54 -05:00

Refine code according to golangci-lint results

This commit is contained in:
loyalsoldier 2020-10-11 19:22:46 +08:00
parent ff6bb51732
commit 784775f689
No known key found for this signature in database
GPG Key ID: 23829BBC1ACF2C90
118 changed files with 515 additions and 532 deletions

View File

@ -22,9 +22,9 @@ func (l *OutboundListener) add(conn net.Conn) {
select { select {
case l.buffer <- conn: case l.buffer <- conn:
case <-l.done.Wait(): case <-l.done.Wait():
conn.Close() // nolint: errcheck conn.Close()
default: default:
conn.Close() // nolint: errcheck conn.Close()
} }
} }
@ -45,7 +45,7 @@ L:
for { for {
select { select {
case c := <-l.buffer: case c := <-l.buffer:
c.Close() // nolint: errcheck c.Close()
default: default:
break L break L
} }

View File

@ -89,7 +89,6 @@ func Test_parseResponse(t *testing.T) {
} }
func Test_buildReqMsgs(t *testing.T) { func Test_buildReqMsgs(t *testing.T) {
stubID := func() uint16 { stubID := func() uint16 {
return uint16(rand.Uint32()) return uint16(rand.Uint32())
} }

View File

@ -43,7 +43,6 @@ type DoHNameServer struct {
// NewDoHNameServer creates DOH client object for remote resolving // NewDoHNameServer creates DOH client object for remote resolving
func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) { func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) {
newError("DNS: created Remote DOH client for ", url.String()).AtInfo().WriteToLog() newError("DNS: created Remote DOH client for ", url.String()).AtInfo().WriteToLog()
s := baseDOHNameServer(url, "DOH", clientIP) s := baseDOHNameServer(url, "DOH", clientIP)
@ -112,7 +111,6 @@ func NewDoHLocalNameServer(url *url.URL, clientIP net.IP) *DoHNameServer {
} }
func baseDOHNameServer(url *url.URL, prefix string, clientIP net.IP) *DoHNameServer { func baseDOHNameServer(url *url.URL, prefix string, clientIP net.IP) *DoHNameServer {
s := &DoHNameServer{ s := &DoHNameServer{
ips: make(map[string]record), ips: make(map[string]record),
clientIP: clientIP, clientIP: clientIP,

View File

@ -66,7 +66,8 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma
} }
id := g.Add(matcher) id := g.Add(matcher)
ips := make([]net.Address, 0, len(mapping.Ip)+1) ips := make([]net.Address, 0, len(mapping.Ip)+1)
if len(mapping.Ip) > 0 { switch {
case len(mapping.Ip) > 0:
for _, ip := range mapping.Ip { for _, ip := range mapping.Ip {
addr := net.IPAddress(ip) addr := net.IPAddress(ip)
if addr == nil { if addr == nil {
@ -74,9 +75,11 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma
} }
ips = append(ips, addr) ips = append(ips, addr)
} }
} else if len(mapping.ProxiedDomain) > 0 {
case len(mapping.ProxiedDomain) > 0:
ips = append(ips, net.DomainAddress(mapping.ProxiedDomain)) ips = append(ips, net.DomainAddress(mapping.ProxiedDomain))
} else {
default:
return nil, newError("neither IP address nor proxied domain specified for domain: ", mapping.Domain).AtWarning() return nil, newError("neither IP address nor proxied domain specified for domain: ", mapping.Domain).AtWarning()
} }

View File

@ -24,11 +24,11 @@ type Client interface {
QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error)
} }
type localNameServer struct { type LocalNameServer struct {
client *localdns.Client client *localdns.Client
} }
func (s *localNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) { func (s *LocalNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
if option.IPv4Enable && option.IPv6Enable { if option.IPv4Enable && option.IPv6Enable {
return s.client.LookupIP(domain) return s.client.LookupIP(domain)
} }
@ -44,13 +44,13 @@ func (s *localNameServer) QueryIP(ctx context.Context, domain string, option IPO
return nil, newError("neither IPv4 nor IPv6 is enabled") return nil, newError("neither IPv4 nor IPv6 is enabled")
} }
func (s *localNameServer) Name() string { func (s *LocalNameServer) Name() string {
return "localhost" return "localhost"
} }
func NewLocalNameServer() *localNameServer { func NewLocalNameServer() *LocalNameServer {
newError("DNS: created localhost client").AtInfo().WriteToLog() newError("DNS: created localhost client").AtInfo().WriteToLog()
return &localNameServer{ return &LocalNameServer{
client: localdns.New(), client: localdns.New(),
} }
} }

View File

@ -97,7 +97,9 @@ func New(ctx context.Context, config *Config) (*Server, error) {
addNameServer := func(ns *NameServer) int { addNameServer := func(ns *NameServer) int {
endpoint := ns.Address endpoint := ns.Address
address := endpoint.Address.AsAddress() address := endpoint.Address.AsAddress()
if address.Family().IsDomain() && address.Domain() == "localhost" {
switch {
case address.Family().IsDomain() && address.Domain() == "localhost":
server.clients = append(server.clients, NewLocalNameServer()) server.clients = append(server.clients, NewLocalNameServer())
// Priotize local domains with specific TLDs or without any dot to local DNS // Priotize local domains with specific TLDs or without any dot to local DNS
// References: // References:
@ -115,7 +117,8 @@ func New(ctx context.Context, config *Config) (*Server, error) {
{Type: DomainMatchingType_Subdomain, Domain: "test"}, {Type: DomainMatchingType_Subdomain, Domain: "test"},
} }
ns.PrioritizedDomain = append(ns.PrioritizedDomain, localTLDsAndDotlessDomains...) ns.PrioritizedDomain = append(ns.PrioritizedDomain, localTLDsAndDotlessDomains...)
} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "https+local://") {
case address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "https+local://"):
// URI schemed string treated as domain // URI schemed string treated as domain
// DOH Local mode // DOH Local mode
u, err := url.Parse(address.Domain()) u, err := url.Parse(address.Domain())
@ -123,7 +126,8 @@ func New(ctx context.Context, config *Config) (*Server, error) {
log.Fatalln(newError("DNS config error").Base(err)) log.Fatalln(newError("DNS config error").Base(err))
} }
server.clients = append(server.clients, NewDoHLocalNameServer(u, server.clientIP)) server.clients = append(server.clients, NewDoHLocalNameServer(u, server.clientIP))
} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "https://") {
case address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "https://"):
// DOH Remote mode // DOH Remote mode
u, err := url.Parse(address.Domain()) u, err := url.Parse(address.Domain())
if err != nil { if err != nil {
@ -140,7 +144,8 @@ func New(ctx context.Context, config *Config) (*Server, error) {
} }
server.clients[idx] = c server.clients[idx] = c
})) }))
} else {
default:
// UDP classic DNS mode // UDP classic DNS mode
dest := endpoint.AsDestination() dest := endpoint.AsDestination()
if dest.Network == net.Network_Unknown { if dest.Network == net.Network_Unknown {

View File

@ -42,7 +42,8 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
for _, q := range r.Question { for _, q := range r.Question {
if q.Name == "google.com." && q.Qtype == dns.TypeA { switch {
case q.Name == "google.com." && q.Qtype == dns.TypeA:
if clientIP == nil { if clientIP == nil {
rr, _ := dns.NewRR("google.com. IN A 8.8.8.8") rr, _ := dns.NewRR("google.com. IN A 8.8.8.8")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
@ -50,44 +51,57 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
rr, _ := dns.NewRR("google.com. IN A 8.8.4.4") rr, _ := dns.NewRR("google.com. IN A 8.8.4.4")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} }
} else if q.Name == "api.google.com." && q.Qtype == dns.TypeA {
case q.Name == "api.google.com." && q.Qtype == dns.TypeA:
rr, _ := dns.NewRR("api.google.com. IN A 8.8.7.7") rr, _ := dns.NewRR("api.google.com. IN A 8.8.7.7")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "v2.api.google.com." && q.Qtype == dns.TypeA {
case q.Name == "v2.api.google.com." && q.Qtype == dns.TypeA:
rr, _ := dns.NewRR("v2.api.google.com. IN A 8.8.7.8") rr, _ := dns.NewRR("v2.api.google.com. IN A 8.8.7.8")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "facebook.com." && q.Qtype == dns.TypeA {
case q.Name == "facebook.com." && q.Qtype == dns.TypeA:
rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9") rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "ipv6.google.com." && q.Qtype == dns.TypeA {
case q.Name == "ipv6.google.com." && q.Qtype == dns.TypeA:
rr, err := dns.NewRR("ipv6.google.com. IN A 8.8.8.7") rr, err := dns.NewRR("ipv6.google.com. IN A 8.8.8.7")
common.Must(err) common.Must(err)
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "ipv6.google.com." && q.Qtype == dns.TypeAAAA {
case q.Name == "ipv6.google.com." && q.Qtype == dns.TypeAAAA:
rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888") rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888")
common.Must(err) common.Must(err)
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA {
case q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA:
ans.MsgHdr.Rcode = dns.RcodeNameError ans.MsgHdr.Rcode = dns.RcodeNameError
} else if q.Name == "hostname." && q.Qtype == dns.TypeA {
case q.Name == "hostname." && q.Qtype == dns.TypeA:
rr, _ := dns.NewRR("hostname. IN A 127.0.0.1") rr, _ := dns.NewRR("hostname. IN A 127.0.0.1")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "hostname.local." && q.Qtype == dns.TypeA {
case q.Name == "hostname.local." && q.Qtype == dns.TypeA:
rr, _ := dns.NewRR("hostname.local. IN A 127.0.0.1") rr, _ := dns.NewRR("hostname.local. IN A 127.0.0.1")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "hostname.localdomain." && q.Qtype == dns.TypeA {
case q.Name == "hostname.localdomain." && q.Qtype == dns.TypeA:
rr, _ := dns.NewRR("hostname.localdomain. IN A 127.0.0.1") rr, _ := dns.NewRR("hostname.localdomain. IN A 127.0.0.1")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "localhost." && q.Qtype == dns.TypeA {
case q.Name == "localhost." && q.Qtype == dns.TypeA:
rr, _ := dns.NewRR("localhost. IN A 127.0.0.2") rr, _ := dns.NewRR("localhost. IN A 127.0.0.2")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "localhost-a." && q.Qtype == dns.TypeA {
case q.Name == "localhost-a." && q.Qtype == dns.TypeA:
rr, _ := dns.NewRR("localhost-a. IN A 127.0.0.3") rr, _ := dns.NewRR("localhost-a. IN A 127.0.0.3")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "localhost-b." && q.Qtype == dns.TypeA {
case q.Name == "localhost-b." && q.Qtype == dns.TypeA:
rr, _ := dns.NewRR("localhost-b. IN A 127.0.0.4") rr, _ := dns.NewRR("localhost-b. IN A 127.0.0.4")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "Mijia\\ Cloud." && q.Qtype == dns.TypeA {
case q.Name == "Mijia\\ Cloud." && q.Qtype == dns.TypeA:
rr, _ := dns.NewRR("Mijia\\ Cloud. IN A 127.0.0.1") rr, _ := dns.NewRR("Mijia\\ Cloud. IN A 127.0.0.1")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} }

View File

@ -36,7 +36,6 @@ type ClassicNameServer struct {
} }
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer { func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
// default to 53 if unspecific // default to 53 if unspecific
if address.Port == 0 { if address.Port == 0 {
address.Port = net.Port(53) address.Port = net.Port(53)
@ -105,7 +104,6 @@ func (s *ClassicNameServer) Cleanup() error {
} }
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) { func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
ipRec, err := parseResponse(packet.Payload.Bytes()) ipRec, err := parseResponse(packet.Payload.Bytes())
if err != nil { if err != nil {
newError(s.name, " fail to parse responded DNS udp").AtError().WriteToLog() newError(s.name, " fail to parse responded DNS udp").AtError().WriteToLog()
@ -240,7 +238,6 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]
} }
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) { func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
fqdn := Fqdn(domain) fqdn := Fqdn(domain)
ips, err := s.findIPsForDomain(fqdn, option) ips, err := s.findIPsForDomain(fqdn, option)

View File

@ -127,10 +127,10 @@ func (g *Instance) Close() error {
g.active = false g.active = false
common.Close(g.accessLogger) // nolint: errcheck common.Close(g.accessLogger)
g.accessLogger = nil g.accessLogger = nil
common.Close(g.errorLogger) // nolint: errcheck common.Close(g.errorLogger)
g.errorLogger = nil g.errorLogger = nil
return nil return nil

View File

@ -280,7 +280,7 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest
conn, existing := w.getConnection(id) conn, existing := w.getConnection(id)
// payload will be discarded in pipe is full. // payload will be discarded in pipe is full.
conn.writer.WriteMultiBuffer(buf.MultiBuffer{b}) // nolint: errcheck conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
if !existing { if !existing {
common.Must(w.checker.Start()) common.Must(w.checker.Start())
@ -303,7 +303,7 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest
if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil { if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil {
newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx)) newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
} }
conn.Close() // nolint: errcheck conn.Close()
w.removeConn(id) w.removeConn(id)
}() }()
} }
@ -332,9 +332,9 @@ func (w *udpWorker) clean() error {
} }
for addr, conn := range w.activeConn { for addr, conn := range w.activeConn {
if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 8 { //TODO Timeout too small if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 8 { // TODO Timeout too small
delete(w.activeConn, addr) delete(w.activeConn, addr)
conn.Close() // nolint: errcheck conn.Close()
} }
} }

View File

@ -39,7 +39,7 @@ func (s *routingServer) TestRoute(ctx context.Context, request *TestRouteRequest
return nil, err return nil, err
} }
if request.PublishResult && s.routingStats != nil { if request.PublishResult && s.routingStats != nil {
ctx, _ := context.WithTimeout(context.Background(), 4*time.Second) // nolint: govet ctx, _ := context.WithTimeout(context.Background(), 4*time.Second)
s.routingStats.Publish(ctx, route) s.routingStats.Publish(ctx, route)
} }
return AsProtobufMessage(request.FieldSelectors)(route), nil return AsProtobufMessage(request.FieldSelectors)(route), nil
@ -54,7 +54,7 @@ func (s *routingServer) SubscribeRoutingStats(request *SubscribeRoutingStatsRequ
if err != nil { if err != nil {
return err return err
} }
defer stats.UnsubscribeClosableChannel(s.routingStats, subscriber) // nolint: errcheck defer stats.UnsubscribeClosableChannel(s.routingStats, subscriber)
for { for {
select { select {
case value, ok := <-subscriber: case value, ok := <-subscriber:

View File

@ -34,12 +34,12 @@ func TestSimpleRouter(t *testing.T) {
mockCtl := gomock.NewController(t) mockCtl := gomock.NewController(t)
defer mockCtl.Finish() defer mockCtl.Finish()
mockDns := mocks.NewDNSClient(mockCtl) mockDNS := mocks.NewDNSClient(mockCtl)
mockOhm := mocks.NewOutboundManager(mockCtl) mockOhm := mocks.NewOutboundManager(mockCtl)
mockHs := mocks.NewOutboundHandlerSelector(mockCtl) mockHs := mocks.NewOutboundHandlerSelector(mockCtl)
r := new(Router) r := new(Router)
common.Must(r.Init(config, mockDns, &mockOutboundManager{ common.Must(r.Init(config, mockDNS, &mockOutboundManager{
Manager: mockOhm, Manager: mockOhm,
HandlerSelector: mockHs, HandlerSelector: mockHs,
})) }))
@ -73,14 +73,14 @@ func TestSimpleBalancer(t *testing.T) {
mockCtl := gomock.NewController(t) mockCtl := gomock.NewController(t)
defer mockCtl.Finish() defer mockCtl.Finish()
mockDns := mocks.NewDNSClient(mockCtl) mockDNS := mocks.NewDNSClient(mockCtl)
mockOhm := mocks.NewOutboundManager(mockCtl) mockOhm := mocks.NewOutboundManager(mockCtl)
mockHs := mocks.NewOutboundHandlerSelector(mockCtl) mockHs := mocks.NewOutboundHandlerSelector(mockCtl)
mockHs.EXPECT().Select(gomock.Eq([]string{"test-"})).Return([]string{"test"}) mockHs.EXPECT().Select(gomock.Eq([]string{"test-"})).Return([]string{"test"})
r := new(Router) r := new(Router)
common.Must(r.Init(config, mockDns, &mockOutboundManager{ common.Must(r.Init(config, mockDNS, &mockOutboundManager{
Manager: mockOhm, Manager: mockOhm,
HandlerSelector: mockHs, HandlerSelector: mockHs,
})) }))
@ -114,11 +114,11 @@ func TestIPOnDemand(t *testing.T) {
mockCtl := gomock.NewController(t) mockCtl := gomock.NewController(t)
defer mockCtl.Finish() defer mockCtl.Finish()
mockDns := mocks.NewDNSClient(mockCtl) mockDNS := mocks.NewDNSClient(mockCtl)
mockDns.EXPECT().LookupIP(gomock.Eq("v2ray.com")).Return([]net.IP{{192, 168, 0, 1}}, nil).AnyTimes() mockDNS.EXPECT().LookupIP(gomock.Eq("v2ray.com")).Return([]net.IP{{192, 168, 0, 1}}, nil).AnyTimes()
r := new(Router) r := new(Router)
common.Must(r.Init(config, mockDns, nil)) common.Must(r.Init(config, mockDNS, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
@ -149,11 +149,11 @@ func TestIPIfNonMatchDomain(t *testing.T) {
mockCtl := gomock.NewController(t) mockCtl := gomock.NewController(t)
defer mockCtl.Finish() defer mockCtl.Finish()
mockDns := mocks.NewDNSClient(mockCtl) mockDNS := mocks.NewDNSClient(mockCtl)
mockDns.EXPECT().LookupIP(gomock.Eq("v2ray.com")).Return([]net.IP{{192, 168, 0, 1}}, nil).AnyTimes() mockDNS.EXPECT().LookupIP(gomock.Eq("v2ray.com")).Return([]net.IP{{192, 168, 0, 1}}, nil).AnyTimes()
r := new(Router) r := new(Router)
common.Must(r.Init(config, mockDns, nil)) common.Must(r.Init(config, mockDNS, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
@ -184,10 +184,10 @@ func TestIPIfNonMatchIP(t *testing.T) {
mockCtl := gomock.NewController(t) mockCtl := gomock.NewController(t)
defer mockCtl.Finish() defer mockCtl.Finish()
mockDns := mocks.NewDNSClient(mockCtl) mockDNS := mocks.NewDNSClient(mockCtl)
r := new(Router) r := new(Router)
common.Must(r.Init(config, mockDns, nil)) common.Must(r.Init(config, mockDNS, nil))
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)}) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))

View File

@ -388,7 +388,7 @@ func TestStatsChannelConcurrency(t *testing.T) {
if ok { if ok {
errCh <- fmt.Sprint("unexpected receiving: ", v) errCh <- fmt.Sprint("unexpected receiving: ", v)
} else { } else {
errCh <- fmt.Sprint("unexpected closing of channel") errCh <- "unexpected closing of channel"
} }
default: default:
} }

View File

@ -1,14 +1,15 @@
package antireplay package antireplay
import ( import (
cuckoo "github.com/seiflotfy/cuckoofilter"
"sync" "sync"
"time" "time"
cuckoo "github.com/seiflotfy/cuckoofilter"
) )
func NewAntiReplayWindow(AntiReplayTime int64) *AntiReplayWindow { func NewAntiReplayWindow(antiReplayTime int64) *AntiReplayWindow {
arw := &AntiReplayWindow{} arw := &AntiReplayWindow{}
arw.AntiReplayTime = AntiReplayTime arw.AntiReplayTime = antiReplayTime
return arw return arw
} }

View File

@ -4,12 +4,11 @@ import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"io" "io"
"io/ioutil"
"os"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"io/ioutil"
"os"
"v2ray.com/core/common" "v2ray.com/core/common"
. "v2ray.com/core/common/buf" . "v2ray.com/core/common/buf"
) )
@ -110,7 +109,6 @@ func TestMultiBufferReadAllToByte(t *testing.T) {
if l := len(b); l != 8*1024 { if l := len(b); l != 8*1024 {
t.Error("unexpceted length from ReadAllToBytes", l) t.Error("unexpceted length from ReadAllToBytes", l)
} }
} }
{ {
const dat = "data/test_MultiBufferReadAllToByte.dat" const dat = "data/test_MultiBufferReadAllToByte.dat"

View File

@ -88,7 +88,6 @@ func TestReadBuffer(t *testing.T) {
} }
buf.Release() buf.Release()
} }
} }
func TestReadAtMost(t *testing.T) { func TestReadAtMost(t *testing.T) {

View File

@ -23,11 +23,11 @@ func TestReadvReader(t *testing.T) {
} }
dest, err := tcpServer.Start() dest, err := tcpServer.Start()
common.Must(err) common.Must(err)
defer tcpServer.Close() // nolint: errcheck defer tcpServer.Close()
conn, err := net.Dial("tcp", dest.NetAddr()) conn, err := net.Dial("tcp", dest.NetAddr())
common.Must(err) common.Must(err)
defer conn.Close() // nolint: errcheck defer conn.Close()
const size = 8192 const size = 8192
data := make([]byte, 8192) data := make([]byte, 8192)

View File

@ -65,7 +65,7 @@ func Free(b []byte) {
b = b[0:cap(b)] b = b[0:cap(b)]
for i := numPools - 1; i >= 0; i-- { for i := numPools - 1; i >= 0; i-- {
if size >= poolSize[i] { if size >= poolSize[i] {
pool[i].Put(b) // nolint: megacheck pool[i].Put(b)
return return
} }
} }

View File

@ -32,15 +32,15 @@ func RollUint64() uint64 {
return rand.Uint64() return rand.Uint64()
} }
func NewDeterministicDice(seed int64) *deterministicDice { func NewDeterministicDice(seed int64) *DeterministicDice {
return &deterministicDice{rand.New(rand.NewSource(seed))} return &DeterministicDice{rand.New(rand.NewSource(seed))}
} }
type deterministicDice struct { type DeterministicDice struct {
*rand.Rand *rand.Rand
} }
func (dd *deterministicDice) Roll(n int) int { func (dd *DeterministicDice) Roll(n int) int {
if n == 1 { if n == 1 {
return 0 return 0
} }

View File

@ -48,14 +48,14 @@ func (l *generalLogger) run() {
if logger == nil { if logger == nil {
return return
} }
defer logger.Close() // nolint: errcheck defer logger.Close()
for { for {
select { select {
case <-l.done.Wait(): case <-l.done.Wait():
return return
case msg := <-l.buffer: case msg := <-l.buffer:
logger.Write(msg.String() + platform.LineSeparator()) // nolint: errcheck logger.Write(msg.String() + platform.LineSeparator())
dataWritten = true dataWritten = true
case <-ticker.C: case <-ticker.C:
if !dataWritten { if !dataWritten {

View File

@ -29,7 +29,7 @@ func TestFileLogger(t *testing.T) {
f, err = os.Open(path) f, err = os.Open(path)
common.Must(err) common.Must(err)
defer f.Close() // nolint: errcheck defer f.Close()
b, err := buf.ReadAllToBytes(f) b, err := buf.ReadAllToBytes(f)
common.Must(err) common.Must(err)

View File

@ -214,8 +214,8 @@ func (m *ClientWorker) monitor() {
select { select {
case <-m.done.Wait(): case <-m.done.Wait():
m.sessionManager.Close() m.sessionManager.Close()
common.Close(m.link.Writer) // nolint: errcheck common.Close(m.link.Writer)
common.Interrupt(m.link.Reader) // nolint: errcheck common.Interrupt(m.link.Reader)
return return
case <-timer.C: case <-timer.C:
size := m.sessionManager.Size() size := m.sessionManager.Size()
@ -247,8 +247,8 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
} }
s.transferType = transferType s.transferType = transferType
writer := NewWriter(s.ID, dest, output, transferType) writer := NewWriter(s.ID, dest, output, transferType)
defer s.Close() // nolint: errcheck defer s.Close()
defer writer.Close() // nolint: errcheck defer writer.Close()
newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx)) newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx))
if err := writeFirstPayload(s.input, writer); err != nil { if err := writeFirstPayload(s.input, writer); err != nil {

View File

@ -232,7 +232,7 @@ func (w *ServerWorker) run(ctx context.Context) {
input := w.link.Reader input := w.link.Reader
reader := &buf.BufferedReader{Reader: input} reader := &buf.BufferedReader{Reader: input}
defer w.sessionManager.Close() // nolint: errcheck defer w.sessionManager.Close()
for { for {
select { select {

View File

@ -126,8 +126,8 @@ func (m *SessionManager) Close() error {
m.closed = true m.closed = true
for _, s := range m.sessions { for _, s := range m.sessions {
common.Close(s.input) // nolint: errcheck common.Close(s.input)
common.Close(s.output) // nolint: errcheck common.Close(s.output)
} }
m.sessions = nil m.sessions = nil
@ -145,8 +145,8 @@ type Session struct {
// Close closes all resources associated with this session. // Close closes all resources associated with this session.
func (s *Session) Close() error { func (s *Session) Close() error {
common.Close(s.output) // nolint: errcheck common.Close(s.output)
common.Close(s.input) // nolint: errcheck common.Close(s.input)
s.parent.Remove(s.ID) s.parent.Remove(s.ID)
return nil return nil
} }

View File

@ -121,6 +121,6 @@ func (w *Writer) Close() error {
frame := buf.New() frame := buf.New()
common.Must(meta.WriteTo(frame)) common.Must(meta.WriteTo(frame))
w.writer.WriteMultiBuffer(buf.MultiBuffer{frame}) // nolint: errcheck w.writer.WriteMultiBuffer(buf.MultiBuffer{frame})
return nil return nil
} }

View File

@ -90,10 +90,8 @@ func TestDestinationParse(t *testing.T) {
if d != testcase.Output { if d != testcase.Output {
t.Error("for test case: ", testcase.Input, " expected output: ", testcase.Output.String(), " but got ", d.String()) t.Error("for test case: ", testcase.Input, " expected output: ", testcase.Output.String(), " but got ", d.String())
} }
} else { } else if err == nil {
if err == nil { t.Error("for test case: ", testcase.Input, " expected error, but got nil")
t.Error("for test case: ", testcase.Input, " expected error, but got nil")
}
} }
} }
} }

View File

@ -49,7 +49,7 @@ func (f EnvFlag) GetValueAsInt(defaultValue int) int {
} }
func NormalizeEnvName(name string) string { func NormalizeEnvName(name string) string {
return strings.Replace(strings.ToUpper(strings.TrimSpace(name)), ".", "_", -1) return strings.ReplaceAll(strings.ToUpper(strings.TrimSpace(name)), ".", "_")
} }
func getExecutableDir() string { func getExecutableDir() string {

View File

@ -8,6 +8,7 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/task" "v2ray.com/core/common/task"
) )

View File

@ -10,7 +10,7 @@ import (
) )
var ( var (
errorTestOnly = errors.New("This is a fake error.") errorTestOnly = errors.New("this is a fake error")
) )
func TestNoRetry(t *testing.T) { func TestNoRetry(t *testing.T) {

View File

@ -45,7 +45,7 @@ func (t *ActivityTimer) finish() {
t.onTimeout = nil t.onTimeout = nil
} }
if t.checkTask != nil { if t.checkTask != nil {
t.checkTask.Close() // nolint: errcheck t.checkTask.Close()
t.checkTask = nil t.checkTask = nil
} }
} }
@ -64,7 +64,7 @@ func (t *ActivityTimer) SetTimeout(timeout time.Duration) {
t.Lock() t.Lock()
if t.checkTask != nil { if t.checkTask != nil {
t.checkTask.Close() // nolint: errcheck t.checkTask.Close()
} }
t.checkTask = checkTask t.checkTask = checkTask
t.Unlock() t.Unlock()

View File

@ -44,7 +44,7 @@ func (t *Periodic) checkedExecute() error {
} }
t.timer = time.AfterFunc(t.Interval, func() { t.timer = time.AfterFunc(t.Interval, func() {
t.checkedExecute() // nolint: errcheck t.checkedExecute()
}) })
return nil return nil

View File

@ -117,8 +117,8 @@ func defaultBufferPolicy() Buffer {
func SessionDefault() Session { func SessionDefault() Session {
return Session{ return Session{
Timeouts: Timeout{ Timeouts: Timeout{
//Align Handshake timeout with nginx client_header_timeout // Align Handshake timeout with nginx client_header_timeout
//So that this value will not indicate server identity // So that this value will not indicate server identity
Handshake: time.Second * 60, Handshake: time.Second * 60,
ConnectionIdle: time.Second * 300, ConnectionIdle: time.Second * 300,
UplinkOnly: time.Second * 1, UplinkOnly: time.Second * 1,

View File

@ -10,14 +10,14 @@ import (
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
) )
type ApiConfig struct { type APIConfig struct {
Tag string `json:"tag"` Tag string `json:"tag"`
Services []string `json:"services"` Services []string `json:"services"`
} }
func (c *ApiConfig) Build() (*commander.Config, error) { func (c *APIConfig) Build() (*commander.Config, error) {
if c.Tag == "" { if c.Tag == "" {
return nil, newError("Api tag can't be empty.") return nil, newError("API tag can't be empty.")
} }
services := make([]*serial.TypedMessage, 0, 16) services := make([]*serial.TypedMessage, 0, 16)

View File

@ -15,9 +15,9 @@ func (*NoneResponse) Build() (proto.Message, error) {
return new(blackhole.NoneResponse), nil return new(blackhole.NoneResponse), nil
} }
type HttpResponse struct{} type HTTPResponse struct{}
func (*HttpResponse) Build() (proto.Message, error) { func (*HTTPResponse) Build() (proto.Message, error) {
return new(blackhole.HTTPResponse), nil return new(blackhole.HTTPResponse), nil
} }
@ -46,7 +46,7 @@ var (
configLoader = NewJSONConfigLoader( configLoader = NewJSONConfigLoader(
ConfigCreatorCache{ ConfigCreatorCache{
"none": func() interface{} { return new(NoneResponse) }, "none": func() interface{} { return new(NoneResponse) },
"http": func() interface{} { return new(HttpResponse) }, "http": func() interface{} { return new(HTTPResponse) },
}, },
"type", "type",
"") "")

View File

@ -221,7 +221,7 @@ func (list *PortList) UnmarshalJSON(data []byte) error {
} }
} }
if number != 0 { if number != 0 {
list.Range = append(list.Range, PortRange{From: uint32(number), To: uint32(number)}) list.Range = append(list.Range, PortRange{From: number, To: number})
} }
return nil return nil
} }

View File

@ -2,31 +2,31 @@ package conf_test
import ( import (
"encoding/json" "encoding/json"
"github.com/google/go-cmp/cmp/cmpopts"
"os" "os"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"v2ray.com/core/common/protocol" "github.com/google/go-cmp/cmp/cmpopts"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol"
. "v2ray.com/core/infra/conf" . "v2ray.com/core/infra/conf"
) )
func TestStringListUnmarshalError(t *testing.T) { func TestStringListUnmarshalError(t *testing.T) {
rawJson := `1234` rawJSON := `1234`
list := new(StringList) list := new(StringList)
err := json.Unmarshal([]byte(rawJson), list) err := json.Unmarshal([]byte(rawJSON), list)
if err == nil { if err == nil {
t.Error("expected error, but got nil") t.Error("expected error, but got nil")
} }
} }
func TestStringListLen(t *testing.T) { func TestStringListLen(t *testing.T) {
rawJson := `"a, b, c, d"` rawJSON := `"a, b, c, d"`
var list StringList var list StringList
err := json.Unmarshal([]byte(rawJson), &list) err := json.Unmarshal([]byte(rawJSON), &list)
common.Must(err) common.Must(err)
if r := cmp.Diff([]string(list), []string{"a", " b", " c", " d"}); r != "" { if r := cmp.Diff([]string(list), []string{"a", " b", " c", " d"}); r != "" {
t.Error(r) t.Error(r)
@ -34,9 +34,9 @@ func TestStringListLen(t *testing.T) {
} }
func TestIPParsing(t *testing.T) { func TestIPParsing(t *testing.T) {
rawJson := "\"8.8.8.8\"" rawJSON := "\"8.8.8.8\""
var address Address var address Address
err := json.Unmarshal([]byte(rawJson), &address) err := json.Unmarshal([]byte(rawJSON), &address)
common.Must(err) common.Must(err)
if r := cmp.Diff(address.IP(), net.IP{8, 8, 8, 8}); r != "" { if r := cmp.Diff(address.IP(), net.IP{8, 8, 8, 8}); r != "" {
t.Error(r) t.Error(r)
@ -44,9 +44,9 @@ func TestIPParsing(t *testing.T) {
} }
func TestDomainParsing(t *testing.T) { func TestDomainParsing(t *testing.T) {
rawJson := "\"v2ray.com\"" rawJSON := "\"v2ray.com\""
var address Address var address Address
common.Must(json.Unmarshal([]byte(rawJson), &address)) common.Must(json.Unmarshal([]byte(rawJSON), &address))
if address.Domain() != "v2ray.com" { if address.Domain() != "v2ray.com" {
t.Error("domain: ", address.Domain()) t.Error("domain: ", address.Domain())
} }
@ -54,17 +54,17 @@ func TestDomainParsing(t *testing.T) {
func TestURLParsing(t *testing.T) { func TestURLParsing(t *testing.T) {
{ {
rawJson := "\"https://dns.google/dns-query\"" rawJSON := "\"https://dns.google/dns-query\""
var address Address var address Address
common.Must(json.Unmarshal([]byte(rawJson), &address)) common.Must(json.Unmarshal([]byte(rawJSON), &address))
if address.Domain() != "https://dns.google/dns-query" { if address.Domain() != "https://dns.google/dns-query" {
t.Error("URL: ", address.Domain()) t.Error("URL: ", address.Domain())
} }
} }
{ {
rawJson := "\"https+local://dns.google/dns-query\"" rawJSON := "\"https+local://dns.google/dns-query\""
var address Address var address Address
common.Must(json.Unmarshal([]byte(rawJson), &address)) common.Must(json.Unmarshal([]byte(rawJSON), &address))
if address.Domain() != "https+local://dns.google/dns-query" { if address.Domain() != "https+local://dns.google/dns-query" {
t.Error("URL: ", address.Domain()) t.Error("URL: ", address.Domain())
} }
@ -72,9 +72,9 @@ func TestURLParsing(t *testing.T) {
} }
func TestInvalidAddressJson(t *testing.T) { func TestInvalidAddressJson(t *testing.T) {
rawJson := "1234" rawJSON := "1234"
var address Address var address Address
err := json.Unmarshal([]byte(rawJson), &address) err := json.Unmarshal([]byte(rawJSON), &address)
if err == nil { if err == nil {
t.Error("nil error") t.Error("nil error")
} }

View File

@ -106,8 +106,8 @@ var typeMap = map[router.Domain_Type]dns.DomainMatchingType{
router.Domain_Regex: dns.DomainMatchingType_Regex, router.Domain_Regex: dns.DomainMatchingType_Regex,
} }
// DnsConfig is a JSON serializable object for dns.Config. // DNSConfig is a JSON serializable object for dns.Config.
type DnsConfig struct { type DNSConfig struct {
Servers []*NameServerConfig `json:"servers"` Servers []*NameServerConfig `json:"servers"`
Hosts map[string]*Address `json:"hosts"` Hosts map[string]*Address `json:"hosts"`
ClientIP *Address `json:"clientIp"` ClientIP *Address `json:"clientIp"`
@ -127,7 +127,7 @@ func getHostMapping(addr *Address) *dns.Config_HostMapping {
} }
// Build implements Buildable // Build implements Buildable
func (c *DnsConfig) Build() (*dns.Config, error) { func (c *DNSConfig) Build() (*dns.Config, error) {
config := &dns.Config{ config := &dns.Config{
Tag: c.Tag, Tag: c.Tag,
} }
@ -153,16 +153,18 @@ func (c *DnsConfig) Build() (*dns.Config, error) {
domains = append(domains, domain) domains = append(domains, domain)
} }
sort.Strings(domains) sort.Strings(domains)
for _, domain := range domains { for _, domain := range domains {
addr := c.Hosts[domain] addr := c.Hosts[domain]
var mappings []*dns.Config_HostMapping var mappings []*dns.Config_HostMapping
if strings.HasPrefix(domain, "domain:") { switch {
case strings.HasPrefix(domain, "domain:"):
mapping := getHostMapping(addr) mapping := getHostMapping(addr)
mapping.Type = dns.DomainMatchingType_Subdomain mapping.Type = dns.DomainMatchingType_Subdomain
mapping.Domain = domain[7:] mapping.Domain = domain[7:]
mappings = append(mappings, mapping) mappings = append(mappings, mapping)
} else if strings.HasPrefix(domain, "geosite:") {
case strings.HasPrefix(domain, "geosite:"):
domains, err := loadGeositeWithAttr("geosite.dat", strings.ToUpper(domain[8:])) domains, err := loadGeositeWithAttr("geosite.dat", strings.ToUpper(domain[8:]))
if err != nil { if err != nil {
return nil, newError("invalid geosite settings: ", domain).Base(err) return nil, newError("invalid geosite settings: ", domain).Base(err)
@ -171,28 +173,28 @@ func (c *DnsConfig) Build() (*dns.Config, error) {
mapping := getHostMapping(addr) mapping := getHostMapping(addr)
mapping.Type = typeMap[d.Type] mapping.Type = typeMap[d.Type]
mapping.Domain = d.Value mapping.Domain = d.Value
mappings = append(mappings, mapping) mappings = append(mappings, mapping)
} }
} else if strings.HasPrefix(domain, "regexp:") {
case strings.HasPrefix(domain, "regexp:"):
mapping := getHostMapping(addr) mapping := getHostMapping(addr)
mapping.Type = dns.DomainMatchingType_Regex mapping.Type = dns.DomainMatchingType_Regex
mapping.Domain = domain[7:] mapping.Domain = domain[7:]
mappings = append(mappings, mapping) mappings = append(mappings, mapping)
} else if strings.HasPrefix(domain, "keyword:") {
case strings.HasPrefix(domain, "keyword:"):
mapping := getHostMapping(addr) mapping := getHostMapping(addr)
mapping.Type = dns.DomainMatchingType_Keyword mapping.Type = dns.DomainMatchingType_Keyword
mapping.Domain = domain[8:] mapping.Domain = domain[8:]
mappings = append(mappings, mapping) mappings = append(mappings, mapping)
} else if strings.HasPrefix(domain, "full:") {
case strings.HasPrefix(domain, "full:"):
mapping := getHostMapping(addr) mapping := getHostMapping(addr)
mapping.Type = dns.DomainMatchingType_Full mapping.Type = dns.DomainMatchingType_Full
mapping.Domain = domain[5:] mapping.Domain = domain[5:]
mappings = append(mappings, mapping) mappings = append(mappings, mapping)
} else if strings.HasPrefix(domain, "dotless:") {
case strings.HasPrefix(domain, "dotless:"):
mapping := getHostMapping(addr) mapping := getHostMapping(addr)
mapping.Type = dns.DomainMatchingType_Regex mapping.Type = dns.DomainMatchingType_Regex
switch substr := domain[8:]; { switch substr := domain[8:]; {
@ -203,9 +205,9 @@ func (c *DnsConfig) Build() (*dns.Config, error) {
default: default:
return nil, newError("substr in dotless rule should not contain a dot: ", substr) return nil, newError("substr in dotless rule should not contain a dot: ", substr)
} }
mappings = append(mappings, mapping) mappings = append(mappings, mapping)
} else if strings.HasPrefix(domain, "ext:") {
case strings.HasPrefix(domain, "ext:"):
kv := strings.Split(domain[4:], ":") kv := strings.Split(domain[4:], ":")
if len(kv) != 2 { if len(kv) != 2 {
return nil, newError("invalid external resource: ", domain) return nil, newError("invalid external resource: ", domain)
@ -220,14 +222,13 @@ func (c *DnsConfig) Build() (*dns.Config, error) {
mapping := getHostMapping(addr) mapping := getHostMapping(addr)
mapping.Type = typeMap[d.Type] mapping.Type = typeMap[d.Type]
mapping.Domain = d.Value mapping.Domain = d.Value
mappings = append(mappings, mapping) mappings = append(mappings, mapping)
} }
} else {
default:
mapping := getHostMapping(addr) mapping := getHostMapping(addr)
mapping.Type = dns.DomainMatchingType_Full mapping.Type = dns.DomainMatchingType_Full
mapping.Domain = domain mapping.Domain = domain
mappings = append(mappings, mapping) mappings = append(mappings, mapping)
} }

View File

@ -6,13 +6,13 @@ import (
"v2ray.com/core/proxy/dns" "v2ray.com/core/proxy/dns"
) )
type DnsOutboundConfig struct { type DNSOutboundConfig struct {
Network Network `json:"network"` Network Network `json:"network"`
Address *Address `json:"address"` Address *Address `json:"address"`
Port uint16 `json:"port"` Port uint16 `json:"port"`
} }
func (c *DnsOutboundConfig) Build() (proto.Message, error) { func (c *DNSOutboundConfig) Build() (proto.Message, error) {
config := &dns.Config{ config := &dns.Config{
Server: &net.Endpoint{ Server: &net.Endpoint{
Network: c.Network.Build(), Network: c.Network.Build(),

View File

@ -10,7 +10,7 @@ import (
func TestDnsProxyConfig(t *testing.T) { func TestDnsProxyConfig(t *testing.T) {
creator := func() Buildable { creator := func() Buildable {
return new(DnsOutboundConfig) return new(DNSOutboundConfig)
} }
runMultiTestCase(t, []TestCase{ runMultiTestCase(t, []TestCase{

View File

@ -45,7 +45,7 @@ func init() {
common.Must(err) common.Must(err)
common.Must2(geositeFile.Write(listBytes)) common.Must2(geositeFile.Write(listBytes))
} }
func TestDnsConfigParsing(t *testing.T) { func TestDNSConfigParsing(t *testing.T) {
geositePath := platform.GetAssetLocation("geosite.dat") geositePath := platform.GetAssetLocation("geosite.dat")
defer func() { defer func() {
os.Remove(geositePath) os.Remove(geositePath)
@ -54,7 +54,7 @@ func TestDnsConfigParsing(t *testing.T) {
parserCreator := func() func(string) (proto.Message, error) { parserCreator := func() func(string) (proto.Message, error) {
return func(s string) (proto.Message, error) { return func(s string) (proto.Message, error) {
config := new(DnsConfig) config := new(DNSConfig)
if err := json.Unmarshal([]byte(s), config); err != nil { if err := json.Unmarshal([]byte(s), config); err != nil {
return nil, err return nil, err
} }

View File

@ -9,26 +9,26 @@ import (
"v2ray.com/core/proxy/http" "v2ray.com/core/proxy/http"
) )
type HttpAccount struct { type HTTPAccount struct {
Username string `json:"user"` Username string `json:"user"`
Password string `json:"pass"` Password string `json:"pass"`
} }
func (v *HttpAccount) Build() *http.Account { func (v *HTTPAccount) Build() *http.Account {
return &http.Account{ return &http.Account{
Username: v.Username, Username: v.Username,
Password: v.Password, Password: v.Password,
} }
} }
type HttpServerConfig struct { type HTTPServerConfig struct {
Timeout uint32 `json:"timeout"` Timeout uint32 `json:"timeout"`
Accounts []*HttpAccount `json:"accounts"` Accounts []*HTTPAccount `json:"accounts"`
Transparent bool `json:"allowTransparent"` Transparent bool `json:"allowTransparent"`
UserLevel uint32 `json:"userLevel"` UserLevel uint32 `json:"userLevel"`
} }
func (c *HttpServerConfig) Build() (proto.Message, error) { func (c *HTTPServerConfig) Build() (proto.Message, error) {
config := &http.ServerConfig{ config := &http.ServerConfig{
Timeout: c.Timeout, Timeout: c.Timeout,
AllowTransparent: c.Transparent, AllowTransparent: c.Transparent,
@ -45,16 +45,16 @@ func (c *HttpServerConfig) Build() (proto.Message, error) {
return config, nil return config, nil
} }
type HttpRemoteConfig struct { type HTTPRemoteConfig struct {
Address *Address `json:"address"` Address *Address `json:"address"`
Port uint16 `json:"port"` Port uint16 `json:"port"`
Users []json.RawMessage `json:"users"` Users []json.RawMessage `json:"users"`
} }
type HttpClientConfig struct { type HTTPClientConfig struct {
Servers []*HttpRemoteConfig `json:"servers"` Servers []*HTTPRemoteConfig `json:"servers"`
} }
func (v *HttpClientConfig) Build() (proto.Message, error) { func (v *HTTPClientConfig) Build() (proto.Message, error) {
config := new(http.ClientConfig) config := new(http.ClientConfig)
config.Server = make([]*protocol.ServerEndpoint, len(v.Servers)) config.Server = make([]*protocol.ServerEndpoint, len(v.Servers))
for idx, serverConfig := range v.Servers { for idx, serverConfig := range v.Servers {
@ -67,7 +67,7 @@ func (v *HttpClientConfig) Build() (proto.Message, error) {
if err := json.Unmarshal(rawUser, user); err != nil { if err := json.Unmarshal(rawUser, user); err != nil {
return nil, newError("failed to parse HTTP user").Base(err).AtError() return nil, newError("failed to parse HTTP user").Base(err).AtError()
} }
account := new(HttpAccount) account := new(HTTPAccount)
if err := json.Unmarshal(rawUser, account); err != nil { if err := json.Unmarshal(rawUser, account); err != nil {
return nil, newError("failed to parse HTTP account").Base(err).AtError() return nil, newError("failed to parse HTTP account").Base(err).AtError()
} }

View File

@ -7,9 +7,9 @@ import (
"v2ray.com/core/proxy/http" "v2ray.com/core/proxy/http"
) )
func TestHttpServerConfig(t *testing.T) { func TestHTTPServerConfig(t *testing.T) {
creator := func() Buildable { creator := func() Buildable {
return new(HttpServerConfig) return new(HTTPServerConfig)
} }
runMultiTestCase(t, []TestCase{ runMultiTestCase(t, []TestCase{

View File

@ -93,5 +93,4 @@ func TestReader1(t *testing.T) {
t.Error("got ", string(target), " want ", testCase.output) t.Error("got ", string(target), " want ", testCase.output)
} }
} }
} }

View File

@ -27,7 +27,7 @@ func TestBufferSize(t *testing.T) {
} }
for _, c := range cases { for _, c := range cases {
bs := int32(c.Input) bs := c.Input
pConf := Policy{ pConf := Policy{
BufferSize: &bs, BufferSize: &bs,
} }

View File

@ -5,11 +5,11 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/golang/protobuf/proto"
"v2ray.com/core/app/router" "v2ray.com/core/app/router"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/platform/filesystem" "v2ray.com/core/common/platform/filesystem"
"github.com/golang/protobuf/proto"
) )
type RouterRulesConfig struct { type RouterRulesConfig struct {
@ -67,10 +67,15 @@ func (c *RouterConfig) Build() (*router.Config, error) {
config := new(router.Config) config := new(router.Config)
config.DomainStrategy = c.getDomainStrategy() config.DomainStrategy = c.getDomainStrategy()
rawRuleList := c.RuleList var rawRuleList []json.RawMessage
if c.Settings != nil { if c != nil {
rawRuleList = append(c.RuleList, c.Settings.RuleList...) rawRuleList = c.RuleList
if c.Settings != nil {
c.RuleList = append(c.RuleList, c.Settings.RuleList...)
rawRuleList = c.RuleList
}
} }
for _, rawRule := range rawRuleList { for _, rawRule := range rawRuleList {
rule, err := ParseRule(rawRule) rule, err := ParseRule(rawRule)
if err != nil { if err != nil {
@ -290,15 +295,19 @@ func parseDomainRule(domain string) ([]*router.Domain, error) {
case strings.HasPrefix(domain, "regexp:"): case strings.HasPrefix(domain, "regexp:"):
domainRule.Type = router.Domain_Regex domainRule.Type = router.Domain_Regex
domainRule.Value = domain[7:] domainRule.Value = domain[7:]
case strings.HasPrefix(domain, "domain:"): case strings.HasPrefix(domain, "domain:"):
domainRule.Type = router.Domain_Domain domainRule.Type = router.Domain_Domain
domainRule.Value = domain[7:] domainRule.Value = domain[7:]
case strings.HasPrefix(domain, "full:"): case strings.HasPrefix(domain, "full:"):
domainRule.Type = router.Domain_Full domainRule.Type = router.Domain_Full
domainRule.Value = domain[5:] domainRule.Value = domain[5:]
case strings.HasPrefix(domain, "keyword:"): case strings.HasPrefix(domain, "keyword:"):
domainRule.Type = router.Domain_Plain domainRule.Type = router.Domain_Plain
domainRule.Value = domain[8:] domainRule.Value = domain[8:]
case strings.HasPrefix(domain, "dotless:"): case strings.HasPrefix(domain, "dotless:"):
domainRule.Type = router.Domain_Regex domainRule.Type = router.Domain_Regex
switch substr := domain[8:]; { switch substr := domain[8:]; {
@ -309,6 +318,7 @@ func parseDomainRule(domain string) ([]*router.Domain, error) {
default: default:
return nil, newError("substr in dotless rule should not contain a dot: ", substr) return nil, newError("substr in dotless rule should not contain a dot: ", substr)
} }
default: default:
domainRule.Type = router.Domain_Plain domainRule.Type = router.Domain_Plain
domainRule.Value = domain domainRule.Value = domain
@ -403,15 +413,16 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
} }
rule := new(router.RoutingRule) rule := new(router.RoutingRule)
if len(rawFieldRule.OutboundTag) > 0 { switch {
case len(rawFieldRule.OutboundTag) > 0:
rule.TargetTag = &router.RoutingRule_Tag{ rule.TargetTag = &router.RoutingRule_Tag{
Tag: rawFieldRule.OutboundTag, Tag: rawFieldRule.OutboundTag,
} }
} else if len(rawFieldRule.BalancerTag) > 0 { case len(rawFieldRule.BalancerTag) > 0:
rule.TargetTag = &router.RoutingRule_BalancingTag{ rule.TargetTag = &router.RoutingRule_BalancingTag{
BalancingTag: rawFieldRule.BalancerTag, BalancingTag: rawFieldRule.BalancerTag,
} }
} else { default:
return nil, newError("neither outboundTag nor balancerTag is specified in routing rule") return nil, newError("neither outboundTag nor balancerTag is specified in routing rule")
} }

View File

@ -43,7 +43,7 @@ func (v *SocksServerConfig) Build() (proto.Message, error) {
case AuthMethodUserPass: case AuthMethodUserPass:
config.AuthType = socks.AuthType_PASSWORD config.AuthType = socks.AuthType_PASSWORD
default: default:
//newError("unknown socks auth method: ", v.AuthMethod, ". Default to noauth.").AtWarning().WriteToLog() // newError("unknown socks auth method: ", v.AuthMethod, ". Default to noauth.").AtWarning().WriteToLog()
config.AuthType = socks.AuthType_NO_AUTH config.AuthType = socks.AuthType_NO_AUTH
} }

View File

@ -56,7 +56,7 @@ func (DTLSAuthenticator) Build() (proto.Message, error) {
return new(tls.PacketConfig), nil return new(tls.PacketConfig), nil
} }
type HTTPAuthenticatorRequest struct { type AuthenticatorRequest struct {
Version string `json:"version"` Version string `json:"version"`
Method string `json:"method"` Method string `json:"method"`
Path StringList `json:"path"` Path StringList `json:"path"`
@ -72,7 +72,7 @@ func sortMapKeys(m map[string]*StringList) []string {
return keys return keys
} }
func (v *HTTPAuthenticatorRequest) Build() (*http.RequestConfig, error) { func (v *AuthenticatorRequest) Build() (*http.RequestConfig, error) {
config := &http.RequestConfig{ config := &http.RequestConfig{
Uri: []string{"/"}, Uri: []string{"/"},
Header: []*http.Header{ Header: []*http.Header{
@ -132,14 +132,14 @@ func (v *HTTPAuthenticatorRequest) Build() (*http.RequestConfig, error) {
return config, nil return config, nil
} }
type HTTPAuthenticatorResponse struct { type AuthenticatorResponse struct {
Version string `json:"version"` Version string `json:"version"`
Status string `json:"status"` Status string `json:"status"`
Reason string `json:"reason"` Reason string `json:"reason"`
Headers map[string]*StringList `json:"headers"` Headers map[string]*StringList `json:"headers"`
} }
func (v *HTTPAuthenticatorResponse) Build() (*http.ResponseConfig, error) { func (v *AuthenticatorResponse) Build() (*http.ResponseConfig, error) {
config := &http.ResponseConfig{ config := &http.ResponseConfig{
Header: []*http.Header{ Header: []*http.Header{
{ {
@ -200,12 +200,12 @@ func (v *HTTPAuthenticatorResponse) Build() (*http.ResponseConfig, error) {
return config, nil return config, nil
} }
type HTTPAuthenticator struct { type Authenticator struct {
Request HTTPAuthenticatorRequest `json:"request"` Request AuthenticatorRequest `json:"request"`
Response HTTPAuthenticatorResponse `json:"response"` Response AuthenticatorResponse `json:"response"`
} }
func (v *HTTPAuthenticator) Build() (proto.Message, error) { func (v *Authenticator) Build() (proto.Message, error) {
config := new(http.Config) config := new(http.Config)
requestConfig, err := v.Request.Build() requestConfig, err := v.Request.Build()
if err != nil { if err != nil {

View File

@ -31,7 +31,7 @@ var (
tcpHeaderLoader = NewJSONConfigLoader(ConfigCreatorCache{ tcpHeaderLoader = NewJSONConfigLoader(ConfigCreatorCache{
"none": func() interface{} { return new(NoOpConnectionAuthenticator) }, "none": func() interface{} { return new(NoOpConnectionAuthenticator) },
"http": func() interface{} { return new(HTTPAuthenticator) }, "http": func() interface{} { return new(Authenticator) },
}, "type", "") }, "type", "")
) )
@ -473,7 +473,7 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) {
ProtocolName: "tcp", ProtocolName: "tcp",
} }
if c.Network != nil { if c.Network != nil {
protocol, err := (*c.Network).Build() protocol, err := c.Network.Build()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -6,7 +6,7 @@ import (
"strconv" "strconv"
"syscall" "syscall"
"github.com/golang/protobuf/proto" // nolint: staticcheck "github.com/golang/protobuf/proto"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"

View File

@ -17,7 +17,7 @@ import (
var ( var (
inboundConfigLoader = NewJSONConfigLoader(ConfigCreatorCache{ inboundConfigLoader = NewJSONConfigLoader(ConfigCreatorCache{
"dokodemo-door": func() interface{} { return new(DokodemoConfig) }, "dokodemo-door": func() interface{} { return new(DokodemoConfig) },
"http": func() interface{} { return new(HttpServerConfig) }, "http": func() interface{} { return new(HTTPServerConfig) },
"shadowsocks": func() interface{} { return new(ShadowsocksServerConfig) }, "shadowsocks": func() interface{} { return new(ShadowsocksServerConfig) },
"socks": func() interface{} { return new(SocksServerConfig) }, "socks": func() interface{} { return new(SocksServerConfig) },
"vless": func() interface{} { return new(VLessInboundConfig) }, "vless": func() interface{} { return new(VLessInboundConfig) },
@ -29,14 +29,14 @@ var (
outboundConfigLoader = NewJSONConfigLoader(ConfigCreatorCache{ outboundConfigLoader = NewJSONConfigLoader(ConfigCreatorCache{
"blackhole": func() interface{} { return new(BlackholeConfig) }, "blackhole": func() interface{} { return new(BlackholeConfig) },
"freedom": func() interface{} { return new(FreedomConfig) }, "freedom": func() interface{} { return new(FreedomConfig) },
"http": func() interface{} { return new(HttpClientConfig) }, "http": func() interface{} { return new(HTTPClientConfig) },
"shadowsocks": func() interface{} { return new(ShadowsocksClientConfig) }, "shadowsocks": func() interface{} { return new(ShadowsocksClientConfig) },
"socks": func() interface{} { return new(SocksClientConfig) }, "socks": func() interface{} { return new(SocksClientConfig) },
"vless": func() interface{} { return new(VLessOutboundConfig) }, "vless": func() interface{} { return new(VLessOutboundConfig) },
"vmess": func() interface{} { return new(VMessOutboundConfig) }, "vmess": func() interface{} { return new(VMessOutboundConfig) },
"trojan": func() interface{} { return new(TrojanClientConfig) }, "trojan": func() interface{} { return new(TrojanClientConfig) },
"mtproto": func() interface{} { return new(MTProtoClientConfig) }, "mtproto": func() interface{} { return new(MTProtoClientConfig) },
"dns": func() interface{} { return new(DnsOutboundConfig) }, "dns": func() interface{} { return new(DNSOutboundConfig) },
}, "protocol", "settings") }, "protocol", "settings")
ctllog = log.New(os.Stderr, "v2ctl> ", 0) ctllog = log.New(os.Stderr, "v2ctl> ", 0)
@ -315,7 +315,7 @@ type Config struct {
Port uint16 `json:"port"` // Port of this Point server. Deprecated. Port uint16 `json:"port"` // Port of this Point server. Deprecated.
LogConfig *LogConfig `json:"log"` LogConfig *LogConfig `json:"log"`
RouterConfig *RouterConfig `json:"routing"` RouterConfig *RouterConfig `json:"routing"`
DNSConfig *DnsConfig `json:"dns"` DNSConfig *DNSConfig `json:"dns"`
InboundConfigs []InboundDetourConfig `json:"inbounds"` InboundConfigs []InboundDetourConfig `json:"inbounds"`
OutboundConfigs []OutboundDetourConfig `json:"outbounds"` OutboundConfigs []OutboundDetourConfig `json:"outbounds"`
InboundConfig *InboundDetourConfig `json:"inbound"` // Deprecated. InboundConfig *InboundDetourConfig `json:"inbound"` // Deprecated.
@ -324,7 +324,7 @@ type Config struct {
OutboundDetours []OutboundDetourConfig `json:"outboundDetour"` // Deprecated. OutboundDetours []OutboundDetourConfig `json:"outboundDetour"` // Deprecated.
Transport *TransportConfig `json:"transport"` Transport *TransportConfig `json:"transport"`
Policy *PolicyConfig `json:"policy"` Policy *PolicyConfig `json:"policy"`
Api *ApiConfig `json:"api"` API *APIConfig `json:"api"`
Stats *StatsConfig `json:"stats"` Stats *StatsConfig `json:"stats"`
Reverse *ReverseConfig `json:"reverse"` Reverse *ReverseConfig `json:"reverse"`
} }
@ -353,7 +353,6 @@ func (c *Config) findOutboundTag(tag string) int {
// Override method accepts another Config overrides the current attribute // Override method accepts another Config overrides the current attribute
func (c *Config) Override(o *Config, fn string) { func (c *Config) Override(o *Config, fn string) {
// only process the non-deprecated members // only process the non-deprecated members
if o.LogConfig != nil { if o.LogConfig != nil {
@ -371,8 +370,8 @@ func (c *Config) Override(o *Config, fn string) {
if o.Policy != nil { if o.Policy != nil {
c.Policy = o.Policy c.Policy = o.Policy
} }
if o.Api != nil { if o.API != nil {
c.Api = o.Api c.API = o.API
} }
if o.Stats != nil { if o.Stats != nil {
c.Stats = o.Stats c.Stats = o.Stats
@ -460,8 +459,8 @@ func (c *Config) Build() (*core.Config, error) {
}, },
} }
if c.Api != nil { if c.API != nil {
apiConf, err := c.Api.Build() apiConf, err := c.API.Build()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -384,10 +384,10 @@ func TestConfig_Override(t *testing.T) {
&Config{ &Config{
LogConfig: &LogConfig{}, LogConfig: &LogConfig{},
RouterConfig: &RouterConfig{}, RouterConfig: &RouterConfig{},
DNSConfig: &DnsConfig{}, DNSConfig: &DNSConfig{},
Transport: &TransportConfig{}, Transport: &TransportConfig{},
Policy: &PolicyConfig{}, Policy: &PolicyConfig{},
Api: &ApiConfig{}, API: &APIConfig{},
Stats: &StatsConfig{}, Stats: &StatsConfig{},
Reverse: &ReverseConfig{}, Reverse: &ReverseConfig{},
}, },
@ -395,10 +395,10 @@ func TestConfig_Override(t *testing.T) {
&Config{ &Config{
LogConfig: &LogConfig{}, LogConfig: &LogConfig{},
RouterConfig: &RouterConfig{}, RouterConfig: &RouterConfig{},
DNSConfig: &DnsConfig{}, DNSConfig: &DNSConfig{},
Transport: &TransportConfig{}, Transport: &TransportConfig{},
Policy: &PolicyConfig{}, Policy: &PolicyConfig{},
Api: &ApiConfig{}, API: &APIConfig{},
Stats: &StatsConfig{}, Stats: &StatsConfig{},
Reverse: &ReverseConfig{}, Reverse: &ReverseConfig{},
}, },

View File

@ -33,12 +33,7 @@ type VLessInboundConfig struct {
// Build implements Buildable // Build implements Buildable
func (c *VLessInboundConfig) Build() (proto.Message, error) { func (c *VLessInboundConfig) Build() (proto.Message, error) {
config := new(inbound.Config) config := new(inbound.Config)
if len(c.Clients) == 0 {
//return nil, newError(`VLESS settings: "clients" is empty`)
}
config.Clients = make([]*protocol.User, len(c.Clients)) config.Clients = make([]*protocol.User, len(c.Clients))
for idx, rawUser := range c.Clients { for idx, rawUser := range c.Clients {
user := new(protocol.User) user := new(protocol.User)
@ -142,7 +137,6 @@ type VLessOutboundConfig struct {
// Build implements Buildable // Build implements Buildable
func (c *VLessOutboundConfig) Build() (proto.Message, error) { func (c *VLessOutboundConfig) Build() (proto.Message, error) {
config := new(outbound.Config) config := new(outbound.Config)
if len(c.Vnext) == 0 { if len(c.Vnext) == 0 {

View File

@ -16,13 +16,13 @@ import (
"v2ray.com/core/common" "v2ray.com/core/common"
) )
type ApiCommand struct{} type APICommand struct{}
func (c *ApiCommand) Name() string { func (c *APICommand) Name() string {
return "api" return "api"
} }
func (c *ApiCommand) Description() Description { func (c *APICommand) Description() Description {
return Description{ return Description{
Short: "Call V2Ray API", Short: "Call V2Ray API",
Usage: []string{ Usage: []string{
@ -42,7 +42,7 @@ func (c *ApiCommand) Description() Description {
} }
} }
func (c *ApiCommand) Execute(args []string) error { func (c *APICommand) Execute(args []string) error {
fs := flag.NewFlagSet(c.Name(), flag.ContinueOnError) fs := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
serverAddrPtr := fs.String("server", "127.0.0.1:8080", "Server address") serverAddrPtr := fs.String("server", "127.0.0.1:8080", "Server address")
@ -154,5 +154,5 @@ func callStatsService(ctx context.Context, conn *grpc.ClientConn, method string,
} }
func init() { func init() {
common.Must(RegisterCommand(&ApiCommand{})) common.Must(RegisterCommand(&APICommand{}))
} }

View File

@ -53,7 +53,7 @@ func (c *CertificateCommand) Description() Description {
} }
} }
func (c *CertificateCommand) printJson(certificate *cert.Certificate) { func (c *CertificateCommand) printJSON(certificate *cert.Certificate) {
certPEM, keyPEM := certificate.ToPEM() certPEM, keyPEM := certificate.ToPEM()
jCert := &jsonCert{ jCert := &jsonCert{
Certificate: strings.Split(strings.TrimSpace(string(certPEM)), "\n"), Certificate: strings.Split(strings.TrimSpace(string(certPEM)), "\n"),
@ -122,7 +122,7 @@ func (c *CertificateCommand) Execute(args []string) error {
} }
if *jsonOutput { if *jsonOutput {
c.printJson(cert) c.printJSON(cert)
} }
if len(*fileOutput) > 0 { if len(*fileOutput) > 0 {

View File

@ -66,13 +66,15 @@ func (c *ConfigCommand) Execute(args []string) error {
// LoadArg loads one arg, maybe an remote url, or local file path // LoadArg loads one arg, maybe an remote url, or local file path
func (c *ConfigCommand) LoadArg(arg string) (out io.Reader, err error) { func (c *ConfigCommand) LoadArg(arg string) (out io.Reader, err error) {
var data []byte var data []byte
if strings.HasPrefix(arg, "http://") || strings.HasPrefix(arg, "https://") { switch {
case strings.HasPrefix(arg, "http://"), strings.HasPrefix(arg, "https://"):
data, err = FetchHTTPContent(arg) data, err = FetchHTTPContent(arg)
} else if arg == "stdin:" {
case arg == "stdin:":
data, err = ioutil.ReadAll(os.Stdin) data, err = ioutil.ReadAll(os.Stdin)
} else {
default:
data, err = ioutil.ReadFile(arg) data, err = ioutil.ReadFile(arg)
} }

View File

@ -39,7 +39,6 @@ func (c *FetchCommand) Execute(args []string) error {
// FetchHTTPContent dials https for remote content // FetchHTTPContent dials https for remote content
func FetchHTTPContent(target string) ([]byte, error) { func FetchHTTPContent(target string) ([]byte, error) {
parsedTarget, err := url.Parse(target) parsedTarget, err := url.Parse(target)
if err != nil { if err != nil {
return nil, newError("invalid URL: ", target).Base(err) return nil, newError("invalid URL: ", target).Base(err)

View File

@ -10,13 +10,13 @@ import (
"v2ray.com/core/common" "v2ray.com/core/common"
) )
type TlsPingCommand struct{} type TLSPingCommand struct{}
func (c *TlsPingCommand) Name() string { func (c *TLSPingCommand) Name() string {
return "tlsping" return "tlsping"
} }
func (c *TlsPingCommand) Description() Description { func (c *TLSPingCommand) Description() Description {
return Description{ return Description{
Short: "Ping the domain with TLS handshake", Short: "Ping the domain with TLS handshake",
Usage: []string{"v2ctl tlsping <domain> --ip <ip>"}, Usage: []string{"v2ctl tlsping <domain> --ip <ip>"},
@ -32,7 +32,7 @@ func printCertificates(certs []*x509.Certificate) {
} }
} }
func (c *TlsPingCommand) Execute(args []string) error { func (c *TLSPingCommand) Execute(args []string) error {
fs := flag.NewFlagSet(c.Name(), flag.ContinueOnError) fs := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
ipStr := fs.String("ip", "", "IP address of the domain") ipStr := fs.String("ip", "", "IP address of the domain")
@ -115,5 +115,5 @@ func (c *TlsPingCommand) Execute(args []string) error {
} }
func init() { func init() {
common.Must(RegisterCommand(&TlsPingCommand{})) common.Must(RegisterCommand(&TLSPingCommand{}))
} }

View File

@ -2,8 +2,9 @@ package control
import ( import (
"flag" "flag"
"github.com/xiaokangwang/VSign/signerVerify"
"os" "os"
"github.com/xiaokangwang/VSign/signerVerify"
"v2ray.com/core/common" "v2ray.com/core/common"
) )

View File

@ -18,13 +18,15 @@ import (
) )
func ConfigLoader(arg string) (out io.Reader, err error) { func ConfigLoader(arg string) (out io.Reader, err error) {
var data []byte var data []byte
if strings.HasPrefix(arg, "http://") || strings.HasPrefix(arg, "https://") { switch {
case strings.HasPrefix(arg, "http://"), strings.HasPrefix(arg, "https://"):
data, err = FetchHTTPContent(arg) data, err = FetchHTTPContent(arg)
} else if arg == "stdin:" {
case arg == "stdin:":
data, err = ioutil.ReadAll(os.Stdin) data, err = ioutil.ReadAll(os.Stdin)
} else {
default:
data, err = ioutil.ReadFile(arg) data, err = ioutil.ReadFile(arg)
} }
@ -36,7 +38,6 @@ func ConfigLoader(arg string) (out io.Reader, err error) {
} }
func FetchHTTPContent(target string) ([]byte, error) { func FetchHTTPContent(target string) ([]byte, error) {
parsedTarget, err := url.Parse(target) parsedTarget, err := url.Parse(target)
if err != nil { if err != nil {
return nil, newError("invalid URL: ", target).Base(err) return nil, newError("invalid URL: ", target).Base(err)

View File

@ -1,7 +1,8 @@
package debug package debug
import _ "net/http/pprof" import (
import "net/http" "net/http"
)
func init() { func init() {
go func() { go func() {

View File

@ -70,11 +70,9 @@ func getConfigFilePath() (cmdarg.Arg, error) {
if dirExists(configDir) { if dirExists(configDir) {
log.Println("Using confdir from arg:", configDir) log.Println("Using confdir from arg:", configDir)
readConfDir(configDir) readConfDir(configDir)
} else { } else if envConfDir := platform.GetConfDirPath(); dirExists(envConfDir) {
if envConfDir := platform.GetConfDirPath(); dirExists(envConfDir) { log.Println("Using confdir from env:", envConfDir)
log.Println("Using confdir from env:", envConfDir) readConfDir(envConfDir)
readConfDir(envConfDir)
}
} }
if len(configFiles) > 0 { if len(configFiles) > 0 {
@ -134,7 +132,6 @@ func printVersion() {
} }
func main() { func main() {
flag.Parse() flag.Parse()
printVersion() printVersion()

View File

@ -1,7 +1,6 @@
// +build !confonly // +build !confonly
// Package blackhole is an outbound handler that blocks all connections. // Package blackhole is an outbound handler that blocks all connections.
package blackhole package blackhole
//go:generate go run v2ray.com/core/common/errors/errorgen //go:generate go run v2ray.com/core/common/errors/errorgen

View File

@ -19,6 +19,7 @@ func TestHTTPResponse(t *testing.T) {
reader := bufio.NewReader(buffer) reader := bufio.NewReader(buffer)
response, err := http.ReadResponse(reader, nil) response, err := http.ReadResponse(reader, nil)
common.Must(err) common.Must(err)
if response.StatusCode != 403 { if response.StatusCode != 403 {
t.Error("expected status code 403, but got ", response.StatusCode) t.Error("expected status code 403, but got ", response.StatusCode)
} }

View File

@ -44,7 +44,8 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
for _, q := range r.Question { for _, q := range r.Question {
if q.Name == "google.com." && q.Qtype == dns.TypeA { switch {
case q.Name == "google.com." && q.Qtype == dns.TypeA:
if clientIP == nil { if clientIP == nil {
rr, _ := dns.NewRR("google.com. IN A 8.8.8.8") rr, _ := dns.NewRR("google.com. IN A 8.8.8.8")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
@ -52,18 +53,22 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
rr, _ := dns.NewRR("google.com. IN A 8.8.4.4") rr, _ := dns.NewRR("google.com. IN A 8.8.4.4")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} }
} else if q.Name == "facebook.com." && q.Qtype == dns.TypeA {
case q.Name == "facebook.com." && q.Qtype == dns.TypeA:
rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9") rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9")
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "ipv6.google.com." && q.Qtype == dns.TypeA {
case q.Name == "ipv6.google.com." && q.Qtype == dns.TypeA:
rr, err := dns.NewRR("ipv6.google.com. IN A 8.8.8.7") rr, err := dns.NewRR("ipv6.google.com. IN A 8.8.8.7")
common.Must(err) common.Must(err)
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "ipv6.google.com." && q.Qtype == dns.TypeAAAA {
case q.Name == "ipv6.google.com." && q.Qtype == dns.TypeAAAA:
rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888") rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888")
common.Must(err) common.Must(err)
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA {
case q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA:
ans.MsgHdr.Rcode = dns.RcodeNameError ans.MsgHdr.Rcode = dns.RcodeNameError
} }
} }

View File

@ -156,7 +156,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
if network == net.Network_TCP { if network == net.Network_TCP {
writer = buf.NewWriter(conn) writer = buf.NewWriter(conn)
} else { } else {
//if we are in TPROXY mode, use linux's udp forging functionality // if we are in TPROXY mode, use linux's udp forging functionality
if !destinationOverridden { if !destinationOverridden {
writer = &buf.SequentialWriter{Writer: conn} writer = &buf.SequentialWriter{Writer: conn}
} else { } else {

View File

@ -137,7 +137,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
if err != nil { if err != nil {
return newError("failed to open connection to ", destination).Base(err) return newError("failed to open connection to ", destination).Base(err)
} }
defer conn.Close() // nolint: errcheck defer conn.Close()
plcy := h.policy() plcy := h.policy()
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)

View File

@ -103,7 +103,7 @@ Start:
if err != nil { if err != nil {
trace := newError("failed to read http request").Base(err) trace := newError("failed to read http request").Base(err)
if errors.Cause(err) != io.EOF && !isTimeout(errors.Cause(err)) { if errors.Cause(err) != io.EOF && !isTimeout(errors.Cause(err)) {
trace.AtWarning() // nolint: errcheck trace.AtWarning()
} }
return trace return trace
} }
@ -159,7 +159,7 @@ Start:
return err return err
} }
func (s *Server) handleConnect(ctx context.Context, request *http.Request, reader *bufio.Reader, conn internet.Connection, dest net.Destination, dispatcher routing.Dispatcher) error { func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *bufio.Reader, conn internet.Connection, dest net.Destination, dispatcher routing.Dispatcher) error {
_, err := conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) _, err := conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n"))
if err != nil { if err != nil {
return newError("failed to write back OK response").Base(err) return newError("failed to write back OK response").Base(err)
@ -263,7 +263,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
} }
// Plain HTTP request is not a stream. The request always finishes before response. Hense request has to be closed later. // Plain HTTP request is not a stream. The request always finishes before response. Hense request has to be closed later.
defer common.Close(link.Writer) // nolint: errcheck defer common.Close(link.Writer)
var result error = errWaitAnother var result error = errWaitAnother
requestDone := func() error { requestDone := func() error {

View File

@ -34,7 +34,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
if err != nil { if err != nil {
return newError("failed to dial to ", dest).Base(err).AtWarning() return newError("failed to dial to ", dest).Base(err).AtWarning()
} }
defer conn.Close() // nolint: errcheck defer conn.Close()
sc := SessionContextFromContext(ctx) sc := SessionContextFromContext(ctx)
auth := NewAuthentication(sc) auth := NewAuthentication(sc)

View File

@ -136,7 +136,6 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
} }
if request.Command == protocol.RequestCommandUDP { if request.Command == protocol.RequestCommandUDP {
writer := &buf.SequentialWriter{Writer: &UDPWriter{ writer := &buf.SequentialWriter{Writer: &UDPWriter{
Writer: conn, Writer: conn,
Request: request, Request: request,

View File

@ -6,14 +6,13 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"hash"
"hash/crc32" "hash/crc32"
"io" "io"
"io/ioutil" "io/ioutil"
"v2ray.com/core/common/dice"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/dice"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
) )
@ -35,7 +34,7 @@ var addrParser = protocol.NewAddressParser(
func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) { func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) {
account := user.Account.(*MemoryAccount) account := user.Account.(*MemoryAccount)
hashkdf := hmac.New(func() hash.Hash { return sha256.New() }, []byte("SSBSKDF")) hashkdf := hmac.New(sha256.New, []byte("SSBSKDF"))
hashkdf.Write(account.Key) hashkdf.Write(account.Key)
behaviorSeed := crc32.ChecksumIEEE(hashkdf.Sum(nil)) behaviorSeed := crc32.ChecksumIEEE(hashkdf.Sum(nil))

View File

@ -133,7 +133,6 @@ func TestTCPRequest(t *testing.T) {
for _, test := range cases { for _, test := range cases {
runTest(test.request, test.payload) runTest(test.request, test.payload)
} }
} }
func TestUDPReaderWriter(t *testing.T) { func TestUDPReaderWriter(t *testing.T) {

View File

@ -127,7 +127,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
if err != nil { if err != nil {
return newError("failed to create UDP connection").Base(err) return newError("failed to create UDP connection").Base(err)
} }
defer udpConn.Close() // nolint: errcheck defer udpConn.Close()
requestFunc = func() error { requestFunc = func() error {
defer timer.SetTimeout(p.Timeouts.DownlinkOnly) defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
return buf.Copy(link.Reader, &buf.SequentialWriter{Writer: NewUDPWriter(request, udpConn)}, buf.UpdateActivity(timer)) return buf.Copy(link.Reader, &buf.SequentialWriter{Writer: NewUDPWriter(request, udpConn)}, buf.UpdateActivity(timer))

View File

@ -26,7 +26,7 @@ const (
socks4RequestRejected = 91 socks4RequestRejected = 91
authNotRequired = 0x00 authNotRequired = 0x00
//authGssAPI = 0x01 // authGssAPI = 0x01
authPassword = 0x02 authPassword = 0x02
authNoMatchingMethod = 0xFF authNoMatchingMethod = 0xFF
@ -47,7 +47,7 @@ type ServerSession struct {
func (s *ServerSession) handshake4(cmd byte, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) { func (s *ServerSession) handshake4(cmd byte, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {
if s.config.AuthType == AuthType_PASSWORD { if s.config.AuthType == AuthType_PASSWORD {
writeSocks4Response(writer, socks4RequestRejected, net.AnyIP, net.Port(0)) // nolint: errcheck writeSocks4Response(writer, socks4RequestRejected, net.AnyIP, net.Port(0))
return nil, newError("socks 4 is not allowed when auth is required.") return nil, newError("socks 4 is not allowed when auth is required.")
} }
@ -89,7 +89,7 @@ func (s *ServerSession) handshake4(cmd byte, reader io.Reader, writer io.Writer)
} }
return request, nil return request, nil
default: default:
writeSocks4Response(writer, socks4RequestRejected, net.AnyIP, net.Port(0)) // nolint: errcheck writeSocks4Response(writer, socks4RequestRejected, net.AnyIP, net.Port(0))
return nil, newError("unsupported command: ", cmd) return nil, newError("unsupported command: ", cmd)
} }
} }
@ -108,7 +108,7 @@ func (s *ServerSession) auth5(nMethod byte, reader io.Reader, writer io.Writer)
} }
if !hasAuthMethod(expectedAuth, buffer.BytesRange(0, int32(nMethod))) { if !hasAuthMethod(expectedAuth, buffer.BytesRange(0, int32(nMethod))) {
writeSocks5AuthenticationResponse(writer, socks5Version, authNoMatchingMethod) // nolint: errcheck writeSocks5AuthenticationResponse(writer, socks5Version, authNoMatchingMethod)
return "", newError("no matching auth method") return "", newError("no matching auth method")
} }
@ -123,7 +123,7 @@ func (s *ServerSession) auth5(nMethod byte, reader io.Reader, writer io.Writer)
} }
if !s.config.HasAccount(username, password) { if !s.config.HasAccount(username, password) {
writeSocks5AuthenticationResponse(writer, 0x01, 0xFF) // nolint: errcheck writeSocks5AuthenticationResponse(writer, 0x01, 0xFF)
return "", newError("invalid username or password") return "", newError("invalid username or password")
} }
@ -166,15 +166,15 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri
request.Command = protocol.RequestCommandTCP request.Command = protocol.RequestCommandTCP
case cmdUDPPort: case cmdUDPPort:
if !s.config.UdpEnabled { if !s.config.UdpEnabled {
writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) // nolint: errcheck writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0))
return nil, newError("UDP is not enabled.") return nil, newError("UDP is not enabled.")
} }
request.Command = protocol.RequestCommandUDP request.Command = protocol.RequestCommandUDP
case cmdTCPBind: case cmdTCPBind:
writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) // nolint: errcheck writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0))
return nil, newError("TCP bind is not supported.") return nil, newError("TCP bind is not supported.")
default: default:
writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) // nolint: errcheck writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0))
return nil, newError("unknown command ", cmd) return nil, newError("unknown command ", cmd)
} }

View File

@ -202,7 +202,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
newError("failed to write UDP response").AtWarning().Base(err).WriteToLog(session.ExportIDToError(ctx)) newError("failed to write UDP response").AtWarning().Base(err).WriteToLog(session.ExportIDToError(ctx))
} }
conn.Write(udpMessage.Bytes()) // nolint: errcheck conn.Write(udpMessage.Bytes())
}) })
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() { if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {

View File

@ -49,7 +49,7 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
} }
// Process implements OutboundHandler.Process(). // Process implements OutboundHandler.Process().
func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { // nolint: funlen func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbound := session.OutboundFromContext(ctx) outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() { if outbound == nil || !outbound.Target.IsValid() {
return newError("target not specified") return newError("target not specified")
@ -60,7 +60,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
var server *protocol.ServerSpec var server *protocol.ServerSpec
var conn internet.Connection var conn internet.Connection
err := retry.ExponentialBackoff(5, 100).On(func() error { // nolint: gomnd err := retry.ExponentialBackoff(5, 100).On(func() error {
server = c.serverPicker.PickServer() server = c.serverPicker.PickServer()
rawConn, err := dialer.Dial(ctx, server.Destination()) rawConn, err := dialer.Dial(ctx, server.Destination())
if err != nil { if err != nil {
@ -101,7 +101,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
} }
// write some request payload to buffer // write some request payload to buffer
if err = buf.CopyOnceTimeout(link.Reader, bodyWriter, time.Millisecond*100); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout { // nolint: lll,gomnd if err = buf.CopyOnceTimeout(link.Reader, bodyWriter, time.Millisecond*100); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout {
return newError("failed to write A reqeust payload").Base(err).AtWarning() return newError("failed to write A reqeust payload").Base(err).AtWarning()
} }
@ -140,7 +140,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
} }
func init() { func init() {
common.Must(common.RegisterConfig((*ClientConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { // nolint: lll common.Must(common.RegisterConfig((*ClientConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
return NewClient(ctx, config.(*ClientConfig)) return NewClient(ctx, config.(*ClientConfig))
})) }))
} }

View File

@ -13,9 +13,9 @@ var (
crlf = []byte{'\r', '\n'} crlf = []byte{'\r', '\n'}
addrParser = protocol.NewAddressParser( addrParser = protocol.NewAddressParser(
protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4), // nolint: gomnd protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6), // nolint: gomnd protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain), // nolint: gomnd protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
) )
) )
@ -129,7 +129,7 @@ func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net
return nil return nil
} }
func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) { // nolint: unparam func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) {
buffer := buf.StackNew() buffer := buf.StackNew()
defer buffer.Release() defer buffer.Release()

View File

@ -29,7 +29,7 @@ import (
) )
func init() { func init() {
common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { // nolint: lll common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
return NewServer(ctx, config.(*ServerConfig)) return NewServer(ctx, config.(*ServerConfig))
})) }))
} }
@ -91,8 +91,7 @@ func (s *Server) Network() []net.Network {
} }
// Process implements proxy.Inbound.Process(). // Process implements proxy.Inbound.Process().
func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error { // nolint: funlen,lll func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error {
sid := session.ExportIDToError(ctx) sid := session.ExportIDToError(ctx)
iConn := conn iConn := conn
@ -125,7 +124,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
isfb := apfb != nil isfb := apfb != nil
shouldFallback := false shouldFallback := false
if firstLen < 58 || first.Byte(56) != '\r' { // nolint: gomnd if firstLen < 58 || first.Byte(56) != '\r' {
// invalid protocol // invalid protocol
err = newError("not trojan protocol") err = newError("not trojan protocol")
log.Record(&log.AccessMessage{ log.Record(&log.AccessMessage{
@ -137,7 +136,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
shouldFallback = true shouldFallback = true
} else { } else {
user = s.validator.Get(hexString(first.BytesTo(56))) // nolint: gomnd user = s.validator.Get(hexString(first.BytesTo(56)))
if user == nil { if user == nil {
// invalid user, let's fallback // invalid user, let's fallback
err = newError("not a valid user") err = newError("not a valid user")
@ -199,7 +198,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher) return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher)
} }
func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error { // nolint: lll func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
common.Must(clientWriter.WriteMultiBufferWithMetadata(buf.MultiBuffer{packet.Payload}, packet.Source)) common.Must(clientWriter.WriteMultiBufferWithMetadata(buf.MultiBuffer{packet.Payload}, packet.Source))
}) })
@ -277,7 +276,7 @@ func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Sess
return nil return nil
} }
func (s *Server) fallback(ctx context.Context, sid errors.ExportOption, err error, sessionPolicy policy.Session, connection internet.Connection, iConn internet.Connection, apfb map[string]map[string]*Fallback, first *buf.Buffer, firstLen int64, reader buf.Reader) error { // nolint: lll func (s *Server) fallback(ctx context.Context, sid errors.ExportOption, err error, sessionPolicy policy.Session, connection internet.Connection, iConn internet.Connection, apfb map[string]map[string]*Fallback, first *buf.Buffer, firstLen int64, reader buf.Reader) error {
if err := connection.SetReadDeadline(time.Time{}); err != nil { if err := connection.SetReadDeadline(time.Time{}); err != nil {
newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid) newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid)
} }

View File

@ -13,10 +13,8 @@ import (
) )
func EncodeHeaderAddons(buffer *buf.Buffer, addons *Addons) error { func EncodeHeaderAddons(buffer *buf.Buffer, addons *Addons) error {
switch addons.Flow { switch addons.Flow {
case vless.XRO, vless.XRD: case vless.XRO, vless.XRD:
bytes, err := proto.Marshal(addons) bytes, err := proto.Marshal(addons)
if err != nil { if err != nil {
return newError("failed to marshal addons protobuf value").Base(err) return newError("failed to marshal addons protobuf value").Base(err)
@ -27,30 +25,23 @@ func EncodeHeaderAddons(buffer *buf.Buffer, addons *Addons) error {
if _, err := buffer.Write(bytes); err != nil { if _, err := buffer.Write(bytes); err != nil {
return newError("failed to write addons protobuf value").Base(err) return newError("failed to write addons protobuf value").Base(err)
} }
default: default:
if err := buffer.WriteByte(0); err != nil { if err := buffer.WriteByte(0); err != nil {
return newError("failed to write addons protobuf length").Base(err) return newError("failed to write addons protobuf length").Base(err)
} }
} }
return nil return nil
} }
func DecodeHeaderAddons(buffer *buf.Buffer, reader io.Reader) (*Addons, error) { func DecodeHeaderAddons(buffer *buf.Buffer, reader io.Reader) (*Addons, error) {
addons := new(Addons) addons := new(Addons)
buffer.Clear() buffer.Clear()
if _, err := buffer.ReadFullFrom(reader, 1); err != nil { if _, err := buffer.ReadFullFrom(reader, 1); err != nil {
return nil, newError("failed to read addons protobuf length").Base(err) return nil, newError("failed to read addons protobuf length").Base(err)
} }
if length := int32(buffer.Byte(0)); length != 0 { if length := int32(buffer.Byte(0)); length != 0 {
buffer.Clear() buffer.Clear()
if _, err := buffer.ReadFullFrom(reader, length); err != nil { if _, err := buffer.ReadFullFrom(reader, length); err != nil {
return nil, newError("failed to read addons protobuf value").Base(err) return nil, newError("failed to read addons protobuf value").Base(err)
@ -63,45 +54,32 @@ func DecodeHeaderAddons(buffer *buf.Buffer, reader io.Reader) (*Addons, error) {
// Verification. // Verification.
switch addons.Flow { switch addons.Flow {
default: default:
} }
} }
return addons, nil return addons, nil
} }
// EncodeBodyAddons returns a Writer that auto-encrypt content written by caller. // EncodeBodyAddons returns a Writer that auto-encrypt content written by caller.
func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, addons *Addons) buf.Writer { func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, addons *Addons) buf.Writer {
switch addons.Flow { switch addons.Flow {
default: default:
if request.Command == protocol.RequestCommandUDP { if request.Command == protocol.RequestCommandUDP {
return NewMultiLengthPacketWriter(writer.(buf.Writer)) return NewMultiLengthPacketWriter(writer.(buf.Writer))
} }
} }
return buf.NewWriter(writer) return buf.NewWriter(writer)
} }
// DecodeBodyAddons returns a Reader from which caller can fetch decrypted body. // DecodeBodyAddons returns a Reader from which caller can fetch decrypted body.
func DecodeBodyAddons(reader io.Reader, request *protocol.RequestHeader, addons *Addons) buf.Reader { func DecodeBodyAddons(reader io.Reader, request *protocol.RequestHeader, addons *Addons) buf.Reader {
switch addons.Flow { switch addons.Flow {
default: default:
if request.Command == protocol.RequestCommandUDP { if request.Command == protocol.RequestCommandUDP {
return NewLengthPacketReader(reader) return NewLengthPacketReader(reader)
} }
} }
return buf.NewReader(reader) return buf.NewReader(reader)
} }
func NewMultiLengthPacketWriter(writer buf.Writer) *MultiLengthPacketWriter { func NewMultiLengthPacketWriter(writer buf.Writer) *MultiLengthPacketWriter {
@ -157,7 +135,7 @@ type LengthPacketWriter struct {
func (w *LengthPacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { func (w *LengthPacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
length := mb.Len() // none of mb is nil length := mb.Len() // none of mb is nil
//fmt.Println("Write", length) // fmt.Println("Write", length)
if length == 0 { if length == 0 {
return nil return nil
} }
@ -193,7 +171,7 @@ func (r *LengthPacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
return nil, newError("failed to read packet length").Base(err) return nil, newError("failed to read packet length").Base(err)
} }
length := int32(r.cache[0])<<8 | int32(r.cache[1]) length := int32(r.cache[0])<<8 | int32(r.cache[1])
//fmt.Println("Read", length) // fmt.Println("Read", length)
mb := make(buf.MultiBuffer, 0, length/buf.Size+1) mb := make(buf.MultiBuffer, 0, length/buf.Size+1)
for length > 0 { for length > 0 {
size := length size := length

View File

@ -26,7 +26,6 @@ var addrParser = protocol.NewAddressParser(
// EncodeRequestHeader writes encoded request header into the given writer. // EncodeRequestHeader writes encoded request header into the given writer.
func EncodeRequestHeader(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons) error { func EncodeRequestHeader(writer io.Writer, request *protocol.RequestHeader, requestAddons *Addons) error {
buffer := buf.StackNew() buffer := buf.StackNew()
defer buffer.Release() defer buffer.Release()
@ -60,8 +59,7 @@ func EncodeRequestHeader(writer io.Writer, request *protocol.RequestHeader, requ
} }
// DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream. // DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator *vless.Validator) (*protocol.RequestHeader, *Addons, error, bool) { func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator *vless.Validator) (*protocol.RequestHeader, *Addons, bool, error) {
buffer := buf.StackNew() buffer := buf.StackNew()
defer buffer.Release() defer buffer.Release()
@ -71,7 +69,7 @@ func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validat
request.Version = first.Byte(0) request.Version = first.Byte(0)
} else { } else {
if _, err := buffer.ReadFullFrom(reader, 1); err != nil { if _, err := buffer.ReadFullFrom(reader, 1); err != nil {
return nil, nil, newError("failed to read request version").Base(err), false return nil, nil, false, newError("failed to read request version").Base(err)
} }
request.Version = buffer.Byte(0) request.Version = buffer.Byte(0)
} }
@ -86,13 +84,13 @@ func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validat
} else { } else {
buffer.Clear() buffer.Clear()
if _, err := buffer.ReadFullFrom(reader, 16); err != nil { if _, err := buffer.ReadFullFrom(reader, 16); err != nil {
return nil, nil, newError("failed to read request user id").Base(err), false return nil, nil, false, newError("failed to read request user id").Base(err)
} }
copy(id[:], buffer.Bytes()) copy(id[:], buffer.Bytes())
} }
if request.User = validator.Get(id); request.User == nil { if request.User = validator.Get(id); request.User == nil {
return nil, nil, newError("invalid request user id"), isfb return nil, nil, isfb, newError("invalid request user id")
} }
if isfb { if isfb {
@ -101,12 +99,12 @@ func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validat
requestAddons, err := DecodeHeaderAddons(&buffer, reader) requestAddons, err := DecodeHeaderAddons(&buffer, reader)
if err != nil { if err != nil {
return nil, nil, newError("failed to decode request header addons").Base(err), false return nil, nil, false, newError("failed to decode request header addons").Base(err)
} }
buffer.Clear() buffer.Clear()
if _, err := buffer.ReadFullFrom(reader, 1); err != nil { if _, err := buffer.ReadFullFrom(reader, 1); err != nil {
return nil, nil, newError("failed to read request command").Base(err), false return nil, nil, false, newError("failed to read request command").Base(err)
} }
request.Command = protocol.RequestCommand(buffer.Byte(0)) request.Command = protocol.RequestCommand(buffer.Byte(0))
@ -120,24 +118,17 @@ func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validat
request.Port = port request.Port = port
} }
} }
if request.Address == nil { if request.Address == nil {
return nil, nil, newError("invalid request address"), false return nil, nil, false, newError("invalid request address")
} }
return request, requestAddons, false, nil
return request, requestAddons, nil, false
default: default:
return nil, nil, isfb, newError("invalid request version")
return nil, nil, newError("invalid request version"), isfb
} }
} }
// EncodeResponseHeader writes encoded response header into the given writer. // EncodeResponseHeader writes encoded response header into the given writer.
func EncodeResponseHeader(writer io.Writer, request *protocol.RequestHeader, responseAddons *Addons) error { func EncodeResponseHeader(writer io.Writer, request *protocol.RequestHeader, responseAddons *Addons) error {
buffer := buf.StackNew() buffer := buf.StackNew()
defer buffer.Release() defer buffer.Release()
@ -158,7 +149,6 @@ func EncodeResponseHeader(writer io.Writer, request *protocol.RequestHeader, res
// DecodeResponseHeader decodes and returns (if successful) a ResponseHeader from an input stream. // DecodeResponseHeader decodes and returns (if successful) a ResponseHeader from an input stream.
func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*Addons, error) { func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*Addons, error) {
buffer := buf.StackNew() buffer := buf.StackNew()
defer buffer.Release() defer buffer.Release()

View File

@ -46,7 +46,7 @@ func TestRequestSerialization(t *testing.T) {
Validator := new(vless.Validator) Validator := new(vless.Validator)
Validator.Add(user) Validator.Add(user)
actualRequest, actualAddons, err, _ := DecodeRequestHeader(false, nil, &buffer, Validator) actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
common.Must(err) common.Must(err)
if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" { if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {
@ -83,7 +83,7 @@ func TestInvalidRequest(t *testing.T) {
Validator := new(vless.Validator) Validator := new(vless.Validator)
Validator.Add(user) Validator.Add(user)
_, _, err, _ := DecodeRequestHeader(false, nil, &buffer, Validator) _, _, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
if err == nil { if err == nil {
t.Error("nil error") t.Error("nil error")
} }
@ -114,7 +114,7 @@ func TestMuxRequest(t *testing.T) {
Validator := new(vless.Validator) Validator := new(vless.Validator)
Validator.Add(user) Validator.Add(user)
actualRequest, actualAddons, err, _ := DecodeRequestHeader(false, nil, &buffer, Validator) actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
common.Must(err) common.Must(err)
if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" { if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {

View File

@ -64,12 +64,11 @@ type Handler struct {
validator *vless.Validator validator *vless.Validator
dns dns.Client dns dns.Client
fallbacks map[string]map[string]*Fallback // or nil fallbacks map[string]map[string]*Fallback // or nil
//regexps map[string]*regexp.Regexp // or nil // regexps map[string]*regexp.Regexp // or nil
} }
// New creates a new VLess inbound handler. // New creates a new VLess inbound handler.
func New(ctx context.Context, config *Config, dc dns.Client) (*Handler, error) { func New(ctx context.Context, config *Config, dc dns.Client) (*Handler, error) {
v := core.MustFromContext(ctx) v := core.MustFromContext(ctx)
handler := &Handler{ handler := &Handler{
inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager), inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager),
@ -90,7 +89,7 @@ func New(ctx context.Context, config *Config, dc dns.Client) (*Handler, error) {
if config.Fallbacks != nil { if config.Fallbacks != nil {
handler.fallbacks = make(map[string]map[string]*Fallback) handler.fallbacks = make(map[string]map[string]*Fallback)
//handler.regexps = make(map[string]*regexp.Regexp) // handler.regexps = make(map[string]*regexp.Regexp)
for _, fb := range config.Fallbacks { for _, fb := range config.Fallbacks {
if handler.fallbacks[fb.Alpn] == nil { if handler.fallbacks[fb.Alpn] == nil {
handler.fallbacks[fb.Alpn] = make(map[string]*Fallback) handler.fallbacks[fb.Alpn] = make(map[string]*Fallback)
@ -144,9 +143,7 @@ func (*Handler) Network() []net.Network {
// Process implements proxy.Inbound.Process(). // Process implements proxy.Inbound.Process().
func (h *Handler) Process(ctx context.Context, network net.Network, connection internet.Connection, dispatcher routing.Dispatcher) error { func (h *Handler) Process(ctx context.Context, network net.Network, connection internet.Connection, dispatcher routing.Dispatcher) error {
sid := session.ExportIDToError(ctx) sid := session.ExportIDToError(ctx)
iConn := connection iConn := connection
if statConn, ok := iConn.(*internet.StatCouterConnection); ok { if statConn, ok := iConn.(*internet.StatCouterConnection); ok {
iConn = statConn.Connection iConn = statConn.Connection
@ -178,11 +175,10 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
if isfb && firstLen < 18 { if isfb && firstLen < 18 {
err = newError("fallback directly") err = newError("fallback directly")
} else { } else {
request, requestAddons, err, isfb = encoding.DecodeRequestHeader(isfb, first, reader, h.validator) request, requestAddons, isfb, err = encoding.DecodeRequestHeader(isfb, first, reader, h.validator)
} }
if err != nil { if err != nil {
if isfb { if isfb {
if err := connection.SetReadDeadline(time.Time{}); err != nil { if err := connection.SetReadDeadline(time.Time{}); err != nil {
newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid) newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid)
@ -271,7 +267,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
}); err != nil { }); err != nil {
return newError("failed to dial to " + fb.Dest).Base(err).AtWarning() return newError("failed to dial to " + fb.Dest).Base(err).AtWarning()
} }
defer conn.Close() // nolint: errcheck defer conn.Close()
serverReader := buf.NewReader(conn) serverReader := buf.NewReader(conn)
serverWriter := buf.NewWriter(conn) serverWriter := buf.NewWriter(conn)
@ -303,6 +299,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
} else { } else {
pro.Write([]byte("PROXY TCP6 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n")) pro.Write([]byte("PROXY TCP6 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n"))
} }
case 2: case 2:
pro.Write([]byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A\x21")) // signature + v2 + PROXY pro.Write([]byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A\x21")) // signature + v2 + PROXY
if ipv4 { if ipv4 {
@ -372,7 +369,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
account := request.User.Account.(*vless.MemoryAccount) account := request.User.Account.(*vless.MemoryAccount)
responseAddons := &encoding.Addons{ responseAddons := &encoding.Addons{
//Flow: requestAddons.Flow, // Flow: requestAddons.Flow,
} }
switch requestAddons.Flow { switch requestAddons.Flow {
@ -473,7 +470,6 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
// Indicates the end of response payload. // Indicates the end of response payload.
switch responseAddons.Flow { switch responseAddons.Flow {
default: default:
} }
return nil return nil

View File

@ -52,7 +52,6 @@ type Handler struct {
// New creates a new VLess outbound handler. // New creates a new VLess outbound handler.
func New(ctx context.Context, config *Config) (*Handler, error) { func New(ctx context.Context, config *Config) (*Handler, error) {
serverList := protocol.NewServerList() serverList := protocol.NewServerList()
for _, rec := range config.Vnext { for _, rec := range config.Vnext {
s, err := protocol.NewServerSpecFromPB(rec) s, err := protocol.NewServerSpecFromPB(rec)
@ -74,7 +73,6 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
// Process implements proxy.Outbound.Process(). // Process implements proxy.Outbound.Process().
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
var rec *protocol.ServerSpec var rec *protocol.ServerSpec
var conn internet.Connection var conn internet.Connection
@ -89,7 +87,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
}); err != nil { }); err != nil {
return newError("failed to find an available destination").Base(err).AtWarning() return newError("failed to find an available destination").Base(err).AtWarning()
} }
defer conn.Close() // nolint: errcheck defer conn.Close()
iConn := conn iConn := conn
if statConn, ok := iConn.(*internet.StatCouterConnection); ok { if statConn, ok := iConn.(*internet.StatCouterConnection); ok {
@ -193,9 +191,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
// Indicates the end of request payload. // Indicates the end of request payload.
switch requestAddons.Flow { switch requestAddons.Flow {
default: default:
} }
return nil return nil
} }

View File

@ -11,10 +11,16 @@ import (
"io" "io"
"math" "math"
"time" "time"
"v2ray.com/core/common" "v2ray.com/core/common"
antiReplayWindow "v2ray.com/core/common/antireplay" antiReplayWindow "v2ray.com/core/common/antireplay"
) )
var (
ErrNotFound = errors.New("user do not exist")
ErrReplay = errors.New("replayed request")
)
func CreateAuthID(cmdKey []byte, time int64) [16]byte { func CreateAuthID(cmdKey []byte, time int64) [16]byte {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
common.Must(binary.Write(buf, binary.BigEndian, time)) common.Must(binary.Write(buf, binary.BigEndian, time))
@ -32,7 +38,7 @@ func CreateAuthID(cmdKey []byte, time int64) [16]byte {
} }
func NewCipherFromKey(cmdKey []byte) cipher.Block { func NewCipherFromKey(cmdKey []byte) cipher.Block {
aesBlock, err := aes.NewCipher(KDF16(cmdKey, KDFSaltConst_AuthIDEncryptionKey)) aesBlock, err := aes.NewCipher(KDF16(cmdKey, KDFSaltConstAuthIDEncryptionKey))
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -88,10 +94,9 @@ func (a *AuthIDDecoderHolder) RemoveUser(key [16]byte) {
delete(a.aidhi, string(key[:])) delete(a.aidhi, string(key[:]))
} }
func (a *AuthIDDecoderHolder) Match(AuthID [16]byte) (interface{}, error) { func (a *AuthIDDecoderHolder) Match(authID [16]byte) (interface{}, error) {
for _, v := range a.aidhi { for _, v := range a.aidhi {
t, z, r, d := v.dec.Decode(authID)
t, z, r, d := v.dec.Decode(AuthID)
if z != crc32.ChecksumIEEE(d[:12]) { if z != crc32.ChecksumIEEE(d[:12]) {
continue continue
} }
@ -104,18 +109,13 @@ func (a *AuthIDDecoderHolder) Match(AuthID [16]byte) (interface{}, error) {
continue continue
} }
if !a.apw.Check(AuthID[:]) { if !a.apw.Check(authID[:]) {
return nil, ErrReplay return nil, ErrReplay
} }
_ = r _ = r
return v.ticket, nil return v.ticket, nil
} }
return nil, ErrNotFound return nil, ErrNotFound
} }
var ErrNotFound = errors.New("user do not exist")
var ErrReplay = errors.New("replayed request")

View File

@ -2,10 +2,11 @@ package aead
import ( import (
"fmt" "fmt"
"github.com/stretchr/testify/assert"
"strconv" "strconv"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func TestCreateAuthID(t *testing.T) { func TestCreateAuthID(t *testing.T) {
@ -57,7 +58,6 @@ func TestCreateAuthIDAndDecode2(t *testing.T) {
res2, err2 := AuthDecoder.Match(authid2) res2, err2 := AuthDecoder.Match(authid2)
assert.EqualError(t, err2, "user do not exist") assert.EqualError(t, err2, "user do not exist")
assert.Nil(t, res2) assert.Nil(t, res2)
} }
func TestCreateAuthIDAndDecodeMassive(t *testing.T) { func TestCreateAuthIDAndDecodeMassive(t *testing.T) {
@ -89,7 +89,6 @@ func TestCreateAuthIDAndDecodeMassive(t *testing.T) {
res2, err2 := AuthDecoder.Match(authid3) res2, err2 := AuthDecoder.Match(authid3)
assert.Equal(t, "Demo User", res2) assert.Equal(t, "Demo User", res2)
assert.Nil(t, err2) assert.Nil(t, err2)
} }
func TestCreateAuthIDAndDecodeSuperMassive(t *testing.T) { func TestCreateAuthIDAndDecodeSuperMassive(t *testing.T) {
@ -125,5 +124,4 @@ func TestCreateAuthIDAndDecodeSuperMassive(t *testing.T) {
assert.Nil(t, err2) assert.Nil(t, err2)
fmt.Println(after.Sub(before).Seconds()) fmt.Println(after.Sub(before).Seconds())
} }

View File

@ -1,21 +1,14 @@
package aead package aead
const KDFSaltConst_AuthIDEncryptionKey = "AES Auth ID Encryption" const (
KDFSaltConstAuthIDEncryptionKey = "AES Auth ID Encryption"
const KDFSaltConst_AEADRespHeaderLenKey = "AEAD Resp Header Len Key" KDFSaltConstAEADRespHeaderLenKey = "AEAD Resp Header Len Key"
KDFSaltConstAEADRespHeaderLenIV = "AEAD Resp Header Len IV"
const KDFSaltConst_AEADRespHeaderLenIV = "AEAD Resp Header Len IV" KDFSaltConstAEADRespHeaderPayloadKey = "AEAD Resp Header Key"
KDFSaltConstAEADRespHeaderPayloadIV = "AEAD Resp Header IV"
const KDFSaltConst_AEADRespHeaderPayloadKey = "AEAD Resp Header Key" KDFSaltConstVMessAEADKDF = "VMess AEAD KDF"
KDFSaltConstVMessHeaderPayloadAEADKey = "VMess Header AEAD Key"
const KDFSaltConst_AEADRespHeaderPayloadIV = "AEAD Resp Header IV" KDFSaltConstVMessHeaderPayloadAEADIV = "VMess Header AEAD Nonce"
KDFSaltConstVMessHeaderPayloadLengthAEADKey = "VMess Header AEAD Key_Length"
const KDFSaltConst_VMessAEADKDF = "VMess AEAD KDF" KDFSaltConstVMessHeaderPayloadLengthAEADIV = "VMess Header AEAD Nonce_Length"
)
const KDFSaltConst_VMessHeaderPayloadAEADKey = "VMess Header AEAD Key"
const KDFSaltConst_VMessHeaderPayloadAEADIV = "VMess Header AEAD Nonce"
const KDFSaltConst_VMessHeaderPayloadLengthAEADKey = "VMess Header AEAD Key_Length"
const KDFSaltConst_VMessHeaderPayloadLengthAEADIV = "VMess Header AEAD Nonce_Length"

View File

@ -30,9 +30,9 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte {
var payloadHeaderLengthAEADEncrypted []byte var payloadHeaderLengthAEADEncrypted []byte
{ {
payloadHeaderLengthAEADKey := KDF16(key[:], KDFSaltConst_VMessHeaderPayloadLengthAEADKey, string(generatedAuthID[:]), string(connectionNonce)) payloadHeaderLengthAEADKey := KDF16(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADKey, string(generatedAuthID[:]), string(connectionNonce))
payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConst_VMessHeaderPayloadLengthAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12] payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12]
payloadHeaderLengthAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey) payloadHeaderLengthAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey)
if err != nil { if err != nil {
@ -51,9 +51,9 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte {
var payloadHeaderAEADEncrypted []byte var payloadHeaderAEADEncrypted []byte
{ {
payloadHeaderAEADKey := KDF16(key[:], KDFSaltConst_VMessHeaderPayloadAEADKey, string(generatedAuthID[:]), string(connectionNonce)) payloadHeaderAEADKey := KDF16(key[:], KDFSaltConstVMessHeaderPayloadAEADKey, string(generatedAuthID[:]), string(connectionNonce))
payloadHeaderAEADNonce := KDF(key[:], KDFSaltConst_VMessHeaderPayloadAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12] payloadHeaderAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12]
payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey) payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey)
if err != nil { if err != nil {
@ -71,18 +71,15 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte {
var outputBuffer = bytes.NewBuffer(nil) var outputBuffer = bytes.NewBuffer(nil)
common.Must2(outputBuffer.Write(generatedAuthID[:])) //16 common.Must2(outputBuffer.Write(generatedAuthID[:])) // 16
common.Must2(outputBuffer.Write(payloadHeaderLengthAEADEncrypted)) // 2+16
common.Must2(outputBuffer.Write(payloadHeaderLengthAEADEncrypted)) //2+16 common.Must2(outputBuffer.Write(connectionNonce)) // 8
common.Must2(outputBuffer.Write(connectionNonce)) //8
common.Must2(outputBuffer.Write(payloadHeaderAEADEncrypted)) common.Must2(outputBuffer.Write(payloadHeaderAEADEncrypted))
return outputBuffer.Bytes() return outputBuffer.Bytes()
} }
func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte, bool, error, int) { func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte, bool, int, error) {
var payloadHeaderLengthAEADEncrypted [18]byte var payloadHeaderLengthAEADEncrypted [18]byte
var nonce [8]byte var nonce [8]byte
@ -91,23 +88,23 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte,
authidCheckValueReadBytesCounts, err := io.ReadFull(data, payloadHeaderLengthAEADEncrypted[:]) authidCheckValueReadBytesCounts, err := io.ReadFull(data, payloadHeaderLengthAEADEncrypted[:])
bytesRead += authidCheckValueReadBytesCounts bytesRead += authidCheckValueReadBytesCounts
if err != nil { if err != nil {
return nil, false, err, bytesRead return nil, false, bytesRead, err
} }
nonceReadBytesCounts, err := io.ReadFull(data, nonce[:]) nonceReadBytesCounts, err := io.ReadFull(data, nonce[:])
bytesRead += nonceReadBytesCounts bytesRead += nonceReadBytesCounts
if err != nil { if err != nil {
return nil, false, err, bytesRead return nil, false, bytesRead, err
} }
//Decrypt Length // Decrypt Length
var decryptedAEADHeaderLengthPayloadResult []byte var decryptedAEADHeaderLengthPayloadResult []byte
{ {
payloadHeaderLengthAEADKey := KDF16(key[:], KDFSaltConst_VMessHeaderPayloadLengthAEADKey, string(authid[:]), string(nonce[:])) payloadHeaderLengthAEADKey := KDF16(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADKey, string(authid[:]), string(nonce[:]))
payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConst_VMessHeaderPayloadLengthAEADIV, string(authid[:]), string(nonce[:]))[:12] payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(authid[:]), string(nonce[:]))[:12]
payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey) payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey)
if err != nil { if err != nil {
@ -123,7 +120,7 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte,
decryptedAEADHeaderLengthPayload, erropenAEAD := payloadHeaderLengthAEAD.Open(nil, payloadHeaderLengthAEADNonce, payloadHeaderLengthAEADEncrypted[:], authid[:]) decryptedAEADHeaderLengthPayload, erropenAEAD := payloadHeaderLengthAEAD.Open(nil, payloadHeaderLengthAEADNonce, payloadHeaderLengthAEADEncrypted[:], authid[:])
if erropenAEAD != nil { if erropenAEAD != nil {
return nil, true, erropenAEAD, bytesRead return nil, true, bytesRead, erropenAEAD
} }
decryptedAEADHeaderLengthPayloadResult = decryptedAEADHeaderLengthPayload decryptedAEADHeaderLengthPayloadResult = decryptedAEADHeaderLengthPayload
@ -131,24 +128,24 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte,
var length uint16 var length uint16
common.Must(binary.Read(bytes.NewReader(decryptedAEADHeaderLengthPayloadResult[:]), binary.BigEndian, &length)) common.Must(binary.Read(bytes.NewReader(decryptedAEADHeaderLengthPayloadResult), binary.BigEndian, &length))
var decryptedAEADHeaderPayloadR []byte var decryptedAEADHeaderPayloadR []byte
var payloadHeaderAEADEncryptedReadedBytesCounts int var payloadHeaderAEADEncryptedReadedBytesCounts int
{ {
payloadHeaderAEADKey := KDF16(key[:], KDFSaltConst_VMessHeaderPayloadAEADKey, string(authid[:]), string(nonce[:])) payloadHeaderAEADKey := KDF16(key[:], KDFSaltConstVMessHeaderPayloadAEADKey, string(authid[:]), string(nonce[:]))
payloadHeaderAEADNonce := KDF(key[:], KDFSaltConst_VMessHeaderPayloadAEADIV, string(authid[:]), string(nonce[:]))[:12] payloadHeaderAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadAEADIV, string(authid[:]), string(nonce[:]))[:12]
//16 == AEAD Tag size // 16 == AEAD Tag size
payloadHeaderAEADEncrypted := make([]byte, length+16) payloadHeaderAEADEncrypted := make([]byte, length+16)
payloadHeaderAEADEncryptedReadedBytesCounts, err = io.ReadFull(data, payloadHeaderAEADEncrypted) payloadHeaderAEADEncryptedReadedBytesCounts, err = io.ReadFull(data, payloadHeaderAEADEncrypted)
bytesRead += payloadHeaderAEADEncryptedReadedBytesCounts bytesRead += payloadHeaderAEADEncryptedReadedBytesCounts
if err != nil { if err != nil {
return nil, false, err, bytesRead return nil, false, bytesRead, err
} }
payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey) payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey)
@ -165,11 +162,11 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte,
decryptedAEADHeaderPayload, erropenAEAD := payloadHeaderAEAD.Open(nil, payloadHeaderAEADNonce, payloadHeaderAEADEncrypted, authid[:]) decryptedAEADHeaderPayload, erropenAEAD := payloadHeaderAEAD.Open(nil, payloadHeaderAEADNonce, payloadHeaderAEADEncrypted, authid[:])
if erropenAEAD != nil { if erropenAEAD != nil {
return nil, true, erropenAEAD, bytesRead return nil, true, bytesRead, erropenAEAD
} }
decryptedAEADHeaderPayloadR = decryptedAEADHeaderPayload decryptedAEADHeaderPayloadR = decryptedAEADHeaderPayload
} }
return decryptedAEADHeaderPayloadR, false, nil, bytesRead return decryptedAEADHeaderPayloadR, false, bytesRead, nil
} }

View File

@ -22,7 +22,7 @@ func TestOpenVMessAEADHeader(t *testing.T) {
io.ReadFull(AEADR, authid[:]) io.ReadFull(AEADR, authid[:])
out, _, err, _ := OpenVMessAEADHeader(keyw, authid, AEADR) out, _, _, err := OpenVMessAEADHeader(keyw, authid, AEADR)
fmt.Println(string(out)) fmt.Println(string(out))
fmt.Println(err) fmt.Println(err)
@ -41,7 +41,7 @@ func TestOpenVMessAEADHeader2(t *testing.T) {
io.ReadFull(AEADR, authid[:]) io.ReadFull(AEADR, authid[:])
out, _, err, readen := OpenVMessAEADHeader(keyw, authid, AEADR) out, _, readen, err := OpenVMessAEADHeader(keyw, authid, AEADR)
assert.Equal(t, len(sealed)-16-AEADR.Len(), readen) assert.Equal(t, len(sealed)-16-AEADR.Len(), readen)
assert.Equal(t, string(TestHeader), string(out)) assert.Equal(t, string(TestHeader), string(out))
assert.Nil(t, err) assert.Nil(t, err)
@ -63,7 +63,7 @@ func TestOpenVMessAEADHeader4(t *testing.T) {
io.ReadFull(AEADR, authid[:]) io.ReadFull(AEADR, authid[:])
out, drain, err, readen := OpenVMessAEADHeader(keyw, authid, AEADR) out, drain, readen, err := OpenVMessAEADHeader(keyw, authid, AEADR)
assert.Equal(t, len(sealed)-16-AEADR.Len(), readen) assert.Equal(t, len(sealed)-16-AEADR.Len(), readen)
assert.Equal(t, true, drain) assert.Equal(t, true, drain)
assert.NotNil(t, err) assert.NotNil(t, err)
@ -72,12 +72,10 @@ func TestOpenVMessAEADHeader4(t *testing.T) {
} }
assert.Nil(t, out) assert.Nil(t, out)
} }
} }
func TestOpenVMessAEADHeader4Massive(t *testing.T) { func TestOpenVMessAEADHeader4Massive(t *testing.T) {
for j := 0; j < 1000; j++ { for j := 0; j < 1000; j++ {
for i := 0; i <= 60; i++ { for i := 0; i <= 60; i++ {
TestHeader := []byte("Test Header") TestHeader := []byte("Test Header")
key := KDF16([]byte("Demo Key for Auth ID Test"), "Demo Path for Auth ID Test") key := KDF16([]byte("Demo Key for Auth ID Test"), "Demo Path for Auth ID Test")
@ -93,7 +91,7 @@ func TestOpenVMessAEADHeader4Massive(t *testing.T) {
io.ReadFull(AEADR, authid[:]) io.ReadFull(AEADR, authid[:])
out, drain, err, readen := OpenVMessAEADHeader(keyw, authid, AEADR) out, drain, readen, err := OpenVMessAEADHeader(keyw, authid, AEADR)
assert.Equal(t, len(sealed)-16-AEADR.Len(), readen) assert.Equal(t, len(sealed)-16-AEADR.Len(), readen)
assert.Equal(t, true, drain) assert.Equal(t, true, drain)
assert.NotNil(t, err) assert.NotNil(t, err)

View File

@ -7,9 +7,7 @@ import (
) )
func KDF(key []byte, path ...string) []byte { func KDF(key []byte, path ...string) []byte {
hmacf := hmac.New(func() hash.Hash { hmacf := hmac.New(sha256.New, []byte(KDFSaltConstVMessAEADKDF))
return sha256.New()
}, []byte(KDFSaltConst_VMessAEADKDF))
for _, v := range path { for _, v := range path {
hmacf = hmac.New(func() hash.Hash { hmacf = hmac.New(func() hash.Hash {

View File

@ -47,8 +47,7 @@ type ClientSession struct {
} }
// NewClientSession creates a new ClientSession. // NewClientSession creates a new ClientSession.
func NewClientSession(isAEAD bool, idHash protocol.IDHash, ctx context.Context) *ClientSession { func NewClientSession(ctx context.Context, isAEAD bool, idHash protocol.IDHash) *ClientSession {
session := &ClientSession{ session := &ClientSession{
isAEAD: isAEAD, isAEAD: isAEAD,
idHash: idHash, idHash: idHash,
@ -114,7 +113,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
if !c.isAEAD { if !c.isAEAD {
iv := hashTimestamp(md5.New(), timestamp) iv := hashTimestamp(md5.New(), timestamp)
aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv[:]) aesStream := crypto.NewAesEncryptionStream(account.ID.CmdKey(), iv)
aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) aesStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
common.Must2(writer.Write(buffer.Bytes())) common.Must2(writer.Write(buffer.Bytes()))
} else { } else {
@ -193,8 +192,8 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:]) aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
c.responseReader = crypto.NewCryptionReader(aesStream, reader) c.responseReader = crypto.NewCryptionReader(aesStream, reader)
} else { } else {
aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConst_AEADRespHeaderLenKey) aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey)
aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConst_AEADRespHeaderLenIV)[:12] aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12]
aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block) aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block)
aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD) aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD)
@ -213,8 +212,8 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
decryptedResponseHeaderLength = int(decryptedResponseHeaderLengthBinaryDeserializeBuffer) decryptedResponseHeaderLength = int(decryptedResponseHeaderLengthBinaryDeserializeBuffer)
} }
aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConst_AEADRespHeaderPayloadKey) aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey)
aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConst_AEADRespHeaderPayloadIV)[:12] aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12]
aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block) aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block)
aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD) aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD)

View File

@ -43,7 +43,7 @@ func TestRequestSerialization(t *testing.T) {
} }
buffer := buf.New() buffer := buf.New()
client := NewClientSession(true, protocol.DefaultIDHash, context.TODO()) client := NewClientSession(context.TODO(), true, protocol.DefaultIDHash)
common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
buffer2 := buf.New() buffer2 := buf.New()
@ -93,7 +93,7 @@ func TestInvalidRequest(t *testing.T) {
} }
buffer := buf.New() buffer := buf.New()
client := NewClientSession(true, protocol.DefaultIDHash, context.TODO()) client := NewClientSession(context.TODO(), true, protocol.DefaultIDHash)
common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
buffer2 := buf.New() buffer2 := buf.New()
@ -134,7 +134,7 @@ func TestMuxRequest(t *testing.T) {
} }
buffer := buf.New() buffer := buf.New()
client := NewClientSession(true, protocol.DefaultIDHash, context.TODO()) client := NewClientSession(context.TODO(), true, protocol.DefaultIDHash)
common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
buffer2 := buf.New() buffer2 := buf.New()

View File

@ -26,7 +26,7 @@ import (
vmessaead "v2ray.com/core/proxy/vmess/aead" vmessaead "v2ray.com/core/proxy/vmess/aead"
) )
type sessionId struct { type sessionID struct {
user [16]byte user [16]byte
key [16]byte key [16]byte
nonce [16]byte nonce [16]byte
@ -35,14 +35,14 @@ type sessionId struct {
// SessionHistory keeps track of historical session ids, to prevent replay attacks. // SessionHistory keeps track of historical session ids, to prevent replay attacks.
type SessionHistory struct { type SessionHistory struct {
sync.RWMutex sync.RWMutex
cache map[sessionId]time.Time cache map[sessionID]time.Time
task *task.Periodic task *task.Periodic
} }
// NewSessionHistory creates a new SessionHistory object. // NewSessionHistory creates a new SessionHistory object.
func NewSessionHistory() *SessionHistory { func NewSessionHistory() *SessionHistory {
h := &SessionHistory{ h := &SessionHistory{
cache: make(map[sessionId]time.Time, 128), cache: make(map[sessionID]time.Time, 128),
} }
h.task = &task.Periodic{ h.task = &task.Periodic{
Interval: time.Second * 30, Interval: time.Second * 30,
@ -56,7 +56,7 @@ func (h *SessionHistory) Close() error {
return h.task.Close() return h.task.Close()
} }
func (h *SessionHistory) addIfNotExits(session sessionId) bool { func (h *SessionHistory) addIfNotExits(session sessionID) bool {
h.Lock() h.Lock()
if expire, found := h.cache[session]; found && expire.After(time.Now()) { if expire, found := h.cache[session]; found && expire.After(time.Now()) {
@ -87,7 +87,7 @@ func (h *SessionHistory) removeExpiredEntries() error {
} }
if len(h.cache) == 0 { if len(h.cache) == 0 {
h.cache = make(map[sessionId]time.Time, 128) h.cache = make(map[sessionID]time.Time, 128)
} }
return nil return nil
@ -141,7 +141,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
readSizeRemain := DrainSize readSizeRemain := DrainSize
drainConnection := func(e error) error { drainConnection := func(e error) error {
//We read a deterministic generated length of data before closing the connection to offset padding read pattern // We read a deterministic generated length of data before closing the connection to offset padding read pattern
readSizeRemain -= int(buffer.Len()) readSizeRemain -= int(buffer.Len())
if readSizeRemain > 0 { if readSizeRemain > 0 {
err := s.DrainConnN(reader, readSizeRemain) err := s.DrainConnN(reader, readSizeRemain)
@ -169,22 +169,24 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
var fixedSizeAuthID [16]byte var fixedSizeAuthID [16]byte
copy(fixedSizeAuthID[:], buffer.Bytes()) copy(fixedSizeAuthID[:], buffer.Bytes())
if foundAEAD { switch {
case foundAEAD:
vmessAccount = user.Account.(*vmess.MemoryAccount) vmessAccount = user.Account.(*vmess.MemoryAccount)
var fixedSizeCmdKey [16]byte var fixedSizeCmdKey [16]byte
copy(fixedSizeCmdKey[:], vmessAccount.ID.CmdKey()) copy(fixedSizeCmdKey[:], vmessAccount.ID.CmdKey())
aeadData, shouldDrain, errorReason, bytesRead := vmessaead.OpenVMessAEADHeader(fixedSizeCmdKey, fixedSizeAuthID, reader) aeadData, shouldDrain, bytesRead, errorReason := vmessaead.OpenVMessAEADHeader(fixedSizeCmdKey, fixedSizeAuthID, reader)
if errorReason != nil { if errorReason != nil {
if shouldDrain { if shouldDrain {
readSizeRemain -= bytesRead readSizeRemain -= bytesRead
return nil, drainConnection(newError("AEAD read failed").Base(errorReason)) return nil, drainConnection(newError("AEAD read failed").Base(errorReason))
} else { } else {
return nil, drainConnection(newError("AEAD read failed, drain skiped").Base(errorReason)) return nil, drainConnection(newError("AEAD read failed, drain skipped").Base(errorReason))
} }
} }
decryptor = bytes.NewReader(aeadData) decryptor = bytes.NewReader(aeadData)
s.isAEADRequest = true s.isAEADRequest = true
} else if !s.isAEADForced && errorAEAD == vmessaead.ErrNotFound {
case !s.isAEADForced && errorAEAD == vmessaead.ErrNotFound:
userLegacy, timestamp, valid, userValidationError := s.userValidator.Get(buffer.Bytes()) userLegacy, timestamp, valid, userValidationError := s.userValidator.Get(buffer.Bytes())
if !valid || userValidationError != nil { if !valid || userValidationError != nil {
return nil, drainConnection(newError("invalid user").Base(userValidationError)) return nil, drainConnection(newError("invalid user").Base(userValidationError))
@ -193,9 +195,10 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
iv := hashTimestamp(md5.New(), timestamp) iv := hashTimestamp(md5.New(), timestamp)
vmessAccount = userLegacy.Account.(*vmess.MemoryAccount) vmessAccount = userLegacy.Account.(*vmess.MemoryAccount)
aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:]) aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv)
decryptor = crypto.NewCryptionReader(aesStream, reader) decryptor = crypto.NewCryptionReader(aesStream, reader)
} else {
default:
return nil, drainConnection(newError("invalid user").Base(errorAEAD)) return nil, drainConnection(newError("invalid user").Base(errorAEAD))
} }
@ -212,7 +215,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
copy(s.requestBodyIV[:], buffer.BytesRange(1, 17)) // 16 bytes copy(s.requestBodyIV[:], buffer.BytesRange(1, 17)) // 16 bytes
copy(s.requestBodyKey[:], buffer.BytesRange(17, 33)) // 16 bytes copy(s.requestBodyKey[:], buffer.BytesRange(17, 33)) // 16 bytes
var sid sessionId var sid sessionID
copy(sid.user[:], vmessAccount.ID.Bytes()) copy(sid.user[:], vmessAccount.ID.Bytes())
sid.key = s.requestBodyKey sid.key = s.requestBodyKey
sid.nonce = s.requestBodyIV sid.nonce = s.requestBodyIV
@ -226,7 +229,6 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
} else { } else {
return nil, newError("duplicated session id, possibly under replay attack, but this is a AEAD request") return nil, newError("duplicated session id, possibly under replay attack, but this is a AEAD request")
} }
} }
s.responseHeader = buffer.Byte(33) // 1 byte s.responseHeader = buffer.Byte(33) // 1 byte
@ -240,6 +242,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
case protocol.RequestCommandMux: case protocol.RequestCommandMux:
request.Address = net.DomainAddress("v1.mux.cool") request.Address = net.DomainAddress("v1.mux.cool")
request.Port = 0 request.Port = 0
case protocol.RequestCommandTCP, protocol.RequestCommandUDP: case protocol.RequestCommandTCP, protocol.RequestCommandUDP:
if addr, port, err := addrParser.ReadAddressPort(buffer, decryptor); err == nil { if addr, port, err := addrParser.ReadAddressPort(buffer, decryptor); err == nil {
request.Address = addr request.Address = addr
@ -283,12 +286,11 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
if burnErr != nil { if burnErr != nil {
Autherr = newError("invalid auth, can't taint legacy userHash").Base(burnErr) Autherr = newError("invalid auth, can't taint legacy userHash").Base(burnErr)
} }
//It is possible that we are under attack described in https://github.com/v2ray/v2ray-core/issues/2523 // It is possible that we are under attack described in https://github.com/v2ray/v2ray-core/issues/2523
return nil, drainConnection(Autherr) return nil, drainConnection(Autherr)
} else { } else {
return nil, newError("invalid auth, but this is a AEAD request") return nil, newError("invalid auth, but this is a AEAD request")
} }
} }
if request.Address == nil { if request.Address == nil {
@ -327,8 +329,8 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
} }
return crypto.NewAuthenticationReader(auth, sizeParser, reader, protocol.TransferTypePacket, padding) return crypto.NewAuthenticationReader(auth, sizeParser, reader, protocol.TransferTypePacket, padding)
} }
return buf.NewReader(reader) return buf.NewReader(reader)
case protocol.SecurityType_LEGACY: case protocol.SecurityType_LEGACY:
aesStream := crypto.NewAesDecryptionStream(s.requestBodyKey[:], s.requestBodyIV[:]) aesStream := crypto.NewAesDecryptionStream(s.requestBodyKey[:], s.requestBodyIV[:])
cryptionReader := crypto.NewCryptionReader(aesStream, reader) cryptionReader := crypto.NewCryptionReader(aesStream, reader)
@ -340,17 +342,17 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
} }
return crypto.NewAuthenticationReader(auth, sizeParser, cryptionReader, request.Command.TransferType(), padding) return crypto.NewAuthenticationReader(auth, sizeParser, cryptionReader, request.Command.TransferType(), padding)
} }
return buf.NewReader(cryptionReader) return buf.NewReader(cryptionReader)
case protocol.SecurityType_AES128_GCM: case protocol.SecurityType_AES128_GCM:
aead := crypto.NewAesGcm(s.requestBodyKey[:]) aead := crypto.NewAesGcm(s.requestBodyKey[:])
auth := &crypto.AEADAuthenticator{ auth := &crypto.AEADAuthenticator{
AEAD: aead, AEAD: aead,
NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())), NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())),
AdditionalDataGenerator: crypto.GenerateEmptyBytes(), AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
} }
return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding) return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding)
case protocol.SecurityType_CHACHA20_POLY1305: case protocol.SecurityType_CHACHA20_POLY1305:
aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(s.requestBodyKey[:])) aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(s.requestBodyKey[:]))
@ -360,6 +362,7 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
AdditionalDataGenerator: crypto.GenerateEmptyBytes(), AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
} }
return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding) return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding)
default: default:
panic("Unknown security type.") panic("Unknown security type.")
} }
@ -395,9 +398,8 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr
} }
if s.isAEADRequest { if s.isAEADRequest {
aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey)
aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConst_AEADRespHeaderLenKey) aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12]
aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConst_AEADRespHeaderLenIV)[:12]
aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block) aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block)
aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD) aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD)
@ -411,8 +413,8 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr
AEADEncryptedLength := aeadResponseHeaderLengthEncryptionAEAD.Seal(nil, aeadResponseHeaderLengthEncryptionIV, aeadResponseHeaderLengthEncryptionBuffer.Bytes(), nil) AEADEncryptedLength := aeadResponseHeaderLengthEncryptionAEAD.Seal(nil, aeadResponseHeaderLengthEncryptionIV, aeadResponseHeaderLengthEncryptionBuffer.Bytes(), nil)
common.Must2(io.Copy(writer, bytes.NewReader(AEADEncryptedLength))) common.Must2(io.Copy(writer, bytes.NewReader(AEADEncryptedLength)))
aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConst_AEADRespHeaderPayloadKey) aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey)
aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConst_AEADRespHeaderPayloadIV)[:12] aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12]
aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block) aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block)
aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD) aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD)
@ -447,8 +449,8 @@ func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ
} }
return crypto.NewAuthenticationWriter(auth, sizeParser, writer, protocol.TransferTypePacket, padding) return crypto.NewAuthenticationWriter(auth, sizeParser, writer, protocol.TransferTypePacket, padding)
} }
return buf.NewWriter(writer) return buf.NewWriter(writer)
case protocol.SecurityType_LEGACY: case protocol.SecurityType_LEGACY:
if request.Option.Has(protocol.RequestOptionChunkStream) { if request.Option.Has(protocol.RequestOptionChunkStream) {
auth := &crypto.AEADAuthenticator{ auth := &crypto.AEADAuthenticator{
@ -458,17 +460,17 @@ func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ
} }
return crypto.NewAuthenticationWriter(auth, sizeParser, s.responseWriter, request.Command.TransferType(), padding) return crypto.NewAuthenticationWriter(auth, sizeParser, s.responseWriter, request.Command.TransferType(), padding)
} }
return &buf.SequentialWriter{Writer: s.responseWriter} return &buf.SequentialWriter{Writer: s.responseWriter}
case protocol.SecurityType_AES128_GCM: case protocol.SecurityType_AES128_GCM:
aead := crypto.NewAesGcm(s.responseBodyKey[:]) aead := crypto.NewAesGcm(s.responseBodyKey[:])
auth := &crypto.AEADAuthenticator{ auth := &crypto.AEADAuthenticator{
AEAD: aead, AEAD: aead,
NonceGenerator: GenerateChunkNonce(s.responseBodyIV[:], uint32(aead.NonceSize())), NonceGenerator: GenerateChunkNonce(s.responseBodyIV[:], uint32(aead.NonceSize())),
AdditionalDataGenerator: crypto.GenerateEmptyBytes(), AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
} }
return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding) return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding)
case protocol.SecurityType_CHACHA20_POLY1305: case protocol.SecurityType_CHACHA20_POLY1305:
aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(s.responseBodyKey[:])) aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(s.responseBodyKey[:]))
@ -478,6 +480,7 @@ func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ
AdditionalDataGenerator: crypto.GenerateEmptyBytes(), AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
} }
return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding) return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding)
default: default:
panic("Unknown security type.") panic("Unknown security type.")
} }

View File

@ -71,7 +71,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
if err != nil { if err != nil {
return newError("failed to find an available destination").Base(err).AtWarning() return newError("failed to find an available destination").Base(err).AtWarning()
} }
defer conn.Close() //nolint: errcheck defer conn.Close()
outbound := session.OutboundFromContext(ctx) outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() { if outbound == nil || !outbound.Target.IsValid() {
@ -114,11 +114,11 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
output := link.Writer output := link.Writer
isAEAD := false isAEAD := false
if !aead_disabled && len(account.AlterIDs) == 0 { if !aeadDisabled && len(account.AlterIDs) == 0 {
isAEAD = true isAEAD = true
} }
session := encoding.NewClientSession(isAEAD, protocol.DefaultIDHash, ctx) session := encoding.NewClientSession(ctx, isAEAD, protocol.DefaultIDHash)
sessionPolicy := h.policyManager.ForLevel(request.User.Level) sessionPolicy := h.policyManager.ForLevel(request.User.Level)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
@ -179,7 +179,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
var ( var (
enablePadding = false enablePadding = false
aead_disabled = false aeadDisabled = false
) )
func shouldEnablePadding(s protocol.SecurityType) bool { func shouldEnablePadding(s protocol.SecurityType) bool {
@ -198,8 +198,8 @@ func init() {
enablePadding = true enablePadding = true
} }
aeadDisabled := platform.NewEnvFlag("v2ray.vmess.aead.disabled").GetValue(func() string { return defaultFlagValue }) isAeadDisabled := platform.NewEnvFlag("v2ray.vmess.aead.disabled").GetValue(func() string { return defaultFlagValue })
if aeadDisabled == "true" { if isAeadDisabled == "true" {
aead_disabled = true aeadDisabled = true
} }
} }

View File

@ -5,7 +5,6 @@ package vmess
import ( import (
"crypto/hmac" "crypto/hmac"
"crypto/sha256" "crypto/sha256"
"hash"
"hash/crc64" "hash/crc64"
"strings" "strings"
"sync" "sync"
@ -142,7 +141,7 @@ func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error {
account := uu.user.Account.(*MemoryAccount) account := uu.user.Account.(*MemoryAccount)
if !v.behaviorFused { if !v.behaviorFused {
hashkdf := hmac.New(func() hash.Hash { return sha256.New() }, []byte("VMESSBSKDF")) hashkdf := hmac.New(sha256.New, []byte("VMESSBSKDF"))
hashkdf.Write(account.ID.Bytes()) hashkdf.Write(account.ID.Bytes())
v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), hashkdf.Sum(nil)) v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), hashkdf.Sum(nil))
} }

View File

@ -3,13 +3,13 @@ package scenarios
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/google/go-cmp/cmp/cmpopts"
"io" "io"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/grpc" "google.golang.org/grpc"
"v2ray.com/core" "v2ray.com/core"

View File

@ -203,6 +203,5 @@ func testTCPConn2(conn net.Conn, payloadSize int, timeout time.Duration) func()
} }
return nil return nil
} }
} }

View File

@ -2,6 +2,7 @@ package scenarios
import ( import (
"bytes" "bytes"
"context"
"crypto/rand" "crypto/rand"
"io" "io"
"io/ioutil" "io/ioutil"
@ -136,7 +137,7 @@ func TestHttpError(t *testing.T) {
} }
} }
func TestHttpConnectMethod(t *testing.T) { func TestHTTPConnectMethod(t *testing.T) {
tcpServer := tcp.Server{ tcpServer := tcp.Server{
MsgProcessor: xor, MsgProcessor: xor,
} }
@ -179,7 +180,9 @@ func TestHttpConnectMethod(t *testing.T) {
payload := make([]byte, 1024*64) payload := make([]byte, 1024*64)
common.Must2(rand.Read(payload)) common.Must2(rand.Read(payload))
req, err := http.NewRequest("Connect", "http://"+dest.NetAddr()+"/", bytes.NewReader(payload))
ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, "Connect", "http://"+dest.NetAddr()+"/", bytes.NewReader(payload))
req.Header.Set("X-a", "b") req.Header.Set("X-a", "b")
req.Header.Set("X-b", "d") req.Header.Set("X-b", "d")
common.Must(err) common.Must(err)
@ -334,7 +337,8 @@ func TestHttpBasicAuth(t *testing.T) {
} }
{ {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+httpServerPort.String(), nil) ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, "GET", "http://127.0.0.1:"+httpServerPort.String(), nil)
common.Must(err) common.Must(err)
setProxyBasicAuth(req, "a", "c") setProxyBasicAuth(req, "a", "c")
@ -346,7 +350,8 @@ func TestHttpBasicAuth(t *testing.T) {
} }
{ {
req, err := http.NewRequest("GET", "http://127.0.0.1:"+httpServerPort.String(), nil) ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, "GET", "http://127.0.0.1:"+httpServerPort.String(), nil)
common.Must(err) common.Must(err)
setProxyBasicAuth(req, "a", "b") setProxyBasicAuth(req, "a", "b")

View File

@ -226,7 +226,7 @@ func TestShadowsocksAES128UDP(t *testing.T) {
payload := make([]byte, 1024) payload := make([]byte, 1024)
common.Must2(rand.Read(payload)) common.Must2(rand.Read(payload))
nBytes, err := conn.Write([]byte(payload)) nBytes, err := conn.Write(payload)
if err != nil { if err != nil {
return err return err
} }

View File

@ -31,7 +31,7 @@ import (
tcptransport "v2ray.com/core/transport/internet/tcp" tcptransport "v2ray.com/core/transport/internet/tcp"
) )
func TestHttpConnectionHeader(t *testing.T) { func TestHTTPConnectionHeader(t *testing.T) {
tcpServer := tcp.Server{ tcpServer := tcp.Server{
MsgProcessor: xor, MsgProcessor: xor,
} }

View File

@ -32,7 +32,7 @@ func (s *Server) Start() (net.Destination, error) {
Handler: s, Handler: s,
} }
go s.server.ListenAndServe() go s.server.ListenAndServe()
return net.TCPDestination(net.LocalHostIP, net.Port(s.Port)), nil return net.TCPDestination(net.LocalHostIP, s.Port), nil
} }
func (s *Server) Close() error { func (s *Server) Close() error {

View File

@ -65,7 +65,7 @@ func (server *Server) handleConnection(conn net.Conn) {
pReader, pWriter := pipe.New(pipe.WithoutSizeLimit()) pReader, pWriter := pipe.New(pipe.WithoutSizeLimit())
err := task.Run(context.Background(), func() error { err := task.Run(context.Background(), func() error {
defer pWriter.Close() // nolint: errcheck defer pWriter.Close()
for { for {
b := buf.New() b := buf.New()
@ -102,7 +102,7 @@ func (server *Server) handleConnection(conn net.Conn) {
fmt.Println("failed to transfer data: ", err.Error()) fmt.Println("failed to transfer data: ", err.Error())
} }
conn.Close() // nolint: errcheck conn.Close()
} }
func (server *Server) Close() error { func (server *Server) Close() error {

Some files were not shown because too many files have changed in this diff Show More