|
| 1 | +package auth |
| 2 | + |
| 3 | +import ( |
| 4 | +"errors" |
| 5 | +"testing" |
| 6 | +"time" |
| 7 | +) |
| 8 | + |
| 9 | +type mockStreamingProvider struct { |
| 10 | +credentials Credentials |
| 11 | +err error |
| 12 | +updates chan Credentials |
| 13 | +} |
| 14 | + |
| 15 | +func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider { |
| 16 | +return &mockStreamingProvider{ |
| 17 | +credentials: initialCreds, |
| 18 | +updates: make(chan Credentials, 10), |
| 19 | +} |
| 20 | +} |
| 21 | + |
| 22 | +func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) { |
| 23 | +if m.err != nil { |
| 24 | +return nil, nil, m.err |
| 25 | +} |
| 26 | + |
| 27 | +// Send initial credentials |
| 28 | +listener.OnNext(m.credentials) |
| 29 | + |
| 30 | +// Start goroutine to handle updates |
| 31 | +go func() { |
| 32 | +for creds := range m.updates { |
| 33 | +listener.OnNext(creds) |
| 34 | +} |
| 35 | +}() |
| 36 | + |
| 37 | +return m.credentials, func() error { |
| 38 | +close(m.updates) |
| 39 | +return nil |
| 40 | +}, nil |
| 41 | +} |
| 42 | + |
| 43 | +func TestStreamingCredentialsProvider(t *testing.T) { |
| 44 | +t.Run("successful subscription", func(t *testing.T) { |
| 45 | +initialCreds := NewBasicCredentials("user1", "pass1") |
| 46 | +provider := newMockStreamingProvider(initialCreds) |
| 47 | + |
| 48 | +var receivedCreds []Credentials |
| 49 | +var receivedErrors []error |
| 50 | + |
| 51 | +listener := NewReAuthCredentialsListener( |
| 52 | +func(creds Credentials) error { |
| 53 | +receivedCreds = append(receivedCreds, creds) |
| 54 | +return nil |
| 55 | +}, |
| 56 | +func(err error) { |
| 57 | +receivedErrors = append(receivedErrors, err) |
| 58 | +}, |
| 59 | +) |
| 60 | + |
| 61 | +creds, cancel, err := provider.Subscribe(listener) |
| 62 | +if err != nil { |
| 63 | +t.Fatalf("unexpected error: %v", err) |
| 64 | +} |
| 65 | +if cancel == nil { |
| 66 | +t.Fatal("expected cancel function to be non-nil") |
| 67 | +} |
| 68 | +if creds != initialCreds { |
| 69 | +t.Fatalf("expected credentials %v, got %v", initialCreds, creds) |
| 70 | +} |
| 71 | +if len(receivedCreds) != 1 { |
| 72 | +t.Fatalf("expected 1 received credential, got %d", len(receivedCreds)) |
| 73 | +} |
| 74 | +if receivedCreds[0] != initialCreds { |
| 75 | +t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0]) |
| 76 | +} |
| 77 | +if len(receivedErrors) != 0 { |
| 78 | +t.Fatalf("expected no errors, got %d", len(receivedErrors)) |
| 79 | +} |
| 80 | + |
| 81 | +// Send an update |
| 82 | +newCreds := NewBasicCredentials("user2", "pass2") |
| 83 | +provider.updates <- newCreds |
| 84 | + |
| 85 | +// Wait for update to be processed |
| 86 | +time.Sleep(100 * time.Millisecond) |
| 87 | +if len(receivedCreds) != 2 { |
| 88 | +t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds)) |
| 89 | +} |
| 90 | +if receivedCreds[1] != newCreds { |
| 91 | +t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1]) |
| 92 | +} |
| 93 | + |
| 94 | +// Cancel subscription |
| 95 | +if err := cancel(); err != nil { |
| 96 | +t.Fatalf("unexpected error cancelling subscription: %v", err) |
| 97 | +} |
| 98 | +}) |
| 99 | + |
| 100 | +t.Run("subscription error", func(t *testing.T) { |
| 101 | +provider := &mockStreamingProvider{ |
| 102 | +err: errors.New("subscription failed"), |
| 103 | +} |
| 104 | + |
| 105 | +var receivedCreds []Credentials |
| 106 | +var receivedErrors []error |
| 107 | + |
| 108 | +listener := NewReAuthCredentialsListener( |
| 109 | +func(creds Credentials) error { |
| 110 | +receivedCreds = append(receivedCreds, creds) |
| 111 | +return nil |
| 112 | +}, |
| 113 | +func(err error) { |
| 114 | +receivedErrors = append(receivedErrors, err) |
| 115 | +}, |
| 116 | +) |
| 117 | + |
| 118 | +creds, cancel, err := provider.Subscribe(listener) |
| 119 | +if err == nil { |
| 120 | +t.Fatal("expected error, got nil") |
| 121 | +} |
| 122 | +if cancel != nil { |
| 123 | +t.Fatal("expected cancel function to be nil") |
| 124 | +} |
| 125 | +if creds != nil { |
| 126 | +t.Fatalf("expected nil credentials, got %v", creds) |
| 127 | +} |
| 128 | +if len(receivedCreds) != 0 { |
| 129 | +t.Fatalf("expected no received credentials, got %d", len(receivedCreds)) |
| 130 | +} |
| 131 | +if len(receivedErrors) != 0 { |
| 132 | +t.Fatalf("expected no errors, got %d", len(receivedErrors)) |
| 133 | +} |
| 134 | +}) |
| 135 | + |
| 136 | +t.Run("re-auth error", func(t *testing.T) { |
| 137 | +initialCreds := NewBasicCredentials("user1", "pass1") |
| 138 | +provider := newMockStreamingProvider(initialCreds) |
| 139 | + |
| 140 | +reauthErr := errors.New("re-auth failed") |
| 141 | +var receivedErrors []error |
| 142 | + |
| 143 | +listener := NewReAuthCredentialsListener( |
| 144 | +func(creds Credentials) error { |
| 145 | +return reauthErr |
| 146 | +}, |
| 147 | +func(err error) { |
| 148 | +receivedErrors = append(receivedErrors, err) |
| 149 | +}, |
| 150 | +) |
| 151 | + |
| 152 | +creds, cancel, err := provider.Subscribe(listener) |
| 153 | +if err != nil { |
| 154 | +t.Fatalf("unexpected error: %v", err) |
| 155 | +} |
| 156 | +if cancel == nil { |
| 157 | +t.Fatal("expected cancel function to be non-nil") |
| 158 | +} |
| 159 | +if creds != initialCreds { |
| 160 | +t.Fatalf("expected credentials %v, got %v", initialCreds, creds) |
| 161 | +} |
| 162 | +if len(receivedErrors) != 1 { |
| 163 | +t.Fatalf("expected 1 error, got %d", len(receivedErrors)) |
| 164 | +} |
| 165 | +if receivedErrors[0] != reauthErr { |
| 166 | +t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0]) |
| 167 | +} |
| 168 | + |
| 169 | +if err := cancel(); err != nil { |
| 170 | +t.Fatalf("unexpected error cancelling subscription: %v", err) |
| 171 | +} |
| 172 | +}) |
| 173 | +} |
| 174 | + |
| 175 | +func TestBasicCredentials(t *testing.T) { |
| 176 | +t.Run("basic auth", func(t *testing.T) { |
| 177 | +creds := NewBasicCredentials("user1", "pass1") |
| 178 | +username, password := creds.BasicAuth() |
| 179 | +if username != "user1" { |
| 180 | +t.Fatalf("expected username 'user1', got '%s'", username) |
| 181 | +} |
| 182 | +if password != "pass1" { |
| 183 | +t.Fatalf("expected password 'pass1', got '%s'", password) |
| 184 | +} |
| 185 | +}) |
| 186 | + |
| 187 | +t.Run("raw credentials", func(t *testing.T) { |
| 188 | +creds := NewBasicCredentials("user1", "pass1") |
| 189 | +raw := creds.RawCredentials() |
| 190 | +expected := "user1:pass1" |
| 191 | +if raw != expected { |
| 192 | +t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw) |
| 193 | +} |
| 194 | +}) |
| 195 | + |
| 196 | +t.Run("empty username", func(t *testing.T) { |
| 197 | +creds := NewBasicCredentials("", "pass1") |
| 198 | +username, password := creds.BasicAuth() |
| 199 | +if username != "" { |
| 200 | +t.Fatalf("expected empty username, got '%s'", username) |
| 201 | +} |
| 202 | +if password != "pass1" { |
| 203 | +t.Fatalf("expected password 'pass1', got '%s'", password) |
| 204 | +} |
| 205 | +}) |
| 206 | +} |
| 207 | + |
| 208 | +func TestReAuthCredentialsListener(t *testing.T) { |
| 209 | +t.Run("successful re-auth", func(t *testing.T) { |
| 210 | +var reAuthCalled bool |
| 211 | +var onErrCalled bool |
| 212 | +var receivedCreds Credentials |
| 213 | + |
| 214 | +listener := NewReAuthCredentialsListener( |
| 215 | +func(creds Credentials) error { |
| 216 | +reAuthCalled = true |
| 217 | +receivedCreds = creds |
| 218 | +return nil |
| 219 | +}, |
| 220 | +func(err error) { |
| 221 | +onErrCalled = true |
| 222 | +}, |
| 223 | +) |
| 224 | + |
| 225 | +creds := NewBasicCredentials("user1", "pass1") |
| 226 | +listener.OnNext(creds) |
| 227 | + |
| 228 | +if !reAuthCalled { |
| 229 | +t.Fatal("expected reAuth to be called") |
| 230 | +} |
| 231 | +if onErrCalled { |
| 232 | +t.Fatal("expected onErr not to be called") |
| 233 | +} |
| 234 | +if receivedCreds != creds { |
| 235 | +t.Fatalf("expected credentials %v, got %v", creds, receivedCreds) |
| 236 | +} |
| 237 | +}) |
| 238 | + |
| 239 | +t.Run("re-auth error", func(t *testing.T) { |
| 240 | +var reAuthCalled bool |
| 241 | +var onErrCalled bool |
| 242 | +var receivedErr error |
| 243 | +expectedErr := errors.New("re-auth failed") |
| 244 | + |
| 245 | +listener := NewReAuthCredentialsListener( |
| 246 | +func(creds Credentials) error { |
| 247 | +reAuthCalled = true |
| 248 | +return expectedErr |
| 249 | +}, |
| 250 | +func(err error) { |
| 251 | +onErrCalled = true |
| 252 | +receivedErr = err |
| 253 | +}, |
| 254 | +) |
| 255 | + |
| 256 | +creds := NewBasicCredentials("user1", "pass1") |
| 257 | +listener.OnNext(creds) |
| 258 | + |
| 259 | +if !reAuthCalled { |
| 260 | +t.Fatal("expected reAuth to be called") |
| 261 | +} |
| 262 | +if !onErrCalled { |
| 263 | +t.Fatal("expected onErr to be called") |
| 264 | +} |
| 265 | +if receivedErr != expectedErr { |
| 266 | +t.Fatalf("expected error %v, got %v", expectedErr, receivedErr) |
| 267 | +} |
| 268 | +}) |
| 269 | + |
| 270 | +t.Run("on error", func(t *testing.T) { |
| 271 | +var onErrCalled bool |
| 272 | +var receivedErr error |
| 273 | +expectedErr := errors.New("provider error") |
| 274 | + |
| 275 | +listener := NewReAuthCredentialsListener( |
| 276 | +func(creds Credentials) error { |
| 277 | +return nil |
| 278 | +}, |
| 279 | +func(err error) { |
| 280 | +onErrCalled = true |
| 281 | +receivedErr = err |
| 282 | +}, |
| 283 | +) |
| 284 | + |
| 285 | +listener.OnError(expectedErr) |
| 286 | + |
| 287 | +if !onErrCalled { |
| 288 | +t.Fatal("expected onErr to be called") |
| 289 | +} |
| 290 | +if receivedErr != expectedErr { |
| 291 | +t.Fatalf("expected error %v, got %v", expectedErr, receivedErr) |
| 292 | +} |
| 293 | +}) |
| 294 | + |
| 295 | +t.Run("nil callbacks", func(t *testing.T) { |
| 296 | +listener := NewReAuthCredentialsListener(nil, nil) |
| 297 | + |
| 298 | +// Should not panic |
| 299 | +listener.OnNext(NewBasicCredentials("user1", "pass1")) |
| 300 | +listener.OnError(errors.New("test error")) |
| 301 | +}) |
| 302 | +} |
0 commit comments