Skip to content

Commit 349d36a

Browse files
authored
Add support for snapstart runtime hooks (#176)
1 parent 079135e commit 349d36a

8 files changed

+280
-12
lines changed

awslambdaric/bootstrap.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
_AWS_LAMBDA_LOG_LEVEL = _get_log_level_from_env_var(
3232
os.environ.get("AWS_LAMBDA_LOG_LEVEL")
3333
)
34+
AWS_LAMBDA_INITIALIZATION_TYPE = "AWS_LAMBDA_INITIALIZATION_TYPE"
35+
INIT_TYPE_SNAP_START = "snap-start"
3436

3537

3638
def _get_handler(handler):
@@ -286,6 +288,29 @@ def extract_traceback(tb):
286288
]
287289

288290

291+
def on_init_complete(lambda_runtime_client, log_sink):
292+
from . import lambda_runtime_hooks_runner
293+
294+
try:
295+
lambda_runtime_hooks_runner.run_before_snapshot()
296+
lambda_runtime_client.restore_next()
297+
except:
298+
error_result = build_fault_result(sys.exc_info(), None)
299+
log_error(error_result, log_sink)
300+
lambda_runtime_client.post_init_error(
301+
error_result, FaultException.BEFORE_SNAPSHOT_ERROR
302+
)
303+
sys.exit(64)
304+
305+
try:
306+
lambda_runtime_hooks_runner.run_after_restore()
307+
except:
308+
error_result = build_fault_result(sys.exc_info(), None)
309+
log_error(error_result, log_sink)
310+
lambda_runtime_client.report_restore_error(error_result)
311+
sys.exit(65)
312+
313+
289314
class LambdaLoggerHandler(logging.Handler):
290315
def __init__(self, log_sink):
291316
logging.Handler.__init__(self)
@@ -454,10 +479,10 @@ def run(app_root, handler, lambda_runtime_api_addr):
454479
sys.stdout = Unbuffered(sys.stdout)
455480
sys.stderr = Unbuffered(sys.stderr)
456481

457-
use_thread_for_polling_next = os.environ.get("AWS_EXECUTION_ENV") in [
482+
use_thread_for_polling_next = os.environ.get("AWS_EXECUTION_ENV") in {
458483
"AWS_Lambda_python3.12",
459484
"AWS_Lambda_python3.13",
460-
]
485+
}
461486

462487
with create_log_sink() as log_sink:
463488
lambda_runtime_client = LambdaRuntimeClient(
@@ -485,6 +510,9 @@ def run(app_root, handler, lambda_runtime_api_addr):
485510

486511
sys.exit(1)
487512

513+
if os.environ.get(AWS_LAMBDA_INITIALIZATION_TYPE) == INIT_TYPE_SNAP_START:
514+
on_init_complete(lambda_runtime_client, log_sink)
515+
488516
while True:
489517
event_request = lambda_runtime_client.wait_next_invocation()
490518

awslambdaric/lambda_runtime_client.py

+41-9
Original file line numberDiff line numberDiff line change
@@ -62,25 +62,57 @@ def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False):
6262
# Not defining symbol as global to avoid relying on TPE being imported unconditionally.
6363
self.ThreadPoolExecutor = ThreadPoolExecutor
6464

65-
def post_init_error(self, error_response_data):
65+
def call_rapid(
66+
self, http_method, endpoint, expected_http_code, payload=None, headers=None
67+
):
6668
# These imports are heavy-weight. They implicitly trigger `import ssl, hashlib`.
6769
# Importing them lazily to speed up critical path of a common case.
68-
import http
6970
import http.client
7071

7172
runtime_connection = http.client.HTTPConnection(self.lambda_runtime_address)
7273
runtime_connection.connect()
73-
endpoint = "/2018-06-01/runtime/init/error"
74-
headers = {ERROR_TYPE_HEADER: error_response_data["errorType"]}
75-
runtime_connection.request(
76-
"POST", endpoint, to_json(error_response_data), headers=headers
77-
)
74+
if http_method == "GET":
75+
runtime_connection.request(http_method, endpoint)
76+
else:
77+
runtime_connection.request(
78+
http_method, endpoint, to_json(payload), headers=headers
79+
)
80+
7881
response = runtime_connection.getresponse()
7982
response_body = response.read()
80-
81-
if response.code != http.HTTPStatus.ACCEPTED:
83+
if response.code != expected_http_code:
8284
raise LambdaRuntimeClientError(endpoint, response.code, response_body)
8385

