mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)
* fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugs
This commit is contained in:
@@ -11,7 +11,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
@@ -122,14 +121,6 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
|
||||
# ======================================
|
||||
# Helper classes and functions for saving shard file
|
||||
# ======================================
|
||||
def unwrap_optimizer(optimizer: OptimizerWrapper):
|
||||
"""
|
||||
Unwrap a wrapped optimizer.
|
||||
This method should be used before saving/loading it to/from sharded checkpoints.
|
||||
"""
|
||||
|
||||
unwrapped_optim = optimizer.optim
|
||||
return unwrapped_optim
|
||||
|
||||
|
||||
class StateDictSharder:
|
||||
|
Reference in New Issue
Block a user