mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 18:39:56 +00:00
[pipelinable]use pipelinable to support GPT model. (#903)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [pipelinable]use pipelinable to support GPT model.
* fix a bug caused by ShardedModel
* polish
* fix front func list
This commit is contained in:
@@ -119,9 +119,12 @@ class PipelineSchedule(BaseSchedule):
|
||||
def pre_processing(self, engine):
|
||||
# TODO: remove this after testing new zero with pipeline parallelism
|
||||
model = engine.model
|
||||
if isinstance(model, (NaiveAMPModel, ShardedModelV2)):
|
||||
if isinstance(model, NaiveAMPModel):
|
||||
self.dtype = torch.half
|
||||
model = model.model
|
||||
if isinstance(model, ShardedModelV2):
|
||||
self.dtype = torch.half
|
||||
model = model.module
|
||||
sig = inspect.signature(model.forward)
|
||||
for p in sig.parameters.values():
|
||||
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
||||
@@ -135,6 +138,12 @@ class PipelineSchedule(BaseSchedule):
|
||||
else:
|
||||
sig = inspect.signature(model.forward)
|
||||
if isinstance(batch_data, torch.Tensor):
|
||||
for p in sig.parameters.values():
|
||||
if p.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
if input_tensor is None:
|
||||
return model(batch_data)
|
||||
else:
|
||||
return model(input_tensor)
|
||||
if input_tensor is None:
|
||||
return model(batch_data)
|
||||
elif len(sig.parameters) > 1:
|
||||
@@ -148,7 +157,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
filter_batch = False
|
||||
if filter_batch:
|
||||
batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters}
|
||||
if input_tensor is None:
|
||||
if input_tensor is None and filter_batch:
|
||||
return model(**batch_data)
|
||||
else:
|
||||
return model(input_tensor, **batch_data)
|
||||
|
Reference in New Issue
Block a user