1818
1919import java .math .BigInteger ;
2020import java .security .SecureRandom ;
21+ import java .util .ArrayList ;
2122import java .util .Collection ;
2223import java .util .Date ;
24+ import java .util .List ;
2325import java .util .UUID ;
26+ import java .util .concurrent .ExecutionException ;
27+ import java .util .concurrent .TimeUnit ;
2428
2529import org .apache .commons .codec .binary .Base64 ;
30+ import org .apache .http .client .HttpClient ;
31+ import org .apache .http .impl .client .DefaultHttpClient ;
2632import org .mitre .oauth2 .model .ClientDetailsEntity ;
2733import org .mitre .oauth2 .repository .OAuth2ClientRepository ;
2834import org .mitre .oauth2 .repository .OAuth2TokenRepository ;
3137import org .mitre .openid .connect .service .ApprovedSiteService ;
3238import org .mitre .openid .connect .service .BlacklistedSiteService ;
3339import org .mitre .openid .connect .service .WhitelistedSiteService ;
40+ import org .slf4j .Logger ;
41+ import org .slf4j .LoggerFactory ;
3442import org .springframework .beans .factory .annotation .Autowired ;
43+ import org .springframework .http .client .HttpComponentsClientHttpRequestFactory ;
3544import org .springframework .security .oauth2 .common .exceptions .InvalidClientException ;
3645import org .springframework .security .oauth2 .common .exceptions .OAuth2Exception ;
3746import org .springframework .stereotype .Service ;
47+ import org .springframework .web .client .RestTemplate ;
3848
3949import com .google .common .base .Strings ;
50+ import com .google .common .cache .CacheBuilder ;
51+ import com .google .common .cache .CacheLoader ;
52+ import com .google .common .cache .LoadingCache ;
53+ import com .google .gson .JsonElement ;
54+ import com .google .gson .JsonParser ;
4055
4156@ Service
4257public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEntityService {
4358
59+ private static Logger logger = LoggerFactory .getLogger (DefaultOAuth2ClientDetailsEntityService .class );
60+
4461@ Autowired
4562private OAuth2ClientRepository clientRepository ;
4663
@@ -56,6 +73,12 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
5673@ Autowired
5774private BlacklistedSiteService blacklistedSiteService ;
5875
76+ // map of sector URI -> list of redirect URIs
77+ private LoadingCache <String , List <String >> sectorRedirects = CacheBuilder .newBuilder ()
78+ .expireAfterAccess (1 , TimeUnit .HOURS )
79+ .maximumSize (100 )
80+ .build (new SectorIdentifierLoader ());
81+
5982@ Override
6083public ClientDetailsEntity saveNewClient (ClientDetailsEntity client ) {
6184if (client .getId () != null ) { // if it's not null, it's already been saved, this is an error
@@ -85,6 +108,26 @@ public ClientDetailsEntity saveNewClient(ClientDetailsEntity client) {
85108
86109// timestamp this to right now
87110client .setCreatedAt (new Date ());
111+
112+
113+ // check the sector URI
114+ if (!Strings .isNullOrEmpty (client .getSectorIdentifierUri ())) {
115+ try {
116+ List <String > redirects = sectorRedirects .get (client .getSectorIdentifierUri ());
117+
118+ if (client .getRegisteredRedirectUri () != null ) {
119+ for (String uri : client .getRegisteredRedirectUri ()) {
120+ if (!redirects .contains (uri )) {
121+ throw new IllegalArgumentException ("Requested Redirect URI " + uri + " is not listed at sector identifier " + redirects );
122+ }
123+ }
124+ }
125+
126+ } catch (ExecutionException e ) {
127+ throw new IllegalArgumentException ("Unable to load sector identifier URI: " + client .getSectorIdentifierUri ());
128+ }
129+ }
130+
88131
89132return clientRepository .saveClient (client );
90133}
@@ -165,6 +208,24 @@ public ClientDetailsEntity updateClient(ClientDetailsEntity oldClient, ClientDet
165208newClient .getScope ().remove ("offline_access" );
166209}
167210
211+ // check the sector URI
212+ if (!Strings .isNullOrEmpty (newClient .getSectorIdentifierUri ())) {
213+ try {
214+ List <String > redirects = sectorRedirects .get (newClient .getSectorIdentifierUri ());
215+
216+ if (newClient .getRegisteredRedirectUri () != null ) {
217+ for (String uri : newClient .getRegisteredRedirectUri ()) {
218+ if (!redirects .contains (uri )) {
219+ throw new IllegalArgumentException ("Requested Redirect URI " + uri + " is not listed at sector identifier " + redirects );
220+ }
221+ }
222+ }
223+
224+ } catch (ExecutionException e ) {
225+ throw new IllegalArgumentException ("Unable to load sector identifier URI: " + newClient .getSectorIdentifierUri ());
226+ }
227+ }
228+
168229return clientRepository .updateClient (oldClient .getId (), newClient );
169230}
170231throw new IllegalArgumentException ("Neither old client or new client can be null!" );
@@ -196,4 +257,45 @@ public ClientDetailsEntity generateClientSecret(ClientDetailsEntity client) {
196257return client ;
197258}
198259
260+ /**
261+ * Utility class to load a sector identifier's set of authorized redirect URIs.
262+ *
263+ * @author jricher
264+ *
265+ */
266+ private class SectorIdentifierLoader extends CacheLoader <String , List <String >> {
267+ private HttpClient httpClient = new DefaultHttpClient ();
268+ private HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory (httpClient );
269+ private RestTemplate restTemplate = new RestTemplate (httpFactory );
270+ private JsonParser parser = new JsonParser ();
271+
272+ @ Override
273+ public List <String > load (String key ) throws Exception {
274+
275+ if (!key .startsWith ("https" )) {
276+ // TODO: this should optionally throw an error (#506)
277+ logger .error ("Sector identifier doesn't start with https, loading anyway..." );
278+ }
279+
280+ // key is the sector URI
281+ String jsonString = restTemplate .getForObject (key , String .class );
282+ JsonElement json = parser .parse (jsonString );
283+
284+ if (json .isJsonArray ()) {
285+ List <String > redirectUris = new ArrayList <String >();
286+ for (JsonElement el : json .getAsJsonArray ()) {
287+ redirectUris .add (el .getAsString ());
288+ }
289+
290+ logger .info ("Found " + redirectUris + " for sector " + key );
291+
292+ return redirectUris ;
293+ } else {
294+ return null ;
295+ }
296+
297+ }
298+
299+ }
300+
199301}
0 commit comments