mirror of
https://github.com/v2fly/v2ray-core.git
synced 2024-12-22 01:57:12 -05:00
clean up dns package
This commit is contained in:
parent
a430e2065a
commit
0dbfb66126
@ -2,12 +2,14 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"v2ray.com/core/app/dispatcher"
|
"v2ray.com/core/app/dispatcher"
|
||||||
"v2ray.com/core/app/log"
|
"v2ray.com/core/app/log"
|
||||||
|
"v2ray.com/core/common"
|
||||||
"v2ray.com/core/common/buf"
|
"v2ray.com/core/common/buf"
|
||||||
"v2ray.com/core/common/dice"
|
"v2ray.com/core/common/dice"
|
||||||
"v2ray.com/core/common/net"
|
"v2ray.com/core/common/net"
|
||||||
@ -15,7 +17,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DefaultTTL = uint32(3600)
|
|
||||||
CleanupInterval = time.Second * 120
|
CleanupInterval = time.Second * 120
|
||||||
CleanupThreshold = 512
|
CleanupThreshold = 512
|
||||||
)
|
)
|
||||||
@ -55,7 +56,6 @@ func NewUDPNameServer(address net.Destination, dispatcher dispatcher.Interface)
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Private: Visible for testing.
|
|
||||||
func (v *UDPNameServer) Cleanup() {
|
func (v *UDPNameServer) Cleanup() {
|
||||||
expiredRequests := make([]uint16, 0, 16)
|
expiredRequests := make([]uint16, 0, 16)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@ -70,10 +70,8 @@ func (v *UDPNameServer) Cleanup() {
|
|||||||
delete(v.requests, id)
|
delete(v.requests, id)
|
||||||
}
|
}
|
||||||
v.Unlock()
|
v.Unlock()
|
||||||
expiredRequests = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Private: Visible for testing.
|
|
||||||
func (v *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
|
func (v *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
|
||||||
var id uint16
|
var id uint16
|
||||||
v.Lock()
|
v.Lock()
|
||||||
@ -98,7 +96,6 @@ func (v *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
|
|||||||
return id
|
return id
|
||||||
}
|
}
|
||||||
|
|
||||||
// Private: Visible for testing.
|
|
||||||
func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
|
func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
|
||||||
msg := new(dns.Msg)
|
msg := new(dns.Msg)
|
||||||
err := msg.Unpack(payload.Bytes())
|
err := msg.Unpack(payload.Bytes())
|
||||||
@ -110,8 +107,8 @@ func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
|
|||||||
IPs: make([]net.IP, 0, 16),
|
IPs: make([]net.IP, 0, 16),
|
||||||
}
|
}
|
||||||
id := msg.Id
|
id := msg.Id
|
||||||
ttl := DefaultTTL
|
ttl := uint32(3600) // an hour
|
||||||
log.Trace(newError("handling response for id ", id, " content: ", msg.String()).AtDebug())
|
log.Trace(newError("handling response for id ", id, " content: ", msg).AtDebug())
|
||||||
|
|
||||||
v.Lock()
|
v.Lock()
|
||||||
request, found := v.requests[id]
|
request, found := v.requests[id]
|
||||||
@ -126,6 +123,7 @@ func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
|
|||||||
switch rr := rr.(type) {
|
switch rr := rr.(type) {
|
||||||
case *dns.A:
|
case *dns.A:
|
||||||
record.IPs = append(record.IPs, rr.A)
|
record.IPs = append(record.IPs, rr.A)
|
||||||
|
fmt.Println("Adding ans:", rr.A)
|
||||||
if rr.Hdr.Ttl < ttl {
|
if rr.Hdr.Ttl < ttl {
|
||||||
ttl = rr.Hdr.Ttl
|
ttl = rr.Hdr.Ttl
|
||||||
}
|
}
|
||||||
@ -152,13 +150,18 @@ func (v *UDPNameServer) BuildQueryA(domain string, id uint16) *buf.Buffer {
|
|||||||
Name: dns.Fqdn(domain),
|
Name: dns.Fqdn(domain),
|
||||||
Qtype: dns.TypeA,
|
Qtype: dns.TypeA,
|
||||||
Qclass: dns.ClassINET,
|
Qclass: dns.ClassINET,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: dns.Fqdn(domain),
|
||||||
|
Qtype: dns.TypeAAAA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
buffer := buf.New()
|
buffer := buf.New()
|
||||||
buffer.AppendSupplier(func(b []byte) (int, error) {
|
common.Must(buffer.Reset(func(b []byte) (int, error) {
|
||||||
writtenBuffer, err := msg.PackBuffer(b)
|
writtenBuffer, err := msg.PackBuffer(b)
|
||||||
return len(writtenBuffer), err
|
return len(writtenBuffer), err
|
||||||
})
|
}))
|
||||||
|
|
||||||
return buffer
|
return buffer
|
||||||
}
|
}
|
||||||
@ -167,7 +170,7 @@ func (v *UDPNameServer) QueryA(domain string) <-chan *ARecord {
|
|||||||
response := make(chan *ARecord, 1)
|
response := make(chan *ARecord, 1)
|
||||||
id := v.AssignUnusedID(response)
|
id := v.AssignUnusedID(response)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*8)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
|
v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@ -176,11 +179,10 @@ func (v *UDPNameServer) QueryA(domain string) <-chan *ARecord {
|
|||||||
v.Lock()
|
v.Lock()
|
||||||
_, found := v.requests[id]
|
_, found := v.requests[id]
|
||||||
v.Unlock()
|
v.Unlock()
|
||||||
if found {
|
if !found {
|
||||||
v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
|
|
||||||
} else {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@ -205,7 +207,7 @@ func (v *LocalNameServer) QueryA(domain string) <-chan *ARecord {
|
|||||||
|
|
||||||
response <- &ARecord{
|
response <- &ARecord{
|
||||||
IPs: ips,
|
IPs: ips,
|
||||||
Expire: time.Now().Add(time.Second * time.Duration(DefaultTTL)),
|
Expire: time.Now().Add(time.Hour),
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"v2ray.com/core/common/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
type IPResult struct {
|
|
||||||
IP []net.IP
|
|
||||||
TTL time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
type Querier interface {
|
|
||||||
QueryDomain(domain string) <-chan *IPResult
|
|
||||||
}
|
|
||||||
|
|
||||||
type UDPQuerier struct {
|
|
||||||
server net.Destination
|
|
||||||
}
|
|
@ -21,22 +21,22 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type DomainRecord struct {
|
type DomainRecord struct {
|
||||||
A *ARecord
|
|
||||||
}
|
|
||||||
|
|
||||||
type Record struct {
|
|
||||||
IP []net.IP
|
IP []net.IP
|
||||||
Expire time.Time
|
Expire time.Time
|
||||||
LastAccess time.Time
|
LastAccess time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Record) Expired() bool {
|
func (r *DomainRecord) Expired() bool {
|
||||||
|
return r.Expire.Before(time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *DomainRecord) Inactive() bool {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return r.Expire.Before(now) || r.LastAccess.Add(time.Hour).Before(now)
|
return r.Expire.Before(now) || r.LastAccess.Add(time.Minute*5).Before(now)
|
||||||
}
|
}
|
||||||
|
|
||||||
type CacheServer struct {
|
type CacheServer struct {
|
||||||
sync.RWMutex
|
sync.Mutex
|
||||||
hosts map[string]net.IP
|
hosts map[string]net.IP
|
||||||
records map[string]*DomainRecord
|
records map[string]*DomainRecord
|
||||||
servers []NameServer
|
servers []NameServer
|
||||||
@ -90,15 +90,33 @@ func (*CacheServer) Start() error {
|
|||||||
func (*CacheServer) Close() {}
|
func (*CacheServer) Close() {}
|
||||||
|
|
||||||
func (s *CacheServer) GetCached(domain string) []net.IP {
|
func (s *CacheServer) GetCached(domain string) []net.IP {
|
||||||
s.RLock()
|
s.Lock()
|
||||||
defer s.RUnlock()
|
defer s.Unlock()
|
||||||
|
|
||||||
if record, found := s.records[domain]; found && record.A.Expire.After(time.Now()) {
|
if record, found := s.records[domain]; found && !record.Expired() {
|
||||||
return record.A.IPs
|
record.LastAccess = time.Now()
|
||||||
|
return record.IP
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *CacheServer) tryCleanup() {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
|
||||||
|
if len(s.records) > 256 {
|
||||||
|
domains := make([]string, 0, 256)
|
||||||
|
for d, r := range s.records {
|
||||||
|
if r.Expired() {
|
||||||
|
domains = append(domains, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, d := range domains {
|
||||||
|
delete(s.records, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *CacheServer) Get(domain string) []net.IP {
|
func (s *CacheServer) Get(domain string) []net.IP {
|
||||||
if ip, found := s.hosts[domain]; found {
|
if ip, found := s.hosts[domain]; found {
|
||||||
return []net.IP{ip}
|
return []net.IP{ip}
|
||||||
@ -110,6 +128,8 @@ func (s *CacheServer) Get(domain string) []net.IP {
|
|||||||
return ips
|
return ips
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.tryCleanup()
|
||||||
|
|
||||||
for _, server := range s.servers {
|
for _, server := range s.servers {
|
||||||
response := server.QueryA(domain)
|
response := server.QueryA(domain)
|
||||||
select {
|
select {
|
||||||
@ -119,7 +139,9 @@ func (s *CacheServer) Get(domain string) []net.IP {
|
|||||||
}
|
}
|
||||||
s.Lock()
|
s.Lock()
|
||||||
s.records[domain] = &DomainRecord{
|
s.records[domain] = &DomainRecord{
|
||||||
A: a,
|
IP: a.IPs,
|
||||||
|
Expire: a.Expire,
|
||||||
|
LastAccess: time.Now(),
|
||||||
}
|
}
|
||||||
s.Unlock()
|
s.Unlock()
|
||||||
log.Trace(newError("returning ", len(a.IPs), " IPs for domain ", domain).AtDebug())
|
log.Trace(newError("returning ", len(a.IPs), " IPs for domain ", domain).AtDebug())
|
||||||
|
102
app/dns/server/server_test.go
Normal file
102
app/dns/server/server_test.go
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"v2ray.com/core/app"
|
||||||
|
"v2ray.com/core/app/dispatcher"
|
||||||
|
_ "v2ray.com/core/app/dispatcher/impl"
|
||||||
|
. "v2ray.com/core/app/dns"
|
||||||
|
_ "v2ray.com/core/app/dns/server"
|
||||||
|
"v2ray.com/core/app/proxyman"
|
||||||
|
_ "v2ray.com/core/app/proxyman/outbound"
|
||||||
|
"v2ray.com/core/common"
|
||||||
|
"v2ray.com/core/common/net"
|
||||||
|
"v2ray.com/core/common/serial"
|
||||||
|
"v2ray.com/core/proxy/freedom"
|
||||||
|
"v2ray.com/core/testing/servers/udp"
|
||||||
|
. "v2ray.com/ext/assert"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
type staticHandler struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
ans := new(dns.Msg)
|
||||||
|
ans.Id = r.Id
|
||||||
|
for _, q := range r.Question {
|
||||||
|
if q.Name == "google.com." && q.Qtype == dns.TypeA {
|
||||||
|
rr, _ := dns.NewRR("google.com. IN A 8.8.8.8")
|
||||||
|
ans.Answer = append(ans.Answer, rr)
|
||||||
|
} else if q.Name == "facebook.com." && q.Qtype == dns.TypeA {
|
||||||
|
rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9")
|
||||||
|
ans.Answer = append(ans.Answer, rr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteMsg(ans)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPServer(t *testing.T) {
|
||||||
|
assert := With(t)
|
||||||
|
|
||||||
|
port := udp.PickPort()
|
||||||
|
|
||||||
|
dnsServer := dns.Server{
|
||||||
|
Addr: "127.0.0.1:" + port.String(),
|
||||||
|
Net: "udp",
|
||||||
|
Handler: &staticHandler{},
|
||||||
|
UDPSize: 1200,
|
||||||
|
}
|
||||||
|
|
||||||
|
go dnsServer.ListenAndServe()
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
NameServers: []*net.Endpoint{
|
||||||
|
{
|
||||||
|
Network: net.Network_UDP,
|
||||||
|
Address: &net.IPOrDomain{
|
||||||
|
Address: &net.IPOrDomain_Ip{
|
||||||
|
Ip: []byte{127, 0, 0, 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Port: uint32(port),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
space := app.NewSpace()
|
||||||
|
|
||||||
|
ctx = app.ContextWithSpace(ctx, space)
|
||||||
|
common.Must(app.AddApplicationToSpace(ctx, config))
|
||||||
|
common.Must(app.AddApplicationToSpace(ctx, &dispatcher.Config{}))
|
||||||
|
common.Must(app.AddApplicationToSpace(ctx, &proxyman.OutboundConfig{}))
|
||||||
|
|
||||||
|
om := proxyman.OutboundHandlerManagerFromSpace(space)
|
||||||
|
om.AddHandler(ctx, &proxyman.OutboundHandlerConfig{
|
||||||
|
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
|
||||||
|
})
|
||||||
|
|
||||||
|
common.Must(space.Initialize())
|
||||||
|
common.Must(space.Start())
|
||||||
|
|
||||||
|
server := FromSpace(space)
|
||||||
|
assert(server, IsNotNil)
|
||||||
|
|
||||||
|
ips := server.Get("google.com")
|
||||||
|
assert(len(ips), Equals, 1)
|
||||||
|
assert([]byte(ips[0]), Equals, []byte{8, 8, 8, 8})
|
||||||
|
|
||||||
|
ips = server.Get("facebook.com")
|
||||||
|
assert(len(ips), Equals, 1)
|
||||||
|
assert([]byte(ips[0]), Equals, []byte{9, 9, 9, 9})
|
||||||
|
|
||||||
|
dnsServer.Shutdown()
|
||||||
|
|
||||||
|
ips = server.Get("google.com")
|
||||||
|
assert(len(ips), Equals, 1)
|
||||||
|
assert([]byte(ips[0]), Equals, []byte{8, 8, 8, 8})
|
||||||
|
}
|
@ -53,8 +53,7 @@ func (t *ActivityTimer) run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.Context, *ActivityTimer) {
|
func CancelAfterInactivity(ctx context.Context, cancel context.CancelFunc, timeout time.Duration) *ActivityTimer {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
timer := &ActivityTimer{
|
timer := &ActivityTimer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
@ -63,5 +62,5 @@ func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.
|
|||||||
}
|
}
|
||||||
timer.timeout <- timeout
|
timer.timeout <- timeout
|
||||||
go timer.run()
|
go timer.run()
|
||||||
return ctx, timer
|
return timer
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,8 @@ import (
|
|||||||
func TestActivityTimer(t *testing.T) {
|
func TestActivityTimer(t *testing.T) {
|
||||||
assert := With(t)
|
assert := With(t)
|
||||||
|
|
||||||
ctx, timer := CancelAfterInactivity(context.Background(), time.Second*5)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
timer := CancelAfterInactivity(ctx, cancel, time.Second*5)
|
||||||
time.Sleep(time.Second * 6)
|
time.Sleep(time.Second * 6)
|
||||||
assert(ctx.Err(), IsNotNil)
|
assert(ctx.Err(), IsNotNil)
|
||||||
runtime.KeepAlive(timer)
|
runtime.KeepAlive(timer)
|
||||||
@ -22,7 +23,8 @@ func TestActivityTimer(t *testing.T) {
|
|||||||
func TestActivityTimerUpdate(t *testing.T) {
|
func TestActivityTimerUpdate(t *testing.T) {
|
||||||
assert := With(t)
|
assert := With(t)
|
||||||
|
|
||||||
ctx, timer := CancelAfterInactivity(context.Background(), time.Second*10)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
timer := CancelAfterInactivity(ctx, cancel, time.Second*10)
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
assert(ctx.Err(), IsNil)
|
assert(ctx.Err(), IsNil)
|
||||||
timer.SetTimeout(time.Second * 1)
|
timer.SetTimeout(time.Second * 1)
|
||||||
|
@ -64,7 +64,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
|
|||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = time.Minute * 5
|
timeout = time.Minute * 5
|
||||||
}
|
}
|
||||||
ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
timer := signal.CancelAfterInactivity(ctx, cancel, timeout)
|
||||||
|
|
||||||
inboundRay, err := dispatcher.Dispatch(ctx, dest)
|
inboundRay, err := dispatcher.Dispatch(ctx, dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -107,7 +107,8 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
|
|||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = time.Minute * 5
|
timeout = time.Minute * 5
|
||||||
}
|
}
|
||||||
ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
timer := signal.CancelAfterInactivity(ctx, cancel, timeout)
|
||||||
|
|
||||||
requestDone := signal.ExecuteAsync(func() error {
|
requestDone := signal.ExecuteAsync(func() error {
|
||||||
var writer buf.Writer
|
var writer buf.Writer
|
||||||
|
@ -153,7 +153,8 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
|
|||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = time.Minute * 5
|
timeout = time.Minute * 5
|
||||||
}
|
}
|
||||||
ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
timer := signal.CancelAfterInactivity(ctx, cancel, timeout)
|
||||||
ray, err := dispatcher.Dispatch(ctx, dest)
|
ray, err := dispatcher.Dispatch(ctx, dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -90,7 +90,8 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
|
|||||||
request.Option |= RequestOptionOneTimeAuth
|
request.Option |= RequestOptionOneTimeAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*5)
|
||||||
|
|
||||||
if request.Command == protocol.RequestCommandTCP {
|
if request.Command == protocol.RequestCommandTCP {
|
||||||
bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
|
bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
|
||||||
|
@ -146,7 +146,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
|
|||||||
ctx = protocol.ContextWithUser(ctx, request.User)
|
ctx = protocol.ContextWithUser(ctx, request.User)
|
||||||
|
|
||||||
userSettings := s.user.GetSettings()
|
userSettings := s.user.GetSettings()
|
||||||
ctx, timer := signal.CancelAfterInactivity(ctx, userSettings.PayloadTimeout)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
timer := signal.CancelAfterInactivity(ctx, cancel, userSettings.PayloadTimeout)
|
||||||
ray, err := dispatcher.Dispatch(ctx, dest)
|
ray, err := dispatcher.Dispatch(ctx, dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -83,7 +83,8 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay, dialer proxy.
|
|||||||
return newError("failed to establish connection to server").AtWarning().Base(err)
|
return newError("failed to establish connection to server").AtWarning().Base(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*5)
|
||||||
|
|
||||||
var requestFunc func() error
|
var requestFunc func() error
|
||||||
var responseFunc func() error
|
var responseFunc func() error
|
||||||
|
@ -107,7 +107,8 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
|
|||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = time.Minute * 5
|
timeout = time.Minute * 5
|
||||||
}
|
}
|
||||||
ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
timer := signal.CancelAfterInactivity(ctx, cancel, timeout)
|
||||||
|
|
||||||
ray, err := dispatcher.Dispatch(ctx, dest)
|
ray, err := dispatcher.Dispatch(ctx, dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -204,7 +204,8 @@ func (v *Handler) Process(ctx context.Context, network net.Network, connection i
|
|||||||
|
|
||||||
ctx = protocol.ContextWithUser(ctx, request.User)
|
ctx = protocol.ContextWithUser(ctx, request.User)
|
||||||
|
|
||||||
ctx, timer := signal.CancelAfterInactivity(ctx, userSettings.PayloadTimeout)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
timer := signal.CancelAfterInactivity(ctx, cancel, userSettings.PayloadTimeout)
|
||||||
ray, err := dispatcher.Dispatch(ctx, request.Destination())
|
ray, err := dispatcher.Dispatch(ctx, request.Destination())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return newError("failed to dispatch request to ", request.Destination()).Base(err)
|
return newError("failed to dispatch request to ", request.Destination()).Base(err)
|
||||||
|
@ -103,7 +103,8 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
|
|||||||
|
|
||||||
session := encoding.NewClientSession(protocol.DefaultIDHash)
|
session := encoding.NewClientSession(protocol.DefaultIDHash)
|
||||||
|
|
||||||
ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*5)
|
||||||
|
|
||||||
requestDone := signal.ExecuteAsync(func() error {
|
requestDone := signal.ExecuteAsync(func() error {
|
||||||
writer := buf.NewBufferedWriter(buf.NewWriter(conn))
|
writer := buf.NewBufferedWriter(buf.NewWriter(conn))
|
||||||
|
@ -3,25 +3,33 @@ package udp
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"v2ray.com/core/app/dispatcher"
|
"v2ray.com/core/app/dispatcher"
|
||||||
"v2ray.com/core/app/log"
|
"v2ray.com/core/app/log"
|
||||||
"v2ray.com/core/common/buf"
|
"v2ray.com/core/common/buf"
|
||||||
"v2ray.com/core/common/net"
|
"v2ray.com/core/common/net"
|
||||||
|
"v2ray.com/core/common/signal"
|
||||||
"v2ray.com/core/transport/ray"
|
"v2ray.com/core/transport/ray"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ResponseCallback func(payload *buf.Buffer)
|
type ResponseCallback func(payload *buf.Buffer)
|
||||||
|
|
||||||
|
type connEntry struct {
|
||||||
|
inbound ray.InboundRay
|
||||||
|
timer signal.ActivityUpdater
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
type Dispatcher struct {
|
type Dispatcher struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
conns map[net.Destination]ray.InboundRay
|
conns map[net.Destination]*connEntry
|
||||||
dispatcher dispatcher.Interface
|
dispatcher dispatcher.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDispatcher(dispatcher dispatcher.Interface) *Dispatcher {
|
func NewDispatcher(dispatcher dispatcher.Interface) *Dispatcher {
|
||||||
return &Dispatcher{
|
return &Dispatcher{
|
||||||
conns: make(map[net.Destination]ray.InboundRay),
|
conns: make(map[net.Destination]*connEntry),
|
||||||
dispatcher: dispatcher,
|
dispatcher: dispatcher,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -30,51 +38,72 @@ func (v *Dispatcher) RemoveRay(dest net.Destination) {
|
|||||||
v.Lock()
|
v.Lock()
|
||||||
defer v.Unlock()
|
defer v.Unlock()
|
||||||
if conn, found := v.conns[dest]; found {
|
if conn, found := v.conns[dest]; found {
|
||||||
conn.InboundInput().Close()
|
conn.inbound.InboundInput().Close()
|
||||||
conn.InboundOutput().Close()
|
conn.inbound.InboundOutput().Close()
|
||||||
delete(v.conns, dest)
|
delete(v.conns, dest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (ray.InboundRay, bool) {
|
func (v *Dispatcher) getInboundRay(dest net.Destination, callback ResponseCallback) *connEntry {
|
||||||
v.Lock()
|
v.Lock()
|
||||||
defer v.Unlock()
|
defer v.Unlock()
|
||||||
|
|
||||||
if entry, found := v.conns[dest]; found {
|
if entry, found := v.conns[dest]; found {
|
||||||
return entry, true
|
return entry
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace(newError("establishing new connection for ", dest))
|
log.Trace(newError("establishing new connection for ", dest))
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
removeRay := func() {
|
||||||
|
cancel()
|
||||||
|
v.RemoveRay(dest)
|
||||||
|
}
|
||||||
|
timer := signal.CancelAfterInactivity(ctx, removeRay, time.Second*4)
|
||||||
inboundRay, _ := v.dispatcher.Dispatch(ctx, dest)
|
inboundRay, _ := v.dispatcher.Dispatch(ctx, dest)
|
||||||
v.conns[dest] = inboundRay
|
entry := &connEntry{
|
||||||
return inboundRay, false
|
inbound: inboundRay,
|
||||||
|
timer: timer,
|
||||||
|
cancel: removeRay,
|
||||||
|
}
|
||||||
|
v.conns[dest] = entry
|
||||||
|
go handleInput(ctx, entry, callback)
|
||||||
|
return entry
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer, callback ResponseCallback) {
|
func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer, callback ResponseCallback) {
|
||||||
// TODO: Add user to destString
|
// TODO: Add user to destString
|
||||||
log.Trace(newError("dispatch request to: ", destination).AtDebug())
|
log.Trace(newError("dispatch request to: ", destination).AtDebug())
|
||||||
|
|
||||||
inboundRay, existing := v.getInboundRay(ctx, destination)
|
conn := v.getInboundRay(destination, callback)
|
||||||
outputStream := inboundRay.InboundInput()
|
outputStream := conn.inbound.InboundInput()
|
||||||
if outputStream != nil {
|
if outputStream != nil {
|
||||||
if err := outputStream.WriteMultiBuffer(buf.NewMultiBufferValue(payload)); err != nil {
|
if err := outputStream.WriteMultiBuffer(buf.NewMultiBufferValue(payload)); err != nil {
|
||||||
v.RemoveRay(destination)
|
log.Trace(newError("failed to write first UDP payload").Base(err))
|
||||||
|
conn.cancel()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !existing {
|
|
||||||
go func() {
|
|
||||||
handleInput(inboundRay.InboundOutput(), callback)
|
|
||||||
v.RemoveRay(destination)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleInput(input ray.InputStream, callback ResponseCallback) {
|
func handleInput(ctx context.Context, conn *connEntry, callback ResponseCallback) {
|
||||||
|
input := conn.inbound.InboundOutput()
|
||||||
|
timer := conn.timer
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
mb, err := input.ReadMultiBuffer()
|
mb, err := input.ReadMultiBuffer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
log.Trace(newError("failed to handl UDP input").Base(err))
|
||||||
|
conn.cancel()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
timer.Update()
|
||||||
for _, b := range mb {
|
for _, b := range mb {
|
||||||
callback(b)
|
callback(b)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user