mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)
* fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugs
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
from torch.nn import Module, Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
@@ -152,3 +152,18 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
if p is working_param:
|
||||
continue
|
||||
working_param.data.copy_(p.data)
|
||||
|
||||
def update_master_params(self, model: Module):
|
||||
# Update master params from working params
|
||||
with torch.no_grad():
|
||||
for p in model.parameters():
|
||||
if (p is None) or (p not in self.working_to_master_map):
|
||||
continue
|
||||
master_param = self.working_to_master_map[p]
|
||||
master_param.data.copy_(p.data)
|
||||
|
||||
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||
return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()}
|
||||
|
||||
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
|
||||
|
Reference in New Issue
Block a user