diff --git a/README.md b/README.md index a22440f..eb9415f 100644 --- a/README.md +++ b/README.md @@ -1836,24 +1836,6 @@ $ curl "http://localhost:8080/getd?field_x=hello&field_d=world" {"d":"world","x":{"FieldX":"hello"}} ``` -**NOTE**: NOT support the follow style struct: - -```go -type StructX struct { - X struct {} `form:"name_x"` // HERE have form -} - -type StructY struct { - Y StructX `form:"name_y"` // HERE have form -} - -type StructZ struct { - Z *StructZ `form:"name_z"` // HERE have form -} -``` - -In a word, only support nested custom struct which have no `form` now. - ### Try to bind body into different structs The normal methods for binding request body consumes `c.Request.Body` and they diff --git a/binding/binding_test.go b/binding/binding_test.go index c9dea34..16ca202 100644 --- a/binding/binding_test.go +++ b/binding/binding_test.go @@ -12,6 +12,7 @@ import ( "mime/multipart" "net/http" "strconv" + "strings" "testing" "time" @@ -57,7 +58,6 @@ type FooStructForTimeTypeFailLocation struct { } type FooStructForMapType struct { - // Unknown type: not support map MapFoo map[string]interface{} `form:"map_foo"` } @@ -303,7 +303,7 @@ func TestBindingFormInvalidName2(t *testing.T) { func TestBindingFormForType(t *testing.T) { testFormBindingForType(t, "POST", "/", "/", - "map_foo=", "bar2=1", "Map") + "map_foo={\"bar\":123}", "map_foo=1", "Map") testFormBindingForType(t, "POST", "/", "/", @@ -508,20 +508,30 @@ func TestBindingYAMLFail(t *testing.T) { `foo:\nbar`, `bar: foo`) } -func createFormPostRequest() *http.Request { - req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo")) +func createFormPostRequest(t *testing.T) *http.Request { + req, err := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo")) + assert.NoError(t, err) req.Header.Set("Content-Type", MIMEPOSTForm) return req } -func createDefaultFormPostRequest() *http.Request { - req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar")) +func createDefaultFormPostRequest(t *testing.T) *http.Request { + req, err := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar")) + assert.NoError(t, err) req.Header.Set("Content-Type", MIMEPOSTForm) return req } -func createFormPostRequestFail() *http.Request { - req, _ := http.NewRequest("POST", "/?map_foo=getfoo", bytes.NewBufferString("map_foo=bar")) +func createFormPostRequestForMap(t *testing.T) *http.Request { + req, err := http.NewRequest("POST", "/?map_foo=getfoo", bytes.NewBufferString("map_foo={\"bar\":123}")) + assert.NoError(t, err) + req.Header.Set("Content-Type", MIMEPOSTForm) + return req +} + +func createFormPostRequestForMapFail(t *testing.T) *http.Request { + req, err := http.NewRequest("POST", "/?map_foo=getfoo", bytes.NewBufferString("map_foo=hello")) + assert.NoError(t, err) req.Header.Set("Content-Type", MIMEPOSTForm) return req } @@ -535,26 +545,42 @@ func createFormMultipartRequest(t *testing.T) *http.Request { assert.NoError(t, mw.SetBoundary(boundary)) assert.NoError(t, mw.WriteField("foo", "bar")) assert.NoError(t, mw.WriteField("bar", "foo")) - req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) + req, err := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) + assert.NoError(t, err) req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) return req } -func createFormMultipartRequestFail(t *testing.T) *http.Request { +func createFormMultipartRequestForMap(t *testing.T) *http.Request { boundary := "--testboundary" body := new(bytes.Buffer) mw := multipart.NewWriter(body) defer mw.Close() assert.NoError(t, mw.SetBoundary(boundary)) - assert.NoError(t, mw.WriteField("map_foo", "bar")) - req, _ := http.NewRequest("POST", "/?map_foo=getfoo", body) + assert.NoError(t, mw.WriteField("map_foo", "{\"bar\":123, \"name\":\"thinkerou\", \"pai\": 3.14}")) + req, err := http.NewRequest("POST", "/?map_foo=getfoo", body) + assert.NoError(t, err) + req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) + return req +} + +func createFormMultipartRequestForMapFail(t *testing.T) *http.Request { + boundary := "--testboundary" + body := new(bytes.Buffer) + mw := multipart.NewWriter(body) + defer mw.Close() + + assert.NoError(t, mw.SetBoundary(boundary)) + assert.NoError(t, mw.WriteField("map_foo", "3.14")) + req, err := http.NewRequest("POST", "/?map_foo=getfoo", body) + assert.NoError(t, err) req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) return req } func TestBindingFormPost(t *testing.T) { - req := createFormPostRequest() + req := createFormPostRequest(t) var obj FooBarStruct assert.NoError(t, FormPost.Bind(req, &obj)) @@ -564,7 +590,7 @@ func TestBindingFormPost(t *testing.T) { } func TestBindingDefaultValueFormPost(t *testing.T) { - req := createDefaultFormPostRequest() + req := createDefaultFormPostRequest(t) var obj FooDefaultBarStruct assert.NoError(t, FormPost.Bind(req, &obj)) @@ -572,8 +598,16 @@ func TestBindingDefaultValueFormPost(t *testing.T) { assert.Equal(t, "hello", obj.Bar) } -func TestBindingFormPostFail(t *testing.T) { - req := createFormPostRequestFail() +func TestBindingFormPostForMap(t *testing.T) { + req := createFormPostRequestForMap(t) + var obj FooStructForMapType + err := FormPost.Bind(req, &obj) + assert.NoError(t, err) + assert.Equal(t, float64(123), obj.MapFoo["bar"].(float64)) +} + +func TestBindingFormPostForMapFail(t *testing.T) { + req := createFormPostRequestForMapFail(t) var obj FooStructForMapType err := FormPost.Bind(req, &obj) assert.Error(t, err) @@ -589,8 +623,18 @@ func TestBindingFormMultipart(t *testing.T) { assert.Equal(t, "foo", obj.Bar) } -func TestBindingFormMultipartFail(t *testing.T) { - req := createFormMultipartRequestFail(t) +func TestBindingFormMultipartForMap(t *testing.T) { + req := createFormMultipartRequestForMap(t) + var obj FooStructForMapType + err := FormMultipart.Bind(req, &obj) + assert.NoError(t, err) + assert.Equal(t, float64(123), obj.MapFoo["bar"].(float64)) + assert.Equal(t, "thinkerou", obj.MapFoo["name"].(string)) + assert.Equal(t, float64(3.14), obj.MapFoo["pai"].(float64)) +} + +func TestBindingFormMultipartForMapFail(t *testing.T) { + req := createFormMultipartRequestForMapFail(t) var obj FooStructForMapType err := FormMultipart.Bind(req, &obj) assert.Error(t, err) @@ -773,6 +817,17 @@ func TestFormBindingFail(t *testing.T) { assert.Error(t, err) } +func TestFormBindingMultipartFail(t *testing.T) { + obj := FooBarStruct{} + req, err := http.NewRequest("POST", "/", strings.NewReader("foo=bar")) + assert.NoError(t, err) + req.Header.Set("Content-Type", MIMEMultipartPOSTForm+";boundary=testboundary") + _, err = req.MultipartReader() + assert.NoError(t, err) + err = Form.Bind(req, &obj) + assert.Error(t, err) +} + func TestFormPostBindingFail(t *testing.T) { b := FormPost assert.Equal(t, "form-urlencoded", b.Name()) @@ -1109,7 +1164,8 @@ func testFormBindingForType(t *testing.T, method, path, badPath, body, badBody s case "Map": obj := FooStructForMapType{} err := b.Bind(req, &obj) - assert.Error(t, err) + assert.NoError(t, err) + assert.Equal(t, float64(123), obj.MapFoo["bar"].(float64)) case "SliceMap": obj := FooStructForSliceMapType{} err := b.Bind(req, &obj) @@ -1317,3 +1373,43 @@ func TestCanSet(t *testing.T) { var c CanSetStruct assert.Nil(t, mapForm(&c, nil)) } + +func formPostRequest(path, body string) *http.Request { + req := requestWithBody("POST", path, body) + req.Header.Add("Content-Type", MIMEPOSTForm) + return req +} + +func TestBindingSliceDefault(t *testing.T) { + var s struct { + Friends []string `form:"friends,default=mike"` + } + req := formPostRequest("", "") + err := Form.Bind(req, &s) + assert.NoError(t, err) + + assert.Len(t, s.Friends, 1) + assert.Equal(t, "mike", s.Friends[0]) +} + +func TestBindingStructField(t *testing.T) { + var s struct { + Opts struct { + Port int + } `form:"opts"` + } + req := formPostRequest("", `opts={"Port": 8000}`) + err := Form.Bind(req, &s) + assert.NoError(t, err) + assert.Equal(t, 8000, s.Opts.Port) +} + +func TestBindingUnknownTypeChan(t *testing.T) { + var s struct { + Stop chan bool `form:"stop"` + } + req := formPostRequest("", "stop=true") + err := Form.Bind(req, &s) + assert.Error(t, err) + assert.Equal(t, errUnknownType, err) +} diff --git a/binding/form_mapping.go b/binding/form_mapping.go index 8eb5c0d..1109e4d 100644 --- a/binding/form_mapping.go +++ b/binding/form_mapping.go @@ -5,6 +5,7 @@ package binding import ( + "encoding/json" "errors" "reflect" "strconv" @@ -12,6 +13,8 @@ import ( "time" ) +var errUnknownType = errors.New("Unknown type") + func mapUri(ptr interface{}, m map[string][]string) error { return mapFormByTag(ptr, m, "uri") } @@ -20,124 +23,153 @@ func mapForm(ptr interface{}, form map[string][]string) error { return mapFormByTag(ptr, form, "form") } +var emptyField = reflect.StructField{} + func mapFormByTag(ptr interface{}, form map[string][]string, tag string) error { - typ := reflect.TypeOf(ptr).Elem() - val := reflect.ValueOf(ptr).Elem() - for i := 0; i < typ.NumField(); i++ { - typeField := typ.Field(i) - structField := val.Field(i) - if !structField.CanSet() { - continue - } - - structFieldKind := structField.Kind() - inputFieldName := typeField.Tag.Get(tag) - inputFieldNameList := strings.Split(inputFieldName, ",") - inputFieldName = inputFieldNameList[0] - var defaultValue string - if len(inputFieldNameList) > 1 { - defaultList := strings.SplitN(inputFieldNameList[1], "=", 2) - if defaultList[0] == "default" { - defaultValue = defaultList[1] - } - } - if inputFieldName == "-" { - continue - } - if inputFieldName == "" { - inputFieldName = typeField.Name - - // if "form" tag is nil, we inspect if the field is a struct or struct pointer. - // this would not make sense for JSON parsing but it does for a form - // since data is flatten - if structFieldKind == reflect.Ptr { - if !structField.Elem().IsValid() { - structField.Set(reflect.New(structField.Type().Elem())) - } - structField = structField.Elem() - structFieldKind = structField.Kind() - } - if structFieldKind == reflect.Struct { - err := mapFormByTag(structField.Addr().Interface(), form, tag) - if err != nil { - return err - } - continue - } - } - inputValue, exists := form[inputFieldName] - - if !exists { - if defaultValue == "" { - continue - } - inputValue = make([]string, 1) - inputValue[0] = defaultValue - } - - numElems := len(inputValue) - if structFieldKind == reflect.Slice && numElems > 0 { - sliceOf := structField.Type().Elem().Kind() - slice := reflect.MakeSlice(structField.Type(), numElems, numElems) - for i := 0; i < numElems; i++ { - if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil { - return err - } - } - val.Field(i).Set(slice) - continue - } - if _, isTime := structField.Interface().(time.Time); isTime { - if err := setTimeField(inputValue[0], typeField, structField); err != nil { - return err - } - continue - } - if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil { - return err - } - } - return nil + _, err := mapping(reflect.ValueOf(ptr), emptyField, form, tag) + return err } -func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { - switch valueKind { - case reflect.Int: - return setIntField(val, 0, structField) - case reflect.Int8: - return setIntField(val, 8, structField) - case reflect.Int16: - return setIntField(val, 16, structField) - case reflect.Int32: - return setIntField(val, 32, structField) - case reflect.Int64: - return setIntField(val, 64, structField) - case reflect.Uint: - return setUintField(val, 0, structField) - case reflect.Uint8: - return setUintField(val, 8, structField) - case reflect.Uint16: - return setUintField(val, 16, structField) - case reflect.Uint32: - return setUintField(val, 32, structField) - case reflect.Uint64: - return setUintField(val, 64, structField) - case reflect.Bool: - return setBoolField(val, structField) - case reflect.Float32: - return setFloatField(val, 32, structField) - case reflect.Float64: - return setFloatField(val, 64, structField) - case reflect.String: - structField.SetString(val) - case reflect.Ptr: - if !structField.Elem().IsValid() { - structField.Set(reflect.New(structField.Type().Elem())) +func mapping(value reflect.Value, field reflect.StructField, form map[string][]string, tag string) (bool, error) { + var vKind = value.Kind() + + if vKind == reflect.Ptr { + var isNew bool + vPtr := value + if value.IsNil() { + isNew = true + vPtr = reflect.New(value.Type().Elem()) } - structFieldElem := structField.Elem() - return setWithProperType(structFieldElem.Kind(), val, structFieldElem) + isSetted, err := mapping(vPtr.Elem(), field, form, tag) + if err != nil { + return false, err + } + if isNew && isSetted { + value.Set(vPtr) + } + return isSetted, nil + } + + ok, err := tryToSetValue(value, field, form, tag) + if err != nil { + return false, err + } + if ok { + return true, nil + } + + if vKind == reflect.Struct { + tValue := value.Type() + + var isSetted bool + for i := 0; i < value.NumField(); i++ { + if !value.Field(i).CanSet() { + continue + } + ok, err := mapping(value.Field(i), tValue.Field(i), form, tag) + if err != nil { + return false, err + } + isSetted = isSetted || ok + } + return isSetted, nil + } + return false, nil +} + +func tryToSetValue(value reflect.Value, field reflect.StructField, form map[string][]string, tag string) (bool, error) { + var tagValue, defaultValue string + var isDefaultExists bool + + tagValue = field.Tag.Get(tag) + tagValue, opts := head(tagValue, ",") + + if tagValue == "-" { // just ignoring this field + return false, nil + } + if tagValue == "" { // default value is FieldName + tagValue = field.Name + } + if tagValue == "" { // when field is "emptyField" variable + return false, nil + } + + var opt string + for len(opts) > 0 { + opt, opts = head(opts, ",") + + k, v := head(opt, "=") + switch k { + case "default": + isDefaultExists = true + defaultValue = v + } + } + + vs, ok := form[tagValue] + if !ok && !isDefaultExists { + return false, nil + } + + switch value.Kind() { + case reflect.Slice: + if !ok { + vs = []string{defaultValue} + } + return true, setSlice(vs, value, field) default: - return errors.New("Unknown type") + var val string + if !ok { + val = defaultValue + } + + if len(vs) > 0 { + val = vs[0] + } + return true, setWithProperType(val, value, field) + } +} + +func setWithProperType(val string, value reflect.Value, field reflect.StructField) error { + switch value.Kind() { + case reflect.Int: + return setIntField(val, 0, value) + case reflect.Int8: + return setIntField(val, 8, value) + case reflect.Int16: + return setIntField(val, 16, value) + case reflect.Int32: + return setIntField(val, 32, value) + case reflect.Int64: + return setIntField(val, 64, value) + case reflect.Uint: + return setUintField(val, 0, value) + case reflect.Uint8: + return setUintField(val, 8, value) + case reflect.Uint16: + return setUintField(val, 16, value) + case reflect.Uint32: + return setUintField(val, 32, value) + case reflect.Uint64: + return setUintField(val, 64, value) + case reflect.Bool: + return setBoolField(val, value) + case reflect.Float32: + return setFloatField(val, 32, value) + case reflect.Float64: + return setFloatField(val, 64, value) + case reflect.String: + value.SetString(val) + case reflect.Struct: + switch value.Interface().(type) { + case time.Time: + return setTimeField(val, field, value) + } + return json.Unmarshal([]byte(val), value.Addr().Interface()) + case reflect.Map: + return json.Unmarshal([]byte(val), value.Addr().Interface()) + default: + return errUnknownType } return nil } @@ -218,3 +250,23 @@ func setTimeField(val string, structField reflect.StructField, value reflect.Val value.Set(reflect.ValueOf(t)) return nil } + +func setSlice(vals []string, value reflect.Value, field reflect.StructField) error { + slice := reflect.MakeSlice(value.Type(), len(vals), len(vals)) + for i, s := range vals { + err := setWithProperType(s, slice.Index(i), field) + if err != nil { + return err + } + } + value.Set(slice) + return nil +} + +func head(str, sep string) (head string, tail string) { + idx := strings.Index(str, sep) + if idx < 0 { + return str, "" + } + return str[:idx], str[idx+len(sep):] +} diff --git a/binding/form_mapping_benchmark_test.go b/binding/form_mapping_benchmark_test.go new file mode 100644 index 0000000..0ef08f0 --- /dev/null +++ b/binding/form_mapping_benchmark_test.go @@ -0,0 +1,61 @@ +// Copyright 2019 Gin Core Team. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package binding + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var form = map[string][]string{ + "name": {"mike"}, + "friends": {"anna", "nicole"}, + "id_number": {"12345678"}, + "id_date": {"2018-01-20"}, +} + +type structFull struct { + Name string `form:"name"` + Age int `form:"age,default=25"` + Friends []string `form:"friends"` + ID *struct { + Number string `form:"id_number"` + DateOfIssue time.Time `form:"id_date" time_format:"2006-01-02" time_utc:"true"` + } + Nationality *string `form:"nationality"` +} + +func BenchmarkMapFormFull(b *testing.B) { + var s structFull + for i := 0; i < b.N; i++ { + mapForm(&s, form) + } + b.StopTimer() + + t := b + assert.Equal(t, "mike", s.Name) + assert.Equal(t, 25, s.Age) + assert.Equal(t, []string{"anna", "nicole"}, s.Friends) + assert.Equal(t, "12345678", s.ID.Number) + assert.Equal(t, time.Date(2018, 1, 20, 0, 0, 0, 0, time.UTC), s.ID.DateOfIssue) + assert.Nil(t, s.Nationality) +} + +type structName struct { + Name string `form:"name"` +} + +func BenchmarkMapFormName(b *testing.B) { + var s structName + for i := 0; i < b.N; i++ { + mapForm(&s, form) + } + b.StopTimer() + + t := b + assert.Equal(t, "mike", s.Name) +}