fix merge

This commit is contained in:
wangbluo 2024-08-20 09:26:04 +00:00
commit 2d362ac090
9 changed files with 118 additions and 82 deletions

View File

@ -1,4 +1,3 @@
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator, List, Optional, Union from typing import Any, Callable, Dict, Iterator, List, Optional, Union
@ -8,6 +7,8 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.logging import get_dist_logger
SUPPORT_PEFT = False SUPPORT_PEFT = False
try: try:
import peft import peft
@ -81,12 +82,15 @@ class Booster:
plugin, Plugin plugin, Plugin
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}." ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
self.plugin = plugin self.plugin = plugin
self.logger = get_dist_logger()
# set accelerator # set accelerator
if self.plugin and self.plugin.control_device(): if self.plugin and self.plugin.control_device():
self.accelerator = None self.accelerator = None
if device is not None: if device is not None:
warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.") self.logger.warning(
"The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0]
)
else: else:
device = device or "cuda" device = device or "cuda"
self.accelerator = Accelerator(device) self.accelerator = Accelerator(device)
@ -94,7 +98,10 @@ class Booster:
# set precision # set precision
if self.plugin and self.plugin.control_precision(): if self.plugin and self.plugin.control_precision():
if mixed_precision is not None: if mixed_precision is not None:
warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.") self.logger.warning(
"The plugin will control the precision," "so the mixed_precision argument will be ignored.",
ranks=[0],
)
self.mixed_precision = None self.mixed_precision = None
elif mixed_precision is None: elif mixed_precision is None:
self.mixed_precision = None self.mixed_precision = None
@ -267,8 +274,9 @@ class Booster:
), "Please provide pretrained directory path if not passing in lora configuration." ), "Please provide pretrained directory path if not passing in lora configuration."
if quantize is True: if quantize is True:
if bnb_quantization_config is not None: if bnb_quantization_config is not None:
warnings.warn( self.logger.warning(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk." "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.",
ranks=[0],
) )
else: else:
bnb_quantization_config = BnbQuantizationConfig( bnb_quantization_config = BnbQuantizationConfig(

View File

@ -1,5 +1,4 @@
import gc import gc
import logging
import os import os
import random import random
from pathlib import Path from pathlib import Path
@ -27,6 +26,7 @@ from colossalai.checkpoint_io.utils import (
) )
from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
@ -63,6 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
""" """
@ -118,7 +119,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
""" """
assert isinstance(model, GeminiDDP), "Please boost the model before saving!" assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file", ranks=[0])
return return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True) Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
@ -143,10 +144,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path) save_config_file(model.unwrap(), checkpoint_path)
logging.info( self.logger.info(
f"The model is split into checkpoint shards. " f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}.",
ranks=[0],
) )
def load_sharded_model( def load_sharded_model(
@ -168,7 +170,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return return
Path(checkpoint).mkdir(parents=True, exist_ok=True) Path(checkpoint).mkdir(parents=True, exist_ok=True)
@ -201,10 +203,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
logging.info( self.logger.info(
f"The optimizer is going to be split to checkpoint shards. " f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}.",
ranks=[0],
) )
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
@ -214,7 +217,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
""" """
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
if not os.path.isfile(checkpoint_index_file): if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file") self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file", ranks=[0])
assert isinstance(optimizer, GeminiOptimizer) assert isinstance(optimizer, GeminiOptimizer)
@ -371,9 +374,12 @@ class GeminiPlugin(DPPluginBase):
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if get_accelerator().name == "npu": if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy" assert placement_policy == "static", "NPU only supports static placement policy"
self.logger = get_dist_logger()
if enable_async_reduce and not pin_memory: if enable_async_reduce and not pin_memory:
logging.warning( self.logger.warning(
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set." f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set.",
ranks=[0],
) )
pin_memory = True pin_memory = True
self.gemini_config = dict( self.gemini_config = dict(

View File

@ -1,6 +1,5 @@
import ctypes import ctypes
import random import random
import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from copy import deepcopy from copy import deepcopy
@ -27,6 +26,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -1036,6 +1036,7 @@ class HybridParallelPlugin(PipelinePluginBase):
inner_ring_size: int = None, inner_ring_size: int = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.logger = get_dist_logger()
assert ( assert (
dist.get_world_size() % (tp_size * pp_size) == 0 dist.get_world_size() % (tp_size * pp_size) == 0
@ -1053,8 +1054,9 @@ class HybridParallelPlugin(PipelinePluginBase):
tp_size > 1 tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1: if sp_size != 1:
warnings.warn( self.logger.warning(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size.",
ranks=[0],
) )
self.sp_size = 1 self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
@ -1143,7 +1145,12 @@ class HybridParallelPlugin(PipelinePluginBase):
else: else:
raise NotImplementedError() raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn": if sequence_parallelism_mode == "ring_attn":
assert parallel_output, "Ring Attention doesn't support gathering output yet." if not parallel_output:
self.logger.warning(
"parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet.",
ranks=[0],
)
parallel_output = True
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
@ -1249,7 +1256,10 @@ class HybridParallelPlugin(PipelinePluginBase):
optimizer = cast_to_distributed(optimizer) optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_config["partition_grad"] = False zero_config["partition_grad"] = False
zero_stage = 0 zero_stage = 0
@ -1306,9 +1316,10 @@ class HybridParallelPlugin(PipelinePluginBase):
else: else:
is_zero = self.dp_size > 1 is_zero = self.dp_size > 1
if self.dp_size == 1: if self.dp_size == 1:
warnings.warn( self.logger.warning(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0." "If you do not intend to use cpu_offload, please consider set zero_stage=0.",
ranks=[0],
) )
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
@ -1351,7 +1362,7 @@ class HybridParallelPlugin(PipelinePluginBase):
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
if return_outputs: if return_outputs:
warnings.warn("return_outputs may lead to significant extra memory consumption.") self.logger.warning("return_outputs may lead to significant extra memory consumption.", ranks=[0])
# Create a context for gradient synchronization based on the optimizer type. # Create a context for gradient synchronization based on the optimizer type.
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
@ -1365,10 +1376,8 @@ class HybridParallelPlugin(PipelinePluginBase):
) )
# run with gradients accumulation # run with gradients accumulation
if ( if model.require_grad_sync == False or (
model.require_grad_sync == False isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
or not torch.is_grad_enabled()
): ):
return outputs return outputs
@ -1468,7 +1477,7 @@ class HybridParallelPlugin(PipelinePluginBase):
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
assert self.pp_size == 1 and self.tp_size == 1 assert self.pp_size == 1 and self.tp_size == 1
self.lora_enabled = True self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
if bnb_quantization_config is not None: if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config) model = quantize_model(model, bnb_quantization_config)

View File

@ -1,7 +1,5 @@
import enum import enum
import logging
import os import os
import warnings
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@ -33,6 +31,7 @@ from colossalai.checkpoint_io.utils import (
) )
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook from colossalai.quantization.fp8_hook import FP8Hook
@ -64,12 +63,7 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin): class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__( def __init__(
self, self, module: nn.Module, precision: str, overlap_allgather: bool = False, use_fp8: bool = False
module: nn.Module,
precision: str,
overlap_allgather: bool = False,
cast_inputs: bool = True,
use_fp8: bool = False,
) -> None: ) -> None:
super().__init__(module) super().__init__(module)
self.dtype = None self.dtype = None
@ -87,8 +81,6 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather self.overlap_allgather = overlap_allgather
self.op_hooks = [] self.op_hooks = []
if self.dtype is not None and cast_inputs:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
if overlap_allgather: if overlap_allgather:
self.op_hooks.append(ZeroOpHook()) self.op_hooks.append(ZeroOpHook())
if use_fp8: if use_fp8:
@ -153,7 +145,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
""" """
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return return
Path(checkpoint).mkdir(parents=True, exist_ok=True) Path(checkpoint).mkdir(parents=True, exist_ok=True)
@ -190,10 +182,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master(): if self.coordinator.is_master():
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
logging.info( self.logger.info(
f"The optimizer is going to be split to checkpoint shards. " f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}.",
ranks=[0],
) )
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
@ -280,7 +273,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return return
from peft import PeftModel from peft import PeftModel
@ -349,7 +342,6 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload: bool = False, cpu_offload: bool = False,
master_weights: bool = True, master_weights: bool = True,
verbose: bool = False, verbose: bool = False,
cast_inputs: bool = True,
fp8_communication: bool = False, fp8_communication: bool = False,
use_fp8: bool = False, use_fp8: bool = False,
) -> None: ) -> None:
@ -379,9 +371,8 @@ class LowLevelZeroPlugin(DPPluginBase):
) )
self.lora_enabled = False self.lora_enabled = False
self.verbose = verbose self.verbose = verbose
self.logger = get_dist_logger()
self.use_fp8 = use_fp8 self.use_fp8 = use_fp8
self.cast_inputs = cast_inputs
# set class name with stage, for better error message # set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
@ -417,7 +408,7 @@ class LowLevelZeroPlugin(DPPluginBase):
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
if bnb_quantization_config is not None: if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config) model = quantize_model(model, bnb_quantization_config)
@ -466,8 +457,9 @@ class LowLevelZeroPlugin(DPPluginBase):
origin_param = name2param[origin_key] origin_param = name2param[origin_key]
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn( self.logger.warning(
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.",
ranks=[0],
) )
elif ( elif (
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
@ -498,7 +490,6 @@ class LowLevelZeroPlugin(DPPluginBase):
model, model,
self.precision, self.precision,
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
cast_inputs=self.cast_inputs,
use_fp8=self.use_fp8, use_fp8=self.use_fp8,
) )
@ -511,7 +502,10 @@ class LowLevelZeroPlugin(DPPluginBase):
optimizer = cast_to_distributed(optimizer) optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0: if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_optim_kwargs["partition_grad"] = False zero_optim_kwargs["partition_grad"] = False
zero_stage = 0 zero_stage = 0

