|
6 | 6 | from copy import copy, deepcopy
|
7 | 7 | from numbers import Integral as Int
|
8 | 8 | from numbers import Real
|
9 |
| -from typing import Any, Callable, Dict, List, Sequence, Tuple, Union |
| 9 | +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Sequence, Tuple, Union |
10 | 10 |
|
11 | 11 | import cloudpickle
|
12 | 12 | import numpy as np
|
|
24 | 24 | partial_function_from_dataframe,
|
25 | 25 | )
|
26 | 26 |
|
| 27 | +if TYPE_CHECKING: |
| 28 | + import holoviews |
| 29 | + |
27 | 30 | try:
|
28 | 31 | from typing import TypeAlias
|
29 | 32 | except ImportError:
|
30 | 33 | # Remove this when we drop support for Python 3.9
|
31 | 34 | from typing_extensions import TypeAlias
|
32 | 35 |
|
| 36 | +try: |
| 37 | + from typing import Literal |
| 38 | +except ImportError: |
| 39 | + # Remove this when we drop support for Python 3.7 |
| 40 | + from typing_extensions import Literal |
| 41 | + |
| 42 | + |
33 | 43 | try:
|
34 | 44 | import pandas
|
35 | 45 |
|
@@ -145,7 +155,7 @@ def resolution_loss_function(
|
145 | 155 |
|
146 | 156 | Returns
|
147 | 157 | -------
|
148 |
| - loss_function : callable |
| 158 | + loss_function |
149 | 159 |
|
150 | 160 | Examples
|
151 | 161 | --------
|
@@ -230,12 +240,12 @@ class Learner1D(BaseLearner):
|
230 | 240 |
|
231 | 241 | Parameters
|
232 | 242 | ----------
|
233 |
| - function : callable |
| 243 | + function |
234 | 244 | The function to learn. Must take a single real parameter and
|
235 | 245 | return a real number or 1D array.
|
236 |
| - bounds : pair of reals |
| 246 | + bounds |
237 | 247 | The bounds of the interval on which to learn 'function'.
|
238 |
| - loss_per_interval: callable, optional |
| 248 | + loss_per_interval |
239 | 249 | A function that returns the loss for a single interval of the domain.
|
240 | 250 | If not provided, then a default is used, which uses the scaled distance
|
241 | 251 | in the x-y plane as the loss. See the notes for more details.
|
@@ -356,15 +366,15 @@ def to_dataframe(
|
356 | 366 |
|
357 | 367 | Parameters
|
358 | 368 | ----------
|
359 |
| - with_default_function_args : bool, optional |
| 369 | + with_default_function_args |
360 | 370 | Include the ``learner.function``'s default arguments as a
|
361 | 371 | column, by default True
|
362 |
| - function_prefix : str, optional |
| 372 | + function_prefix |
363 | 373 | Prefix to the ``learner.function``'s default arguments' names,
|
364 | 374 | by default "function."
|
365 |
| - x_name : str, optional |
| 375 | + x_name |
366 | 376 | Name of the input value, by default "x"
|
367 |
| - y_name : str, optional |
| 377 | + y_name |
368 | 378 | Name of the output value, by default "y"
|
369 | 379 |
|
370 | 380 | Returns
|
@@ -403,16 +413,16 @@ def load_dataframe(
|
403 | 413 |
|
404 | 414 | Parameters
|
405 | 415 | ----------
|
406 |
| - df : pandas.DataFrame |
| 416 | + df |
407 | 417 | The data to load.
|
408 |
| - with_default_function_args : bool, optional |
| 418 | + with_default_function_args |
409 | 419 | The ``with_default_function_args`` used in ``to_dataframe()``,
|
410 | 420 | by default True
|
411 |
| - function_prefix : str, optional |
| 421 | + function_prefix |
412 | 422 | The ``function_prefix`` used in ``to_dataframe``, by default "function."
|
413 |
| - x_name : str, optional |
| 423 | + x_name |
414 | 424 | The ``x_name`` used in ``to_dataframe``, by default "x"
|
415 |
| - y_name : str, optional |
| 425 | + y_name |
416 | 426 | The ``y_name`` used in ``to_dataframe``, by default "y"
|
417 | 427 | """
|
418 | 428 | self.tell_many(df[x_name].values, df[y_name].values)
|
@@ -795,17 +805,19 @@ def _loss(
|
795 | 805 | loss = mapping[ival]
|
796 | 806 | return finite_loss(ival, loss, self._scale[0])
|
797 | 807 |
|
798 |
| - def plot(self, *, scatter_or_line: str = "scatter"): |
| 808 | + def plot( |
| 809 | + self, *, scatter_or_line: Literal["scatter", "line"] = "scatter" |
| 810 | + ) -> holoviews.Overlay: |
799 | 811 | """Returns a plot of the evaluated data.
|
800 | 812 |
|
801 | 813 | Parameters
|
802 | 814 | ----------
|
803 |
| - scatter_or_line : str, default: "scatter" |
| 815 | + scatter_or_line |
804 | 816 | Plot as a scatter plot ("scatter") or a line plot ("line").
|
805 | 817 |
|
806 | 818 | Returns
|
807 | 819 | -------
|
808 |
| - plot : `holoviews.Overlay` |
| 820 | + plot |
809 | 821 | Plot of the evaluated data.
|
810 | 822 | """
|
811 | 823 | if scatter_or_line not in ("scatter", "line"):
|
|
0 commit comments