Better unit tests for BasicAuth middleware
This commit is contained in:
parent
4d315f474b
commit
a28104fa21
29
auth.go
29
auth.go
@ -29,6 +29,19 @@ func (a authPairs) Len() int { return len(a) }
|
|||||||
func (a authPairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
func (a authPairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||||
func (a authPairs) Less(i, j int) bool { return a[i].Value < a[j].Value }
|
func (a authPairs) Less(i, j int) bool { return a[i].Value < a[j].Value }
|
||||||
|
|
||||||
|
func (a authPairs) searchCredential(auth string) (string, bool) {
|
||||||
|
if len(auth) == 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
// Search user in the slice of allowed credentials
|
||||||
|
r := sort.Search(len(a), func(i int) bool { return a[i].Value >= auth })
|
||||||
|
if r < len(a) && secureCompare(a[r].Value, auth) {
|
||||||
|
return a[r].User, true
|
||||||
|
} else {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Implements a basic Basic HTTP Authorization. It takes as arguments a map[string]string where
|
// Implements a basic Basic HTTP Authorization. It takes as arguments a map[string]string where
|
||||||
// the key is the user name and the value is the password, as well as the name of the Realm
|
// the key is the user name and the value is the password, as well as the name of the Realm
|
||||||
// (see http://tools.ietf.org/html/rfc2617#section-1.2)
|
// (see http://tools.ietf.org/html/rfc2617#section-1.2)
|
||||||
@ -40,7 +53,7 @@ func BasicAuthForRealm(accounts Accounts, realm string) HandlerFunc {
|
|||||||
pairs := processAccounts(accounts)
|
pairs := processAccounts(accounts)
|
||||||
return func(c *Context) {
|
return func(c *Context) {
|
||||||
// Search user in the slice of allowed credentials
|
// Search user in the slice of allowed credentials
|
||||||
user, ok := searchCredential(pairs, c.Request.Header.Get("Authorization"))
|
user, ok := pairs.searchCredential(c.Request.Header.Get("Authorization"))
|
||||||
if !ok {
|
if !ok {
|
||||||
// Credentials doesn't match, we return 401 Unauthorized and abort request.
|
// Credentials doesn't match, we return 401 Unauthorized and abort request.
|
||||||
c.Writer.Header().Set("WWW-Authenticate", realm)
|
c.Writer.Header().Set("WWW-Authenticate", realm)
|
||||||
@ -80,17 +93,9 @@ func processAccounts(accounts Accounts) authPairs {
|
|||||||
return pairs
|
return pairs
|
||||||
}
|
}
|
||||||
|
|
||||||
func searchCredential(pairs authPairs, auth string) (string, bool) {
|
func authorizationHeader(user, password string) string {
|
||||||
if len(auth) == 0 {
|
base := user + ":" + password
|
||||||
return "", false
|
return "Basic " + base64.StdEncoding.EncodeToString([]byte(base))
|
||||||
}
|
|
||||||
// Search user in the slice of allowed credentials
|
|
||||||
r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Value >= auth })
|
|
||||||
if r < len(pairs) && secureCompare(pairs[r].Value, auth) {
|
|
||||||
return pairs[r].User, true
|
|
||||||
} else {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func secureCompare(given, actual string) bool {
|
func secureCompare(given, actual string) bool {
|
||||||
|
157
auth_test.go
157
auth_test.go
@ -9,77 +9,136 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBasicAuthSucceed(t *testing.T) {
|
func TestBasicAuth(t *testing.T) {
|
||||||
req, _ := http.NewRequest("GET", "/login", nil)
|
accounts := Accounts{
|
||||||
w := httptest.NewRecorder()
|
"admin": "password",
|
||||||
|
"foo": "bar",
|
||||||
|
"bar": "foo",
|
||||||
|
}
|
||||||
|
expectedPairs := authPairs{
|
||||||
|
authPair{
|
||||||
|
User: "admin",
|
||||||
|
Value: "Basic YWRtaW46cGFzc3dvcmQ=",
|
||||||
|
},
|
||||||
|
authPair{
|
||||||
|
User: "bar",
|
||||||
|
Value: "Basic YmFyOmZvbw==",
|
||||||
|
},
|
||||||
|
authPair{
|
||||||
|
User: "foo",
|
||||||
|
Value: "Basic Zm9vOmJhcg==",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pairs := processAccounts(accounts)
|
||||||
|
assert.Equal(t, pairs, expectedPairs)
|
||||||
|
}
|
||||||
|
|
||||||
r := New()
|
func TestBasicAuthFails(t *testing.T) {
|
||||||
accounts := Accounts{"admin": "password"}
|
assert.Panics(t, func() { processAccounts(nil) })
|
||||||
r.Use(BasicAuth(accounts))
|
assert.Panics(t, func() {
|
||||||
|
processAccounts(Accounts{
|
||||||
|
"": "password",
|
||||||
|
"foo": "bar",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
r.GET("/login", func(c *Context) {
|
func TestBasicAuthSearchCredential(t *testing.T) {
|
||||||
c.String(200, "autorized")
|
pairs := processAccounts(Accounts{
|
||||||
|
"admin": "password",
|
||||||
|
"foo": "bar",
|
||||||
|
"bar": "foo",
|
||||||
})
|
})
|
||||||
|
|
||||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
|
user, found := pairs.searchCredential(authorizationHeader("admin", "password"))
|
||||||
r.ServeHTTP(w, req)
|
assert.Equal(t, user, "admin")
|
||||||
|
assert.True(t, found)
|
||||||
|
|
||||||
if w.Code != 200 {
|
user, found = pairs.searchCredential(authorizationHeader("foo", "bar"))
|
||||||
t.Errorf("Response code should be Ok, was: %d", w.Code)
|
assert.Equal(t, user, "foo")
|
||||||
}
|
assert.True(t, found)
|
||||||
bodyAsString := w.Body.String()
|
|
||||||
|
|
||||||
if bodyAsString != "autorized" {
|
user, found = pairs.searchCredential(authorizationHeader("bar", "foo"))
|
||||||
t.Errorf("Response body should be `autorized`, was %s", bodyAsString)
|
assert.Equal(t, user, "bar")
|
||||||
}
|
assert.True(t, found)
|
||||||
|
|
||||||
|
user, found = pairs.searchCredential(authorizationHeader("admins", "password"))
|
||||||
|
assert.Empty(t, user)
|
||||||
|
assert.False(t, found)
|
||||||
|
|
||||||
|
user, found = pairs.searchCredential(authorizationHeader("foo", "bar "))
|
||||||
|
assert.Empty(t, user)
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasicAuthAuthorizationHeader(t *testing.T) {
|
||||||
|
assert.Equal(t, authorizationHeader("admin", "password"), "Basic YWRtaW46cGFzc3dvcmQ=")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasicAuthSecureCompare(t *testing.T) {
|
||||||
|
assert.True(t, secureCompare("1234567890", "1234567890"))
|
||||||
|
assert.False(t, secureCompare("123456789", "1234567890"))
|
||||||
|
assert.False(t, secureCompare("12345678900", "1234567890"))
|
||||||
|
assert.False(t, secureCompare("1234567891", "1234567890"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasicAuthSucceed(t *testing.T) {
|
||||||
|
accounts := Accounts{"admin": "password"}
|
||||||
|
router := New()
|
||||||
|
router.Use(BasicAuth(accounts))
|
||||||
|
router.GET("/login", func(c *Context) {
|
||||||
|
c.String(200, c.MustGet(AuthUserKey).(string))
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/login", nil)
|
||||||
|
req.Header.Set("Authorization", authorizationHeader("admin", "password"))
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, w.Code, 200)
|
||||||
|
assert.Equal(t, w.Body.String(), "admin")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBasicAuth401(t *testing.T) {
|
func TestBasicAuth401(t *testing.T) {
|
||||||
req, _ := http.NewRequest("GET", "/login", nil)
|
called := false
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
r := New()
|
|
||||||
accounts := Accounts{"foo": "bar"}
|
accounts := Accounts{"foo": "bar"}
|
||||||
r.Use(BasicAuth(accounts))
|
router := New()
|
||||||
|
router.Use(BasicAuth(accounts))
|
||||||
r.GET("/login", func(c *Context) {
|
router.GET("/login", func(c *Context) {
|
||||||
c.String(200, "autorized")
|
called = true
|
||||||
|
c.String(200, c.MustGet(AuthUserKey).(string))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/login", nil)
|
||||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
|
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
|
||||||
r.ServeHTTP(w, req)
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
if w.Code != 401 {
|
assert.False(t, called)
|
||||||
t.Errorf("Response code should be Not autorized, was: %d", w.Code)
|
assert.Equal(t, w.Code, 401)
|
||||||
}
|
assert.Equal(t, w.HeaderMap.Get("WWW-Authenticate"), "Basic realm=\"Authorization Required\"")
|
||||||
|
|
||||||
if w.HeaderMap.Get("WWW-Authenticate") != "Basic realm=\"Authorization Required\"" {
|
|
||||||
t.Errorf("WWW-Authenticate header is incorrect: %s", w.HeaderMap.Get("Content-Type"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBasicAuth401WithCustomRealm(t *testing.T) {
|
func TestBasicAuth401WithCustomRealm(t *testing.T) {
|
||||||
req, _ := http.NewRequest("GET", "/login", nil)
|
called := false
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
r := New()
|
|
||||||
accounts := Accounts{"foo": "bar"}
|
accounts := Accounts{"foo": "bar"}
|
||||||
r.Use(BasicAuthForRealm(accounts, "My Custom Realm"))
|
router := New()
|
||||||
|
router.Use(BasicAuthForRealm(accounts, "My Custom Realm"))
|
||||||
r.GET("/login", func(c *Context) {
|
router.GET("/login", func(c *Context) {
|
||||||
c.String(200, "autorized")
|
called = true
|
||||||
|
c.String(200, c.MustGet(AuthUserKey).(string))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/login", nil)
|
||||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
|
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
|
||||||
r.ServeHTTP(w, req)
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
if w.Code != 401 {
|
assert.False(t, called)
|
||||||
t.Errorf("Response code should be Not autorized, was: %d", w.Code)
|
assert.Equal(t, w.Code, 401)
|
||||||
}
|
assert.Equal(t, w.HeaderMap.Get("WWW-Authenticate"), "Basic realm=\"My Custom Realm\"")
|
||||||
|
|
||||||
if w.HeaderMap.Get("WWW-Authenticate") != "Basic realm=\"My Custom Realm\"" {
|
|
||||||
t.Errorf("WWW-Authenticate header is incorrect: %s", w.HeaderMap.Get("Content-Type"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user