[shardformer] support from_pretrained when loading model with HybridParallelPlugin (#4575)

* hybrid plugin support huggingface from_pretrained

* add huggingface compatibility tests

* add folder cleaning

* fix bugs
This commit is contained in:
Baizhou Zhang
2023-09-01 17:40:01 +08:00
committed by GitHub
parent c9625dbb63
commit 38ccb8b1a3
5 changed files with 218 additions and 17 deletions

View File

@@ -141,10 +141,10 @@ def get_param_info(optim: Optimizer):
def init_pipeline_optimizer(optim: Optimizer, model: Module):
params = set(model.parameters())
model_params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
params = [p for p in group['params'] if p in params]
params = [p for p in group['params'] if p in model_params]
new_param_groups.append({**group, 'params': params})
optim.__setstate__({'param_groups': new_param_groups})