Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
19860f713c |
95
net/http/mux.go
Normal file
95
net/http/mux.go
Normal file
@ -0,0 +1,95 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ServeMux struct {
|
||||
routes []route
|
||||
}
|
||||
|
||||
func (mux *ServeMux) Handle(pattern string, handler http.Handler, pathParams ...any) {
|
||||
mux.routes = append(mux.routes, newRoute(pattern, handler, pathParams...))
|
||||
}
|
||||
|
||||
func (mux *ServeMux) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request), pathParams ...any) {
|
||||
mux.routes = append(mux.routes, newRoute(pattern, http.HandlerFunc(handler), pathParams...))
|
||||
}
|
||||
|
||||
func (mux *ServeMux) Handler(r *http.Request) (h http.Handler, pattern string) {
|
||||
for _, rte := range mux.routes {
|
||||
switch {
|
||||
case rte.matcher(r):
|
||||
return rte.handler, rte.pattern
|
||||
}
|
||||
}
|
||||
return http.HandlerFunc(http.NotFound), ""
|
||||
}
|
||||
|
||||
func (mux *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.RequestURI == "*" {
|
||||
if r.ProtoAtLeast(1, 1) {
|
||||
w.Header().Set("Connection", "close")
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
h, _ := mux.Handler(r)
|
||||
h.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
type route struct {
|
||||
pattern string
|
||||
matcher func(r *http.Request) bool
|
||||
handler http.HandlerFunc
|
||||
}
|
||||
|
||||
func newRoute(pattern string, handler http.Handler, vars ...interface{}) route {
|
||||
return route{
|
||||
pattern,
|
||||
func(r *http.Request) bool {
|
||||
return match(r.URL.Path, pattern, vars...)
|
||||
},
|
||||
handler.ServeHTTP,
|
||||
}
|
||||
}
|
||||
|
||||
// match reports whether path matches the given pattern, which is a
|
||||
// path with '+' wildcards wherever you want to use a parameter. Path
|
||||
// parameters are assigned to the pointers in vars (len(vars) must be
|
||||
// the number of wildcards), which must be of type *string or *int.
|
||||
func match(path, pattern string, vars ...interface{}) bool {
|
||||
for ; pattern != "" && path != ""; pattern = pattern[1:] {
|
||||
switch pattern[0] {
|
||||
case '+':
|
||||
// '+' matches till next slash in path
|
||||
slash := strings.IndexByte(path, '/')
|
||||
if slash < 0 {
|
||||
slash = len(path)
|
||||
}
|
||||
segment := path[:slash]
|
||||
path = path[slash:]
|
||||
switch p := vars[0].(type) {
|
||||
case *string:
|
||||
*p = segment
|
||||
case *int:
|
||||
n, err := strconv.Atoi(segment)
|
||||
if err != nil || n < 0 {
|
||||
return false
|
||||
}
|
||||
*p = n
|
||||
default:
|
||||
panic("vars must be *string or *int")
|
||||
}
|
||||
vars = vars[1:]
|
||||
case path[0]:
|
||||
// non-'+' pattern byte must match path byte
|
||||
path = path[1:]
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return path == "" && pattern == ""
|
||||
}
|
Loading…
Reference in New Issue
Block a user