set engine.TrustedProxies For items that don't use gin.RUN (#2692)
Co-authored-by: Bo-Yi Wu <appleboy.tw@gmail.com>
This commit is contained in:
		@ -1392,14 +1392,10 @@ func TestContextAbortWithError(t *testing.T) {
 | 
				
			|||||||
	assert.True(t, c.IsAborted())
 | 
						assert.True(t, c.IsAborted())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func resetTrustedCIDRs(c *Context) {
 | 
					 | 
				
			||||||
	c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestContextClientIP(t *testing.T) {
 | 
					func TestContextClientIP(t *testing.T) {
 | 
				
			||||||
	c, _ := CreateTestContext(httptest.NewRecorder())
 | 
						c, _ := CreateTestContext(httptest.NewRecorder())
 | 
				
			||||||
	c.Request, _ = http.NewRequest("POST", "/", nil)
 | 
						c.Request, _ = http.NewRequest("POST", "/", nil)
 | 
				
			||||||
	resetTrustedCIDRs(c)
 | 
						c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs()
 | 
				
			||||||
	resetContextForClientIPTests(c)
 | 
						resetContextForClientIPTests(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Legacy tests (validating that the defaults don't break the
 | 
						// Legacy tests (validating that the defaults don't break the
 | 
				
			||||||
@ -1428,57 +1424,47 @@ func TestContextClientIP(t *testing.T) {
 | 
				
			|||||||
	resetContextForClientIPTests(c)
 | 
						resetContextForClientIPTests(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// No trusted proxies
 | 
						// No trusted proxies
 | 
				
			||||||
	c.engine.TrustedProxies = []string{}
 | 
						_ = c.engine.SetTrustedProxies([]string{})
 | 
				
			||||||
	resetTrustedCIDRs(c)
 | 
					 | 
				
			||||||
	c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"}
 | 
						c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"}
 | 
				
			||||||
	assert.Equal(t, "40.40.40.40", c.ClientIP())
 | 
						assert.Equal(t, "40.40.40.40", c.ClientIP())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Last proxy is trusted, but the RemoteAddr is not
 | 
						// Last proxy is trusted, but the RemoteAddr is not
 | 
				
			||||||
	c.engine.TrustedProxies = []string{"30.30.30.30"}
 | 
						_ = c.engine.SetTrustedProxies([]string{"30.30.30.30"})
 | 
				
			||||||
	resetTrustedCIDRs(c)
 | 
					 | 
				
			||||||
	assert.Equal(t, "40.40.40.40", c.ClientIP())
 | 
						assert.Equal(t, "40.40.40.40", c.ClientIP())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Only trust RemoteAddr
 | 
						// Only trust RemoteAddr
 | 
				
			||||||
	c.engine.TrustedProxies = []string{"40.40.40.40"}
 | 
						_ = c.engine.SetTrustedProxies([]string{"40.40.40.40"})
 | 
				
			||||||
	resetTrustedCIDRs(c)
 | 
					 | 
				
			||||||
	assert.Equal(t, "20.20.20.20", c.ClientIP())
 | 
						assert.Equal(t, "20.20.20.20", c.ClientIP())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// All steps are trusted
 | 
						// All steps are trusted
 | 
				
			||||||
	c.engine.TrustedProxies = []string{"40.40.40.40", "30.30.30.30", "20.20.20.20"}
 | 
						_ = c.engine.SetTrustedProxies([]string{"40.40.40.40", "30.30.30.30", "20.20.20.20"})
 | 
				
			||||||
	resetTrustedCIDRs(c)
 | 
					 | 
				
			||||||
	assert.Equal(t, "20.20.20.20", c.ClientIP())
 | 
						assert.Equal(t, "20.20.20.20", c.ClientIP())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Use CIDR
 | 
						// Use CIDR
 | 
				
			||||||
	c.engine.TrustedProxies = []string{"40.40.25.25/16", "30.30.30.30"}
 | 
						_ = c.engine.SetTrustedProxies([]string{"40.40.25.25/16", "30.30.30.30"})
 | 
				
			||||||
	resetTrustedCIDRs(c)
 | 
					 | 
				
			||||||
	assert.Equal(t, "20.20.20.20", c.ClientIP())
 | 
						assert.Equal(t, "20.20.20.20", c.ClientIP())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Use hostname that resolves to all the proxies
 | 
						// Use hostname that resolves to all the proxies
 | 
				
			||||||
	c.engine.TrustedProxies = []string{"foo"}
 | 
						_ = c.engine.SetTrustedProxies([]string{"foo"})
 | 
				
			||||||
	resetTrustedCIDRs(c)
 | 
					 | 
				
			||||||
	assert.Equal(t, "40.40.40.40", c.ClientIP())
 | 
						assert.Equal(t, "40.40.40.40", c.ClientIP())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Use hostname that returns an error
 | 
						// Use hostname that returns an error
 | 
				
			||||||
	c.engine.TrustedProxies = []string{"bar"}
 | 
						_ = c.engine.SetTrustedProxies([]string{"bar"})
 | 
				
			||||||
	resetTrustedCIDRs(c)
 | 
					 | 
				
			||||||
	assert.Equal(t, "40.40.40.40", c.ClientIP())
 | 
						assert.Equal(t, "40.40.40.40", c.ClientIP())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// X-Forwarded-For has a non-IP element
 | 
						// X-Forwarded-For has a non-IP element
 | 
				
			||||||
	c.engine.TrustedProxies = []string{"40.40.40.40"}
 | 
						_ = c.engine.SetTrustedProxies([]string{"40.40.40.40"})
 | 
				
			||||||
	resetTrustedCIDRs(c)
 | 
					 | 
				
			||||||
	c.Request.Header.Set("X-Forwarded-For", " blah ")
 | 
						c.Request.Header.Set("X-Forwarded-For", " blah ")
 | 
				
			||||||
	assert.Equal(t, "40.40.40.40", c.ClientIP())
 | 
						assert.Equal(t, "40.40.40.40", c.ClientIP())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Result from LookupHost has non-IP element. This should never
 | 
						// Result from LookupHost has non-IP element. This should never
 | 
				
			||||||
	// happen, but we should test it to make sure we handle it
 | 
						// happen, but we should test it to make sure we handle it
 | 
				
			||||||
	// gracefully.
 | 
						// gracefully.
 | 
				
			||||||
	c.engine.TrustedProxies = []string{"baz"}
 | 
						_ = c.engine.SetTrustedProxies([]string{"baz"})
 | 
				
			||||||
	resetTrustedCIDRs(c)
 | 
					 | 
				
			||||||
	c.Request.Header.Set("X-Forwarded-For", " 30.30.30.30 ")
 | 
						c.Request.Header.Set("X-Forwarded-For", " 30.30.30.30 ")
 | 
				
			||||||
	assert.Equal(t, "40.40.40.40", c.ClientIP())
 | 
						assert.Equal(t, "40.40.40.40", c.ClientIP())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	c.engine.TrustedProxies = []string{"40.40.40.40"}
 | 
						_ = c.engine.SetTrustedProxies([]string{"40.40.40.40"})
 | 
				
			||||||
	resetTrustedCIDRs(c)
 | 
					 | 
				
			||||||
	c.Request.Header.Del("X-Forwarded-For")
 | 
						c.Request.Header.Del("X-Forwarded-For")
 | 
				
			||||||
	c.engine.RemoteIPHeaders = []string{"X-Forwarded-For", "X-Real-IP"}
 | 
						c.engine.RemoteIPHeaders = []string{"X-Forwarded-For", "X-Real-IP"}
 | 
				
			||||||
	assert.Equal(t, "10.10.10.10", c.ClientIP())
 | 
						assert.Equal(t, "10.10.10.10", c.ClientIP())
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										38
									
								
								gin.go
									
									
									
									
									
								
							
							
						
						
									
										38
									
								
								gin.go
									
									
									
									
									
								
							@ -326,11 +326,11 @@ func iterate(path, method string, routes RoutesInfo, root *node) RoutesInfo {
 | 
				
			|||||||
func (engine *Engine) Run(addr ...string) (err error) {
 | 
					func (engine *Engine) Run(addr ...string) (err error) {
 | 
				
			||||||
	defer func() { debugPrintError(err) }()
 | 
						defer func() { debugPrintError(err) }()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	trustedCIDRs, err := engine.prepareTrustedCIDRs()
 | 
						err = engine.parseTrustedProxies()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	engine.trustedCIDRs = trustedCIDRs
 | 
					
 | 
				
			||||||
	address := resolveAddress(addr)
 | 
						address := resolveAddress(addr)
 | 
				
			||||||
	debugPrint("Listening and serving HTTP on %s\n", address)
 | 
						debugPrint("Listening and serving HTTP on %s\n", address)
 | 
				
			||||||
	err = http.ListenAndServe(address, engine)
 | 
						err = http.ListenAndServe(address, engine)
 | 
				
			||||||
@ -366,6 +366,19 @@ func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) {
 | 
				
			|||||||
	return cidr, nil
 | 
						return cidr, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SetTrustedProxies  set Engine.TrustedProxies
 | 
				
			||||||
 | 
					func (engine *Engine) SetTrustedProxies(trustedProxies []string) error {
 | 
				
			||||||
 | 
						engine.TrustedProxies = trustedProxies
 | 
				
			||||||
 | 
						return engine.parseTrustedProxies()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// parseTrustedProxies parse Engine.TrustedProxies to Engine.trustedCIDRs
 | 
				
			||||||
 | 
					func (engine *Engine) parseTrustedProxies() error {
 | 
				
			||||||
 | 
						trustedCIDRs, err := engine.prepareTrustedCIDRs()
 | 
				
			||||||
 | 
						engine.trustedCIDRs = trustedCIDRs
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// parseIP parse a string representation of an IP and returns a net.IP with the
 | 
					// parseIP parse a string representation of an IP and returns a net.IP with the
 | 
				
			||||||
// minimum byte representation or nil if input is invalid.
 | 
					// minimum byte representation or nil if input is invalid.
 | 
				
			||||||
func parseIP(ip string) net.IP {
 | 
					func parseIP(ip string) net.IP {
 | 
				
			||||||
@ -387,6 +400,11 @@ func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) {
 | 
				
			|||||||
	debugPrint("Listening and serving HTTPS on %s\n", addr)
 | 
						debugPrint("Listening and serving HTTPS on %s\n", addr)
 | 
				
			||||||
	defer func() { debugPrintError(err) }()
 | 
						defer func() { debugPrintError(err) }()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = engine.parseTrustedProxies()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = http.ListenAndServeTLS(addr, certFile, keyFile, engine)
 | 
						err = http.ListenAndServeTLS(addr, certFile, keyFile, engine)
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -398,6 +416,11 @@ func (engine *Engine) RunUnix(file string) (err error) {
 | 
				
			|||||||
	debugPrint("Listening and serving HTTP on unix:/%s", file)
 | 
						debugPrint("Listening and serving HTTP on unix:/%s", file)
 | 
				
			||||||
	defer func() { debugPrintError(err) }()
 | 
						defer func() { debugPrintError(err) }()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = engine.parseTrustedProxies()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	listener, err := net.Listen("unix", file)
 | 
						listener, err := net.Listen("unix", file)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
@ -416,6 +439,11 @@ func (engine *Engine) RunFd(fd int) (err error) {
 | 
				
			|||||||
	debugPrint("Listening and serving HTTP on fd@%d", fd)
 | 
						debugPrint("Listening and serving HTTP on fd@%d", fd)
 | 
				
			||||||
	defer func() { debugPrintError(err) }()
 | 
						defer func() { debugPrintError(err) }()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = engine.parseTrustedProxies()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	f := os.NewFile(uintptr(fd), fmt.Sprintf("fd@%d", fd))
 | 
						f := os.NewFile(uintptr(fd), fmt.Sprintf("fd@%d", fd))
 | 
				
			||||||
	listener, err := net.FileListener(f)
 | 
						listener, err := net.FileListener(f)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@ -431,6 +459,12 @@ func (engine *Engine) RunFd(fd int) (err error) {
 | 
				
			|||||||
func (engine *Engine) RunListener(listener net.Listener) (err error) {
 | 
					func (engine *Engine) RunListener(listener net.Listener) (err error) {
 | 
				
			||||||
	debugPrint("Listening and serving HTTP on listener what's bind with address@%s", listener.Addr())
 | 
						debugPrint("Listening and serving HTTP on listener what's bind with address@%s", listener.Addr())
 | 
				
			||||||
	defer func() { debugPrintError(err) }()
 | 
						defer func() { debugPrintError(err) }()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = engine.parseTrustedProxies()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = http.Serve(listener, engine)
 | 
						err = http.Serve(listener, engine)
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -55,13 +55,74 @@ func TestRunEmpty(t *testing.T) {
 | 
				
			|||||||
	testRequest(t, "http://localhost:8080/example")
 | 
						testRequest(t, "http://localhost:8080/example")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestTrustedCIDRsForRun(t *testing.T) {
 | 
					func TestBadTrustedCIDRsForRun(t *testing.T) {
 | 
				
			||||||
	os.Setenv("PORT", "")
 | 
						os.Setenv("PORT", "")
 | 
				
			||||||
	router := New()
 | 
						router := New()
 | 
				
			||||||
	router.TrustedProxies = []string{"hello/world"}
 | 
						router.TrustedProxies = []string{"hello/world"}
 | 
				
			||||||
	assert.Error(t, router.Run(":8080"))
 | 
						assert.Error(t, router.Run(":8080"))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestBadTrustedCIDRsForRunUnix(t *testing.T) {
 | 
				
			||||||
 | 
						router := New()
 | 
				
			||||||
 | 
						router.TrustedProxies = []string{"hello/world"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						unixTestSocket := filepath.Join(os.TempDir(), "unix_unit_test")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer os.Remove(unixTestSocket)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") })
 | 
				
			||||||
 | 
							assert.Error(t, router.RunUnix(unixTestSocket))
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						// have to wait for the goroutine to start and run the server
 | 
				
			||||||
 | 
						// otherwise the main thread will complete
 | 
				
			||||||
 | 
						time.Sleep(5 * time.Millisecond)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestBadTrustedCIDRsForRunFd(t *testing.T) {
 | 
				
			||||||
 | 
						router := New()
 | 
				
			||||||
 | 
						router.TrustedProxies = []string{"hello/world"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
 | 
				
			||||||
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
						listener, err := net.ListenTCP("tcp", addr)
 | 
				
			||||||
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
						socketFile, err := listener.File()
 | 
				
			||||||
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") })
 | 
				
			||||||
 | 
							assert.Error(t, router.RunFd(int(socketFile.Fd())))
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						// have to wait for the goroutine to start and run the server
 | 
				
			||||||
 | 
						// otherwise the main thread will complete
 | 
				
			||||||
 | 
						time.Sleep(5 * time.Millisecond)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestBadTrustedCIDRsForRunListener(t *testing.T) {
 | 
				
			||||||
 | 
						router := New()
 | 
				
			||||||
 | 
						router.TrustedProxies = []string{"hello/world"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
 | 
				
			||||||
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
						listener, err := net.ListenTCP("tcp", addr)
 | 
				
			||||||
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") })
 | 
				
			||||||
 | 
							assert.Error(t, router.RunListener(listener))
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						// have to wait for the goroutine to start and run the server
 | 
				
			||||||
 | 
						// otherwise the main thread will complete
 | 
				
			||||||
 | 
						time.Sleep(5 * time.Millisecond)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestBadTrustedCIDRsForRunTLS(t *testing.T) {
 | 
				
			||||||
 | 
						os.Setenv("PORT", "")
 | 
				
			||||||
 | 
						router := New()
 | 
				
			||||||
 | 
						router.TrustedProxies = []string{"hello/world"}
 | 
				
			||||||
 | 
						assert.Error(t, router.RunTLS(":8080", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem"))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestRunTLS(t *testing.T) {
 | 
					func TestRunTLS(t *testing.T) {
 | 
				
			||||||
	router := New()
 | 
						router := New()
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user