diff --git a/binding/binding.go b/binding/binding.go index b0f561a..f76efba 100644 --- a/binding/binding.go +++ b/binding/binding.go @@ -4,282 +4,43 @@ package binding -import ( - "encoding/json" - "encoding/xml" - "errors" - "log" - "net/http" - "reflect" - "strconv" - "strings" +import "net/http" + +const ( + MIMEJSON = "application/json" + MIMEHTML = "text/html" + MIMEXML = "application/xml" + MIMEXML2 = "text/xml" + MIMEPlain = "text/plain" + MIMEPOSTForm = "application/x-www-form-urlencoded" + MIMEMultipartPOSTForm = "multipart/form-data" ) -type ( - Binding interface { - Bind(*http.Request, interface{}) error - } - - // JSON binding - jsonBinding struct{} - - // XML binding - xmlBinding struct{} - - // form binding - formBinding struct{} - - // multipart form binding - multipartFormBinding struct{} -) - -const MAX_MEMORY = 1 * 1024 * 1024 +type Binding interface { + Name() string + Bind(*http.Request, interface{}) error +} var ( - JSON = jsonBinding{} - XML = xmlBinding{} - Form = formBinding{} // todo - MultipartForm = multipartFormBinding{} + JSON = jsonBinding{} + XML = xmlBinding{} + GETForm = getFormBinding{} + POSTForm = postFormBinding{} ) -func (_ jsonBinding) Bind(req *http.Request, obj interface{}) error { - decoder := json.NewDecoder(req.Body) - if err := decoder.Decode(obj); err == nil { - return Validate(obj) +func Default(method, contentType string) Binding { + if method == "GET" { + return GETForm } else { - return err - } -} - -func (_ xmlBinding) Bind(req *http.Request, obj interface{}) error { - decoder := xml.NewDecoder(req.Body) - if err := decoder.Decode(obj); err == nil { - return Validate(obj) - } else { - return err - } -} - -func (_ formBinding) Bind(req *http.Request, obj interface{}) error { - if err := req.ParseForm(); err != nil { - return err - } - if err := mapForm(obj, req.Form); err != nil { - return err - } - return Validate(obj) -} - -func (_ multipartFormBinding) Bind(req *http.Request, obj interface{}) error { - if err := req.ParseMultipartForm(MAX_MEMORY); err != nil { - return err - } - if err := mapForm(obj, req.Form); err != nil { - return err - } - return Validate(obj) -} - -func mapForm(ptr interface{}, form map[string][]string) error { - typ := reflect.TypeOf(ptr).Elem() - formStruct := reflect.ValueOf(ptr).Elem() - for i := 0; i < typ.NumField(); i++ { - typeField := typ.Field(i) - if inputFieldName := typeField.Tag.Get("form"); inputFieldName != "" { - structField := formStruct.Field(i) - if !structField.CanSet() { - continue - } - - inputValue, exists := form[inputFieldName] - if !exists { - continue - } - numElems := len(inputValue) - if structField.Kind() == reflect.Slice && numElems > 0 { - sliceOf := structField.Type().Elem().Kind() - slice := reflect.MakeSlice(structField.Type(), numElems, numElems) - for i := 0; i < numElems; i++ { - if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil { - return err - } - } - formStruct.Field(i).Set(slice) - } else { - if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil { - return err - } - } + switch contentType { + case MIMEPOSTForm: + return POSTForm + case MIMEJSON: + return JSON + case MIMEXML, MIMEXML2: + return XML + default: + return GETForm } } - return nil -} - -func setIntField(val string, bitSize int, structField reflect.Value) error { - if val == "" { - val = "0" - } - - intVal, err := strconv.ParseInt(val, 10, bitSize) - if err == nil { - structField.SetInt(intVal) - } - - return err -} - -func setUintField(val string, bitSize int, structField reflect.Value) error { - if val == "" { - val = "0" - } - - uintVal, err := strconv.ParseUint(val, 10, bitSize) - if err == nil { - structField.SetUint(uintVal) - } - - return err -} - -func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { - switch valueKind { - case reflect.Int: - return setIntField(val, 0, structField) - case reflect.Int8: - return setIntField(val, 8, structField) - case reflect.Int16: - return setIntField(val, 16, structField) - case reflect.Int32: - return setIntField(val, 32, structField) - case reflect.Int64: - return setIntField(val, 64, structField) - case reflect.Uint: - return setUintField(val, 0, structField) - case reflect.Uint8: - return setUintField(val, 8, structField) - case reflect.Uint16: - return setUintField(val, 16, structField) - case reflect.Uint32: - return setUintField(val, 32, structField) - case reflect.Uint64: - return setUintField(val, 64, structField) - case reflect.Bool: - if val == "" { - val = "false" - } - boolVal, err := strconv.ParseBool(val) - if err != nil { - return err - } else { - structField.SetBool(boolVal) - } - case reflect.Float32: - if val == "" { - val = "0.0" - } - floatVal, err := strconv.ParseFloat(val, 32) - if err != nil { - return err - } else { - structField.SetFloat(floatVal) - } - case reflect.Float64: - if val == "" { - val = "0.0" - } - floatVal, err := strconv.ParseFloat(val, 64) - if err != nil { - return err - } else { - structField.SetFloat(floatVal) - } - case reflect.String: - structField.SetString(val) - } - return nil -} - -// Don't pass in pointers to bind to. Can lead to bugs. See: -// https://github.com/codegangsta/martini-contrib/issues/40 -// https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659 -func ensureNotPointer(obj interface{}) { - if reflect.TypeOf(obj).Kind() == reflect.Ptr { - log.Panic("Pointers are not accepted as binding models") - } -} - -func Validate(obj interface{}, parents ...string) error { - typ := reflect.TypeOf(obj) - val := reflect.ValueOf(obj) - - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - val = val.Elem() - } - - switch typ.Kind() { - case reflect.Struct: - for i := 0; i < typ.NumField(); i++ { - field := typ.Field(i) - - // Allow ignored and unexported fields in the struct - if len(field.PkgPath) > 0 || field.Tag.Get("form") == "-" { - continue - } - - fieldValue := val.Field(i).Interface() - zero := reflect.Zero(field.Type).Interface() - - if strings.Index(field.Tag.Get("binding"), "required") > -1 { - fieldType := field.Type.Kind() - if fieldType == reflect.Struct { - if reflect.DeepEqual(zero, fieldValue) { - return errors.New("Required " + field.Name) - } - err := Validate(fieldValue, field.Name) - if err != nil { - return err - } - } else if reflect.DeepEqual(zero, fieldValue) { - if len(parents) > 0 { - return errors.New("Required " + field.Name + " on " + parents[0]) - } else { - return errors.New("Required " + field.Name) - } - } else if fieldType == reflect.Slice && field.Type.Elem().Kind() == reflect.Struct { - err := Validate(fieldValue) - if err != nil { - return err - } - } - } else { - fieldType := field.Type.Kind() - if fieldType == reflect.Struct { - if reflect.DeepEqual(zero, fieldValue) { - continue - } - err := Validate(fieldValue, field.Name) - if err != nil { - return err - } - } else if fieldType == reflect.Slice && field.Type.Elem().Kind() == reflect.Struct { - err := Validate(fieldValue, field.Name) - if err != nil { - return err - } - } - } - } - case reflect.Slice: - for i := 0; i < val.Len(); i++ { - fieldValue := val.Index(i).Interface() - err := Validate(fieldValue) - if err != nil { - return err - } - } - default: - return nil - } - return nil } diff --git a/binding/form_mapping.go b/binding/form_mapping.go new file mode 100644 index 0000000..e406245 --- /dev/null +++ b/binding/form_mapping.go @@ -0,0 +1,143 @@ +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package binding + +import ( + "errors" + "fmt" + "log" + "reflect" + "strconv" +) + +func mapForm(ptr interface{}, form map[string][]string) error { + typ := reflect.TypeOf(ptr).Elem() + val := reflect.ValueOf(ptr).Elem() + for i := 0; i < typ.NumField(); i++ { + typeField := typ.Field(i) + structField := val.Field(i) + if !structField.CanSet() { + continue + } + + inputFieldName := typeField.Tag.Get("form") + if inputFieldName == "" { + inputFieldName = typeField.Name + } + inputValue, exists := form[inputFieldName] + fmt.Println("Field: "+inputFieldName+" Value: ", inputValue) + + if !exists { + continue + } + + numElems := len(inputValue) + if structField.Kind() == reflect.Slice && numElems > 0 { + sliceOf := structField.Type().Elem().Kind() + slice := reflect.MakeSlice(structField.Type(), numElems, numElems) + for i := 0; i < numElems; i++ { + if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil { + return err + } + } + val.Field(i).Set(slice) + } else { + if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil { + return err + } + } + + } + return nil +} + +func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { + switch valueKind { + case reflect.Int: + return setIntField(val, 0, structField) + case reflect.Int8: + return setIntField(val, 8, structField) + case reflect.Int16: + return setIntField(val, 16, structField) + case reflect.Int32: + return setIntField(val, 32, structField) + case reflect.Int64: + return setIntField(val, 64, structField) + case reflect.Uint: + return setUintField(val, 0, structField) + case reflect.Uint8: + return setUintField(val, 8, structField) + case reflect.Uint16: + return setUintField(val, 16, structField) + case reflect.Uint32: + return setUintField(val, 32, structField) + case reflect.Uint64: + return setUintField(val, 64, structField) + case reflect.Bool: + return setBoolField(val, structField) + case reflect.Float32: + return setFloatField(val, 32, structField) + case reflect.Float64: + return setFloatField(val, 64, structField) + case reflect.String: + structField.SetString(val) + default: + return errors.New("Unknown type") + } + return nil +} + +func setIntField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0" + } + intVal, err := strconv.ParseInt(val, 10, bitSize) + if err == nil { + field.SetInt(intVal) + } + return err +} + +func setUintField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0" + } + uintVal, err := strconv.ParseUint(val, 10, bitSize) + if err == nil { + field.SetUint(uintVal) + } + return err +} + +func setBoolField(val string, field reflect.Value) error { + if val == "" { + val = "false" + } + boolVal, err := strconv.ParseBool(val) + if err == nil { + field.SetBool(boolVal) + } + return nil +} + +func setFloatField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0.0" + } + floatVal, err := strconv.ParseFloat(val, bitSize) + if err == nil { + field.SetFloat(floatVal) + } + return err +} + +// Don't pass in pointers to bind to. Can lead to bugs. See: +// https://github.com/codegangsta/martini-contrib/issues/40 +// https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659 +func ensureNotPointer(obj interface{}) { + if reflect.TypeOf(obj).Kind() == reflect.Ptr { + log.Panic("Pointers are not accepted as binding models") + } +} diff --git a/binding/get_form.go b/binding/get_form.go new file mode 100644 index 0000000..6226c51 --- /dev/null +++ b/binding/get_form.go @@ -0,0 +1,23 @@ +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package binding + +import "net/http" + +type getFormBinding struct{} + +func (_ getFormBinding) Name() string { + return "get_form" +} + +func (_ getFormBinding) Bind(req *http.Request, obj interface{}) error { + if err := req.ParseForm(); err != nil { + return err + } + if err := mapForm(obj, req.Form); err != nil { + return err + } + return Validate(obj) +} diff --git a/binding/json.go b/binding/json.go new file mode 100644 index 0000000..731626c --- /dev/null +++ b/binding/json.go @@ -0,0 +1,26 @@ +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package binding + +import ( + "encoding/json" + + "net/http" +) + +type jsonBinding struct{} + +func (_ jsonBinding) Name() string { + return "json" +} + +func (_ jsonBinding) Bind(req *http.Request, obj interface{}) error { + decoder := json.NewDecoder(req.Body) + if err := decoder.Decode(obj); err == nil { + return Validate(obj) + } else { + return err + } +} diff --git a/binding/post_form.go b/binding/post_form.go new file mode 100644 index 0000000..9a0f0b6 --- /dev/null +++ b/binding/post_form.go @@ -0,0 +1,23 @@ +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package binding + +import "net/http" + +type postFormBinding struct{} + +func (_ postFormBinding) Name() string { + return "post_form" +} + +func (_ postFormBinding) Bind(req *http.Request, obj interface{}) error { + if err := req.ParseForm(); err != nil { + return err + } + if err := mapForm(obj, req.PostForm); err != nil { + return err + } + return Validate(obj) +} diff --git a/binding/validate.go b/binding/validate.go new file mode 100644 index 0000000..b743405 --- /dev/null +++ b/binding/validate.go @@ -0,0 +1,79 @@ +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package binding + +import ( + "errors" + "reflect" + "strings" +) + +func Validate(obj interface{}) error { + return validate(obj, "{{ROOT}}") +} + +func validate(obj interface{}, parent string) error { + typ, val := inspectObject(obj) + switch typ.Kind() { + case reflect.Struct: + return validateStruct(typ, val, parent) + + case reflect.Slice: + return validateSlice(typ, val, parent) + + default: + return errors.New("The object is not a slice or struct.") + } +} + +func inspectObject(obj interface{}) (typ reflect.Type, val reflect.Value) { + typ = reflect.TypeOf(obj) + val = reflect.ValueOf(obj) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + val = val.Elem() + } + return +} + +func validateSlice(typ reflect.Type, val reflect.Value, parent string) error { + if typ.Elem().Kind() == reflect.Struct { + for i := 0; i < val.Len(); i++ { + itemValue := val.Index(i).Interface() + if err := validate(itemValue, parent); err != nil { + return err + } + } + } + return nil +} + +func validateStruct(typ reflect.Type, val reflect.Value, parent string) error { + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + // Allow ignored and unexported fields in the struct + // TODO should include || field.Tag.Get("form") == "-" + if len(field.PkgPath) > 0 { + continue + } + + fieldValue := val.Field(i).Interface() + requiredField := strings.Index(field.Tag.Get("binding"), "required") > -1 + + if requiredField { + zero := reflect.Zero(field.Type).Interface() + if reflect.DeepEqual(zero, fieldValue) { + return errors.New("Required " + field.Name + " in " + parent) + } + } + fieldType := field.Type.Kind() + if fieldType == reflect.Struct || fieldType == reflect.Slice { + if err := validate(fieldValue, field.Name); err != nil { + return err + } + } + } + return nil +} diff --git a/binding/xml.go b/binding/xml.go new file mode 100644 index 0000000..b6c07c2 --- /dev/null +++ b/binding/xml.go @@ -0,0 +1,25 @@ +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package binding + +import ( + "encoding/xml" + "net/http" +) + +type xmlBinding struct{} + +func (_ xmlBinding) Name() string { + return "xml" +} + +func (_ xmlBinding) Bind(req *http.Request, obj interface{}) error { + decoder := xml.NewDecoder(req.Body) + if err := decoder.Decode(obj); err == nil { + return Validate(obj) + } else { + return err + } +} diff --git a/context.go b/context.go index a092565..c028a79 100644 --- a/context.go +++ b/context.go @@ -179,21 +179,7 @@ func (c *Context) ContentType() string { // else --> returns an error // if Parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input. It decodes the json payload into the struct specified as a pointer.Like ParseBody() but this method also writes a 400 error if the json is not valid. func (c *Context) Bind(obj interface{}) bool { - var b binding.Binding - ctype := filterFlags(c.Request.Header.Get("Content-Type")) - switch { - case c.Request.Method == "GET" || ctype == MIMEPOSTForm: - b = binding.Form - case ctype == MIMEMultipartPOSTForm: - b = binding.MultipartForm - case ctype == MIMEJSON: - b = binding.JSON - case ctype == MIMEXML || ctype == MIMEXML2: - b = binding.XML - default: - c.Fail(400, errors.New("unknown content-type: "+ctype)) - return false - } + b := binding.Default(c.Request.Method, c.ContentType()) return c.BindWith(obj, b) } @@ -283,18 +269,18 @@ type Negotiate struct { func (c *Context) Negotiate(code int, config Negotiate) { switch c.NegotiateFormat(config.Offered...) { - case MIMEJSON: + case binding.MIMEJSON: data := chooseData(config.JSONData, config.Data) c.JSON(code, data) - case MIMEHTML: - data := chooseData(config.HTMLData, config.Data) + case binding.MIMEHTML: if len(config.HTMLPath) == 0 { log.Panic("negotiate config is wrong. html path is needed") } + data := chooseData(config.HTMLData, config.Data) c.HTML(code, config.HTMLPath, data) - case MIMEXML: + case binding.MIMEXML: data := chooseData(config.XMLData, config.Data) c.XML(code, data) diff --git a/deprecated.go b/deprecated.go index a1a1024..ebee67f 100644 --- a/deprecated.go +++ b/deprecated.go @@ -13,6 +13,16 @@ import ( "github.com/gin-gonic/gin/binding" ) +const ( + MIMEJSON = binding.MIMEJSON + MIMEHTML = binding.MIMEHTML + MIMEXML = binding.MIMEXML + MIMEXML2 = binding.MIMEXML2 + MIMEPlain = binding.MIMEPlain + MIMEPOSTForm = binding.MIMEPOSTForm + MIMEMultipartPOSTForm = binding.MIMEMultipartPOSTForm +) + // DEPRECATED, use Bind() instead. // Like ParseBody() but this method also writes a 400 error if the json is not valid. func (c *Context) EnsureBody(item interface{}) bool { diff --git a/gin.go b/gin.go index 6fdb156..fa8b12d 100644 --- a/gin.go +++ b/gin.go @@ -9,19 +9,11 @@ import ( "net/http" "sync" + "github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/render" "github.com/julienschmidt/httprouter" ) -const ( - MIMEJSON = "application/json" - MIMEHTML = "text/html" - MIMEXML = "application/xml" - MIMEXML2 = "text/xml" - MIMEPlain = "text/plain" - MIMEPOSTForm = "application/x-www-form-urlencoded" - MIMEMultipartPOSTForm = "multipart/form-data" -) type ( HandlerFunc func(*Context) @@ -147,7 +139,7 @@ func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) { c.Next() if !c.Writer.Written() { if c.Writer.Status() == 404 { - c.Data(-1, MIMEPlain, engine.Default404Body) + c.Data(-1, binding.MIMEPlain, engine.Default404Body) } else { c.Writer.WriteHeaderNow() } @@ -162,7 +154,7 @@ func (engine *Engine) handle405(w http.ResponseWriter, req *http.Request) { c.Next() if !c.Writer.Written() { if c.Writer.Status() == 405 { - c.Data(-1, MIMEPlain, engine.Default405Body) + c.Data(-1, binding.MIMEPlain, engine.Default405Body) } else { c.Writer.WriteHeaderNow() }