@@ -558,7 +558,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
558
558
check_batch_size = max (len (self .model .device_ids ), check_batch_size )
559
559
_check_code (dataset = train_data , model = self .model , losser = losser , forward_func = self ._forward_func , metrics = metrics ,
560
560
dev_data = dev_dataset , metric_key = self .metric_key , check_level = check_code_level ,
561
- batch_size = check_batch_size )
561
+ batch_size = check_batch_size , pin_memory = self . pin_memory )
562
562
563
563
self .train_data = train_data
564
564
self .dev_data = dev_data # If None, No validation.
@@ -950,7 +950,7 @@ def _get_value_info(_dict):
950
950
return strs
951
951
952
952
953
- def _check_code (dataset , model , losser , metrics , forward_func , batch_size = DEFAULT_CHECK_BATCH_SIZE ,
953
+ def _check_code (dataset , model , losser , metrics , forward_func , pin_memory , batch_size = DEFAULT_CHECK_BATCH_SIZE ,
954
954
dev_data = None , metric_key = None , check_level = 0 ):
955
955
# check get_loss 方法
956
956
model_device = _get_model_device (model = model )
@@ -1010,7 +1010,7 @@ def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAUL
1010
1010
1011
1011
if dev_data is not None :
1012
1012
tester = Tester (data = dev_data [:batch_size * DEFAULT_CHECK_NUM_BATCH ], model = model , metrics = metrics ,
1013
- batch_size = batch_size , verbose = - 1 , use_tqdm = False )
1013
+ batch_size = batch_size , verbose = - 1 , use_tqdm = False , pin_memory = pin_memory )
1014
1014
evaluate_results = tester .test ()
1015
1015
_check_eval_results (metrics = evaluate_results , metric_key = metric_key , metric_list = metrics )
1016
1016
0 commit comments