[pipeline]support List of Dict data (#1125)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [pipeline]support List of Dict data

* polish
This commit is contained in:
YuliangLiu0306
2022-06-16 11:19:48 +08:00
committed by GitHub
parent 91a5999825
commit 3175bcb4d8
3 changed files with 14 additions and 8 deletions

View File

@@ -12,7 +12,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
A context manager to split the model into pipeline stages.
"""
def __init__(self, policy: str="balanced"):
def __init__(self, policy: str = "balanced"):
super().__init__()
self._layer_spec_dict = {}
self._root_children = None
@@ -61,11 +61,12 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
"""
# iterate over the positional arguments
# to check if an argument is a torch Module
# if found any torch Module, replace it with its layer spec
# if found any torch Module, replace it with its layer spec
# for storage purpose
modified_args = []
for arg in args:
if isinstance(arg, torch.nn.Module):
# (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself.
arg = self._layer_spec_dict[id(arg)]
modified_args.append(arg)
@@ -255,6 +256,3 @@ class PipelinableModel(torch.nn.Module):
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
return input_tensor