[pipeline]support more flexible pipeline (#1138)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [pipeline]support more flexible pipeline
This commit is contained in:
YuliangLiu0306
2022-06-21 14:40:50 +08:00
committed by GitHub
parent ccf3c58c89
commit 18091581c0
3 changed files with 69 additions and 40 deletions

View File

@@ -1,7 +1,7 @@
import torch
import inspect
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, call_module
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoParameter
from .layer_sepc import LayerSpec
@@ -213,8 +213,7 @@ class PipelinableModel(torch.nn.Module):
self._front_func_dict = front_func_dict
self._behind_func_dict = behind_func_dict
def forward(self, input_tensor, **kwargs):
def forward(self, *input_tensor, **kwargs):
for module in self._module_list:
if id(module) in self._front_func_dict:
@@ -224,36 +223,13 @@ class PipelinableModel(torch.nn.Module):
forward_func = module._forward
else:
forward_func = module.forward
module_kwargs = build_kwargs_for_module(forward_func, input_tensor, kwargs)
if input_tensor is None:
module_kwargs = build_kwargs_for_function(forward_func, kwargs)
input_tensor = call_module(module, kwargs=module_kwargs)
elif isinstance(input_tensor, torch.Tensor):
input_tensor = call_module(module, args=(input_tensor,), kwargs=module_kwargs)
else:
module_kwargs = build_kwargs_for_module(forward_func, kwargs)
if module_kwargs is not None and input_tensor is not None:
if isinstance(module, CheckpointModule):
convert_kwargs_to_args = []
for v in module_kwargs.values():
convert_kwargs_to_args.append(v)
rst = module(input_tensor, *convert_kwargs_to_args)
else:
rst = module(input_tensor, **module_kwargs)
if isinstance(rst, tuple):
input_tensor = rst[0]
else:
input_tensor = rst
elif module_kwargs is not None and input_tensor is None:
if isinstance(module, CheckpointModule):
convert_kwargs_to_args = []
for v in module_kwargs.values():
convert_kwargs_to_args.append(v)
rst = module(input_tensor, *convert_kwargs_to_args)
else:
rst = module(**module_kwargs)
if isinstance(rst, tuple):
input_tensor = rst[0]
else:
input_tensor = rst
else:
input_tensor = module(input_tensor)
input_tensor = call_module(module, args=input_tensor, kwargs=module_kwargs)
if id(module) in self._behind_func_dict:
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)