Files
x/net/http/http_test.go
Colin Henry 54aae5f242
All checks were successful
Go / build (1.23) (push) Successful in 3m51s
big updates: tests, bug fixed, documentation. oh my
2026-01-03 15:53:50 -08:00

355 lines
9.4 KiB
Go

package http
import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
func TestStatusHandler(t *testing.T) {
tests := []struct {
name string
handler StatusHandler
expectedCode int
expectedBody string
}{
{
name: "not found handler",
handler: NotFoundHandler,
expectedCode: http.StatusNotFound,
expectedBody: "Not Found",
},
{
name: "not implemented handler",
handler: NotImplementedHandler,
expectedCode: http.StatusNotImplemented,
expectedBody: "Not Implemented",
},
{
name: "not allowed handler",
handler: NotAllowedHandler,
expectedCode: http.StatusMethodNotAllowed,
expectedBody: "Method Not Allowed",
},
{
name: "not legal handler",
handler: NotLegalHandler,
expectedCode: http.StatusUnavailableForLegalReasons,
expectedBody: "Unavailable For Legal Reasons",
},
{
name: "custom status",
handler: StatusHandler(http.StatusTeapot),
expectedCode: http.StatusTeapot,
expectedBody: "I'm a teapot",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
tt.handler.ServeHTTP(w, req)
if w.Code != tt.expectedCode {
t.Errorf("expected status code %d, got %d", tt.expectedCode, w.Code)
}
if w.Body.String() != tt.expectedBody {
t.Errorf("expected body %q, got %q", tt.expectedBody, w.Body.String())
}
})
}
}
func TestBasicAuth(t *testing.T) {
protectedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
})
t.Run("successful authentication", func(t *testing.T) {
// Note: The BasicAuth implementation uses SHA1 hash of the stored password
// and compares it (case-insensitive) with the incoming password
// So to authenticate, the password sent must be the base64 SHA1 hash of the stored value
htpasswd := map[string]string{
"user": "pass",
}
handler := BasicAuth(protectedHandler, htpasswd, "test realm")
req := httptest.NewRequest(http.MethodGet, "/", nil)
// SHA1 hash of "pass" as base64: nU4eI71bcnBGqeO0t9tXvY1u5oQ=
req.SetBasicAuth("user", "nU4eI71bcnBGqeO0t9tXvY1u5oQ=")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
if w.Body.String() != "success" {
t.Errorf("expected body 'success', got %q", w.Body.String())
}
})
t.Run("missing credentials", func(t *testing.T) {
htpasswd := map[string]string{
"user": "pass",
}
handler := BasicAuth(protectedHandler, htpasswd, "test realm")
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
authHeader := w.Header().Get("WWW-Authenticate")
expectedHeader := `Basic realm="test realm"`
if authHeader != expectedHeader {
t.Errorf("expected WWW-Authenticate header %q, got %q", expectedHeader, authHeader)
}
})
t.Run("wrong username", func(t *testing.T) {
htpasswd := map[string]string{
"user": "pass",
}
handler := BasicAuth(protectedHandler, htpasswd, "test realm")
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.SetBasicAuth("wronguser", "pass")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
})
t.Run("wrong password", func(t *testing.T) {
htpasswd := map[string]string{
"user": "pass",
}
handler := BasicAuth(protectedHandler, htpasswd, "test realm")
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.SetBasicAuth("user", "wrongpass")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
})
}
func TestMultiHandler(t *testing.T) {
getHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("GET response"))
})
postHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("POST response"))
})
t.Run("valid methods", func(t *testing.T) {
handlers := map[string]http.Handler{
http.MethodGet: getHandler,
http.MethodPost: postHandler,
}
handler, err := MultiHandler(handlers)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
t.Run("GET request", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Body.String() != "GET response" {
t.Errorf("expected 'GET response', got %q", w.Body.String())
}
})
t.Run("POST request", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Body.String() != "POST response" {
t.Errorf("expected 'POST response', got %q", w.Body.String())
}
})
t.Run("unsupported method", func(t *testing.T) {
req := httptest.NewRequest(http.MethodDelete, "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
}
})
})
t.Run("invalid method returns error", func(t *testing.T) {
handlers := map[string]http.Handler{
"INVALID": getHandler,
}
handler, err := MultiHandler(handlers)
if err == nil {
t.Fatal("expected error for invalid method, got nil")
}
if handler != nil {
t.Error("expected nil handler when error occurs")
}
expectedErr := "invalid HTTP method: INVALID"
if err.Error() != expectedErr {
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
}
})
t.Run("all standard methods", func(t *testing.T) {
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "handled %s", r.Method)
})
handlers := map[string]http.Handler{
http.MethodGet: testHandler,
http.MethodHead: testHandler,
http.MethodPost: testHandler,
http.MethodPut: testHandler,
http.MethodPatch: testHandler,
http.MethodDelete: testHandler,
http.MethodConnect: testHandler,
http.MethodOptions: testHandler,
http.MethodTrace: testHandler,
}
handler, err := MultiHandler(handlers)
if err != nil {
t.Fatalf("expected no error for standard methods, got %v", err)
}
methods := []string{
http.MethodGet, http.MethodHead, http.MethodPost,
http.MethodPut, http.MethodPatch, http.MethodDelete,
http.MethodConnect, http.MethodOptions, http.MethodTrace,
}
for _, method := range methods {
req := httptest.NewRequest(method, "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("method %s: expected status 200, got %d", method, w.Code)
}
}
})
t.Run("empty handlers map", func(t *testing.T) {
handlers := map[string]http.Handler{}
handler, err := MultiHandler(handlers)
if err != nil {
t.Fatalf("expected no error for empty map, got %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
}
})
}
func TestDownloadFile(t *testing.T) {
t.Run("successful download", func(t *testing.T) {
content := "test file content"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(content))
}))
defer server.Close()
tmpDir := t.TempDir()
destPath := filepath.Join(tmpDir, "downloaded.txt")
err := DownloadFile(server.URL, destPath)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
data, err := os.ReadFile(destPath)
if err != nil {
t.Fatalf("failed to read downloaded file: %v", err)
}
if string(data) != content {
t.Errorf("expected %q, got %q", content, string(data))
}
})
t.Run("non-200 status code", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
tmpDir := t.TempDir()
destPath := filepath.Join(tmpDir, "downloaded.txt")
err := DownloadFile(server.URL, destPath)
if err == nil {
t.Fatal("expected error for non-200 status code, got nil")
}
expectedMsg := "download failed: unexpected status code 404"
if !strings.Contains(err.Error(), expectedMsg) {
t.Errorf("expected error to contain %q, got %q", expectedMsg, err.Error())
}
})
t.Run("invalid URL", func(t *testing.T) {
tmpDir := t.TempDir()
destPath := filepath.Join(tmpDir, "downloaded.txt")
err := DownloadFile("http://invalid.nonexistent.domain.test", destPath)
if err == nil {
t.Fatal("expected error for invalid URL, got nil")
}
})
t.Run("invalid destination path", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("content"))
}))
defer server.Close()
err := DownloadFile(server.URL, "/invalid/nonexistent/path/file.txt")
if err == nil {
t.Fatal("expected error for invalid path, got nil")
}
})
}