mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-20 09:01:06 +00:00
[pipeline]support more flexible pipeline (#1138)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [pipeline]support more flexible pipeline
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import inspect
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs
|
||||
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, call_module
|
||||
from colossalai.nn.layer.utils import CheckpointModule
|
||||
from colossalai.tensor import ColoParameter
|
||||
from .layer_sepc import LayerSpec
|
||||
@@ -213,8 +213,7 @@ class PipelinableModel(torch.nn.Module):
|
||||
self._front_func_dict = front_func_dict
|
||||
self._behind_func_dict = behind_func_dict
|
||||
|
||||
def forward(self, input_tensor, **kwargs):
|
||||
|
||||
def forward(self, *input_tensor, **kwargs):
|
||||
for module in self._module_list:
|
||||
|
||||
if id(module) in self._front_func_dict:
|
||||
@@ -224,36 +223,13 @@ class PipelinableModel(torch.nn.Module):
|
||||
forward_func = module._forward
|
||||
else:
|
||||
forward_func = module.forward
|
||||
module_kwargs = build_kwargs_for_module(forward_func, input_tensor, kwargs)
|
||||
if input_tensor is None:
|
||||
module_kwargs = build_kwargs_for_function(forward_func, kwargs)
|
||||
input_tensor = call_module(module, kwargs=module_kwargs)
|
||||
elif isinstance(input_tensor, torch.Tensor):
|
||||
input_tensor = call_module(module, args=(input_tensor,), kwargs=module_kwargs)
|
||||
else:
|
||||
module_kwargs = build_kwargs_for_module(forward_func, kwargs)
|
||||
if module_kwargs is not None and input_tensor is not None:
|
||||
if isinstance(module, CheckpointModule):
|
||||
convert_kwargs_to_args = []
|
||||
for v in module_kwargs.values():
|
||||
convert_kwargs_to_args.append(v)
|
||||
rst = module(input_tensor, *convert_kwargs_to_args)
|
||||
else:
|
||||
rst = module(input_tensor, **module_kwargs)
|
||||
if isinstance(rst, tuple):
|
||||
input_tensor = rst[0]
|
||||
else:
|
||||
input_tensor = rst
|
||||
elif module_kwargs is not None and input_tensor is None:
|
||||
if isinstance(module, CheckpointModule):
|
||||
convert_kwargs_to_args = []
|
||||
for v in module_kwargs.values():
|
||||
convert_kwargs_to_args.append(v)
|
||||
rst = module(input_tensor, *convert_kwargs_to_args)
|
||||
else:
|
||||
rst = module(**module_kwargs)
|
||||
if isinstance(rst, tuple):
|
||||
input_tensor = rst[0]
|
||||
else:
|
||||
input_tensor = rst
|
||||
else:
|
||||
input_tensor = module(input_tensor)
|
||||
input_tensor = call_module(module, args=input_tensor, kwargs=module_kwargs)
|
||||
|
||||
if id(module) in self._behind_func_dict:
|
||||
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
|
||||
|
Reference in New Issue
Block a user