[test] fixed tests failed due to dtensor change (#4082)

* [test] fixed tests failed due to dtensor change

* polish code
This commit is contained in:
Frank Lee
2023-06-26 15:50:07 +08:00
parent 92f6791095
commit c4b1b65931
37 changed files with 233 additions and 289 deletions

View File

@@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
for key, weight in state_dict.items():
ret_block = None
ret_block_size = 0
if is_distributed_tensor(weight):
if not is_distributed_tensor(weight):
weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split.
@@ -146,7 +146,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
continue
# If the states are stored as DTensors, mark isDTensor as true.
if type(state_tensor) == DTensor:
if is_distributed_tensor(state_tensor):
isDTensor = True
state_size += calculate_tensor_size(state_tensor)