[shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540)

* implement sharded optimizer saving

* add more param info

* finish implementation of sharded optimizer saving

* fix bugs in optimizer sharded saving

* add pp+zero test

* param group loading

* greedy loading of optimizer

* fix bug when loading

* implement optimizer sharded saving

* add optimizer test & arrange checkpointIO utils

* fix gemini sharding state_dict

* add verbose option

* add loading of master params

* fix typehint

* fix master/working mapping in fp16 amp
This commit is contained in:
Baizhou Zhang
2023-08-31 14:50:47 +08:00
committed by GitHub
parent 2c787d7f47
commit c9625dbb63
6 changed files with 812 additions and 369 deletions

View File

@@ -679,7 +679,7 @@ class ZeroDDP(ColoDDP):
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
gathered_param = gathered_param_buffer.pop(fp32_param)
block, block_size = sharder.append(prefix + name, gathered_param)
block, block_size = sharder.append_param(prefix + name, gathered_param)
if block is not None:
yield block, block_size
@@ -690,7 +690,7 @@ class ZeroDDP(ColoDDP):
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = sharder.append(prefix + name, buffer)
block, block_size = sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
# save extra states
@@ -698,7 +698,7 @@ class ZeroDDP(ColoDDP):
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = self.get_extra_state()
block, block_size = sharder.append(extra_state_key, extra_state)
block, block_size = sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size