mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 11:08:50 +00:00
[pipeline]support more flexible pipeline (#1138)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [pipeline]support more flexible pipeline
This commit is contained in:
@@ -165,9 +165,9 @@ class PipelineSchedule(BaseSchedule):
|
||||
if isinstance(model, ShardedModelV2):
|
||||
self.dtype = torch.half
|
||||
model = model.module
|
||||
sig = inspect.signature(model.forward)
|
||||
for p in sig.parameters.values():
|
||||
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
||||
# sig = inspect.signature(model.forward)
|
||||
# for p in sig.parameters.values():
|
||||
# assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(model, data):
|
||||
@@ -180,7 +180,16 @@ class PipelineSchedule(BaseSchedule):
|
||||
stage_output = None
|
||||
if 'stage_output' in data:
|
||||
stage_output = data.pop('stage_output')
|
||||
return model(stage_output, **data)
|
||||
if stage_output is None:
|
||||
return model(**data)
|
||||
elif isinstance(stage_output, torch.Tensor):
|
||||
return model(stage_output, **data)
|
||||
elif isinstance(stage_output, (tuple, list)):
|
||||
return model(*stage_output, **data)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected stage_output to be of type torch.Tensor, list, or tuple, but got {type(stage_output)}"
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||
|
||||
|
Reference in New Issue
Block a user