Skip to content

Commit b396610

Browse files
author
Justin Richer
committed
refactor processing of request object
1 parent 47d3048 commit b396610

File tree

1 file changed

+126
-92
lines changed

1 file changed

+126
-92
lines changed

openid-connect-server/src/main/java/org/mitre/openid/connect/ConnectOAuth2RequestFactory.java

Lines changed: 126 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@
1616
******************************************************************************/
1717
package org.mitre.openid.connect;
1818

19-
import java.io.Serializable;
2019
import java.security.NoSuchAlgorithmException;
2120
import java.security.spec.InvalidKeySpecException;
2221
import java.text.ParseException;
2322
import java.util.Collections;
24-
import java.util.HashMap;
2523
import java.util.Map;
2624
import java.util.Set;
2725

@@ -45,7 +43,9 @@
4543

4644
import com.google.common.base.Strings;
4745
import com.google.common.collect.ImmutableMap;
48-
import com.google.common.collect.Maps;
46+
import com.google.gson.JsonElement;
47+
import com.google.gson.JsonObject;
48+
import com.google.gson.JsonParser;
4949
import com.nimbusds.jose.Algorithm;
5050
import com.nimbusds.jose.JWEObject.State;
5151
import com.nimbusds.jose.JWSAlgorithm;
@@ -76,6 +76,8 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
7676
@Autowired
7777
private JwtEncryptionAndDecryptionService encryptionService;
7878

