Skip to content

Commit 2302853

Browse files
author
ouyhlan
committed
修复Trainer里check_code函数忽略pin_memory参数导致的内存bug
1 parent 9ac7d09 commit 2302853

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

fastNLP/core/trainer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
558558
check_batch_size = max(len(self.model.device_ids), check_batch_size)
559559
_check_code(dataset=train_data, model=self.model, losser=losser, forward_func=self._forward_func, metrics=metrics,
560560
dev_data=dev_dataset, metric_key=self.metric_key, check_level=check_code_level,
561-
batch_size=check_batch_size)
561+
batch_size=check_batch_size, pin_memory=self.pin_memory)
562562

563563
self.train_data = train_data
564564
self.dev_data = dev_data # If None, No validation.
@@ -950,7 +950,7 @@ def _get_value_info(_dict):
950950
return strs
951951

952952

953-
def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE,
953+
def _check_code(dataset, model, losser, metrics, forward_func, pin_memory, batch_size=DEFAULT_CHECK_BATCH_SIZE,
954954
dev_data=None, metric_key=None, check_level=0):
955955
# check get_loss 方法
956956
model_device = _get_model_device(model=model)
@@ -1010,7 +1010,7 @@ def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAUL
10101010

10111011
if dev_data is not None:
10121012
tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
1013-
batch_size=batch_size, verbose=-1, use_tqdm=False)
1013+
batch_size=batch_size, verbose=-1, use_tqdm=False, pin_memory=pin_memory)
10141014
evaluate_results = tester.test()
10151015
_check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics)
10161016

0 commit comments

Comments
 (0)