mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-04 15:14:19 +00:00
[Gemini] add unitests to check gemini correctness (#2015)
This commit is contained in:
15
tests/components_to_test/utils/executor.py
Normal file
15
tests/components_to_test/utils/executor.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, use_init_ctx=False):
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
if use_init_ctx:
|
||||
model.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
Reference in New Issue
Block a user