[Pipeline Middleware] fix data race in Pipeline Scheduler for DAG (#2087)

* add DAG test case

* fix datarace by adjusting theposition of lock

* polish code

* fix pytest for middleware

* remove test

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-12-08 13:32:27 +08:00
committed by GitHub
parent b175e6d58e
commit e4705ba4e2
3 changed files with 131 additions and 51 deletions

View File

@@ -32,6 +32,21 @@ class MLP(nn.Module):
for layer in self.layers:
x = layer(x)
return x
class DAG_MLP(nn.Module):
def __init__(self, dim: int, layers: int):
super().__init__()
self.layers = torch.nn.ModuleList()
self.dag_layer = nn.Linear(dim, dim, bias=False)
for _ in range(layers):
self.layers.append(nn.Linear(dim, dim, bias=False))
def forward(self, x, y):
for layer in self.layers:
x = layer(x)
y = self.dag_layer(y)
return x, y
class RpcTestModel(nn.Module):