package main import ( "fmt" "net/http" "net/http/httptest" "testing" ) func init() { allowedHosts = []string{"teapot-dummy-target.example.com"} } func TestHttpReverseProxy(t *testing.T) { var header map[string][]string teapot := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { header = r.Header.Clone() w.WriteHeader(http.StatusTeapot) })) defer teapot.Close() targetHost = teapot.Listener.Addr().String() targetScheme = "http" reverseProxy := newHttpReverseProxy(":8081", false) fmt.Println("Proxy server listening on " + reverseProxy.Addr) fmt.Println("Dummy target server listening on " + targetHost) var tests = []struct { host string wantXForwardedHostHeader string wantResponseCode int }{ { "teapot-dummy-target.example.com", "teapot-dummy-target.example.com", http.StatusTeapot, }, { myPtrAddr, "", http.StatusOK, }, { reverseProxy.Addr, "", http.StatusGatewayTimeout, }, } for _, test := range tests { t.Run(fmt.Sprintf("requesting host %q returns status %d and header X-Forwarded-Host set to %q", test.host, test.wantResponseCode, test.wantXForwardedHostHeader), func(t *testing.T) { request, _ := http.NewRequest("GET", "/", nil) request.Host = test.host response := httptest.NewRecorder() reverseProxy.Handler.ServeHTTP(response, request) gotHeader := header header = nil assertStatus(t, response.Code, test.wantResponseCode) if len(test.wantXForwardedHostHeader) > 0 { assertHeader(t, gotHeader, "X-Forwarded-Host", test.wantXForwardedHostHeader) } }) } } func assertHeader(t *testing.T, gotHeader map[string][]string, headerName, want string) { t.Helper() var got string if lookup, ok := gotHeader[headerName]; ok { got = lookup[0] } if got != want { t.Errorf("got %s, want %s", got, want) } } func assertStatus(t *testing.T, got, want int) { t.Helper() if got != want { t.Errorf("got %d, want %d", got, want) } }