Skip to content

File tree

9 files changed

+267
-10
lines changed

9 files changed

+267
-10
lines changed

stubs/tensorflow/@tests/stubtest_allowlist.txt

+4-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ tensorflow.keras.layers.*.call
7373
tensorflow.keras.regularizers.Regularizer.__call__
7474
tensorflow.keras.constraints.Constraint.__call__
7575

76-
# Layer class does good deal of __new__ magic and actually returns one of two different internal
76+
# Layer/Model class does good deal of __new__ magic and actually returns one of two different internal
7777
# types depending on tensorflow execution mode. This feels like implementation internal.
7878
tensorflow.keras.layers.Layer.__new__
7979

@@ -114,5 +114,8 @@ tensorflow.train.ServerDef.*
114114
# python.X
115115
tensorflow.python.*
116116

117+
# The modules below are re-exported from tensorflow.python, and they therefore appear missing to stubtest.
118+
tensorflow.distribute.Strategy
119+
117120
# sigmoid_cross_entropy_with_logits has default values (None), however those values are not valid.
118121
tensorflow.nn.sigmoid_cross_entropy_with_logits
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from _typeshed import Incomplete
2+
3+
from tensorflow.python.distribute.distribute_lib import Strategy as Strategy
4+
5+
def __getattr__(name: str) -> Incomplete: ...

stubs/tensorflow/tensorflow/keras/__init__.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ from _typeshed import Incomplete
22