View File

@ -1,4 +1,3 @@
import warnings
from collections import defaultdict from collections import defaultdict
from types import MethodType from types import MethodType
from typing import Callable, List, Optional, OrderedDict, Tuple from typing import Callable, List, Optional, OrderedDict, Tuple
@ -26,6 +25,7 @@ from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import cast_to_distributed from colossalai.nn.optimizer import cast_to_distributed
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
@ -217,12 +217,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
fp8_communication: bool = False, fp8_communication: bool = False,
use_fp8: bool = False, use_fp8: bool = False,
) -> None: ) -> None:
self.logger = get_dist_logger()
if overlap_communication or zero_stage == 2: if overlap_communication or zero_stage == 2:
overlap_communication = False overlap_communication = False
zero_stage = 1 zero_stage = 1
warnings.warn( self.logger.warning(
f"overlap_communication and zero_stage are set to False and 1 because " f"overlap_communication and zero_stage are set to False and 1 because "
f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. " f"ZeRO-2 or comm overlap cause program hang when some experts are not routed.",
ranks=[0],
) )
assert ( assert (
@ -240,8 +242,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
tp_size > 1 tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1: if sp_size != 1:
warnings.warn( self.logger.warning(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode},"
"will ignore the given sequence parallelism size.",
ranks=[0],
) )
self.sp_size = 1 self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
@ -404,8 +408,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
and self.sequence_parallelism_mode == "all_to_all" and self.sequence_parallelism_mode == "all_to_all"
) )
if use_ddp: if use_ddp:
warnings.warn( self.logger.warning(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated" f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
ranks=[0],
) )
self.ddp_config["find_unused_parameters"] = True self.ddp_config["find_unused_parameters"] = True
@ -462,9 +467,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
) )
else: else:
if self.dp_size <= 1: if self.dp_size <= 1:
warnings.warn( self.logger.warning(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0." "If you do not intend to use cpu_offload, please consider set zero_stage=0.",
ranks=[0],
) )
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = MoeHybridParallelZeroOptimizer( optimizer = MoeHybridParallelZeroOptimizer(

View File

@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -21,6 +22,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
""" """

View File

@ -1,6 +1,4 @@
import logging
import os import os
import warnings
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
@ -30,6 +28,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
@ -40,6 +39,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool): def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
@ -88,7 +88,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
""" """
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True) Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
@ -117,7 +117,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
utils.save_config_file(model.unwrap(), checkpoint_path) utils.save_config_file(model.unwrap(), checkpoint_path)
logging.info( self.logger.info(
f"The model is split into checkpoint shards. " f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}."
@ -162,7 +162,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return return
Path(checkpoint).mkdir(parents=True, exist_ok=True) Path(checkpoint).mkdir(parents=True, exist_ok=True)
@ -200,7 +200,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
logging.info( self.logger.info(
f"The optimizer is going to be split to checkpoint shards. " f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}."
@ -313,6 +313,7 @@ class TorchFSDPPlugin(DPPluginBase):
sync_module_states=sync_module_states, sync_module_states=sync_module_states,
) )
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
self.logger = get_dist_logger()
else: else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@ -364,7 +365,7 @@ class TorchFSDPPlugin(DPPluginBase):
if optimizer is not None: if optimizer is not None:
if len(optimizer.param_groups) > 1: if len(optimizer.param_groups) > 1:
warnings.warn( self.logger.warning(
"TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used." "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
) )
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)

