355 lines
9.4 KiB
Go
355 lines
9.4 KiB
Go
package http
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestStatusHandler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
handler StatusHandler
|
|
expectedCode int
|
|
expectedBody string
|
|
}{
|
|
{
|
|
name: "not found handler",
|
|
handler: NotFoundHandler,
|
|
expectedCode: http.StatusNotFound,
|
|
expectedBody: "Not Found",
|
|
},
|
|
{
|
|
name: "not implemented handler",
|
|
handler: NotImplementedHandler,
|
|
expectedCode: http.StatusNotImplemented,
|
|
expectedBody: "Not Implemented",
|
|
},
|
|
{
|
|
name: "not allowed handler",
|
|
handler: NotAllowedHandler,
|
|
expectedCode: http.StatusMethodNotAllowed,
|
|
expectedBody: "Method Not Allowed",
|
|
},
|
|
{
|
|
name: "not legal handler",
|
|
handler: NotLegalHandler,
|
|
expectedCode: http.StatusUnavailableForLegalReasons,
|
|
expectedBody: "Unavailable For Legal Reasons",
|
|
},
|
|
{
|
|
name: "custom status",
|
|
handler: StatusHandler(http.StatusTeapot),
|
|
expectedCode: http.StatusTeapot,
|
|
expectedBody: "I'm a teapot",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
tt.handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != tt.expectedCode {
|
|
t.Errorf("expected status code %d, got %d", tt.expectedCode, w.Code)
|
|
}
|
|
|
|
if w.Body.String() != tt.expectedBody {
|
|
t.Errorf("expected body %q, got %q", tt.expectedBody, w.Body.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBasicAuth(t *testing.T) {
|
|
protectedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("success"))
|
|
})
|
|
|
|
t.Run("successful authentication", func(t *testing.T) {
|
|
// Note: The BasicAuth implementation uses SHA1 hash of the stored password
|
|
// and compares it (case-insensitive) with the incoming password
|
|
// So to authenticate, the password sent must be the base64 SHA1 hash of the stored value
|
|
htpasswd := map[string]string{
|
|
"user": "pass",
|
|
}
|
|
handler := BasicAuth(protectedHandler, htpasswd, "test realm")
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
// SHA1 hash of "pass" as base64: nU4eI71bcnBGqeO0t9tXvY1u5oQ=
|
|
req.SetBasicAuth("user", "nU4eI71bcnBGqeO0t9tXvY1u5oQ=")
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
|
}
|
|
|
|
if w.Body.String() != "success" {
|
|
t.Errorf("expected body 'success', got %q", w.Body.String())
|
|
}
|
|
})
|
|
|
|
t.Run("missing credentials", func(t *testing.T) {
|
|
htpasswd := map[string]string{
|
|
"user": "pass",
|
|
}
|
|
handler := BasicAuth(protectedHandler, htpasswd, "test realm")
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusUnauthorized {
|
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
|
}
|
|
|
|
authHeader := w.Header().Get("WWW-Authenticate")
|
|
expectedHeader := `Basic realm="test realm"`
|
|
if authHeader != expectedHeader {
|
|
t.Errorf("expected WWW-Authenticate header %q, got %q", expectedHeader, authHeader)
|
|
}
|
|
})
|
|
|
|
t.Run("wrong username", func(t *testing.T) {
|
|
htpasswd := map[string]string{
|
|
"user": "pass",
|
|
}
|
|
handler := BasicAuth(protectedHandler, htpasswd, "test realm")
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.SetBasicAuth("wronguser", "pass")
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusUnauthorized {
|
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
|
}
|
|
})
|
|
|
|
t.Run("wrong password", func(t *testing.T) {
|
|
htpasswd := map[string]string{
|
|
"user": "pass",
|
|
}
|
|
handler := BasicAuth(protectedHandler, htpasswd, "test realm")
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.SetBasicAuth("user", "wrongpass")
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusUnauthorized {
|
|
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestMultiHandler(t *testing.T) {
|
|
getHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("GET response"))
|
|
})
|
|
postHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("POST response"))
|
|
})
|
|
|
|
t.Run("valid methods", func(t *testing.T) {
|
|
handlers := map[string]http.Handler{
|
|
http.MethodGet: getHandler,
|
|
http.MethodPost: postHandler,
|
|
}
|
|
|
|
handler, err := MultiHandler(handlers)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
|
|
t.Run("GET request", func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
w := httptest.NewRecorder()
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Body.String() != "GET response" {
|
|
t.Errorf("expected 'GET response', got %q", w.Body.String())
|
|
}
|
|
})
|
|
|
|
t.Run("POST request", func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
|
w := httptest.NewRecorder()
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Body.String() != "POST response" {
|
|
t.Errorf("expected 'POST response', got %q", w.Body.String())
|
|
}
|
|
})
|
|
|
|
t.Run("unsupported method", func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodDelete, "/", nil)
|
|
w := httptest.NewRecorder()
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusMethodNotAllowed {
|
|
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
|
|
}
|
|
})
|
|
})
|
|
|
|
t.Run("invalid method returns error", func(t *testing.T) {
|
|
handlers := map[string]http.Handler{
|
|
"INVALID": getHandler,
|
|
}
|
|
|
|
handler, err := MultiHandler(handlers)
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid method, got nil")
|
|
}
|
|
|
|
if handler != nil {
|
|
t.Error("expected nil handler when error occurs")
|
|
}
|
|
|
|
expectedErr := "invalid HTTP method: INVALID"
|
|
if err.Error() != expectedErr {
|
|
t.Errorf("expected error %q, got %q", expectedErr, err.Error())
|
|
}
|
|
})
|
|
|
|
t.Run("all standard methods", func(t *testing.T) {
|
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprintf(w, "handled %s", r.Method)
|
|
})
|
|
|
|
handlers := map[string]http.Handler{
|
|
http.MethodGet: testHandler,
|
|
http.MethodHead: testHandler,
|
|
http.MethodPost: testHandler,
|
|
http.MethodPut: testHandler,
|
|
http.MethodPatch: testHandler,
|
|
http.MethodDelete: testHandler,
|
|
http.MethodConnect: testHandler,
|
|
http.MethodOptions: testHandler,
|
|
http.MethodTrace: testHandler,
|
|
}
|
|
|
|
handler, err := MultiHandler(handlers)
|
|
if err != nil {
|
|
t.Fatalf("expected no error for standard methods, got %v", err)
|
|
}
|
|
|
|
methods := []string{
|
|
http.MethodGet, http.MethodHead, http.MethodPost,
|
|
http.MethodPut, http.MethodPatch, http.MethodDelete,
|
|
http.MethodConnect, http.MethodOptions, http.MethodTrace,
|
|
}
|
|
|
|
for _, method := range methods {
|
|
req := httptest.NewRequest(method, "/", nil)
|
|
w := httptest.NewRecorder()
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("method %s: expected status 200, got %d", method, w.Code)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("empty handlers map", func(t *testing.T) {
|
|
handlers := map[string]http.Handler{}
|
|
|
|
handler, err := MultiHandler(handlers)
|
|
if err != nil {
|
|
t.Fatalf("expected no error for empty map, got %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
w := httptest.NewRecorder()
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusMethodNotAllowed {
|
|
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDownloadFile(t *testing.T) {
|
|
t.Run("successful download", func(t *testing.T) {
|
|
content := "test file content"
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(content))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tmpDir := t.TempDir()
|
|
destPath := filepath.Join(tmpDir, "downloaded.txt")
|
|
|
|
err := DownloadFile(server.URL, destPath)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
|
|
data, err := os.ReadFile(destPath)
|
|
if err != nil {
|
|
t.Fatalf("failed to read downloaded file: %v", err)
|
|
}
|
|
|
|
if string(data) != content {
|
|
t.Errorf("expected %q, got %q", content, string(data))
|
|
}
|
|
})
|
|
|
|
t.Run("non-200 status code", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}))
|
|
defer server.Close()
|
|
|
|
tmpDir := t.TempDir()
|
|
destPath := filepath.Join(tmpDir, "downloaded.txt")
|
|
|
|
err := DownloadFile(server.URL, destPath)
|
|
if err == nil {
|
|
t.Fatal("expected error for non-200 status code, got nil")
|
|
}
|
|
|
|
expectedMsg := "download failed: unexpected status code 404"
|
|
if !strings.Contains(err.Error(), expectedMsg) {
|
|
t.Errorf("expected error to contain %q, got %q", expectedMsg, err.Error())
|
|
}
|
|
})
|
|
|
|
t.Run("invalid URL", func(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
destPath := filepath.Join(tmpDir, "downloaded.txt")
|
|
|
|
err := DownloadFile("http://invalid.nonexistent.domain.test", destPath)
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid URL, got nil")
|
|
}
|
|
})
|
|
|
|
t.Run("invalid destination path", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("content"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
err := DownloadFile(server.URL, "/invalid/nonexistent/path/file.txt")
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid path, got nil")
|
|
}
|
|
})
|
|
}
|