From a28104fa2160d3e8965962d27913b4ef121db1d3 Mon Sep 17 00:00:00 2001 From: Manu Mtz-Almeida Date: Wed, 8 Apr 2015 15:17:41 +0200 Subject: [PATCH] Better unit tests for BasicAuth middleware --- auth.go | 29 ++++++---- auth_test.go | 157 +++++++++++++++++++++++++++++++++++---------------- 2 files changed, 125 insertions(+), 61 deletions(-) diff --git a/auth.go b/auth.go index 7a65343..077aca3 100644 --- a/auth.go +++ b/auth.go @@ -29,6 +29,19 @@ func (a authPairs) Len() int { return len(a) } func (a authPairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a authPairs) Less(i, j int) bool { return a[i].Value < a[j].Value } +func (a authPairs) searchCredential(auth string) (string, bool) { + if len(auth) == 0 { + return "", false + } + // Search user in the slice of allowed credentials + r := sort.Search(len(a), func(i int) bool { return a[i].Value >= auth }) + if r < len(a) && secureCompare(a[r].Value, auth) { + return a[r].User, true + } else { + return "", false + } +} + // Implements a basic Basic HTTP Authorization. It takes as arguments a map[string]string where // the key is the user name and the value is the password, as well as the name of the Realm // (see http://tools.ietf.org/html/rfc2617#section-1.2) @@ -40,7 +53,7 @@ func BasicAuthForRealm(accounts Accounts, realm string) HandlerFunc { pairs := processAccounts(accounts) return func(c *Context) { // Search user in the slice of allowed credentials - user, ok := searchCredential(pairs, c.Request.Header.Get("Authorization")) + user, ok := pairs.searchCredential(c.Request.Header.Get("Authorization")) if !ok { // Credentials doesn't match, we return 401 Unauthorized and abort request. c.Writer.Header().Set("WWW-Authenticate", realm) @@ -80,17 +93,9 @@ func processAccounts(accounts Accounts) authPairs { return pairs } -func searchCredential(pairs authPairs, auth string) (string, bool) { - if len(auth) == 0 { - return "", false - } - // Search user in the slice of allowed credentials - r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Value >= auth }) - if r < len(pairs) && secureCompare(pairs[r].Value, auth) { - return pairs[r].User, true - } else { - return "", false - } +func authorizationHeader(user, password string) string { + base := user + ":" + password + return "Basic " + base64.StdEncoding.EncodeToString([]byte(base)) } func secureCompare(given, actual string) bool { diff --git a/auth_test.go b/auth_test.go index d2f165c..a378c1a 100644 --- a/auth_test.go +++ b/auth_test.go @@ -9,77 +9,136 @@ import ( "net/http" "net/http/httptest" "testing" + + "github.com/stretchr/testify/assert" ) -func TestBasicAuthSucceed(t *testing.T) { - req, _ := http.NewRequest("GET", "/login", nil) - w := httptest.NewRecorder() +func TestBasicAuth(t *testing.T) { + accounts := Accounts{ + "admin": "password", + "foo": "bar", + "bar": "foo", + } + expectedPairs := authPairs{ + authPair{ + User: "admin", + Value: "Basic YWRtaW46cGFzc3dvcmQ=", + }, + authPair{ + User: "bar", + Value: "Basic YmFyOmZvbw==", + }, + authPair{ + User: "foo", + Value: "Basic Zm9vOmJhcg==", + }, + } + pairs := processAccounts(accounts) + assert.Equal(t, pairs, expectedPairs) +} - r := New() - accounts := Accounts{"admin": "password"} - r.Use(BasicAuth(accounts)) +func TestBasicAuthFails(t *testing.T) { + assert.Panics(t, func() { processAccounts(nil) }) + assert.Panics(t, func() { + processAccounts(Accounts{ + "": "password", + "foo": "bar", + }) + }) +} - r.GET("/login", func(c *Context) { - c.String(200, "autorized") +func TestBasicAuthSearchCredential(t *testing.T) { + pairs := processAccounts(Accounts{ + "admin": "password", + "foo": "bar", + "bar": "foo", }) - req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password"))) - r.ServeHTTP(w, req) + user, found := pairs.searchCredential(authorizationHeader("admin", "password")) + assert.Equal(t, user, "admin") + assert.True(t, found) - if w.Code != 200 { - t.Errorf("Response code should be Ok, was: %d", w.Code) - } - bodyAsString := w.Body.String() + user, found = pairs.searchCredential(authorizationHeader("foo", "bar")) + assert.Equal(t, user, "foo") + assert.True(t, found) - if bodyAsString != "autorized" { - t.Errorf("Response body should be `autorized`, was %s", bodyAsString) - } + user, found = pairs.searchCredential(authorizationHeader("bar", "foo")) + assert.Equal(t, user, "bar") + assert.True(t, found) + + user, found = pairs.searchCredential(authorizationHeader("admins", "password")) + assert.Empty(t, user) + assert.False(t, found) + + user, found = pairs.searchCredential(authorizationHeader("foo", "bar ")) + assert.Empty(t, user) + assert.False(t, found) +} + +func TestBasicAuthAuthorizationHeader(t *testing.T) { + assert.Equal(t, authorizationHeader("admin", "password"), "Basic YWRtaW46cGFzc3dvcmQ=") +} + +func TestBasicAuthSecureCompare(t *testing.T) { + assert.True(t, secureCompare("1234567890", "1234567890")) + assert.False(t, secureCompare("123456789", "1234567890")) + assert.False(t, secureCompare("12345678900", "1234567890")) + assert.False(t, secureCompare("1234567891", "1234567890")) +} + +func TestBasicAuthSucceed(t *testing.T) { + accounts := Accounts{"admin": "password"} + router := New() + router.Use(BasicAuth(accounts)) + router.GET("/login", func(c *Context) { + c.String(200, c.MustGet(AuthUserKey).(string)) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/login", nil) + req.Header.Set("Authorization", authorizationHeader("admin", "password")) + router.ServeHTTP(w, req) + + assert.Equal(t, w.Code, 200) + assert.Equal(t, w.Body.String(), "admin") } func TestBasicAuth401(t *testing.T) { - req, _ := http.NewRequest("GET", "/login", nil) - w := httptest.NewRecorder() - - r := New() + called := false accounts := Accounts{"foo": "bar"} - r.Use(BasicAuth(accounts)) - - r.GET("/login", func(c *Context) { - c.String(200, "autorized") + router := New() + router.Use(BasicAuth(accounts)) + router.GET("/login", func(c *Context) { + called = true + c.String(200, c.MustGet(AuthUserKey).(string)) }) + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/login", nil) req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password"))) - r.ServeHTTP(w, req) + router.ServeHTTP(w, req) - if w.Code != 401 { - t.Errorf("Response code should be Not autorized, was: %d", w.Code) - } - - if w.HeaderMap.Get("WWW-Authenticate") != "Basic realm=\"Authorization Required\"" { - t.Errorf("WWW-Authenticate header is incorrect: %s", w.HeaderMap.Get("Content-Type")) - } + assert.False(t, called) + assert.Equal(t, w.Code, 401) + assert.Equal(t, w.HeaderMap.Get("WWW-Authenticate"), "Basic realm=\"Authorization Required\"") } func TestBasicAuth401WithCustomRealm(t *testing.T) { - req, _ := http.NewRequest("GET", "/login", nil) - w := httptest.NewRecorder() - - r := New() + called := false accounts := Accounts{"foo": "bar"} - r.Use(BasicAuthForRealm(accounts, "My Custom Realm")) - - r.GET("/login", func(c *Context) { - c.String(200, "autorized") + router := New() + router.Use(BasicAuthForRealm(accounts, "My Custom Realm")) + router.GET("/login", func(c *Context) { + called = true + c.String(200, c.MustGet(AuthUserKey).(string)) }) + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/login", nil) req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password"))) - r.ServeHTTP(w, req) + router.ServeHTTP(w, req) - if w.Code != 401 { - t.Errorf("Response code should be Not autorized, was: %d", w.Code) - } - - if w.HeaderMap.Get("WWW-Authenticate") != "Basic realm=\"My Custom Realm\"" { - t.Errorf("WWW-Authenticate header is incorrect: %s", w.HeaderMap.Get("Content-Type")) - } + assert.False(t, called) + assert.Equal(t, w.Code, 401) + assert.Equal(t, w.HeaderMap.Get("WWW-Authenticate"), "Basic realm=\"My Custom Realm\"") }