Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kingcrimsontianyu committed Feb 28, 2025
1 parent 33b133e commit 910ed6b
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 85 deletions.
59 changes: 56 additions & 3 deletions python/kvikio/kvikio/cufile_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,21 @@
from kvikio._lib import cufile_driver # type: ignore

properties = cufile_driver.DriverProperties()
"""cuFile driver configurations. Use kvikio.cufile_driver.properties.get and
kvikio.cufile_driver.properties.set to access the configurations.
"""


class ConfigContextManager:
"""Context manager allowing the cuFile driver configurations to be set upon
entering a `with` block, and automatically reset upon leaving the block.
"""

def __init__(self, config: dict[str, str]):
(
self._property_getters,
self._property_setters,
self._readonly_property_getters,
) = self._property_getter_and_setter()
self._old_properties = {}

Expand All @@ -30,19 +38,29 @@ def __exit__(self, type_unused, value, traceback_unused):
self._set_property(key, value)

def _get_property(self, property: str) -> Any:
func = self._property_getters[property]
if property in self._property_getters:
func = self._property_getters[property]
elif property in self._readonly_property_getters:
func = self._readonly_property_getters[property]
else:
raise KeyError

# getter signature: object.__get__(self, instance, owner=None)
return func(properties)

def _set_property(self, property: str, value: Any):
if property in self._readonly_property_getters:
raise KeyError("This property is read-only.")

func = self._property_setters[property]

# setter signature: object.__set__(self, instance, value)
func(properties, value)

