@@ -107,10 +107,7 @@ def attention_layer_opt(prefix, config, init_dict, network, input_tensor, imask)
107
107
Ball = init_dict [prefix + BQKV ]
108
108
109
109
# FC_attention
110
- if config .use_int8 :
111
- mult_all = network .add_convolution_nd (input_tensor , 3 * hidden_size , (1 , 1 ), Wall , Ball )
112
- else :
113
- mult_all = network .add_fully_connected (input_tensor , 3 * hidden_size , Wall , Ball )
110
+ mult_all = network .add_convolution_nd (input_tensor , 3 * hidden_size , (1 , 1 ), Wall , Ball )
114
111
115
112
if config .use_qat :
116
113
dr_qkv = max (
@@ -217,24 +214,20 @@ def transformer_layer_opt(prefix, config, init_dict, network, input_tensor, imas
217
214
218
215
# FC0
219
216
B_aout = init_dict [prefix + B_AOUT ]
220
- if config .use_int8 :
217
+ if not config .use_int8 and use_custom_fc ():
218
+ W_aoutT = init_dict [prefix + W_AOUT + "_notrans" ]
219
+ attention_out_fc = custom_fc (config , network , attention_heads , hidden_size , W_aoutT )
220
+ else :
221
221
W_aout = init_dict [prefix + W_AOUT ]
222
222
attention_out_fc = network .add_convolution_nd (attention_heads , hidden_size , (1 , 1 ), W_aout , B_aout )
223
223
B_aout = None
224
224
225
- if not config .use_int8_skipln :
225
+ if config . use_int8 and not config .use_int8_skipln :
226
226
attention_out_fc .set_output_type (0 , trt .DataType .HALF if config .use_fp16 else trt .DataType .FLOAT )
227
227
228
- if config .use_qat :
228
+ if config .use_int8 and config . use_qat :
229
229
dr_fc_aout = init_dict [prefix + 'attention_output_add_local_input_quantizer_amax' ]
230
230
set_output_range (attention_out_fc , dr_fc_aout )
231
- elif use_custom_fc ():
232
- W_aoutT = init_dict [prefix + W_AOUT + "_notrans" ]
233
- attention_out_fc = custom_fc (config , network , attention_heads , hidden_size , W_aoutT )
234
- else :
235
- W_aout = init_dict [prefix + W_AOUT ]
236
- attention_out_fc = network .add_fully_connected (attention_heads , hidden_size , W_aout , B_aout )
237
- B_aout = None
238
231
239
232
skiplayer = skipln (prefix + "attention_output_layernorm_" ,config , init_dict , network , attention_out_fc .get_output (0 ), input_tensor , B_aout )
240
233
attention_ln = skiplayer .get_output (0 )
@@ -245,10 +238,7 @@ def transformer_layer_opt(prefix, config, init_dict, network, input_tensor, imas
245
238
# FC1 + GELU
246
239
B_mid = init_dict [prefix + B_MID ]
247
240
W_mid = init_dict [prefix + W_MID ]
248
- if config .use_int8 :
249
- mid_dense = network .add_convolution_nd (attention_ln , config .intermediate_size , (1 , 1 ), W_mid , B_mid )
250
- else :
251
- mid_dense = network .add_fully_connected (attention_ln , config .intermediate_size , W_mid , B_mid )
241
+ mid_dense = network .add_convolution_nd (attention_ln , config .intermediate_size , (1 , 1 ), W_mid , B_mid )
252
242
253
243
mid_dense_out = mid_dense .get_output (0 )
254
244
POW = network .add_constant ((1 , 1 , 1 , 1 , 1 ), trt .Weights (np .ascontiguousarray ([3.0 ], dtype = np .float32 )))
@@ -281,21 +271,18 @@ def transformer_layer_opt(prefix, config, init_dict, network, input_tensor, imas
281
271
# FC2
282
272
# Dense to hidden size
283
273
B_lout = init_dict [prefix + B_LOUT ]
284
- if config .use_int8 and not config .use_fc2_gemm :
285
- W_lout = init_dict [prefix + W_LOUT ]
286
- out_dense = network .add_convolution_nd (intermediate_act , hidden_size , (1 , 1 ), W_lout , B_lout )
287
- B_lout = None
288
-
289
- if not config .use_int8_skipln :
290
- out_dense .set_output_type (0 , trt .DataType .HALF if config .use_fp16 else trt .DataType .FLOAT )
291
- elif use_custom_fc ():
274
+ prefer_conv = config .use_int8 and not config .use_fc2_gemm
275
+ if not prefer_conv and use_custom_fc ():
292
276
W_loutT = init_dict [prefix + W_LOUT + "_notrans" ]
293
277
out_dense = custom_fc (config , network , intermediate_act , hidden_size , W_loutT )
294
278
else :
295
279
W_lout = init_dict [prefix + W_LOUT ]
296
- out_dense = network .add_fully_connected (intermediate_act , hidden_size , W_lout , B_lout )
280
+ out_dense = network .add_convolution_nd (intermediate_act , hidden_size , ( 1 , 1 ) , W_lout , B_lout )
297
281
B_lout = None
298
282
283
+ if config .use_int8 and not config .use_int8_skipln :
284
+ out_dense .set_output_type (0 , trt .DataType .HALF if config .use_fp16 else trt .DataType .FLOAT )
285
+
299
286
if config .use_qat :
300
287
dr_fc_out = init_dict [prefix + 'output_add_local_input_quantizer_amax' ]
301
288
set_output_range (out_dense , dr_fc_out )
@@ -334,7 +321,7 @@ def squad_output(prefix, config, init_dict, network, input_tensor):
334
321
B_out = init_dict [prefix + SQD_B ]
335
322
336
323
W = network .add_constant ((1 , hidden_size , 2 ), W_out )
337
- dense = network .add_fully_connected (input_tensor , 2 , W_out , B_out )
324
+ dense = network .add_convolution_nd (input_tensor , 2 , ( 1 , 1 ) , W_out , B_out )
338
325
339
326
OUT = network .add_shuffle (dense .get_output (0 ))
340
327
OUT .second_transpose = (1 , 0 , 2 , 3 , 4 )
@@ -402,7 +389,7 @@ def build_engine(batch_sizes, workspace_size, sequence_lengths, config, weights_
402
389
explicit_batch_flag = 1 << int (trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH )
403
390
404
391
with trt .Builder (TRT_LOGGER ) as builder , builder .create_network (explicit_batch_flag ) as network , builder .create_builder_config () as builder_config :
405
- builder_config .max_workspace_size = workspace_size * (1024 * 1024 )
392
+ builder_config .set_memory_pool_limit ( trt . MemoryPoolType . WORKSPACE , workspace_size * (1024 * 1024 ) )
406
393
builder_config .avg_timing_iterations = 8
407
394
if config .use_fp16 :
408
395
builder_config .set_flag (trt .BuilderFlag .FP16 )
@@ -451,10 +438,11 @@ def build_engine(batch_sizes, workspace_size, sequence_lengths, config, weights_
451
438
squad_logits = squad_output ("cls_" , config , weights_dict , network , bert_out )
452
439
squad_logits_out = squad_logits .get_output (0 )
453
440
441
+ squad_logits_out .name = "logits_out"
454
442
network .mark_output (squad_logits_out )
455
443
456
444
build_start_time = time .time ()
457
- engine = builder .build_engine (network , builder_config )
445
+ serialized_engine = builder .build_serialized_network (network , builder_config )
458
446
build_time_elapsed = (time .time () - build_start_time )
459
447
TRT_LOGGER .log (TRT_LOGGER .INFO , "build engine in {:.3f} Sec" .format (build_time_elapsed ))
460
448
@@ -469,7 +457,7 @@ def build_engine(batch_sizes, workspace_size, sequence_lengths, config, weights_
469
457
470
458
if config .use_int8 and not config .use_qat :
471
459
calibrator .free ()
472
- return engine
460
+ return serialized_engine
473
461
474
462
def generate_calibration_cache (sequence_lengths , workspace_size , config , weights_dict , squad_json , vocab_file , calibrationCacheFile , calib_num ):
475
463
"""
@@ -488,7 +476,7 @@ def generate_calibration_cache(sequence_lengths, workspace_size, config, weights
488
476
config .use_fp16 = False
489
477
config .is_calib_mode = True
490
478
491
- with build_engine ([1 ], workspace_size , sequence_lengths , config , weights_dict , squad_json , vocab_file , calibrationCacheFile , calib_num , False ) as engine :
479
+ with build_engine ([1 ], workspace_size , sequence_lengths , config , weights_dict , squad_json , vocab_file , calibrationCacheFile , calib_num , False ) as serialized_engine :
492
480
TRT_LOGGER .log (TRT_LOGGER .INFO , "calibration cache generated in {:}" .format (calibrationCacheFile ))
493
481
494
482
config .use_fp16 = saved_use_fp16
@@ -553,9 +541,7 @@ def main():
553
541
else :
554
542
raise RuntimeError ("You need either specify TF checkpoint using option --ckpt or ONNX using option --onnx to build TRT BERT model." )
555
543
556
- with build_engine (args .batch_size , args .workspace_size , args .sequence_length , config , weights_dict , args .squad_json , args .vocab_file , calib_cache , args .calib_num , args .verbose ) as engine :
557
- TRT_LOGGER .log (TRT_LOGGER .VERBOSE , "Serializing Engine..." )
558
- serialized_engine = engine .serialize ()
544
+ with build_engine (args .batch_size , args .workspace_size , args .sequence_length , config , weights_dict , args .squad_json , args .vocab_file , calib_cache , args .calib_num , args .verbose ) as serialized_engine :
559
545
TRT_LOGGER .log (TRT_LOGGER .INFO , "Saving Engine to {:}" .format (args .output ))
560
546
with open (args .output , "wb" ) as fout :
561
547
fout .write (serialized_engine )
0 commit comments