Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 88 additions & 62 deletions middleware/basic_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package middleware

import (
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -16,78 +17,103 @@ import (

func TestBasicAuth(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
f := func(u, p string, c echo.Context) (bool, error) {

mockValidator := func(u, p string, c echo.Context) (bool, error) {
if u == "joe" && p == "secret" {
return true, nil
}
return false, nil
}
h := BasicAuth(f)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})

// Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))

h = BasicAuthWithConfig(BasicAuthConfig{
Validator: f,
Realm: "someRealm",
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})

// Valid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))

// Case-insensitive header scheme
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))

// Invalid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
tests := []struct {
name string
authHeader string
expectedCode int
expectedAuth string
skipperResult bool
expectedErr bool
expectedErrMsg string
}{
{
name: "Valid credentials",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
expectedCode: http.StatusOK,
},
{
name: "Case-insensitive header scheme",
authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
expectedCode: http.StatusOK,
},
{
name: "Invalid credentials",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")),
expectedCode: http.StatusUnauthorized,
expectedAuth: basic + ` realm="someRealm"`,
expectedErr: true,
expectedErrMsg: "Unauthorized",
},
{
name: "Invalid base64 string",
authHeader: basic + " invalidString",
expectedCode: http.StatusBadRequest,
expectedErr: true,
expectedErrMsg: "Bad Request",
},
{
name: "Missing Authorization header",
expectedCode: http.StatusUnauthorized,
expectedErr: true,
expectedErrMsg: "Unauthorized",
},
{
name: "Invalid Authorization header",
authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")),
expectedCode: http.StatusUnauthorized,
expectedErr: true,
expectedErrMsg: "Unauthorized",
},
{
name: "Skipped Request",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")),
expectedCode: http.StatusOK,
skipperResult: true,
},
}

// Invalid base64 string
auth = basic + " invalidString"
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)

// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
if tt.authHeader != "" {
req.Header.Set(echo.HeaderAuthorization, tt.authHeader)
}

h = BasicAuthWithConfig(BasicAuthConfig{
Validator: f,
Realm: "someRealm",
Skipper: func(c echo.Context) bool {
return true
},
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
h := BasicAuthWithConfig(BasicAuthConfig{
Validator: mockValidator,
Realm: "someRealm",
Skipper: func(c echo.Context) bool {
return tt.skipperResult
},
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})

// Skipped Request
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
err := h(c)

if tt.expectedErr {
var he *echo.HTTPError
errors.As(err, &he)
assert.Equal(t, tt.expectedCode, he.Code)
if tt.expectedAuth != "" {
assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate))
}
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedCode, res.Code)
}
})
}
}
Loading