[pipeline]support more flexible pipeline (#1138)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [pipeline]support more flexible pipeline
This commit is contained in:
YuliangLiu0306
2022-06-21 14:40:50 +08:00
committed by GitHub
parent ccf3c58c89
commit 18091581c0
3 changed files with 69 additions and 40 deletions

View File

@@ -165,9 +165,9 @@ class PipelineSchedule(BaseSchedule):
if isinstance(model, ShardedModelV2):
self.dtype = torch.half
model = model.module
sig = inspect.signature(model.forward)
for p in sig.parameters.values():
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
# sig = inspect.signature(model.forward)
# for p in sig.parameters.values():
# assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
@staticmethod
def _call_engine(model, data):
@@ -180,7 +180,16 @@ class PipelineSchedule(BaseSchedule):
stage_output = None
if 'stage_output' in data:
stage_output = data.pop('stage_output')
return model(stage_output, **data)
if stage_output is None:
return model(**data)
elif isinstance(stage_output, torch.Tensor):
return model(stage_output, **data)
elif isinstance(stage_output, (tuple, list)):
return model(*stage_output, **data)
else:
raise TypeError(
f"Expected stage_output to be of type torch.Tensor, list, or tuple, but got {type(stage_output)}"
)
else:
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")