[hotfix]fix bugs caused by refactored pipeline (#1133)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [hotfix]fix bugs caused by refactored pipeline
This commit is contained in:
YuliangLiu0306
2022-06-17 17:54:15 +08:00
committed by GitHub
parent 789cad301b
commit 946dbd629d
3 changed files with 20 additions and 39 deletions

View File

@@ -67,8 +67,8 @@ class NonPipelineSchedule(BaseSchedule):
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
batch_data = self.load_batch(data_iter)
if self.batch_data_process_func:
data, label = self.batch_data_process_func(batch_data)
if self.data_process_func:
data, label = self.data_process_func(batch_data)
else:
# if not batch data process func is given,
# then we regard the batch data as a simple tuple of (data, label)

View File

@@ -141,6 +141,8 @@ class PipelineSchedule(BaseSchedule):
for element in data:
if isinstance(element, dict):
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
elif data_dict:
data_dict['label'] = element[offset:offset + self.microbatch_size]
if data_dict:
return data_dict
return [val[offset:offset + self.microbatch_size] for val in data]
@@ -175,7 +177,10 @@ class PipelineSchedule(BaseSchedule):
elif isinstance(data, (list, tuple)):
return model(*data)
elif isinstance(data, dict):
return model(**data)
stage_output = None
if 'stage_output' in data:
stage_output = data.pop('stage_output')
return model(stage_output, **data)
else:
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
@@ -204,41 +209,14 @@ class PipelineSchedule(BaseSchedule):
data = stage_output
_, label = micro_batch_data
elif isinstance(micro_batch_data, dict):
args = []
data = {}
label = {}
# we feed the stage output to args first
# then map each arg in args to its param name
if stage_output is not None:
if isinstance(stage_output, torch.Tensor):
args.append(stage_output)
elif isinstance(stage_output, (list, tuple)):
args.extend(stage_output)
else:
raise TypeError(
f"Expected the values passed from previous pipeline stage to be torch.Tensor, list or tuple, but got {type(input_obj)}"
)
# get all parameter names for the forward function of the model
fwd_sig = self._get_actual_forward_func(model)
fwd_sig_param_name = [p.name for p in fwd_sig.parameters.values()]
# build the kwargs for the forward function
for idx, param_name in enumerate(fwd_sig_param_name):
if idx < len(args):
data[param_name] = args[idx]
else:
if param_name in micro_batch_data:
data[param_name] = micro_batch_data[param_name]
# get the tensors for loss
loss_sig = inspect.signature(criterion)
loss_sig_param_name = [p.name for p in loss_sig.parameters.values()]
for param_name in loss_sig_param_name:
if param_name in micro_batch_data:
label[param_name] = micro_batch_data[param_name]
data['stage_output'] = stage_output
if 'label' in micro_batch_data:
label = micro_batch_data.pop('label')
else:
label = None
load_data = micro_batch_data
data.update(load_data)
return data, label
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):