@@ -61,6 +61,7 @@ static vector<vector<int>> GetInputs_() {
61
61
return input_dims;
62
62
}
63
63
64
+ template <typename OutType>
64
65
int run_benchmark (
65
66
int bit_rate,
66
67
int batch_size,
@@ -69,7 +70,8 @@ int run_benchmark(
69
70
int average_len,
70
71
bool normalize_by_lengths,
71
72
bool use_32_bit_indices = false ,
72
- bool prefetch = false ) {
73
+ bool prefetch = false ,
74
+ bool is_bf16_out = false ) {
73
75
// Create embedding table
74
76
int num_elem_per_byte = 8 / bit_rate;
75
77
int fused_embedding_dim =
@@ -133,8 +135,8 @@ int run_benchmark(
133
135
weights[i] = embedding_distribution (generator);
134
136
}
135
137
136
- vector<float > output_sls_ref (batch_size * embedding_dim);
137
- vector<float > output_slws_ref (output_sls_ref.size ()),
138
+ vector<OutType > output_sls_ref (batch_size * embedding_dim);
139
+ vector<OutType > output_slws_ref (output_sls_ref.size ()),
138
140
output_sls (output_sls_ref.size ()), output_slws (output_sls_ref.size ());
139
141
140
142
constexpr int NUM_WARMUP = 10 ;
@@ -148,11 +150,12 @@ int run_benchmark(
148
150
CACHE_LINE_LEN);
149
151
150
152
for (bool has_weight : {false , true }) {
151
- vector<float >& output_ref = has_weight ? output_slws_ref : output_sls_ref;
152
- vector<float > output_autovec (output_sls_ref.size ());
153
+ vector<OutType >& output_ref = has_weight ? output_slws_ref : output_sls_ref;
154
+ vector<OutType > output_autovec (output_sls_ref.size ());
153
155
154
156
bool success = false , success_ref = false , success_autovec = false ;
155
157
158
+ #ifndef OUT_TYPE_FLOAT16
156
159
auto kernel_32 = GenerateEmbeddingSpMDMNBit<int32_t >(
157
160
bit_rate,
158
161
embedding_dim,
@@ -165,8 +168,9 @@ int run_benchmark(
165
168
has_weight,
166
169
normalize_by_lengths,
167
170
prefetch ? 16 : 0 );
171
+ #endif // OUT_TYPE_FLOAT16
168
172
169
- vector<float >& output = has_weight ? output_slws : output_sls;
173
+ vector<OutType >& output = has_weight ? output_slws : output_sls;
170
174
for (bool flush_cache : {false , true }) {
171
175
// Reference implementation
172
176
double t_ref = measureWithWarmup (
@@ -183,7 +187,13 @@ int run_benchmark(
183
187
offsets.data (),
184
188
has_weight ? weights.data () : nullptr ,
185
189
normalize_by_lengths,
186
- output_ref.data ());
190
+ output_ref.data (),
191
+ false , // is_weight_positional
192
+ true , // use_offsets
193
+ -1 , // output_stride
194
+ -1 , // input_stride
195
+ true , // scale_bias_last
196
+ is_bf16_out);
187
197
} else {
188
198
success_ref = EmbeddingSpMDMNBit_ref (
189
199
bit_rate,
@@ -196,7 +206,13 @@ int run_benchmark(
196
206
offsets.data (),
197
207
has_weight ? weights.data () : nullptr ,
198
208
normalize_by_lengths,
199
- output_ref.data ());
209
+ output_ref.data (),
210
+ false , // is_weight_positional
211
+ true , // use_offsets
212
+ -1 , // output_stride
213
+ -1 , // input_stride
214
+ true , // scale_bias_last
215
+ is_bf16_out);
200
216
}
201
217
},
202
218
NUM_WARMUP,
@@ -227,7 +243,13 @@ int run_benchmark(
227
243
offsets.data (),
228
244
has_weight ? weights.data () : nullptr ,
229
245
normalize_by_lengths,
230
- output_autovec.data ());
246
+ output_autovec.data (),
247
+ false , // is_weight_positional
248
+ true , // use_offsets
249
+ -1 , // output_stride
250
+ -1 , // input_stride
251
+ true , // scale_bias_last
252
+ is_bf16_out);
231
253
} else {
232
254
success_autovec = EmbeddingSpMDMNBit_autovec (
233
255
bit_rate,
@@ -240,7 +262,13 @@ int run_benchmark(
240
262
offsets.data (),
241
263
has_weight ? weights.data () : nullptr ,
242
264
normalize_by_lengths,
243
- output_autovec.data ());
265
+ output_autovec.data (),
266
+ false , // is_weight_positional
267
+ true , // use_offsets
268
+ -1 , // output_stride
269
+ -1 , // input_stride
270
+ true , // scale_bias_last
271
+ is_bf16_out);
244
272
}
245
273
},
246
274
NUM_WARMUP,
@@ -256,6 +284,7 @@ int run_benchmark(
256
284
}
257
285
});
258
286
287
+ #ifndef OUT_TYPE_FLOAT16
259
288
// Hand-written AVX2/AVX512 implementation
260
289
double t = measureWithWarmup (
261
290
[&]() {
@@ -293,6 +322,7 @@ int run_benchmark(
293
322
cache_evict (output);
294
323
}
295
324
});
325
+ #endif // OUT_TYPE_FLOAT16
296
326
297
327
// printMatrix(
298
328
// matrix_op_t::NoTranspose,
@@ -312,6 +342,7 @@ int run_benchmark(
312
342
if (!flush_cache) {
313
343
// vector<float>& output_ref =
314
344
// has_weight ? output_slws_ref : output_sls_ref;
345
+ #ifndef OUT_TYPE_FLOAT16
315
346
if (success != success_ref) {
316
347
assert (
317
348
false &&
@@ -320,13 +351,32 @@ int run_benchmark(
320
351
<< endl;
321
352
} else {
322
353
for (size_t i = 0 ; i < output.size (); ++i) {
323
- assert (fabs (output[i] - output_ref[i]) < 1e-3 );
324
- if (fabs (output[i] - output_ref[i]) >= 1e-3 ) {
325
- cout << " asmjit vs ref : " << i << " " << output[i] << " "
326
- << output_ref[i] << endl;
354
+ float tmp1 = 0 ;
355
+ float tmp2 = 0 ;
356
+ if (std::is_same<OutType, float >::value) {
357
+ tmp1 = output[i];
358
+ tmp2 = output_ref[i];
359
+ } else if (std::is_same<OutType, uint16_t >::value) {
360
+ if (is_bf16_out) {
361
+ tmp1 = cpu_bf162float (output[i]);
362
+ tmp2 = cpu_bf162float (output_ref[i]);
363
+ } else {
364
+ tmp1 = cpu_half2float (output[i]);
365
+ tmp2 = cpu_half2float (output_ref[i]);
366
+ }
367
+ } else {
368
+ assert (false && " ERROR: unsupported output type" );
369
+ cout << " ERROR: unsupported output type" << endl;
370
+ }
371
+
372
+ assert (fabs (tmp1 - tmp2) < 1e-3 );
373
+ if (fabs (tmp1 - tmp2) >= 1e-3 ) {
374
+ cout << " asmjit vs ref : " << i << " " << tmp1 << " " << tmp2
375
+ << endl;
327
376
}
328
377
}
329
378
}
379
+ #endif // OUT_TYPE_FLOAT16
330
380
331
381
if (success_autovec != success_ref) {
332
382
assert (
@@ -335,16 +385,47 @@ int run_benchmark(
335
385
cout << " autovec return " << success_autovec << " ref return "
336
386
<< success_ref << endl;
337
387
} else {
338
- for (size_t i = 0 ; i < output.size (); ++i) {
339
- assert (fabs (output_autovec[i] - output_ref[i]) < 1e-3 );
340
- if (fabs (output_autovec[i] - output_ref[i]) >= 1e-3 ) {
341
- cout << " autovec vs ref: " << i << " " << output_autovec[i] << " "
342
- << output_ref[i] << endl;
388
+ for (size_t i = 0 ; i < output_autovec.size (); ++i) {
389
+ float tmp1 = 0 ;
390
+ float tmp2 = 0 ;
391
+ if (std::is_same<OutType, float >::value) {
392
+ tmp1 = output_autovec[i];
393
+ tmp2 = output_ref[i];
394
+ } else if (std::is_same<OutType, uint16_t >::value) {
395
+ if (is_bf16_out) {
396
+ tmp1 = cpu_bf162float (output_autovec[i]);
397
+ tmp2 = cpu_bf162float (output_ref[i]);
398
+ } else {
399
+ tmp1 = cpu_half2float (output_autovec[i]);
400
+ tmp2 = cpu_half2float (output_ref[i]);
401
+ }
402
+ } else {
403
+ assert (false && " ERROR: unsupported output type" );
404
+ cout << " ERROR: unsupported output type" << endl;
405
+ }
406
+
407
+ assert (fabs (tmp1 - tmp2) < 1e-3 );
408
+ if (fabs (tmp1 - tmp2) >= 1e-3 ) {
409
+ cout << " autovec vs ref: " << i << " " << tmp1 << " " << tmp2
410
+ << endl;
343
411
}
344
412
}
345
413
}
346
414
}
347
415
416
+ if (std::is_same<OutType, float >::value) {
417
+ cout << " out type fp32, " ;
418
+ } else if (std::is_same<OutType, uint16_t >::value) {
419
+ if (is_bf16_out) {
420
+ cout << " out type bf16, " ;
421
+ } else {
422
+ cout << " out type fp16, " ;
423
+ }
424
+ } else {
425
+ assert (false && " ERROR: unsupported output type" );
426
+ cout << " ERROR: unsupported output type" << endl;
427
+ }
428
+
348
429
if (has_weight) {
349
430
cout << " SLW(WEIGHTED), " ;
350
431
} else {
@@ -361,6 +442,7 @@ int run_benchmark(
361
442
cout << " prefetch off, " ;
362
443
}
363
444
445
+ #ifndef OUT_TYPE_FLOAT16
364
446
cout << " b/w, " << bytes / 1e9 / t << " , GB/s, " << " effective b/w, "
365
447
<< bytes_padded / 1e9 / t << " , GB/s, " << " time, " << t
366
448
<< " , autovec b/w, " << bytes / 1e9 / t_autovec << " , GB/s, "
@@ -370,6 +452,14 @@ int run_benchmark(
370
452
<< bytes_padded / 1e9 / t_ref << " , GB/s, " << " ref time, " << t_ref
371
453
<< " , autovec speedup, " << t_ref / t_autovec << " , asmjit speedup, "
372
454
<< t_ref / t << endl;
455
+ #else
456
+ cout << " autovec b/w, " << bytes / 1e9 / t_autovec << " , GB/s, "
457
+ << " autovec eff. b/w, " << bytes_padded / 1e9 / t_autovec
458
+ << " , GB/s, " << " autovec time, " << t_autovec << " , ref b/w, "
459
+ << bytes / 1e9 / t_ref << " , GB/s, " << " ref eff. b/w, "
460
+ << bytes_padded / 1e9 / t_ref << " , GB/s, " << " ref time, " << t_ref
461
+ << " , autovec speedup, " << t_ref / t_autovec << endl;
462
+ #endif // OUT_TYPE_FLOAT16
373
463
} // flush_cache
374
464
} // has_weight
375
465
return 0 ;
@@ -397,16 +487,41 @@ int main() {
397
487
// args: batch sz, num rows, emb dim, avg len, normalize, use 32b,
398
488
// prefetch
399
489
cout << " 64 bit indices, " ;
400
- run_benchmark (
490
+ #ifndef OUT_TYPE_FLOAT16
491
+ run_benchmark<float >(
401
492
bit_rate,
402
493
batch_size,
403
494
num_rows,
404
495
embedding_dim,
405
496
average_len,
406
497
false ); // normalize_by_lengths
498
+ #else
499
+ run_benchmark<float16>(
500
+ bit_rate,
501
+ batch_size,
502
+ num_rows,
503
+ embedding_dim,
504
+ average_len,
505
+ false , // normalize_by_lengths
506
+ false , // use_32_bit_indices
507
+ false , // prefetch
508
+ false ); // is_bf16_out
509
+
510
+ run_benchmark<float16>(
511
+ bit_rate,
512
+ batch_size,
513
+ num_rows,
514
+ embedding_dim,
515
+ average_len,
516
+ false , // normalize_by_lengths
517
+ false , // use_32_bit_indices
518
+ false , // prefetch
519
+ true ); // is_bf16_out
520
+ #endif // OUT_TYPE_FLOAT16
407
521
408
522
cout << " 64 bit indices with prefetching, " ;
409
- run_benchmark (
523
+ #ifndef OUT_TYPE_FLOAT16
524
+ run_benchmark<float >(
410
525
bit_rate,
411
526
batch_size,
412
527
num_rows,
@@ -415,19 +530,67 @@ int main() {
415
530
false , // normalize_by_lengths
416
531
false , // use_32_bit_indices
417
532
true ); // prefetch
533
+ #else
534
+ run_benchmark<float16>(
535
+ bit_rate,
536
+ batch_size,
537
+ num_rows,
538
+ embedding_dim,
539
+ average_len,
540
+ false , // normalize_by_lengths
541
+ false , // use_32_bit_indices
542
+ true , // prefetch
543
+ false ); // is_bf16_out
544
+
545
+ run_benchmark<float16>(
546
+ bit_rate,
547
+ batch_size,
548
+ num_rows,
549
+ embedding_dim,
550
+ average_len,
551
+ false , // normalize_by_lengths
552
+ false , // use_32_bit_indices
553
+ true , // prefetch
554
+ true ); // is_bf16_out
555
+ #endif // OUT_TYPE_FLOAT16
418
556
419
557
cout << " 32 bit indices, " ;
420
- run_benchmark (
558
+ #ifndef OUT_TYPE_FLOAT16
559
+ run_benchmark<float >(
421
560
bit_rate,
422
561
batch_size,
423
562
num_rows,
424
563
embedding_dim,
425
564
average_len,
426
565
false , // normalize_by_lengths
427
566
true ); // use_32_bit_indices
567
+ #else
568
+ run_benchmark<float16>(
569
+ bit_rate,
570
+ batch_size,
571
+ num_rows,
572
+ embedding_dim,
573
+ average_len,
574
+ false , // normalize_by_lengths
575
+ true , // use_32_bit_indices
576
+ false , // prefetch
577
+ false ); // is_bf16_out
578
+
579
+ run_benchmark<float16>(
580
+ bit_rate,
581
+ batch_size,
582
+ num_rows,
583
+ embedding_dim,
584
+ average_len,
585
+ false , // normalize_by_lengths
586
+ true , // use_32_bit_indices
587
+ false , // prefetch
588
+ true ); // is_bf16_out
589
+ #endif // OUT_TYPE_FLOAT16
428
590
429
591
cout << " 32 bit indices with prefetching, " ;
430
- run_benchmark (
592
+ #ifndef OUT_TYPE_FLOAT16
593
+ run_benchmark<float >(
431
594
bit_rate,
432
595
batch_size,
433
596
num_rows,
@@ -436,6 +599,29 @@ int main() {
436
599
false , // normalize_by_lengths
437
600
true , // use_32_bit_indices
438
601
true ); // prefetch
602
+ #else
603
+ run_benchmark<float16>(
604
+ bit_rate,
605
+ batch_size,
606
+ num_rows,
607
+ embedding_dim,
608
+ average_len,
609
+ false , // normalize_by_lengths
610
+ true , // use_32_bit_indices
611
+ true , // prefetch
612
+ false ); // is_bf16_out
613
+
614
+ run_benchmark<float16>(
615
+ bit_rate,
616
+ batch_size,
617
+ num_rows,
618
+ embedding_dim,
619
+ average_len,
620
+ false , // normalize_by_lengths
621
+ true , // use_32_bit_indices
622
+ true , // prefetch
623
+ true ); // is_bf16_out
624
+ #endif // OUT_TYPE_FLOAT16
439
625
440
626
// running with normalize by lengths
441
627
// run_benchmark(batch_size, num_rows, embedding_dim, average_len,
0 commit comments