16
16
******************************************************************************/
17
17
package org .mitre .openid .connect ;
18
18
19
- import java .io .Serializable ;
20
19
import java .security .NoSuchAlgorithmException ;
21
20
import java .security .spec .InvalidKeySpecException ;
22
21
import java .text .ParseException ;
23
22
import java .util .Collections ;
24
- import java .util .HashMap ;
25
23
import java .util .Map ;
26
24
import java .util .Set ;
27
25
45
43
46
44
import com .google .common .base .Strings ;
47
45
import com .google .common .collect .ImmutableMap ;
48
- import com .google .common .collect .Maps ;
46
+ import com .google .gson .JsonElement ;
47
+ import com .google .gson .JsonObject ;
48
+ import com .google .gson .JsonParser ;
49
49
import com .nimbusds .jose .Algorithm ;
50
50
import com .nimbusds .jose .JWEObject .State ;
51
51
import com .nimbusds .jose .JWSAlgorithm ;
@@ -76,6 +76,8 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
76
76
@ Autowired
77
77
private JwtEncryptionAndDecryptionService encryptionService ;
78
78
79
+ private JsonParser parser = new JsonParser ();
80
+
79
81
/**
80
82
* Constructor with arguments
81
83
*
@@ -97,43 +99,43 @@ public OAuth2Request createOAuth2Request(AuthorizationRequest request) {
97
99
@ Override
98
100
public AuthorizationRequest createAuthorizationRequest (Map <String , String > inputParams ) {
99
101
100
- Map <String , String > parameters = processRequestObject (inputParams );
101
-
102
- String clientId = parameters .get ("client_id" );
103
- ClientDetails client = null ;
104
-
105
- if (clientId != null ) {
106
- client = clientDetailsService .loadClientByClientId (clientId );
107
- }
108
-
109
- AuthorizationRequest request = new AuthorizationRequest (parameters , Collections .<String , String > emptyMap (),
110
- parameters .get (OAuth2Utils .CLIENT_ID ),
111
- OAuth2Utils .parseParameterList (parameters .get (OAuth2Utils .SCOPE )), null ,
112
- null , false , parameters .get (OAuth2Utils .STATE ),
113
- parameters .get (OAuth2Utils .REDIRECT_URI ),
114
- OAuth2Utils .parseParameterList (parameters .get (OAuth2Utils .RESPONSE_TYPE )));
115
-
116
- Set <String > scopes = OAuth2Utils .parseParameterList (parameters .get ("scope" ));
117
- if ((scopes == null || scopes .isEmpty ()) && client != null ) {
118
- Set <String > clientScopes = client .getScope ();
119
- scopes = clientScopes ;
120
- }
121
102
122
- request .setScope (scopes );
103
+ AuthorizationRequest request = new AuthorizationRequest (inputParams , Collections .<String , String > emptyMap (),
104
+ inputParams .get (OAuth2Utils .CLIENT_ID ),
105
+ OAuth2Utils .parseParameterList (inputParams .get (OAuth2Utils .SCOPE )), null ,
106
+ null , false , inputParams .get (OAuth2Utils .STATE ),
107
+ inputParams .get (OAuth2Utils .REDIRECT_URI ),
108
+ OAuth2Utils .parseParameterList (inputParams .get (OAuth2Utils .RESPONSE_TYPE )));
123
109
124
110
//Add extension parameters to the 'extensions' map
125
- Map <String , Serializable > extensions = Maps .newHashMap ();
126
- if (parameters .containsKey ("prompt" )) {
127
- extensions .put ("prompt" , parameters .get ("prompt" ));
111
+
112
+ if (inputParams .containsKey ("prompt" )) {
113
+ request .getExtensions ().put ("prompt" , inputParams .get ("prompt" ));
114
+ }
115
+ if (inputParams .containsKey ("nonce" )) {
116
+ request .getExtensions ().put ("nonce" , inputParams .get ("nonce" ));
128
117
}
129
- if (parameters .containsKey ("request" )) {
130
- extensions .put ("request" , parameters .get ("request" ));
118
+
119
+ if (inputParams .containsKey ("claims" )) {
120
+ JsonObject claimsRequest = parseClaimRequest (inputParams .get ("claims" ));
121
+ if (claimsRequest != null ) {
122
+ request .getExtensions ().put ("claims" , claimsRequest .toString ());
123
+ }
131
124
}
132
- if (parameters .containsKey ("nonce" )) {
133
- extensions .put ("nonce" , parameters .get ("nonce" ));
125
+
126
+ if (inputParams .containsKey ("request" )) {
127
+ request .getExtensions ().put ("request" , inputParams .get ("request" ));
128
+ processRequestObject (inputParams .get ("request" ), request );
134
129
}
135
130
136
- request .setExtensions (extensions );
131
+
132
+ if ((request .getScope () == null || request .getScope ().isEmpty ())) {
133
+ if (request .getClientId () != null ) {
134
+ ClientDetails client = clientDetailsService .loadClientByClientId (request .getClientId ());
135
+ Set <String > clientScopes = client .getScope ();
136
+ request .setScope (clientScopes );
137
+ }
138
+ }
137
139
138
140
return request ;
139
141
}
@@ -142,49 +144,28 @@ public AuthorizationRequest createAuthorizationRequest(Map<String, String> input
142
144
* @param inputParams
143
145
* @return
144
146
*/
145
- private Map <String , String > processRequestObject (Map <String , String > inputParams ) {
146
-
147
- String jwtString = inputParams .get ("request" );
148
-
149
- // if there's no request object, bail early
150
- if (Strings .isNullOrEmpty (jwtString )) {
151
- return inputParams ;
152
- }
153
-
154
- // start by copying over what's already in there
155
- Map <String , String > parameters = new HashMap <String , String >(inputParams );
147
+ private void processRequestObject (String jwtString , AuthorizationRequest request ) {
156
148
157
149
// parse the request object
158
150
try {
159
151
JWT jwt = JWTParser .parse (jwtString );
160
152
161
- /*
162
- if (jwt instanceof EncryptedJWT) {
163
- // TODO: it's an encrypted JWT, decrypt it and use it
164
- } else {
165
- // it's not encrypted...
166
- }
167
- */
168
-
169
-
170
-
171
-
172
153
// TODO: check parameter consistency, move keys to constants
173
154
174
155
if (jwt instanceof SignedJWT ) {
175
156
// it's a signed JWT, check the signature
176
157
177
158
SignedJWT signedJwt = (SignedJWT )jwt ;
178
-
179
- String clientId = inputParams . get ( "client_id" );
180
- if (clientId == null ) {
181
- clientId = signedJwt .getJWTClaimsSet ().getStringClaim ("client_id" );
159
+
160
+ // need to check clientId first so that we can load the client to check other fields
161
+ if (request . getClientId () == null ) {
162
+ request . setClientId ( signedJwt .getJWTClaimsSet ().getStringClaim ("client_id" ) );
182
163
}
183
164
184
- ClientDetailsEntity client = clientDetailsService .loadClientByClientId (clientId );
165
+ ClientDetailsEntity client = clientDetailsService .loadClientByClientId (request . getClientId () );
185
166
186
167
if (client == null ) {
187
- throw new InvalidClientException ("Client not found: " + clientId );
168
+ throw new InvalidClientException ("Client not found: " + request . getClientId () );
188
169
}
189
170
190
171
@@ -239,15 +220,15 @@ private Map<String, String> processRequestObject(Map<String, String> inputParams
239
220
} else if (jwt instanceof PlainJWT ) {
240
221
PlainJWT plainJwt = (PlainJWT )jwt ;
241
222
242
- String clientId = inputParams . get ( "client_id" );
243
- if (clientId == null ) {
244
- clientId = plainJwt .getJWTClaimsSet ().getStringClaim ("client_id" );
223
+ // need to check clientId first so that we can load the client to check other fields
224
+ if (request . getClientId () == null ) {
225
+ request . setClientId ( plainJwt .getJWTClaimsSet ().getStringClaim ("client_id" ) );
245
226
}
246
227
247
- ClientDetailsEntity client = clientDetailsService .loadClientByClientId (clientId );
228
+ ClientDetailsEntity client = clientDetailsService .loadClientByClientId (request . getClientId () );
248
229
249
230
if (client == null ) {
250
- throw new InvalidClientException ("Client not found: " + clientId );
231
+ throw new InvalidClientException ("Client not found: " + request . getClientId () );
251
232
}
252
233
253
234
if (client .getRequestObjectSigningAlg () == null ) {
@@ -270,13 +251,28 @@ private Map<String, String> processRequestObject(Map<String, String> inputParams
270
251
throw new InvalidClientException ("Unable to decrypt the request object" );
271
252
}
272
253
254
+ // need to check clientId first so that we can load the client to check other fields
255
+ if (request .getClientId () == null ) {
256
+ request .setClientId (encryptedJWT .getJWTClaimsSet ().getStringClaim ("client_id" ));
257
+ }
258
+
259
+ ClientDetailsEntity client = clientDetailsService .loadClientByClientId (request .getClientId ());
260
+
261
+ if (client == null ) {
262
+ throw new InvalidClientException ("Client not found: " + request .getClientId ());
263
+ }
264
+
265
+
273
266
}
274
267
268
+
275
269
/*
270
+ * Claims precedence order logic:
271
+ *
276
272
* if (in Claims):
277
273
* if (in params):
278
274
* if (equal):
279
- * all set
275
+ * OK
280
276
* else (not equal):
281
277
* error
282
278
* else (not in params):
@@ -285,64 +281,102 @@ private Map<String, String> processRequestObject(Map<String, String> inputParams
285
281
* we don't care
286
282
*/
287
283
288
- ReadOnlyJWTClaimsSet claims = jwt . getJWTClaimsSet ();
284
+ // now that we've got the JWT, and it's been parsed, validated, and/or decrypted, we can process the claims
289
285
290
- String clientId = claims .getStringClaim ("client_id" );
291
- if (clientId != null ) {
292
- parameters .put ("client_id" , clientId );
293
- }
286
+ ReadOnlyJWTClaimsSet claims = jwt .getJWTClaimsSet ();
294
287
295
- String responseTypes = claims .getStringClaim ("response_type" );
296
- if (responseTypes != null ) {
297
- parameters .put ("response_type" , responseTypes );
288
+ Set <String > responseTypes = OAuth2Utils .parseParameterList (claims .getStringClaim ("response_type" ));
289
+ if (responseTypes != null && !responseTypes .isEmpty ()) {
290
+ if (request .getResponseTypes () == null || request .getResponseTypes ().isEmpty ()) {
291
+ // if it's null or empty, we fill in the value with what we were passed
292
+ request .setResponseTypes (responseTypes );
293
+ } else if (!request .getResponseTypes ().equals (responseTypes )) {
294
+ // FIXME: throw an error
295
+ }
298
296
}
299
297
300
- if (claims .getStringClaim ("redirect_uri" ) != null ) {
301
- if (inputParams .containsKey ("redirect_uri" ) == false ) {
302
- parameters .put ("redirect_uri" , claims .getStringClaim ("redirect_uri" ));
298
+ String redirectUri = claims .getStringClaim ("redirect_uri" );
299
+ if (redirectUri != null ) {
300
+ if (request .getRedirectUri () == null ) {
301
+ request .setRedirectUri (redirectUri );
302
+ } else if (!request .getRedirectUri ().equals (redirectUri )) {
303
+ // FIXME: throw an error
303
304
}
304
305
}
305
306
306
307
String state = claims .getStringClaim ("state" );
307
308
if (state != null ) {
308
- if (inputParams .containsKey ("state" ) == false ) {
309
- parameters .put ("state" , state );
309
+ if (request .getState () == null ) {
310
+ request .setState (state );
311
+ } else if (!request .getState ().equals (state )) {
312
+ // FIXME: throw an error
310
313
}
311
314
}
312
315
313
316
String nonce = claims .getStringClaim ("nonce" );
314
317
if (nonce != null ) {
315
- if (inputParams .containsKey ("nonce" ) == false ) {
316
- parameters .put ("nonce" , nonce );
318
+ if (request .getExtensions ().get ("nonce" ) == null ) {
319
+ request .getExtensions ().put ("nonce" , nonce );
320
+ } else if (!request .getExtensions ().get ("nonce" ).equals (nonce )) {
321
+ // FIXME: throw an error
317
322
}
318
323
}
319
324
320
325
String display = claims .getStringClaim ("display" );
321
326
if (display != null ) {
322
- if (inputParams .containsKey ("display" ) == false ) {
323
- parameters .put ("display" , display );
327
+ if (request .getExtensions ().get ("display" ) == null ) {
328
+ request .getExtensions ().put ("display" , display );
329
+ } else if (!request .getExtensions ().get ("display" ).equals (display )) {
330
+ // FIXME: throw an error
324
331
}
325
332
}
326
333
327
334
String prompt = claims .getStringClaim ("prompt" );
328
335
if (prompt != null ) {
329
- if (inputParams .containsKey ("prompt" ) == false ) {
330
- parameters .put ("prompt" , prompt );
336
+ if (request .getExtensions ().get ("prompt" ) == null ) {
337
+ request .getExtensions ().put ("prompt" , prompt );
338
+ } else if (!request .getExtensions ().get ("prompt" ).equals (prompt )) {
339
+ // FIXME: throw an error
331
340
}
332
341
}
333
-
334
- String scope = claims .getStringClaim ("scope" );
335
- if (scope != null ) {
336
- if (inputParams .containsKey ("scope" ) == false ) {
337
- parameters .put ("scope" , scope );
342
+
343
+ Set <String > scope = OAuth2Utils .parseParameterList (claims .getStringClaim ("scope" ));
344
+ if (scope != null && !scope .isEmpty ()) {
345
+ if (request .getScope () == null || request .getScope ().isEmpty ()) {
346
+ request .setScope (scope );
347
+ } else if (!request .getScope ().equals (scope )) {
348
+ // FIXME: throw an error
338
349
}
339
350
}
351
+
352
+ JsonObject claimRequest = parseClaimRequest (claims .getStringClaim ("claims" ));
353
+ if (claimRequest != null ) {
354
+ if (request .getExtensions ().get ("claims" ) == null ) {
355
+ // we save the string because the object might not serialize
356
+ request .getExtensions ().put ("claims" , claimRequest .toString ());
357
+ } else if (parseClaimRequest (request .getExtensions ().get ("claims" ).toString ()).equals (claimRequest )) {
358
+ // FIXME: throw an error
359
+ }
360
+ }
361
+
340
362
} catch (ParseException e ) {
341
363
logger .error ("ParseException while parsing RequestObject:" , e );
342
364
}
343
- return parameters ;
344
365
}
345
366
367
+ /**
368
+ * @param claimRequestString
369
+ * @return
370
+ */
371
+ private JsonObject parseClaimRequest (String claimRequestString ) {
372
+ JsonElement el = parser .parse (claimRequestString );
373
+ if (el != null && el .isJsonObject ()) {
374
+ return el .getAsJsonObject ();
375
+ } else {
376
+ return null ;
377
+ }
378
+ }
379
+
346
380
/**
347
381
* Create a symmetric signing and validation service for the given client
348
382
*
0 commit comments