Skip to content

Commit

Permalink
Implement code as suggested by @jameslamb
Browse files Browse the repository at this point in the history
  • Loading branch information
mlondschien committed Jan 12, 2025
1 parent e61bcbe commit 6b37b34
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
11 changes: 7 additions & 4 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import scipy.sparse

from .compat import (
CFFI_INSTALLED,
PANDAS_INSTALLED,
PYARROW_INSTALLED,
arrow_cffi,
Expand Down Expand Up @@ -1706,8 +1707,8 @@ def __pred_for_pyarrow_table(
predict_type: int,
) -> Tuple[np.ndarray, int]:
"""Predict for a PyArrow table."""
if not PYARROW_INSTALLED:
raise LightGBMError("Cannot predict from Arrow without `pyarrow` installed.")
if not (PYARROW_INSTALLED and CFFI_INSTALLED):
raise LightGBMError("Cannot predict from Arrow without `pyarrow` and `cffi` installed.")

# Check that the input is valid: we only handle numbers (for now)
if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types):
Expand Down Expand Up @@ -2186,6 +2187,8 @@ def _lazy_init(
elif isinstance(data, np.ndarray):
self.__init_from_np2d(data, params_str, ref_dataset)
elif _is_pyarrow_table(data):
if not CFFI_INSTALLED:
raise LightGBMError("Cannot init dataframe from Arrow without `pyarrow` and `cffi` installed.")
self.__init_from_pyarrow_table(data, params_str, ref_dataset)
elif isinstance(data, list) and len(data) > 0:
if _is_list_of_numpy_arrays(data):
Expand Down Expand Up @@ -2459,8 +2462,8 @@ def __init_from_pyarrow_table(
ref_dataset: Optional[_DatasetHandle],
) -> "Dataset":
"""Initialize data from a PyArrow table."""
if not PYARROW_INSTALLED:
raise LightGBMError("Cannot init dataframe from Arrow without `pyarrow` installed.")
if not (PYARROW_INSTALLED and CFFI_INSTALLED):
raise LightGBMError("Cannot init dataframe from Arrow without `pyarrow` and `cffi` installed.")

# Check that the input is valid: we only handle numbers (for now)
if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types):
Expand Down
33 changes: 21 additions & 12 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ def __init__(self, *args: Any, **kwargs: Any):
from pyarrow import ChunkedArray as pa_ChunkedArray
from pyarrow import Table as pa_Table
from pyarrow import chunked_array as pa_chunked_array
from pyarrow.cffi import ffi as arrow_cffi
from pyarrow.types import is_boolean as arrow_is_boolean
from pyarrow.types import is_floating as arrow_is_floating
from pyarrow.types import is_integer as arrow_is_integer
Expand All @@ -316,17 +315,6 @@ class pa_Table: # type: ignore
def __init__(self, *args: Any, **kwargs: Any):
pass

class arrow_cffi: # type: ignore
"""Dummy class for pyarrow.cffi.ffi."""

CData = None
addressof = None
cast = None
new = None

def __init__(self, *args: Any, **kwargs: Any):
pass

class pa_compute: # type: ignore
"""Dummy class for pyarrow.compute."""

Expand All @@ -338,6 +326,27 @@ class pa_compute: # type: ignore
arrow_is_integer = None
arrow_is_floating = None


"""cffi"""
try:
from pyarrow.cffi import ffi as arrow_cffi

CFFI_INSTALLED = True
except ImportError:
CFFI_INSTALLED = False

class arrow_cffi: # type: ignore
"""Dummy class for pyarrow.cffi.ffi."""

CData = None
addressof = None
cast = None
new = None

def __init__(self, *args: Any, **kwargs: Any):
pass


"""cpu_count()"""
try:
from joblib import cpu_count
Expand Down

0 comments on commit 6b37b34

Please sign in to comment.