mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[shardformer] support module saving and loading (#4062)
* [shardformer] support module saving and loading * polish code
This commit is contained in:
@@ -10,7 +10,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
@@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
|
||||
for key, weight in state_dict.items():
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
if type(weight) != DTensor:
|
||||
if is_distributed_tensor(weight):
|
||||
weight_size = calculate_tensor_size(weight)
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
|
Reference in New Issue
Block a user