[shardformer] support module saving and loading (#4062)

* [shardformer] support module saving and loading

* polish code
This commit is contained in:
Frank Lee
2023-06-22 11:42:11 +08:00
parent 7740c55c55
commit 8eb09a4c69
19 changed files with 493 additions and 102 deletions

View File

@@ -10,7 +10,7 @@ import torch
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor import is_distributed_tensor
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
for key, weight in state_dict.items():
ret_block = None
ret_block_size = 0
if type(weight) != DTensor:
if is_distributed_tensor(weight):
weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split.