mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)
* [shardformer] support pipeline for deepseek v3 * [checkpointio] fix lora save * [devops] update ci env * [booster] optimize lora * fix test * fix test
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
import concurrent.futures
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from collections import abc as container_abcs
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
@@ -9,8 +10,12 @@ from pathlib import Path
|
||||
from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from packaging.version import Version
|
||||
from peft import PeftModel, PeftType
|
||||
from peft.utils.other import EMBEDDING_LAYER_NAMES, check_file_exists_on_hf_hub
|
||||
from peft.utils.save_and_load import get_embedding_layer_name, has_valid_embedding_base_layer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
@@ -21,6 +26,7 @@ from colossalai.tensor.d_tensor import (
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
@@ -1004,3 +1010,138 @@ def load_state_dict_shards(
|
||||
futures.append(future)
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
yield future.result()
|
||||
|
||||
|
||||
# adapted from `peft/utils/save_and_load.py`
|
||||
def get_lora_state_dict(
|
||||
model: PeftModel, state_dict: dict, adapter_name="default", save_embedding_layers="auto"
|
||||
) -> dict:
|
||||
config = model.peft_config[adapter_name]
|
||||
if config.peft_type != PeftType.LORA:
|
||||
raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.")
|
||||
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
|
||||
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
|
||||
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
|
||||
bias = config.bias
|
||||
if bias == "none":
|
||||
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
|
||||
elif bias == "all":
|
||||
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
|
||||
elif bias == "lora_only":
|
||||
to_return = {}
|
||||
for k in state_dict:
|
||||
if "lora_" in k:
|
||||
to_return[k] = state_dict[k]
|
||||
bias_name = k.split("lora_")[0] + "bias"
|
||||
if bias_name in state_dict:
|
||||
to_return[bias_name] = state_dict[bias_name]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))}
|
||||
if config.use_dora:
|
||||
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
|
||||
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
|
||||
# we want the state_dict format not to change, we remove the "weight" part.
|
||||
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"
|
||||
|
||||
def renamed_dora_weights(k):
|
||||
if k.endswith(new_dora_suffix):
|
||||
k = k[:-7] # remove ".weight"
|
||||
return k
|
||||
|
||||
to_return = {renamed_dora_weights(k): v for k, v in to_return.items()}
|
||||
|
||||
# DEAL WITH EMBEDDINGS
|
||||
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
|
||||
is_embedding_in_target_modules = False
|
||||
if (
|
||||
save_embedding_layers == "auto"
|
||||
and hasattr(config, "target_modules")
|
||||
and any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES)
|
||||
):
|
||||
warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.")
|
||||
save_embedding_layers = is_embedding_in_target_modules = True
|
||||
elif save_embedding_layers == "auto":
|
||||
vocab_size = getattr(getattr(model, "config", None), "vocab_size", None)
|
||||
model_id = getattr(config, "base_model_name_or_path", None)
|
||||
|
||||
# For some models e.g. diffusers the text config file is stored in a subfolder
|
||||
# we need to make sure we can download that config.
|
||||
has_base_config = False
|
||||
|
||||
# ensure that this check is not performed in HF offline mode, see #1452
|
||||
if model_id is not None:
|
||||
local_config_exists = os.path.exists(os.path.join(model_id, "config.json"))
|
||||
exists = local_config_exists or check_file_exists_on_hf_hub(model_id, "config.json")
|
||||
if exists is None:
|
||||
# check failed, could not determine if it exists or not
|
||||
warnings.warn(
|
||||
f"Could not find a config file in {model_id} - will assume that the vocabulary was not modified."
|
||||
)
|
||||
has_base_config = False
|
||||
else:
|
||||
has_base_config = exists
|
||||
|
||||
# check if the vocab size of the base model is different from the vocab size of the finetuned model
|
||||
if (
|
||||
vocab_size
|
||||
and model_id
|
||||
and has_base_config
|
||||
and (vocab_size != model.config.__class__.from_pretrained(model_id).vocab_size)
|
||||
):
|
||||
warnings.warn(
|
||||
"Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning."
|
||||
)
|
||||
save_embedding_layers = True
|
||||
else:
|
||||
save_embedding_layers = False
|
||||
|
||||
if save_embedding_layers and hasattr(model, "get_input_embeddings"):
|
||||
for layer in [model.get_input_embeddings(), model.get_output_embeddings()]:
|
||||
if not is_embedding_in_target_modules or has_valid_embedding_base_layer(layer):
|
||||
# support from version >= 0.6.2
|
||||
embedding_module_name = get_embedding_layer_name(model, layer, is_embedding_in_target_modules)
|
||||
if embedding_module_name:
|
||||
to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k})
|
||||
elif save_embedding_layers:
|
||||
warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.")
|
||||
|
||||
return to_return
|
||||
|
||||
|
||||
def gather_state_dict_fast(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
group: dist.ProcessGroup,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
dst: int = 0,
|
||||
) -> Optional[Dict[str, torch.Tensor]]:
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
rank = dist.get_rank(group)
|
||||
world_size = dist.get_world_size(group)
|
||||
metadata = [(k, v.shape, v.dtype) for k, v in state_dict.items()]
|
||||
all_meta_data = [None] * world_size
|
||||
if rank == dst:
|
||||
returned_state_dict = state_dict.copy()
|
||||
dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group)
|
||||
for i, target_metadata in enumerate(all_meta_data):
|
||||
if i == dst:
|
||||
continue
|
||||
ops = []
|
||||
for k, shape, dtype in target_metadata:
|
||||
buffer = torch.empty(shape, dtype=dtype, device=get_current_device())
|
||||
returned_state_dict[k] = buffer
|
||||
ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group))
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req, (k, *_) in zip(reqs, target_metadata):
|
||||
req.wait()
|
||||
returned_state_dict[k] = returned_state_dict[k].to(device)
|
||||
return returned_state_dict
|
||||
else:
|
||||
dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group)
|
||||
ops = []
|
||||
for k, *_ in metadata:
|
||||
ops.append(dist.P2POp(dist.isend, state_dict[k], dist.get_global_rank(group, dst), group))
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
Reference in New Issue
Block a user