33
from tensorflow.keras import (
44
activations as activations,
5+
callbacks as callbacks,
56
constraints as constraints,
67
initializers as initializers,
78
layers as layers,

stubs/tensorflow/tensorflow/keras/callbacks.pyi

+8-7
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ import tensorflow as tf
77
from requests.api import _HeadersMapping
88
from tensorflow.keras import Model
99
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
10+
from tensorflow.saved_model import SaveOptions
1011
from tensorflow.train import CheckpointOptions
1112

1213
_Logs: TypeAlias = Mapping[str, Any] | None | Any
1314

1415
class Callback:
15-
model: Model # Model[Any, object]
16+
model: Model[Any, Any]
1617
params: dict[str, Any]
17-
def set_model(self, model: Model) -> None: ...
18+
def set_model(self, model: Model[Any, Any]) -> None: ...
1819
def set_params(self, params: dict[str, Any]) -> None: ...
1920
def on_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
2021
def on_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
@@ -35,18 +36,18 @@ class Callback:
3536

3637
# A CallbackList has exact same api as a callback, but does not actually subclass it.
3738
class CallbackList:
38-
model: Model
39+
model: Model[Any, Any]
3940
params: dict[str, Any]
4041
def __init__(
4142
self,
4243
callbacks: Sequence[Callback] | None = None,
4344
add_history: bool = False,
4445
add_progbar: bool = False,
4546
# model: Model[Any, object] | None = None,
46-
model: Model | None = None,
47+
model: Model[Any, Any] | None = None,
4748
**params: Any,
4849
) -> None: ...
49-
def set_model(self, model: Model) -> None: ...
50+
def set_model(self, model: Model[Any, Any]) -> None: ...
5051
def set_params(self, params: dict[str, Any]) -> None: ...
5152
def on_batch_begin(self, batch: int, logs: _Logs | None = None) -> None: ...
5253
def on_batch_end(self, batch: int, logs: _Logs | None = None) -> None: ...
@@ -115,7 +116,7 @@ class LearningRateScheduler(Callback):
115116
class ModelCheckpoint(Callback):
116117
monitor_op: Any
117118
filepath: str
118-
_options: CheckpointOptions | tf.saved_model.SaveOptions | None
119+
_options: CheckpointOptions | SaveOptions | None
119120
def __init__(
120121
self,
121122
filepath: str,
@@ -125,7 +126,7 @@ class ModelCheckpoint(Callback):
125126
save_weights_only: bool = False,
126127
mode: Literal["auto", "min", "max"] = "auto",
127128
save_freq: str | int = "epoch",
128-
options: CheckpointOptions | tf.saved_model.SaveOptions | None = None,
129+
options: CheckpointOptions | SaveOptions | None = None,
129130
initial_value_threshold: float | None = None,
130131
) -> None: ...
131132
def _save_model(self, epoch: int, batch: int | None, logs: _Logs) -> None: ...
+233-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,235 @@
11
from _typeshed import Incomplete
2+
from collections.abc import Callable, Container, Iterator
3+
from pathlib import Path
4+
from typing import Any, Literal
5+
from typing_extensions import Self, TypeAlias
26

3-
Model = Incomplete
7+
import numpy as np
8+
import numpy.typing as npt
9+
import tensorflow
10+
import tensorflow as tf
11+
from tensorflow import Variable
12+
from tensorflow._aliases import ContainerGeneric, ShapeLike, TensorCompatible
13+
from tensorflow.keras.layers import Layer, _InputT, _OutputT
14+
from tensorflow.keras.optimizers import Optimizer
15+
16+
_Loss: TypeAlias = str | tf.keras.losses.Loss | Callable[[TensorCompatible, TensorCompatible], tf.Tensor]
17+
_Metric: TypeAlias = str | tf.keras.metrics.Metric | Callable[[TensorCompatible, TensorCompatible], tf.Tensor] | None
18+
19+
class Model(Layer[_InputT, _OutputT], tf.Module):
20+
_train_counter: tf.Variable
21+
_test_counter: tf.Variable
22+
optimizer: Optimizer | None
23+
loss: tf.keras.losses.Loss | dict[str, tf.keras.losses.Loss]
24+
stop_training: bool
25+
26+
def __new__(cls, *args: Any, **kwargs: Any) -> Model[_InputT, _OutputT]: ...
27+
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
28+
def __setattr__(self, name: str, value: Any) -> None: ...
29+
def __reduce__(self) -> Incomplete: ...
30+
def __deepcopy__(self, memo: Incomplete) -> Incomplete: ...
31+
def build(self, input_shape: ShapeLike) -> None: ...
32+
def __call__(self, inputs: _InputT, *, training: bool = False, mask: TensorCompatible | None = None) -> _OutputT: ...
33+
def call(self, inputs: _InputT, training: bool | None = None, mask: TensorCompatible | None = None) -> _OutputT: ...
34+
# Ideally loss/metrics/output would share the same structure but higher kinded types are not supported.
35+
def compile(
36+
self,
37+
optimizer: Optimizer | str = "rmsprop",
38+
loss: ContainerGeneric[_Loss] | None = None,
39+
metrics: ContainerGeneric[_Metric] | None = None,
40+
loss_weights: ContainerGeneric[float] | None = None,
41+
weighted_metrics: ContainerGeneric[_Metric] | None = None,
42+
run_eagerly: bool | None = None,
43+
steps_per_execution: int | Literal["auto"] | None = None,
44+
jit_compile: bool | None = None,
45+
pss_evaluation_shards: int | Literal["auto"] = 0,
46+
**kwargs: Any,
47+
) -> None: ...
48+
@property
49+
def metrics(self) -> list[Incomplete]: ...
50+
@property
51+
def metrics_names(self) -> list[str]: ...
52+
@property
53+
def distribute_strategy(self) -> tf.distribute.Strategy: ...
54+
@property
55+
def run_eagerly(self) -> bool: ...
56+
@property
57+
def autotune_steps_per_execution(self) -> bool: ...
58+
@property
59+
def steps_per_execution(self) -> int | None: ... # Returns None for a non-compiled model.
60+
@property
61+
def jit_compile(self) -> bool: ...
62+
@property
63+
def distribute_reduction_method(self) -> Incomplete | Literal["auto"]: ...
64+
def train_step(self, data: TensorCompatible) -> Incomplete: ...
65+
def compute_loss(
66+
self,
67+
x: TensorCompatible | None = None,
68+
y: TensorCompatible | None = None,
69+
y_pred: TensorCompatible | None = None,
70+
sample_weight: Incomplete | None = None,
71+
) -> tf.Tensor | None: ...
72+
def compute_metrics(
73+
self, x: TensorCompatible, y: TensorCompatible, y_pred: TensorCompatible, sample_weight: Incomplete
74+
) -> dict[str, float]: ...
75+
def get_metrics_result(self) -> dict[str, float]: ...
76+
def make_train_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], dict[str, float]]: ...
77+
def fit(
78+
self,
79+
x: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete] | None = None,
80+
y: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete] | None = None,
81+
batch_size: int | None = None,
82+
epochs: int = 1,
83+
verbose: Literal["auto", 0, 1, 2] = "auto",
84+
callbacks: list[tf.keras.callbacks.Callback] | None = None,
85+
validation_split: float = 0.0,
86+
validation_data: TensorCompatible | tf.data.Dataset[Any] | None = None,
87+
shuffle: bool = True,
88+
class_weight: dict[int, float] | None = None,
89+
sample_weight: npt.NDArray[np.float_] | None = None,
90+
initial_epoch: int = 0,
91+
steps_per_epoch: int | None = None,
92+
validation_steps: int | None = None,
93+
validation_batch_size: int | None = None,
94+
validation_freq: int | Container[int] = 1,
95+
max_queue_size: int = 10,
96+
workers: int = 1,
97+
use_multiprocessing: bool = False,
98+
) -> tf.keras.callbacks.History: ...
99+
def test_step(self, data: TensorCompatible) -> dict[str, float]: ...
100+
def make_test_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], dict[str, float]]: ...
101+
def evaluate(
102+
self,
103+
x: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete] | None = None,
104+
y: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete] | None = None,
105+
batch_size: int | None = None,
106+
verbose: Literal["auto", 0, 1, 2] = "auto",
107+
sample_weight: npt.NDArray[np.float_] | None = None,
108+
steps: int | None = None,
109+
callbacks: list[tf.keras.callbacks.Callback] | None = None,
110+
max_queue_size: int = 10,
111+
workers: int = 1,
112+
use_multiprocessing: bool = False,
113+
return_dict: bool = False,
114+
**kwargs: Any,
115+
) -> float | list[float]: ...
116+
def predict_step(self, data: _InputT) -> _OutputT: ...
117+
def make_predict_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], _OutputT]: ...
118+
def predict(
119+
self,
120+
x: TensorCompatible | tf.data.Dataset[Incomplete],
121+
batch_size: int | None = None,
122+
verbose: Literal["auto", 0, 1, 2] = "auto",
123+
steps: int | None = None,
124+
callbacks: list[tf.keras.callbacks.Callback] | None = None,
125+
max_queue_size: int = 10,
126+
workers: int = 1,
127+
use_multiprocessing: bool = False,
128+
) -> _OutputT: ...
129+
def reset_metrics(self) -> None: ...
130+
def train_on_batch(
131+
self,
132+
x: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete],
133+
y: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete] | None = None,
134+
sample_weight: npt.NDArray[np.float_] | None = None,
135+
class_weight: dict[int, float] | None = None,
136+
reset_metrics: bool = True,
137+
return_dict: bool = False,
138+
) -> float | list[float]: ...
139+
def test_on_batch(
140+
self,
141+
x: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete],
142+
y: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete] | None = None,
143+
sample_weight: npt.NDArray[np.float_] | None = None,
144+
reset_metrics: bool = True,
145+
return_dict: bool = False,
146+
) -> float | list[float]: ...
147+
def predict_on_batch(self, x: Iterator[_InputT]) -> npt.NDArray[Incomplete]: ...
148+
def fit_generator(
149+
self,
150+
generator: Iterator[Incomplete],
151+
steps_per_epoch: int | None = None,
152+
epochs: int = 1,
153+
verbose: Literal["auto", 0, 1, 2] = 1,
154+
callbacks: list[tf.keras.callbacks.Callback] | None = None,
155+
validation_data: TensorCompatible | tf.data.Dataset[Any] | None = None,
156+
validation_steps: int | None = None,
157+
validation_freq: int | Container[int] = 1,
158+
class_weight: dict[int, float] | None = None,
159+
max_queue_size: int = 10,
160+
workers: int = 1,
161+
use_multiprocessing: bool = False,
162+
shuffle: bool = True,
163+
initial_epoch: int = 0,
164+
) -> tf.keras.callbacks.History: ...
165+
def evaluate_generator(
166+
self,
167+
generator: Iterator[Incomplete],
168+
steps: int | None = None,
169+
callbacks: list[tf.keras.callbacks.Callback] | None = None,
170+
max_queue_size: int = 10,
171+
workers: int = 1,
172+
use_multiprocessing: bool = False,
173+
verbose: Literal["auto", 0, 1, 2] = 0,
174+
) -> float | list[float]: ...
175+
def predict_generator(
176+
self,
177+
generator: Iterator[Incomplete],
178+
steps: int | None = None,
179+
callbacks: list[tf.keras.callbacks.Callback] | None = None,
180+
max_queue_size: int = 10,
181+
workers: int = 1,
182+
use_multiprocessing: bool = False,
183+
verbose: Literal["auto", 0, 1, 2] = 0,
184+
) -> _OutputT: ...
185+
@property
186+
def trainable_weights(self) -> list[Variable]: ...
187+
@property
188+
def non_trainable_weights(self) -> list[Variable]: ...
189+
def get_weights(self) -> Incomplete: ...
190+
def save(
191+
self, filepath: str | Path, overwrite: bool = True, save_format: Literal["keras", "tf", "h5"] | None = None, **kwargs: Any
192+
) -> None: ...
193+
def save_weights(
194+
self,
195+
filepath: str | Path,
196+
overwrite: bool = True,
197+
save_format: Literal["tf", "h5"] | None = None,
198+
options: tf.train.CheckpointOptions | None = None,
199+
) -> None: ...
200+
def load_weights(
201+
self,
202+
filepath: str | Path,
203+
skip_mismatch: bool = False,
204+
by_name: bool = False,
205+
options: None | tensorflow.train.CheckpointOptions = None,
206+
) -> None: ...
207+
def get_config(self) -> dict[str, Any]: ...
208+
@classmethod
209+
def from_config(cls, config: dict[str, Any], custom_objects: Incomplete | None = None) -> Self: ...
210+
def to_json(self, **kwargs: Any) -> str: ...
211+
def to_yaml(self, **kwargs: Any) -> str: ...
212+
def reset_states(self) -> None: ...
213+
@property
214+
def state_updates(self) -> list[Incomplete]: ...
215+
@property
216+
def weights(self) -> list[Variable]: ...
217+
def summary(
218+
self,
219+
line_length: None | int = None,
220+
positions: None | list[float] = None,
221+
print_fn: None | Callable[[str], None] = None,
222+
expand_nested: bool = False,
223+
show_trainable: bool = False,
224+
layer_range: None | list[str] | tuple[str, str] = None,
225+
) -> None: ...
226+
@property
227+
def layers(self) -> list[Layer[Incomplete, Incomplete]]: ...
228+
def get_layer(self, name: str | None = None, index: int | None = None) -> Layer[Incomplete, Incomplete]: ...
229+
def get_weight_paths(self) -> dict[str, tf.Variable]: ...
230+
def get_compile_config(self) -> dict[str, Any]: ...
231+
def compile_from_config(self, config: dict[str, Any]) -> Self: ...
232+
def export(self, filepath: str | Path) -> None: ...
233+
def save_spec(self, dynamic_batch: bool = True) -> tuple[tuple[tf.TensorSpec, ...], dict[str, tf.TensorSpec]] | None: ...
234+
235+
def __getattr__(name: str) -> Incomplete: ...

