Skip to content

Commit dd2090c

Browse files
authored
- Add adjr2 (#2364)
1 parent ffeaede commit dd2090c

File tree

2 files changed

+104
-7
lines changed

2 files changed

+104
-7
lines changed

Diff for: tensorflow_addons/metrics/r_square.py

+44-5
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,15 @@ class RSquare(Metric):
5959
](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html)
6060
of the same metric.
6161
62+
Can also calculate the Adjusted R2 Score.
63+
6264
Args:
6365
multioutput: `string`, the reduce method for scores.
6466
Should be one of `["raw_values", "uniform_average", "variance_weighted"]`.
6567
name: (Optional) string name of the metric instance.
6668
dtype: (Optional) data type of the metric result.
69+
num_regressors: (Optional) Number of indepedent regressors used (Adjusted R2).
70+
Defaults to zero(standard R2 score).
6771
6872
Usage:
6973
@@ -83,6 +87,7 @@ def __init__(
8387
dtype: AcceptableDTypes = None,
8488
y_shape: Tuple[int, ...] = (),
8589
multioutput: str = "uniform_average",
90+
num_regressors: tf.int32 = 0,
8691
**kwargs,
8792
):
8893
super().__init__(name=name, dtype=dtype, **kwargs)
@@ -95,6 +100,7 @@ def __init__(
95100
)
96101
)
97102
self.multioutput = multioutput
103+
self.num_regressors = num_regressors
98104
self.squared_sum = self.add_weight(
99105
name="squared_sum", shape=y_shape, initializer="zeros", dtype=dtype
100106
)
@@ -107,6 +113,7 @@ def __init__(
107113
self.count = self.add_weight(
108114
name="count", shape=y_shape, initializer="zeros", dtype=dtype
109115
)
116+
self.num_samples = self.add_weight(name="num_samples", dtype=tf.int32)
110117

111118
def update_state(self, y_true, y_pred, sample_weight=None) -> None:
112119
y_true = tf.cast(y_true, dtype=self._dtype)
@@ -125,6 +132,7 @@ def update_state(self, y_true, y_pred, sample_weight=None) -> None:
125132
tf.reduce_sum((y_true - y_pred) ** 2 * sample_weight, axis=0)
126133
)
127134
self.count.assign_add(tf.reduce_sum(sample_weight, axis=0))
135+
self.num_samples.assign_add(tf.size(y_true))
128136

129137
def result(self) -> tf.Tensor:
130138
mean = self.sum / self.count
@@ -133,11 +141,42 @@ def result(self) -> tf.Tensor:
133141
raw_scores = tf.where(tf.math.is_inf(raw_scores), 0.0, raw_scores)
134142

135143
if self.multioutput == "raw_values":
136-
return raw_scores
137-
if self.multioutput == "uniform_average":
138-
return tf.reduce_mean(raw_scores)
139-
if self.multioutput == "variance_weighted":
140-
return _reduce_average(raw_scores, weights=total)
144+
r2_score = raw_scores
145+
elif self.multioutput == "uniform_average":
146+
r2_score = tf.reduce_mean(raw_scores)
147+
elif self.multioutput == "variance_weighted":
148+
r2_score = _reduce_average(raw_scores, weights=total)
149+
else:
150+
raise RuntimeError(
151+
"The multioutput attribute must be one of {}, but was: {}".format(
152+
_VALID_MULTIOUTPUT, self.multioutput
153+
)
154+
)
155+
156+
if self.num_regressors < 0:
157+
raise ValueError(
158+
"num_regressors parameter should be greater than or equal to zero"
159+
)
160+
161+
if self.num_regressors != 0:
162+
if self.num_regressors > self.num_samples - 1:
163+
UserWarning(
164+
"More independent predictors than datapoints in adjusted r2 score. Falls back to standard r2 "
165+
"score."
166+
)
167+
elif self.num_regressors == self.num_samples - 1:
168+
UserWarning(
169+
"Division by zero in adjusted r2 score. Falls back to standard r2 score."
170+
)
171+
else:
172+
n = tf.cast(self.num_samples, dtype=tf.float32)
173+
p = tf.cast(self.num_regressors, dtype=tf.float32)
174+
175+
num = tf.multiply(tf.subtract(1.0, r2_score), tf.subtract(n, 1.0))
176+
den = tf.subtract(tf.subtract(n, p), 1.0)
177+
r2_score = tf.subtract(1.0, tf.divide(num, den))
178+
179+
return r2_score
141180

