|
1 | 1 | # pyre-strict |
2 | 2 | import time |
3 | 3 | import warnings |
| 4 | +from functools import reduce |
| 5 | +from types import ModuleType |
4 | 6 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple |
5 | 7 |
|
6 | 8 | import torch |
@@ -282,14 +284,38 @@ def forward(self, x): |
282 | 284 | return (x - self.mean) / (self.std + self.eps) |
283 | 285 |
|
284 | 286 |
|
| 287 | +def _import_sklearn() -> ModuleType: |
| 288 | + try: |
| 289 | + import sklearn |
| 290 | + import sklearn.linear_model |
| 291 | + import sklearn.svm |
| 292 | + except ImportError: |
| 293 | + raise ValueError("sklearn is not available. Please install sklearn >= 0.23") |
| 294 | + |
| 295 | + if not sklearn.__version__ >= "0.23.0": |
| 296 | + warnings.warn( |
| 297 | + "Must have sklearn version 0.23.0 or higher to use " |
| 298 | + "sample_weight in Lasso regression.", |
| 299 | + stacklevel=1, |
| 300 | + ) |
| 301 | + return sklearn |
| 302 | + |
| 303 | + |
| 304 | +def _import_numpy() -> ModuleType: |
| 305 | + try: |
| 306 | + import numpy |
| 307 | + except ImportError: |
| 308 | + raise ValueError("numpy is not available. Please install numpy.") |
| 309 | + return numpy |
| 310 | + |
| 311 | + |
285 | 312 | def sklearn_train_linear_model( |
286 | 313 | model: LinearModel, |
287 | 314 | dataloader: DataLoader, |
288 | 315 | construct_kwargs: Dict[str, Any], |
289 | 316 | sklearn_trainer: str = "Lasso", |
290 | 317 | norm_input: bool = False, |
291 | | - # pyre-fixme[2]: Parameter must be annotated. |
292 | | - **fit_kwargs, |
| 318 | + **fit_kwargs: Any, |
293 | 319 | ) -> Dict[str, float]: |
294 | 320 | r""" |
295 | 321 | Alternative method to train with sklearn. This does introduce some slight |
@@ -318,26 +344,9 @@ def sklearn_train_linear_model( |
318 | 344 | fit_kwargs |
319 | 345 | Other arguments to send to `sklearn_trainer`'s `.fit` method |
320 | 346 | """ |
321 | | - from functools import reduce |
322 | | - |
323 | | - try: |
324 | | - import numpy as np |
325 | | - except ImportError: |
326 | | - raise ValueError("numpy is not available. Please install numpy.") |
327 | | - |
328 | | - try: |
329 | | - import sklearn |
330 | | - import sklearn.linear_model |
331 | | - import sklearn.svm |
332 | | - except ImportError: |
333 | | - raise ValueError("sklearn is not available. Please install sklearn >= 0.23") |
334 | | - |
335 | | - if not sklearn.__version__ >= "0.23.0": |
336 | | - warnings.warn( |
337 | | - "Must have sklearn version 0.23.0 or higher to use " |
338 | | - "sample_weight in Lasso regression.", |
339 | | - stacklevel=1, |
340 | | - ) |
| 347 | + # Lazy imports |
| 348 | + np = _import_numpy() |
| 349 | + sklearn = _import_sklearn() |
341 | 350 |
|
342 | 351 | num_batches = 0 |
343 | 352 | xs, ys, ws = [], [], [] |
@@ -369,8 +378,8 @@ def sklearn_train_linear_model( |
369 | 378 |
|
370 | 379 | t1 = time.time() |
371 | 380 | # pyre-fixme[29]: `str` is not a function. |
372 | | - sklearn_model = reduce( |
373 | | - lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") |
| 381 | + sklearn_model = reduce( # type: ignore |
| 382 | + lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") # type: ignore # noqa: E501 |
374 | 383 | )(**construct_kwargs) |
375 | 384 | try: |
376 | 385 | sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs) |
|
0 commit comments