86+
def post_init_error(self, error_response_data, error_type_override=None):
87+
import http
88+
89+
endpoint = "/2018-06-01/runtime/init/error"
90+
headers = {
91+
ERROR_TYPE_HEADER: (
92+
error_type_override
93+
if error_type_override
94+
else error_response_data["errorType"]
95+
)
96+
}
97+
self.call_rapid(
98+
"POST", endpoint, http.HTTPStatus.ACCEPTED, error_response_data, headers
99+
)
100+
101+
def restore_next(self):
102+
import http
103+
104+
endpoint = "/2018-06-01/runtime/restore/next"
105+
self.call_rapid("GET", endpoint, http.HTTPStatus.OK)
106+
107+
def report_restore_error(self, restore_error_data):
108+
import http
109+
110+
endpoint = "/2018-06-01/runtime/restore/error"
111+
headers = {ERROR_TYPE_HEADER: FaultException.AFTER_RESTORE_ERROR}
112+
self.call_rapid(
113+
"POST", endpoint, http.HTTPStatus.ACCEPTED, restore_error_data, headers
114+
)
115+
84116
def wait_next_invocation(self):
85117
# Calling runtime_client.next() from a separate thread unblocks the main thread,
86118
# which can then process signals.

awslambdaric/lambda_runtime_exception.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ class FaultException(Exception):
1111
IMPORT_MODULE_ERROR = "Runtime.ImportModuleError"
1212
BUILT_IN_MODULE_CONFLICT = "Runtime.BuiltInModuleConflict"
1313
MALFORMED_HANDLER_NAME = "Runtime.MalformedHandlerName"
14+
BEFORE_SNAPSHOT_ERROR = "Runtime.BeforeSnapshotError"
15+
AFTER_RESTORE_ERROR = "Runtime.AfterRestoreError"
1416
LAMBDA_CONTEXT_UNMARSHAL_ERROR = "Runtime.LambdaContextUnmarshalError"
1517
LAMBDA_RUNTIME_CLIENT_ERROR = "Runtime.LambdaRuntimeClientError"
1618

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from snapshot_restore_py import get_before_snapshot, get_after_restore
5+
6+
7+
def run_before_snapshot():
8+
before_snapshot_callables = get_before_snapshot()
9+
while before_snapshot_callables:
10+
# Using pop as before checkpoint callables are executed in the reverse order of their registration
11+
func, args, kwargs = before_snapshot_callables.pop()
12+
func(*args, **kwargs)
13+
14+
15+
def run_after_restore():
16+
after_restore_callables = get_after_restore()
17+
for func, args, kwargs in after_restore_callables:
18+
func(*args, **kwargs)

requirements/base.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
simplejson>=3.18.4
2+
snapshot-restore-py>=1.0.0

tests/test_bootstrap.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import unittest
1515
from io import StringIO
1616
from tempfile import NamedTemporaryFile
17-
from unittest.mock import MagicMock, Mock, patch
17+
from unittest.mock import MagicMock, Mock, patch, ANY
1818

1919
import awslambdaric.bootstrap as bootstrap
2020
from awslambdaric.lambda_runtime_exception import FaultException
@@ -23,6 +23,7 @@
2323
from awslambdaric.lambda_literals import (
2424
lambda_unhandled_exception_warning_message,
2525
)
26+
import snapshot_restore_py
2627

2728

2829
class TestUpdateXrayEnv(unittest.TestCase):
@@ -1457,5 +1458,53 @@ class TestException(Exception):
14571458
mock_sys.exit.assert_called_once_with(1)
14581459

14591460

