[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)

* [shardformer] support pipeline for deepseek v3

* [checkpointio] fix lora save

* [devops] update ci env

* [booster] optimize lora

* fix test

* fix test
This commit is contained in:
Hongxin Liu 2025-02-14 14:48:54 +08:00 committed by GitHub
parent ec73f1b5e2
commit 014837e725
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 478 additions and 91 deletions

View File

@ -166,6 +166,7 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com
- name: Collate artifact - name: Collate artifact
env: env:

View File

@ -70,6 +70,7 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com
- name: Notify Lark - name: Notify Lark
id: message-preparation id: message-preparation

View File

@ -79,3 +79,4 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com

View File

@ -73,3 +73,4 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com

View File

@ -67,6 +67,7 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com
- name: Notify Lark - name: Notify Lark
id: message-preparation id: message-preparation

View File

@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, T
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from peft import PeftModel
from torch import Tensor, inf from torch import Tensor, inf
from torch.distributed import ProcessGroup, get_world_size from torch.distributed import ProcessGroup, get_world_size
from torch.nn import Module, SyncBatchNorm from torch.nn import Module, SyncBatchNorm
@ -219,11 +220,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
with self._hook_context(): with self._hook_context():
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
def unwrap(self): def unwrap(self, unwrap_peft: bool = True):
module = super().unwrap() model = self.module
if isinstance(module, DDP): if isinstance(model, DDP):
module = module.module model = model.module
return module if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
return model
def _force_wait_all_gather(self): def _force_wait_all_gather(self):
for p in self.module.parameters(): for p in self.module.parameters():
@ -1509,7 +1512,7 @@ class HybridParallelPlugin(PipelinePluginBase):
from peft import PeftModel, get_peft_model from peft import PeftModel, get_peft_model
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
assert self.pp_size == 1 and self.tp_size == 1 assert self.tp_size == 1
self.lora_enabled = True self.lora_enabled = True
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0]) self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])

View File

@ -359,23 +359,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
) )
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
if os.path.isfile(checkpoint): assert isinstance(model, LowLevelZeroModel), "Please boost the model before saving!"
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather() model._force_wait_all_gather()
peft_model = model.unwrap() super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict=state_dict)
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
class LowLevelZeroPlugin(DPPluginBase): class LowLevelZeroPlugin(DPPluginBase):

View File

@ -2,6 +2,7 @@ from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from peft import PeftModel
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
@ -166,7 +167,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
) )
def save_lora_as_pretrained( def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
use_safetensors: bool = False,
state_dict: Optional[dict] = None,
) -> None: ) -> None:
""" """
Save the lora adapters and adapter configuration file to checkpoint directory. Save the lora adapters and adapter configuration file to checkpoint directory.
@ -174,15 +179,17 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
from peft import PeftModel from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master(): peft_model = model.unwrap(unwrap_peft=False)
peft_model = model.unwrap()
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
if state_dict is None:
state_dict = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, peft_model.state_dict())
if self.coordinator.is_master():
return peft_model.save_pretrained( return peft_model.save_pretrained(
checkpoint, checkpoint,
safe_serialization=use_safetensors, safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()), state_dict=state_dict,
) )
@ -191,8 +198,11 @@ class TorchDDPModel(ModelWrapper):
super().__init__(module) super().__init__(module)
self.module = DDP(module, *args, **kwargs) self.module = DDP(module, *args, **kwargs)
def unwrap(self): def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
return self.module.module model = self.module.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
return model
class TorchDDPPlugin(DPPluginBase): class TorchDDPPlugin(DPPluginBase):

View File

@ -437,9 +437,6 @@ class TorchFSDPModel(ModelWrapper):
super().__init__(module) super().__init__(module)
self.module = FSDP(module, *args, **kwargs) self.module = FSDP(module, *args, **kwargs)
def unwrap(self):
return self.module
class FSDPOptimizerWrapper(OptimizerWrapper): class FSDPOptimizerWrapper(OptimizerWrapper):
def __init__(self, optimizer: Optimizer, model: nn.Module): def __init__(self, optimizer: Optimizer, model: nn.Module):

