diff --git a/.travis.yml b/.travis.yml index 98b4346..3d33833 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,6 @@ language: go - +sudo: false go: - 1.3 + - 1.4 - tip diff --git a/AUTHORS.md b/AUTHORS.md index c09e263..accf8b0 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -1,12 +1,11 @@ -List of all the awesome people working to make Gin the best Web Framework in Go! +List of all the awesome people working to make Gin the best Web Framework in Go. ##gin 0.x series authors -**Lead Developer:** Manu Martinez-Almeida (@manucorporat) -**Staff:** -Javier Provecho (@javierprovecho) +**Original Developer:** Manu Martinez-Almeida (@manucorporat) +**Long-term Maintainer:** Javier Provecho (@javierprovecho) People and companies, who have contributed, in alphabetical order. @@ -31,6 +30,14 @@ People and companies, who have contributed, in alphabetical order. - Added travis CI integration +**@andredublin (Andre Dublin)** +- Fix typo in comment + + +**@bredov (Ludwig Valda Vasquez)** +- Fix html templating in debug mode + + **@bluele (Jun Kimura)** - Fixes code examples in README @@ -41,20 +48,38 @@ People and companies, who have contributed, in alphabetical order. **@dickeyxxx (Jeff Dickey)** - Typos in README +- Add example about serving static files + + +**@dutchcoders (DutchCoders)** +- ★ Fix security bug that allows client to spoof ip +- Fix typo. r.HTMLTemplates -> SetHTMLTemplate **@fmd (Fareed Dudhia)** - Fix typo. SetHTTPTemplate -> SetHTMLTemplate +**@jammie-stackhouse (Jamie Stackhouse) +- Add more shortcuts for router methods + + **@jasonrhansen** - Fix spelling and grammar errors in documentation +**@JasonSoft (Jason Lee)** +- Fix typo in comment + + **@julienschmidt (Julien Schmidt)** - gofmt the code examples +**@kelcecil (Kel Cecil)** +- Fix readme typo + + **@kyledinh (Kyle Dinh)** - Adds RunTLS() @@ -63,6 +88,10 @@ People and companies, who have contributed, in alphabetical order. - Small fixes in README +**@loongmxbt (Saint Asky)** +- Fix typo in example + + **@lucas-clemente (Lucas Clemente)** - ★ work around path.Join removing trailing slashes from routes @@ -73,10 +102,15 @@ People and companies, who have contributed, in alphabetical order. - Fixes Content-Type for json render +**@mirzac (Mirza Ceric)** +- Fix debug printing + + **@mopemope (Yutaka Matsubara)** - ★ Adds Godep support (Dependencies Manager) - Fix variadic parameter in the flexible render API - Fix Corrupted plain render +- Add Pluggable View Renderer Example **@msemenistyi (Mykyta Semenistyi)** @@ -96,6 +130,10 @@ People and companies, who have contributed, in alphabetical order. - Fix Port usage in README. +**@se77en (Damon Zhao)** +- Improve color logging + + **@silasb (Silas Baronda)** - Fixing quotes in README @@ -104,5 +142,17 @@ People and companies, who have contributed, in alphabetical order. - Fixes some texts in README II +**@slimmy (Jimmy Pettersson) +- Added messages for required bindings + + +**@smira (Andrey Smirnov)** +- Add support for ignored/unexported fields in binding + + +**@yosssi (Keiji Yoshida)** +- Fix link in README + + **@yuyabee** - Fixed README \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ec8c5a..461ea02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ #Changelog +###Gin 0.5 (Jan 4, 2015) + +- [NEW] Content Negotiation +- [FIX] Solved security bug that allow a client to spoof ip +- [FIX] Fix unexported/ignored fields in binding + + ###Gin 0.4 (Aug 21, 2014) - [NEW] Development mode @@ -34,7 +41,7 @@ - [NEW] New API for serving static files. gin.Static() - [NEW] gin.H() can be serialized into XML - [NEW] Typed errors. Errors can be typed. Internet/external/custom. -- [NEW] Support for Godebs +- [NEW] Support for Godeps - [NEW] Travis/Godocs badges in README - [NEW] New Bind() and BindWith() methods for parsing request body. - [NEW] Add Content.Copy() diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index d963b7e..20da1fc 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -4,7 +4,7 @@ "Deps": [ { "ImportPath": "github.com/julienschmidt/httprouter", - "Rev": "7deadb6844d2c6ff1dfb812eaa439b87cdaedf20" + "Rev": "aeec11926f7a8fab580383810e1b1bbba99bdaa7" } ] } diff --git a/README.md b/README.md index 94fd949..d38ca93 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,35 @@ [![GoDoc](https://godoc.org/github.com/gin-gonic/gin?status.svg)](https://godoc.org/github.com/gin-gonic/gin) [![Build Status](https://travis-ci.org/gin-gonic/gin.svg)](https://travis-ci.org/gin-gonic/gin) -Gin is a web framework written in Golang. It features a martini-like API with much better performance, up to 40 times faster. If you need performance and good productivity, you will love Gin. -![Gin console logger](http://gin-gonic.github.io/gin/other/console.png) +Gin is a web framework written in Golang. It features a martini-like API with much better performance, up to 40 times faster thanks to [httprouter](https://github.com/julienschmidt/httprouter). If you need performance and good productivity, you will love Gin. + +![Gin console logger](https://gin-gonic.github.io/gin/other/console.png) + +``` +$ cat test.go +``` +```go +package main + +import "github.com/gin-gonic/gin" + +func main() { + router := gin.Default() + router.GET("/", func(c *gin.Context) { + c.String(200, "hello world") + }) + router.GET("/ping", func(c *gin.Context) { + c.String(200, "pong") + }) + router.POST("/submit", func(c *gin.Context) { + c.String(401, "not authorized") + }) + router.PUT("/error", func(c *gin.Context) { + c.String(500, "and error hapenned :(") + }) + router.Run(":8080") +} +``` ##Gin is new, will it be supported? @@ -24,19 +51,19 @@ Yes, Gin is an internal project of [my](https://github.com/manucorporat) upcomin - [x] Flexible rendering system - [ ] More powerful validation API - [ ] Improve documentation -- [ ] Add more cool middlewares, for example redis caching (this also helps developers to understand the framework). +- [X] Add more cool middlewares, for example redis caching (this also helps developers to understand the framework). - [x] Continuous integration ## Start using it -Obviously, you need to have Git and Go! already installed to run Gin. +Obviously, you need to have Git and Go already installed to run Gin. Run this in your terminal ``` go get github.com/gin-gonic/gin ``` -Then import it in your Go! code: +Then import it in your Go code: ``` import "github.com/gin-gonic/gin" @@ -223,7 +250,7 @@ func main() { r := gin.Default() // Example for binding JSON ({"user": "manu", "password": "123"}) - r.POST("/login", func(c *gin.Context) { + r.POST("/loginJSON", func(c *gin.Context) { var json LoginJSON c.Bind(&json) // This will infer what binder to use depending on the content-type header. @@ -234,8 +261,8 @@ func main() { } }) - // Example for binding a HTLM form (user=manu&password=123) - r.POST("/login", func(c *gin.Context) { + // Example for binding a HTML form (user=manu&password=123) + r.POST("/loginHTML", func(c *gin.Context) { var form LoginForm c.BindWith(&form, binding.Form) // You can also specify which binder to use. We support binding.Form, binding.JSON and binding.XML. @@ -257,7 +284,7 @@ func main() { func main() { r := gin.Default() - // gin.H is a shortcup for map[string]interface{} + // gin.H is a shortcut for map[string]interface{} r.GET("/someJSON", func(c *gin.Context) { c.JSON(200, gin.H{"message": "hey", "status": 200}) }) @@ -286,6 +313,21 @@ func main() { } ``` +####Serving static files + +Use Engine.ServeFiles(path string, root http.FileSystem): + +```go +func main() { + r := gin.Default() + r.Static("/assets", "./assets") + + // Listen and server on 0.0.0.0:8080 + r.Run(":8080") +} +``` + +Note: this will use `httpNotFound` instead of the Router's `NotFound` handler. ####Serving static files @@ -310,7 +352,7 @@ Using LoadHTMLTemplates() ```go func main() { r := gin.Default() - r.LoadHTMLTemplates("templates/*") + r.LoadHTMLGlob("templates/*") r.GET("/index", func(c *gin.Context) { obj := gin.H{"title": "Main website"} c.HTML(200, "index.tmpl", obj) @@ -320,6 +362,11 @@ func main() { r.Run(":8080") } ``` +```html +