@kvikio.utils.call_once
def _property_getter_and_setter(self) -> tuple[dict[str, Any], dict[str, Any]]:
def _property_getter_and_setter(
self,
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
class_dict = vars(cufile_driver.DriverProperties)

property_getter_names = [
Expand All @@ -58,7 +76,19 @@ def _property_getter_and_setter(self) -> tuple[dict[str, Any], dict[str, Any]]:
for name in property_getter_names:
property_getters[name] = class_dict[name].__get__
property_setters[name] = class_dict[name].__set__
return property_getters, property_setters

readonly_property_getter_names = [
"is_gds_available",
"major_version",
"minor_version",
"allow_compat_mode",
"per_buffer_cache_size",
]
readonly_property_getters = {}
for name in readonly_property_getter_names:
readonly_property_getters[name] = class_dict[name].__get__

return property_getters, property_setters, readonly_property_getters


@overload
Expand Down Expand Up @@ -94,6 +124,12 @@ def set(*config) -> ConfigContextManager:
The configurations. Can either be a single parameter (dict) consisting of one
or more properties, or two parameters key (string) and value (Any)
indicating a single property.
Returns
-------
ConfigContextManager
A context manager. If used in a `with` statement, the configuration will revert
to its old value upon leaving the block.
"""

err_msg = (
Expand All @@ -113,6 +149,23 @@ def set(*config) -> ConfigContextManager:
raise ValueError(err_msg)


def get(config_name: str) -> Any:
"""Get cuFile driver configurations.
Parameters
----------
config_name: str
The name of the configuration.
Returns
-------
Any
The value of the configuration.
"""
context_manager = ConfigContextManager({})
return context_manager._get_property(config_name)


def libcufile_version() -> Tuple[int, int]:
"""Get the libcufile version.
Expand Down
67 changes: 52 additions & 15 deletions python/kvikio/kvikio/defaults.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# Copyright (c) 2021-2025, NVIDIA CORPORATION. All rights reserved.
# See file LICENSE for terms.

import warnings
from typing import Any, Callable, overload

from typing import Any, overload

import kvikio._lib.defaults
import kvikio.utils
from kvikio.utils import call_once, kvikio_deprecation_notice


class ConfigContextManager:
"""Context manager allowing the KvikIO configurations to be set upon entering a
`with` block, and automatically reset upon leaving the block.
"""

def __init__(self, config: dict[str, str]):
(
self._property_getters,
Expand Down Expand Up @@ -39,7 +43,7 @@ def _set_property(self, property: str, value: Any):
func = self._property_setters[property]
func(value)

@kvikio.utils.call_once
@call_once
def _property_getter_and_setter(self) -> tuple[dict[str, Any], dict[str, Any]]:
module_dict = vars(kvikio._lib.defaults)

Expand Down Expand Up @@ -81,20 +85,40 @@ def set(*config) -> ConfigContextManager:
.. code-block:: python
# Set the property globally.
kvikio.defaults.set({"prop1": value1, "prop2": value2})
# Set the property with a context manager.
# The property automatically reverts to its old value
# after leaving the `with` block.
with kvikio.defaults.set({"prop1": value1, "prop2": value2}):
...
- To set a single property
.. code-block:: python
# Set the property globally.
kvikio.defaults.set("prop", value)
# Set the property with a context manager.
# The property automatically reverts to its old value
# after leaving the `with` block.
with kvikio.defaults.set("prop", value):
...
Parameters
----------
config
The configurations. Can either be a single parameter (dict) consisting of one
or more properties, or two parameters key (string) and value (Any)
indicating a single property.
Returns
-------
ConfigContextManager
A context manager. If used in a `with` statement, the configuration will revert
to its old value upon leaving the block.
"""

err_msg = (
Expand All @@ -114,6 +138,24 @@ def set(*config) -> ConfigContextManager:
raise ValueError(err_msg)


def get(config_name: str) -> Any:
"""Get KvikIO configurations.
Parameters
----------
config_name: str
The name of the configuration.
Returns
-------
Any
The value of the configuration.
"""
context_manager = ConfigContextManager({})
return context_manager._get_property(config_name)


@kvikio_deprecation_notice('Use kvikio.defaults.get("compat_mode") instead')
def compat_mode() -> kvikio.CompatMode:
"""Check if KvikIO is running in compatibility mode.
Expand All @@ -139,6 +181,7 @@ def compat_mode() -> kvikio.CompatMode:
return kvikio._lib.defaults.compat_mode()


@kvikio_deprecation_notice('Use kvikio.defaults.get("num_threads") instead')
def num_threads() -> int:
"""Get the number of threads of the thread pool.
Expand All @@ -153,6 +196,7 @@ def num_threads() -> int:
return kvikio._lib.defaults.thread_pool_nthreads()


@kvikio_deprecation_notice('Use kvikio.defaults.get("task_size") instead')
def task_size() -> int:
"""Get the default task size used for parallel IO operations.
Expand All @@ -168,6 +212,7 @@ def task_size() -> int:
return kvikio._lib.defaults.task_size()


@kvikio_deprecation_notice('Use kvikio.defaults.get("gds_threshold") instead')
def gds_threshold() -> int:
"""Get the default GDS threshold, which is the minimum size to use GDS.
Expand All @@ -187,6 +232,7 @@ def gds_threshold() -> int:
return kvikio._lib.defaults.gds_threshold()


@kvikio_deprecation_notice('Use kvikio.defaults.get("bounce_buffer_size") instead')
def bounce_buffer_size() -> int:
"""Get the size of the bounce buffer used to stage data in host memory.
Expand All @@ -202,6 +248,7 @@ def bounce_buffer_size() -> int:
return kvikio._lib.defaults.bounce_buffer_size()


@kvikio_deprecation_notice('Use kvikio.defaults.get("http_max_attempts") instead')
def http_max_attempts() -> int:
"""Get the maximum number of attempts per remote IO read.
Expand All @@ -221,6 +268,7 @@ def http_max_attempts() -> int:
return kvikio._lib.defaults.http_max_attempts()


@kvikio_deprecation_notice('Use kvikio.defaults.get("http_status_codes") instead')
def http_status_codes() -> list[int]:
"""Get the list of HTTP status codes to retry.
Expand All @@ -242,17 +290,6 @@ def http_status_codes() -> list[int]:
return kvikio._lib.defaults.http_status_codes()


def kvikio_deprecation_notice(msg: str):
def decorator(func: Callable):
def wrapper(*args, **kwargs):
warnings.warn(msg, category=FutureWarning, stacklevel=2)
return func(*args, **kwargs)

return wrapper

return decorator


@kvikio_deprecation_notice('Use kvikio.defaults.set("compat_mode", value) instead')
def compat_mode_reset(compatmode: kvikio.CompatMode) -> None:
"""(Deprecated) Reset the compatibility mode.
Expand Down
41 changes: 39 additions & 2 deletions python/kvikio/kvikio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pathlib
import threading
import time
import warnings
from http.server import (
BaseHTTPRequestHandler,
SimpleHTTPRequestHandler,
Expand Down Expand Up @@ -96,20 +97,25 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.process.kill()


def call_once(func: Callable):
def call_once(func: Callable) -> Callable:
"""Decorate a function such that it is only called once
Examples:
.. code-block:: python
@call_once
@kvikio.utils.call_once
foo(args)
Parameters
----------
func: Callable
The function to be decorated.
Returns
-------
Callable
A decorated function.
"""
once_flag = True
cached_result = None
Expand All @@ -123,3 +129,34 @@ def wrapper(*args, **kwargs):
return cached_result

return wrapper


def kvikio_deprecation_notice(msg: str) -> Callable:
"""Decorate a function to print the deprecation notice at runtime.
Examples:
.. code-block:: python
@kvikio.utils.kvikio_deprecation_notice("Use bar(args) instead.")
foo(args)
Parameters
----------
msg: str
The deprecation notice.
Returns
-------
Callable
A decorated function.
"""

def decorator(func: Callable):
def wrapper(*args, **kwargs):
warnings.warn(msg, category=FutureWarning, stacklevel=2)
return func(*args, **kwargs)

return wrapper

return decorator
Loading

0 comments on commit 910ed6b

Please sign in to comment.