diff --git a/python/kvikio/kvikio/cufile_driver.py b/python/kvikio/kvikio/cufile_driver.py index 8018415191..7d5818f09f 100644 --- a/python/kvikio/kvikio/cufile_driver.py +++ b/python/kvikio/kvikio/cufile_driver.py @@ -8,13 +8,21 @@ from kvikio._lib import cufile_driver # type: ignore properties = cufile_driver.DriverProperties() +"""cuFile driver configurations. Use kvikio.cufile_driver.properties.get and + kvikio.cufile_driver.properties.set to access the configurations. +""" class ConfigContextManager: + """Context manager allowing the cuFile driver configurations to be set upon + entering a `with` block, and automatically reset upon leaving the block. + """ + def __init__(self, config: dict[str, str]): ( self._property_getters, self._property_setters, + self._readonly_property_getters, ) = self._property_getter_and_setter() self._old_properties = {} @@ -30,19 +38,29 @@ def __exit__(self, type_unused, value, traceback_unused): self._set_property(key, value) def _get_property(self, property: str) -> Any: - func = self._property_getters[property] + if property in self._property_getters: + func = self._property_getters[property] + elif property in self._readonly_property_getters: + func = self._readonly_property_getters[property] + else: + raise KeyError # getter signature: object.__get__(self, instance, owner=None) return func(properties) def _set_property(self, property: str, value: Any): + if property in self._readonly_property_getters: + raise KeyError("This property is read-only.") + func = self._property_setters[property] # setter signature: object.__set__(self, instance, value) func(properties, value) @kvikio.utils.call_once - def _property_getter_and_setter(self) -> tuple[dict[str, Any], dict[str, Any]]: + def _property_getter_and_setter( + self, + ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: class_dict = vars(cufile_driver.DriverProperties) property_getter_names = [ @@ -58,7 +76,19 @@ def _property_getter_and_setter(self) -> tuple[dict[str, Any], dict[str, Any]]: for name in property_getter_names: property_getters[name] = class_dict[name].__get__ property_setters[name] = class_dict[name].__set__ - return property_getters, property_setters + + readonly_property_getter_names = [ + "is_gds_available", + "major_version", + "minor_version", + "allow_compat_mode", + "per_buffer_cache_size", + ] + readonly_property_getters = {} + for name in readonly_property_getter_names: + readonly_property_getters[name] = class_dict[name].__get__ + + return property_getters, property_setters, readonly_property_getters @overload @@ -94,6 +124,12 @@ def set(*config) -> ConfigContextManager: The configurations. Can either be a single parameter (dict) consisting of one or more properties, or two parameters key (string) and value (Any) indicating a single property. + + Returns + ------- + ConfigContextManager + A context manager. If used in a `with` statement, the configuration will revert + to its old value upon leaving the block. """ err_msg = ( @@ -113,6 +149,23 @@ def set(*config) -> ConfigContextManager: raise ValueError(err_msg) +def get(config_name: str) -> Any: + """Get cuFile driver configurations. + + Parameters + ---------- + config_name: str + The name of the configuration. + + Returns + ------- + Any + The value of the configuration. + """ + context_manager = ConfigContextManager({}) + return context_manager._get_property(config_name) + + def libcufile_version() -> Tuple[int, int]: """Get the libcufile version. diff --git a/python/kvikio/kvikio/defaults.py b/python/kvikio/kvikio/defaults.py index 3688be6a6e..528a4d4607 100644 --- a/python/kvikio/kvikio/defaults.py +++ b/python/kvikio/kvikio/defaults.py @@ -1,14 +1,18 @@ # Copyright (c) 2021-2025, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. -import warnings -from typing import Any, Callable, overload + +from typing import Any, overload import kvikio._lib.defaults -import kvikio.utils +from kvikio.utils import call_once, kvikio_deprecation_notice class ConfigContextManager: + """Context manager allowing the KvikIO configurations to be set upon entering a + `with` block, and automatically reset upon leaving the block. + """ + def __init__(self, config: dict[str, str]): ( self._property_getters, @@ -39,7 +43,7 @@ def _set_property(self, property: str, value: Any): func = self._property_setters[property] func(value) - @kvikio.utils.call_once + @call_once def _property_getter_and_setter(self) -> tuple[dict[str, Any], dict[str, Any]]: module_dict = vars(kvikio._lib.defaults) @@ -81,20 +85,40 @@ def set(*config) -> ConfigContextManager: .. code-block:: python + # Set the property globally. kvikio.defaults.set({"prop1": value1, "prop2": value2}) + # Set the property with a context manager. + # The property automatically reverts to its old value + # after leaving the `with` block. + with kvikio.defaults.set({"prop1": value1, "prop2": value2}): + ... + - To set a single property .. code-block:: python + # Set the property globally. kvikio.defaults.set("prop", value) + # Set the property with a context manager. + # The property automatically reverts to its old value + # after leaving the `with` block. + with kvikio.defaults.set("prop", value): + ... + Parameters ---------- config The configurations. Can either be a single parameter (dict) consisting of one or more properties, or two parameters key (string) and value (Any) indicating a single property. + + Returns + ------- + ConfigContextManager + A context manager. If used in a `with` statement, the configuration will revert + to its old value upon leaving the block. """ err_msg = ( @@ -114,6 +138,24 @@ def set(*config) -> ConfigContextManager: raise ValueError(err_msg) +def get(config_name: str) -> Any: + """Get KvikIO configurations. + + Parameters + ---------- + config_name: str + The name of the configuration. + + Returns + ------- + Any + The value of the configuration. + """ + context_manager = ConfigContextManager({}) + return context_manager._get_property(config_name) + + +@kvikio_deprecation_notice('Use kvikio.defaults.get("compat_mode") instead') def compat_mode() -> kvikio.CompatMode: """Check if KvikIO is running in compatibility mode. @@ -139,6 +181,7 @@ def compat_mode() -> kvikio.CompatMode: return kvikio._lib.defaults.compat_mode() +@kvikio_deprecation_notice('Use kvikio.defaults.get("num_threads") instead') def num_threads() -> int: """Get the number of threads of the thread pool. @@ -153,6 +196,7 @@ def num_threads() -> int: return kvikio._lib.defaults.thread_pool_nthreads() +@kvikio_deprecation_notice('Use kvikio.defaults.get("task_size") instead') def task_size() -> int: """Get the default task size used for parallel IO operations. @@ -168,6 +212,7 @@ def task_size() -> int: return kvikio._lib.defaults.task_size() +@kvikio_deprecation_notice('Use kvikio.defaults.get("gds_threshold") instead') def gds_threshold() -> int: """Get the default GDS threshold, which is the minimum size to use GDS. @@ -187,6 +232,7 @@ def gds_threshold() -> int: return kvikio._lib.defaults.gds_threshold() +@kvikio_deprecation_notice('Use kvikio.defaults.get("bounce_buffer_size") instead') def bounce_buffer_size() -> int: """Get the size of the bounce buffer used to stage data in host memory. @@ -202,6 +248,7 @@ def bounce_buffer_size() -> int: return kvikio._lib.defaults.bounce_buffer_size() +@kvikio_deprecation_notice('Use kvikio.defaults.get("http_max_attempts") instead') def http_max_attempts() -> int: """Get the maximum number of attempts per remote IO read. @@ -221,6 +268,7 @@ def http_max_attempts() -> int: return kvikio._lib.defaults.http_max_attempts() +@kvikio_deprecation_notice('Use kvikio.defaults.get("http_status_codes") instead') def http_status_codes() -> list[int]: """Get the list of HTTP status codes to retry. @@ -242,17 +290,6 @@ def http_status_codes() -> list[int]: return kvikio._lib.defaults.http_status_codes() -def kvikio_deprecation_notice(msg: str): - def decorator(func: Callable): - def wrapper(*args, **kwargs): - warnings.warn(msg, category=FutureWarning, stacklevel=2) - return func(*args, **kwargs) - - return wrapper - - return decorator - - @kvikio_deprecation_notice('Use kvikio.defaults.set("compat_mode", value) instead') def compat_mode_reset(compatmode: kvikio.CompatMode) -> None: """(Deprecated) Reset the compatibility mode. diff --git a/python/kvikio/kvikio/utils.py b/python/kvikio/kvikio/utils.py index e79386023c..79f5309d4c 100644 --- a/python/kvikio/kvikio/utils.py +++ b/python/kvikio/kvikio/utils.py @@ -6,6 +6,7 @@ import pathlib import threading import time +import warnings from http.server import ( BaseHTTPRequestHandler, SimpleHTTPRequestHandler, @@ -96,20 +97,25 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.process.kill() -def call_once(func: Callable): +def call_once(func: Callable) -> Callable: """Decorate a function such that it is only called once Examples: .. code-block:: python - @call_once + @kvikio.utils.call_once foo(args) Parameters ---------- func: Callable The function to be decorated. + + Returns + ------- + Callable + A decorated function. """ once_flag = True cached_result = None @@ -123,3 +129,34 @@ def wrapper(*args, **kwargs): return cached_result return wrapper + + +def kvikio_deprecation_notice(msg: str) -> Callable: + """Decorate a function to print the deprecation notice at runtime. + + Examples: + + .. code-block:: python + + @kvikio.utils.kvikio_deprecation_notice("Use bar(args) instead.") + foo(args) + + Parameters + ---------- + msg: str + The deprecation notice. + + Returns + ------- + Callable + A decorated function. + """ + + def decorator(func: Callable): + def wrapper(*args, **kwargs): + warnings.warn(msg, category=FutureWarning, stacklevel=2) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/python/kvikio/tests/test_cufile_driver.py b/python/kvikio/tests/test_cufile_driver.py index fcc95c6cbc..7e31d24335 100644 --- a/python/kvikio/tests/test_cufile_driver.py +++ b/python/kvikio/tests/test_cufile_driver.py @@ -18,35 +18,41 @@ def test_open_and_close(): kvikio.cufile_driver.driver_close() -def test_property_setter(): - """Test the method `set`""" +def test_property_accessor(): + """Test the method `get` and `set`""" # Attempt to set a nonexistent property with pytest.raises(KeyError): kvikio.cufile_driver.set("nonexistent_property", 123) + # Attempt to get a nonexistent property + with pytest.raises(KeyError): + kvikio.cufile_driver.get("nonexistent_property") + + # Attempt to set a read-only property + with pytest.raises(KeyError, match="read-only"): + kvikio.cufile_driver.set("major_version", 2077) + # Nested context managers - poll_thresh_size_default = kvikio.cufile_driver.properties.poll_thresh_size + poll_thresh_size_default = kvikio.cufile_driver.get("poll_thresh_size") with kvikio.cufile_driver.set("poll_thresh_size", 1024): - assert kvikio.cufile_driver.properties.poll_thresh_size == 1024 + assert kvikio.cufile_driver.get("poll_thresh_size") == 1024 with kvikio.cufile_driver.set("poll_thresh_size", 2048): - assert kvikio.cufile_driver.properties.poll_thresh_size == 2048 + assert kvikio.cufile_driver.get("poll_thresh_size") == 2048 with kvikio.cufile_driver.set("poll_thresh_size", 4096): - assert kvikio.cufile_driver.properties.poll_thresh_size == 4096 - assert kvikio.cufile_driver.properties.poll_thresh_size == 2048 - assert kvikio.cufile_driver.properties.poll_thresh_size == 1024 - assert kvikio.cufile_driver.properties.poll_thresh_size == poll_thresh_size_default + assert kvikio.cufile_driver.get("poll_thresh_size") == 4096 + assert kvikio.cufile_driver.get("poll_thresh_size") == 2048 + assert kvikio.cufile_driver.get("poll_thresh_size") == 1024 + assert kvikio.cufile_driver.get("poll_thresh_size") == poll_thresh_size_default # Multiple context managers - poll_mode_default = kvikio.cufile_driver.properties.poll_mode - max_device_cache_size_default = ( - kvikio.cufile_driver.properties.max_device_cache_size - ) + poll_mode_default = kvikio.cufile_driver.get("poll_mode") + max_device_cache_size_default = kvikio.cufile_driver.get("max_device_cache_size") with kvikio.cufile_driver.set({"poll_mode": True, "max_device_cache_size": 2048}): - assert kvikio.cufile_driver.properties.poll_mode and ( - kvikio.cufile_driver.properties.max_device_cache_size == 2048 + assert kvikio.cufile_driver.get("poll_mode") and ( + kvikio.cufile_driver.get("max_device_cache_size") == 2048 ) - assert (kvikio.cufile_driver.properties.poll_mode == poll_mode_default) and ( - kvikio.cufile_driver.properties.max_device_cache_size + assert (kvikio.cufile_driver.get("poll_mode") == poll_mode_default) and ( + kvikio.cufile_driver.get("max_device_cache_size") == max_device_cache_size_default ) diff --git a/python/kvikio/tests/test_defaults.py b/python/kvikio/tests/test_defaults.py index 82c6327f5e..bc6cff6180 100644 --- a/python/kvikio/tests/test_defaults.py +++ b/python/kvikio/tests/test_defaults.py @@ -7,75 +7,79 @@ import kvikio.defaults -def test_property_setter(): - """Test the method `set`""" +def test_property_accessor(): + """Test the method `get` and `set`""" # Attempt to set a nonexistent property with pytest.raises(KeyError): kvikio.defaults.set("nonexistent_property", 123) + # Attempt to get a nonexistent property + with pytest.raises(KeyError): + kvikio.defaults.get("nonexistent_property") + # Attempt to set a property whose name is mistakenly prefixed by "set_" # (coinciding with the setter method). with pytest.raises(KeyError): kvikio.defaults.set("set_task_size", 123) # Nested context managers - task_size_default = kvikio.defaults.task_size() + task_size_default = kvikio.defaults.get("task_size") with kvikio.defaults.set("task_size", 1024): - assert kvikio.defaults.task_size() == 1024 + assert kvikio.defaults.get("task_size") == 1024 with kvikio.defaults.set("task_size", 2048): - assert kvikio.defaults.task_size() == 2048 + assert kvikio.defaults.get("task_size") == 2048 with kvikio.defaults.set("task_size", 4096): - assert kvikio.defaults.task_size() == 4096 - assert kvikio.defaults.task_size() == 2048 - assert kvikio.defaults.task_size() == 1024 - assert kvikio.defaults.task_size() == task_size_default + assert kvikio.defaults.get("task_size") == 4096 + assert kvikio.defaults.get("task_size") == 2048 + assert kvikio.defaults.get("task_size") == 1024 + assert kvikio.defaults.get("task_size") == task_size_default # Multiple context managers - task_size_default = kvikio.defaults.task_size() - num_threads_default = kvikio.defaults.num_threads() - bounce_buffer_size_default = kvikio.defaults.bounce_buffer_size() + task_size_default = kvikio.defaults.get("task_size") + num_threads_default = kvikio.defaults.get("num_threads") + bounce_buffer_size_default = kvikio.defaults.get("bounce_buffer_size") with kvikio.defaults.set( {"task_size": 1024, "num_threads": 16, "bounce_buffer_size": 1024} ): assert ( - (kvikio.defaults.task_size() == 1024) - and (kvikio.defaults.num_threads() == 16) - and (kvikio.defaults.bounce_buffer_size() == 1024) + (kvikio.defaults.get("task_size") == 1024) + and (kvikio.defaults.get("num_threads") == 16) + and (kvikio.defaults.get("bounce_buffer_size") == 1024) ) assert ( - (kvikio.defaults.task_size() == task_size_default) - and (kvikio.defaults.num_threads() == num_threads_default) - and (kvikio.defaults.bounce_buffer_size() == bounce_buffer_size_default) + (kvikio.defaults.get("task_size") == task_size_default) + and (kvikio.defaults.get("num_threads") == num_threads_default) + and (kvikio.defaults.get("bounce_buffer_size") == bounce_buffer_size_default) ) @pytest.mark.skipif( - kvikio.defaults.compat_mode() == kvikio.CompatMode.ON, + kvikio.defaults.get("compat_mode") == kvikio.CompatMode.ON, reason="cannot test `compat_mode` when already running in compatibility mode", ) def test_compat_mode(): """Test changing `compat_mode`""" - before = kvikio.defaults.compat_mode() + before = kvikio.defaults.get("compat_mode") with kvikio.defaults.set("compat_mode", kvikio.CompatMode.ON): - assert kvikio.defaults.compat_mode() == kvikio.CompatMode.ON + assert kvikio.defaults.get("compat_mode") == kvikio.CompatMode.ON kvikio.defaults.set("compat_mode", kvikio.CompatMode.OFF) - assert kvikio.defaults.compat_mode() == kvikio.CompatMode.OFF + assert kvikio.defaults.get("compat_mode") == kvikio.CompatMode.OFF kvikio.defaults.set("compat_mode", kvikio.CompatMode.AUTO) - assert kvikio.defaults.compat_mode() == kvikio.CompatMode.AUTO - assert before == kvikio.defaults.compat_mode() + assert kvikio.defaults.get("compat_mode") == kvikio.CompatMode.AUTO + assert before == kvikio.defaults.get("compat_mode") def test_num_threads(): """Test changing `num_threads`""" - before = kvikio.defaults.num_threads() + before = kvikio.defaults.get("num_threads") with kvikio.defaults.set("num_threads", 3): - assert kvikio.defaults.num_threads() == 3 + assert kvikio.defaults.get("num_threads") == 3 kvikio.defaults.set("num_threads", 4) - assert kvikio.defaults.num_threads() == 4 - assert before == kvikio.defaults.num_threads() + assert kvikio.defaults.get("num_threads") == 4 + assert before == kvikio.defaults.get("num_threads") with pytest.raises(ValueError, match="positive integer greater than zero"): kvikio.defaults.set("num_threads", 0) @@ -86,12 +90,12 @@ def test_num_threads(): def test_task_size(): """Test changing `task_size`""" - before = kvikio.defaults.task_size() + before = kvikio.defaults.get("task_size") with kvikio.defaults.set("task_size", 3): - assert kvikio.defaults.task_size() == 3 + assert kvikio.defaults.get("task_size") == 3 kvikio.defaults.set("task_size", 4) - assert kvikio.defaults.task_size() == 4 - assert before == kvikio.defaults.task_size() + assert kvikio.defaults.get("task_size") == 4 + assert before == kvikio.defaults.get("task_size") with pytest.raises(ValueError, match="positive integer greater than zero"): kvikio.defaults.set("task_size", 0) @@ -102,12 +106,12 @@ def test_task_size(): def test_gds_threshold(): """Test changing `gds_threshold`""" - before = kvikio.defaults.gds_threshold() + before = kvikio.defaults.get("gds_threshold") with kvikio.defaults.set("gds_threshold", 3): - assert kvikio.defaults.gds_threshold() == 3 + assert kvikio.defaults.get("gds_threshold") == 3 kvikio.defaults.set("gds_threshold", 4) - assert kvikio.defaults.gds_threshold() == 4 - assert before == kvikio.defaults.gds_threshold() + assert kvikio.defaults.get("gds_threshold") == 4 + assert before == kvikio.defaults.get("gds_threshold") with pytest.raises(OverflowError, match="negative value"): kvikio.defaults.set("gds_threshold", -1) @@ -116,12 +120,12 @@ def test_gds_threshold(): def test_bounce_buffer_size(): """Test changing `bounce_buffer_size`""" - before = kvikio.defaults.bounce_buffer_size() + before = kvikio.defaults.get("bounce_buffer_size") with kvikio.defaults.set("bounce_buffer_size", 3): - assert kvikio.defaults.bounce_buffer_size() == 3 + assert kvikio.defaults.get("bounce_buffer_size") == 3 kvikio.defaults.set("bounce_buffer_size", 4) - assert kvikio.defaults.bounce_buffer_size() == 4 - assert before == kvikio.defaults.bounce_buffer_size() + assert kvikio.defaults.get("bounce_buffer_size") == 4 + assert before == kvikio.defaults.get("bounce_buffer_size") with pytest.raises(ValueError, match="positive integer greater than zero"): kvikio.defaults.set("bounce_buffer_size", 0) @@ -130,13 +134,13 @@ def test_bounce_buffer_size(): def test_http_max_attempts(): - before = kvikio.defaults.http_max_attempts() + before = kvikio.defaults.get("http_max_attempts") with kvikio.defaults.set("http_max_attempts", 5): - assert kvikio.defaults.http_max_attempts() == 5 + assert kvikio.defaults.get("http_max_attempts") == 5 kvikio.defaults.set("http_max_attempts", 4) - assert kvikio.defaults.http_max_attempts() == 4 - assert kvikio.defaults.http_max_attempts() == before + assert kvikio.defaults.get("http_max_attempts") == 4 + assert kvikio.defaults.get("http_max_attempts") == before with pytest.raises(ValueError, match="positive integer"): kvikio.defaults.set("http_max_attempts", 0) @@ -145,13 +149,13 @@ def test_http_max_attempts(): def test_http_status_codes(): - before = kvikio.defaults.http_status_codes() + before = kvikio.defaults.get("http_status_codes") with kvikio.defaults.set("http_status_codes", [500]): - assert kvikio.defaults.http_status_codes() == [500] + assert kvikio.defaults.get("http_status_codes") == [500] kvikio.defaults.set("http_status_codes", [429, 500]) - assert kvikio.defaults.http_status_codes() == [429, 500] - assert kvikio.defaults.http_status_codes() == before + assert kvikio.defaults.get("http_status_codes") == [429, 500] + assert kvikio.defaults.get("http_status_codes") == before with pytest.raises(TypeError): kvikio.defaults.set("http_status_codes", 0)