|
7 | 7 | from hypothesis import given
|
8 | 8 |
|
9 | 9 |
|
10 |
| -def compare_rowwise(emb_orig, emb_reconstructed): |
| 10 | +def compare_rowwise(emb_orig, emb_reconstructed, fp16): |
| 11 | + # there is an absolute error introduced per row through int8 quantization |
| 12 | + # and a relative error introduced when quantizing back from fp32 to fp16 |
11 | 13 | assert(emb_orig.shape == emb_reconstructed.shape)
|
12 |
| - range = np.amax(emb_orig, axis=1) - np.amin(emb_orig, axis=1) |
13 |
| - # TOOO: figure out the right threshold, this has to do with the |
14 |
| - # fact that the data types are float16, in float32, it should be /1.9 |
15 |
| - threshold = range / 255.0 / 1.5 |
16 |
| - diff = np.amax(np.abs(emb_orig - emb_reconstructed), axis=1) |
17 |
| - n_violated = ((threshold - diff) < 0).sum() |
18 |
| - if n_violated > 0: |
19 |
| - print(n_violated, threshold, diff, threshold < diff, emb_orig, |
20 |
| - emb_reconstructed, emb_orig - emb_reconstructed) |
21 |
| - assert(n_violated == 0) |
| 14 | + rtol = 1e-8 |
| 15 | + if fp16: |
| 16 | + rtol = 1e-3 |
| 17 | + erange = np.amax(emb_orig, axis=1) - np.amin(emb_orig, axis=1) |
| 18 | + |
| 19 | + threshold = erange / 255.0 / 1.9 |
| 20 | + |
| 21 | + for i in range(emb_orig.shape[0]): |
| 22 | + r_orig = emb_orig[i, :] |
| 23 | + r_reconstructed = emb_reconstructed[i, :] |
| 24 | + |
| 25 | + isclose = np.isclose(r_orig, r_reconstructed, atol=threshold[i], rtol=rtol) |
| 26 | + n_violated = isclose.size - isclose.sum() |
| 27 | + |
| 28 | + if n_violated > 0: |
| 29 | + print(isclose, threshold[i]) |
| 30 | + print(i, r_orig, r_reconstructed, threshold[i], r_orig - r_reconstructed) |
| 31 | + assert(n_violated == 0) |
22 | 32 |
|
23 | 33 |
|
24 | 34 | class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase):
|
@@ -102,7 +112,7 @@ def test_sparse_lengths_sum(
|
102 | 112 |
|
103 | 113 | dequantized_data = workspace.FetchBlob("dequantized_data")
|
104 | 114 | np.testing.assert_array_almost_equal(input_data, workspace.FetchBlob("input_data"))
|
105 |
| - compare_rowwise(input_data, dequantized_data) |
| 115 | + compare_rowwise(input_data, dequantized_data, fp16) |
106 | 116 |
|
107 | 117 | sum_reference = workspace.FetchBlob("sum_reference")
|
108 | 118 | sum_quantized = workspace.FetchBlob("sum_quantized")
|
@@ -179,7 +189,7 @@ def test_sparse_lengths_mean(self, batchsize, blocksize, seed, empty_indices, fp
|
179 | 189 |
|
180 | 190 | dequantized_data = workspace.FetchBlob("dequantized_data")
|
181 | 191 | np.testing.assert_array_almost_equal(input_data, workspace.FetchBlob("input_data"))
|
182 |
| - compare_rowwise(input_data, dequantized_data) |
| 192 | + compare_rowwise(input_data, dequantized_data, fp16) |
183 | 193 |
|
184 | 194 | mean_reference = workspace.FetchBlob("mean_reference")
|
185 | 195 | mean_quantized = workspace.FetchBlob("mean_quantized")
|
|
0 commit comments