mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[shardformer] support from_pretrained when loading model with HybridParallelPlugin (#4575)
* hybrid plugin support huggingface from_pretrained * add huggingface compatibility tests * add folder cleaning * fix bugs
This commit is contained in:
@@ -141,10 +141,10 @@ def get_param_info(optim: Optimizer):
|
||||
|
||||
|
||||
def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
params = set(model.parameters())
|
||||
model_params = set(model.parameters())
|
||||
new_param_groups = []
|
||||
for group in optim.param_groups:
|
||||
params = [p for p in group['params'] if p in params]
|
||||
params = [p for p in group['params'] if p in model_params]
|
||||
new_param_groups.append({**group, 'params': params})
|
||||
optim.__setstate__({'param_groups': new_param_groups})
|
||||
|
||||
|
Reference in New Issue
Block a user