Skip to content

Commit 28c58da

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Reduce complexity of 'sklearn_train_linear_model' (#1375)
Summary: Pull Request resolved: #1375 Reduce complexity of 'sklearn_train_linear_model' Reviewed By: jsawruk Differential Revision: D64438317 fbshipit-source-id: aa99f2ec9d9a0b349a423fc0a37e9b21b2a0ff39
1 parent 9689ccd commit 28c58da

File tree

1 file changed

+33
-24
lines changed
  • captum/_utils/models/linear_model

1 file changed

+33
-24
lines changed

captum/_utils/models/linear_model/train.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# pyre-strict
22
import time
33
import warnings
4+
from functools import reduce
5+
from types import ModuleType
46
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
57

68
import torch
@@ -282,14 +284,38 @@ def forward(self, x):
282284
return (x - self.mean) / (self.std + self.eps)
283285

284286

287+
def _import_sklearn() -> ModuleType:
288+
try:
289+
import sklearn
290+
import sklearn.linear_model
291+
import sklearn.svm
292+
except ImportError:
293+
raise ValueError("sklearn is not available. Please install sklearn >= 0.23")
294+
295+
if not sklearn.__version__ >= "0.23.0":
296+
warnings.warn(
297+
"Must have sklearn version 0.23.0 or higher to use "
298+
"sample_weight in Lasso regression.",
299+
stacklevel=1,
300+
)
301+
return sklearn
302+
303+
304+
def _import_numpy() -> ModuleType:
305+
try:
306+
import numpy
307+
except ImportError:
308+
raise ValueError("numpy is not available. Please install numpy.")
309+
return numpy
310+
311+
285312
def sklearn_train_linear_model(
286313
model: LinearModel,
287314
dataloader: DataLoader,
288315
construct_kwargs: Dict[str, Any],
289316
sklearn_trainer: str = "Lasso",
290317
norm_input: bool = False,
291-
# pyre-fixme[2]: Parameter must be annotated.
292-
**fit_kwargs,
318+
**fit_kwargs: Any,
293319
) -> Dict[str, float]:
294320
r"""
295321
Alternative method to train with sklearn. This does introduce some slight
@@ -318,26 +344,9 @@ def sklearn_train_linear_model(
318344
fit_kwargs
319345
Other arguments to send to `sklearn_trainer`'s `.fit` method
320346
"""
321-
from functools import reduce
322-
323-
try:
324-
import numpy as np
325-
except ImportError:
326-
raise ValueError("numpy is not available. Please install numpy.")
327-
328-
try:
329-
import sklearn
330-
import sklearn.linear_model
331-
import sklearn.svm
332-
except ImportError:
333-
raise ValueError("sklearn is not available. Please install sklearn >= 0.23")
334-
335-
if not sklearn.__version__ >= "0.23.0":
336-
warnings.warn(
337-
"Must have sklearn version 0.23.0 or higher to use "
338-
"sample_weight in Lasso regression.",
339-
stacklevel=1,
340-
)
347+
# Lazy imports
348+
np = _import_numpy()
349+
sklearn = _import_sklearn()
341350

342351
num_batches = 0
343352
xs, ys, ws = [], [], []
@@ -369,8 +378,8 @@ def sklearn_train_linear_model(
369378

370379
t1 = time.time()
371380
# pyre-fixme[29]: `str` is not a function.
372-
sklearn_model = reduce(
373-
lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".")
381+
sklearn_model = reduce( # type: ignore
382+
lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") # type: ignore # noqa: E501
374383
)(**construct_kwargs)
375384
try:
376385
sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)

0 commit comments

Comments
 (0)