13
13
# limitations under the License.
14
14
# ==============================================================================
15
15
"""Implements R^2 scores."""
16
- from typing import Tuple
16
+ import warnings
17
17
18
18
import numpy as np
19
19
import tensorflow as tf
@@ -86,13 +86,18 @@ def __init__(
86
86
self ,
87
87
name : str = "r_square" ,
88
88
dtype : AcceptableDTypes = None ,
89
- y_shape : Tuple [int , ...] = (),
90
89
multioutput : str = "uniform_average" ,
91
90
num_regressors : tf .int32 = 0 ,
92
91
** kwargs ,
93
92
):
94
93
super ().__init__ (name = name , dtype = dtype , ** kwargs )
95
- self .y_shape = y_shape
94
+
95
+ if "y_shape" in kwargs :
96
+ warnings .warn (
97
+ "y_shape has been removed, because it's automatically derived,"
98
+ "and will be deprecated in Addons 0.18." ,
99
+ DeprecationWarning ,
100
+ )
96
101
97
102
if multioutput not in _VALID_MULTIOUTPUT :
98
103
raise ValueError (
@@ -102,21 +107,38 @@ def __init__(
102
107
)
103
108
self .multioutput = multioutput
104
109
self .num_regressors = num_regressors
105
- self .squared_sum = self .add_weight (
106
- name = "squared_sum" , shape = y_shape , initializer = "zeros" , dtype = dtype
107
- )
108
- self .sum = self .add_weight (
109
- name = "sum" , shape = y_shape , initializer = "zeros" , dtype = dtype
110
- )
111
- self .res = self .add_weight (
112
- name = "residual" , shape = y_shape , initializer = "zeros" , dtype = dtype
113
- )
114
- self .count = self .add_weight (
115
- name = "count" , shape = y_shape , initializer = "zeros" , dtype = dtype
116
- )
117
110
self .num_samples = self .add_weight (name = "num_samples" , dtype = tf .int32 )
118
111
119
112
def update_state (self , y_true , y_pred , sample_weight = None ) -> None :
113
+ if not hasattr (self , "squared_sum" ):
114
+ self .squared_sum = self .add_weight (
115
+ name = "squared_sum" ,
116
+ shape = y_true .shape [1 :],
117
+ initializer = "zeros" ,
118
+ dtype = self ._dtype ,
119
+ )
120
+ if not hasattr (self , "sum" ):
121
+ self .sum = self .add_weight (
122
+ name = "sum" ,
123
+ shape = y_true .shape [1 :],
124
+ initializer = "zeros" ,
125
+ dtype = self ._dtype ,
126
+ )
127
+ if not hasattr (self , "res" ):
128
+ self .res = self .add_weight (
129
+ name = "residual" ,
130
+ shape = y_true .shape [1 :],
131
+ initializer = "zeros" ,
132
+ dtype = self ._dtype ,
133
+ )
134
+ if not hasattr (self , "count" ):
135
+ self .count = self .add_weight (
136
+ name = "count" ,
137
+ shape = y_true .shape [1 :],
138
+ initializer = "zeros" ,
139
+ dtype = self ._dtype ,
140
+ )
141
+
120
142
y_true = tf .cast (y_true , dtype = self ._dtype )
121
143
y_pred = tf .cast (y_pred , dtype = self ._dtype )
122
144
if sample_weight is None :
@@ -191,7 +213,6 @@ def reset_states(self):
191
213
192
214
def get_config (self ):
193
215
config = {
194
- "y_shape" : self .y_shape ,
195
216
"multioutput" : self .multioutput ,
196
217
}
197
218
base_config = super ().get_config ()
0 commit comments