[zero] hotfix master param sync (#4618)

* [zero] add method to update master params

* [zero] update zero plugin

* [plugin] update low level zero plugin
This commit is contained in:
Hongxin Liu
2023-09-05 15:04:02 +08:00
committed by GitHub
parent aaeb520ce3
commit 807e01a4ba
5 changed files with 122 additions and 45 deletions

View File

@@ -6,6 +6,7 @@ from typing import Dict, Iterator, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
@@ -600,3 +601,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
ret_block_size += current_block_size
yield ret_block, ret_block_size
def update_master_params(self, model: nn.Module) -> None:
"""Update master params from working params
Args:
model (nn.Module): The model to update master params
"""
for p in model.parameters():
p_id = id(p)
if p_id in self._param_store.working_to_master_param:
master_param = self._param_store.working_to_master_param[p_id]
padding_size = self._param_store.get_param_padding_size(p)
working_param = p.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])