From 6d2b7eba41e68af1dd7d7dd169f9a84645d6437e Mon Sep 17 00:00:00 2001 From: Naman Trivedi Date: Mon, 18 Nov 2024 09:06:14 +0000 Subject: [PATCH] Add support for snapstart runtime hooks --- awslambdaric/bootstrap.py | 32 ++++++++- awslambdaric/lambda_runtime_client.py | 50 +++++++++++--- awslambdaric/lambda_runtime_exception.py | 2 + awslambdaric/lambda_runtime_hooks_runner.py | 18 +++++ requirements/base.txt | 1 + tests/test_bootstrap.py | 51 +++++++++++++- tests/test_lambda_runtime_client.py | 73 +++++++++++++++++++++ tests/test_runtime_hooks.py | 65 ++++++++++++++++++ 8 files changed, 280 insertions(+), 12 deletions(-) create mode 100644 awslambdaric/lambda_runtime_hooks_runner.py create mode 100644 tests/test_runtime_hooks.py diff --git a/awslambdaric/bootstrap.py b/awslambdaric/bootstrap.py index 30fce8b..f10c7bf 100644 --- a/awslambdaric/bootstrap.py +++ b/awslambdaric/bootstrap.py @@ -31,6 +31,8 @@ _AWS_LAMBDA_LOG_LEVEL = _get_log_level_from_env_var( os.environ.get("AWS_LAMBDA_LOG_LEVEL") ) +AWS_LAMBDA_INITIALIZATION_TYPE = "AWS_LAMBDA_INITIALIZATION_TYPE" +INIT_TYPE_SNAP_START = "snap-start" def _get_handler(handler): @@ -286,6 +288,29 @@ def extract_traceback(tb): ] +def on_init_complete(lambda_runtime_client, log_sink): + from . import lambda_runtime_hooks_runner + + try: + lambda_runtime_hooks_runner.run_before_snapshot() + lambda_runtime_client.restore_next() + except: + error_result = build_fault_result(sys.exc_info(), None) + log_error(error_result, log_sink) + lambda_runtime_client.post_init_error( + error_result, FaultException.BEFORE_SNAPSHOT_ERROR + ) + sys.exit(64) + + try: + lambda_runtime_hooks_runner.run_after_restore() + except: + error_result = build_fault_result(sys.exc_info(), None) + log_error(error_result, log_sink) + lambda_runtime_client.report_restore_error(error_result) + sys.exit(65) + + class LambdaLoggerHandler(logging.Handler): def __init__(self, log_sink): logging.Handler.__init__(self) @@ -454,10 +479,10 @@ def run(app_root, handler, lambda_runtime_api_addr): sys.stdout = Unbuffered(sys.stdout) sys.stderr = Unbuffered(sys.stderr) - use_thread_for_polling_next = os.environ.get("AWS_EXECUTION_ENV") in [ + use_thread_for_polling_next = os.environ.get("AWS_EXECUTION_ENV") in { "AWS_Lambda_python3.12", "AWS_Lambda_python3.13", - ] + } with create_log_sink() as log_sink: lambda_runtime_client = LambdaRuntimeClient( @@ -485,6 +510,9 @@ def run(app_root, handler, lambda_runtime_api_addr): sys.exit(1) + if os.environ.get(AWS_LAMBDA_INITIALIZATION_TYPE) == INIT_TYPE_SNAP_START: + on_init_complete(lambda_runtime_client, log_sink) + while True: event_request = lambda_runtime_client.wait_next_invocation() diff --git a/awslambdaric/lambda_runtime_client.py b/awslambdaric/lambda_runtime_client.py index 036d10b..cc87262 100644 --- a/awslambdaric/lambda_runtime_client.py +++ b/awslambdaric/lambda_runtime_client.py @@ -62,25 +62,57 @@ def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False): # Not defining symbol as global to avoid relying on TPE being imported unconditionally. self.ThreadPoolExecutor = ThreadPoolExecutor - def post_init_error(self, error_response_data): + def call_rapid( + self, http_method, endpoint, expected_http_code, payload=None, headers=None + ): # These imports are heavy-weight. They implicitly trigger `import ssl, hashlib`. # Importing them lazily to speed up critical path of a common case. - import http import http.client runtime_connection = http.client.HTTPConnection(self.lambda_runtime_address) runtime_connection.connect() - endpoint = "/2018-06-01/runtime/init/error" - headers = {ERROR_TYPE_HEADER: error_response_data["errorType"]} - runtime_connection.request( - "POST", endpoint, to_json(error_response_data), headers=headers - ) + if http_method == "GET": + runtime_connection.request(http_method, endpoint) + else: + runtime_connection.request( + http_method, endpoint, to_json(payload), headers=headers + ) + response = runtime_connection.getresponse() response_body = response.read() - - if response.code != http.HTTPStatus.ACCEPTED: + if response.code != expected_http_code: raise LambdaRuntimeClientError(endpoint, response.code, response_body) + def post_init_error(self, error_response_data, error_type_override=None): + import http + + endpoint = "/2018-06-01/runtime/init/error" + headers = { + ERROR_TYPE_HEADER: ( + error_type_override + if error_type_override + else error_response_data["errorType"] + ) + } + self.call_rapid( + "POST", endpoint, http.HTTPStatus.ACCEPTED, error_response_data, headers + ) + + def restore_next(self): + import http + + endpoint = "/2018-06-01/runtime/restore/next" + self.call_rapid("GET", endpoint, http.HTTPStatus.OK) + + def report_restore_error(self, restore_error_data): + import http + + endpoint = "/2018-06-01/runtime/restore/error" + headers = {ERROR_TYPE_HEADER: FaultException.AFTER_RESTORE_ERROR} + self.call_rapid( + "POST", endpoint, http.HTTPStatus.ACCEPTED, restore_error_data, headers + ) + def wait_next_invocation(self): # Calling runtime_client.next() from a separate thread unblocks the main thread, # which can then process signals. diff --git a/awslambdaric/lambda_runtime_exception.py b/awslambdaric/lambda_runtime_exception.py index e09af70..3ea5b29 100644 --- a/awslambdaric/lambda_runtime_exception.py +++ b/awslambdaric/lambda_runtime_exception.py @@ -11,6 +11,8 @@ class FaultException(Exception): IMPORT_MODULE_ERROR = "Runtime.ImportModuleError" BUILT_IN_MODULE_CONFLICT = "Runtime.BuiltInModuleConflict" MALFORMED_HANDLER_NAME = "Runtime.MalformedHandlerName" + BEFORE_SNAPSHOT_ERROR = "Runtime.BeforeSnapshotError" + AFTER_RESTORE_ERROR = "Runtime.AfterRestoreError" LAMBDA_CONTEXT_UNMARSHAL_ERROR = "Runtime.LambdaContextUnmarshalError" LAMBDA_RUNTIME_CLIENT_ERROR = "Runtime.LambdaRuntimeClientError" diff --git a/awslambdaric/lambda_runtime_hooks_runner.py b/awslambdaric/lambda_runtime_hooks_runner.py new file mode 100644 index 0000000..8aee181 --- /dev/null +++ b/awslambdaric/lambda_runtime_hooks_runner.py @@ -0,0 +1,18 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from snapshot_restore_py import get_before_snapshot, get_after_restore + + +def run_before_snapshot(): + before_snapshot_callables = get_before_snapshot() + while before_snapshot_callables: + # Using pop as before checkpoint callables are executed in the reverse order of their registration + func, args, kwargs = before_snapshot_callables.pop() + func(*args, **kwargs) + + +def run_after_restore(): + after_restore_callables = get_after_restore() + for func, args, kwargs in after_restore_callables: + func(*args, **kwargs) diff --git a/requirements/base.txt b/requirements/base.txt index 819c723..afdff74 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1 +1,2 @@ simplejson>=3.18.4 +snapshot-restore-py>=1.0.0 diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index 7bc2ad2..79dcae6 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -14,7 +14,7 @@ import unittest from io import StringIO from tempfile import NamedTemporaryFile -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, patch, ANY import awslambdaric.bootstrap as bootstrap from awslambdaric.lambda_runtime_exception import FaultException @@ -23,6 +23,7 @@ from awslambdaric.lambda_literals import ( lambda_unhandled_exception_warning_message, ) +import snapshot_restore_py class TestUpdateXrayEnv(unittest.TestCase): @@ -1457,5 +1458,53 @@ class TestException(Exception): mock_sys.exit.assert_called_once_with(1) +class TestOnInitComplete(unittest.TestCase): + def tearDown(self): + # We are accessing private filed for cleaning up + snapshot_restore_py._before_snapshot_registry = [] + snapshot_restore_py._after_restore_registry = [] + + # We are using ANY over here as the main thing we want to test is teh errorType propogation and stack trace generation + error_result = { + "errorMessage": "This is a Dummy type error", + "errorType": "TypeError", + "requestId": "", + "stackTrace": ANY, + } + + def raise_type_error(self): + raise TypeError("This is a Dummy type error") + + @patch("awslambdaric.bootstrap.LambdaRuntimeClient") + def test_before_snapshot_exception(self, mock_runtime_client): + snapshot_restore_py.register_before_snapshot(self.raise_type_error) + + with self.assertRaises(SystemExit) as cm: + bootstrap.on_init_complete( + mock_runtime_client, log_sink=bootstrap.StandardLogSink() + ) + + self.assertEqual(cm.exception.code, 64) + mock_runtime_client.post_init_error.assert_called_once_with( + self.error_result, + FaultException.BEFORE_SNAPSHOT_ERROR, + ) + + @patch("awslambdaric.bootstrap.LambdaRuntimeClient") + def test_after_restore_exception(self, mock_runtime_client): + snapshot_restore_py.register_after_restore(self.raise_type_error) + + with self.assertRaises(SystemExit) as cm: + bootstrap.on_init_complete( + mock_runtime_client, log_sink=bootstrap.StandardLogSink() + ) + + self.assertEqual(cm.exception.code, 65) + mock_runtime_client.restore_next.assert_called_once() + mock_runtime_client.report_restore_error.assert_called_once_with( + self.error_result + ) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_lambda_runtime_client.py b/tests/test_lambda_runtime_client.py index e09130b..b13aa83 100644 --- a/tests/test_lambda_runtime_client.py +++ b/tests/test_lambda_runtime_client.py @@ -109,6 +109,21 @@ def test_wait_next_invocation(self, mock_runtime_client): headers = {"Lambda-Runtime-Function-Error-Type": error_result["errorType"]} + restore_error_result = { + "errorMessage": "Dummy Restore error", + "errorType": "Runtime.DummyRestoreError", + "requestId": "", + "stackTrace": [], + } + + restore_error_header = { + "Lambda-Runtime-Function-Error-Type": "Runtime.AfterRestoreError" + } + + before_snapshot_error_header = { + "Lambda-Runtime-Function-Error-Type": "Runtime.BeforeSnapshotError" + } + @patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection) def test_post_init_error(self, MockHTTPConnection): mock_conn = MockHTTPConnection.return_value @@ -225,6 +240,64 @@ def test_post_invocation_error_with_too_large_xray_cause(self, mock_runtime_clie invoke_id, error_data, "" ) + @patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection) + def test_restore_next(self, MockHTTPConnection): + mock_conn = MockHTTPConnection.return_value + mock_response = MagicMock(autospec=http.client.HTTPResponse) + mock_conn.getresponse.return_value = mock_response + mock_response.read.return_value = b"" + mock_response.code = http.HTTPStatus.OK + + runtime_client = LambdaRuntimeClient("localhost:1234") + runtime_client.restore_next() + + MockHTTPConnection.assert_called_with("localhost:1234") + mock_conn.request.assert_called_once_with( + "GET", + "/2018-06-01/runtime/restore/next", + ) + mock_response.read.assert_called_once() + + @patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection) + def test_restore_error(self, MockHTTPConnection): + mock_conn = MockHTTPConnection.return_value + mock_response = MagicMock(autospec=http.client.HTTPResponse) + mock_conn.getresponse.return_value = mock_response + mock_response.read.return_value = b"" + mock_response.code = http.HTTPStatus.ACCEPTED + + runtime_client = LambdaRuntimeClient("localhost:1234") + runtime_client.report_restore_error(self.restore_error_result) + + MockHTTPConnection.assert_called_with("localhost:1234") + mock_conn.request.assert_called_once_with( + "POST", + "/2018-06-01/runtime/restore/error", + to_json(self.restore_error_result), + headers=self.restore_error_header, + ) + mock_response.read.assert_called_once() + + @patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection) + def test_init_before_snapshot_error(self, MockHTTPConnection): + mock_conn = MockHTTPConnection.return_value + mock_response = MagicMock(autospec=http.client.HTTPResponse) + mock_conn.getresponse.return_value = mock_response + mock_response.read.return_value = b"" + mock_response.code = http.HTTPStatus.ACCEPTED + + runtime_client = LambdaRuntimeClient("localhost:1234") + runtime_client.post_init_error(self.error_result, "Runtime.BeforeSnapshotError") + + MockHTTPConnection.assert_called_with("localhost:1234") + mock_conn.request.assert_called_once_with( + "POST", + "/2018-06-01/runtime/init/error", + to_json(self.error_result), + headers=self.before_snapshot_error_header, + ) + mock_response.read.assert_called_once() + def test_connection_refused(self): with self.assertRaises(ConnectionRefusedError): runtime_client = LambdaRuntimeClient("127.0.0.1:1") diff --git a/tests/test_runtime_hooks.py b/tests/test_runtime_hooks.py new file mode 100644 index 0000000..e73204f --- /dev/null +++ b/tests/test_runtime_hooks.py @@ -0,0 +1,65 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import patch, call +from awslambdaric import lambda_runtime_hooks_runner +import snapshot_restore_py + + +def fun_test1(): + print("In function ONE") + + +def fun_test2(): + print("In function TWO") + + +def fun_with_args_kwargs(x, y, **kwargs): + print("Here are the args:", x, y) + print("Here are the keyword args:", kwargs) + + +class TestRuntimeHooks(unittest.TestCase): + def tearDown(self): + # We are accessing private filed for cleaning up + snapshot_restore_py._before_snapshot_registry = [] + snapshot_restore_py._after_restore_registry = [] + + @patch("builtins.print") + def test_before_snapshot_execution_order(self, mock_print): + snapshot_restore_py.register_before_snapshot( + fun_with_args_kwargs, 5, 7, arg1="Lambda", arg2="SnapStart" + ) + snapshot_restore_py.register_before_snapshot(fun_test2) + snapshot_restore_py.register_before_snapshot(fun_test1) + + lambda_runtime_hooks_runner.run_before_snapshot() + + calls = [] + calls.append(call("In function ONE")) + calls.append(call("In function TWO")) + calls.append(call("Here are the args:", 5, 7)) + calls.append( + call("Here are the keyword args:", {"arg1": "Lambda", "arg2": "SnapStart"}) + ) + self.assertEqual(calls, mock_print.mock_calls) + + @patch("builtins.print") + def test_after_restore_execution_order(self, mock_print): + snapshot_restore_py.register_after_restore( + fun_with_args_kwargs, 11, 13, arg1="Lambda", arg2="SnapStart" + ) + snapshot_restore_py.register_after_restore(fun_test2) + snapshot_restore_py.register_after_restore(fun_test1) + + lambda_runtime_hooks_runner.run_after_restore() + + calls = [] + calls.append(call("Here are the args:", 11, 13)) + calls.append( + call("Here are the keyword args:", {"arg1": "Lambda", "arg2": "SnapStart"}) + ) + calls.append(call("In function TWO")) + calls.append(call("In function ONE")) + self.assertEqual(calls, mock_print.mock_calls)