Skip to content

Commit 3f27704

Browse files
author
Sauli Ketola
committed
Use query by user sub to get all tokens for user
1 parent 417a6b7 commit 3f27704

File tree

2 files changed

+30
-28
lines changed

2 files changed

+30
-28
lines changed

openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
import org.springframework.transaction.annotation.Transactional;
6767

6868
import com.google.common.base.Strings;
69-
import com.google.common.collect.Sets;
7069
import com.nimbusds.jose.util.Base64URL;
7170
import com.nimbusds.jwt.JWTClaimsSet;
7271
import com.nimbusds.jwt.PlainJWT;
@@ -102,35 +101,14 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
102101
@Autowired
103102
private ApprovedSiteService approvedSiteService;
104103

105-
106104
@Override
107-
public Set<OAuth2AccessTokenEntity> getAllAccessTokensForUser(String id) {
108-
109-
Set<OAuth2AccessTokenEntity> all = tokenRepository.getAllAccessTokens();
110-
Set<OAuth2AccessTokenEntity> results = Sets.newLinkedHashSet();
111-
112-
for (OAuth2AccessTokenEntity token : all) {
113-
if (clearExpiredAccessToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
114-
results.add(token);
115-
}
116-
}
117-
118-
return results;
105+
public Set<OAuth2AccessTokenEntity> getAllAccessTokensForUser(String sub) {
106+
return tokenRepository.getAccessTokensBySub(sub);
119107
}
120108

121-
122109
@Override
123-
public Set<OAuth2RefreshTokenEntity> getAllRefreshTokensForUser(String id) {
124-
Set<OAuth2RefreshTokenEntity> all = tokenRepository.getAllRefreshTokens();
125-
Set<OAuth2RefreshTokenEntity> results = Sets.newLinkedHashSet();
126-
127-
for (OAuth2RefreshTokenEntity token : all) {
128-
if (clearExpiredRefreshToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
129-
results.add(token);
130-
}
131-
}
132-
133-
return results;
110+
public Set<OAuth2RefreshTokenEntity> getAllRefreshTokensForUser(String sub) {
111+
return tokenRepository.getRefreshTokensBySub(sub);
134112
}
135113

136114
@Override

openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.HashSet;
2222
import java.util.Set;
2323

24+
import org.junit.Assert;
2425
import org.junit.Before;
2526
import org.junit.Test;
2627
import org.junit.runner.RunWith;
@@ -52,6 +53,7 @@
5253

5354
import com.google.common.collect.Sets;
5455

56+
import static com.google.common.collect.Sets.newHashSet;
5557
import static org.hamcrest.CoreMatchers.equalTo;
5658
import static org.hamcrest.CoreMatchers.is;
5759
import static org.hamcrest.CoreMatchers.not;
@@ -60,7 +62,8 @@
6062

6163
import static org.mockito.Mockito.never;
6264
import static org.mockito.Mockito.when;
63-
65+
import static org.junit.Assert.assertEquals;
66+
import static org.junit.Assert.assertFalse;
6467
import static org.junit.Assert.assertThat;
6568
import static org.junit.Assert.assertTrue;
6669
import static org.junit.Assert.fail;
@@ -83,7 +86,9 @@ public class TestDefaultOAuth2ProviderTokenService {
8386
private String badClientId = "bad_client";
8487
private Set<String> scope = Sets.newHashSet("openid", "profile", "email", "offline_access");
8588
private OAuth2RefreshTokenEntity refreshToken;
89+
private OAuth2AccessTokenEntity accessToken;
8690
private String refreshTokenValue = "refresh_token_value";
91+
private String userSub = "6a50ac11786d402a9591d3e592ac770f";
8792
private TokenRequest tokenRequest;
8893

8994
// for use when refreshing access tokens
@@ -142,6 +147,8 @@ public void prepare() {
142147
Mockito.when(tokenRepository.getRefreshTokenByValue(refreshTokenValue)).thenReturn(refreshToken);
143148
Mockito.when(refreshToken.getClient()).thenReturn(client);
144149
Mockito.when(refreshToken.isExpired()).thenReturn(false);
150+
151+
accessToken = Mockito.mock(OAuth2AccessTokenEntity.class);
145152

146153
tokenRequest = new TokenRequest(null, clientId, null, null);
147154

@@ -542,5 +549,22 @@ public void refreshAccessToken_expiration() {
542549

543550
assertTrue(token.getExpiration().after(lowerBoundAccessTokens) && token.getExpiration().before(upperBoundAccessTokens));
544551
}
545-
552+
553+
@Test
554+
public void getAllAccessTokensForUser(){
555+
Mockito.when(tokenRepository.getAccessTokensBySub(userSub)).thenReturn(newHashSet(accessToken));
556+
557+
Set<OAuth2AccessTokenEntity> tokens = service.getAllAccessTokensForUser(userSub);
558+
assertEquals(1, tokens.size());
559+
assertTrue(tokens.contains(accessToken));
560+
}
561+
562+
@Test
563+
public void getAllRefreshTokensForUser(){
564+
Mockito.when(tokenRepository.getRefreshTokensBySub(userSub)).thenReturn(newHashSet(refreshToken));
565+
566+
Set<OAuth2RefreshTokenEntity> tokens = service.getAllRefreshTokensForUser(userSub);
567+
assertEquals(1, tokens.size());
568+
assertTrue(tokens.contains(refreshToken));
569+
}
546570
}

0 commit comments

Comments
 (0)