View File

@ -437,7 +437,11 @@ class CheckpointIO(ABC):
@abstractmethod @abstractmethod
def save_lora_as_pretrained( def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
use_safetensors: bool = False,
state_dict: Optional[dict] = None,
) -> None: ) -> None:
""" """
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory. Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
@ -446,4 +450,5 @@ class CheckpointIO(ABC):
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster. model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
checkpoint (str): Path to the checkpoint directory. It must be a local path. checkpoint (str): Path to the checkpoint directory. It must be a local path.
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False. use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
state_dict (Optional[dict], optional): The state dict to save. Defaults to None.
""" """

View File

@ -308,5 +308,7 @@ class GeneralCheckpointIO(CheckpointIO):
) )
) )
def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None: def save_lora_as_pretrained(
self, model: nn.Module, checkpoint: str, use_safetensors: bool = False, state_dict: Optional[dict] = None
) -> None:
raise NotImplementedError raise NotImplementedError

View File

@ -33,6 +33,8 @@ from .utils import (
async_save_state_dict_shards, async_save_state_dict_shards,
create_pinned_state_dict, create_pinned_state_dict,
gather_distributed_param, gather_distributed_param,
gather_state_dict_fast,
get_lora_state_dict,
get_model_base_filenames, get_model_base_filenames,
get_optimizer_base_filenames, get_optimizer_base_filenames,
is_safetensors_available, is_safetensors_available,
@ -1137,7 +1139,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
return state_ return state_
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return return
@ -1145,12 +1147,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather() model._force_wait_all_gather()
peft_model = model.unwrap() peft_model = model.unwrap(unwrap_peft=False)
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
if state_dict is None:
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict())
if self.pp_size > 1:
lora_state_dict = get_lora_state_dict(peft_model, state_dict)
gathered_lora_state_dict = gather_state_dict_fast(lora_state_dict, self.pp_group, device="cpu")
if self.pp_rank == 0:
state_dict.update(gathered_lora_state_dict)
state_dict = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
if self.coordinator.is_master():
return peft_model.save_pretrained( return peft_model.save_pretrained(
checkpoint, checkpoint,
safe_serialization=use_safetensors, safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()), state_dict=state_dict,
) )

View File

@ -10,6 +10,7 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import get_global_rank from torch.distributed.distributed_c10d import get_global_rank
from torch.utils._pytree import tree_map
from colossalai.checkpoint_io import CheckpointIndexFile from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
@ -17,6 +18,8 @@ from colossalai.checkpoint_io.index_file import CheckpointIndexFile
from colossalai.checkpoint_io.utils import ( from colossalai.checkpoint_io.utils import (
StateDictSharder, StateDictSharder,
gather_distributed_param, gather_distributed_param,
gather_state_dict_fast,
get_lora_state_dict,
get_model_base_filenames, get_model_base_filenames,
get_optimizer_base_filenames, get_optimizer_base_filenames,
load_shard_state_dict, load_shard_state_dict,
@ -889,3 +892,26 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
optimizer.optim.state[param] = sharded_state optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim) sharded_optimizer_loading_epilogue(optimizer.optim)
dist.barrier() dist.barrier()
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict=None):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap(unwrap_peft=False)
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
if state_dict is None:
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict())
if self.ep_size > 1:
lora_state_dict = get_lora_state_dict(peft_model, state_dict)
moe_params = set(n for n, p in peft_model.named_parameters() if is_moe_tensor(p))
expert_state_dict = {n: p for n, p in lora_state_dict.items() if n in moe_params}
gathered_expert_state_dict = gather_state_dict_fast(expert_state_dict, self.ep_group)
if self.ep_rank == 0:
state_dict.update(gathered_expert_state_dict)
return super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict)

View File

