fixed unit tests

This commit is contained in:
Colin Henry 2022-09-15 21:21:51 -07:00
parent 74c99d013c
commit 593c8db66e
2 changed files with 257 additions and 280 deletions

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
@ -11,23 +12,28 @@ import (
"strings" "strings"
) )
func PathParam(ctx context.Context, Param func(ctx context.Context, paramName string) string, paramName string, required bool, dt string) (p any, err error) { func PathParam(ctx context.Context, Param func(ctx context.Context, paramName string) string, p interface{}, paramName string, required bool) (err error) {
s := Param(ctx, paramName) s := Param(ctx, paramName)
if s == "" && required { if s == "" && required {
return nil, errors.New("missing required parameter") switch v := p.(type) {
case *string:
p = v
p = nil
}
return errors.New("missing required parameter")
} }
switch dt { switch v := p.(type) {
case "int64": case *int64:
p, err = strconv.ParseInt(s, 10, 64) *v, err = strconv.ParseInt(s, 10, 64)
case "int32": case *int32:
var x int64 var x int64
x, err = strconv.ParseInt(s, 10, 32) x, err = strconv.ParseInt(s, 10, 32)
p = int32(x) *v = int32(x)
case "string": case *string:
p = s *v = s
default: default:
err = errors.New("no match for type") err = fmt.Errorf("no match for pointer type %T", v)
} }
return return
@ -41,8 +47,7 @@ func BodyParam(body io.ReadCloser, p any, v func(p any) error) (err error) {
return return
} }
func mappedParam(m map[string][]string, paramName string, required bool, dt string) (p any, err error) { func mappedParam(m map[string][]string, paramName string, p interface{}, required bool) (err error) {
var s string var s string
q, exists := m[paramName] q, exists := m[paramName]
if !exists { // intentionally left empty if !exists { // intentionally left empty
@ -53,60 +58,57 @@ func mappedParam(m map[string][]string, paramName string, required bool, dt stri
} }
if s == "" && required { if s == "" && required {
return nil, errors.New("missing required parameter") return errors.New("missing required parameter")
} }
switch dt { switch v := p.(type) {
case "int64": case *int64:
p, err = strconv.ParseInt(s, 10, 64) *v, err = strconv.ParseInt(s, 10, 64)
case "int32": case *int32:
var x int64 var x int64
x, err = strconv.ParseInt(s, 10, 32) x, err = strconv.ParseInt(s, 10, 32)
p = int32(x) *v = int32(x)
case "bool": case *bool:
var b bool *v, err = strconv.ParseBool(s)
b, err = strconv.ParseBool(s) case *string:
p = bool(b) *v = s
case "string": case *[]int64:
p = s
case "[]int64":
str := strings.Split(s, ",") str := strings.Split(s, ",")
ints := make([]int64, len(str)) for _, s := range str {
for i, s := range str { if e, err := strconv.ParseInt(s, 10, 64); err != nil {
if v, err := strconv.ParseInt(s, 10, 64); err != nil { return err
return nil, err
} else { } else {
ints[i] = v *v = append(*v, e)
// ints[i] = e
} }
} }
p = ints case *[]int32:
case "[]int32":
str := strings.Split(s, ",") str := strings.Split(s, ",")
ints := make([]int32, len(str)) for _, s := range str {
for i, s := range str { if e, err := strconv.ParseInt(s, 10, 32); err != nil {
if v, err := strconv.ParseInt(s, 10, 32); err != nil { return err
return nil, err
} else { } else {
ints[i] = int32(v) *v = append(*v, int32(e))
// ints[i] = e
} }
} }
p = ints case *[]string:
case "[]string": *v = strings.Split(s, ",")
p = strings.Split(s, ",")
default: default:
err = errors.New("no match for type") err = fmt.Errorf("no match for pointer type %T", v)
} }
return return
} }
func QueryParam(query url.Values, paramName string, required bool, dt string) (p any, err error) { func QueryParam(query url.Values, paramName string, p interface{}, required bool) (err error) {
return mappedParam(query, paramName, required, dt) return mappedParam(query, paramName, p, required)
} }
func HeaderParam(h http.Header, paramName string, required bool, dt string) (p any, err error) { func HeaderParam(h http.Header, paramName string, p interface{}, required bool) (err error) {
return mappedParam(h, paramName, required, dt) return mappedParam(h, paramName, p, required)
} }
func FormParam(form url.Values, paramName string, required bool, dt string) (p any, err error) { func FormParam(form url.Values, paramName string, p interface{}, required bool) (err error) {
return mappedParam(form, paramName, required, dt) return mappedParam(form, paramName, p, required)
} }

View File

@ -11,246 +11,213 @@ import (
) )
func TestPathParam(t *testing.T) { func TestPathParam(t *testing.T) {
type args struct { t.Run("test int64 parse", func(t *testing.T) {
ctx context.Context var p int64
Param func(ctx context.Context, paramName string) string err := PathParam(context.WithValue(context.Background(), contextKey("int64id"), "123"), Param, &p, "int64id", true)
paramName string if (err != nil) != false {
required bool t.Errorf("PathParam() error = %v, wantErr %v", err, false)
dt string return
} }
tests := []struct { if !reflect.DeepEqual(p, int64(123)) {
name string t.Errorf("PathParam() = %v, want %v", p, int64(123))
args args }
wantP any })
wantErr bool
}{ t.Run("test int32 parse", func(t *testing.T) {
{ var p int32
name: "test int64 parse", err := PathParam(context.WithValue(context.Background(), contextKey("int32id"), "123"), Param, &p, "int32id", true)
args: args{ if (err != nil) != false {
context.WithValue(context.Background(), contextKey("int64id"), "123"), t.Errorf("PathParam() error = %v, wantErr %v", err, false)
Param, return
"int64id", }
true, if !reflect.DeepEqual(p, int32(123)) {
"int64", t.Errorf("PathParam() = %v, want %v", p, int32(123))
}, }
wantP: int64(123), })
wantErr: false,
}, t.Run("test string parse", func(t *testing.T) {
{ var p string
name: "test int32 parse", err := PathParam(
args: args{ context.WithValue(
context.WithValue(context.Background(), contextKey("int32id"), "123"), context.Background(),
Param, contextKey("stringid"),
"int32id", "foo"),
true, Param,
"int32", &p,
}, "stringid",
wantP: int32(123), true)
wantErr: false, if (err != nil) != false {
}, t.Errorf("PathParam() error = %v, wantErr %v", err, false)
{ return
name: "test string parse", }
args: args{ if !reflect.DeepEqual(p, "foo") {
context.WithValue(context.Background(), contextKey("stringid"), "foo"), t.Errorf("PathParam() = %v, want %v", p, "foo")
Param, }
"stringid", })
true,
"string", t.Run("test missing required parameter", func(t *testing.T) {
}, var p string
wantP: string("foo"), err := PathParam(context.Background(), Param, &p, "stringid", true)
wantErr: false, if (err != nil) != true {
}, t.Errorf("PathParam() error = %v, wantErr %v", err, true)
{ return
name: "test missing required parameter", }
args: args{ if !reflect.DeepEqual(p, "") {
context.WithValue(context.Background(), contextKey("stringid"), ""), t.Errorf("PathParam() = %v, want %v", p, "")
Param, }
"stringid", })
true,
"string", t.Run("test unknown type parameter", func(t *testing.T) {
}, var p complex64
wantP: nil, err := PathParam(context.WithValue(context.Background(), contextKey("stringid"), "foo"),
wantErr: true, Param, p, "stringid", true)
}, if (err != nil) != true {
{ t.Errorf("PathParam() error = %v, wantErr %v", err, true)
name: "test unknown type parameter", return
args: args{ }
context.WithValue(context.Background(), contextKey("stringid"), "foo"), if !reflect.DeepEqual(p, complex64(0)) {
Param, t.Errorf("PathParam() = %v, want %v", p, complex64(0))
"stringid", }
true, })
"not_a_real_type",
},
wantP: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotP, err := PathParam(tt.args.ctx, tt.args.Param, tt.args.paramName, tt.args.required, tt.args.dt)
if (err != nil) != tt.wantErr {
t.Errorf("PathParam() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotP, tt.wantP) {
t.Errorf("PathParam() = %v, want %v", gotP, tt.wantP)
}
})
}
} }
func toValues(s string) url.Values { func toValues(s string) url.Values {
v, _ := url.ParseQuery(s) v, _ := url.ParseQuery(s)
return v return v
} }
func TestMappedParam(t *testing.T) {
type args struct { func TestMappedParam(t *testing.T) {
query url.Values t.Run("test int64 parse", func(t *testing.T) {
paramName string var p int64
required bool err := mappedParam(toValues("x=123"), "x", &p, true)
dt string if (err != nil) != false {
} t.Errorf("QueryParam() error = %v, wantErr %v", err, false)
tests := []struct { return
name string }
args args if !reflect.DeepEqual(p, int64(123)) {
wantP any t.Errorf("QueryParam() = %v, want %v", p, int64(123))
wantErr bool }
}{ })
{
name: "test int64 parse", t.Run("test int32 parse", func(t *testing.T) {
args: args{ var p int32
toValues("x=123"), err := mappedParam(toValues("x=123"), "x", &p, true)
"x", if (err != nil) != false {
true, t.Errorf("QueryParam() error = %v, wantErr %v", err, false)
"int64", return
}, }
wantP: int64(123), if !reflect.DeepEqual(p, int32(123)) {
wantErr: false, t.Errorf("QueryParam() = %v, want %v", p, int32(123))
}, }
{ })
name: "test int32 parse",
args: args{ t.Run("test bool parse", func(t *testing.T) {
toValues("x=123"), var p bool
"x", err := mappedParam(toValues("x=true"), "x", &p, true)
true, if (err != nil) != false {
"int32", t.Errorf("QueryParam() error = %v, wantErr %v", err, false)
}, return
wantP: int32(123), }
wantErr: false, if !reflect.DeepEqual(p, true) {
}, t.Errorf("QueryParam() = %v, want %v", p, true)
{ }
name: "test bool parse", })
args: args{
toValues("x=true"), t.Run("test string parse", func(t *testing.T) {
"x", var p string
true, err := mappedParam(toValues("x=foobar"), "x", &p, true)
"bool", if (err != nil) != false {
}, t.Errorf("QueryParam() error = %v, wantErr %v", err, false)
wantP: bool(true), return
wantErr: false, }
}, if !reflect.DeepEqual(p, "foobar") {
{ t.Errorf("QueryParam() = %v, want %v", p, "foobar")
name: "test string parse", }
args: args{ })
toValues("x=foobar"),
"x", t.Run("test []int64 parse", func(t *testing.T) {
true, var p []int64
"string", err := mappedParam(toValues("x=123&x=456"), "x", &p, true)
}, if (err != nil) != false {
wantP: string("foobar"), t.Errorf("QueryParam() error = %v, wantErr %v", err, false)
wantErr: false, return
}, }
{ if !reflect.DeepEqual(p, []int64{int64(123), int64(456)}) {
name: "test []int64 parse", t.Errorf("QueryParam() = %v, want %v", p, []int64{int64(123), int64(456)})
args: args{ }
toValues("x=123&x=456"), })
"x",
true, t.Run("test []int64 bad parse", func(t *testing.T) {
"[]int64", var p []int64
}, err := mappedParam(toValues("x=123&x=4q56"), "x", &p, true)
wantP: []int64{int64(123), int64(456)}, if (err != nil) != true {
wantErr: false, t.Errorf("QueryParam() error = %v, wantErr %v", err, true)
}, return
{ }
name: "test []int64 bad parse", if !reflect.DeepEqual(p, []int64{123}) {
args: args{ t.Errorf("QueryParam() = %v, want %v", p, []int64{})
toValues("x=123&x=4q56"), }
"x", })
true,
"[]int64", t.Run("test []int32 parse", func(t *testing.T) {
}, var p []int32
wantP: nil, err := mappedParam(toValues("x=123&x=456"), "x", &p, true)
wantErr: true, if (err != nil) != false {
}, t.Errorf("QueryParam() error = %v, wantErr %v", err, false)
{ return
name: "test []int32 parse", }
args: args{ if !reflect.DeepEqual(p, []int32{int32(123), int32(456)}) {
toValues("x=123&x=456"), t.Errorf("QueryParam() = %v, want %v", p, []int32{int32(123), int32(456)})
"x", }
true, })
"[]int32",
}, t.Run("test []int32 bad parse", func(t *testing.T) {
wantP: []int32{int32(123), int32(456)}, var p []int32
wantErr: false, err := mappedParam(toValues("x=123&x=4q56"), "x", &p, true)
}, if (err != nil) != true {
{ t.Errorf("QueryParam() error = %v, wantErr %v", err, true)
name: "test []int32 bad parse", return
args: args{ }
toValues("x=123&x=4q56"), if !reflect.DeepEqual(p, []int32{123}) {
"x", t.Errorf("QueryParam() = %v, want %v", p, []int32{123})
true, }
"[]int32", })
},
wantP: nil, t.Run("test []string parse", func(t *testing.T) {
wantErr: true, var p []string
}, err := mappedParam(toValues("x=foo&x=bar"), "x", &p, true)
{ if (err != nil) != false {
name: "test []string parse", t.Errorf("QueryParam() error = %v, wantErr %v", err, false)
args: args{ return
toValues("x=foo&x=bar"), }
"x", if !reflect.DeepEqual(p, []string{"foo", "bar"}) {
true, t.Errorf("QueryParam() = %v, want %v", p, []string{"foo", "bar"})
"[]string", }
}, })
wantP: []string{"foo", "bar"},
wantErr: false, t.Run("test missing required parameter", func(t *testing.T) {
}, var p string
{ err := mappedParam(toValues("y=hello"), "x", &p, true)
name: "test missing required parameter", if (err != nil) != true {
args: args{ t.Errorf("QueryParam() error = %v, wantErr %v", err, true)
toValues("y=hello"), return
"x", }
true, if !reflect.DeepEqual(p, "") {
"string", t.Errorf("QueryParam() = %v, want %v", p, "")
}, }
wantP: nil, })
wantErr: true,
}, t.Run("test unknown type parameter", func(t *testing.T) {
{ var p complex64
name: "test unknown type parameter", err := mappedParam(toValues("x=hello"), "x", &p, true)
args: args{ if (err != nil) != true {
toValues("x=hello"), t.Errorf("QueryParam() error = %v, wantErr %v", err, true)
"x", return
true, }
"not_a_real_type", if !reflect.DeepEqual(p, complex64(0)) {
}, t.Errorf("QueryParam() = %v, want %v", p, complex64(0))
wantP: nil, }
wantErr: true, })
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotP, err := mappedParam(tt.args.query, tt.args.paramName, tt.args.required, tt.args.dt)
if (err != nil) != tt.wantErr {
t.Errorf("QueryParam() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotP, tt.wantP) {
t.Errorf("QueryParam() = %v, want %v", gotP, tt.wantP)
}
})
}
} }
func TestBodyParam(t *testing.T) { func TestBodyParam(t *testing.T) {
@ -310,3 +277,11 @@ func TestBodyParam(t *testing.T) {
}) })
} }
} }
func Ptr[T any](v T) *T {
return &v
}
func NilString() *string {
return nil
}