[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)

* [shardformer] support pipeline for deepseek v3

* [checkpointio] fix lora save

* [devops] update ci env

* [booster] optimize lora

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2025-02-14 14:48:54 +08:00
committed by GitHub
parent ec73f1b5e2
commit 014837e725
21 changed files with 478 additions and 91 deletions

View File

@@ -2,6 +2,7 @@
import concurrent.futures
import os
import re
import warnings
from collections import abc as container_abcs
from collections import defaultdict
from itertools import chain
@@ -9,8 +10,12 @@ from pathlib import Path
from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from packaging.version import Version
from peft import PeftModel, PeftType
from peft.utils.other import EMBEDDING_LAYER_NAMES, check_file_exists_on_hf_hub
from peft.utils.save_and_load import get_embedding_layer_name, has_valid_embedding_base_layer
from torch.optim import Optimizer
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
@@ -21,6 +26,7 @@ from colossalai.tensor.d_tensor import (
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.utils import get_current_device
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
SAFE_WEIGHTS_NAME = "model.safetensors"
@@ -1004,3 +1010,138 @@ def load_state_dict_shards(
futures.append(future)
for future in concurrent.futures.as_completed(futures):
yield future.result()
# adapted from `peft/utils/save_and_load.py`
def get_lora_state_dict(
model: PeftModel, state_dict: dict, adapter_name="default", save_embedding_layers="auto"
) -> dict:
config = model.peft_config[adapter_name]
if config.peft_type != PeftType.LORA:
raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.")
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
bias = config.bias
if bias == "none":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
elif bias == "all":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
for k in state_dict:
if "lora_" in k:
to_return[k] = state_dict[k]
bias_name = k.split("lora_")[0] + "bias"
if bias_name in state_dict:
to_return[bias_name] = state_dict[bias_name]
else:
raise NotImplementedError
to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))}
if config.use_dora:
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
# we want the state_dict format not to change, we remove the "weight" part.
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"
def renamed_dora_weights(k):
if k.endswith(new_dora_suffix):
k = k[:-7] # remove ".weight"
return k
to_return = {renamed_dora_weights(k): v for k, v in to_return.items()}
# DEAL WITH EMBEDDINGS
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
is_embedding_in_target_modules = False
if (
save_embedding_layers == "auto"
and hasattr(config, "target_modules")
and any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES)
):
warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.")
save_embedding_layers = is_embedding_in_target_modules = True
elif save_embedding_layers == "auto":
vocab_size = getattr(getattr(model, "config", None), "vocab_size", None)
model_id = getattr(config, "base_model_name_or_path", None)
# For some models e.g. diffusers the text config file is stored in a subfolder
# we need to make sure we can download that config.
has_base_config = False
# ensure that this check is not performed in HF offline mode, see #1452
if model_id is not None:
local_config_exists = os.path.exists(os.path.join(model_id, "config.json"))
exists = local_config_exists or check_file_exists_on_hf_hub(model_id, "config.json")
if exists is None:
# check failed, could not determine if it exists or not
warnings.warn(
f"Could not find a config file in {model_id} - will assume that the vocabulary was not modified."
)
has_base_config = False
else:
has_base_config = exists
# check if the vocab size of the base model is different from the vocab size of the finetuned model
if (
vocab_size
and model_id
and has_base_config
and (vocab_size != model.config.__class__.from_pretrained(model_id).vocab_size)
):
warnings.warn(
"Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning."
)
save_embedding_layers = True
else:
save_embedding_layers = False
if save_embedding_layers and hasattr(model, "get_input_embeddings"):
for layer in [model.get_input_embeddings(), model.get_output_embeddings()]:
if not is_embedding_in_target_modules or has_valid_embedding_base_layer(layer):
# support from version >= 0.6.2
embedding_module_name = get_embedding_layer_name(model, layer, is_embedding_in_target_modules)
if embedding_module_name:
to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k})
elif save_embedding_layers:
warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.")
return to_return
def gather_state_dict_fast(
state_dict: Dict[str, torch.Tensor],
group: dist.ProcessGroup,
device: Optional[Union[torch.device, str]] = None,
dst: int = 0,
) -> Optional[Dict[str, torch.Tensor]]:
if device is None:
device = get_current_device()
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
metadata = [(k, v.shape, v.dtype) for k, v in state_dict.items()]
all_meta_data = [None] * world_size
if rank == dst:
returned_state_dict = state_dict.copy()
dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group)
for i, target_metadata in enumerate(all_meta_data):
if i == dst:
continue
ops = []
for k, shape, dtype in target_metadata:
buffer = torch.empty(shape, dtype=dtype, device=get_current_device())
returned_state_dict[k] = buffer
ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group))
reqs = dist.batch_isend_irecv(ops)
for req, (k, *_) in zip(reqs, target_metadata):
req.wait()
returned_state_dict[k] = returned_state_dict[k].to(device)
return returned_state_dict
else:
dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group)
ops = []
for k, *_ in metadata:
ops.append(dist.P2POp(dist.isend, state_dict[k], dist.get_global_rank(group, dst), group))
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()