diff --git a/context.go b/context.go index 8225124..2f0e2d8 100644 --- a/context.go +++ b/context.go @@ -12,7 +12,9 @@ import ( "github.com/gin-gonic/gin/render" "github.com/julienschmidt/httprouter" "log" + "net" "net/http" + "strings" ) const ( @@ -197,15 +199,84 @@ func (c *Context) MustGet(key string) interface{} { return value } +func ipInMasks(ip net.IP, masks []interface{}) bool { + for _, proxy := range masks { + var mask *net.IPNet + var err error + + switch t := proxy.(type) { + case string: + if _, mask, err = net.ParseCIDR(t); err != nil { + panic(err) + } + case net.IP: + mask = &net.IPNet{IP: t, Mask: net.CIDRMask(len(t)*8, len(t)*8)} + case net.IPNet: + mask = &t + } + + if mask.Contains(ip) { + return true + } + } + + return false +} + +// the ForwardedFor middleware unwraps the X-Forwarded-For headers, be careful to only use this +// middleware if you've got servers in front of this server. The list with (known) proxies and +// local ips are being filtered out of the forwarded for list, giving the last not local ip being +// the real client ip. +func ForwardedFor(proxies ...interface{}) HandlerFunc { + if len(proxies) == 0 { + // default to local ips + var reservedLocalIps = []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"} + + proxies = make([]interface{}, len(reservedLocalIps)) + + for i, v := range reservedLocalIps { + proxies[i] = v + } + } + + return func(c *Context) { + // the X-Forwarded-For header contains an array with left most the client ip, then + // comma separated, all proxies the request passed. The last proxy appears + // as the remote address of the request. Returning the client + // ip to comply with default RemoteAddr response. + + // check if remoteaddr is local ip or in list of defined proxies + remoteIp := net.ParseIP(strings.Split(c.Request.RemoteAddr, ":")[0]) + + if !ipInMasks(remoteIp, proxies) { + return + } + + if forwardedFor := c.Request.Header.Get("X-Forwarded-For"); forwardedFor != "" { + parts := strings.Split(forwardedFor, ",") + + for i := len(parts) - 1; i >= 0; i-- { + part := parts[i] + + ip := net.ParseIP(strings.TrimSpace(part)) + + if ipInMasks(ip, proxies) { + continue + } + + // returning remote addr conform the original remote addr format + c.Request.RemoteAddr = ip.String() + ":0" + + // remove forwarded for address + c.Request.Header.Set("X-Forwarded-For", "") + return + } + } + } +} + func (c *Context) ClientIP() string { - clientIP := c.Request.Header.Get("X-Real-IP") - if len(clientIP) == 0 { - clientIP = c.Request.Header.Get("X-Forwarded-For") - } - if len(clientIP) == 0 { - clientIP = c.Request.RemoteAddr - } - return clientIP + return c.Request.RemoteAddr } /************************************/ diff --git a/context_test.go b/context_test.go index 8435ac5..851a56c 100644 --- a/context_test.go +++ b/context_test.go @@ -440,3 +440,44 @@ func TestBindingJSONMalformed(t *testing.T) { t.Errorf("Content-Type should not be application/json, was %s", w.HeaderMap.Get("Content-Type")) } } + +func TestClientIP(t *testing.T) { + r := New() + + var clientIP string = "" + r.GET("/", func(c *Context) { + clientIP = c.ClientIP() + }) + + body := bytes.NewBuffer([]byte("")) + req, _ := http.NewRequest("GET", "/", body) + req.RemoteAddr = "clientip:1234" + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if clientIP != "clientip:1234" { + t.Errorf("ClientIP should not be %s, but clientip:1234", clientIP) + } +} + +func TestClientIPWithXForwardedForWithProxy(t *testing.T) { + r := New() + r.Use(ForwardedFor()) + + var clientIP string = "" + r.GET("/", func(c *Context) { + clientIP = c.ClientIP() + }) + + body := bytes.NewBuffer([]byte("")) + req, _ := http.NewRequest("GET", "/", body) + req.RemoteAddr = "172.16.8.3:1234" + req.Header.Set("X-Real-Ip", "realip") + req.Header.Set("X-Forwarded-For", "1.2.3.4, 10.10.0.4, 192.168.0.43, 172.16.8.4") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if clientIP != "1.2.3.4:0" { + t.Errorf("ClientIP should not be %s, but 1.2.3.4:0", clientIP) + } +}