Skip to content

Commit b162979

Browse files
lingvo-botcopybara-github
authored andcommitted
Add tensors that test QAT.
PiperOrigin-RevId: 648803095
1 parent bc1cfd8 commit b162979

11 files changed

+72
-0
lines changed

lingvo/core/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,7 @@ pytype_strict_test(
14751475
size = "medium",
14761476
srcs = ["py_utils_test.py"],
14771477
args = ["--noenable_eager_execution"],
1478+
data = ["//lingvo/core/testdata:quantization_test_data"],
14781479
deps = [
14791480
":py_utils_test_lib",
14801481
# Implicit freezegun dependency.
@@ -1487,6 +1488,7 @@ pytype_strict_test(
14871488
name = "py_utils_eager_test",
14881489
srcs = ["py_utils_test.py"],
14891490
args = ["--enable_eager_execution"],
1491+
data = ["//lingvo/core/testdata:quantization_test_data"],
14901492
main = "py_utils_test.py",
14911493
deps = [
14921494
":py_utils_test_lib",

lingvo/core/py_utils_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,69 @@ def testQAT(self, qat_output, expected):
13521352
)
13531353
self.assertAllClose(self.evaluate(x), expected)
13541354

1355+
@parameterized.named_parameters(
1356+
(
1357+
'4bit_weight_qat_output_false',
1358+
False,
1359+
'core/testdata/qat_test_4bit_weights.npy',
1360+
'core/testdata/qat_test_output_4bit_weight_qat_false.npy',
1361+
),
1362+
(
1363+
'4bit_weight_qat_output_true',
1364+
True,
1365+
'core/testdata/qat_test_4bit_weights.npy',
1366+
'core/testdata/qat_test_output_4bit_weight_qat_true.npy',
1367+
),
1368+
(
1369+
'8bit_weight_qat_output_false',
1370+
False,
1371+
'core/testdata/qat_test_8bit_weights.npy',
1372+
'core/testdata/qat_test_output_8bit_weight_qat_false.npy',
1373+
),
1374+
(
1375+
'8bit_weight_qat_output_true',
1376+
True,
1377+
'core/testdata/qat_test_8bit_weights.npy',
1378+
'core/testdata/qat_test_output_8bit_weight_qat_true.npy',
1379+
),
1380+
)
1381+
def testEinsumQuantization(self, qat_output, weights_path, expected):
1382+
# num_tasks=1, input_dim=2, output_dim=3
1383+
weights_path = test_helper.test_src_dir_path(weights_path)
1384+
weights = tf.convert_to_tensor(np.load(weights_path), tf.float32)
1385+
bias_path = test_helper.test_src_dir_path('core/testdata/qat_test_bias.npy')
1386+
bias = tf.convert_to_tensor(np.load(bias_path), tf.float32)
1387+
inputs_path = test_helper.test_src_dir_path(
1388+
'core/testdata/qat_test_inputs.npy'
1389+
)
1390+
inputs = tf.convert_to_tensor(np.load(inputs_path), tf.float32)
1391+
output_path = test_helper.test_src_dir_path(expected)
1392+
output = tf.convert_to_tensor(np.load(output_path), tf.float32)
1393+
1394+
quant_layer_p = layers.MultitaskProjectionEinsumLayer.Params()
1395+
quant_layer_p.name = 'testQAT'
1396+
quant_layer_p.input_dim = 256
1397+
quant_layer_p.output_dim = 126
1398+
quant_layer_p.num_tasks = 8
1399+
1400+
with self.session(use_gpu=False):
1401+
x = self.evaluate(
1402+
py_utils.MultiTaskProjection(
1403+
weights=weights,
1404+
biases=bias,
1405+
inputs=inputs,
1406+
tasks=1,
1407+
einsum_order='select_and_multiply',
1408+
quant_layer=layers.MultitaskProjectionEinsumLayer(quant_layer_p),
1409+
w_q_name='w',
1410+
w_q_domain='default',
1411+
qat_output=qat_output,
1412+
)
1413+
)
1414+
# different server CPUs produce slightly different results, e-3 is a safe
1415+
# margin since outputs are in the order of e+4
1416+
self.assertAllClose(x, output, atol=2.5e-3)
1417+
13551418
def testShardedFilePatternToGlob(self):
13561419
file_pattern = '/some/path/to/file@8'
13571420
self.assertEqual('/some/path/to/file-?????-of-00008',

lingvo/core/testdata/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,10 @@ filegroup(
1515
"en-1k.spm.*",
1616
]),
1717
)
18+
19+
filegroup(
20+
name = "quantization_test_data",
21+
data = glob([
22+
"qat_test_*",
23+
]),
24+
)
252 KB
Binary file not shown.
252 KB
Binary file not shown.
8 KB
Binary file not shown.
16.1 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)