mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)
* [ddp] add is_ddp_ignored [ddp] rename to is_ddp_ignored * [zero] fix state_dict and load_state_dict * fix bugs * [zero] update unit test for ZeroDDP
This commit is contained in:
@@ -233,7 +233,7 @@ class ZeroDDP(ColoDDP):
|
||||
assert isinstance(p, ColoParameter)
|
||||
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.half()
|
||||
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
||||
continue
|
||||
|
||||
fp32_data = p.data.float()
|
||||
@@ -451,8 +451,14 @@ class ZeroDDP(ColoDDP):
|
||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||
|
||||
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
||||
# TODO: (HELSON) deal with ddp ignored parameters
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
ddp_param_list = []
|
||||
for name, param in self.named_parameters():
|
||||
if is_ddp_ignored(param):
|
||||
# deal with ddp ignored parameters
|
||||
destination[prefix + name] = param if keep_vars else param.detach()
|
||||
else:
|
||||
ddp_param_list.append((name, param))
|
||||
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
|
||||
if p is not None:
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
record_parameter = param_to_save_data[fp32_p]
|
||||
@@ -588,8 +594,16 @@ class ZeroDDP(ColoDDP):
|
||||
def load_fp32_parameter(chunk_slice, data):
|
||||
chunk_slice.copy_(data.flatten())
|
||||
|
||||
ddp_param_list = []
|
||||
for name, param in self.named_parameters():
|
||||
if is_ddp_ignored(param):
|
||||
# deal with ddp ignored parameters
|
||||
load(name, param, param.copy_)
|
||||
else:
|
||||
ddp_param_list.append((name, param))
|
||||
|
||||
fp32_to_name = dict()
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
|
||||
if p is not None:
|
||||
fp32_to_name[fp32_p] = name
|
||||
|
||||
|
Reference in New Issue
Block a user