1461+
class TestOnInitComplete(unittest.TestCase):
1462+
def tearDown(self):
1463+
# We are accessing private filed for cleaning up
1464+
snapshot_restore_py._before_snapshot_registry = []
1465+
snapshot_restore_py._after_restore_registry = []
1466+
1467+
# We are using ANY over here as the main thing we want to test is teh errorType propogation and stack trace generation
1468+
error_result = {
1469+
"errorMessage": "This is a Dummy type error",
1470+
"errorType": "TypeError",
1471+
"requestId": "",
1472+
"stackTrace": ANY,
1473+
}
1474+
1475+
def raise_type_error(self):
1476+
raise TypeError("This is a Dummy type error")
1477+
1478+
@patch("awslambdaric.bootstrap.LambdaRuntimeClient")
1479+
def test_before_snapshot_exception(self, mock_runtime_client):
1480+
snapshot_restore_py.register_before_snapshot(self.raise_type_error)
1481+
1482+
with self.assertRaises(SystemExit) as cm:
1483+
bootstrap.on_init_complete(
1484+
mock_runtime_client, log_sink=bootstrap.StandardLogSink()
1485+
)
1486+
1487+
self.assertEqual(cm.exception.code, 64)
1488+
mock_runtime_client.post_init_error.assert_called_once_with(
1489+
self.error_result,
1490+
FaultException.BEFORE_SNAPSHOT_ERROR,
1491+
)
1492+
1493+
@patch("awslambdaric.bootstrap.LambdaRuntimeClient")
1494+
def test_after_restore_exception(self, mock_runtime_client):
1495+
snapshot_restore_py.register_after_restore(self.raise_type_error)
1496+
1497+
with self.assertRaises(SystemExit) as cm:
1498+
bootstrap.on_init_complete(
1499+
mock_runtime_client, log_sink=bootstrap.StandardLogSink()
1500+
)
1501+
1502+
self.assertEqual(cm.exception.code, 65)
1503+
mock_runtime_client.restore_next.assert_called_once()
1504+
mock_runtime_client.report_restore_error.assert_called_once_with(
1505+
self.error_result
1506+
)
1507+
1508+
14601509
if __name__ == "__main__":
14611510
unittest.main()

tests/test_lambda_runtime_client.py

+73
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,21 @@ def test_wait_next_invocation(self, mock_runtime_client):
109109

110110
headers = {"Lambda-Runtime-Function-Error-Type": error_result["errorType"]}
111111

112+
restore_error_result = {
113+
"errorMessage": "Dummy Restore error",
114+
"errorType": "Runtime.DummyRestoreError",
115+
"requestId": "",
116+
"stackTrace": [],
117+
}
118+
119+
restore_error_header = {
120+
"Lambda-Runtime-Function-Error-Type": "Runtime.AfterRestoreError"
121+
}
122+
123+
before_snapshot_error_header = {
124+
"Lambda-Runtime-Function-Error-Type": "Runtime.BeforeSnapshotError"
125+
}
126+
112127
@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
113128
def test_post_init_error(self, MockHTTPConnection):
114129
mock_conn = MockHTTPConnection.return_value
@@ -225,6 +240,64 @@ def test_post_invocation_error_with_too_large_xray_cause(self, mock_runtime_clie
225240
invoke_id, error_data, ""
226241
)
227242

