Skip to content

Commit 10029b0

Browse files
authored
Merge pull request #40 from redis/feature/investigate-this-triage
Fix authentication event loop corruption by converting get_current_user to async
2 parents abb0fff + 3b304c1 commit 10029b0

File tree

4 files changed

+35
-34
lines changed

4 files changed

+35
-34
lines changed

agent_memory_server/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Redis Agent Memory Server - A memory system for conversational AI."""
22

3-
__version__ = "0.9.2"
3+
__version__ = "0.9.3"

agent_memory_server/auth.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ async def verify_token(token: str) -> UserInfo:
346346
) from e
347347

348348

349-
def get_current_user(
349+
async def get_current_user(
350350
credentials: HTTPAuthorizationCredentials | None = Depends(oauth2_scheme),
351351
) -> UserInfo:
352352
if settings.disable_auth or settings.auth_mode == "disabled":
@@ -371,17 +371,15 @@ def get_current_user(
371371

372372
# Determine authentication mode
373373
if settings.auth_mode == "token" or settings.token_auth_enabled:
374-
import asyncio
375-
376-
return asyncio.run(verify_token(credentials.credentials))
374+
return await verify_token(credentials.credentials)
377375
if settings.auth_mode == "oauth2":
378376
return verify_jwt(credentials.credentials)
379377
# Default to OAuth2 for backward compatibility
380378
return verify_jwt(credentials.credentials)
381379

382380

383381
def require_scope(required_scope: str):
384-
def scope_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
382+
async def scope_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
385383
if settings.disable_auth:
386384
return user
387385

@@ -397,7 +395,7 @@ def scope_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
397395

398396

399397
def require_role(required_role: str):
400-
def role_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
398+
async def role_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
401399
if settings.disable_auth:
402400
return user
403401

tests/test_auth.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ async def test_get_current_user_disabled_auth(self, mock_settings):
685685
"""Test get_current_user when authentication is disabled"""
686686
mock_settings.disable_auth = True
687687

688-
result = get_current_user(None)
688+
result = await get_current_user(None)
689689

690690
assert isinstance(result, UserInfo)
691691
assert result.sub == "local-dev-user"
@@ -700,7 +700,7 @@ async def test_get_current_user_missing_credentials(self, mock_settings):
700700
mock_settings.auth_mode = "oauth2"
701701

702702
with pytest.raises(HTTPException) as exc_info:
703-
get_current_user(None)
703+
await get_current_user(None)
704704

705705
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
706706
assert "Missing authorization header" in str(exc_info.value.detail)
@@ -717,7 +717,7 @@ async def test_get_current_user_empty_credentials(self, mock_settings):
717717
empty_creds = HTTPAuthorizationCredentials(scheme="Bearer", credentials="")
718718

719719
with pytest.raises(HTTPException) as exc_info:
720-
get_current_user(empty_creds)
720+
await get_current_user(empty_creds)
721721

722722
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
723723
assert "Missing bearer token" in str(exc_info.value.detail)
@@ -736,7 +736,7 @@ async def test_get_current_user_valid_token(self, mock_settings, valid_token):
736736
expected_user = UserInfo(sub="test-user", email="[email protected]")
737737
mock_verify.return_value = expected_user
738738

739-
result = get_current_user(creds)
739+
result = await get_current_user(creds)
740740

741741
assert result == expected_user
742742
mock_verify.assert_called_once_with(valid_token)
@@ -753,7 +753,7 @@ async def test_require_scope_success(self, mock_settings):
753753
user = UserInfo(sub="test-user", scope="read write admin")
754754
scope_dependency = require_scope("read")
755755

756-
result = scope_dependency(user)
756+
result = await scope_dependency(user)
757757
assert result == user
758758

759759
@pytest.mark.asyncio
@@ -765,7 +765,7 @@ async def test_require_scope_failure(self, mock_settings):
765765
scope_dependency = require_scope("admin")
766766

767767
with pytest.raises(HTTPException) as exc_info:
768-
scope_dependency(user)
768+
await scope_dependency(user)
769769

770770
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
771771
assert "Insufficient permissions" in str(exc_info.value.detail)
@@ -780,7 +780,7 @@ async def test_require_scope_no_scope(self, mock_settings):
780780
scope_dependency = require_scope("read")
781781

782782
with pytest.raises(HTTPException) as exc_info:
783-
scope_dependency(user)
783+
await scope_dependency(user)
784784

785785
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
786786

@@ -792,7 +792,7 @@ async def test_require_scope_disabled_auth(self, mock_settings):
792792
user = UserInfo(sub="test-user", scope=None)
793793
scope_dependency = require_scope("admin")
794794

795-
result = scope_dependency(user)
795+
result = await scope_dependency(user)
796796
assert result == user
797797

798798
@pytest.mark.asyncio
@@ -803,7 +803,7 @@ async def test_require_role_success(self, mock_settings):
803803
user = UserInfo(sub="test-user", roles=["user", "admin"])
804804
role_dependency = require_role("admin")
805805

806-
result = role_dependency(user)
806+
result = await role_dependency(user)
807807
assert result == user
808808

809809
@pytest.mark.asyncio
@@ -815,7 +815,7 @@ async def test_require_role_failure(self, mock_settings):
815815
role_dependency = require_role("admin")
816816

817817
with pytest.raises(HTTPException) as exc_info:
818-
role_dependency(user)
818+
await role_dependency(user)
819819

820820
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
821821
assert "Insufficient permissions" in str(exc_info.value.detail)
@@ -830,7 +830,7 @@ async def test_require_role_no_roles(self, mock_settings):
830830
role_dependency = require_role("admin")
831831

832832
with pytest.raises(HTTPException) as exc_info:
833-
role_dependency(user)
833+
await role_dependency(user)
834834

835835
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
836836

@@ -842,7 +842,7 @@ async def test_require_role_disabled_auth(self, mock_settings):
842842
user = UserInfo(sub="test-user", roles=None)
843843
role_dependency = require_role("admin")
844844

845-
result = role_dependency(user)
845+
result = await role_dependency(user)
846846
assert result == user
847847

848848

tests/test_token_auth.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -189,59 +189,62 @@ async def test_verify_token_wrong_token(self, mock_redis, sample_token_info):
189189
class TestGetCurrentUser:
190190
"""Test get_current_user with token authentication."""
191191

192-
def test_get_current_user_disabled_auth(self, mock_settings):
192+
@pytest.mark.asyncio
193+
async def test_get_current_user_disabled_auth(self, mock_settings):
193194
"""Test get_current_user with disabled authentication."""
194195
mock_settings.disable_auth = True
195196
mock_settings.auth_mode = "disabled"
196197

197-
user_info = get_current_user(None)
198+
user_info = await get_current_user(None)
198199

199200
assert user_info.sub == "local-dev-user"
200201
assert user_info.aud == "local-dev"
201202

202-
def test_get_current_user_missing_credentials(self, mock_settings):
203+
@pytest.mark.asyncio
204+
async def test_get_current_user_missing_credentials(self, mock_settings):
203205
"""Test get_current_user with missing credentials."""
204206
mock_settings.disable_auth = False
205207
mock_settings.auth_mode = "token"
206208

207209
with pytest.raises(HTTPException) as exc_info:
208-
get_current_user(None)
210+
await get_current_user(None)
209211

210212
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
211213
assert "Missing authorization header" in exc_info.value.detail
212214

213-
def test_get_current_user_missing_token(self, mock_settings):
215+
@pytest.mark.asyncio
216+
async def test_get_current_user_missing_token(self, mock_settings):
214217
"""Test get_current_user with missing token."""
215218
mock_settings.disable_auth = False
216219
mock_settings.auth_mode = "token"
217220

218221
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="")
219222

220223
with pytest.raises(HTTPException) as exc_info:
221-
get_current_user(credentials)
224+
await get_current_user(credentials)
222225

223226
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
224227
assert "Missing bearer token" in exc_info.value.detail
225228

226-
@patch("agent_memory_server.auth.verify_token")
227-
def test_get_current_user_token_auth(self, mock_verify_token, mock_settings):
229+
@patch("agent_memory_server.auth.verify_token", new_callable=AsyncMock)
230+
@pytest.mark.asyncio
231+
async def test_get_current_user_token_auth(self, mock_verify_token, mock_settings):
228232
"""Test get_current_user with token authentication."""
229233
mock_settings.disable_auth = False
230234
mock_settings.auth_mode = "token"
231235

232236
# Mock verify_token to return a user
233237
mock_user = Mock()
234238
mock_user.sub = "token-user"
239+
mock_verify_token.return_value = mock_user
235240

236-
# Mock asyncio.run to return the user directly
237-
with patch("asyncio.run", return_value=mock_user):
238-
credentials = HTTPAuthorizationCredentials(
239-
scheme="Bearer", credentials="test_token"
240-
)
241+
credentials = HTTPAuthorizationCredentials(
242+
scheme="Bearer", credentials="test_token"
243+
)
241244

242-
user_info = get_current_user(credentials)
245+
user_info = await get_current_user(credentials)
243246

244-
assert user_info.sub == "token-user"
247+
assert user_info.sub == "token-user"
245248

246249

247250
class TestAuthConfig:

0 commit comments

Comments
 (0)