mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[Tensor] fix optimizer for CPU parallel (#1069)
This commit is contained in:
@@ -9,6 +9,7 @@ from colossalai.context import ParallelMode
|
||||
|
||||
from colossalai.nn.parallel.layers import init_colo_module
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
@@ -56,10 +57,11 @@ def run_hybrid_device(use_ddp):
|
||||
print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}')
|
||||
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
|
||||
|
||||
optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
|
||||
out = model(data)
|
||||
out.sum().backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp):
|
||||
if use_ddp and world_size == 1:
|
||||
@@ -81,4 +83,4 @@ def _test_hybrid_device(world_size, use_ddp):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_hybrid_device(1, False)
|
||||
_test_hybrid_device(4, True)
|
||||
|
Reference in New Issue
Block a user