mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[zero] hotfix master param sync (#4618)
* [zero] add method to update master params * [zero] update zero plugin * [plugin] update low level zero plugin
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import Dict, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
|
||||
@@ -600,3 +601,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
ret_block_size += current_block_size
|
||||
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
def update_master_params(self, model: nn.Module) -> None:
|
||||
"""Update master params from working params
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to update master params
|
||||
"""
|
||||
for p in model.parameters():
|
||||
p_id = id(p)
|
||||
if p_id in self._param_store.working_to_master_param:
|
||||
master_param = self._param_store.working_to_master_param[p_id]
|
||||
padding_size = self._param_store.get_param_padding_size(p)
|
||||
working_param = p.data.view(-1)
|
||||
if padding_size > 0:
|
||||
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
||||
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||
|
||||
Reference in New Issue
Block a user