243+
@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
244+
def test_restore_next(self, MockHTTPConnection):
245+
mock_conn = MockHTTPConnection.return_value
246+
mock_response = MagicMock(autospec=http.client.HTTPResponse)
247+
mock_conn.getresponse.return_value = mock_response
248+
mock_response.read.return_value = b""
249+
mock_response.code = http.HTTPStatus.OK
250+
251+
runtime_client = LambdaRuntimeClient("localhost:1234")
252+
runtime_client.restore_next()
253+
254+
MockHTTPConnection.assert_called_with("localhost:1234")
255+
mock_conn.request.assert_called_once_with(
256+
"GET",
257+
"/2018-06-01/runtime/restore/next",
258+
)
259+
mock_response.read.assert_called_once()
260+
261+
@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
262+
def test_restore_error(self, MockHTTPConnection):
263+
mock_conn = MockHTTPConnection.return_value
264+
mock_response = MagicMock(autospec=http.client.HTTPResponse)
265+
mock_conn.getresponse.return_value = mock_response
266+
mock_response.read.return_value = b""
267+
mock_response.code = http.HTTPStatus.ACCEPTED
268+
269+
runtime_client = LambdaRuntimeClient("localhost:1234")
270+
runtime_client.report_restore_error(self.restore_error_result)
271+
272+
MockHTTPConnection.assert_called_with("localhost:1234")
273+
mock_conn.request.assert_called_once_with(
274+
"POST",
275+
"/2018-06-01/runtime/restore/error",
276+
to_json(self.restore_error_result),
277+
headers=self.restore_error_header,
278+
)
279+
mock_response.read.assert_called_once()
280+
281+
@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
282+
def test_init_before_snapshot_error(self, MockHTTPConnection):
283+
mock_conn = MockHTTPConnection.return_value
284+
mock_response = MagicMock(autospec=http.client.HTTPResponse)
285+
mock_conn.getresponse.return_value = mock_response
286+
mock_response.read.return_value = b""
287+
mock_response.code = http.HTTPStatus.ACCEPTED
288+
289+
runtime_client = LambdaRuntimeClient("localhost:1234")
290+
runtime_client.post_init_error(self.error_result, "Runtime.BeforeSnapshotError")
291+
292+
MockHTTPConnection.assert_called_with("localhost:1234")
293+
mock_conn.request.assert_called_once_with(
294+
"POST",
295+
"/2018-06-01/runtime/init/error",
296+
to_json(self.error_result),
297+
headers=self.before_snapshot_error_header,
298+
)
299+
mock_response.read.assert_called_once()
300+
228301
def test_connection_refused(self):
229302
with self.assertRaises(ConnectionRefusedError):
230303
runtime_client = LambdaRuntimeClient("127.0.0.1:1")

tests/test_runtime_hooks.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import unittest
5+
from unittest.mock import patch, call
6+
from awslambdaric import lambda_runtime_hooks_runner
7+
import snapshot_restore_py
8+
9+
10+
def fun_test1():
11+
print("In function ONE")
12+
13+
14+
def fun_test2():
15+
print("In function TWO")
16+
17+
18+
def fun_with_args_kwargs(x, y, **kwargs):
19+
print("Here are the args:", x, y)
20+
print("Here are the keyword args:", kwargs)
21+
22+
23+
class TestRuntimeHooks(unittest.TestCase):
24+
def tearDown(self):
25+
# We are accessing private filed for cleaning up
26+
snapshot_restore_py._before_snapshot_registry = []
27+
snapshot_restore_py._after_restore_registry = []
28+
29+
@patch("builtins.print")
30+
def test_before_snapshot_execution_order(self, mock_print):
31+
snapshot_restore_py.register_before_snapshot(
32+
fun_with_args_kwargs, 5, 7, arg1="Lambda", arg2="SnapStart"
33+
)
34+
snapshot_restore_py.register_before_snapshot(fun_test2)
35+
snapshot_restore_py.register_before_snapshot(fun_test1)
36+
37+
lambda_runtime_hooks_runner.run_before_snapshot()
38+
39+
calls = []
40+
calls.append(call("In function ONE"))
41+
calls.append(call("In function TWO"))
42+
calls.append(call("Here are the args:", 5, 7))
43+
calls.append(
44+
call("Here are the keyword args:", {"arg1": "Lambda", "arg2": "SnapStart"})
45+
)
46+
self.assertEqual(calls, mock_print.mock_calls)
47+
48+
@patch("builtins.print")
49+
def test_after_restore_execution_order(self, mock_print):
50+
snapshot_restore_py.register_after_restore(
51+
fun_with_args_kwargs, 11, 13, arg1="Lambda", arg2="SnapStart"
52+
)
53+
snapshot_restore_py.register_after_restore(fun_test2)
54+
snapshot_restore_py.register_after_restore(fun_test1)
55+
56+
lambda_runtime_hooks_runner.run_after_restore()
57+
58+
calls = []
59+
calls.append(call("Here are the args:", 11, 13))
60+
calls.append(
61+
call("Here are the keyword args:", {"arg1": "Lambda", "arg2": "SnapStart"})
62+
)
63+
calls.append(call("In function TWO"))
64+
calls.append(call("In function ONE"))
65+
self.assertEqual(calls, mock_print.mock_calls)

0 commit comments

Comments
 (0)