mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
[hotfix] add correct device for fake_param (#2796)
This commit is contained in:
parent
a619a190df
commit
56ddc9ca7a
@ -136,7 +136,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for fake_param in group['params']:
|
for fake_param in group['params']:
|
||||||
assert fake_param.grad is None
|
assert fake_param.grad is None
|
||||||
fake_param.data = none_tensor
|
fake_param.data = none_tensor.to(fake_param.device)
|
||||||
|
|
||||||
for chunk16 in self.chunk16_set:
|
for chunk16 in self.chunk16_set:
|
||||||
chunk16.optim_update()
|
chunk16.optim_update()
|
||||||
@ -307,7 +307,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||||||
if range_pair[0] >= range_pair[1]:
|
if range_pair[0] >= range_pair[1]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
fake_param = torch.nn.Parameter(torch.empty([0]))
|
grad_device = self.module.grads_device[param]
|
||||||
|
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
|
||||||
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
|
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
|
||||||
self.param_to_range[fake_param] = range_pair
|
self.param_to_range[fake_param] = range_pair
|
||||||
|
|
||||||
|
@ -70,8 +70,6 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
|
|||||||
for n, m in v.items():
|
for n, m in v.items():
|
||||||
if isinstance(m, torch.Tensor):
|
if isinstance(m, torch.Tensor):
|
||||||
o = w[n]
|
o = w[n]
|
||||||
if m.device != o.device:
|
|
||||||
o = o.to(m.device)
|
|
||||||
assert torch.equal(m, o)
|
assert torch.equal(m, o)
|
||||||
else:
|
else:
|
||||||
assert m == w[n]
|
assert m == w[n]
|
||||||
|
Loading…
Reference in New Issue
Block a user