diff --git a/binding/binding.go b/binding/binding.go new file mode 100644 index 0000000..c826073 --- /dev/null +++ b/binding/binding.go @@ -0,0 +1,190 @@ +package binding + +import ( + "encoding/json" + "encoding/xml" + "errors" + "net/http" + "reflect" + "strconv" + "strings" +) + +type ( + Binding interface { + Bind(*http.Request, interface{}) error + } + + // JSON binding + jsonBinding struct{} + + // XML binding + xmlBinding struct{} + + // // form binding + formBinding struct{} +) + +var ( + JSON = jsonBinding{} + XML = xmlBinding{} + Form = formBinding{} // todo +) + +func (_ jsonBinding) Bind(req *http.Request, obj interface{}) error { + decoder := json.NewDecoder(req.Body) + if err := decoder.Decode(obj); err == nil { + return Validate(obj) + } else { + return err + } +} + +func (_ xmlBinding) Bind(req *http.Request, obj interface{}) error { + decoder := xml.NewDecoder(req.Body) + if err := decoder.Decode(obj); err == nil { + return Validate(obj) + } else { + return err + } +} + +func (_ formBinding) Bind(req *http.Request, obj interface{}) error { + if err := req.ParseForm(); err != nil { + return err + } + if err := mapForm(obj, req.Form); err != nil { + return err + } + return Validate(obj) +} + +func mapForm(ptr interface{}, form map[string][]string) error { + typ := reflect.TypeOf(ptr).Elem() + formStruct := reflect.ValueOf(ptr).Elem() + for i := 0; i < typ.NumField(); i++ { + typeField := typ.Field(i) + if inputFieldName := typeField.Tag.Get("form"); inputFieldName != "" { + structField := formStruct.Field(i) + if !structField.CanSet() { + continue + } + + inputValue, exists := form[inputFieldName] + if !exists { + continue + } + numElems := len(inputValue) + if structField.Kind() == reflect.Slice && numElems > 0 { + sliceOf := structField.Type().Elem().Kind() + slice := reflect.MakeSlice(structField.Type(), numElems, numElems) + for i := 0; i < numElems; i++ { + if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil { + return err + } + } + formStruct.Elem().Field(i).Set(slice) + } else { + if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil { + return err + } + } + } + } + return nil +} + +func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { + switch valueKind { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if val == "" { + val = "0" + } + intVal, err := strconv.Atoi(val) + if err != nil { + return err + } else { + structField.SetInt(int64(intVal)) + } + case reflect.Bool: + if val == "" { + val = "false" + } + boolVal, err := strconv.ParseBool(val) + if err != nil { + return err + } else { + structField.SetBool(boolVal) + } + case reflect.Float32: + if val == "" { + val = "0.0" + } + floatVal, err := strconv.ParseFloat(val, 32) + if err != nil { + return err + } else { + structField.SetFloat(floatVal) + } + case reflect.Float64: + if val == "" { + val = "0.0" + } + floatVal, err := strconv.ParseFloat(val, 64) + if err != nil { + return err + } else { + structField.SetFloat(floatVal) + } + case reflect.String: + structField.SetString(val) + } + return nil +} + +// Don't pass in pointers to bind to. Can lead to bugs. See: +// https://github.com/codegangsta/martini-contrib/issues/40 +// https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659 +func ensureNotPointer(obj interface{}) { + if reflect.TypeOf(obj).Kind() == reflect.Ptr { + panic("Pointers are not accepted as binding models") + } +} + +func Validate(obj interface{}) error { + + typ := reflect.TypeOf(obj) + val := reflect.ValueOf(obj) + + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + val = val.Elem() + } + + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + fieldValue := val.Field(i).Interface() + zero := reflect.Zero(field.Type).Interface() + + // Validate nested and embedded structs (if pointer, only do so if not nil) + if field.Type.Kind() == reflect.Struct || + (field.Type.Kind() == reflect.Ptr && !reflect.DeepEqual(zero, fieldValue)) { + if err := Validate(fieldValue); err != nil { + return err + } + } + + if strings.Index(field.Tag.Get("binding"), "required") > -1 { + if reflect.DeepEqual(zero, fieldValue) { + name := field.Name + if j := field.Tag.Get("json"); j != "" { + name = j + } else if f := field.Tag.Get("form"); f != "" { + name = f + } + return errors.New("Required " + name) + } + } + } + return nil +} diff --git a/deprecated.go b/deprecated.go new file mode 100644 index 0000000..b629098 --- /dev/null +++ b/deprecated.go @@ -0,0 +1,17 @@ +package gin + +import ( + "github.com/gin-gonic/gin/binding" +) + +// DEPRECATED, use Bind() instead. +// Like ParseBody() but this method also writes a 400 error if the json is not valid. +func (c *Context) EnsureBody(item interface{}) bool { + return c.Bind(item) +} + +// DEPRECATED use bindings directly +// Parses the body content as a JSON input. It decodes the json payload into the struct specified as a pointer. +func (c *Context) ParseBody(item interface{}) error { + return binding.JSON.Bind(c.Req, item) +} diff --git a/examples/example_basic.go b/examples/example_basic.go index 5040357..919580d 100644 --- a/examples/example_basic.go +++ b/examples/example_basic.go @@ -44,7 +44,8 @@ func main() { var json struct { Value string `json:"value" binding:"required"` } - if c.EnsureBody(&json) { + + if c.Bind(&json) { DB[user] = json.Value c.JSON(200, gin.H{"status": "ok"}) } diff --git a/gin.go b/gin.go index 5a17f52..f470fc8 100644 --- a/gin.go +++ b/gin.go @@ -6,6 +6,7 @@ import ( "encoding/xml" "errors" "fmt" + "github.com/gin-gonic/gin/binding" "github.com/julienschmidt/httprouter" "html/template" "log" @@ -16,6 +17,11 @@ import ( const ( AbortIndex = math.MaxInt8 / 2 + MIMEJSON = "application/json" + MIMEHTML = "text/html" + MIMEXML = "application/xml" + MIMEXML2 = "text/xml" + MIMEPlain = "text/plain" ) type ( @@ -379,25 +385,46 @@ func (c *Context) MustGet(key string) interface{} { /******** ENCOGING MANAGEMENT********/ /************************************/ -// Like ParseBody() but this method also writes a 400 error if the json is not valid. -func (c *Context) EnsureBody(item interface{}) bool { - if err := c.ParseBody(item); err != nil { +func filterFlags(content string) string { + for i, a := range content { + if a == ' ' || a == ';' { + return content[:i] + } + } + return content +} + +// This function checks the Content-Type to select a binding engine automatically, +// Depending the "Content-Type" header different bindings are used: +// "application/json" --> JSON binding +// "application/xml" --> XML binding +// else --> returns an error +// if Parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input. It decodes the json payload into the struct specified as a pointer.Like ParseBody() but this method also writes a 400 error if the json is not valid. +func (c *Context) Bind(obj interface{}) bool { + var b binding.Binding + ctype := filterFlags(c.Req.Header.Get("Content-Type")) + switch { + case c.Req.Method == "GET": + b = binding.Form + case ctype == MIMEJSON: + b = binding.JSON + case ctype == MIMEXML || ctype == MIMEXML2: + b = binding.XML + default: + c.Fail(400, errors.New("unknown content-type: "+ctype)) + return false + } + return c.BindWith(obj, b) +} + +func (c *Context) BindWith(obj interface{}, b binding.Binding) bool { + if err := b.Bind(c.Req, obj); err != nil { c.Fail(400, err) return false } return true } -// Parses the body content as a JSON input. It decodes the json payload into the struct specified as a pointer. -func (c *Context) ParseBody(item interface{}) error { - decoder := json.NewDecoder(c.Req.Body) - if err := decoder.Decode(&item); err == nil { - return Validate(c, item) - } else { - return err - } -} - // Serializes the given struct as JSON into the response body in a fast and efficient way. // It also sets the Content-Type as "application/json". func (c *Context) JSON(code int, obj interface{}) { diff --git a/validation.go b/validation.go deleted file mode 100644 index 501ee50..0000000 --- a/validation.go +++ /dev/null @@ -1,45 +0,0 @@ -package gin - -import ( - "errors" - "reflect" - "strings" -) - -func Validate(c *Context, obj interface{}) error { - - var err error - typ := reflect.TypeOf(obj) - val := reflect.ValueOf(obj) - - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - val = val.Elem() - } - - for i := 0; i < typ.NumField(); i++ { - field := typ.Field(i) - fieldValue := val.Field(i).Interface() - zero := reflect.Zero(field.Type).Interface() - - // Validate nested and embedded structs (if pointer, only do so if not nil) - if field.Type.Kind() == reflect.Struct || - (field.Type.Kind() == reflect.Ptr && !reflect.DeepEqual(zero, fieldValue)) { - err = Validate(c, fieldValue) - } - - if strings.Index(field.Tag.Get("binding"), "required") > -1 { - if reflect.DeepEqual(zero, fieldValue) { - name := field.Name - if j := field.Tag.Get("json"); j != "" { - name = j - } else if f := field.Tag.Get("form"); f != "" { - name = f - } - err = errors.New("Required " + name) - c.Error(err, "json validation") - } - } - } - return err -}