diff --git a/app/restful-api/restful-api.go b/app/restful-api/restful-api.go index 1608cf98c..8526d16f2 100644 --- a/app/restful-api/restful-api.go +++ b/app/restful-api/restful-api.go @@ -1,10 +1,11 @@ package restful_api import ( - "encoding/json" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/render" "github.com/go-playground/validator/v10" + core "github.com/v2fly/v2ray-core/v4" "github.com/v2fly/v2ray-core/v4/common/net" "github.com/v2fly/v2ray-core/v4/transport/internet" @@ -12,105 +13,69 @@ import ( "strings" ) -func JSONResponse(w http.ResponseWriter, data interface{}, code int) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.WriteHeader(code) - _ = json.NewEncoder(w).Encode(data) -} - var validate *validator.Validate -type StatsUser struct { - uuid string `validate:"required_without=email,uuid4"` - email string `validate:"required_without=uuid,email"` -} - -type StatsUserResponse struct { - Uplink int64 `json:"uplink"` - Downlink int64 `json:"downlink"` -} - -func (rs *restfulService) statsUser(w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - statsUser := &StatsUser{ - uuid: query.Get("uuid"), - email: query.Get("email"), - } - - if err := validate.Struct(statsUser); err != nil { - JSONResponse(w, http.StatusText(422), 422) - } - - response := &StatsUserResponse{ - Uplink: 0, - Downlink: 0, - } - - JSONResponse(w, response, 200) -} - -type Stats struct { - tag string `validate:"required,alpha,min=1,max=255"` -} - type StatsBound struct { // Better name? Uplink int64 `json:"uplink"` Downlink int64 `json:"downlink"` } -type StatsResponse struct { - Inbound StatsBound `json:"inbound"` - Outbound StatsBound `json:"outbound"` +func (rs *restfulService) tagStats(w http.ResponseWriter, r *http.Request) { + boundType := chi.URLParam(r, "bound_type") + tag := chi.URLParam(r, "tag") + + if validate.Var(boundType, "required,oneof=inbounds outbounds") != nil || + validate.Var(tag, "required,min=1,max=255") != nil { + render.Status(r, http.StatusUnprocessableEntity) + render.JSON(w, r, render.M{}) + return + } + + bound := boundType[:len(boundType)-1] + upCounter := rs.stats.GetCounter(bound + ">>>" + tag + ">>>traffic>>>uplink") + downCounter := rs.stats.GetCounter(bound + ">>>" + tag + ">>>traffic>>>downlink") + if upCounter == nil || downCounter == nil { + render.Status(r, http.StatusNotFound) + render.JSON(w, r, render.M{}) + return + } + + render.JSON(w, r, &StatsBound{ + Uplink: upCounter.Value(), + Downlink: downCounter.Value(), + }) } -func (rs *restfulService) statsRequest(w http.ResponseWriter, r *http.Request) { - stats := &Stats{ - tag: r.URL.Query().Get("tag"), - } - if err := validate.Struct(stats); err != nil { - JSONResponse(w, http.StatusText(422), 422) - } - - response := StatsResponse{ - Inbound: StatsBound{ - Uplink: 1, - Downlink: 1, - }, - Outbound: StatsBound{ - Uplink: 1, - Downlink: 1, - }} - - JSONResponse(w, response, 200) +func (rs *restfulService) version(w http.ResponseWriter, r *http.Request) { + render.JSON(w, r, render.M{"version": core.Version()}) } func (rs *restfulService) TokenAuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - auth := r.Header.Get("Authorization") - const prefix = "Bearer " - if !strings.HasPrefix(auth, prefix) { - JSONResponse(w, http.StatusText(403), 403) - return - } - auth = strings.TrimPrefix(auth, prefix) - if auth != rs.config.AuthToken { - JSONResponse(w, http.StatusText(403), 403) + header := r.Header.Get("Authorization") + text := strings.SplitN(header, " ", 2) + + hasInvalidHeader := text[0] != "Bearer" + hasInvalidSecret := len(text) != 2 || text[1] != rs.config.AuthToken + if hasInvalidHeader || hasInvalidSecret { + render.Status(r, http.StatusUnauthorized) + render.JSON(w, r, render.M{}) return } + next.ServeHTTP(w, r) }) } func (rs *restfulService) start() error { r := chi.NewRouter() - r.Use(rs.TokenAuthMiddleware) r.Use(middleware.Heartbeat("/ping")) + validate = validator.New() r.Route("/v1", func(r chi.Router) { - r.Get("/stats/user", rs.statsUser) - r.Get("/stats", rs.statsRequest) + r.Get("/{bound_type}/{tag}/stats", rs.tagStats) }) + r.Get("/version", rs.version) var listener net.Listener var err error @@ -134,6 +99,5 @@ func (rs *restfulService) start() error { newError("unable to serve restful api").WriteToLog() } }() - return nil } diff --git a/go.mod b/go.mod index bff270ab0..5f733908b 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.17 require ( github.com/go-chi/chi/v5 v5.0.4 + github.com/go-chi/render v1.0.1 github.com/go-playground/validator/v10 v10.9.0 github.com/golang/mock v1.6.0 github.com/golang/protobuf v1.5.2 diff --git a/go.sum b/go.sum index 16f0c45b4..ff2f87a39 100644 --- a/go.sum +++ b/go.sum @@ -74,6 +74,8 @@ github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeME github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-chi/chi/v5 v5.0.4 h1:5e494iHzsYBiyXQAHHuI4tyJS9M3V84OuX3ufIIGHFo= github.com/go-chi/chi/v5 v5.0.4/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= +github.com/go-chi/render v1.0.1/go.mod h1:pq4Rr7HbnsdaeHagklXub+p6Wd16Af5l9koip1OvJns= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=