mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[misc] Use dist logger in plugins (#6011)
* use dist logger in plugins * remove trash * print on rank 0 --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@@ -33,6 +31,7 @@ from colossalai.checkpoint_io.utils import (
|
||||
)
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
@@ -62,9 +61,7 @@ class OptimizerParamCheckState(enum.Enum):
|
||||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
def __init__(
|
||||
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True
|
||||
) -> None:
|
||||
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
if precision == "fp16":
|
||||
@@ -76,7 +73,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
module = module.to(get_accelerator().get_current_device())
|
||||
self.module = module
|
||||
self.convert_fn = None
|
||||
if self.dtype is not None and cast_inputs:
|
||||
if self.dtype is not None:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
self.overlap_allgather = overlap_allgather
|
||||
if overlap_allgather:
|
||||
@@ -140,7 +137,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
"""
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
@@ -177,10 +174,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
if self.coordinator.is_master():
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(
|
||||
self.logger.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
f"index located at {save_index_file}.",
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
|
||||
@@ -267,7 +265,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
|
||||
return
|
||||
from peft import PeftModel
|
||||
|
||||
@@ -336,7 +334,6 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
cpu_offload: bool = False,
|
||||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
cast_inputs: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||
@@ -363,8 +360,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
)
|
||||
self.lora_enabled = False
|
||||
self.verbose = verbose
|
||||
self.cast_inputs = cast_inputs
|
||||
|
||||
self.logger = get_dist_logger()
|
||||
# set class name with stage, for better error message
|
||||
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
||||
|
||||
@@ -400,7 +396,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
|
||||
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
|
||||
self.lora_enabled = True
|
||||
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
|
||||
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
|
||||
|
||||
if bnb_quantization_config is not None:
|
||||
model = quantize_model(model, bnb_quantization_config)
|
||||
@@ -449,8 +445,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
origin_param = name2param[origin_key]
|
||||
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
|
||||
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
|
||||
warnings.warn(
|
||||
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
|
||||
self.logger.warning(
|
||||
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.",
|
||||
ranks=[0],
|
||||
)
|
||||
elif (
|
||||
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
||||
@@ -478,10 +475,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(
|
||||
model,
|
||||
self.precision,
|
||||
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
|
||||
cast_inputs=self.cast_inputs,
|
||||
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
|
||||
)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
@@ -493,7 +487,10 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
optimizer = cast_to_distributed(optimizer)
|
||||
|
||||
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
|
||||
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
|
||||
self.logger.warning(
|
||||
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
|
||||
ranks=[0],
|
||||
)
|
||||
zero_optim_kwargs["partition_grad"] = False
|
||||
zero_stage = 0
|
||||
|
||||
|
Reference in New Issue
Block a user