Skip to content

Commit dd922b4

Browse files
authored
Merge pull request mitreid-connect#1378 from ketola/fetch-tokens-by-sub
Fetch tokens by user name
2 parents 938d7e0 + e6a8e0c commit dd922b4

File tree

9 files changed

+365
-170
lines changed

9 files changed

+365
-170
lines changed

openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@
7171
@NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_CLIENT, query = "select a from OAuth2AccessTokenEntity a where a.client = :" + OAuth2AccessTokenEntity.PARAM_CLIENT),
7272
@NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_TOKEN_VALUE, query = "select a from OAuth2AccessTokenEntity a where a.jwt = :" + OAuth2AccessTokenEntity.PARAM_TOKEN_VALUE),
7373
@NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_APPROVED_SITE, query = "select a from OAuth2AccessTokenEntity a where a.approvedSite = :" + OAuth2AccessTokenEntity.PARAM_APPROVED_SITE),
74-
@NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, query = "select a from OAuth2AccessTokenEntity a join a.permissions p where p.resourceSet.id = :" + OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID)
74+
@NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, query = "select a from OAuth2AccessTokenEntity a join a.permissions p where p.resourceSet.id = :" + OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID),
75+
@NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_NAME, query = "select r from OAuth2AccessTokenEntity r where r.authenticationHolder.userAuth.name = :" + OAuth2AccessTokenEntity.PARAM_NAME)
7576
})
7677
@org.codehaus.jackson.map.annotate.JsonSerialize(using = OAuth2AccessTokenJackson1Serializer.class)
7778
@org.codehaus.jackson.map.annotate.JsonDeserialize(using = OAuth2AccessTokenJackson1Deserializer.class)
@@ -86,13 +87,15 @@ public class OAuth2AccessTokenEntity implements OAuth2AccessToken {
8687
public static final String QUERY_EXPIRED_BY_DATE = "OAuth2AccessTokenEntity.getAllExpiredByDate";
8788
public static final String QUERY_ALL = "OAuth2AccessTokenEntity.getAll";
8889
public static final String QUERY_BY_RESOURCE_SET = "OAuth2AccessTokenEntity.getByResourceSet";
90+
public static final String QUERY_BY_NAME = "OAuth2AccessTokenEntity.getByName";
8991

9092
public static final String PARAM_TOKEN_VALUE = "tokenValue";
9193
public static final String PARAM_CLIENT = "client";
9294
public static final String PARAM_REFERSH_TOKEN = "refreshToken";
9395
public static final String PARAM_DATE = "date";
9496
public static final String PARAM_RESOURCE_SET_ID = "rsid";
9597
public static final String PARAM_APPROVED_SITE = "approvedSite";
98+
public static final String PARAM_NAME = "name";
9699

97100
public static final String ID_TOKEN_FIELD_NAME = "id_token";
98101

openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2RefreshTokenEntity.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,21 @@
5353
@NamedQuery(name = OAuth2RefreshTokenEntity.QUERY_ALL, query = "select r from OAuth2RefreshTokenEntity r"),
5454
@NamedQuery(name = OAuth2RefreshTokenEntity.QUERY_EXPIRED_BY_DATE, query = "select r from OAuth2RefreshTokenEntity r where r.expiration <= :" + OAuth2RefreshTokenEntity.PARAM_DATE),
5555
@NamedQuery(name = OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, query = "select r from OAuth2RefreshTokenEntity r where r.client = :" + OAuth2RefreshTokenEntity.PARAM_CLIENT),
56-
@NamedQuery(name = OAuth2RefreshTokenEntity.QUERY_BY_TOKEN_VALUE, query = "select r from OAuth2RefreshTokenEntity r where r.jwt = :" + OAuth2RefreshTokenEntity.PARAM_TOKEN_VALUE)
56+
@NamedQuery(name = OAuth2RefreshTokenEntity.QUERY_BY_TOKEN_VALUE, query = "select r from OAuth2RefreshTokenEntity r where r.jwt = :" + OAuth2RefreshTokenEntity.PARAM_TOKEN_VALUE),
57+
@NamedQuery(name = OAuth2RefreshTokenEntity.QUERY_BY_NAME, query = "select r from OAuth2RefreshTokenEntity r where r.authenticationHolder.userAuth.name = :" + OAuth2RefreshTokenEntity.PARAM_NAME)
5758
})
5859
public class OAuth2RefreshTokenEntity implements OAuth2RefreshToken {
5960

6061
public static final String QUERY_BY_TOKEN_VALUE = "OAuth2RefreshTokenEntity.getByTokenValue";
6162
public static final String QUERY_BY_CLIENT = "OAuth2RefreshTokenEntity.getByClient";
6263
public static final String QUERY_EXPIRED_BY_DATE = "OAuth2RefreshTokenEntity.getAllExpiredByDate";
6364
public static final String QUERY_ALL = "OAuth2RefreshTokenEntity.getAll";
65+
public static final String QUERY_BY_NAME = "OAuth2RefreshTokenEntity.getByName";
6466

6567
public static final String PARAM_TOKEN_VALUE = "tokenValue";
6668
public static final String PARAM_CLIENT = "client";
6769
public static final String PARAM_DATE = "date";
70+
public static final String PARAM_NAME = "name";
6871

6972
private Long id;
7073

openid-connect-common/src/main/java/org/mitre/oauth2/repository/OAuth2TokenRepository.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ public interface OAuth2TokenRepository {
5252
public List<OAuth2AccessTokenEntity> getAccessTokensForClient(ClientDetailsEntity client);
5353

5454
public List<OAuth2RefreshTokenEntity> getRefreshTokensForClient(ClientDetailsEntity client);
55+
56+
public Set<OAuth2AccessTokenEntity> getAccessTokensByUserName(String name);
57+
58+
public Set<OAuth2RefreshTokenEntity> getRefreshTokensByUserName(String name);
5559

5660
public Set<OAuth2AccessTokenEntity> getAllAccessTokens();
5761

openid-connect-server/pom.xml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,31 @@
4747
<groupId>org.springframework</groupId>
4848
<artifactId>spring-tx</artifactId>
4949
</dependency>
50+
<dependency>
51+
<groupId>org.springframework</groupId>
52+
<artifactId>spring-orm</artifactId>
53+
<scope>test</scope>
54+
<exclusions>
55+
<exclusion>
56+
<groupId>commons-logging</groupId>
57+
<artifactId>commons-logging</artifactId>
58+
</exclusion>
59+
</exclusions>
60+
</dependency>
5061
<dependency>
5162
<groupId>org.eclipse.persistence</groupId>
5263
<artifactId>org.eclipse.persistence.core</artifactId>
5364
</dependency>
65+
<dependency>
66+
<groupId>org.hsqldb</groupId>
67+
<artifactId>hsqldb</artifactId>
68+
<scope>test</scope>
69+
</dependency>
70+
<dependency>
71+
<groupId>org.eclipse.persistence</groupId>
72+
<artifactId>org.eclipse.persistence.jpa</artifactId>
73+
<scope>test</scope>
74+
</dependency>
5475

5576
<dependency>
5677
<groupId>org.apache.commons</groupId>

openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.text.ParseException;
2121
import java.util.ArrayList;
2222
import java.util.Date;
23+
import java.util.HashSet;
2324
import java.util.LinkedHashSet;
2425
import java.util.List;
2526
import java.util.Set;
@@ -168,9 +169,6 @@ public void clearTokensForClient(ClientDetailsEntity client) {
168169
}
169170
}
170171

171-
/* (non-Javadoc)
172-
* @see org.mitre.oauth2.repository.OAuth2TokenRepository#getAccessTokensForClient(org.mitre.oauth2.model.ClientDetailsEntity)
173-
*/
174172
@Override
175173
public List<OAuth2AccessTokenEntity> getAccessTokensForClient(ClientDetailsEntity client) {
176174
TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class);
@@ -179,16 +177,29 @@ public List<OAuth2AccessTokenEntity> getAccessTokensForClient(ClientDetailsEntit
179177
return accessTokens;
180178
}
181179

182-
/* (non-Javadoc)
183-
* @see org.mitre.oauth2.repository.OAuth2TokenRepository#getRefreshTokensForClient(org.mitre.oauth2.model.ClientDetailsEntity)
184-
*/
185180
@Override
186181
public List<OAuth2RefreshTokenEntity> getRefreshTokensForClient(ClientDetailsEntity client) {
187182
TypedQuery<OAuth2RefreshTokenEntity> queryR = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class);
188183
queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client);
189184
List<OAuth2RefreshTokenEntity> refreshTokens = queryR.getResultList();
190185
return refreshTokens;
191186
}
187+
188+
@Override
189+
public Set<OAuth2AccessTokenEntity> getAccessTokensByUserName(String name) {
190+
TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_NAME, OAuth2AccessTokenEntity.class);
191+
query.setParameter(OAuth2AccessTokenEntity.PARAM_NAME, name);
192+
List<OAuth2AccessTokenEntity> results = query.getResultList();
193+
return results != null ? new HashSet<>(query.getResultList()) : new HashSet<>();
194+
}
195+
196+
@Override
197+
public Set<OAuth2RefreshTokenEntity> getRefreshTokensByUserName(String name) {
198+
TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_NAME, OAuth2RefreshTokenEntity.class);
199+
query.setParameter(OAuth2RefreshTokenEntity.PARAM_NAME, name);
200+
List<OAuth2RefreshTokenEntity> results = query.getResultList();
201+
return results != null ? new HashSet<>(query.getResultList()) : new HashSet<>();
202+
}
192203