@ -2,6 +2,7 @@
import concurrent.futures import concurrent.futures
import os import os
import re import re
import warnings
from collections import abc as container_abcs from collections import abc as container_abcs
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
@ -9,8 +10,12 @@ from pathlib import Path
from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from packaging.version import Version from packaging.version import Version
from peft import PeftModel, PeftType
from peft.utils.other import EMBEDDING_LAYER_NAMES, check_file_exists_on_hf_hub
from peft.utils.save_and_load import get_embedding_layer_name, has_valid_embedding_base_layer
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
@ -21,6 +26,7 @@ from colossalai.tensor.d_tensor import (
to_global, to_global,
to_global_for_customized_distributed_tensor, to_global_for_customized_distributed_tensor,
) )
from colossalai.utils import get_current_device
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
SAFE_WEIGHTS_NAME = "model.safetensors" SAFE_WEIGHTS_NAME = "model.safetensors"
@ -1004,3 +1010,138 @@ def load_state_dict_shards(
futures.append(future) futures.append(future)
for future in concurrent.futures.as_completed(futures): for future in concurrent.futures.as_completed(futures):
yield future.result() yield future.result()
# adapted from `peft/utils/save_and_load.py`
def get_lora_state_dict(
model: PeftModel, state_dict: dict, adapter_name="default", save_embedding_layers="auto"
) -> dict:
config = model.peft_config[adapter_name]
if config.peft_type != PeftType.LORA:
raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.")
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
bias = config.bias
if bias == "none":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
elif bias == "all":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
for k in state_dict:
if "lora_" in k:
to_return[k] = state_dict[k]
bias_name = k.split("lora_")[0] + "bias"
if bias_name in state_dict:
to_return[bias_name] = state_dict[bias_name]
else:
raise NotImplementedError
to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))}
if config.use_dora:
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
# we want the state_dict format not to change, we remove the "weight" part.
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"
def renamed_dora_weights(k):
if k.endswith(new_dora_suffix):
k = k[:-7] # remove ".weight"
return k
to_return = {renamed_dora_weights(k): v for k, v in to_return.items()}
# DEAL WITH EMBEDDINGS
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
is_embedding_in_target_modules = False
if (
save_embedding_layers == "auto"
and hasattr(config, "target_modules")
and any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES)
):
warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.")
save_embedding_layers = is_embedding_in_target_modules = True
elif save_embedding_layers == "auto":
vocab_size = getattr(getattr(model, "config", None), "vocab_size", None)
model_id = getattr(config, "base_model_name_or_path", None)
# For some models e.g. diffusers the text config file is stored in a subfolder
# we need to make sure we can download that config.
has_base_config = False
# ensure that this check is not performed in HF offline mode, see #1452
if model_id is not None:
local_config_exists = os.path.exists(os.path.join(model_id, "config.json"))
exists = local_config_exists or check_file_exists_on_hf_hub(model_id, "config.json")
if exists is None:
# check failed, could not determine if it exists or not
warnings.warn(
f"Could not find a config file in {model_id} - will assume that the vocabulary was not modified."
)
has_base_config = False
else:
has_base_config = exists
# check if the vocab size of the base model is different from the vocab size of the finetuned model
if (
vocab_size
and model_id
and has_base_config
and (vocab_size != model.config.__class__.from_pretrained(model_id).vocab_size)
):
warnings.warn(
"Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning."
)
save_embedding_layers = True
else:
save_embedding_layers = False
if save_embedding_layers and hasattr(model, "get_input_embeddings"):
for layer in [model.get_input_embeddings(), model.get_output_embeddings()]:
if not is_embedding_in_target_modules or has_valid_embedding_base_layer(layer):
# support from version >= 0.6.2
embedding_module_name = get_embedding_layer_name(model, layer, is_embedding_in_target_modules)
if embedding_module_name:
to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k})
elif save_embedding_layers:
warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.")
return to_return
def gather_state_dict_fast(
state_dict: Dict[str, torch.Tensor],
group: dist.ProcessGroup,
device: Optional[Union[torch.device, str]] = None,
dst: int = 0,
) -> Optional[Dict[str, torch.Tensor]]:
if device is None:
device = get_current_device()
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
metadata = [(k, v.shape, v.dtype) for k, v in state_dict.items()]
all_meta_data = [None] * world_size
if rank == dst:
returned_state_dict = state_dict.copy()
dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group)
for i, target_metadata in enumerate(all_meta_data):
if i == dst:
continue
ops = []
for k, shape, dtype in target_metadata:
buffer = torch.empty(shape, dtype=dtype, device=get_current_device())
returned_state_dict[k] = buffer
ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group))
reqs = dist.batch_isend_irecv(ops)
for req, (k, *_) in zip(reqs, target_metadata):
req.wait()
returned_state_dict[k] = returned_state_dict[k].to(device)
return returned_state_dict
else:
dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group)
ops = []
for k, *_ in metadata:
ops.append(dist.P2POp(dist.isend, state_dict[k], dist.get_global_rank(group, dst), group))
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()

