Skip to content

Commit 9b72c6b

Browse files
author
Justin Richer
committed
check sector identifier URI's contents and match against redirect URIs, addresses mitreid-connect#504
1 parent 1aa5fe2 commit 9b72c6b

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,17 @@
1818

1919
import java.math.BigInteger;
2020
import java.security.SecureRandom;
21+
import java.util.ArrayList;
2122
import java.util.Collection;
2223
import java.util.Date;
24+
import java.util.List;
2325
import java.util.UUID;
26+
import java.util.concurrent.ExecutionException;
27+
import java.util.concurrent.TimeUnit;
2428

2529
import org.apache.commons.codec.binary.Base64;
30+
import org.apache.http.client.HttpClient;
31+
import org.apache.http.impl.client.DefaultHttpClient;
2632
import org.mitre.oauth2.model.ClientDetailsEntity;
2733
import org.mitre.oauth2.repository.OAuth2ClientRepository;
2834
import org.mitre.oauth2.repository.OAuth2TokenRepository;
@@ -31,16 +37,27 @@
3137
import org.mitre.openid.connect.service.ApprovedSiteService;
3238
import org.mitre.openid.connect.service.BlacklistedSiteService;
3339
import org.mitre.openid.connect.service.WhitelistedSiteService;
40+
import org.slf4j.Logger;
41+
import org.slf4j.LoggerFactory;
3442
import org.springframework.beans.factory.annotation.Autowired;
43+
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
3544
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
3645
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
3746
import org.springframework.stereotype.Service;
47+
import org.springframework.web.client.RestTemplate;
3848

3949
import com.google.common.base.Strings;
50+
import com.google.common.cache.CacheBuilder;
51+
import com.google.common.cache.CacheLoader;
52+
import com.google.common.cache.LoadingCache;
53+
import com.google.gson.JsonElement;
54+
import com.google.gson.JsonParser;
4055

4156
@Service
4257
public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEntityService {
4358

59+
private static Logger logger = LoggerFactory.getLogger(DefaultOAuth2ClientDetailsEntityService.class);
60+
4461
@Autowired
4562
private OAuth2ClientRepository clientRepository;
4663

@@ -56,6 +73,12 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
5673
@Autowired
5774
private BlacklistedSiteService blacklistedSiteService;
5875

76+
// map of sector URI -> list of redirect URIs
77+
private LoadingCache<String, List<String>> sectorRedirects = CacheBuilder.newBuilder()
78+
.expireAfterAccess(1, TimeUnit.HOURS)
79+
.maximumSize(100)
80+
.build(new SectorIdentifierLoader());
81+
5982
@Override
6083
public ClientDetailsEntity saveNewClient(ClientDetailsEntity client) {
6184
if (client.getId() != null) { // if it's not null, it's already been saved, this is an error
@@ -85,6 +108,26 @@ public ClientDetailsEntity saveNewClient(ClientDetailsEntity client) {
85108

86109
// timestamp this to right now
87110
client.setCreatedAt(new Date());
111+
112+
113+
// check the sector URI
114+
if (!Strings.isNullOrEmpty(client.getSectorIdentifierUri())) {
115+
try {
116+
List<String> redirects = sectorRedirects.get(client.getSectorIdentifierUri());
117+
118+
if (client.getRegisteredRedirectUri() != null) {
119+
for (String uri : client.getRegisteredRedirectUri()) {
120+
if (!redirects.contains(uri)) {
121+
throw new IllegalArgumentException("Requested Redirect URI " + uri + " is not listed at sector identifier " + redirects);
122+
}
123+
}
124+
}
125+
126+
} catch (ExecutionException e) {
127+
throw new IllegalArgumentException("Unable to load sector identifier URI: " + client.getSectorIdentifierUri());
128+
}
129+
}
130+
88131

89132
return clientRepository.saveClient(client);
90133
}
@@ -165,6 +208,24 @@ public ClientDetailsEntity updateClient(ClientDetailsEntity oldClient, ClientDet
165208
newClient.getScope().remove("offline_access");
166209
}
167210

211+
// check the sector URI
212+
if (!Strings.isNullOrEmpty(newClient.getSectorIdentifierUri())) {
213+
try {
214+
List<String> redirects = sectorRedirects.get(newClient.getSectorIdentifierUri());
215+
216+
if (newClient.getRegisteredRedirectUri() != null) {
217+
for (String uri : newClient.getRegisteredRedirectUri()) {
218+
if (!redirects.contains(uri)) {
219+
throw new IllegalArgumentException("Requested Redirect URI " + uri + " is not listed at sector identifier " + redirects);
220+
}
221+
}
222+
}
223+
224+
} catch (ExecutionException e) {
225+
throw new IllegalArgumentException("Unable to load sector identifier URI: " + newClient.getSectorIdentifierUri());
226+
}
227+
}
228+
168229
return clientRepository.updateClient(oldClient.getId(), newClient);
169230
}
170231
throw new IllegalArgumentException("Neither old client or new client can be null!");
@@ -196,4 +257,45 @@ public ClientDetailsEntity generateClientSecret(ClientDetailsEntity client) {
196257
return client;
197258
}
198259

260+
/**
261+
* Utility class to load a sector identifier's set of authorized redirect URIs.
262+
*
263+
* @author jricher
264+
*
265+
*/
266+
private class SectorIdentifierLoader extends CacheLoader<String, List<String>> {
267+
private HttpClient httpClient = new DefaultHttpClient();
268+
private HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
269+
private RestTemplate restTemplate = new RestTemplate(httpFactory);
270+
private JsonParser parser = new JsonParser();
271+
272+
@Override
273+
public List<String> load(String key) throws Exception {
274+
275+
if (!key.startsWith("https")) {
276+
// TODO: this should optionally throw an error (#506)
277+
logger.error("Sector identifier doesn't start with https, loading anyway...");
278+
}
279+
280+
// key is the sector URI
281+
String jsonString = restTemplate.getForObject(key, String.class);
282+
JsonElement json = parser.parse(jsonString);
283+
284+
if (json.isJsonArray()) {
285+
List<String> redirectUris = new ArrayList<String>();
286+
for (JsonElement el : json.getAsJsonArray()) {
287+
redirectUris.add(el.getAsString());
288+
}
289+
290+
logger.info("Found " + redirectUris + " for sector " + key);
291+
292+
return redirectUris;
293+
} else {
294+
return null;
295+
}
296+
297+
}
298+
299+
}
300+
199301
}

0 commit comments

Comments
 (0)