add prefix from X-Forwarded-Prefix in redirectTrailingSlash (#1238)
* add prefix from X-Forwarded-Prefix in redirectTrailingSlash * added test * fix path import
This commit is contained in:
		
							
								
								
									
										14
									
								
								gin.go
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								gin.go
									
									
									
									
									
								
							@ -10,6 +10,7 @@ import (
 | 
				
			|||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
 | 
						"path"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/gin-gonic/gin/render"
 | 
						"github.com/gin-gonic/gin/render"
 | 
				
			||||||
@ -438,17 +439,20 @@ func serveError(c *Context, code int, defaultMessage []byte) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func redirectTrailingSlash(c *Context) {
 | 
					func redirectTrailingSlash(c *Context) {
 | 
				
			||||||
	req := c.Request
 | 
						req := c.Request
 | 
				
			||||||
	path := req.URL.Path
 | 
						p := req.URL.Path
 | 
				
			||||||
 | 
						if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." {
 | 
				
			||||||
 | 
							p = prefix + "/" + req.URL.Path
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	code := http.StatusMovedPermanently // Permanent redirect, request with GET method
 | 
						code := http.StatusMovedPermanently // Permanent redirect, request with GET method
 | 
				
			||||||
	if req.Method != "GET" {
 | 
						if req.Method != "GET" {
 | 
				
			||||||
		code = http.StatusTemporaryRedirect
 | 
							code = http.StatusTemporaryRedirect
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req.URL.Path = path + "/"
 | 
						req.URL.Path = p + "/"
 | 
				
			||||||
	if length := len(path); length > 1 && path[length-1] == '/' {
 | 
						if length := len(p); length > 1 && p[length-1] == '/' {
 | 
				
			||||||
		req.URL.Path = path[:length-1]
 | 
							req.URL.Path = p[:length-1]
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String())
 | 
						debugPrint("redirecting request %d: %s --> %s", code, p, req.URL.String())
 | 
				
			||||||
	http.Redirect(c.Writer, req, req.URL.String(), code)
 | 
						http.Redirect(c.Writer, req, req.URL.String(), code)
 | 
				
			||||||
	c.writermem.WriteHeaderNow()
 | 
						c.writermem.WriteHeaderNow()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -16,8 +16,16 @@ import (
 | 
				
			|||||||
	"github.com/stretchr/testify/assert"
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func performRequest(r http.Handler, method, path string) *httptest.ResponseRecorder {
 | 
					type header struct {
 | 
				
			||||||
 | 
						Key   string
 | 
				
			||||||
 | 
						Value string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func performRequest(r http.Handler, method, path string, headers ...header) *httptest.ResponseRecorder {
 | 
				
			||||||
	req, _ := http.NewRequest(method, path, nil)
 | 
						req, _ := http.NewRequest(method, path, nil)
 | 
				
			||||||
 | 
						for _, h := range headers {
 | 
				
			||||||
 | 
							req.Header.Add(h.Key, h.Value)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	w := httptest.NewRecorder()
 | 
						w := httptest.NewRecorder()
 | 
				
			||||||
	r.ServeHTTP(w, req)
 | 
						r.ServeHTTP(w, req)
 | 
				
			||||||
	return w
 | 
						return w
 | 
				
			||||||
@ -170,6 +178,13 @@ func TestRouteRedirectTrailingSlash(t *testing.T) {
 | 
				
			|||||||
	w = performRequest(router, "PUT", "/path4/")
 | 
						w = performRequest(router, "PUT", "/path4/")
 | 
				
			||||||
	assert.Equal(t, http.StatusOK, w.Code)
 | 
						assert.Equal(t, http.StatusOK, w.Code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						w = performRequest(router, "GET", "/path2", header{Key: "X-Forwarded-Prefix", Value: "/api"})
 | 
				
			||||||
 | 
						assert.Equal(t, "/api/path2/", w.Header().Get("Location"))
 | 
				
			||||||
 | 
						assert.Equal(t, 301, w.Code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						w = performRequest(router, "GET", "/path2/", header{Key: "X-Forwarded-Prefix", Value: "/api/"})
 | 
				
			||||||
 | 
						assert.Equal(t, 200, w.Code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	router.RedirectTrailingSlash = false
 | 
						router.RedirectTrailingSlash = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	w = performRequest(router, "GET", "/path/")
 | 
						w = performRequest(router, "GET", "/path/")
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user