@@ -2,6 +2,7 @@ package middleware
|
2 | 2 |
|
3 | 3 | import (
|
4 | 4 | "encoding/base64"
|
| 5 | +"errors" |
5 | 6 | "net/http"
|
6 | 7 | "net/http/httptest"
|
7 | 8 | "strings"
|
@@ -11,11 +12,139 @@ import (
|
11 | 12 | ".com/stretchr/testify/assert"
|
12 | 13 | )
|
13 | 14 |
|
| 15 | +func TestBasicAuthWithConfig(t *testing.T) { |
| 16 | +validatorFunc := func(u, p string, c echo.Context) (bool, error) { |
| 17 | +if u == "joe" && p == "secret" { |
| 18 | +return true, nil |
| 19 | +} |
| 20 | +if u == "error" { |
| 21 | +return false, errors.New(p) |
| 22 | +} |
| 23 | +return false, nil |
| 24 | +} |
| 25 | +defaultConfig := BasicAuthConfig{Validator: validatorFunc} |
| 26 | + |
| 27 | +// we can not add OK value here because ranging over map returns random order. We just try to trigger break |
| 28 | +tooManyAuths := make([]string, 0) |
| 29 | +for i := 0; i < headerCountLimit+2; i++ { |
| 30 | +tooManyAuths = append(tooManyAuths, basic+" "+base64.StdEncoding.EncodeToString([]byte("nope:nope"))) |
| 31 | +} |
| 32 | + |
| 33 | +var testCases = []struct { |
| 34 | +name string |
| 35 | +givenConfig BasicAuthConfig |
| 36 | +whenAuth []string |
| 37 | +expectHeader string |
| 38 | +expectErr string |
| 39 | +}{ |
| 40 | +{ |
| 41 | +name: "ok", |
| 42 | +givenConfig: defaultConfig, |
| 43 | +whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, |
| 44 | +}, |
| 45 | +{ |
| 46 | +name: "ok, from multiple auth headers one is ok", |
| 47 | +givenConfig: defaultConfig, |
| 48 | +whenAuth: []string{ |
| 49 | +"Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), // different type |
| 50 | +basic + " NOT_BASE64", // invalid basic auth |
| 51 | +basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), // OK |
| 52 | +}, |
| 53 | +}, |
| 54 | +{ |
| 55 | +name: "nok, invalid Authorization header", |
| 56 | +givenConfig: defaultConfig, |
| 57 | +whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, |
| 58 | +expectHeader: basic + ` realm=Restricted`, |
| 59 | +expectErr: "code=401, message=Unauthorized", |
| 60 | +}, |
| 61 | +{ |
| 62 | +name: "nok, not base64 Authorization header", |
| 63 | +givenConfig: defaultConfig, |
| 64 | +whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"}, |
| 65 | +expectErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 3", |
| 66 | +}, |
| 67 | +{ |
| 68 | +name: "nok, missing Authorization header", |
| 69 | +givenConfig: defaultConfig, |
| 70 | +expectHeader: basic + ` realm=Restricted`, |
| 71 | +expectErr: "code=401, message=Unauthorized", |
| 72 | +}, |
| 73 | +{ |
| 74 | +name: "nok, too many invalid Authorization header", |
| 75 | +givenConfig: defaultConfig, |
| 76 | +whenAuth: tooManyAuths, |
| 77 | +expectHeader: basic + ` realm=Restricted`, |
| 78 | +expectErr: "code=401, message=Unauthorized", |
| 79 | +}, |
| 80 | +{ |
| 81 | +name: "ok, realm", |
| 82 | +givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, |
| 83 | +whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, |
| 84 | +}, |
| 85 | +{ |
| 86 | +name: "ok, realm, case-insensitive header scheme", |
| 87 | +givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, |
| 88 | +whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, |
| 89 | +}, |
| 90 | +{ |
| 91 | +name: "nok, realm, invalid Authorization header", |
| 92 | +givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, |
| 93 | +whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, |
| 94 | +expectHeader: basic + ` realm="someRealm"`, |
| 95 | +expectErr: "code=401, message=Unauthorized", |
| 96 | +}, |
| 97 | +{ |
| 98 | +name: "nok, validator func returns an error", |
| 99 | +givenConfig: defaultConfig, |
| 100 | +whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))}, |
| 101 | +expectErr: "my_error", |
| 102 | +}, |
| 103 | +{ |
| 104 | +name: "ok, skipped", |
| 105 | +givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool { |
| 106 | +return true |
| 107 | +}}, |
| 108 | +whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, |
| 109 | +}, |
| 110 | +} |
| 111 | + |
| 112 | +for _, tc := range testCases { |
| 113 | +t.Run(tc.name, func(t *testing.T) { |
| 114 | +e := echo.New() |
| 115 | + |
| 116 | +mw := BasicAuthWithConfig(tc.givenConfig) |
| 117 | + |
| 118 | +h := mw(func(c echo.Context) error { |
| 119 | +return c.String(http.StatusTeapot, "test") |
| 120 | +}) |
| 121 | + |
| 122 | +req := httptest.NewRequest(http.MethodGet, "/", nil) |
| 123 | +res := httptest.NewRecorder() |
| 124 | + |
| 125 | +if len(tc.whenAuth) != 0 { |
| 126 | +for _, a := range tc.whenAuth { |
| 127 | +req.Header.Add(echo.HeaderAuthorization, a) |
| 128 | +} |
| 129 | +} |
| 130 | +err := h(e.NewContext(req, res)) |
| 131 | + |
| 132 | +if tc.expectErr != "" { |
| 133 | +assert.Equal(t, http.StatusOK, res.Code) |
| 134 | +assert.EqualError(t, err, tc.expectErr) |
| 135 | +} else { |
| 136 | +assert.Equal(t, http.StatusTeapot, res.Code) |
| 137 | +assert.NoError(t, err) |
| 138 | +} |
| 139 | +if tc.expectHeader != "" { |
| 140 | +assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate)) |
| 141 | +} |
| 142 | +}) |
| 143 | +} |
| 144 | +} |
| 145 | + |
14 | 146 | func TestBasicAuth(t *testing.T) {
|
15 | 147 | e := echo.New()
|
16 |
| -req := httptest.NewRequest(http.MethodGet, "/", nil) |
17 |
| -res := httptest.NewRecorder() |
18 |
| -c := e.NewContext(req, res) |
19 | 148 | f := func(u, p string, c echo.Context) (bool, error) {
|
20 | 149 | if u == "joe" && p == "secret" {
|
21 | 150 | return true, nil
|
@@ -26,50 +155,11 @@ func TestBasicAuth(t *testing.T) {
|
26 | 155 | return c.String(http.StatusOK, "test")
|
27 | 156 | })
|
28 | 157 |
|
29 |
| -// Valid credentials |
30 |
| -auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) |
31 |
| -req.Header.Set(echo.HeaderAuthorization, auth) |
32 |
| -assert.NoError(t, h(c)) |
33 |
| - |
34 |
| -h = BasicAuthWithConfig(BasicAuthConfig{ |
35 |
| -Skipper: nil, |
36 |
| -Validator: f, |
37 |
| -Realm: "someRealm", |
38 |
| -})(func(c echo.Context) error { |
39 |
| -return c.String(http.StatusOK, "test") |
40 |
| -}) |
41 |
| - |
42 |
| -// Valid credentials |
43 |
| -auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) |
44 |
| -req.Header.Set(echo.HeaderAuthorization, auth) |
45 |
| -assert.NoError(t, h(c)) |
| 158 | +req := httptest.NewRequest(http.MethodGet, "/", nil) |
| 159 | +res := httptest.NewRecorder() |
| 160 | +c := e.NewContext(req, res) |
46 | 161 |
|
47 |
| -// Case-insensitive header scheme |
48 |
| -auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) |
| 162 | +auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) |
49 | 163 | req.Header.Set(echo.HeaderAuthorization, auth)
|
50 | 164 | assert.NoError(t, h(c))
|
51 |
| - |
52 |
| -// Invalid credentials |
53 |
| -auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) |
54 |
| -req.Header.Set(echo.HeaderAuthorization, auth) |
55 |
| -he := h(c).(*echo.HTTPError) |
56 |
| -assert.Equal(t, http.StatusUnauthorized, he.Code) |
57 |
| -assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) |
58 |
| - |
59 |
| -// Invalid base64 string |
60 |
| -auth = basic + " invalidString" |
61 |
| -req.Header.Set(echo.HeaderAuthorization, auth) |
62 |
| -he = h(c).(*echo.HTTPError) |
63 |
| -assert.Equal(t, http.StatusBadRequest, he.Code) |
64 |
| - |
65 |
| -// Missing Authorization header |
66 |
| -req.Header.Del(echo.HeaderAuthorization) |
67 |
| -he = h(c).(*echo.HTTPError) |
68 |
| -assert.Equal(t, http.StatusUnauthorized, he.Code) |
69 |
| - |
70 |
| -// Invalid Authorization header |
71 |
| -auth = base64.StdEncoding.EncodeToString([]byte("invalid")) |
72 |
| -req.Header.Set(echo.HeaderAuthorization, auth) |
73 |
| -he = h(c).(*echo.HTTPError) |
74 |
| -assert.Equal(t, http.StatusUnauthorized, he.Code) |
75 | 165 | }
|
0 commit comments