Better unit tests for BasicAuth middleware

This commit is contained in:
Manu Mtz-Almeida 2015-04-08 15:17:41 +02:00
parent 4d315f474b
commit a28104fa21
2 changed files with 125 additions and 61 deletions

29
auth.go
View File

@ -29,6 +29,19 @@ func (a authPairs) Len() int { return len(a) }
func (a authPairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a authPairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a authPairs) Less(i, j int) bool { return a[i].Value < a[j].Value } func (a authPairs) Less(i, j int) bool { return a[i].Value < a[j].Value }
func (a authPairs) searchCredential(auth string) (string, bool) {
if len(auth) == 0 {
return "", false
}
// Search user in the slice of allowed credentials
r := sort.Search(len(a), func(i int) bool { return a[i].Value >= auth })
if r < len(a) && secureCompare(a[r].Value, auth) {
return a[r].User, true
} else {
return "", false
}
}
// Implements a basic Basic HTTP Authorization. It takes as arguments a map[string]string where // Implements a basic Basic HTTP Authorization. It takes as arguments a map[string]string where
// the key is the user name and the value is the password, as well as the name of the Realm // the key is the user name and the value is the password, as well as the name of the Realm
// (see http://tools.ietf.org/html/rfc2617#section-1.2) // (see http://tools.ietf.org/html/rfc2617#section-1.2)
@ -40,7 +53,7 @@ func BasicAuthForRealm(accounts Accounts, realm string) HandlerFunc {
pairs := processAccounts(accounts) pairs := processAccounts(accounts)
return func(c *Context) { return func(c *Context) {
// Search user in the slice of allowed credentials // Search user in the slice of allowed credentials
user, ok := searchCredential(pairs, c.Request.Header.Get("Authorization")) user, ok := pairs.searchCredential(c.Request.Header.Get("Authorization"))
if !ok { if !ok {
// Credentials doesn't match, we return 401 Unauthorized and abort request. // Credentials doesn't match, we return 401 Unauthorized and abort request.
c.Writer.Header().Set("WWW-Authenticate", realm) c.Writer.Header().Set("WWW-Authenticate", realm)
@ -80,17 +93,9 @@ func processAccounts(accounts Accounts) authPairs {
return pairs return pairs
} }
func searchCredential(pairs authPairs, auth string) (string, bool) { func authorizationHeader(user, password string) string {
if len(auth) == 0 { base := user + ":" + password
return "", false return "Basic " + base64.StdEncoding.EncodeToString([]byte(base))
}
// Search user in the slice of allowed credentials
r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Value >= auth })
if r < len(pairs) && secureCompare(pairs[r].Value, auth) {
return pairs[r].User, true
} else {
return "", false
}
} }
func secureCompare(given, actual string) bool { func secureCompare(given, actual string) bool {

View File

@ -9,77 +9,136 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestBasicAuthSucceed(t *testing.T) { func TestBasicAuth(t *testing.T) {
req, _ := http.NewRequest("GET", "/login", nil) accounts := Accounts{
w := httptest.NewRecorder() "admin": "password",
"foo": "bar",
"bar": "foo",
}
expectedPairs := authPairs{
authPair{
User: "admin",
Value: "Basic YWRtaW46cGFzc3dvcmQ=",
},
authPair{
User: "bar",
Value: "Basic YmFyOmZvbw==",
},
authPair{
User: "foo",
Value: "Basic Zm9vOmJhcg==",
},
}
pairs := processAccounts(accounts)
assert.Equal(t, pairs, expectedPairs)
}
r := New() func TestBasicAuthFails(t *testing.T) {
accounts := Accounts{"admin": "password"} assert.Panics(t, func() { processAccounts(nil) })
r.Use(BasicAuth(accounts)) assert.Panics(t, func() {
processAccounts(Accounts{
"": "password",
"foo": "bar",
})
})
}
r.GET("/login", func(c *Context) { func TestBasicAuthSearchCredential(t *testing.T) {
c.String(200, "autorized") pairs := processAccounts(Accounts{
"admin": "password",
"foo": "bar",
"bar": "foo",
}) })
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password"))) user, found := pairs.searchCredential(authorizationHeader("admin", "password"))
r.ServeHTTP(w, req) assert.Equal(t, user, "admin")
assert.True(t, found)
if w.Code != 200 { user, found = pairs.searchCredential(authorizationHeader("foo", "bar"))
t.Errorf("Response code should be Ok, was: %d", w.Code) assert.Equal(t, user, "foo")
} assert.True(t, found)
bodyAsString := w.Body.String()
if bodyAsString != "autorized" { user, found = pairs.searchCredential(authorizationHeader("bar", "foo"))
t.Errorf("Response body should be `autorized`, was %s", bodyAsString) assert.Equal(t, user, "bar")
} assert.True(t, found)
user, found = pairs.searchCredential(authorizationHeader("admins", "password"))
assert.Empty(t, user)
assert.False(t, found)
user, found = pairs.searchCredential(authorizationHeader("foo", "bar "))
assert.Empty(t, user)
assert.False(t, found)
}
func TestBasicAuthAuthorizationHeader(t *testing.T) {
assert.Equal(t, authorizationHeader("admin", "password"), "Basic YWRtaW46cGFzc3dvcmQ=")
}
func TestBasicAuthSecureCompare(t *testing.T) {
assert.True(t, secureCompare("1234567890", "1234567890"))
assert.False(t, secureCompare("123456789", "1234567890"))
assert.False(t, secureCompare("12345678900", "1234567890"))
assert.False(t, secureCompare("1234567891", "1234567890"))
}
func TestBasicAuthSucceed(t *testing.T) {
accounts := Accounts{"admin": "password"}
router := New()
router.Use(BasicAuth(accounts))
router.GET("/login", func(c *Context) {
c.String(200, c.MustGet(AuthUserKey).(string))
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/login", nil)
req.Header.Set("Authorization", authorizationHeader("admin", "password"))
router.ServeHTTP(w, req)
assert.Equal(t, w.Code, 200)
assert.Equal(t, w.Body.String(), "admin")
} }
func TestBasicAuth401(t *testing.T) { func TestBasicAuth401(t *testing.T) {
req, _ := http.NewRequest("GET", "/login", nil) called := false
w := httptest.NewRecorder()
r := New()
accounts := Accounts{"foo": "bar"} accounts := Accounts{"foo": "bar"}
r.Use(BasicAuth(accounts)) router := New()
router.Use(BasicAuth(accounts))
r.GET("/login", func(c *Context) { router.GET("/login", func(c *Context) {
c.String(200, "autorized") called = true
c.String(200, c.MustGet(AuthUserKey).(string))
}) })
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/login", nil)
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password"))) req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
r.ServeHTTP(w, req) router.ServeHTTP(w, req)
if w.Code != 401 { assert.False(t, called)
t.Errorf("Response code should be Not autorized, was: %d", w.Code) assert.Equal(t, w.Code, 401)
} assert.Equal(t, w.HeaderMap.Get("WWW-Authenticate"), "Basic realm=\"Authorization Required\"")
if w.HeaderMap.Get("WWW-Authenticate") != "Basic realm=\"Authorization Required\"" {
t.Errorf("WWW-Authenticate header is incorrect: %s", w.HeaderMap.Get("Content-Type"))
}
} }
func TestBasicAuth401WithCustomRealm(t *testing.T) { func TestBasicAuth401WithCustomRealm(t *testing.T) {
req, _ := http.NewRequest("GET", "/login", nil) called := false
w := httptest.NewRecorder()
r := New()
accounts := Accounts{"foo": "bar"} accounts := Accounts{"foo": "bar"}
r.Use(BasicAuthForRealm(accounts, "My Custom Realm")) router := New()
router.Use(BasicAuthForRealm(accounts, "My Custom Realm"))
r.GET("/login", func(c *Context) { router.GET("/login", func(c *Context) {
c.String(200, "autorized") called = true
c.String(200, c.MustGet(AuthUserKey).(string))
}) })
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/login", nil)
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password"))) req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
r.ServeHTTP(w, req) router.ServeHTTP(w, req)
if w.Code != 401 { assert.False(t, called)
t.Errorf("Response code should be Not autorized, was: %d", w.Code) assert.Equal(t, w.Code, 401)
} assert.Equal(t, w.HeaderMap.Get("WWW-Authenticate"), "Basic realm=\"My Custom Realm\"")
if w.HeaderMap.Get("WWW-Authenticate") != "Basic realm=\"My Custom Realm\"" {
t.Errorf("WWW-Authenticate header is incorrect: %s", w.HeaderMap.Get("Content-Type"))
}
} }