[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)

* fix master param sync for hybrid plugin

* rewrite unwrap for ddp/fsdp

* rewrite unwrap for zero/gemini

* rewrite unwrap for hybrid plugin

* fix geemini unwrap

* fix bugs
This commit is contained in:
Baizhou Zhang
2023-09-20 18:29:37 +08:00
committed by GitHub
parent 7b9b86441f
commit c0a033700c
14 changed files with 141 additions and 171 deletions

View File

@@ -11,7 +11,6 @@ import torch
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
@@ -122,14 +121,6 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
# ======================================
# Helper classes and functions for saving shard file
# ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper):
"""
Unwrap a wrapped optimizer.
This method should be used before saving/loading it to/from sharded checkpoints.
"""
unwrapped_optim = optimizer.optim
return unwrapped_optim
class StateDictSharder: