[tensor] add cross_entrophy_loss (#868)

This commit is contained in:
Jiarui Fang
2022-04-25 16:01:52 +08:00
committed by GitHub
parent 3107817172
commit 1190b2c4a4
6 changed files with 48 additions and 7 deletions

View File

@@ -14,11 +14,15 @@ class SimpleNet(CheckpointModule):
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.proj1 = nn.Linear(4, 8)
self.ln1 = nn.LayerNorm(8)
self.proj2 = nn.Linear(8, 4)
self.ln2 = nn.LayerNorm(4)
def forward(self, x):
x = self.proj1(x)
x = self.ln1(x)
x = self.proj2(x)
x = self.ln2(x)
return x