193204
@Override
194205
public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens() {
@@ -216,25 +227,16 @@ public Set<OAuth2RefreshTokenEntity> getAllExpiredRefreshTokens(PageCriteria pag
216227
return new LinkedHashSet<>(JpaUtil.getResultPage(query,pageCriteria));
217228
}
218229

219-
220-
221-
/* (non-Javadoc)
222-
* @see org.mitre.oauth2.repository.OAuth2TokenRepository#getAccessTokensForResourceSet(org.mitre.uma.model.ResourceSet)
223-
*/
224230
@Override
225231
public Set<OAuth2AccessTokenEntity> getAccessTokensForResourceSet(ResourceSet rs) {
226232
TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, OAuth2AccessTokenEntity.class);
227233
query.setParameter(OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID, rs.getId());
228234
return new LinkedHashSet<>(query.getResultList());
229235
}
230236

231-
/* (non-Javadoc)
232-
* @see org.mitre.oauth2.repository.OAuth2TokenRepository#clearDuplicateAccessTokens()
233-
*/
234237
@Override
235238
@Transactional(value="defaultTransactionManager")
236239
public void clearDuplicateAccessTokens() {
237-
238240
Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2AccessTokenEntity a GROUP BY a.jwt HAVING count(1) > 1");
239241
@SuppressWarnings("unchecked")
240242
List<Object[]> resultList = query.getResultList();
@@ -253,9 +255,6 @@ public void clearDuplicateAccessTokens() {
253255
}
254256
}
255257

