Skip to content

Commit ca4b062

Browse files
Zsailerkevin-bates
andauthoredApr 11, 2023
Define a CURRENT_JUPYTER_HANDLER context var (#1251)
* [Enhancement] Define a CURRENT_JUPYTER_HANDLER context var * add type to context var * Introduce CallContext class * Add CallContext to API docs * Alphabetize submodules * Unit test contextvar in the kernel shutdown flow * Update tests/services/sessions/test_call_context.py Co-authored-by: Kevin Bates <kbates4@gmail.com> * revert unit test back to using kernel_model * Relocate to base package * Update location in docs as well --------- Co-authored-by: Kevin Bates <kbates4@gmail.com>
1 parent 87b2158 commit ca4b062

File tree

5 files changed

+208
-0
lines changed

5 files changed

+208
-0
lines changed
 

‎docs/source/api/jupyter_server.base.rst

+6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ Submodules
55
----------
66

77

8+
.. automodule:: jupyter_server.base.call_context
9+
:members:
10+
:undoc-members:
11+
:show-inheritance:
12+
13+
814
.. automodule:: jupyter_server.base.handlers
915
:members:
1016
:undoc-members:

‎jupyter_server/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
del os
1616

1717
from ._version import __version__, version_info # noqa
18+
from .base.call_context import CallContext # noqa
1819

1920

2021
def _cleanup():

‎jupyter_server/base/call_context.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Provides access to variables pertaining to specific call contexts."""
2+
# Copyright (c) Jupyter Development Team.
3+
# Distributed under the terms of the Modified BSD License.
4+
5+
from contextvars import Context, ContextVar, copy_context
6+
from typing import Any, Dict, List
7+
8+
9+
class CallContext:
10+
"""CallContext essentially acts as a namespace for managing context variables.
11+
12+
Although not required, it is recommended that any "file-spanning" context variable
13+
names (i.e., variables that will be set or retrieved from multiple files or services) be
14+
added as constants to this class definition.
15+
"""
16+
17+
# Add well-known (file-spanning) names here.
18+
#: Provides access to the current request handler once set.
19+
JUPYTER_HANDLER: str = "JUPYTER_HANDLER"
20+
21+
# A map of variable name to value is maintained as the single ContextVar. This also enables
22+
# easier management over maintaining a set of ContextVar instances, since the Context is a
23+
# map of ContextVar instances to their values, and the "name" is no longer a lookup key.
24+
_NAME_VALUE_MAP = "_name_value_map"
25+
_name_value_map: ContextVar[Dict[str, Any]] = ContextVar(_NAME_VALUE_MAP)
26+
27+
@classmethod
28+
def get(cls, name: str) -> Any:
29+
"""Returns the value corresponding the named variable relative to this context.
30+
31+
If the named variable doesn't exist, None will be returned.
32+
33+
Parameters
34+
----------
35+
name : str
36+
The name of the variable to get from the call context
37+
38+
Returns
39+
-------
40+
value: Any
41+
The value associated with the named variable for this call context
42+
"""
43+
name_value_map = CallContext._get_map()
44+
45+
if name in name_value_map:
46+
return name_value_map[name]
47+
return None # TODO - should this raise `LookupError` (or a custom error derived from said)
48+
49+
@classmethod
50+
def set(cls, name: str, value: Any) -> None:
51+
"""Sets the named variable to the specified value in the current call context.
52+
53+
Parameters
54+
----------
55+
name : str
56+
The name of the variable to store into the call context
57+
value : Any
58+
The value of the variable to store into the call context
59+
60+
Returns
61+
-------
62+
None
63+
"""
64+
name_value_map = CallContext._get_map()
65+
name_value_map[name] = value
66+
67+
@classmethod
68+
def context_variable_names(cls) -> List[str]:
69+
"""Returns a list of variable names set for this call context.
70+
71+
Returns
72+
-------
73+
names: List[str]
74+
A list of variable names set for this call context.
75+
"""
76+
name_value_map = CallContext._get_map()
77+
return list(name_value_map.keys())
78+
79+
@classmethod
80+
def _get_map(cls) -> Dict[str, Any]:
81+
"""Get the map of names to their values from the _NAME_VALUE_MAP context var.
82+
83+
If the map does not exist in the current context, an empty map is created and returned.
84+
"""
85+
ctx: Context = copy_context()
86+
if CallContext._name_value_map not in ctx:
87+
CallContext._name_value_map.set({})
88+
return CallContext._name_value_map.get()

‎jupyter_server/base/handlers.py

+4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from traitlets.config import Application
2727

2828
import jupyter_server
29+
from jupyter_server import CallContext
2930
from jupyter_server._sysinfo import get_sys_info
3031
from jupyter_server._tz import utcnow
3132
from jupyter_server.auth import authorized
@@ -582,6 +583,9 @@ def check_host(self):
582583

583584
async def prepare(self):
584585
"""Pepare a response."""
586+
# Set the current Jupyter Handler context variable.
587+
CallContext.set(CallContext.JUPYTER_HANDLER, self)
588+
585589
if not self.check_host():
586590
self.current_user = self._jupyter_current_user = None
587591
raise web.HTTPError(403)

‎tests/base/test_call_context.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import asyncio
2+
3+
from jupyter_server import CallContext
4+
from jupyter_server.auth.utils import get_anonymous_username
5+
from jupyter_server.base.handlers import JupyterHandler
6+
from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager
7+
8+
9+
async def test_jupyter_handler_contextvar(jp_fetch, monkeypatch):
10+
# Create some mock kernel Ids
11+
kernel1 = "x-x-x-x-x"
12+
kernel2 = "y-y-y-y-y"
13+
14+
# We'll use this dictionary to track the current user within each request.
15+
context_tracker = {
16+
kernel1: {"started": "no user yet", "ended": "still no user", "user": None},
17+
kernel2: {"started": "no user yet", "ended": "still no user", "user": None},
18+
}
19+
20+
# Monkeypatch the get_current_user method in Tornado's
21+
# request handler to return a random user name for
22+
# each request
23+
async def get_current_user(self):
24+
return get_anonymous_username()
25+
26+
monkeypatch.setattr(JupyterHandler, "get_current_user", get_current_user)
27+
28+
# Monkeypatch the kernel_model method to show that
29+
# the current context variable is truly local and
30+
# not contaminated by other asynchronous parallel requests.
31+
# Note that even though the current implementation of `kernel_model()`
32+
# is synchronous, we can convert this into an async method because the
33+
# kernel handler wraps the call to `kernel_model()` in `ensure_async()`.
34+
async def kernel_model(self, kernel_id):
35+
# Get the Jupyter Handler from the current context.
36+
current: JupyterHandler = CallContext.get(CallContext.JUPYTER_HANDLER)
37+
# Get the current user
38+
context_tracker[kernel_id]["user"] = current.current_user
39+
context_tracker[kernel_id]["started"] = current.current_user
40+
await asyncio.sleep(1.0)
41+
# Track the current user a few seconds later. We'll
42+
# verify that this user was unaffected by other parallel
43+
# requests.
44+
context_tracker[kernel_id]["ended"] = current.current_user
45+
return {"id": kernel_id, "name": "blah"}
46+
47+
monkeypatch.setattr(AsyncMappingKernelManager, "kernel_model", kernel_model)
48+
49+
# Make two requests in parallel.
50+
await asyncio.gather(
51+
jp_fetch("api", "kernels", kernel1),
52+
jp_fetch("api", "kernels", kernel2),
53+
)
54+
55+
# Assert that the two requests had different users
56+
assert context_tracker[kernel1]["user"] != context_tracker[kernel2]["user"]
57+
# Assert that the first request started+ended with the same user
58+
assert context_tracker[kernel1]["started"] == context_tracker[kernel1]["ended"]
59+
# Assert that the second request started+ended with the same user
60+
assert context_tracker[kernel2]["started"] == context_tracker[kernel2]["ended"]
61+
62+
63+
async def test_context_variable_names():
64+
CallContext.set("foo", "bar")
65+
CallContext.set("foo2", "bar2")
66+
names = CallContext.context_variable_names()
67+
assert len(names) == 2
68+
assert set(names) == {"foo", "foo2"}
69+
70+
71+
async def test_same_context_operations():
72+
CallContext.set("foo", "bar")
73+
CallContext.set("foo2", "bar2")
74+
75+
foo = CallContext.get("foo")
76+
assert foo == "bar"
77+
78+
CallContext.set("foo", "bar2")
79+
assert CallContext.get("foo") == CallContext.get("foo2")
80+
81+
82+
async def test_multi_context_operations():
83+
async def context1():
84+
"""The "slower" context. This ensures that, following the sleep, the
85+
context variable set prior to the sleep is still the expected value.
86+
If contexts are not managed properly, we should find that context2() has
87+
corrupted context1().
88+
"""
89+
CallContext.set("foo", "bar1")
90+
await asyncio.sleep(1.0)
91+
assert CallContext.get("foo") == "bar1"
92+
context1_names = CallContext.context_variable_names()
93+
assert len(context1_names) == 1
94+
95+
async def context2():
96+
"""The "faster" context. This ensures that CallContext reflects the
97+
appropriate values of THIS context.
98+
"""
99+
CallContext.set("foo", "bar2")
100+
assert CallContext.get("foo") == "bar2"
101+
CallContext.set("foo2", "bar2")
102+
context2_names = CallContext.context_variable_names()
103+
assert len(context2_names) == 2
104+
105+
await asyncio.gather(context1(), context2())
106+
107+
# Assert that THIS context doesn't have any variables defined.
108+
names = CallContext.context_variable_names()
109+
assert len(names) == 0

0 commit comments

Comments
 (0)
Please sign in to comment.