79+
private JsonParser parser = new JsonParser();
80+
7981
/**
8082
* Constructor with arguments
8183
*
@@ -97,43 +99,43 @@ public OAuth2Request createOAuth2Request(AuthorizationRequest request) {
9799
@Override
98100
public AuthorizationRequest createAuthorizationRequest(Map<String, String> inputParams) {
99101

100-
Map<String, String> parameters = processRequestObject(inputParams);
101-
102-
String clientId = parameters.get("client_id");
103-
ClientDetails client = null;
104-
105-
if (clientId != null) {
106-
client = clientDetailsService.loadClientByClientId(clientId);
107-
}
108-
109-
AuthorizationRequest request = new AuthorizationRequest(parameters, Collections.<String, String> emptyMap(),
110-
parameters.get(OAuth2Utils.CLIENT_ID),
111-
OAuth2Utils.parseParameterList(parameters.get(OAuth2Utils.SCOPE)), null,
112-
null, false, parameters.get(OAuth2Utils.STATE),
113-
parameters.get(OAuth2Utils.REDIRECT_URI),
114-
OAuth2Utils.parseParameterList(parameters.get(OAuth2Utils.RESPONSE_TYPE)));
115-
116-
Set<String> scopes = OAuth2Utils.parseParameterList(parameters.get("scope"));
117-
if ((scopes == null || scopes.isEmpty()) && client != null) {
118-
Set<String> clientScopes = client.getScope();
119-
scopes = clientScopes;
120-
}
121102

122-
request.setScope(scopes);
103+
AuthorizationRequest request = new AuthorizationRequest(inputParams, Collections.<String, String> emptyMap(),
104+
inputParams.get(OAuth2Utils.CLIENT_ID),
105+
OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.SCOPE)), null,
106+
null, false, inputParams.get(OAuth2Utils.STATE),
107+
inputParams.get(OAuth2Utils.REDIRECT_URI),
108+
OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.RESPONSE_TYPE)));
123109

124110
//Add extension parameters to the 'extensions' map
125-
Map<String, Serializable> extensions = Maps.newHashMap();
126-
if (parameters.containsKey("prompt")) {
127-
extensions.put("prompt", parameters.get("prompt"));
111+
112+
if (inputParams.containsKey("prompt")) {
113+
request.getExtensions().put("prompt", inputParams.get("prompt"));
114+
}
115+
if (inputParams.containsKey("nonce")) {
116+
request.getExtensions().put("nonce", inputParams.get("nonce"));
128117
}
129-
if (parameters.containsKey("request")) {
130-
extensions.put("request", parameters.get("request"));
118+
119+
if (inputParams.containsKey("claims")) {
120+
JsonObject claimsRequest = parseClaimRequest(inputParams.get("claims"));
121+
if (claimsRequest != null) {
122+
request.getExtensions().put("claims", claimsRequest.toString());
123+
}
131124
}
132-
if (parameters.containsKey("nonce")) {
133-
extensions.put("nonce", parameters.get("nonce"));
125+
126+
if (inputParams.containsKey("request")) {
127+
request.getExtensions().put("request", inputParams.get("request"));
128+
processRequestObject(inputParams.get("request"), request);
134129
}
135130

136-
request.setExtensions(extensions);
131+
132+
if ((request.getScope() == null || request.getScope().isEmpty())) {
133+
if (request.getClientId() != null) {
134+
ClientDetails client = clientDetailsService.loadClientByClientId(request.getClientId());
135+
Set<String> clientScopes = client.getScope();
136+
request.setScope(clientScopes);
137+
}
138+
}
137139

138140
return request;
139141
}
@@ -142,49 +144,28 @@ public AuthorizationRequest createAuthorizationRequest(Map<String, String> input
142144
* @param inputParams
143145
* @return
144146
*/
145-
private Map<String, String> processRequestObject(Map<String, String> inputParams) {
146-
147-
String jwtString = inputParams.get("request");
148-
149-
// if there's no request object, bail early
150-
if (Strings.isNullOrEmpty(jwtString)) {
151-
return inputParams;
152-
}
153-
154-
// start by copying over what's already in there
155-
Map<String, String> parameters = new HashMap<String, String>(inputParams);
147+
private void processRequestObject(String jwtString, AuthorizationRequest request) {
156148

157149
// parse the request object
158150
try {
159151
JWT jwt = JWTParser.parse(jwtString);
160152

161-
/*
162-
if (jwt instanceof EncryptedJWT) {
163-
// TODO: it's an encrypted JWT, decrypt it and use it
164-
} else {
165-
// it's not encrypted...
166-
}
167-
*/
168-
169-
170-
171-
172153
// TODO: check parameter consistency, move keys to constants
173154

174155
if (jwt instanceof SignedJWT) {
175156
// it's a signed JWT, check the signature
176157

177158
SignedJWT signedJwt = (SignedJWT)jwt;
178-
179-
String clientId = inputParams.get("client_id");
180-
if (clientId == null) {
181-
clientId = signedJwt.getJWTClaimsSet().getStringClaim("client_id");
159+
160+
// need to check clientId first so that we can load the client to check other fields
161+
if (request.getClientId() == null) {
162+
request.setClientId(signedJwt.getJWTClaimsSet().getStringClaim("client_id"));
182163
}
183164

184-
ClientDetailsEntity client = clientDetailsService.loadClientByClientId(clientId);
165+
ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
185166

186167
if (client == null) {
187-
throw new InvalidClientException("Client not found: " + clientId);
168+
throw new InvalidClientException("Client not found: " + request.getClientId());
188169
}
189170

190171

@@ -239,15 +220,15 @@ private Map<String, String> processRequestObject(Map<String, String> inputParams
239220
} else if (jwt instanceof PlainJWT) {
240221
PlainJWT plainJwt = (PlainJWT)jwt;
241222

242-
String clientId = inputParams.get("client_id");
243-
if (clientId == null) {
244-
clientId = plainJwt.getJWTClaimsSet().getStringClaim("client_id");
223+
// need to check clientId first so that we can load the client to check other fields
224+
if (request.getClientId() == null) {
225+
request.setClientId(plainJwt.getJWTClaimsSet().getStringClaim("client_id"));
245226
}
246227

247-
ClientDetailsEntity client = clientDetailsService.loadClientByClientId(clientId);
228+
ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
248229

249230
if (client == null) {
250-
throw new InvalidClientException("Client not found: " + clientId);
231+
throw new InvalidClientException("Client not found: " + request.getClientId());
251232
}
252233

253234
if (client.getRequestObjectSigningAlg() == null) {
@@ -270,13 +251,28 @@ private Map<String, String> processRequestObject(Map<String, String> inputParams
270251
throw new InvalidClientException("Unable to decrypt the request object");
271252
}
272253

254+
// need to check clientId first so that we can load the client to check other fields
255+
if (request.getClientId() == null) {
256+
request.setClientId(encryptedJWT.getJWTClaimsSet().getStringClaim("client_id"));
257+
}
258+
259+
ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
260+
261+
if (client == null) {
262+
throw new InvalidClientException("Client not found: " + request.getClientId());
263+
}
264+
265+
273266
}
274267

268+
275269
/*
270+
* Claims precedence order logic:
271+
*
276272
* if (in Claims):
277273
* if (in params):
278274
* if (equal):
279-
* all set
275+
* OK
280276
* else (not equal):
281277
* error
282278
* else (not in params):
@@ -285,64 +281,102 @@ private Map<String, String> processRequestObject(Map<String, String> inputParams
285281
* we don't care
286282
*/
287283

288-
ReadOnlyJWTClaimsSet claims = jwt.getJWTClaimsSet();
284+
// now that we've got the JWT, and it's been parsed, validated, and/or decrypted, we can process the claims
289285

290-
String clientId = claims.getStringClaim("client_id");
291-
if (clientId != null) {
292-
parameters.put("client_id", clientId);
293-
}
286+
ReadOnlyJWTClaimsSet claims = jwt.getJWTClaimsSet();
294287

295-
String responseTypes = claims.getStringClaim("response_type");
296-
if (responseTypes != null) {
297-
parameters.put("response_type", responseTypes);
288+
Set<String> responseTypes = OAuth2Utils.parseParameterList(claims.getStringClaim("response_type"));
289+
if (responseTypes != null && !responseTypes.isEmpty()) {
290+
if (request.getResponseTypes() == null || request.getResponseTypes().isEmpty()) {
291+
// if it's null or empty, we fill in the value with what we were passed
292+
request.setResponseTypes(responseTypes);
293+
} else if (!request.getResponseTypes().equals(responseTypes)) {
294+
// FIXME: throw an error
295+
}
298296
}
299297

300-
if (claims.getStringClaim("redirect_uri") != null) {
301-
if (inputParams.containsKey("redirect_uri") == false) {
302-
parameters.put("redirect_uri", claims.getStringClaim("redirect_uri"));
298+
String redirectUri = claims.getStringClaim("redirect_uri");
299+
if (redirectUri != null) {
300+
if (request.getRedirectUri() == null) {
301+
request.setRedirectUri(redirectUri);
302+
} else if (!request.getRedirectUri().equals(redirectUri)) {
303+
// FIXME: throw an error
303304
}
304305
}
305306

306307
String state = claims.getStringClaim("state");
307308
if(state != null) {
308-
if (inputParams.containsKey("state") == false) {
309-
parameters.put("state", state);
309+
if (request.getState() == null) {
310+
request.setState(state);
311+
} else if (!request.getState().equals(state)) {
312+
// FIXME: throw an error
310313
}
311314
}
312315

313316
String nonce = claims.getStringClaim("nonce");
314317
if(nonce != null) {
315-
if (inputParams.containsKey("nonce") == false) {
316-
parameters.put("nonce", nonce);
318+
if (request.getExtensions().get("nonce") == null) {
319+
request.getExtensions().put("nonce", nonce);
320+
} else if (!request.getExtensions().get("nonce").equals(nonce)) {
321+
// FIXME: throw an error
317322
}
318323
}
319324

320325
String display = claims.getStringClaim("display");
321326
if (display != null) {
322-
if (inputParams.containsKey("display") == false) {
323-
parameters.put("display", display);
327+
if (request.getExtensions().get("display") == null) {
328+
request.getExtensions().put("display", display);
329+
} else if (!request.getExtensions().get("display").equals(display)) {
330+
// FIXME: throw an error
324331
}
325332
}
326333

327334
String prompt = claims.getStringClaim("prompt");
328335
if (prompt != null) {
329-
if (inputParams.containsKey("prompt") == false) {
330-
parameters.put("prompt", prompt);
336+
if (request.getExtensions().get("prompt") == null) {
337+
request.getExtensions().put("prompt", prompt);
338+
} else if (!request.getExtensions().get("prompt").equals(prompt)) {
339+
// FIXME: throw an error
331340
}
332341
}
333-
334-
String scope = claims.getStringClaim("scope");
335-
if (scope != null) {
336-
if (inputParams.containsKey("scope") == false) {
337-
parameters.put("scope", scope);
342+
343+
Set<String> scope = OAuth2Utils.parseParameterList(claims.getStringClaim("scope"));
344+
if (scope != null && !scope.isEmpty()) {
345+
if (request.getScope() == null || request.getScope().isEmpty()) {
346+
request.setScope(scope);
347+
} else if (!request.getScope().equals(scope)) {
348+
// FIXME: throw an error
338349
}
339350
}
351+
352+
JsonObject claimRequest = parseClaimRequest(claims.getStringClaim("claims"));
353+
if (claimRequest != null) {
354+
if (request.getExtensions().get("claims") == null) {
355+
// we save the string because the object might not serialize
356+
request.getExtensions().put("claims", claimRequest.toString());
357+
} else if (parseClaimRequest(request.getExtensions().get("claims").toString()).equals(claimRequest)) {
358+
// FIXME: throw an error
359+
}
360+
}
361+
340362
} catch (ParseException e) {
341363
logger.error("ParseException while parsing RequestObject:", e);
342364
}
343-
return parameters;
344365
}
345366

367+
/**
368+
* @param claimRequestString
369+
* @return
370+
*/
371+
private JsonObject parseClaimRequest(String claimRequestString) {
372+
JsonElement el = parser .parse(claimRequestString);
373+
if (el != null && el.isJsonObject()) {
374+
return el.getAsJsonObject();
375+
} else {
376+
return null;
377+
}
378+
}
379+
346380
/**
347381
* Create a symmetric signing and validation service for the given client
348382
*

0 commit comments

Comments
 (0)