[gemini] fixed the gemini checkpoint io (#3934)

This commit is contained in:
Frank Lee
2023-06-09 09:48:49 +08:00
committed by GitHub
parent a98e16ed07
commit bd1ab98158
3 changed files with 19 additions and 11 deletions

View File

@@ -716,7 +716,10 @@ class _StateDictSharder:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
if self.current_block_size + tensor_size > self.max_shard_size:
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()