Skip to content

Commit fe9c9dd

Browse files
authored
AUTOTUNE 'y_shape' (#2702)
* AUTOTUNE 'y_shape' * black + flake8 * + warning * + version
1 parent 863625b commit fe9c9dd

File tree

2 files changed

+42
-27
lines changed

2 files changed

+42
-27
lines changed

tensorflow_addons/metrics/r_square.py

+37-16
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Implements R^2 scores."""
16-
from typing import Tuple
16+
import warnings
1717

1818
import numpy as np
1919
import tensorflow as tf
@@ -86,13 +86,18 @@ def __init__(
8686
self,
8787
name: str = "r_square",
8888
dtype: AcceptableDTypes = None,
89-
y_shape: Tuple[int, ...] = (),
9089
multioutput: str = "uniform_average",
9190
num_regressors: tf.int32 = 0,
9291
**kwargs,
9392
):
9493
super().__init__(name=name, dtype=dtype, **kwargs)
95-
self.y_shape = y_shape
94+
95+
if "y_shape" in kwargs:
96+
warnings.warn(
97+
"y_shape has been removed, because it's automatically derived,"
98+
"and will be deprecated in Addons 0.18.",
99+
DeprecationWarning,
100+
)
96101

97102
if multioutput not in _VALID_MULTIOUTPUT:
98103
raise ValueError(
@@ -102,21 +107,38 @@ def __init__(
102107
)
103108
self.multioutput = multioutput
104109
self.num_regressors = num_regressors
105-
self.squared_sum = self.add_weight(
106-
name="squared_sum", shape=y_shape, initializer="zeros", dtype=dtype
107-
)
108-
self.sum = self.add_weight(
109-
name="sum", shape=y_shape, initializer="zeros", dtype=dtype
110-
)
111-
self.res = self.add_weight(
112-
name="residual", shape=y_shape, initializer="zeros", dtype=dtype
113-
)
114-
self.count = self.add_weight(
115-
name="count", shape=y_shape, initializer="zeros", dtype=dtype
116-
)
117110
self.num_samples = self.add_weight(name="num_samples", dtype=tf.int32)
118111

119112
def update_state(self, y_true, y_pred, sample_weight=None) -> None:
113+
if not hasattr(self, "squared_sum"):
114+
self.squared_sum = self.add_weight(
115+
name="squared_sum",
116+
shape=y_true.shape[1:],
117+
initializer="zeros",
118+
dtype=self._dtype,
119+
)
120+
if not hasattr(self, "sum"):
121+
self.sum = self.add_weight(
122+
name="sum",
123+
shape=y_true.shape[1:],
124+
initializer="zeros",
125+
dtype=self._dtype,
126+
)
127+
if not hasattr(self, "res"):
128+
self.res = self.add_weight(
129+
name="residual",
130+
shape=y_true.shape[1:],
131+
initializer="zeros",
132+
dtype=self._dtype,
133+
)
134+
if not hasattr(self, "count"):
135+
self.count = self.add_weight(
136+
name="count",
137+
shape=y_true.shape[1:],
138+
initializer="zeros",
139+
dtype=self._dtype,
140+
)
141+
120142
y_true = tf.cast(y_true, dtype=self._dtype)
121143
y_pred = tf.cast(y_pred, dtype=self._dtype)
122144
if sample_weight is None:
@@ -191,7 +213,6 @@ def reset_states(self):
191213

192214
def get_config(self):
193215
config = {
194-
"y_shape": self.y_shape,
195216
"multioutput": self.multioutput,
196217
}
197218
base_config = super().get_config()

tensorflow_addons/metrics/tests/r_square_test.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,23 @@
2424

2525

2626
@pytest.mark.parametrize("multioutput", sorted(_VALID_MULTIOUTPUT))
27-
@pytest.mark.parametrize("y_shape", [(), (1,)])
28-
def test_config(multioutput, y_shape):
29-
r2_obj = RSquare(multioutput=multioutput, y_shape=y_shape, name="r_square")
27+
def test_config(multioutput):
28+
r2_obj = RSquare(multioutput=multioutput, name="r_square")
3029
assert r2_obj.name == "r_square"
3130
assert r2_obj.dtype == tf.float32
3231
assert r2_obj.multioutput == multioutput
33-
assert r2_obj.y_shape == y_shape
3432
# Check save and restore config
3533
r2_obj2 = RSquare.from_config(r2_obj.get_config())
3634
assert r2_obj2.name == "r_square"
3735
assert r2_obj2.dtype == tf.float32
3836
assert r2_obj2.multioutput == multioutput
39-
assert r2_obj2.y_shape == y_shape
4037

4138

4239
def initialize_vars(
43-
y_shape=(),
4440
multioutput: str = "uniform_average",
4541
num_regressors: tf.int32 = 0,
4642
):
47-
return RSquare(
48-
y_shape=y_shape, multioutput=multioutput, num_regressors=num_regressors
49-
)
43+
return RSquare(multioutput=multioutput, num_regressors=num_regressors)
5044

5145

5246
def update_obj_states(obj, actuals, preds, sample_weight=None):
@@ -149,7 +143,7 @@ def test_r2_sklearn_comparison(multioutput):
149143
tensor_preds = tf.cast(tensor_preds, dtype=tf.float32)
150144
tensor_sample_weight = tf.cast(tensor_sample_weight, dtype=tf.float32)
151145
# Initialize
152-
r2_obj = initialize_vars(y_shape=(3,), multioutput=multioutput)
146+
r2_obj = initialize_vars(multioutput=multioutput)
153147
# Update
154148
update_obj_states(
155149
r2_obj,
@@ -171,7 +165,7 @@ def test_unrecognized_multioutput():
171165

172166
def test_keras_fit():
173167
model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
174-
model.compile(loss="mse", metrics=[RSquare(y_shape=(1,))])
168+
model.compile(loss="mse", metrics=[RSquare()])
175169
data = tf.data.Dataset.from_tensor_slices(
176170
(tf.random.normal(shape=(100, 1)), tf.random.normal(shape=(100, 1)))
177171
)

0 commit comments

Comments
 (0)