[pipelinable]use pipelinable to support GPT model. (#903)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [pipelinable]use pipelinable to support GPT model.

* fix a bug caused by ShardedModel

* polish

* fix front func list
This commit is contained in:
YuliangLiu0306
2022-05-11 09:23:58 +08:00
committed by GitHub
parent b61d64685f
commit 32a45cd7ef
2 changed files with 168 additions and 31 deletions

View File

@@ -119,9 +119,12 @@ class PipelineSchedule(BaseSchedule):
def pre_processing(self, engine):
# TODO: remove this after testing new zero with pipeline parallelism
model = engine.model
if isinstance(model, (NaiveAMPModel, ShardedModelV2)):
if isinstance(model, NaiveAMPModel):
self.dtype = torch.half
model = model.model
if isinstance(model, ShardedModelV2):
self.dtype = torch.half
model = model.module
sig = inspect.signature(model.forward)
for p in sig.parameters.values():
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
@@ -135,6 +138,12 @@ class PipelineSchedule(BaseSchedule):
else:
sig = inspect.signature(model.forward)
if isinstance(batch_data, torch.Tensor):
for p in sig.parameters.values():
if p.kind == inspect.Parameter.VAR_KEYWORD:
if input_tensor is None:
return model(batch_data)
else:
return model(input_tensor)
if input_tensor is None:
return model(batch_data)
elif len(sig.parameters) > 1:
@@ -148,7 +157,7 @@ class PipelineSchedule(BaseSchedule):
filter_batch = False
if filter_batch:
batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters}
if input_tensor is None:
if input_tensor is None and filter_batch:
return model(**batch_data)
else:
return model(input_tensor, **batch_data)