Skip to content

Commit 0e561d4

Browse files
helloguofacebook-github-bot
authored andcommitted
Add benchmark EmbeddingSpMDMNBitBenchmarkOutTypeFloat16 (#2901)
Summary: Pull Request resolved: #2901 This diff adds the benchmark EmbeddingSpMDMNBitBenchmarkOutTypeFloat16, to test TBE with output type float16. This diff doesn't change the EmbeddingSpMDMNBitBenchmark. Reviewed By: excelle08 Differential Revision: D60254038 fbshipit-source-id: f4b839d53450cf364e38aec4dab2e8f8c276bbab
1 parent 2808977 commit 0e561d4

File tree

1 file changed

+209
-23
lines changed

1 file changed

+209
-23
lines changed

bench/EmbeddingSpMDMNBitBenchmark.cc

Lines changed: 209 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ static vector<vector<int>> GetInputs_() {
6161
return input_dims;
6262
}
6363

64+
template <typename OutType>
6465
int run_benchmark(
6566
int bit_rate,
6667
int batch_size,
@@ -69,7 +70,8 @@ int run_benchmark(
6970
int average_len,
7071
bool normalize_by_lengths,
7172
bool use_32_bit_indices = false,
72-
bool prefetch = false) {
73+
bool prefetch = false,
74+
bool is_bf16_out = false) {
7375
// Create embedding table
7476
int num_elem_per_byte = 8 / bit_rate;
7577
int fused_embedding_dim =
@@ -133,8 +135,8 @@ int run_benchmark(
133135
weights[i] = embedding_distribution(generator);
134136
}
135137

136-
vector<float> output_sls_ref(batch_size * embedding_dim);
137-
vector<float> output_slws_ref(output_sls_ref.size()),
138+
vector<OutType> output_sls_ref(batch_size * embedding_dim);
139+
vector<OutType> output_slws_ref(output_sls_ref.size()),
138140
output_sls(output_sls_ref.size()), output_slws(output_sls_ref.size());
139141

140142
constexpr int NUM_WARMUP = 10;
@@ -148,11 +150,12 @@ int run_benchmark(
148150
CACHE_LINE_LEN);
149151

150152
for (bool has_weight : {false, true}) {
151-
vector<float>& output_ref = has_weight ? output_slws_ref : output_sls_ref;
152-
vector<float> output_autovec(output_sls_ref.size());
153+
vector<OutType>& output_ref = has_weight ? output_slws_ref : output_sls_ref;
154+
vector<OutType> output_autovec(output_sls_ref.size());
153155

154156
bool success = false, success_ref = false, success_autovec = false;
155157

158+
#ifndef OUT_TYPE_FLOAT16
156159
auto kernel_32 = GenerateEmbeddingSpMDMNBit<int32_t>(
157160
bit_rate,
158161
embedding_dim,
@@ -165,8 +168,9 @@ int run_benchmark(
165168
has_weight,
166169
normalize_by_lengths,
167170
prefetch ? 16 : 0);
171+
#endif // OUT_TYPE_FLOAT16
168172

169-
vector<float>& output = has_weight ? output_slws : output_sls;
173+
vector<OutType>& output = has_weight ? output_slws : output_sls;
170174
for (bool flush_cache : {false, true}) {
171175
// Reference implementation
172176
double t_ref = measureWithWarmup(
@@ -183,7 +187,13 @@ int run_benchmark(
183187
offsets.data(),
184188
has_weight ? weights.data() : nullptr,
185189
normalize_by_lengths,
186-
output_ref.data());
190+
output_ref.data(),
191+
false, // is_weight_positional
192+
true, // use_offsets
193+
-1, // output_stride
194+
-1, // input_stride
195+
true, // scale_bias_last
196+
is_bf16_out);
187197
} else {
188198
success_ref = EmbeddingSpMDMNBit_ref(
189199
bit_rate,
@@ -196,7 +206,13 @@ int run_benchmark(
196206
offsets.data(),
197207
has_weight ? weights.data() : nullptr,
198208
normalize_by_lengths,
199-
output_ref.data());
209+
output_ref.data(),
210+
false, // is_weight_positional
211+
true, // use_offsets
212+
-1, // output_stride
213+
-1, // input_stride
214+
true, // scale_bias_last
215+
is_bf16_out);
200216
}
201217
},
202218
NUM_WARMUP,
@@ -227,7 +243,13 @@ int run_benchmark(
227243
offsets.data(),
228244
has_weight ? weights.data() : nullptr,
229245
normalize_by_lengths,
230-
output_autovec.data());
246+
output_autovec.data(),
247+
false, // is_weight_positional
248+
true, // use_offsets
249+
-1, // output_stride
250+
-1, // input_stride
251+
true, // scale_bias_last
252+
is_bf16_out);
231253
} else {
232254
success_autovec = EmbeddingSpMDMNBit_autovec(
233255
bit_rate,
@@ -240,7 +262,13 @@ int run_benchmark(
240262
offsets.data(),
241263
has_weight ? weights.data() : nullptr,
242264
normalize_by_lengths,
243-
output_autovec.data());
265+
output_autovec.data(),
266+
false, // is_weight_positional
267+
true, // use_offsets
268+
-1, // output_stride
269+
-1, // input_stride
270+
true, // scale_bias_last
271+
is_bf16_out);
244272
}
245273
},
246274
NUM_WARMUP,
@@ -256,6 +284,7 @@ int run_benchmark(
256284
}
257285
});
258286