View File

@ -694,6 +694,13 @@ class RingAttention(torch.autograd.Function):
) )
return out, softmax_lse, rng_state return out, softmax_lse, rng_state
def _kv_comm(i):
# Avoid overwriting attn input when it shares mem with buffer
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
def _local_ring_forward(): def _local_ring_forward():
# (Hopefully) overlap output correction with next flash attn # (Hopefully) overlap output correction with next flash attn
for i in range(local_sp_size): for i in range(local_sp_size):
@ -702,12 +709,8 @@ class RingAttention(torch.autograd.Function):
# NOTE: waiting outside the current stream will NOT correctly synchronize. # NOTE: waiting outside the current stream will NOT correctly synchronize.
if i > 0: if i > 0:
local_kv_comms[(i + 1) % 2].wait() local_kv_comms[(i + 1) % 2].wait()
if i == 0:
# Avoid overwriting attn input when it shares mem with buffer _kv_comm(i)
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
if i == 0: if i == 0:
# Compute with local KV; no mask # Compute with local KV; no mask
@ -738,6 +741,9 @@ class RingAttention(torch.autograd.Function):
rng_states[i], rng_states[i],
) = _forward(q_block, kv_block[0], kv_block[1], causal=False) ) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record() RingAttention.ATTN_DONE.record()
# Pipeline the next KV comm with output correction instead of the next flash attn
# to minimize idle time when comm takes longer than attn.
_kv_comm(i + 1)
block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
@ -765,15 +771,13 @@ class RingAttention(torch.autograd.Function):
# all new KVs from the previous inner ring # all new KVs from the previous inner ring
for i in range(local_sp_size): for i in range(local_sp_size):
with torch.cuda.stream(sp_streams[i % 2]): with torch.cuda.stream(sp_streams[i % 2]):
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
# Send & recv KV # Send & recv KV
if i > 0: if i > 0:
local_kv_comms[(i + 1) % 2].wait() local_kv_comms[(i + 1) % 2].wait()
if i == 0:
_kv_comm(i)
if ring_num_idx > inter_ring_rank: if ring_num_idx > inter_ring_rank:
kv_block = kv_buffers[i % 2] kv_block = kv_buffers[i % 2]
( (
@ -782,6 +786,8 @@ class RingAttention(torch.autograd.Function):
rng_states[i + local_sp_size * ring_num_idx], rng_states[i + local_sp_size * ring_num_idx],
) = _forward(q1, kv_block[0], kv_block[1], causal=False) ) = _forward(q1, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record() RingAttention.ATTN_DONE.record()
_kv_comm(i + 1)
block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
) )
@ -796,6 +802,8 @@ class RingAttention(torch.autograd.Function):
rng_states[i + local_sp_size * ring_num_idx], rng_states[i + local_sp_size * ring_num_idx],
) = _forward(q, kv_block[0], kv_block[1], causal=False) ) = _forward(q, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record() RingAttention.ATTN_DONE.record()
_kv_comm(i + 1)
block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
) )

