diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index b05cb660b..89b7f1f3b 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -166,6 +166,7 @@ jobs: LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LLAMA_PATH: /data/scratch/llama-tiny MOE_TENSOR_PATH: /data/scratch/moe_tensors + HF_ENDPOINT: https://hf-mirror.com - name: Collate artifact env: diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index f8ca07d97..fd7dc42e5 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -70,6 +70,7 @@ jobs: LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LLAMA_PATH: /data/scratch/llama-tiny MOE_TENSOR_PATH: /data/scratch/moe_tensors + HF_ENDPOINT: https://hf-mirror.com - name: Notify Lark id: message-preparation diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index c56b6211d..1534fa7f6 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -79,3 +79,4 @@ jobs: LD_LIBRARY_PATH: /github/home/.tensornvme/lib LLAMA_PATH: /data/scratch/llama-tiny MOE_TENSOR_PATH: /data/scratch/moe_tensors + HF_ENDPOINT: https://hf-mirror.com diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index 68fb3a090..c2cc85b3f 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -73,3 +73,4 @@ jobs: LD_LIBRARY_PATH: /github/home/.tensornvme/lib LLAMA_PATH: /data/scratch/llama-tiny MOE_TENSOR_PATH: /data/scratch/moe_tensors + HF_ENDPOINT: https://hf-mirror.com diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 9e6265b1b..1bd24b0a2 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -67,6 +67,7 @@ jobs: LD_LIBRARY_PATH: /github/home/.tensornvme/lib LLAMA_PATH: /data/scratch/llama-tiny MOE_TENSOR_PATH: /data/scratch/moe_tensors + HF_ENDPOINT: https://hf-mirror.com - name: Notify Lark id: message-preparation diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 62046bc36..1684fd702 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, T import numpy as np import torch import torch.distributed as dist +from peft import PeftModel from torch import Tensor, inf from torch.distributed import ProcessGroup, get_world_size from torch.nn import Module, SyncBatchNorm @@ -219,11 +220,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): with self._hook_context(): return super().forward(*args, **kwargs) - def unwrap(self): - module = super().unwrap() - if isinstance(module, DDP): - module = module.module - return module + def unwrap(self, unwrap_peft: bool = True): + model = self.module + if isinstance(model, DDP): + model = model.module + if unwrap_peft and isinstance(model, PeftModel): + model = model.get_base_model() + return model def _force_wait_all_gather(self): for p in self.module.parameters(): @@ -1509,7 +1512,7 @@ class HybridParallelPlugin(PipelinePluginBase): from peft import PeftModel, get_peft_model assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." - assert self.pp_size == 1 and self.tp_size == 1 + assert self.tp_size == 1 self.lora_enabled = True self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0]) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index d29098a6e..642969be3 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -359,23 +359,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async ) - def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): - if os.path.isfile(checkpoint): - self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0]) - return - from peft import PeftModel - - assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None): + assert isinstance(model, LowLevelZeroModel), "Please boost the model before saving!" model._force_wait_all_gather() - peft_model = model.unwrap() - assert isinstance( - peft_model, PeftModel - ), "The model doesn't have lora adapters, please enable lora before saving." - return peft_model.save_pretrained( - checkpoint, - safe_serialization=use_safetensors, - state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()), - ) + super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict=state_dict) class LowLevelZeroPlugin(DPPluginBase): diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index acec7e82d..e74b1a959 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -2,6 +2,7 @@ from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union import torch import torch.nn as nn +from peft import PeftModel from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler @@ -166,7 +167,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): ) def save_lora_as_pretrained( - self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False + self, + model: Union[nn.Module, ModelWrapper], + checkpoint: str, + use_safetensors: bool = False, + state_dict: Optional[dict] = None, ) -> None: """ Save the lora adapters and adapter configuration file to checkpoint directory. @@ -174,15 +179,17 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): from peft import PeftModel assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + peft_model = model.unwrap(unwrap_peft=False) + assert isinstance( + peft_model, PeftModel + ), "The model doesn't have lora adapters, please enable lora before saving." + if state_dict is None: + state_dict = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, peft_model.state_dict()) if self.coordinator.is_master(): - peft_model = model.unwrap() - assert isinstance( - peft_model, PeftModel - ), "The model doesn't have lora adapters, please enable lora before saving." return peft_model.save_pretrained( checkpoint, safe_serialization=use_safetensors, - state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()), + state_dict=state_dict, ) @@ -191,8 +198,11 @@ class TorchDDPModel(ModelWrapper): super().__init__(module) self.module = DDP(module, *args, **kwargs) - def unwrap(self): - return self.module.module + def unwrap(self, unwrap_peft: bool = True) -> nn.Module: + model = self.module.module + if unwrap_peft and isinstance(model, PeftModel): + model = model.get_base_model() + return model class TorchDDPPlugin(DPPluginBase): diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index dca7d43c0..d713203fe 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -437,9 +437,6 @@ class TorchFSDPModel(ModelWrapper): super().__init__(module) self.module = FSDP(module, *args, **kwargs) - def unwrap(self): - return self.module - class FSDPOptimizerWrapper(OptimizerWrapper): def __init__(self, optimizer: Optimizer, model: nn.Module): diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 40024f8a8..da57ee829 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -437,7 +437,11 @@ class CheckpointIO(ABC): @abstractmethod def save_lora_as_pretrained( - self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False + self, + model: Union[nn.Module, ModelWrapper], + checkpoint: str, + use_safetensors: bool = False, + state_dict: Optional[dict] = None, ) -> None: """ Save the lora adapters and adapter configuration file to a pretrained checkpoint directory. @@ -446,4 +450,5 @@ class CheckpointIO(ABC): model (Union[nn.Module, ModelWrapper]): A model boosted by Booster. checkpoint (str): Path to the checkpoint directory. It must be a local path. use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False. + state_dict (Optional[dict], optional): The state dict to save. Defaults to None. """ diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index c38958ee3..3e600c94d 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -308,5 +308,7 @@ class GeneralCheckpointIO(CheckpointIO): ) ) - def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None: + def save_lora_as_pretrained( + self, model: nn.Module, checkpoint: str, use_safetensors: bool = False, state_dict: Optional[dict] = None + ) -> None: raise NotImplementedError diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index bd814f426..5de32e666 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -33,6 +33,8 @@ from .utils import ( async_save_state_dict_shards, create_pinned_state_dict, gather_distributed_param, + gather_state_dict_fast, + get_lora_state_dict, get_model_base_filenames, get_optimizer_base_filenames, is_safetensors_available, @@ -1137,7 +1139,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): return state_ - def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): + def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None): if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return @@ -1145,12 +1147,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): assert isinstance(model, ModelWrapper), "Please boost the model before saving!" model._force_wait_all_gather() - peft_model = model.unwrap() + peft_model = model.unwrap(unwrap_peft=False) assert isinstance( peft_model, PeftModel ), "The model doesn't have lora adapters, please enable lora before saving." - return peft_model.save_pretrained( - checkpoint, - safe_serialization=use_safetensors, - state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()), - ) + if state_dict is None: + state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()) + if self.pp_size > 1: + lora_state_dict = get_lora_state_dict(peft_model, state_dict) + gathered_lora_state_dict = gather_state_dict_fast(lora_state_dict, self.pp_group, device="cpu") + if self.pp_rank == 0: + state_dict.update(gathered_lora_state_dict) + state_dict = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict) + if self.coordinator.is_master(): + return peft_model.save_pretrained( + checkpoint, + safe_serialization=use_safetensors, + state_dict=state_dict, + ) diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index 04655dec5..586c7863f 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -10,6 +10,7 @@ import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import get_global_rank +from torch.utils._pytree import tree_map from colossalai.checkpoint_io import CheckpointIndexFile from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO @@ -17,6 +18,8 @@ from colossalai.checkpoint_io.index_file import CheckpointIndexFile from colossalai.checkpoint_io.utils import ( StateDictSharder, gather_distributed_param, + gather_state_dict_fast, + get_lora_state_dict, get_model_base_filenames, get_optimizer_base_filenames, load_shard_state_dict, @@ -889,3 +892,26 @@ class MoECheckpointIO(HybridParallelCheckpointIO): optimizer.optim.state[param] = sharded_state sharded_optimizer_loading_epilogue(optimizer.optim) dist.barrier() + + def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict=None): + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + from peft import PeftModel + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() + peft_model = model.unwrap(unwrap_peft=False) + assert isinstance( + peft_model, PeftModel + ), "The model doesn't have lora adapters, please enable lora before saving." + if state_dict is None: + state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()) + if self.ep_size > 1: + lora_state_dict = get_lora_state_dict(peft_model, state_dict) + moe_params = set(n for n, p in peft_model.named_parameters() if is_moe_tensor(p)) + expert_state_dict = {n: p for n, p in lora_state_dict.items() if n in moe_params} + gathered_expert_state_dict = gather_state_dict_fast(expert_state_dict, self.ep_group) + if self.ep_rank == 0: + state_dict.update(gathered_expert_state_dict) + return super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 50b6f1438..f4a050647 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -2,6 +2,7 @@ import concurrent.futures import os import re +import warnings from collections import abc as container_abcs from collections import defaultdict from itertools import chain @@ -9,8 +10,12 @@ from pathlib import Path from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union import torch +import torch.distributed as dist import torch.nn as nn from packaging.version import Version +from peft import PeftModel, PeftType +from peft.utils.other import EMBEDDING_LAYER_NAMES, check_file_exists_on_hf_hub +from peft.utils.save_and_load import get_embedding_layer_name, has_valid_embedding_base_layer from torch.optim import Optimizer from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -21,6 +26,7 @@ from colossalai.tensor.d_tensor import ( to_global, to_global_for_customized_distributed_tensor, ) +from colossalai.utils import get_current_device from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat SAFE_WEIGHTS_NAME = "model.safetensors" @@ -1004,3 +1010,138 @@ def load_state_dict_shards( futures.append(future) for future in concurrent.futures.as_completed(futures): yield future.result() + + +# adapted from `peft/utils/save_and_load.py` +def get_lora_state_dict( + model: PeftModel, state_dict: dict, adapter_name="default", save_embedding_layers="auto" +) -> dict: + config = model.peft_config[adapter_name] + if config.peft_type != PeftType.LORA: + raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.") + # to_return = lora_state_dict(model, bias=model.peft_config.bias) + # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` + # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP + bias = config.bias + if bias == "none": + to_return = {k: state_dict[k] for k in state_dict if "lora_" in k} + elif bias == "all": + to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + for k in state_dict: + if "lora_" in k: + to_return[k] = state_dict[k] + bias_name = k.split("lora_")[0] + "bias" + if bias_name in state_dict: + to_return[bias_name] = state_dict[bias_name] + else: + raise NotImplementedError + to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))} + if config.use_dora: + # Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a + # ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since + # we want the state_dict format not to change, we remove the "weight" part. + new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight" + + def renamed_dora_weights(k): + if k.endswith(new_dora_suffix): + k = k[:-7] # remove ".weight" + return k + + to_return = {renamed_dora_weights(k): v for k, v in to_return.items()} + + # DEAL WITH EMBEDDINGS + # check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary + is_embedding_in_target_modules = False + if ( + save_embedding_layers == "auto" + and hasattr(config, "target_modules") + and any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES) + ): + warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.") + save_embedding_layers = is_embedding_in_target_modules = True + elif save_embedding_layers == "auto": + vocab_size = getattr(getattr(model, "config", None), "vocab_size", None) + model_id = getattr(config, "base_model_name_or_path", None) + + # For some models e.g. diffusers the text config file is stored in a subfolder + # we need to make sure we can download that config. + has_base_config = False + + # ensure that this check is not performed in HF offline mode, see #1452 + if model_id is not None: + local_config_exists = os.path.exists(os.path.join(model_id, "config.json")) + exists = local_config_exists or check_file_exists_on_hf_hub(model_id, "config.json") + if exists is None: + # check failed, could not determine if it exists or not + warnings.warn( + f"Could not find a config file in {model_id} - will assume that the vocabulary was not modified." + ) + has_base_config = False + else: + has_base_config = exists + + # check if the vocab size of the base model is different from the vocab size of the finetuned model + if ( + vocab_size + and model_id + and has_base_config + and (vocab_size != model.config.__class__.from_pretrained(model_id).vocab_size) + ): + warnings.warn( + "Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning." + ) + save_embedding_layers = True + else: + save_embedding_layers = False + + if save_embedding_layers and hasattr(model, "get_input_embeddings"): + for layer in [model.get_input_embeddings(), model.get_output_embeddings()]: + if not is_embedding_in_target_modules or has_valid_embedding_base_layer(layer): + # support from version >= 0.6.2 + embedding_module_name = get_embedding_layer_name(model, layer, is_embedding_in_target_modules) + if embedding_module_name: + to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k}) + elif save_embedding_layers: + warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.") + + return to_return + + +def gather_state_dict_fast( + state_dict: Dict[str, torch.Tensor], + group: dist.ProcessGroup, + device: Optional[Union[torch.device, str]] = None, + dst: int = 0, +) -> Optional[Dict[str, torch.Tensor]]: + if device is None: + device = get_current_device() + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + metadata = [(k, v.shape, v.dtype) for k, v in state_dict.items()] + all_meta_data = [None] * world_size + if rank == dst: + returned_state_dict = state_dict.copy() + dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group) + for i, target_metadata in enumerate(all_meta_data): + if i == dst: + continue + ops = [] + for k, shape, dtype in target_metadata: + buffer = torch.empty(shape, dtype=dtype, device=get_current_device()) + returned_state_dict[k] = buffer + ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group)) + reqs = dist.batch_isend_irecv(ops) + for req, (k, *_) in zip(reqs, target_metadata): + req.wait() + returned_state_dict[k] = returned_state_dict[k].to(device) + return returned_state_dict + else: + dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group) + ops = [] + for k, *_ in metadata: + ops.append(dist.P2POp(dist.isend, state_dict[k], dist.get_global_rank(group, dst), group)) + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py index 58df09b85..d112c2723 100644 --- a/colossalai/interface/model.py +++ b/colossalai/interface/model.py @@ -1,4 +1,5 @@ import torch.nn as nn +from peft import PeftModel class ModelWrapper(nn.Module): @@ -13,13 +14,17 @@ class ModelWrapper(nn.Module): super().__init__() self.module = module - def unwrap(self): + def unwrap(self, unwrap_peft: bool = True): """ Unwrap the model to return the original model for checkpoint saving/loading. """ if isinstance(self.module, ModelWrapper): - return self.module.unwrap() - return self.module + model = self.module.unwrap() + else: + model = self.module + if unwrap_peft and isinstance(model, PeftModel): + model = model.get_base_model() + return model def forward(self, *args, **kwargs): return self.module(*args, **kwargs) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 8ae8a516f..883ae5f66 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -156,7 +156,9 @@ def _check_for_nccl_hccl_backend(group): while isinstance(pg, c10d._ProcessGroupWrapper): pg = pg.wrapped_pg - return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and pg.name() == c10d.Backend.NCCL + return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and ( + pg.name() == c10d.Backend.NCCL or pg.name() == c10d.Backend.HCCL + ) def _check_device(group): diff --git a/colossalai/shardformer/modeling/deepseek_v3.py b/colossalai/shardformer/modeling/deepseek_v3.py index 5d8031d5c..7515c2064 100644 --- a/colossalai/shardformer/modeling/deepseek_v3.py +++ b/colossalai/shardformer/modeling/deepseek_v3.py @@ -4,9 +4,10 @@ import numpy as np import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from colossalai.lazy import LazyInitContext from colossalai.moe._operation import ( @@ -16,6 +17,7 @@ from colossalai.moe._operation import ( EPGradScalerOut, all_to_all_uneven, ) +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer.linear import ParallelModule from colossalai.shardformer.shard.utils import set_tensors_to_none from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group @@ -167,6 +169,9 @@ def deepseek_v3_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + hidden_states_internal: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -203,8 +208,11 @@ def deepseek_v3_model_forward( ) position_ids = position_ids.unsqueeze(0) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + if stage_manager is None or stage_manager.is_first_stage(): + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + else: + inputs_embeds = hidden_states_internal if self._use_flash_attention_2: # 2d mask is passed through the layers @@ -226,7 +234,11 @@ def deepseek_v3_model_forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for i, decoder_layer in enumerate(self.layers): + if stage_index is not None: + start_idx, end_idx = stage_index + else: + start_idx, end_idx = 0, len(self.layers) + for i, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -258,7 +270,8 @@ def deepseek_v3_model_forward( if output_attentions: all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) + if stage_manager is None or stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: @@ -267,6 +280,10 @@ def deepseek_v3_model_forward( next_cache = None if use_cache: next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if stage_manager is not None and not stage_manager.is_last_stage(): + return { + "hidden_states_internal": hidden_states, + } if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -275,3 +292,94 @@ def deepseek_v3_model_forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) + + +def deepseek_v3_for_causal_lm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + hidden_states_internal: Optional[torch.Tensor] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM + >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = deepseek_v3_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + stage_index=stage_index, + hidden_states_internal=hidden_states_internal, + ) + if stage_manager is not None and not stage_manager.is_last_stage(): + return outputs + + hidden_states = outputs[0] + + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/colossalai/shardformer/policies/deepseek_v3.py b/colossalai/shardformer/policies/deepseek_v3.py index 58c1243d1..8b386bcaa 100644 --- a/colossalai/shardformer/policies/deepseek_v3.py +++ b/colossalai/shardformer/policies/deepseek_v3.py @@ -1,9 +1,14 @@ -from typing import Dict, Union +from functools import partial +from typing import Callable, Dict, List, Union import torch.nn as nn from colossalai.shardformer.layer import FusedRMSNorm -from colossalai.shardformer.modeling.deepseek_v3 import EpDeepseekV3MoE +from colossalai.shardformer.modeling.deepseek_v3 import ( + EpDeepseekV3MoE, + deepseek_v3_for_causal_lm_forward, + deepseek_v3_model_forward, +) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"] @@ -12,8 +17,9 @@ __all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"] class DeepseekV3Policy(Policy): def config_sanity_check(self): assert not self.shard_config.enable_tensor_parallelism, "DeepSeekV3 does not support tensor parallelism" - assert self.shard_config.pipeline_stage_manager is None, "DeepSeekV3 does not support pipeline parallelism" assert not self.shard_config.enable_sequence_parallelism, "DeepSeekV3 does not support sequence parallelism" + if self.shard_config.pipeline_stage_manager: + assert not self.shard_config.pipeline_stage_manager.use_zbv, "DeepSeekV3 does not support ZBV" def preprocess(self): return self.model @@ -23,7 +29,10 @@ class DeepseekV3Policy(Policy): policy = {} # support gradient checkpointing - # policy["DeepseekV3Model"] = ModulePolicyDescription(method_replacement={"forward": deepseek_v3_model_forward}) + if self.shard_config.pipeline_stage_manager is None: + policy["DeepseekV3Model"] = ModulePolicyDescription( + method_replacement={"forward": deepseek_v3_model_forward} + ) if self.shard_config.expert_parallel_size > 1: # expert parallel @@ -74,10 +83,82 @@ class DeepseekV3Policy(Policy): def postprocess(self): return self.model + def set_pipeline_forward(self, model_cls: str, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + num_layers = self.model.config.num_hidden_layers + stage_manager = self.pipeline_stage_manager + + layers_per_stage = stage_manager.distribute_layers(num_layers) + stage_index = stage_manager.get_stage_index(layers_per_stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + return + + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + module = self.model + if module.__class__.__name__.startswith("PeftModel"): + module = module.get_base_model() + if module.__class__.__name__ != "DeepseekV3Model": + module = module.model + + stage_manager = self.pipeline_stage_manager + + held_layers = [] + + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + stage_manager.stage_indices = stage_indices + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + # for zbv, when is_first_stage (last fwd), we append norm + # for interleaved, when is_last_stage (last fwd), we also append norm + held_layers.append(module.norm) + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + return held_layers + class DeepseekV3ModelPolicy(DeepseekV3Policy): - pass + def module_policy(self): + policy = super().module_policy() + if self.shard_config.pipeline_stage_manager: + self.set_pipeline_forward("DeepseekV3Model", deepseek_v3_model_forward, policy) + return policy class DeepseekV3ForCausalLMPolicy(DeepseekV3Policy): - pass + def module_policy(self): + policy = super().module_policy() + if self.shard_config.pipeline_stage_manager: + self.set_pipeline_forward("DeepseekV3ForCausalLM", deepseek_v3_for_causal_lm_forward, policy) + return policy + + def get_held_layers(self): + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + elif stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + return held_layers diff --git a/examples/language/deepseek/benchmark.py b/examples/language/deepseek/benchmark.py index 64a8b8a7d..7ef1e4ebc 100644 --- a/examples/language/deepseek/benchmark.py +++ b/examples/language/deepseek/benchmark.py @@ -60,9 +60,9 @@ MODEL_CONFIGS = { attn_implementation="flash_attention_2", trust_remote_code=True, ), - "v3-6b": AutoConfig.from_pretrained( + "v3-7b": AutoConfig.from_pretrained( "deepseek-ai/DeepSeek-V3", - num_hidden_layers=5, + num_hidden_layers=6, first_k_dense_replace=2, n_routed_experts=32, vocab_size=8192, @@ -210,14 +210,15 @@ def main(): config, trust_remote_code=True, attn_implementation=attn_impl, torch_dtype=torch.bfloat16 ).to(torch.bfloat16) if args.enable_lora: - booster.enable_lora( + model = booster.enable_lora( model, lora_config=LoraConfig(task_type="CAUSAL_LM", target_modules=["gate_proj", "up_proj", "down_proj"]), ) if args.grad_checkpoint: model.gradient_checkpointing_enable() - if model.__class__.__name__.startswith("DeepseekV3"): + if config.__class__.__name__.startswith("DeepseekV3"): + model.config.use_cache = False model.eval() # enable grad for moe layers for m in model.modules(): @@ -257,40 +258,42 @@ def main(): ) as prof: # , distributed_debug_mode(10, enable=True): if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1: data_iter = iter(dataloader) - for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): - performance_evaluator.on_step_start(step) - outputs = booster.execute_pipeline( - data_iter, - model, - criterion=lambda outputs, inputs: outputs[0], - optimizer=optimizer, - return_loss=True, - ) - loss = outputs["loss"] - if dist.get_rank() == dist.get_world_size() - 1: - print(f"Step {step} loss: {loss}") - optimizer.step() - optimizer.zero_grad() + with tqdm( + range(len(dataloader)), desc="Step", disable=dist.get_rank() != dist.get_world_size() - 1 + ) as pbar: + for step in pbar: + performance_evaluator.on_step_start(step) + outputs = booster.execute_pipeline( + data_iter, + model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=optimizer, + return_loss=True, + ) + loss = outputs["loss"] + loss_scalar = loss.item() if loss is not None else None + pbar.set_postfix({"loss": loss_scalar}) + optimizer.step() + optimizer.zero_grad() - performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) - prof.step() - print(f"rank {dist.get_rank()} step {step} passed") + performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) + prof.step() else: - for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): - performance_evaluator.on_step_start(step) - outputs = model(**batch) - loss = outputs[0] - del outputs # free memory + with tqdm(dataloader, desc="Step", disable=not coordinator.is_master()) as pbar: + for step, batch in enumerate(pbar): + performance_evaluator.on_step_start(step) + outputs = model(**batch) + loss = outputs[0] + del outputs # free memory - if dist.get_rank() == dist.get_world_size() - 1: - print(f"Step {step} loss: {loss}") + pbar.set_postfix({"loss": loss.item()}) - booster.backward(loss, optimizer) - optimizer.step() - optimizer.zero_grad() + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() - performance_evaluator.on_step_end(**batch) - prof.step() + performance_evaluator.on_step_end(**batch) + prof.step() performance_evaluator.on_fit_end() coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 9f0180d52..abebd0fab 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -126,6 +126,7 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo booster.save_lora_as_pretrained(model, model_ckpt_path) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False) + dist.barrier() new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config) new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion) check_state_dict_equal(model.state_dict(), new_model.state_dict()) diff --git a/tests/test_shardformer/test_model/test_shard_deepseek_v3.py b/tests/test_shardformer/test_model/test_shard_deepseek_v3.py index 798fb639c..aeded5466 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek_v3.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek_v3.py @@ -51,7 +51,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if target_grad is None: continue target_grad = target_grad.view(-1).chunk(dist.get_world_size(pg))[dist.get_rank(pg)] - assert_close(grad, target_grad, atol=3e-1, rtol=0) + assert_close(grad, target_grad, atol=5e-1, rtol=0) @parameterize(