stubs/tensorflow/tensorflow/keras/optimizers/__init__.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@ from _typeshed import Incomplete
22

33
from tensorflow.keras.optimizers import legacy as legacy, schedules as schedules
44

5+
Optimizer = Incomplete
6+
57
def __getattr__(name: str) -> Incomplete: ...
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from _typeshed import Incomplete
2+
3+
Strategy = Incomplete
4+
5+
def __getattr__(name: str) -> Incomplete: ...
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from tensorflow.python.trackable.base import Trackable
2+
3+
class _ResourceMetaclass(type): ...
4+
5+
# Internal type that is commonly used as a base class
6+
# and some public apis the signature needs it.
7+
class CapturableResource(Trackable, metaclass=_ResourceMetaclass): ...

stubs/tensorflow/tensorflow/train/__init__.pyi

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ from tensorflow.core.example.feature_pb2 import (
1616
from tensorflow.core.protobuf.cluster_pb2 import ClusterDef as ClusterDef
1717
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef as ServerDef
1818
from tensorflow.python.trackable.base import Trackable
19+
from tensorflow.python.training.tracking.autotrackable import AutoTrackable
1920

2021
class CheckpointOptions:
2122
experimental_io_device: None | str
@@ -44,7 +45,7 @@ class _CheckpointLoadStatus:
4445
def assert_nontrivial_match(self) -> Self: ...
4546
def expect_partial(self) -> Self: ...
4647

47-
class Checkpoint:
48+
class Checkpoint(AutoTrackable):
4849
def __init__(self, root: Trackable | None = None, **kwargs: Trackable) -> None: ...
4950
def read(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ...
5051
def restore(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ...

0 commit comments

Comments
 (0)