mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-19 00:16:51 +00:00
[booster] implemented mixed precision class (#3151)
* [booster] implemented mixed precision class * polish code
This commit is contained in:
@@ -6,7 +6,7 @@ from ..registry import ModelAttribute, model_zoo
|
||||
# ===============================
|
||||
# Register single-sentence GPT
|
||||
# ===============================
|
||||
BATCH_SIZE = 2
|
||||
BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined.
|
||||
SEQ_LENGTH = 16
|
||||
|
||||
|
||||
|
23
tests/test_booster/test_mixed_precision/test_fp16_torch.py
Normal file
23
tests/test_booster/test_mixed_precision/test_fp16_torch.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def test_torch_amp():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||
model = model_fn().cuda()
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
criterion = lambda x: x.mean()
|
||||
data = data_gen_fn()
|
||||
data = {k: v.cuda() if torch.is_tensor(v) else v for k, v in data.items()}
|
||||
mixed_precision = FP16TorchMixedPrecision()
|
||||
model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion)
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
optimizer.backward(loss)
|
||||
optimizer.clip_grad_by_norm(1.0)
|
||||
optimizer.step()
|
Reference in New Issue
Block a user