21
21
use League \OAuth2 \Server \Entities \ClientEntityInterface as OAuth2ClientEntityInterface ;
22
22
use League \OAuth2 \Server \Entities \UserEntityInterface ;
23
23
use League \OAuth2 \Server \Repositories \UserRepositoryInterface ;
24
+ use SimpleSAML \Database ;
24
25
use SimpleSAML \Module \oidc \Entities \UserEntity ;
25
26
use SimpleSAML \Module \oidc \Factories \Entities \UserEntityFactory ;
26
27
use SimpleSAML \Module \oidc \Helpers ;
27
28
use SimpleSAML \Module \oidc \ModuleConfig ;
28
29
use SimpleSAML \Module \oidc \Repositories \Interfaces \IdentityProviderInterface ;
30
+ use SimpleSAML \Module \oidc \Utils \ProtocolCache ;
29
31
30
32
class UserRepository extends AbstractDatabaseRepository implements UserRepositoryInterface, IdentityProviderInterface
31
33
{
32
34
final public const TABLE_NAME = 'oidc_user ' ;
33
35
34
36
public function __construct (
35
37
ModuleConfig $ moduleConfig ,
38
+ Database $ database ,
39
+ ?ProtocolCache $ protocolCache ,
36
40
protected readonly Helpers $ helpers ,
37
41
protected readonly UserEntityFactory $ userEntityFactory ,
38
42
) {
39
- parent ::__construct ($ moduleConfig );
43
+ parent ::__construct ($ moduleConfig, $ database , $ protocolCache );
40
44
}
41
45
42
46
public function getTableName (): string
43
47
{
44
48
return $ this ->database ->applyPrefix (self ::TABLE_NAME );
45
49
}
46
50
51
+ public function getCacheKey (string $ identifier ): string
52
+ {
53
+ return $ this ->getTableName () . '_ ' . $ identifier ;
54
+ }
55
+
47
56
/**
48
57
* @param string $identifier
49
58
*
@@ -52,6 +61,13 @@ public function getTableName(): string
52
61
*/
53
62
public function getUserEntityByIdentifier (string $ identifier ): ?UserEntity
54
63
{
64
+ /** @var ?array $cachedState */
65
+ $ cachedState = $ this ->protocolCache ?->get(null , $ this ->getCacheKey ($ identifier ));
66
+
67
+ if (is_array ($ cachedState )) {
68
+ return $ this ->userEntityFactory ->fromState ($ cachedState );
69
+ }
70
+
55
71
$ stmt = $ this ->database ->read (
56
72
"SELECT * FROM {$ this ->getTableName ()} WHERE id = :id " ,
57
73
[
@@ -69,7 +85,15 @@ public function getUserEntityByIdentifier(string $identifier): ?UserEntity
69
85
return null ;
70
86
}
71
87
72
- return $ this ->userEntityFactory ->fromState ($ row );
88
+ $ userEntity = $ this ->userEntityFactory ->fromState ($ row );
89
+
90
+ $ this ->protocolCache ?->set(
91
+ $ userEntity ->getState (),
92
+ $ this ->moduleConfig ->getProtocolUserEntityCacheDuration (),
93
+ $ this ->getCacheKey ($ userEntity ->getIdentifier ()),
94
+ );
95
+
96
+ return $ userEntity ;
73
97
}
74
98
75
99
/**
@@ -95,21 +119,29 @@ public function add(UserEntity $userEntity): void
95
119
$ stmt ,
96
120
$ userEntity ->getState (),
97
121
);
122
+
123
+ $ this ->protocolCache ?->set(
124
+ $ userEntity ->getState (),
125
+ $ this ->moduleConfig ->getProtocolUserEntityCacheDuration (),
126
+ $ this ->getCacheKey ($ userEntity ->getIdentifier ()),
127
+ );
98
128
}
99
129
100
- public function delete (UserEntity $ user ): void
130
+ public function delete (UserEntity $ userEntity ): void
101
131
{
102
132
$ this ->database ->write (
103
133
"DELETE FROM {$ this ->getTableName ()} WHERE id = :id " ,
104
134
[
105
- 'id ' => $ user ->getIdentifier (),
135
+ 'id ' => $ userEntity ->getIdentifier (),
106
136
],
107
137
);
138
+
139
+ $ this ->protocolCache ?->delete($ this ->getCacheKey ($ userEntity ->getIdentifier ()));
108
140
}
109
141
110
- public function update (UserEntity $ user , ?DateTimeImmutable $ updatedAt = null ): void
142
+ public function update (UserEntity $ userEntity , ?DateTimeImmutable $ updatedAt = null ): void
111
143
{
112
- $ user ->setUpdatedAt ($ updatedAt ?? $ this ->helpers ->dateTime ()->getUtc ());
144
+ $ userEntity ->setUpdatedAt ($ updatedAt ?? $ this ->helpers ->dateTime ()->getUtc ());
113
145
114
146
$ stmt = sprintf (
115
147
"UPDATE %s SET claims = :claims, updated_at = :updated_at, created_at = :created_at WHERE id = :id " ,
@@ -118,7 +150,13 @@ public function update(UserEntity $user, ?DateTimeImmutable $updatedAt = null):
118
150
119
151
$ this ->database ->write (
120
152
$ stmt ,
121
- $ user ->getState (),
153
+ $ userEntity ->getState (),
154
+ );
155
+
156
+ $ this ->protocolCache ?->set(
157
+ $ userEntity ->getState (),
158
+ $ this ->moduleConfig ->getProtocolUserEntityCacheDuration (),
159
+ $ this ->getCacheKey ($ userEntity ->getIdentifier ()),
122
160
);
123
161
}
124
162
}
0 commit comments