256-
/* (non-Javadoc)
257-
* @see org.mitre.oauth2.repository.OAuth2TokenRepository#clearDuplicateRefreshTokens()
258-
*/
259258
@Override
260259
@Transactional(value="defaultTransactionManager")
261260
public void clearDuplicateRefreshTokens() {

openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
import org.springframework.transaction.annotation.Transactional;
6767

6868
import com.google.common.base.Strings;
69-
import com.google.common.collect.Sets;
7069
import com.nimbusds.jose.util.Base64URL;
7170
import com.nimbusds.jwt.JWTClaimsSet;
7271
import com.nimbusds.jwt.PlainJWT;
@@ -102,35 +101,14 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
102101
@Autowired
103102
private ApprovedSiteService approvedSiteService;
104103

105-
106104
@Override
107-
public Set<OAuth2AccessTokenEntity> getAllAccessTokensForUser(String id) {
108-
109-
Set<OAuth2AccessTokenEntity> all = tokenRepository.getAllAccessTokens();
110-
Set<OAuth2AccessTokenEntity> results = Sets.newLinkedHashSet();
111-
112-
for (OAuth2AccessTokenEntity token : all) {
113-
if (clearExpiredAccessToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
114-
results.add(token);
115-
}
116-
}
117-
118-
return results;
105+
public Set<OAuth2AccessTokenEntity> getAllAccessTokensForUser(String userName) {
106+
return tokenRepository.getAccessTokensByUserName(userName);
119107
}
120108

121-
122109
@Override
123-
public Set<OAuth2RefreshTokenEntity> getAllRefreshTokensForUser(String id) {
124-
Set<OAuth2RefreshTokenEntity> all = tokenRepository.getAllRefreshTokens();
125-
Set<OAuth2RefreshTokenEntity> results = Sets.newLinkedHashSet();
126-
127-
for (OAuth2RefreshTokenEntity token : all) {
128-
if (clearExpiredRefreshToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
129-
results.add(token);
130-
}
131-
}
132-
133-
return results;
110+
public Set<OAuth2RefreshTokenEntity> getAllRefreshTokensForUser(String userName) {
111+
return tokenRepository.getRefreshTokensByUserName(userName);
134112
}
135113

136114
@Override
@@ -192,7 +170,6 @@ public OAuth2AccessTokenEntity createAccessToken(OAuth2Authentication authentica
192170
throw new InvalidClientException("Client not found: " + request.getClientId());
193171
}
194172

195-
196173
// handle the PKCE code challenge if present
197174
if (request.getExtensions().containsKey(CODE_CHALLENGE)) {
198175
String challenge = (String) request.getExtensions().get(CODE_CHALLENGE);
@@ -220,7 +197,6 @@ public OAuth2AccessTokenEntity createAccessToken(OAuth2Authentication authentica
220197

221198
}
222199

223-
224200
OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity();//accessTokenFactory.createNewAccessToken();
225201

226202
// attach the client
@@ -306,8 +282,6 @@ private OAuth2RefreshTokenEntity createRefreshToken(ClientDetailsEntity client,
306282
refreshToken.setAuthenticationHolder(authHolder);
307283
refreshToken.setClient(client);
308284

309-
310-
311285
// save the token first so that we can set it to a member of the access token (NOTE: is this step necessary?)
312286
OAuth2RefreshTokenEntity savedRefreshToken = tokenRepository.saveRefreshToken(refreshToken);
313287
return savedRefreshToken;
@@ -410,12 +384,10 @@ public OAuth2AccessTokenEntity refreshAccessToken(String refreshTokenValue, Toke
410384
tokenRepository.saveAccessToken(token);
411385

412386
return token;
413-
414387
}
415388

416389
@Override
417390
public OAuth2Authentication loadAuthentication(String accessTokenValue) throws AuthenticationException {
418-
419391
OAuth2AccessTokenEntity accessToken = clearExpiredAccessToken(tokenRepository.getAccessTokenByValue(accessTokenValue));
420392

421393
if (accessToken == null) {
@@ -481,18 +453,11 @@ public void revokeAccessToken(OAuth2AccessTokenEntity accessToken) {
481453
tokenRepository.removeAccessToken(accessToken);
482454
}
483455

484-
485-
/* (non-Javadoc)
486-
* @see org.mitre.oauth2.service.OAuth2TokenEntityService#getAccessTokensForClient(org.mitre.oauth2.model.ClientDetailsEntity)
487-
*/
488456
@Override
489457
public List<OAuth2AccessTokenEntity> getAccessTokensForClient(ClientDetailsEntity client) {
490458
return tokenRepository.getAccessTokensForClient(client);
491459
}
492460

493-
/* (non-Javadoc)
494-
* @see org.mitre.oauth2.service.OAuth2TokenEntityService#getRefreshTokensForClient(org.mitre.oauth2.model.ClientDetailsEntity)
495-
*/
496461
@Override
497462
public List<OAuth2RefreshTokenEntity> getRefreshTokensForClient(ClientDetailsEntity client) {
498463
return tokenRepository.getRefreshTokensForClient(client);
@@ -595,7 +560,4 @@ public OAuth2AccessTokenEntity getRegistrationAccessTokenForClient(ClientDetails
595560

596561
return null;
597562
}
598-
599-
600-
601563
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package org.mitre.oauth2.repository.impl;
2+
3+
import static java.nio.charset.StandardCharsets.UTF_8;
4+
5+
import java.io.IOException;
6+
import java.nio.file.Files;
7+
import java.nio.file.Paths;
8+
import java.util.HashMap;
9+
import java.util.Map;
10+
11+
import javax.persistence.EntityManagerFactory;
12+
import javax.sql.DataSource;
13+
14+
import org.springframework.beans.factory.FactoryBean;
15+
import org.springframework.beans.factory.annotation.Autowired;
16+
import org.springframework.context.annotation.Bean;
17+
import org.springframework.core.io.ByteArrayResource;
18+
import org.springframework.core.io.DefaultResourceLoader;
19+
import org.springframework.core.io.Resource;
20+
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
21+
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
22+
import org.springframework.orm.jpa.JpaTransactionManager;
23+
import org.springframework.orm.jpa.JpaVendorAdapter;
24+
import org.springframework.orm.jpa.LocalContainerEntityManagerFactoryBean;
25+
import org.springframework.orm.jpa.vendor.Database;
26+
import org.springframework.orm.jpa.vendor.EclipseLinkJpaVendorAdapter;
27+
import org.springframework.transaction.PlatformTransactionManager;
28+
29+
public class TestDatabaseConfiguration {
30+
31+
@Autowired
32+
private JpaVendorAdapter jpaAdapter;
33+
34+
@Autowired
35+
private DataSource dataSource;
36+
37+
@Autowired
38+
private EntityManagerFactory entityManagerFactory;
39+
40+
@Bean
41+
public JpaOAuth2TokenRepository repository() {
42+
return new JpaOAuth2TokenRepository();
43+
}
44+
45+
@Bean(name = "defaultPersistenceUnit")
46+
public FactoryBean<EntityManagerFactory> entityManagerFactory() {
47+
LocalContainerEntityManagerFactoryBean factory = new LocalContainerEntityManagerFactoryBean();
48+
factory.setPackagesToScan("org.mitre", "org.mitre");
49+
factory.setPersistenceProviderClass(org.eclipse.persistence.jpa.PersistenceProvider.class);
50+
factory.setPersistenceUnitName("test" + System.currentTimeMillis());
51+
factory.setDataSource(dataSource);
52+
factory.setJpaVendorAdapter(jpaAdapter);
53+
Map<String, Object> jpaProperties = new HashMap<String, Object>();
54+
jpaProperties.put("eclipselink.weaving", "false");
55+
jpaProperties.put("eclipselink.logging.level", "INFO");
56+
jpaProperties.put("eclipselink.logging.level.sql", "INFO");
57+
jpaProperties.put("eclipselink.cache.shared.default", "false");
58+
factory.setJpaPropertyMap(jpaProperties);
59+
60+
return factory;
61+
}
62+
63+
@Bean
64+
public DataSource dataSource() {
65+
return new EmbeddedDatabaseBuilder(new DefaultResourceLoader() {
66+
@Override
67+
public Resource getResource(String location) {
68+
String sql;
69+
try {
70+
sql = new String(Files.readAllBytes(Paths.get("..", "openid-connect-server-webapp", "src", "main",
71+
"resources", "db", "hsql", location)), UTF_8);
72+
} catch (IOException e) {
73+
throw new RuntimeException("Failed to read sql-script " + location, e);
74+
}
75+
76+
return new ByteArrayResource(sql.getBytes(UTF_8));
77+
}
78+
}).generateUniqueName(true).setScriptEncoding(UTF_8.name()).setType(EmbeddedDatabaseType.HSQL)
79+
.addScripts("hsql_database_tables.sql").build();
80+
}
81+
82+
@Bean
83+
public JpaVendorAdapter jpaAdapter() {
84+
EclipseLinkJpaVendorAdapter adapter = new EclipseLinkJpaVendorAdapter();
85+
adapter.setDatabase(Database.HSQL);
86+
adapter.setShowSql(true);
87+
return adapter;
88+
}
89+
90+
@Bean
91+
public PlatformTransactionManager transactionManager() {
92+
JpaTransactionManager platformTransactionManager = new JpaTransactionManager();
93+
platformTransactionManager.setEntityManagerFactory(entityManagerFactory);
94+
return platformTransactionManager;
95+
}
96+
}

0 commit comments

Comments
 (0)