diff --git a/cache/tiered.go b/cache/tiered.go index e538c93..6411012 100644 --- a/cache/tiered.go +++ b/cache/tiered.go @@ -1,6 +1,10 @@ package cache -import "reflect" +import ( + "reflect" + + "git.sdf.org/jchenry/x" +) type tieredCache[K comparable, V any] struct { inner Interface[K, V] @@ -8,6 +12,8 @@ type tieredCache[K comparable, V any] struct { } func NewTieredCache[K comparable, V any](inner, outer Interface[K, V]) Interface[K, V] { + x.Assert(inner != nil, "cache.NewTieredCache: inner cannot be nil") + x.Assert(outer != nil, "cache.NewTieredCache: outer cannot be nil") return &tieredCache[K, V]{ inner: inner, outer: outer, diff --git a/database/actor.go b/database/actor.go index ab7f6dc..dad527d 100644 --- a/database/actor.go +++ b/database/actor.go @@ -3,6 +3,8 @@ package database import ( "context" "database/sql" + + "git.sdf.org/jchenry/x" ) type Func func(db *sql.DB) @@ -13,6 +15,7 @@ type Actor struct { } func (a *Actor) Run(ctx context.Context) error { + x.Assert(ctx != nil, "Actor.Run: context cannot be nil") for { select { case f := <-a.ActionChan: diff --git a/net/http/auth.go b/net/http/auth.go index 7068ef8..5bae5fa 100644 --- a/net/http/auth.go +++ b/net/http/auth.go @@ -6,9 +6,13 @@ import ( "fmt" "net/http" "strings" + + "git.sdf.org/jchenry/x" ) func BasicAuth(h http.Handler, htpasswd map[string]string, realm string) http.HandlerFunc { + x.Assert(len(htpasswd) > 0, "http.BasicAuth: htpassword cannot be empty") + x.Assert(len(realm) > 0, "http.BasicAuth: realm cannot be empty") rlm := fmt.Sprintf(`Basic realm="%s"`, realm) sha1 := func(password string) string { s := sha1.New() diff --git a/pkg.go b/pkg.go index 51bb964..9d22d90 100644 --- a/pkg.go +++ b/pkg.go @@ -9,3 +9,31 @@ func Assert(cond bool, msg string) { panic(errors.New(msg)) } } + +func Check(cond bool, err error) *invariants { + return new(invariants).Check(cond, err) +} + +type invariants struct { + errs []error +} + +func (i *invariants) Check(cond bool, err error) *invariants { + if !cond { + i.errs = append(i.errs, err) + } + return i +} + +func (i *invariants) Join() error { + return errors.Join(i.errs...) +} +func (i *invariants) First() error { + if len(i.errs) > 0 { + return i.errs[0] + } + return nil +} +func (i *invariants) All() []error { + return i.errs +}