[PP Middleware] Add bwd and step for PP middleware (#2111)

* add bwd and step for PP middleware

* pre-commit

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-12-12 12:40:03 +08:00
committed by GitHub
parent 8afc001f4f
commit 09d69e1c25
5 changed files with 225 additions and 82 deletions

View File

@@ -31,7 +31,7 @@ class MLP(nn.Module):
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
return x.sum()
class DAG_MLP(nn.Module):
def __init__(self, dim: int, layers: int):
@@ -46,7 +46,7 @@ class DAG_MLP(nn.Module):
for layer in self.layers:
x = layer(x)
y = self.dag_layer(y)
return x, y
return x.sum(), y.sum()
class RpcTestModel(nn.Module):