Skip to content

Commit 075c7b1

Browse files
Hector Yuenfacebook-github-bot
Hector Yuen
authored andcommitted
make the threshold for acurracy more precise (pytorch#17194)
Summary: Pull Request resolved: pytorch#17194 we found that there is a per row absolute error due to int8 quant and a relative error table-wide in case fp16 is used Reviewed By: csummersea Differential Revision: D14113353 fbshipit-source-id: c7065aa9d15c453c2e5609f421ad0155145af889
1 parent db1d61a commit 075c7b1

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,28 @@
77
from hypothesis import given
88

99

10-
def compare_rowwise(emb_orig, emb_reconstructed):
10+
def compare_rowwise(emb_orig, emb_reconstructed, fp16):
11+
# there is an absolute error introduced per row through int8 quantization
12+
# and a relative error introduced when quantizing back from fp32 to fp16
1113
assert(emb_orig.shape == emb_reconstructed.shape)
12-
range = np.amax(emb_orig, axis=1) - np.amin(emb_orig, axis=1)
13-
# TOOO: figure out the right threshold, this has to do with the
14-
# fact that the data types are float16, in float32, it should be /1.9
15-
threshold = range / 255.0 / 1.5
16-
diff = np.amax(np.abs(emb_orig - emb_reconstructed), axis=1)
17-
n_violated = ((threshold - diff) < 0).sum()
18-
if n_violated > 0:
19-
print(n_violated, threshold, diff, threshold < diff, emb_orig,
20-
emb_reconstructed, emb_orig - emb_reconstructed)
21-
assert(n_violated == 0)
14+
rtol = 1e-8
15+
if fp16:
16+
rtol = 1e-3
17+
erange = np.amax(emb_orig, axis=1) - np.amin(emb_orig, axis=1)
18+
19+
threshold = erange / 255.0 / 1.9
20+
21+
for i in range(emb_orig.shape[0]):
22+
r_orig = emb_orig[i, :]
23+
r_reconstructed = emb_reconstructed[i, :]
24+
25+
isclose = np.isclose(r_orig, r_reconstructed, atol=threshold[i], rtol=rtol)
26+
n_violated = isclose.size - isclose.sum()
27+
28+
if n_violated > 0:
29+
print(isclose, threshold[i])
30+
print(i, r_orig, r_reconstructed, threshold[i], r_orig - r_reconstructed)
31+
assert(n_violated == 0)
2232

2333

2434
class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase):
@@ -102,7 +112,7 @@ def test_sparse_lengths_sum(
102112

103113
dequantized_data = workspace.FetchBlob("dequantized_data")
104114
np.testing.assert_array_almost_equal(input_data, workspace.FetchBlob("input_data"))
105-
compare_rowwise(input_data, dequantized_data)
115+
compare_rowwise(input_data, dequantized_data, fp16)
106116

107117
sum_reference = workspace.FetchBlob("sum_reference")
108118
sum_quantized = workspace.FetchBlob("sum_quantized")
@@ -179,7 +189,7 @@ def test_sparse_lengths_mean(self, batchsize, blocksize, seed, empty_indices, fp
179189

180190
dequantized_data = workspace.FetchBlob("dequantized_data")
181191
np.testing.assert_array_almost_equal(input_data, workspace.FetchBlob("input_data"))
182-
compare_rowwise(input_data, dequantized_data)
192+
compare_rowwise(input_data, dequantized_data, fp16)
183193

184194
mean_reference = workspace.FetchBlob("mean_reference")
185195
mean_quantized = workspace.FetchBlob("mean_quantized")

0 commit comments

Comments
 (0)