mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-18 11:48:53 +00:00
[plugin]fix 3d checkpoint load when booster boost without optimizer. (#5135)
* fix 3d checkpoint load when booster boost without optimizer fix 3d checkpoint load when booster boost without optimizer * test ci * revert ci * fix fix
This commit is contained in:
parent
f6731db67c
commit
2a2ec49aa7
@ -21,7 +21,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
|
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
|
||||||
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin
|
||||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
@ -42,7 +42,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class HybridParallelModule(ModelWrapper):
|
class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
module: Module,
|
module: Module,
|
||||||
|
@ -116,6 +116,9 @@ def check_gemini_plugin(
|
|||||||
"transformers_falcon_for_sequence_classification",
|
"transformers_falcon_for_sequence_classification",
|
||||||
"transformers_falcon_for_token_classification",
|
"transformers_falcon_for_token_classification",
|
||||||
"transformers_falcon_for_question_answering",
|
"transformers_falcon_for_question_answering",
|
||||||
|
"transformers_gptj_lm", # lead to OOM when running in ci
|
||||||
|
"transformers_gptj_for_question_answering",
|
||||||
|
"transformers_gptj_for_sequence_classification",
|
||||||
]:
|
]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user