142181
def reset_states(self) -> None:
143182
# The state of the metric will be reset at the start of each epoch.

Diff for: tensorflow_addons/metrics/tests/r_square_test.py

+60-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,14 @@ def test_config(multioutput, y_shape):
3939
assert r2_obj2.y_shape == y_shape
4040

4141

42-
def initialize_vars(y_shape=(), multioutput: str = "uniform_average"):
43-
return RSquare(y_shape=y_shape, multioutput=multioutput)
42+
def initialize_vars(
43+
y_shape=(),
44+
multioutput: str = "uniform_average",
45+
num_regressors: tf.int32 = 0,
46+
):
47+
return RSquare(
48+
y_shape=y_shape, multioutput=multioutput, num_regressors=num_regressors
49+
)
4450

4551

4652
def update_obj_states(obj, actuals, preds, sample_weight=None):
@@ -145,3 +151,55 @@ def test_keras_fit():
145151
)
146152
data = data.batch(10)
147153
model.fit(x=data, validation_data=data)
154+
155+
156+
def test_adjr2():
157+
actuals = tf.constant([10, 600, 3, 9.77], dtype=tf.float32)
158+
preds = tf.constant([1, 340, 40, 5.7], dtype=tf.float32)
159+
actuals = tf.cast(actuals, dtype=tf.float32)
160+
preds = tf.cast(preds, dtype=tf.float32)
161+
# Initialize
162+
adjr2_obj = initialize_vars(num_regressors=2)
163+
update_obj_states(adjr2_obj, actuals, preds)
164+
# Check result
165+
check_results(adjr2_obj, 0.2128982)
166+
167+
168+
def test_adjr2_negative_num_preds():
169+
actuals = tf.constant([10, 600, 3, 9.77], dtype=tf.float32)
170+
preds = tf.constant([1, 340, 40, 5.7], dtype=tf.float32)
171+
actuals = tf.cast(actuals, dtype=tf.float32)
172+
preds = tf.cast(preds, dtype=tf.float32)
173+
# Initialize
174+
adjr2_obj = initialize_vars(num_regressors=-3)
175+
update_obj_states(adjr2_obj, actuals, preds)
176+
# Expect runtime error
177+
pytest.raises(ValueError)
178+
179+
180+
def test_adjr2_zero_division():
181+
actuals = tf.constant([10, 600, 3, 9.77], dtype=tf.float32)
182+
preds = tf.constant([1, 340, 40, 5.7], dtype=tf.float32)
183+
actuals = tf.cast(actuals, dtype=tf.float32)
184+
preds = tf.cast(preds, dtype=tf.float32)
185+
# Initialize
186+
adjr2_obj = initialize_vars(num_regressors=3)
187+
update_obj_states(adjr2_obj, actuals, preds)
188+
# Expect warning
189+
pytest.raises(UserWarning)
190+
# Fallback to standard
191+
check_results(adjr2_obj, 0.7376327)
192+
193+
194+
def test_adjr2_excess_num_preds():
195+
actuals = tf.constant([10, 600, 3, 9.77], dtype=tf.float32)
196+
preds = tf.constant([1, 340, 40, 5.7], dtype=tf.float32)
197+
actuals = tf.cast(actuals, dtype=tf.float32)
198+
preds = tf.cast(preds, dtype=tf.float32)
199+
# Initialize
200+
adjr2_obj = initialize_vars(num_regressors=5)
201+
update_obj_states(adjr2_obj, actuals, preds)
202+
# Expect warning
203+
pytest.raises(UserWarning)
204+
# Fallback to standard
205+
check_results(adjr2_obj, 0.7376327)

0 commit comments

Comments
 (0)