mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change * polish code
This commit is contained in:
@@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
|
||||
for key, weight in state_dict.items():
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
if is_distributed_tensor(weight):
|
||||
if not is_distributed_tensor(weight):
|
||||
weight_size = calculate_tensor_size(weight)
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
@@ -146,7 +146,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
||||
continue
|
||||
|
||||
# If the states are stored as DTensors, mark isDTensor as true.
|
||||
if type(state_tensor) == DTensor:
|
||||
if is_distributed_tensor(state_tensor):
|
||||
isDTensor = True
|
||||
state_size += calculate_tensor_size(state_tensor)
|
||||
|
||||
|
Reference in New Issue
Block a user