mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[tensor] add cross_entrophy_loss (#868)
This commit is contained in:
@@ -14,11 +14,15 @@ class SimpleNet(CheckpointModule):
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__(checkpoint=checkpoint)
|
||||
self.proj1 = nn.Linear(4, 8)
|
||||
self.ln1 = nn.LayerNorm(8)
|
||||
self.proj2 = nn.Linear(8, 4)
|
||||
self.ln2 = nn.LayerNorm(4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj1(x)
|
||||
x = self.ln1(x)
|
||||
x = self.proj2(x)
|
||||
x = self.ln2(x)
|
||||
return x
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user