diff --git a/app/dispatcher/impl/default.go b/app/dispatcher/impl/default.go index d3de9d873..89bc5570c 100644 --- a/app/dispatcher/impl/default.go +++ b/app/dispatcher/impl/default.go @@ -88,30 +88,11 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin return outbound, nil } -func trySnif(sniferList []proxyman.KnownProtocols, b []byte) (string, error) { - for _, protocol := range sniferList { - var f func([]byte) (string, error) - switch protocol { - case proxyman.KnownProtocols_HTTP: - f = SniffHTTP - case proxyman.KnownProtocols_TLS: - f = SniffTLS - default: - panic("Unsupported protocol") - } - - domain, err := f(b) - if err != ErrMoreData { - return domain, err - } - } - return "", ErrMoreData -} - func snifer(ctx context.Context, sniferList []proxyman.KnownProtocols, outbound ray.OutboundRay) (string, error) { payload := buf.New() defer payload.Release() + sniffer := NewSniffer(sniferList) totalAttempt := 0 for { select { @@ -124,7 +105,7 @@ func snifer(ctx context.Context, sniferList []proxyman.KnownProtocols, outbound } outbound.OutboundInput().Peek(payload) if !payload.IsEmpty() { - domain, err := trySnif(sniferList, payload.Bytes()) + domain, err := sniffer.Sniff(payload.Bytes()) if err != ErrMoreData { return domain, err } diff --git a/app/dispatcher/impl/sniffer.go b/app/dispatcher/impl/sniffer.go index c566c89cf..ba4add252 100644 --- a/app/dispatcher/impl/sniffer.go +++ b/app/dispatcher/impl/sniffer.go @@ -4,6 +4,7 @@ import ( "bytes" "strings" + "v2ray.com/core/app/proxyman" "v2ray.com/core/common/serial" ) @@ -166,3 +167,49 @@ func SniffTLS(b []byte) (string, error) { } return ReadClientHello(b[5 : 5+headerLen]) } + +type Sniffer struct { + slist []func([]byte) (string, error) + err []error +} + +func NewSniffer(sniferList []proxyman.KnownProtocols) *Sniffer { + s := new(Sniffer) + + for _, protocol := range sniferList { + var f func([]byte) (string, error) + switch protocol { + case proxyman.KnownProtocols_HTTP: + f = SniffHTTP + case proxyman.KnownProtocols_TLS: + f = SniffTLS + default: + panic("Unsupported protocol") + } + s.slist = append(s.slist, f) + } + s.err = make([]error, len(s.slist)) + + return s +} + +func (s *Sniffer) Sniff(payload []byte) (string, error) { + sniffed := false + for idx, sniffer := range s.slist { + if s.err[idx] != nil { + continue + } + sniffed = true + domain, err := sniffer(payload) + if err == nil { + return domain, nil + } + if err != ErrMoreData { + s.err[idx] = err + } + } + if sniffed { + return "", ErrMoreData + } + return "", s.err[0] +}