View File

@ -1,7 +1,6 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch # this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy import copy
import math import math
import warnings
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch import torch
@ -136,7 +135,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
self.verbose = verbose self.verbose = verbose
self.param_groups_backup = list() self.param_groups_backup = list()
self.logger = get_dist_logger()
# Mapping from integer id to real/fake param tensor, used for checkpointing. # Mapping from integer id to real/fake param tensor, used for checkpointing.
self.id_to_real_params: Dict[int, Parameter] = dict() self.id_to_real_params: Dict[int, Parameter] = dict()
self.id_to_fake_params: Dict[int, Parameter] = dict() self.id_to_fake_params: Dict[int, Parameter] = dict()
@ -148,9 +147,10 @@ class GeminiOptimizer(OptimizerWrapper):
for name, param in module.named_parameters(): for name, param in module.named_parameters():
if is_ddp_ignored(param): if is_ddp_ignored(param):
if param.requires_grad: if param.requires_grad:
warnings.warn( self.logger.warning(
f"Parameter `{name}` is ignored by DDP but requires gradient! " f"Parameter `{name}` is ignored by DDP but requires gradient! "
"You should handle its optimizer update by yourself!" "You should handle its optimizer update by yourself!",
ranks=[0],
) )
else: else:
ddp_param_list.append(param) ddp_param_list.append(param)
@ -842,7 +842,9 @@ class GeminiOptimizer(OptimizerWrapper):
*args, *args,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm") self.logger.warning(
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
)
class GeminiAdamOptimizer(GeminiOptimizer): class GeminiAdamOptimizer(GeminiOptimizer):