287+
#ifndef OUT_TYPE_FLOAT16
259288
// Hand-written AVX2/AVX512 implementation
260289
double t = measureWithWarmup(
261290
[&]() {
@@ -293,6 +322,7 @@ int run_benchmark(
293322
cache_evict(output);
294323
}
295324
});
325+
#endif // OUT_TYPE_FLOAT16
296326

297327
// printMatrix(
298328
// matrix_op_t::NoTranspose,
@@ -312,6 +342,7 @@ int run_benchmark(
312342
if (!flush_cache) {
313343
// vector<float>& output_ref =
314344
// has_weight ? output_slws_ref : output_sls_ref;
345+
#ifndef OUT_TYPE_FLOAT16
315346
if (success != success_ref) {
316347
assert(
317348
false &&
@@ -320,13 +351,32 @@ int run_benchmark(
320351
<< endl;
321352
} else {
322353
for (size_t i = 0; i < output.size(); ++i) {
323-
assert(fabs(output[i] - output_ref[i]) < 1e-3);
324-
if (fabs(output[i] - output_ref[i]) >= 1e-3) {
325-
cout << "asmjit vs ref : " << i << " " << output[i] << " "
326-
<< output_ref[i] << endl;
354+
float tmp1 = 0;
355+
float tmp2 = 0;
356+
if (std::is_same<OutType, float>::value) {
357+
tmp1 = output[i];
358+
tmp2 = output_ref[i];
359+
} else if (std::is_same<OutType, uint16_t>::value) {
360+
if (is_bf16_out) {
361+
tmp1 = cpu_bf162float(output[i]);
362+
tmp2 = cpu_bf162float(output_ref[i]);
363+
} else {
364+
tmp1 = cpu_half2float(output[i]);
365+
tmp2 = cpu_half2float(output_ref[i]);
366+
}
367+
} else {
368+
assert(false && "ERROR: unsupported output type");
369+
cout << "ERROR: unsupported output type" << endl;
370+
}
371+
372+
assert(fabs(tmp1 - tmp2) < 1e-3);
373+
if (fabs(tmp1 - tmp2) >= 1e-3) {
374+
cout << "asmjit vs ref : " << i << " " << tmp1 << " " << tmp2
375+
<< endl;
327376
}
328377
}
329378
}
379+
#endif // OUT_TYPE_FLOAT16
330380

331381
if (success_autovec != success_ref) {
332382
assert(
@@ -335,16 +385,47 @@ int run_benchmark(
335385
cout << "autovec return " << success_autovec << " ref return "
336386
<< success_ref << endl;
337387
} else {
338-
for (size_t i = 0; i < output.size(); ++i) {
339-
assert(fabs(output_autovec[i] - output_ref[i]) < 1e-3);
340-
if (fabs(output_autovec[i] - output_ref[i]) >= 1e-3) {
341-
cout << "autovec vs ref: " << i << " " << output_autovec[i] << " "
342-
<< output_ref[i] << endl;
388+
for (size_t i = 0; i < output_autovec.size(); ++i) {
389+
float tmp1 = 0;
390+
float tmp2 = 0;
391+
if (std::is_same<OutType, float>::value) {
392+
tmp1 = output_autovec[i];
393+
tmp2 = output_ref[i];
394+
} else if (std::is_same<OutType, uint16_t>::value) {
395+
if (is_bf16_out) {
396+
tmp1 = cpu_bf162float(output_autovec[i]);
397+
tmp2 = cpu_bf162float(output_ref[i]);
398+
} else {
399+
tmp1 = cpu_half2float(output_autovec[i]);
400+
tmp2 = cpu_half2float(output_ref[i]);
401+
}
402+
} else {
403+
assert(false && "ERROR: unsupported output type");
404+
cout << "ERROR: unsupported output type" << endl;
405+
}
406+
407+
assert(fabs(tmp1 - tmp2) < 1e-3);
408+
if (fabs(tmp1 - tmp2) >= 1e-3) {
409+
cout << "autovec vs ref: " << i << " " << tmp1 << " " << tmp2
410+
<< endl;
343411
}
344412
}
345413
}
346414
}
347415

416+
if (std::is_same<OutType, float>::value) {
417+
cout << "out type fp32, ";
418+
} else if (std::is_same<OutType, uint16_t>::value) {
419+
if (is_bf16_out) {
420+
cout << "out type bf16, ";
421+
} else {
422+
cout << "out type fp16, ";
423+
}
424+
} else {
425+
assert(false && "ERROR: unsupported output type");
426+
cout << "ERROR: unsupported output type" << endl;
427+
}
428+
348429
if (has_weight) {
349430
cout << "SLW(WEIGHTED), ";
350431
} else {
@@ -361,6 +442,7 @@ int run_benchmark(
361442
cout << "prefetch off, ";
362443
}
363444

445+
#ifndef OUT_TYPE_FLOAT16
364446
cout << "b/w, " << bytes / 1e9 / t << ", GB/s, " << "effective b/w, "
365447
<< bytes_padded / 1e9 / t << ", GB/s, " << "time, " << t
366448
<< ", autovec b/w, " << bytes / 1e9 / t_autovec << ", GB/s, "
@@ -370,6 +452,14 @@ int run_benchmark(
370452
<< bytes_padded / 1e9 / t_ref << ", GB/s, " << "ref time, " << t_ref
371453
<< ", autovec speedup, " << t_ref / t_autovec << ", asmjit speedup, "
372454
<< t_ref / t << endl;
455+
#else
456+
cout << "autovec b/w, " << bytes / 1e9 / t_autovec << ", GB/s, "
457+
<< "autovec eff. b/w, " << bytes_padded / 1e9 / t_autovec
458+
<< ", GB/s, " << "autovec time, " << t_autovec << ", ref b/w, "
459+
<< bytes / 1e9 / t_ref << ", GB/s, " << "ref eff. b/w, "
460+
<< bytes_padded / 1e9 / t_ref << ", GB/s, " << "ref time, " << t_ref
461+
<< ", autovec speedup, " << t_ref / t_autovec << endl;
462+
#endif // OUT_TYPE_FLOAT16
373463
} // flush_cache
374464
} // has_weight
375465
return 0;
@@ -397,16 +487,41 @@ int main() {
397487
// args: batch sz, num rows, emb dim, avg len, normalize, use 32b,
398488
// prefetch
399489
cout << "64 bit indices, ";
400-
run_benchmark(
490+
#ifndef OUT_TYPE_FLOAT16
491+
run_benchmark<float>(
401492
bit_rate,
402493
batch_size,
403494
num_rows,
404495
embedding_dim,
405496
average_len,
406497
false); // normalize_by_lengths
498+
#else
499+
run_benchmark<float16>(
500+
bit_rate,
501+
batch_size,
502+
num_rows,
503+
embedding_dim,
504+
average_len,
505+
false, // normalize_by_lengths
506+
false, // use_32_bit_indices
507+
false, // prefetch
508+
false); // is_bf16_out
509+
510+
run_benchmark<float16>(
511+
bit_rate,
512+
batch_size,
513+
num_rows,
514+
embedding_dim,
515+
average_len,
516+
false, // normalize_by_lengths
517+
false, // use_32_bit_indices
518+
false, // prefetch
519+
true); // is_bf16_out
520+
#endif // OUT_TYPE_FLOAT16
407521

408522
cout << "64 bit indices with prefetching, ";
409-
run_benchmark(
523+
#ifndef OUT_TYPE_FLOAT16
524+
run_benchmark<float>(
410525
bit_rate,
411526
batch_size,
412527
num_rows,
@@ -415,19 +530,67 @@ int main() {
415530
false, // normalize_by_lengths
416531
false, // use_32_bit_indices
417532
true); // prefetch
533+
#else
534+
run_benchmark<float16>(
535+
bit_rate,
536+
batch_size,
537+
num_rows,
538+
embedding_dim,
539+
average_len,
540+
false, // normalize_by_lengths
541+
false, // use_32_bit_indices
542+
true, // prefetch
543+
false); // is_bf16_out
544+
545+
run_benchmark<float16>(
546+
bit_rate,
547+
batch_size,
548+
num_rows,
549+
embedding_dim,
550+
average_len,
551+
false, // normalize_by_lengths
552+
false, // use_32_bit_indices
553+
true, // prefetch
554+
true); // is_bf16_out
555+
#endif // OUT_TYPE_FLOAT16
418556

419557
cout << "32 bit indices, ";
420-
run_benchmark(
558+
#ifndef OUT_TYPE_FLOAT16
559+
run_benchmark<float>(
421560
bit_rate,
422561
batch_size,
423562
num_rows,
424563
embedding_dim,
425564
average_len,
426565
false, // normalize_by_lengths
427566
true); // use_32_bit_indices
567+
#else
568+
run_benchmark<float16>(
569+
bit_rate,
570+
batch_size,
571+
num_rows,
572+
embedding_dim,
573+
average_len,
574+
false, // normalize_by_lengths
575+
true, // use_32_bit_indices
576+
false, // prefetch
577+
false); // is_bf16_out
578+
579+
run_benchmark<float16>(
580+
bit_rate,
581+
batch_size,
582+
num_rows,
583+
embedding_dim,
584+
average_len,
585+
false, // normalize_by_lengths
586+
true, // use_32_bit_indices
587+
false, // prefetch
588+
true); // is_bf16_out
589+
#endif // OUT_TYPE_FLOAT16
428590

429591
cout << "32 bit indices with prefetching, ";
430-
run_benchmark(
592+
#ifndef OUT_TYPE_FLOAT16
593+
run_benchmark<float>(
431594
bit_rate,
432595
batch_size,
433596
num_rows,
@@ -436,6 +599,29 @@ int main() {
436599
false, // normalize_by_lengths
437600
true, // use_32_bit_indices
438601
true); // prefetch
602+
#else
603+
run_benchmark<float16>(
604+
bit_rate,
605+
batch_size,
606+
num_rows,
607+
embedding_dim,
608+
average_len,
609+
false, // normalize_by_lengths
610+
true, // use_32_bit_indices
611+
true, // prefetch
612+
false); // is_bf16_out
613+
614+
run_benchmark<float16>(
615+
bit_rate,
616+
batch_size,
617+
num_rows,
618+
embedding_dim,
619+
average_len,
620+
false, // normalize_by_lengths
621+
true, // use_32_bit_indices
622+
true, // prefetch
623+
true); // is_bf16_out
624+
#endif // OUT_TYPE_FLOAT16
439625

440626
// running with normalize by lengths
441627
// run_benchmark(batch_size, num_rows, embedding_dim, average_len,

0 commit comments

Comments
 (0)