mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 16:00:49 +00:00
[pipeline]support List of Dict data (#1125)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [pipeline]support List of Dict data
* polish
This commit is contained in:
@@ -12,7 +12,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||
A context manager to split the model into pipeline stages.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: str="balanced"):
|
||||
def __init__(self, policy: str = "balanced"):
|
||||
super().__init__()
|
||||
self._layer_spec_dict = {}
|
||||
self._root_children = None
|
||||
@@ -61,11 +61,12 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||
"""
|
||||
# iterate over the positional arguments
|
||||
# to check if an argument is a torch Module
|
||||
# if found any torch Module, replace it with its layer spec
|
||||
# if found any torch Module, replace it with its layer spec
|
||||
# for storage purpose
|
||||
modified_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
# (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself.
|
||||
arg = self._layer_spec_dict[id(arg)]
|
||||
modified_args.append(arg)
|
||||
|
||||
@@ -255,6 +256,3 @@ class PipelinableModel(torch.nn.Module):
|
||||
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user