mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[fp8] support gemini plugin (#5978)
* [fp8] refactor hook * [fp8] support gemini plugin * [example] add fp8 option for llama benchmark
This commit is contained in:
@@ -31,6 +31,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
@@ -40,7 +41,6 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
||||
|
||||
from .fp8_hook import FP8Hook
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
||||
|
Reference in New Issue
Block a user