Skip to content

Commit 4d10f2f

Browse files
committed
Remove types in doc-strings in learner1D.py
1 parent abc48b0 commit 4d10f2f

File tree

1 file changed

+29
-17
lines changed

1 file changed

+29
-17
lines changed

adaptive/learner/learner1D.py

+29-17
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from copy import copy, deepcopy
77
from numbers import Integral as Int
88
from numbers import Real
9-
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
9+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Sequence, Tuple, Union
1010

1111
import cloudpickle
1212
import numpy as np
@@ -24,12 +24,22 @@
2424
partial_function_from_dataframe,
2525
)
2626

27+
if TYPE_CHECKING:
28+
import holoviews
29+
2730
try:
2831
from typing import TypeAlias
2932
except ImportError:
3033
# Remove this when we drop support for Python 3.9
3134
from typing_extensions import TypeAlias
3235

36+
try:
37+
from typing import Literal
38+
except ImportError:
39+
# Remove this when we drop support for Python 3.7
40+
from typing_extensions import Literal
41+
42+
3343
try:
3444
import pandas
3545

@@ -145,7 +155,7 @@ def resolution_loss_function(
145155
146156
Returns
147157
-------
148-
loss_function : callable
158+
loss_function
149159
150160
Examples
151161
--------
@@ -230,12 +240,12 @@ class Learner1D(BaseLearner):
230240
231241
Parameters
232242
----------
233-
function : callable
243+
function
234244
The function to learn. Must take a single real parameter and
235245
return a real number or 1D array.
236-
bounds : pair of reals
246+
bounds
237247
The bounds of the interval on which to learn 'function'.
238-
loss_per_interval: callable, optional
248+
loss_per_interval
239249
A function that returns the loss for a single interval of the domain.
240250
If not provided, then a default is used, which uses the scaled distance
241251
in the x-y plane as the loss. See the notes for more details.
@@ -356,15 +366,15 @@ def to_dataframe(
356366
357367
Parameters
358368
----------
359-
with_default_function_args : bool, optional
369+
with_default_function_args
360370
Include the ``learner.function``'s default arguments as a
361371
column, by default True
362-
function_prefix : str, optional
372+
function_prefix
363373
Prefix to the ``learner.function``'s default arguments' names,
364374
by default "function."
365-
x_name : str, optional
375+
x_name
366376
Name of the input value, by default "x"
367-
y_name : str, optional
377+
y_name
368378
Name of the output value, by default "y"
369379
370380
Returns
@@ -403,16 +413,16 @@ def load_dataframe(
403413
404414
Parameters
405415
----------
406-
df : pandas.DataFrame
416+
df
407417
The data to load.
408-
with_default_function_args : bool, optional
418+
with_default_function_args
409419
The ``with_default_function_args`` used in ``to_dataframe()``,
410420
by default True
411-
function_prefix : str, optional
421+
function_prefix
412422
The ``function_prefix`` used in ``to_dataframe``, by default "function."
413-
x_name : str, optional
423+
x_name
414424
The ``x_name`` used in ``to_dataframe``, by default "x"
415-
y_name : str, optional
425+
y_name
416426
The ``y_name`` used in ``to_dataframe``, by default "y"
417427
"""
418428
self.tell_many(df[x_name].values, df[y_name].values)
@@ -795,17 +805,19 @@ def _loss(
795805
loss = mapping[ival]
796806
return finite_loss(ival, loss, self._scale[0])
797807

798-
def plot(self, *, scatter_or_line: str = "scatter"):
808+
def plot(
809+
self, *, scatter_or_line: Literal["scatter", "line"] = "scatter"
810+
) -> holoviews.Overlay:
799811
"""Returns a plot of the evaluated data.
800812
801813
Parameters
802814
----------
803-
scatter_or_line : str, default: "scatter"
815+
scatter_or_line
804816
Plot as a scatter plot ("scatter") or a line plot ("line").
805817
806818
Returns
807819
-------
808-
plot : `holoviews.Overlay`
820+
plot
809821
Plot of the evaluated data.
810822
"""
811823
if scatter_or_line not in ("scatter", "line"):

0 commit comments

Comments
 (0)