+ {{ .title }} +

+``` You can also use your own html template render @@ -329,7 +376,7 @@ import "html/template" func main() { r := gin.Default() html := template.Must(template.ParseFiles("file1", "file2")) - r.HTMLTemplates = html + r.SetHTMLTemplate(html) // Listen and server on 0.0.0.0:8080 r.Run(":8080") @@ -413,7 +460,7 @@ func main() { // hit "localhost:8080/admin/secrets authorized.GET("/secrets", func(c *gin.Context) { // get user, it was setted by the BasicAuth middleware - user := c.Get(gin.AuthUserKey).(string) + user := c.MustGet(gin.AuthUserKey).(string) if secret, ok := secrets[user]; ok { c.JSON(200, gin.H{"user": user, "secret": secret}) } else { diff --git a/auth.go b/auth.go index 248f97d..7602d72 100644 --- a/auth.go +++ b/auth.go @@ -16,70 +16,29 @@ const ( ) type ( - BasicAuthPair struct { - Code string - User string - } Accounts map[string]string - Pairs []BasicAuthPair + authPair struct { + Value string + User string + } + authPairs []authPair ) -func (a Pairs) Len() int { return len(a) } -func (a Pairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a Pairs) Less(i, j int) bool { return a[i].Code < a[j].Code } - -func processCredentials(accounts Accounts) (Pairs, error) { - if len(accounts) == 0 { - return nil, errors.New("Empty list of authorized credentials.") - } - pairs := make(Pairs, 0, len(accounts)) - for user, password := range accounts { - if len(user) == 0 || len(password) == 0 { - return nil, errors.New("User or password is empty") - } - base := user + ":" + password - code := "Basic " + base64.StdEncoding.EncodeToString([]byte(base)) - pairs = append(pairs, BasicAuthPair{code, user}) - } - // We have to sort the credentials in order to use bsearch later. - sort.Sort(pairs) - return pairs, nil -} - -func secureCompare(given, actual string) bool { - if subtle.ConstantTimeEq(int32(len(given)), int32(len(actual))) == 1 { - return subtle.ConstantTimeCompare([]byte(given), []byte(actual)) == 1 - } else { - /* Securely compare actual to itself to keep constant time, but always return false */ - return subtle.ConstantTimeCompare([]byte(actual), []byte(actual)) == 1 && false - } -} - -func searchCredential(pairs Pairs, auth string) string { - if len(auth) == 0 { - return "" - } - // Search user in the slice of allowed credentials - r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Code >= auth }) - if r < len(pairs) && secureCompare(pairs[r].Code, auth) { - return pairs[r].User - } else { - return "" - } -} +func (a authPairs) Len() int { return len(a) } +func (a authPairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a authPairs) Less(i, j int) bool { return a[i].Value < a[j].Value } // Implements a basic Basic HTTP Authorization. It takes as argument a map[string]string where // the key is the user name and the value is the password. func BasicAuth(accounts Accounts) HandlerFunc { - - pairs, err := processCredentials(accounts) + pairs, err := processAccounts(accounts) if err != nil { panic(err) } return func(c *Context) { // Search user in the slice of allowed credentials - user := searchCredential(pairs, c.Request.Header.Get("Authorization")) - if len(user) == 0 { + user, ok := searchCredential(pairs, c.Request.Header.Get("Authorization")) + if !ok { // Credentials doesn't match, we return 401 Unauthorized and abort request. c.Writer.Header().Set("WWW-Authenticate", "Basic realm=\"Authorization Required\"") c.Fail(401, errors.New("Unauthorized")) @@ -90,3 +49,46 @@ func BasicAuth(accounts Accounts) HandlerFunc { } } } + +func processAccounts(accounts Accounts) (authPairs, error) { + if len(accounts) == 0 { + return nil, errors.New("Empty list of authorized credentials") + } + pairs := make(authPairs, 0, len(accounts)) + for user, password := range accounts { + if len(user) == 0 { + return nil, errors.New("User can not be empty") + } + base := user + ":" + password + value := "Basic " + base64.StdEncoding.EncodeToString([]byte(base)) + pairs = append(pairs, authPair{ + Value: value, + User: user, + }) + } + // We have to sort the credentials in order to use bsearch later. + sort.Sort(pairs) + return pairs, nil +} + +func searchCredential(pairs authPairs, auth string) (string, bool) { + if len(auth) == 0 { + return "", false + } + // Search user in the slice of allowed credentials + r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Value >= auth }) + if r < len(pairs) && secureCompare(pairs[r].Value, auth) { + return pairs[r].User, true + } else { + return "", false + } +} + +func secureCompare(given, actual string) bool { + if subtle.ConstantTimeEq(int32(len(given)), int32(len(actual))) == 1 { + return subtle.ConstantTimeCompare([]byte(given), []byte(actual)) == 1 + } else { + /* Securely compare actual to itself to keep constant time, but always return false */ + return subtle.ConstantTimeCompare([]byte(actual), []byte(actual)) == 1 && false + } +} diff --git a/binding/binding.go b/binding/binding.go index 81ac3fa..92460a5 100644 --- a/binding/binding.go +++ b/binding/binding.go @@ -87,7 +87,7 @@ func mapForm(ptr interface{}, form map[string][]string) error { return err } } - formStruct.Elem().Field(i).Set(slice) + formStruct.Field(i).Set(slice) } else { if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil { return err @@ -169,8 +169,8 @@ func Validate(obj interface{}, parents ...string) error { for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) - // Allow ignored fields in the struct - if field.Tag.Get("form") == "-" { + // Allow ignored and unexported fields in the struct + if field.Tag.Get("form") == "-" || field.PkgPath != "" { continue } diff --git a/context.go b/context.go index be25db8..2f0e2d8 100644 --- a/context.go +++ b/context.go @@ -12,7 +12,9 @@ import ( "github.com/gin-gonic/gin/render" "github.com/julienschmidt/httprouter" "log" + "net" "net/http" + "strings" ) const ( @@ -67,27 +69,29 @@ type Context struct { Engine *Engine handlers []HandlerFunc index int8 + accepted []string } /************************************/ -/********** ROUTES GROUPING *********/ +/********** CONTEXT CREATION ********/ /************************************/ func (engine *Engine) createContext(w http.ResponseWriter, req *http.Request, params httprouter.Params, handlers []HandlerFunc) *Context { - c := engine.cache.Get().(*Context) + c := engine.pool.Get().(*Context) c.writermem.reset(w) c.Request = req c.Params = params c.handlers = handlers c.Keys = nil c.index = -1 + c.accepted = nil c.Errors = c.Errors[0:0] return c } -/************************************/ -/****** FLOW AND ERROR MANAGEMENT****/ -/************************************/ +func (engine *Engine) reuseContext(c *Context) { + engine.pool.Put(c) +} func (c *Context) Copy() *Context { var cp Context = *c @@ -96,6 +100,10 @@ func (c *Context) Copy() *Context { return &cp } +/************************************/ +/*************** FLOW ***************/ +/************************************/ + // Next should be used only in the middlewares. // It executes the pending handlers in the chain inside the calling handler. // See example in github. @@ -107,25 +115,31 @@ func (c *Context) Next() { } } -// Forces the system to do not continue calling the pending handlers. -// For example, the first handler checks if the request is authorized. If it's not, context.Abort(401) should be called. -// The rest of pending handlers would never be called for that request. -func (c *Context) Abort(code int) { - if code >= 0 { - c.Writer.WriteHeader(code) - } +// Forces the system to do not continue calling the pending handlers in the chain. +func (c *Context) Abort() { c.index = AbortIndex } +// Same than AbortWithStatus() but also writes the specified response status code. +// For example, the first handler checks if the request is authorized. If it's not, context.AbortWithStatus(401) should be called. +func (c *Context) AbortWithStatus(code int) { + c.Writer.WriteHeader(code) + c.Abort() +} + +/************************************/ +/********* ERROR MANAGEMENT *********/ +/************************************/ + // Fail is the same as Abort plus an error message. // Calling `context.Fail(500, err)` is equivalent to: // ``` // context.Error("Operation aborted", err) -// context.Abort(500) +// context.AbortWithStatus(500) // ``` func (c *Context) Fail(code int, err error) { c.Error(err, "Operation aborted") - c.Abort(code) + c.AbortWithStatus(code) } func (c *Context) ErrorTyped(err error, typ uint32, meta interface{}) { @@ -144,9 +158,9 @@ func (c *Context) Error(err error, meta interface{}) { } func (c *Context) LastError() error { - s := len(c.Errors) - if s > 0 { - return errors.New(c.Errors[s-1].Err) + nuErrors := len(c.Errors) + if nuErrors > 0 { + return errors.New(c.Errors[nuErrors-1].Err) } else { return nil } @@ -168,9 +182,9 @@ func (c *Context) Set(key string, item interface{}) { // Get returns the value for the given key or an error if the key does not exist. func (c *Context) Get(key string) (interface{}, error) { if c.Keys != nil { - item, ok := c.Keys[key] + value, ok := c.Keys[key] if ok { - return item, nil + return value, nil } } return nil, errors.New("Key does not exist.") @@ -180,13 +194,93 @@ func (c *Context) Get(key string) (interface{}, error) { func (c *Context) MustGet(key string) interface{} { value, err := c.Get(key) if err != nil || value == nil { - log.Panicf("Key %s doesn't exist", key) + log.Panicf("Key %s doesn't exist", value) } return value } +func ipInMasks(ip net.IP, masks []interface{}) bool { + for _, proxy := range masks { + var mask *net.IPNet + var err error + + switch t := proxy.(type) { + case string: + if _, mask, err = net.ParseCIDR(t); err != nil { + panic(err) + } + case net.IP: + mask = &net.IPNet{IP: t, Mask: net.CIDRMask(len(t)*8, len(t)*8)} + case net.IPNet: + mask = &t + } + + if mask.Contains(ip) { + return true + } + } + + return false +} + +// the ForwardedFor middleware unwraps the X-Forwarded-For headers, be careful to only use this +// middleware if you've got servers in front of this server. The list with (known) proxies and +// local ips are being filtered out of the forwarded for list, giving the last not local ip being +// the real client ip. +func ForwardedFor(proxies ...interface{}) HandlerFunc { + if len(proxies) == 0 { + // default to local ips + var reservedLocalIps = []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"} + + proxies = make([]interface{}, len(reservedLocalIps)) + + for i, v := range reservedLocalIps { + proxies[i] = v + } + } + + return func(c *Context) { + // the X-Forwarded-For header contains an array with left most the client ip, then + // comma separated, all proxies the request passed. The last proxy appears + // as the remote address of the request. Returning the client + // ip to comply with default RemoteAddr response. + + // check if remoteaddr is local ip or in list of defined proxies + remoteIp := net.ParseIP(strings.Split(c.Request.RemoteAddr, ":")[0]) + + if !ipInMasks(remoteIp, proxies) { + return + } + + if forwardedFor := c.Request.Header.Get("X-Forwarded-For"); forwardedFor != "" { + parts := strings.Split(forwardedFor, ",") + + for i := len(parts) - 1; i >= 0; i-- { + part := parts[i] + + ip := net.ParseIP(strings.TrimSpace(part)) + + if ipInMasks(ip, proxies) { + continue + } + + // returning remote addr conform the original remote addr format + c.Request.RemoteAddr = ip.String() + ":0" + + // remove forwarded for address + c.Request.Header.Set("X-Forwarded-For", "") + return + } + } + } +} + +func (c *Context) ClientIP() string { + return c.Request.RemoteAddr +} + /************************************/ -/******** ENCODING MANAGEMENT********/ +/********* PARSING REQUEST **********/ /************************************/ // This function checks the Content-Type to select a binding engine automatically, @@ -220,10 +314,14 @@ func (c *Context) BindWith(obj interface{}, b binding.Binding) bool { return true } +/************************************/ +/******** RESPONSE RENDERING ********/ +/************************************/ + func (c *Context) Render(code int, render render.Render, obj ...interface{}) { if err := render.Render(c.Writer, code, obj...); err != nil { c.ErrorTyped(err, ErrorTypeInternal, obj) - c.Abort(500) + c.AbortWithStatus(500) } } @@ -265,9 +363,7 @@ func (c *Context) Data(code int, contentType string, data []byte) { if len(contentType) > 0 { c.Writer.Header().Set("Content-Type", contentType) } - if code >= 0 { - c.Writer.WriteHeader(code) - } + c.Writer.WriteHeader(code) c.Writer.Write(data) } @@ -275,3 +371,64 @@ func (c *Context) Data(code int, contentType string, data []byte) { func (c *Context) File(filepath string) { http.ServeFile(c.Writer, c.Request, filepath) } + +/************************************/ +/******** CONTENT NEGOTIATION *******/ +/************************************/ + +type Negotiate struct { + Offered []string + HTMLPath string + HTMLData interface{} + JSONData interface{} + XMLData interface{} + Data interface{} +} + +func (c *Context) Negotiate(code int, config Negotiate) { + switch c.NegotiateFormat(config.Offered...) { + case MIMEJSON: + data := chooseData(config.JSONData, config.Data) + c.JSON(code, data) + + case MIMEHTML: + data := chooseData(config.HTMLData, config.Data) + if len(config.HTMLPath) == 0 { + panic("negotiate config is wrong. html path is needed") + } + c.HTML(code, config.HTMLPath, data) + + case MIMEXML: + data := chooseData(config.XMLData, config.Data) + c.XML(code, data) + + default: + c.Fail(http.StatusNotAcceptable, errors.New("the accepted formats are not offered by the server")) + } +} + +func (c *Context) NegotiateFormat(offered ...string) string { + if len(offered) == 0 { + panic("you must provide at least one offer") + } + if c.accepted == nil { + c.accepted = parseAccept(c.Request.Header.Get("Accept")) + } + if len(c.accepted) == 0 { + return offered[0] + + } else { + for _, accepted := range c.accepted { + for _, offert := range offered { + if accepted == offert { + return offert + } + } + } + return "" + } +} + +func (c *Context) SetAccepted(formats ...string) { + c.accepted = formats +} diff --git a/context_test.go b/context_test.go index 6df824c..851a56c 100644 --- a/context_test.go +++ b/context_test.go @@ -232,13 +232,13 @@ func TestBadAbortHandlersChain(t *testing.T) { c.Next() stepsPassed += 1 // after check and abort - c.Abort(409) + c.AbortWithStatus(409) }) r.Use(func(c *Context) { stepsPassed += 1 c.Next() stepsPassed += 1 - c.Abort(403) + c.AbortWithStatus(403) }) // RUN @@ -260,7 +260,7 @@ func TestAbortHandlersChain(t *testing.T) { r := New() r.Use(func(context *Context) { stepsPassed += 1 - context.Abort(409) + context.AbortWithStatus(409) }) r.Use(func(context *Context) { stepsPassed += 1 @@ -440,3 +440,44 @@ func TestBindingJSONMalformed(t *testing.T) { t.Errorf("Content-Type should not be application/json, was %s", w.HeaderMap.Get("Content-Type")) } } + +func TestClientIP(t *testing.T) { + r := New() + + var clientIP string = "" + r.GET("/", func(c *Context) { + clientIP = c.ClientIP() + }) + + body := bytes.NewBuffer([]byte("")) + req, _ := http.NewRequest("GET", "/", body) + req.RemoteAddr = "clientip:1234" + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if clientIP != "clientip:1234" { + t.Errorf("ClientIP should not be %s, but clientip:1234", clientIP) + } +} + +func TestClientIPWithXForwardedForWithProxy(t *testing.T) { + r := New() + r.Use(ForwardedFor()) + + var clientIP string = "" + r.GET("/", func(c *Context) { + clientIP = c.ClientIP() + }) + + body := bytes.NewBuffer([]byte("")) + req, _ := http.NewRequest("GET", "/", body) + req.RemoteAddr = "172.16.8.3:1234" + req.Header.Set("X-Real-Ip", "realip") + req.Header.Set("X-Forwarded-For", "1.2.3.4, 10.10.0.4, 192.168.0.43, 172.16.8.4") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if clientIP != "1.2.3.4:0" { + t.Errorf("ClientIP should not be %s, but 1.2.3.4:0", clientIP) + } +} diff --git a/deprecated.go b/deprecated.go index eb248dd..7188153 100644 --- a/deprecated.go +++ b/deprecated.go @@ -41,7 +41,7 @@ func (engine *Engine) LoadHTMLTemplates(pattern string) { engine.LoadHTMLGlob(pattern) } -// DEPRECATED. Use NotFound() instead +// DEPRECATED. Use NoRoute() instead func (engine *Engine) NotFound404(handlers ...HandlerFunc) { engine.NoRoute(handlers...) } diff --git a/examples/pluggable_renderer/example_pongo2.go b/examples/pluggable_renderer/example_pongo2.go new file mode 100644 index 0000000..9f745e1 --- /dev/null +++ b/examples/pluggable_renderer/example_pongo2.go @@ -0,0 +1,58 @@ +package main + +import ( + "github.com/flosch/pongo2" + "github.com/gin-gonic/gin" + "net/http" +) + +type pongoRender struct { + cache map[string]*pongo2.Template +} + +func newPongoRender() *pongoRender { + return &pongoRender{map[string]*pongo2.Template{}} +} + +func writeHeader(w http.ResponseWriter, code int, contentType string) { + if code >= 0 { + w.Header().Set("Content-Type", contentType) + w.WriteHeader(code) + } +} + +func (p *pongoRender) Render(w http.ResponseWriter, code int, data ...interface{}) error { + file := data[0].(string) + ctx := data[1].(pongo2.Context) + var t *pongo2.Template + + if tmpl, ok := p.cache[file]; ok { + t = tmpl + } else { + tmpl, err := pongo2.FromFile(file) + if err != nil { + return err + } + p.cache[file] = tmpl + t = tmpl + } + writeHeader(w, code, "text/html") + return t.ExecuteWriter(ctx, w) +} + +func main() { + r := gin.Default() + r.HTMLRender = newPongoRender() + + r.GET("/index", func(c *gin.Context) { + name := c.Request.FormValue("name") + ctx := pongo2.Context{ + "title": "Gin meets pongo2 !", + "name": name, + } + c.HTML(200, "index.html", ctx) + }) + + // Listen and server on 0.0.0.0:8080 + r.Run(":8080") +} diff --git a/examples/pluggable_renderer/index.html b/examples/pluggable_renderer/index.html new file mode 100644 index 0000000..8b293ed --- /dev/null +++ b/examples/pluggable_renderer/index.html @@ -0,0 +1,12 @@ + + + + + {{ title }} + + + + + Hello {{ name }} ! + + diff --git a/gin.go b/gin.go index 45b3807..37e6e4d 100644 --- a/gin.go +++ b/gin.go @@ -5,13 +5,11 @@ package gin import ( - "fmt" "github.com/gin-gonic/gin/render" "github.com/julienschmidt/httprouter" "html/template" "math" "net/http" - "path" "sync" ) @@ -28,49 +26,31 @@ const ( type ( HandlerFunc func(*Context) - // Used internally to configure router, a RouterGroup is associated with a prefix - // and an array of handlers (middlewares) - RouterGroup struct { - Handlers []HandlerFunc - prefix string - parent *RouterGroup - engine *Engine - } - // Represents the web framework, it wraps the blazing fast httprouter multiplexer and a list of global middlewares. Engine struct { *RouterGroup - HTMLRender render.Render - cache sync.Pool - finalNoRoute []HandlerFunc - noRoute []HandlerFunc - router *httprouter.Router + HTMLRender render.Render + Default404Body []byte + pool sync.Pool + allNoRoute []HandlerFunc + noRoute []HandlerFunc + router *httprouter.Router } ) -func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) { - c := engine.createContext(w, req, nil, engine.finalNoRoute) - // set 404 by default, useful for logging - c.Writer.WriteHeader(404) - c.Next() - if !c.Writer.Written() { - if c.Writer.Status() == 404 { - c.Data(-1, MIMEPlain, []byte("404 page not found")) - } else { - c.Writer.WriteHeaderNow() - } - } - engine.cache.Put(c) -} - // Returns a new blank Engine instance without any middleware attached. // The most basic configuration func New() *Engine { engine := &Engine{} - engine.RouterGroup = &RouterGroup{nil, "/", nil, engine} + engine.RouterGroup = &RouterGroup{ + Handlers: nil, + absolutePath: "/", + engine: engine, + } engine.router = httprouter.New() + engine.Default404Body = []byte("404 page not found") engine.router.NotFound = engine.handle404 - engine.cache.New = func() interface{} { + engine.pool.New = func() interface{} { c := &Context{Engine: engine} c.Writer = &c.writermem return c @@ -86,7 +66,8 @@ func Default() *Engine { } func (engine *Engine) LoadHTMLGlob(pattern string) { - if gin_mode == debugCode { + if IsDebugging() { + render.HTMLDebug.AddGlob(pattern) engine.HTMLRender = render.HTMLDebug } else { templ := template.Must(template.ParseGlob(pattern)) @@ -95,7 +76,8 @@ func (engine *Engine) LoadHTMLGlob(pattern string) { } func (engine *Engine) LoadHTMLFiles(files ...string) { - if gin_mode == debugCode { + if IsDebugging() { + render.HTMLDebug.AddFiles(files...) engine.HTMLRender = render.HTMLDebug } else { templ := template.Must(template.ParseFiles(files...)) @@ -112,151 +94,50 @@ func (engine *Engine) SetHTMLTemplate(templ *template.Template) { // Adds handlers for NoRoute. It return a 404 code by default. func (engine *Engine) NoRoute(handlers ...HandlerFunc) { engine.noRoute = handlers - engine.finalNoRoute = engine.combineHandlers(engine.noRoute) + engine.rebuild404Handlers() } func (engine *Engine) Use(middlewares ...HandlerFunc) { engine.RouterGroup.Use(middlewares...) - engine.finalNoRoute = engine.combineHandlers(engine.noRoute) + engine.rebuild404Handlers() +} + +func (engine *Engine) rebuild404Handlers() { + engine.allNoRoute = engine.combineHandlers(engine.noRoute) +} + +func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) { + c := engine.createContext(w, req, nil, engine.allNoRoute) + // set 404 by default, useful for logging + c.Writer.WriteHeader(404) + c.Next() + if !c.Writer.Written() { + if c.Writer.Status() == 404 { + c.Data(-1, MIMEPlain, engine.Default404Body) + } else { + c.Writer.WriteHeaderNow() + } + } + engine.reuseContext(c) } // ServeHTTP makes the router implement the http.Handler interface. -func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { - engine.router.ServeHTTP(w, req) +func (engine *Engine) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + engine.router.ServeHTTP(writer, request) } -func (engine *Engine) Run(addr string) { - if gin_mode == debugCode { - fmt.Println("[GIN-debug] Listening and serving HTTP on " + addr) - } +func (engine *Engine) Run(addr string) error { + debugPrint("Listening and serving HTTP on %s", addr) if err := http.ListenAndServe(addr, engine); err != nil { - panic(err) + return err } + return nil } -func (engine *Engine) RunTLS(addr string, cert string, key string) { - if gin_mode == debugCode { - fmt.Println("[GIN-debug] Listening and serving HTTPS on " + addr) - } +func (engine *Engine) RunTLS(addr string, cert string, key string) error { + debugPrint("Listening and serving HTTPS on %s", addr) if err := http.ListenAndServeTLS(addr, cert, key, engine); err != nil { - panic(err) + return err } -} - -/************************************/ -/********** ROUTES GROUPING *********/ -/************************************/ - -// Adds middlewares to the group, see example code in github. -func (group *RouterGroup) Use(middlewares ...HandlerFunc) { - group.Handlers = append(group.Handlers, middlewares...) -} - -// Creates a new router group. You should add all the routes that have common middlwares or the same path prefix. -// For example, all the routes that use a common middlware for authorization could be grouped. -func (group *RouterGroup) Group(component string, handlers ...HandlerFunc) *RouterGroup { - prefix := group.pathFor(component) - - return &RouterGroup{ - Handlers: group.combineHandlers(handlers), - parent: group, - prefix: prefix, - engine: group.engine, - } -} - -func (group *RouterGroup) pathFor(p string) string { - joined := path.Join(group.prefix, p) - // Append a '/' if the last component had one, but only if it's not there already - if len(p) > 0 && p[len(p)-1] == '/' && joined[len(joined)-1] != '/' { - return joined + "/" - } - return joined -} - -// Handle registers a new request handle and middlewares with the given path and method. -// The last handler should be the real handler, the other ones should be middlewares that can and should be shared among different routes. -// See the example code in github. -// -// For GET, POST, PUT, PATCH and DELETE requests the respective shortcut -// functions can be used. -// -// This function is intended for bulk loading and to allow the usage of less -// frequently used, non-standardized or custom methods (e.g. for internal -// communication with a proxy). -func (group *RouterGroup) Handle(method, p string, handlers []HandlerFunc) { - p = group.pathFor(p) - handlers = group.combineHandlers(handlers) - if gin_mode == debugCode { - nuHandlers := len(handlers) - name := funcName(handlers[nuHandlers-1]) - fmt.Printf("[GIN-debug] %-5s %-25s --> %s (%d handlers)\n", method, p, name, nuHandlers) - } - group.engine.router.Handle(method, p, func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { - c := group.engine.createContext(w, req, params, handlers) - c.Next() - c.Writer.WriteHeaderNow() - group.engine.cache.Put(c) - }) -} - -// POST is a shortcut for router.Handle("POST", path, handle) -func (group *RouterGroup) POST(path string, handlers ...HandlerFunc) { - group.Handle("POST", path, handlers) -} - -// GET is a shortcut for router.Handle("GET", path, handle) -func (group *RouterGroup) GET(path string, handlers ...HandlerFunc) { - group.Handle("GET", path, handlers) -} - -// DELETE is a shortcut for router.Handle("DELETE", path, handle) -func (group *RouterGroup) DELETE(path string, handlers ...HandlerFunc) { - group.Handle("DELETE", path, handlers) -} - -// PATCH is a shortcut for router.Handle("PATCH", path, handle) -func (group *RouterGroup) PATCH(path string, handlers ...HandlerFunc) { - group.Handle("PATCH", path, handlers) -} - -// PUT is a shortcut for router.Handle("PUT", path, handle) -func (group *RouterGroup) PUT(path string, handlers ...HandlerFunc) { - group.Handle("PUT", path, handlers) -} - -// OPTIONS is a shortcut for router.Handle("OPTIONS", path, handle) -func (group *RouterGroup) OPTIONS(path string, handlers ...HandlerFunc) { - group.Handle("OPTIONS", path, handlers) -} - -// HEAD is a shortcut for router.Handle("HEAD", path, handle) -func (group *RouterGroup) HEAD(path string, handlers ...HandlerFunc) { - group.Handle("HEAD", path, handlers) -} - -// Static serves files from the given file system root. -// Internally a http.FileServer is used, therefore http.NotFound is used instead -// of the Router's NotFound handler. -// To use the operating system's file system implementation, -// use : -// router.Static("/static", "/var/www") -func (group *RouterGroup) Static(p, root string) { - prefix := group.pathFor(p) - p = path.Join(p, "/*filepath") - fileServer := http.StripPrefix(prefix, http.FileServer(http.Dir(root))) - group.GET(p, func(c *Context) { - fileServer.ServeHTTP(c.Writer, c.Request) - }) - group.HEAD(p, func(c *Context) { - fileServer.ServeHTTP(c.Writer, c.Request) - }) -} - -func (group *RouterGroup) combineHandlers(handlers []HandlerFunc) []HandlerFunc { - s := len(group.Handlers) + len(handlers) - h := make([]HandlerFunc, 0, s) - h = append(h, group.Handlers...) - h = append(h, handlers...) - return h + return nil } diff --git a/gin_test.go b/gin_test.go index 3397943..ba74c15 100644 --- a/gin_test.go +++ b/gin_test.go @@ -108,9 +108,8 @@ func testRouteNotOK2(method string, t *testing.T) { if passed == true { t.Errorf(method + " route handler was invoked, when it should not") } - if w.Code != http.StatusNotFound { - // If this fails, it's because httprouter needs to be updated to at least f78f58a0db - t.Errorf("Status code should be %v, was %d. Location: %s", http.StatusNotFound, w.Code, w.HeaderMap.Get("Location")) + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("Status code should be %v, was %d. Location: %s", http.StatusMethodNotAllowed, w.Code, w.HeaderMap.Get("Location")) } } @@ -146,7 +145,7 @@ func TestHandleStaticFile(t *testing.T) { // TEST if w.Code != 200 { - t.Errorf("Response code should be Ok, was: %s", w.Code) + t.Errorf("Response code should be 200, was: %d", w.Code) } if w.Body.String() != "Gin Web Framework" { t.Errorf("Response should be test, was: %s", w.Body.String()) @@ -168,7 +167,7 @@ func TestHandleStaticDir(t *testing.T) { // TEST bodyAsString := w.Body.String() if w.Code != 200 { - t.Errorf("Response code should be Ok, was: %s", w.Code) + t.Errorf("Response code should be 200, was: %d", w.Code) } if len(bodyAsString) == 0 { t.Errorf("Got empty body instead of file tree") diff --git a/logger.go b/logger.go index 56602c0..5054f6e 100644 --- a/logger.go +++ b/logger.go @@ -10,6 +10,17 @@ import ( "time" ) +var ( + green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109}) + white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109}) + yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109}) + red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109}) + blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109}) + magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109}) + cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109}) + reset = string([]byte{27, 91, 48, 109}) +) + func ErrorLogger() HandlerFunc { return ErrorLoggerT(ErrorTypeAll) } @@ -26,14 +37,6 @@ func ErrorLoggerT(typ uint32) HandlerFunc { } } -var ( - green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109}) - white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109}) - yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109}) - red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109}) - reset = string([]byte{27, 91, 48, 109}) -) - func Logger() HandlerFunc { stdlogger := log.New(os.Stdout, "", 0) //errlogger := log.New(os.Stderr, "", 0) @@ -45,38 +48,58 @@ func Logger() HandlerFunc { // Process request c.Next() - // save the IP of the requester - requester := c.Request.Header.Get("X-Real-IP") - // if the requester-header is empty, check the forwarded-header - if len(requester) == 0 { - requester = c.Request.Header.Get("X-Forwarded-For") - } - // if the requester is still empty, use the hard-coded address from the socket - if len(requester) == 0 { - requester = c.Request.RemoteAddr - } - - var color string - code := c.Writer.Status() - switch { - case code >= 200 && code <= 299: - color = green - case code >= 300 && code <= 399: - color = white - case code >= 400 && code <= 499: - color = yellow - default: - color = red - } + // Stop timer end := time.Now() latency := end.Sub(start) - stdlogger.Printf("[GIN] %v |%s %3d %s| %12v | %s %4s %s\n%s", + + clientIP := c.ClientIP() + method := c.Request.Method + statusCode := c.Writer.Status() + statusColor := colorForStatus(statusCode) + methodColor := colorForMethod(method) + + stdlogger.Printf("[GIN] %v |%s %3d %s| %12v | %s |%s %s %-7s %s\n%s", end.Format("2006/01/02 - 15:04:05"), - color, code, reset, + statusColor, statusCode, reset, latency, - requester, - c.Request.Method, c.Request.URL.Path, + clientIP, + methodColor, reset, method, + c.Request.URL.Path, c.Errors.String(), ) } } + +func colorForStatus(code int) string { + switch { + case code >= 200 && code <= 299: + return green + case code >= 300 && code <= 399: + return white + case code >= 400 && code <= 499: + return yellow + default: + return red + } +} + +func colorForMethod(method string) string { + switch { + case method == "GET": + return blue + case method == "POST": + return cyan + case method == "PUT": + return yellow + case method == "DELETE": + return red + case method == "PATCH": + return green + case method == "HEAD": + return magenta + case method == "OPTIONS": + return white + default: + return reset + } +} diff --git a/mode.go b/mode.go index 166c09c..0495b83 100644 --- a/mode.go +++ b/mode.go @@ -5,6 +5,7 @@ package gin import ( + "fmt" "os" ) @@ -22,6 +23,16 @@ const ( ) var gin_mode int = debugCode +var mode_name string = DebugMode + +func init() { + value := os.Getenv(GIN_MODE) + if len(value) == 0 { + SetMode(DebugMode) + } else { + SetMode(value) + } +} func SetMode(value string) { switch value { @@ -32,15 +43,21 @@ func SetMode(value string) { case TestMode: gin_mode = testCode default: - panic("gin mode unknown, the allowed modes are: " + DebugMode + " and " + ReleaseMode) + panic("gin mode unknown: " + value) } + mode_name = value } -func init() { - value := os.Getenv(GIN_MODE) - if len(value) == 0 { - SetMode(DebugMode) - } else { - SetMode(value) +func Mode() string { + return mode_name +} + +func IsDebugging() bool { + return gin_mode == debugCode +} + +func debugPrint(format string, values ...interface{}) { + if IsDebugging() { + fmt.Printf("[GIN-debug] "+format, values...) } } diff --git a/recovery_test.go b/recovery_test.go index 756c7c2..f9047e2 100644 --- a/recovery_test.go +++ b/recovery_test.go @@ -39,7 +39,7 @@ func TestPanicWithAbort(t *testing.T) { r := New() r.Use(Recovery()) r.GET("/recovery", func(c *Context) { - c.Abort(400) + c.AbortWithStatus(400) panic("Oupps, Houston, we have a problem") }) diff --git a/render/render.go b/render/render.go index 699b4e9..a81fffe 100644 --- a/render/render.go +++ b/render/render.go @@ -30,7 +30,10 @@ type ( redirectRender struct{} // Redirects - htmlDebugRender struct{} + htmlDebugRender struct { + files []string + globs []string + } // form binding HTMLRender struct { @@ -43,7 +46,7 @@ var ( XML = xmlRender{} Plain = plainRender{} Redirect = redirectRender{} - HTMLDebug = htmlDebugRender{} + HTMLDebug = &htmlDebugRender{} ) func writeHeader(w http.ResponseWriter, code int, contentType string) { @@ -82,14 +85,33 @@ func (_ plainRender) Render(w http.ResponseWriter, code int, data ...interface{} return err } -func (_ htmlDebugRender) Render(w http.ResponseWriter, code int, data ...interface{}) error { +func (r *htmlDebugRender) AddGlob(pattern string) { + r.globs = append(r.globs, pattern) +} + +func (r *htmlDebugRender) AddFiles(files ...string) { + r.files = append(r.files, files...) +} + +func (r *htmlDebugRender) Render(w http.ResponseWriter, code int, data ...interface{}) error { writeHeader(w, code, "text/html") file := data[0].(string) obj := data[1] - t, err := template.ParseFiles(file) - if err != nil { - return err + + t := template.New("") + + if len(r.files) > 0 { + if _, err := t.ParseFiles(r.files...); err != nil { + return err + } } + + for _, glob := range r.globs { + if _, err := t.ParseGlob(glob); err != nil { + return err + } + } + return t.ExecuteTemplate(w, file, obj) } diff --git a/response_writer.go b/response_writer.go index 3ce8414..9899395 100644 --- a/response_writer.go +++ b/response_writer.go @@ -12,6 +12,10 @@ import ( "net/http" ) +const ( + NoWritten = -1 +) + type ( ResponseWriter interface { http.ResponseWriter @@ -20,50 +24,57 @@ type ( http.CloseNotifier Status() int + Size() int Written() bool WriteHeaderNow() } responseWriter struct { http.ResponseWriter - status int - written bool + status int + size int } ) func (w *responseWriter) reset(writer http.ResponseWriter) { w.ResponseWriter = writer w.status = 200 - w.written = false + w.size = NoWritten } func (w *responseWriter) WriteHeader(code int) { if code > 0 { w.status = code - if w.written { + if w.Written() { log.Println("[GIN] WARNING. Headers were already written!") } } } func (w *responseWriter) WriteHeaderNow() { - if !w.written { - w.written = true + if !w.Written() { + w.size = 0 w.ResponseWriter.WriteHeader(w.status) } } func (w *responseWriter) Write(data []byte) (n int, err error) { w.WriteHeaderNow() - return w.ResponseWriter.Write(data) + n, err = w.ResponseWriter.Write(data) + w.size += n + return } func (w *responseWriter) Status() int { return w.status } +func (w *responseWriter) Size() int { + return w.size +} + func (w *responseWriter) Written() bool { - return w.written + return w.size != NoWritten } // Implements the http.Hijacker interface diff --git a/routergroup.go b/routergroup.go new file mode 100644 index 0000000..8e02a40 --- /dev/null +++ b/routergroup.go @@ -0,0 +1,148 @@ +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package gin + +import ( + "github.com/julienschmidt/httprouter" + "net/http" + "path" +) + +// Used internally to configure router, a RouterGroup is associated with a prefix +// and an array of handlers (middlewares) +type RouterGroup struct { + Handlers []HandlerFunc + absolutePath string + engine *Engine +} + +// Adds middlewares to the group, see example code in github. +func (group *RouterGroup) Use(middlewares ...HandlerFunc) { + group.Handlers = append(group.Handlers, middlewares...) +} + +// Creates a new router group. You should add all the routes that have common middlwares or the same path prefix. +// For example, all the routes that use a common middlware for authorization could be grouped. +func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup { + return &RouterGroup{ + Handlers: group.combineHandlers(handlers), + absolutePath: group.calculateAbsolutePath(relativePath), + engine: group.engine, + } +} + +// Handle registers a new request handle and middlewares with the given path and method. +// The last handler should be the real handler, the other ones should be middlewares that can and should be shared among different routes. +// See the example code in github. +// +// For GET, POST, PUT, PATCH and DELETE requests the respective shortcut +// functions can be used. +// +// This function is intended for bulk loading and to allow the usage of less +// frequently used, non-standardized or custom methods (e.g. for internal +// communication with a proxy). +func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers []HandlerFunc) { + absolutePath := group.calculateAbsolutePath(relativePath) + handlers = group.combineHandlers(handlers) + if IsDebugging() { + nuHandlers := len(handlers) + handlerName := nameOfFunction(handlers[nuHandlers-1]) + debugPrint("%-5s %-25s --> %s (%d handlers)\n", httpMethod, absolutePath, handlerName, nuHandlers) + } + + group.engine.router.Handle(httpMethod, absolutePath, func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { + context := group.engine.createContext(w, req, params, handlers) + context.Next() + context.Writer.WriteHeaderNow() + group.engine.reuseContext(context) + }) +} + +// POST is a shortcut for router.Handle("POST", path, handle) +func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) { + group.Handle("POST", relativePath, handlers) +} + +// GET is a shortcut for router.Handle("GET", path, handle) +func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) { + group.Handle("GET", relativePath, handlers) +} + +// DELETE is a shortcut for router.Handle("DELETE", path, handle) +func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) { + group.Handle("DELETE", relativePath, handlers) +} + +// PATCH is a shortcut for router.Handle("PATCH", path, handle) +func (group *RouterGroup) PATCH(relativePath string, handlers ...HandlerFunc) { + group.Handle("PATCH", relativePath, handlers) +} + +// PUT is a shortcut for router.Handle("PUT", path, handle) +func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) { + group.Handle("PUT", relativePath, handlers) +} + +// OPTIONS is a shortcut for router.Handle("OPTIONS", path, handle) +func (group *RouterGroup) OPTIONS(relativePath string, handlers ...HandlerFunc) { + group.Handle("OPTIONS", relativePath, handlers) +} + +// HEAD is a shortcut for router.Handle("HEAD", path, handle) +func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) { + group.Handle("HEAD", relativePath, handlers) +} + +// LINK is a shortcut for router.Handle("LINK", path, handle) +func (group *RouterGroup) LINK(relativePath string, handlers ...HandlerFunc) { + group.Handle("LINK", relativePath, handlers) +} + +// UNLINK is a shortcut for router.Handle("UNLINK", path, handle) +func (group *RouterGroup) UNLINK(relativePath string, handlers ...HandlerFunc) { + group.Handle("UNLINK", relativePath, handlers) +} + +// Static serves files from the given file system root. +// Internally a http.FileServer is used, therefore http.NotFound is used instead +// of the Router's NotFound handler. +// To use the operating system's file system implementation, +// use : +// router.Static("/static", "/var/www") +func (group *RouterGroup) Static(relativePath, root string) { + absolutePath := group.calculateAbsolutePath(relativePath) + handler := group.createStaticHandler(absolutePath, root) + absolutePath = path.Join(absolutePath, "/*filepath") + + // Register GET and HEAD handlers + group.GET(absolutePath, handler) + group.HEAD(absolutePath, handler) +} + +func (group *RouterGroup) createStaticHandler(absolutePath, root string) func(*Context) { + fileServer := http.StripPrefix(absolutePath, http.FileServer(http.Dir(root))) + return func(c *Context) { + fileServer.ServeHTTP(c.Writer, c.Request) + } +} + +func (group *RouterGroup) combineHandlers(handlers []HandlerFunc) []HandlerFunc { + finalSize := len(group.Handlers) + len(handlers) + mergedHandlers := make([]HandlerFunc, 0, finalSize) + mergedHandlers = append(mergedHandlers, group.Handlers...) + return append(mergedHandlers, handlers...) +} + +func (group *RouterGroup) calculateAbsolutePath(relativePath string) string { + if len(relativePath) == 0 { + return group.absolutePath + } + absolutePath := path.Join(group.absolutePath, relativePath) + appendSlash := lastChar(relativePath) == '/' && lastChar(absolutePath) != '/' + if appendSlash { + return absolutePath + "/" + } + return absolutePath +} diff --git a/utils.go b/utils.go index f58097a..43ddaec 100644 --- a/utils.go +++ b/utils.go @@ -8,6 +8,7 @@ import ( "encoding/xml" "reflect" "runtime" + "strings" ) type H map[string]interface{} @@ -37,14 +38,45 @@ func (h H) MarshalXML(e *xml.Encoder, start xml.StartElement) error { } func filterFlags(content string) string { - for i, a := range content { - if a == ' ' || a == ';' { + for i, char := range content { + if char == ' ' || char == ';' { return content[:i] } } return content } -func funcName(f interface{}) string { +func chooseData(custom, wildcard interface{}) interface{} { + if custom == nil { + if wildcard == nil { + panic("negotiation config is invalid") + } + return wildcard + } + return custom +} + +func parseAccept(accept string) []string { + parts := strings.Split(accept, ",") + for i, part := range parts { + index := strings.IndexByte(part, ';') + if index >= 0 { + part = part[0:index] + } + part = strings.TrimSpace(part) + parts[i] = part + } + return parts +} + +func lastChar(str string) uint8 { + size := len(str) + if size == 0 { + panic("The length of the string can't be 0") + } + return str[size-1] +} + +func nameOfFunction(f interface{}) string { return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() }