diff --git a/transport/internet/udp/dispatcher.go b/transport/internet/udp/dispatcher.go index 969aa056a..f526d52f2 100644 --- a/transport/internet/udp/dispatcher.go +++ b/transport/internet/udp/dispatcher.go @@ -15,57 +15,56 @@ type ResponseCallback func(payload *buf.Buffer) type Dispatcher struct { sync.RWMutex - conns map[string]ray.InboundRay + conns map[v2net.Destination]ray.InboundRay dispatcher dispatcher.Interface } func NewDispatcher(dispatcher dispatcher.Interface) *Dispatcher { return &Dispatcher{ - conns: make(map[string]ray.InboundRay), + conns: make(map[v2net.Destination]ray.InboundRay), dispatcher: dispatcher, } } -func (v *Dispatcher) RemoveRay(name string) { +func (v *Dispatcher) RemoveRay(dest v2net.Destination) { v.Lock() defer v.Unlock() - if conn, found := v.conns[name]; found { + if conn, found := v.conns[dest]; found { conn.InboundInput().Close() conn.InboundOutput().Close() - delete(v.conns, name) + delete(v.conns, dest) } } func (v *Dispatcher) getInboundRay(ctx context.Context, dest v2net.Destination) (ray.InboundRay, bool) { - destString := dest.String() v.Lock() defer v.Unlock() - if entry, found := v.conns[destString]; found { + if entry, found := v.conns[dest]; found { return entry, true } log.Info("UDP|Server: establishing new connection for ", dest) inboundRay, _ := v.dispatcher.Dispatch(ctx, dest) + v.conns[dest] = inboundRay return inboundRay, false } func (v *Dispatcher) Dispatch(ctx context.Context, destination v2net.Destination, payload *buf.Buffer, callback ResponseCallback) { // TODO: Add user to destString - destString := destination.String() - log.Debug("UDP|Server: Dispatch request: ", destString) + log.Debug("UDP|Server: Dispatch request: ", destination) inboundRay, existing := v.getInboundRay(ctx, destination) outputStream := inboundRay.InboundInput() if outputStream != nil { if err := outputStream.Write(payload); err != nil { - v.RemoveRay(destString) + v.RemoveRay(destination) } } if !existing { go func() { handleInput(inboundRay.InboundOutput(), callback) - v.RemoveRay(destString) + v.RemoveRay(destination) }() } }