View File

@ -1,4 +1,5 @@
import torch.nn as nn import torch.nn as nn
from peft import PeftModel
class ModelWrapper(nn.Module): class ModelWrapper(nn.Module):
@ -13,13 +14,17 @@ class ModelWrapper(nn.Module):
super().__init__() super().__init__()
self.module = module self.module = module
def unwrap(self): def unwrap(self, unwrap_peft: bool = True):
""" """
Unwrap the model to return the original model for checkpoint saving/loading. Unwrap the model to return the original model for checkpoint saving/loading.
""" """
if isinstance(self.module, ModelWrapper): if isinstance(self.module, ModelWrapper):
return self.module.unwrap() model = self.module.unwrap()
return self.module else:
model = self.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
return model
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.module(*args, **kwargs) return self.module(*args, **kwargs)

View File

@ -156,7 +156,9 @@ def _check_for_nccl_hccl_backend(group):
while isinstance(pg, c10d._ProcessGroupWrapper): while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg pg = pg.wrapped_pg
return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and pg.name() == c10d.Backend.NCCL return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and (
pg.name() == c10d.Backend.NCCL or pg.name() == c10d.Backend.HCCL
)
def _check_device(group): def _check_device(group):

View File

@ -4,9 +4,10 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import ( from colossalai.moe._operation import (
@ -16,6 +17,7 @@ from colossalai.moe._operation import (
EPGradScalerOut, EPGradScalerOut,
all_to_all_uneven, all_to_all_uneven,
) )
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.linear import ParallelModule from colossalai.shardformer.layer.linear import ParallelModule
from colossalai.shardformer.shard.utils import set_tensors_to_none from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
@ -167,6 +169,9 @@ def deepseek_v3_model_forward(
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
hidden_states_internal: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@ -203,8 +208,11 @@ def deepseek_v3_model_forward(
) )
position_ids = position_ids.unsqueeze(0) position_ids = position_ids.unsqueeze(0)
if stage_manager is None or stage_manager.is_first_stage():
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
else:
inputs_embeds = hidden_states_internal
if self._use_flash_attention_2: if self._use_flash_attention_2:
# 2d mask is passed through the layers # 2d mask is passed through the layers
@ -226,7 +234,11 @@ def deepseek_v3_model_forward(
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None next_decoder_cache = None
for i, decoder_layer in enumerate(self.layers): if stage_index is not None:
start_idx, end_idx = stage_index
else:
start_idx, end_idx = 0, len(self.layers)
for i, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
@ -258,6 +270,7 @@ def deepseek_v3_model_forward(
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
if stage_manager is None or stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
@ -267,6 +280,10 @@ def deepseek_v3_model_forward(
next_cache = None next_cache = None
if use_cache: if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if stage_manager is not None and not stage_manager.is_last_stage():
return {
"hidden_states_internal": hidden_states,
}
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
@ -275,3 +292,94 @@ def deepseek_v3_model_forward(
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
) )
def deepseek_v3_for_causal_lm_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
hidden_states_internal: Optional[torch.Tensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
>>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = deepseek_v3_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
stage_index=stage_index,
hidden_states_internal=hidden_states_internal,
)
if stage_manager is not None and not stage_manager.is_last_stage():
return outputs
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -1,9 +1,14 @@
from typing import Dict, Union from functools import partial
from typing import Callable, Dict, List, Union
import torch.nn as nn import torch.nn as nn
from colossalai.shardformer.layer import FusedRMSNorm from colossalai.shardformer.layer import FusedRMSNorm
from colossalai.shardformer.modeling.deepseek_v3 import EpDeepseekV3MoE from colossalai.shardformer.modeling.deepseek_v3 import (
EpDeepseekV3MoE,
deepseek_v3_for_causal_lm_forward,
deepseek_v3_model_forward,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"] __all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
@ -12,8 +17,9 @@ __all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
class DeepseekV3Policy(Policy): class DeepseekV3Policy(Policy):
def config_sanity_check(self): def config_sanity_check(self):
assert not self.shard_config.enable_tensor_parallelism, "DeepSeekV3 does not support tensor parallelism" assert not self.shard_config.enable_tensor_parallelism, "DeepSeekV3 does not support tensor parallelism"
assert self.shard_config.pipeline_stage_manager is None, "DeepSeekV3 does not support pipeline parallelism"
assert not self.shard_config.enable_sequence_parallelism, "DeepSeekV3 does not support sequence parallelism" assert not self.shard_config.enable_sequence_parallelism, "DeepSeekV3 does not support sequence parallelism"
if self.shard_config.pipeline_stage_manager:
assert not self.shard_config.pipeline_stage_manager.use_zbv, "DeepSeekV3 does not support ZBV"
def preprocess(self): def preprocess(self):
return self.model return self.model
@ -23,7 +29,10 @@ class DeepseekV3Policy(Policy):
policy = {} policy = {}
# support gradient checkpointing # support gradient checkpointing
# policy["DeepseekV3Model"] = ModulePolicyDescription(method_replacement={"forward": deepseek_v3_model_forward}) if self.shard_config.pipeline_stage_manager is None:
policy["DeepseekV3Model"] = ModulePolicyDescription(
method_replacement={"forward": deepseek_v3_model_forward}
)
if self.shard_config.expert_parallel_size > 1: if self.shard_config.expert_parallel_size > 1:
# expert parallel # expert parallel
@ -74,10 +83,82 @@ class DeepseekV3Policy(Policy):
def postprocess(self): def postprocess(self):
return self.model return self.model
def set_pipeline_forward(self, model_cls: str, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
num_layers = self.model.config.num_hidden_layers
stage_manager = self.pipeline_stage_manager
layers_per_stage = stage_manager.distribute_layers(num_layers)
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return
def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
module = self.model
if module.__class__.__name__.startswith("PeftModel"):
module = module.get_base_model()
if module.__class__.__name__ != "DeepseekV3Model":
module = module.model
stage_manager = self.pipeline_stage_manager
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
stage_manager.stage_indices = stage_indices
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
# for zbv, when is_first_stage (last fwd), we append norm
# for interleaved, when is_last_stage (last fwd), we also append norm
held_layers.append(module.norm)
else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
class DeepseekV3ModelPolicy(DeepseekV3Policy): class DeepseekV3ModelPolicy(DeepseekV3Policy):
pass def module_policy(self):
policy = super().module_policy()
if self.shard_config.pipeline_stage_manager:
self.set_pipeline_forward("DeepseekV3Model", deepseek_v3_model_forward, policy)
return policy
class DeepseekV3ForCausalLMPolicy(DeepseekV3Policy): class DeepseekV3ForCausalLMPolicy(DeepseekV3Policy):
pass def module_policy(self):
policy = super().module_policy()
if self.shard_config.pipeline_stage_manager:
self.set_pipeline_forward("DeepseekV3ForCausalLM", deepseek_v3_for_causal_lm_forward, policy)
return policy
def get_held_layers(self):
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers

View File

@ -60,9 +60,9 @@ MODEL_CONFIGS = {
attn_implementation="flash_attention_2", attn_implementation="flash_attention_2",
trust_remote_code=True, trust_remote_code=True,
), ),
"v3-6b": AutoConfig.from_pretrained( "v3-7b": AutoConfig.from_pretrained(
"deepseek-ai/DeepSeek-V3", "deepseek-ai/DeepSeek-V3",
num_hidden_layers=5, num_hidden_layers=6,
first_k_dense_replace=2, first_k_dense_replace=2,
n_routed_experts=32, n_routed_experts=32,
vocab_size=8192, vocab_size=8192,
@ -210,14 +210,15 @@ def main():
config, trust_remote_code=True, attn_implementation=attn_impl, torch_dtype=torch.bfloat16 config, trust_remote_code=True, attn_implementation=attn_impl, torch_dtype=torch.bfloat16
).to(torch.bfloat16) ).to(torch.bfloat16)
if args.enable_lora: if args.enable_lora:
booster.enable_lora( model = booster.enable_lora(
model, model,
lora_config=LoraConfig(task_type="CAUSAL_LM", target_modules=["gate_proj", "up_proj", "down_proj"]), lora_config=LoraConfig(task_type="CAUSAL_LM", target_modules=["gate_proj", "up_proj", "down_proj"]),
) )
if args.grad_checkpoint: if args.grad_checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
if model.__class__.__name__.startswith("DeepseekV3"): if config.__class__.__name__.startswith("DeepseekV3"):
model.config.use_cache = False
model.eval() model.eval()
# enable grad for moe layers # enable grad for moe layers
for m in model.modules(): for m in model.modules():
@ -257,7 +258,10 @@ def main():
) as prof: # , distributed_debug_mode(10, enable=True): ) as prof: # , distributed_debug_mode(10, enable=True):
if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1: if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1:
data_iter = iter(dataloader) data_iter = iter(dataloader)
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): with tqdm(
range(len(dataloader)), desc="Step", disable=dist.get_rank() != dist.get_world_size() - 1
) as pbar:
for step in pbar:
performance_evaluator.on_step_start(step) performance_evaluator.on_step_start(step)
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(
data_iter, data_iter,
@ -267,23 +271,22 @@ def main():
return_loss=True, return_loss=True,
) )
loss = outputs["loss"] loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1: loss_scalar = loss.item() if loss is not None else None
print(f"Step {step} loss: {loss}") pbar.set_postfix({"loss": loss_scalar})
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
prof.step() prof.step()
print(f"rank {dist.get_rank()} step {step} passed")
else: else:
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): with tqdm(dataloader, desc="Step", disable=not coordinator.is_master()) as pbar:
for step, batch in enumerate(pbar):
performance_evaluator.on_step_start(step) performance_evaluator.on_step_start(step)
outputs = model(**batch) outputs = model(**batch)
loss = outputs[0] loss = outputs[0]
del outputs # free memory del outputs # free memory
if dist.get_rank() == dist.get_world_size() - 1: pbar.set_postfix({"loss": loss.item()})
print(f"Step {step} loss: {loss}")
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
optimizer.step() optimizer.step()

View File

@ -126,6 +126,7 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
booster.save_lora_as_pretrained(model, model_ckpt_path) booster.save_lora_as_pretrained(model, model_ckpt_path)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False)
dist.barrier()
new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config) new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config)
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion) new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
check_state_dict_equal(model.state_dict(), new_model.state_dict()) check_state_dict_equal(model.state_dict(), new_model.state_dict())

View File

@ -51,7 +51,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if target_grad is None: if target_grad is None:
continue continue
target_grad = target_grad.view(-1).chunk(dist.get_world_size(pg))[dist.get_rank(pg)] target_grad = target_grad.view(-1).chunk(dist.get_world_size(pg))[dist.get_rank(pg)]
assert_close(grad, target_grad, atol=3e-1, rtol=0) assert_close(grad, target_grad, atol=5e-1, rtol=0)
@parameterize( @parameterize(