[rpc] split with dag (#2028)

* add DAG to split_module

* add comment

* add test case for DAG

* remove print

* add DAG middleware in scheduler

* add test case for scheduler

* remove break

* recover old lifecycle

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-11-29 11:36:28 +08:00
committed by GitHub
parent 96134e7be3
commit b0936e4a44
5 changed files with 337 additions and 53 deletions

View File

@@ -20,6 +20,18 @@ def color_debug(text, prefix=' ', color='blue'):
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
class MLP(nn.Module):
def __init__(self, dim: int, layers: int):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(layers):
self.layers.append(nn.Linear(dim, dim, bias=False))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class RpcTestModel(nn.Module):