mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 03:21:47 +00:00
[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)
* [shardformer] support pipeline for deepseek v3 * [checkpointio] fix lora save * [devops] update ci env * [booster] optimize lora * fix test * fix test
This commit is contained in:
parent
ec73f1b5e2
commit
014837e725
1
.github/workflows/build_on_pr.yml
vendored
1
.github/workflows/build_on_pr.yml
vendored
@ -166,6 +166,7 @@ jobs:
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
HF_ENDPOINT: https://hf-mirror.com
|
||||
|
||||
- name: Collate artifact
|
||||
env:
|
||||
|
1
.github/workflows/build_on_schedule.yml
vendored
1
.github/workflows/build_on_schedule.yml
vendored
@ -70,6 +70,7 @@ jobs:
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
HF_ENDPOINT: https://hf-mirror.com
|
||||
|
||||
- name: Notify Lark
|
||||
id: message-preparation
|
||||
|
@ -79,3 +79,4 @@ jobs:
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
HF_ENDPOINT: https://hf-mirror.com
|
||||
|
@ -73,3 +73,4 @@ jobs:
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
HF_ENDPOINT: https://hf-mirror.com
|
||||
|
@ -67,6 +67,7 @@ jobs:
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
HF_ENDPOINT: https://hf-mirror.com
|
||||
|
||||
- name: Notify Lark
|
||||
id: message-preparation
|
||||
|
@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, T
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from peft import PeftModel
|
||||
from torch import Tensor, inf
|
||||
from torch.distributed import ProcessGroup, get_world_size
|
||||
from torch.nn import Module, SyncBatchNorm
|
||||
@ -219,11 +220,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
with self._hook_context():
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
module = super().unwrap()
|
||||
if isinstance(module, DDP):
|
||||
module = module.module
|
||||
return module
|
||||
def unwrap(self, unwrap_peft: bool = True):
|
||||
model = self.module
|
||||
if isinstance(model, DDP):
|
||||
model = model.module
|
||||
if unwrap_peft and isinstance(model, PeftModel):
|
||||
model = model.get_base_model()
|
||||
return model
|
||||
|
||||
def _force_wait_all_gather(self):
|
||||
for p in self.module.parameters():
|
||||
@ -1509,7 +1512,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
from peft import PeftModel, get_peft_model
|
||||
|
||||
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
|
||||
assert self.pp_size == 1 and self.tp_size == 1
|
||||
assert self.tp_size == 1
|
||||
self.lora_enabled = True
|
||||
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
|
||||
|
||||
|
@ -359,23 +359,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
if os.path.isfile(checkpoint):
|
||||
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
|
||||
return
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
return peft_model.save_pretrained(
|
||||
checkpoint,
|
||||
safe_serialization=use_safetensors,
|
||||
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
|
||||
)
|
||||
super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict=state_dict)
|
||||
|
||||
|
||||
class LowLevelZeroPlugin(DPPluginBase):
|
||||
|
@ -2,6 +2,7 @@ from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from peft import PeftModel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
@ -166,7 +167,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
use_safetensors: bool = False,
|
||||
state_dict: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to checkpoint directory.
|
||||
@ -174,15 +179,17 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
if self.coordinator.is_master():
|
||||
peft_model = model.unwrap()
|
||||
peft_model = model.unwrap(unwrap_peft=False)
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
if state_dict is None:
|
||||
state_dict = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, peft_model.state_dict())
|
||||
if self.coordinator.is_master():
|
||||
return peft_model.save_pretrained(
|
||||
checkpoint,
|
||||
safe_serialization=use_safetensors,
|
||||
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
|
||||
state_dict=state_dict,
|
||||
)
|
||||
|
||||
|
||||
@ -191,8 +198,11 @@ class TorchDDPModel(ModelWrapper):
|
||||
super().__init__(module)
|
||||
self.module = DDP(module, *args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
return self.module.module
|
||||
def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
|
||||
model = self.module.module
|
||||
if unwrap_peft and isinstance(model, PeftModel):
|
||||
model = model.get_base_model()
|
||||
return model
|
||||
|
||||
|
||||
class TorchDDPPlugin(DPPluginBase):
|
||||
|
@ -437,9 +437,6 @@ class TorchFSDPModel(ModelWrapper):
|
||||
super().__init__(module)
|
||||
self.module = FSDP(module, *args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
return self.module
|
||||
|
||||
|
||||
class FSDPOptimizerWrapper(OptimizerWrapper):
|
||||
def __init__(self, optimizer: Optimizer, model: nn.Module):
|
||||
|
@ -437,7 +437,11 @@ class CheckpointIO(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
use_safetensors: bool = False,
|
||||
state_dict: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
||||
@ -446,4 +450,5 @@ class CheckpointIO(ABC):
|
||||
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
||||
state_dict (Optional[dict], optional): The state dict to save. Defaults to None.
|
||||
"""
|
||||
|
@ -308,5 +308,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
)
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
|
||||
def save_lora_as_pretrained(
|
||||
self, model: nn.Module, checkpoint: str, use_safetensors: bool = False, state_dict: Optional[dict] = None
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
@ -33,6 +33,8 @@ from .utils import (
|
||||
async_save_state_dict_shards,
|
||||
create_pinned_state_dict,
|
||||
gather_distributed_param,
|
||||
gather_state_dict_fast,
|
||||
get_lora_state_dict,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
is_safetensors_available,
|
||||
@ -1137,7 +1139,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
return state_
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
@ -1145,12 +1147,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
peft_model = model.unwrap()
|
||||
peft_model = model.unwrap(unwrap_peft=False)
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
if state_dict is None:
|
||||
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict())
|
||||
if self.pp_size > 1:
|
||||
lora_state_dict = get_lora_state_dict(peft_model, state_dict)
|
||||
gathered_lora_state_dict = gather_state_dict_fast(lora_state_dict, self.pp_group, device="cpu")
|
||||
if self.pp_rank == 0:
|
||||
state_dict.update(gathered_lora_state_dict)
|
||||
state_dict = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
|
||||
if self.coordinator.is_master():
|
||||
return peft_model.save_pretrained(
|
||||
checkpoint,
|
||||
safe_serialization=use_safetensors,
|
||||
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
|
||||
state_dict=state_dict,
|
||||
)
|
||||
|
@ -10,6 +10,7 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import get_global_rank
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile
|
||||
from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
|
||||
@ -17,6 +18,8 @@ from colossalai.checkpoint_io.index_file import CheckpointIndexFile
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
StateDictSharder,
|
||||
gather_distributed_param,
|
||||
gather_state_dict_fast,
|
||||
get_lora_state_dict,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
load_shard_state_dict,
|
||||
@ -889,3 +892,26 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
dist.barrier()
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict=None):
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
peft_model = model.unwrap(unwrap_peft=False)
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
if state_dict is None:
|
||||
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict())
|
||||
if self.ep_size > 1:
|
||||
lora_state_dict = get_lora_state_dict(peft_model, state_dict)
|
||||
moe_params = set(n for n, p in peft_model.named_parameters() if is_moe_tensor(p))
|
||||
expert_state_dict = {n: p for n, p in lora_state_dict.items() if n in moe_params}
|
||||
gathered_expert_state_dict = gather_state_dict_fast(expert_state_dict, self.ep_group)
|
||||
if self.ep_rank == 0:
|
||||
state_dict.update(gathered_expert_state_dict)
|
||||
return super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict)
|
||||
|
@ -2,6 +2,7 @@
|
||||
import concurrent.futures
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from collections import abc as container_abcs
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
@ -9,8 +10,12 @@ from pathlib import Path
|
||||
from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from packaging.version import Version
|
||||
from peft import PeftModel, PeftType
|
||||
from peft.utils.other import EMBEDDING_LAYER_NAMES, check_file_exists_on_hf_hub
|
||||
from peft.utils.save_and_load import get_embedding_layer_name, has_valid_embedding_base_layer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
@ -21,6 +26,7 @@ from colossalai.tensor.d_tensor import (
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
@ -1004,3 +1010,138 @@ def load_state_dict_shards(
|
||||
futures.append(future)
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
yield future.result()
|
||||
|
||||
|
||||
# adapted from `peft/utils/save_and_load.py`
|
||||
def get_lora_state_dict(
|
||||
model: PeftModel, state_dict: dict, adapter_name="default", save_embedding_layers="auto"
|
||||
) -> dict:
|
||||
config = model.peft_config[adapter_name]
|
||||
if config.peft_type != PeftType.LORA:
|
||||
raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.")
|
||||
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
|
||||
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
|
||||
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
|
||||
bias = config.bias
|
||||
if bias == "none":
|
||||
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
|
||||
elif bias == "all":
|
||||
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
|
||||
elif bias == "lora_only":
|
||||
to_return = {}
|
||||
for k in state_dict:
|
||||
if "lora_" in k:
|
||||
to_return[k] = state_dict[k]
|
||||
bias_name = k.split("lora_")[0] + "bias"
|
||||
if bias_name in state_dict:
|
||||
to_return[bias_name] = state_dict[bias_name]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))}
|
||||
if config.use_dora:
|
||||
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
|
||||
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
|
||||
# we want the state_dict format not to change, we remove the "weight" part.
|
||||
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"
|
||||
|
||||
def renamed_dora_weights(k):
|
||||
if k.endswith(new_dora_suffix):
|
||||
k = k[:-7] # remove ".weight"
|
||||
return k
|
||||
|
||||
to_return = {renamed_dora_weights(k): v for k, v in to_return.items()}
|
||||
|
||||
# DEAL WITH EMBEDDINGS
|
||||
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
|
||||
is_embedding_in_target_modules = False
|
||||
if (
|
||||
save_embedding_layers == "auto"
|
||||
and hasattr(config, "target_modules")
|
||||
and any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES)
|
||||
):
|
||||
warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.")
|
||||
save_embedding_layers = is_embedding_in_target_modules = True
|
||||
elif save_embedding_layers == "auto":
|
||||
vocab_size = getattr(getattr(model, "config", None), "vocab_size", None)
|
||||
model_id = getattr(config, "base_model_name_or_path", None)
|
||||
|
||||
# For some models e.g. diffusers the text config file is stored in a subfolder
|
||||
# we need to make sure we can download that config.
|
||||
has_base_config = False
|
||||
|
||||
# ensure that this check is not performed in HF offline mode, see #1452
|
||||
if model_id is not None:
|
||||
local_config_exists = os.path.exists(os.path.join(model_id, "config.json"))
|
||||
exists = local_config_exists or check_file_exists_on_hf_hub(model_id, "config.json")
|
||||
if exists is None:
|
||||
# check failed, could not determine if it exists or not
|
||||
warnings.warn(
|
||||
f"Could not find a config file in {model_id} - will assume that the vocabulary was not modified."
|
||||
)
|
||||
has_base_config = False
|
||||
else:
|
||||
has_base_config = exists
|
||||
|
||||
# check if the vocab size of the base model is different from the vocab size of the finetuned model
|
||||
if (
|
||||
vocab_size
|
||||
and model_id
|
||||
and has_base_config
|
||||
and (vocab_size != model.config.__class__.from_pretrained(model_id).vocab_size)
|
||||
):
|
||||
warnings.warn(
|
||||
"Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning."
|
||||
)
|
||||
save_embedding_layers = True
|
||||
else:
|
||||
save_embedding_layers = False
|
||||
|
||||
if save_embedding_layers and hasattr(model, "get_input_embeddings"):
|
||||
for layer in [model.get_input_embeddings(), model.get_output_embeddings()]:
|
||||
if not is_embedding_in_target_modules or has_valid_embedding_base_layer(layer):
|
||||
# support from version >= 0.6.2
|
||||
embedding_module_name = get_embedding_layer_name(model, layer, is_embedding_in_target_modules)
|
||||
if embedding_module_name:
|
||||
to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k})
|
||||
elif save_embedding_layers:
|
||||
warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.")
|
||||
|
||||
return to_return
|
||||
|
||||
|
||||
def gather_state_dict_fast(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
group: dist.ProcessGroup,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
dst: int = 0,
|
||||
) -> Optional[Dict[str, torch.Tensor]]:
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
rank = dist.get_rank(group)
|
||||
world_size = dist.get_world_size(group)
|
||||
metadata = [(k, v.shape, v.dtype) for k, v in state_dict.items()]
|
||||
all_meta_data = [None] * world_size
|
||||
if rank == dst:
|
||||
returned_state_dict = state_dict.copy()
|
||||
dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group)
|
||||
for i, target_metadata in enumerate(all_meta_data):
|
||||
if i == dst:
|
||||
continue
|
||||
ops = []
|
||||
for k, shape, dtype in target_metadata:
|
||||
buffer = torch.empty(shape, dtype=dtype, device=get_current_device())
|
||||
returned_state_dict[k] = buffer
|
||||
ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group))
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req, (k, *_) in zip(reqs, target_metadata):
|
||||
req.wait()
|
||||
returned_state_dict[k] = returned_state_dict[k].to(device)
|
||||
return returned_state_dict
|
||||
else:
|
||||
dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group)
|
||||
ops = []
|
||||
for k, *_ in metadata:
|
||||
ops.append(dist.P2POp(dist.isend, state_dict[k], dist.get_global_rank(group, dst), group))
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
@ -1,4 +1,5 @@
|
||||
import torch.nn as nn
|
||||
from peft import PeftModel
|
||||
|
||||
|
||||
class ModelWrapper(nn.Module):
|
||||
@ -13,13 +14,17 @@ class ModelWrapper(nn.Module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def unwrap(self):
|
||||
def unwrap(self, unwrap_peft: bool = True):
|
||||
"""
|
||||
Unwrap the model to return the original model for checkpoint saving/loading.
|
||||
"""
|
||||
if isinstance(self.module, ModelWrapper):
|
||||
return self.module.unwrap()
|
||||
return self.module
|
||||
model = self.module.unwrap()
|
||||
else:
|
||||
model = self.module
|
||||
if unwrap_peft and isinstance(model, PeftModel):
|
||||
model = model.get_base_model()
|
||||
return model
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
@ -156,7 +156,9 @@ def _check_for_nccl_hccl_backend(group):
|
||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and pg.name() == c10d.Backend.NCCL
|
||||
return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and (
|
||||
pg.name() == c10d.Backend.NCCL or pg.name() == c10d.Backend.HCCL
|
||||
)
|
||||
|
||||
|
||||
def _check_device(group):
|
||||
|
@ -4,9 +4,10 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe._operation import (
|
||||
@ -16,6 +17,7 @@ from colossalai.moe._operation import (
|
||||
EPGradScalerOut,
|
||||
all_to_all_uneven,
|
||||
)
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer.linear import ParallelModule
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
||||
@ -167,6 +169,9 @@ def deepseek_v3_model_forward(
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
hidden_states_internal: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -203,8 +208,11 @@ def deepseek_v3_model_forward(
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
else:
|
||||
inputs_embeds = hidden_states_internal
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
@ -226,7 +234,11 @@ def deepseek_v3_model_forward(
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
if stage_index is not None:
|
||||
start_idx, end_idx = stage_index
|
||||
else:
|
||||
start_idx, end_idx = 0, len(self.layers)
|
||||
for i, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
@ -258,6 +270,7 @@ def deepseek_v3_model_forward(
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
@ -267,6 +280,10 @@ def deepseek_v3_model_forward(
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||
if stage_manager is not None and not stage_manager.is_last_stage():
|
||||
return {
|
||||
"hidden_states_internal": hidden_states,
|
||||
}
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
@ -275,3 +292,94 @@ def deepseek_v3_model_forward(
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def deepseek_v3_for_causal_lm_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
hidden_states_internal: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
|
||||
Returns:
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
|
||||
>>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = deepseek_v3_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
hidden_states_internal=hidden_states_internal,
|
||||
)
|
||||
if stage_manager is not None and not stage_manager.is_last_stage():
|
||||
return outputs
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -1,9 +1,14 @@
|
||||
from typing import Dict, Union
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm
|
||||
from colossalai.shardformer.modeling.deepseek_v3 import EpDeepseekV3MoE
|
||||
from colossalai.shardformer.modeling.deepseek_v3 import (
|
||||
EpDeepseekV3MoE,
|
||||
deepseek_v3_for_causal_lm_forward,
|
||||
deepseek_v3_model_forward,
|
||||
)
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
|
||||
@ -12,8 +17,9 @@ __all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
|
||||
class DeepseekV3Policy(Policy):
|
||||
def config_sanity_check(self):
|
||||
assert not self.shard_config.enable_tensor_parallelism, "DeepSeekV3 does not support tensor parallelism"
|
||||
assert self.shard_config.pipeline_stage_manager is None, "DeepSeekV3 does not support pipeline parallelism"
|
||||
assert not self.shard_config.enable_sequence_parallelism, "DeepSeekV3 does not support sequence parallelism"
|
||||
if self.shard_config.pipeline_stage_manager:
|
||||
assert not self.shard_config.pipeline_stage_manager.use_zbv, "DeepSeekV3 does not support ZBV"
|
||||
|
||||
def preprocess(self):
|
||||
return self.model
|
||||
@ -23,7 +29,10 @@ class DeepseekV3Policy(Policy):
|
||||
policy = {}
|
||||
|
||||
# support gradient checkpointing
|
||||
# policy["DeepseekV3Model"] = ModulePolicyDescription(method_replacement={"forward": deepseek_v3_model_forward})
|
||||
if self.shard_config.pipeline_stage_manager is None:
|
||||
policy["DeepseekV3Model"] = ModulePolicyDescription(
|
||||
method_replacement={"forward": deepseek_v3_model_forward}
|
||||
)
|
||||
|
||||
if self.shard_config.expert_parallel_size > 1:
|
||||
# expert parallel
|
||||
@ -74,10 +83,82 @@ class DeepseekV3Policy(Policy):
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def set_pipeline_forward(self, model_cls: str, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
num_layers = self.model.config.num_hidden_layers
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
layers_per_stage = stage_manager.distribute_layers(num_layers)
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
module = self.model
|
||||
if module.__class__.__name__.startswith("PeftModel"):
|
||||
module = module.get_base_model()
|
||||
if module.__class__.__name__ != "DeepseekV3Model":
|
||||
module = module.model
|
||||
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
stage_manager.stage_indices = stage_indices
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
# for zbv, when is_first_stage (last fwd), we append norm
|
||||
# for interleaved, when is_last_stage (last fwd), we also append norm
|
||||
held_layers.append(module.norm)
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
return held_layers
|
||||
|
||||
|
||||
class DeepseekV3ModelPolicy(DeepseekV3Policy):
|
||||
pass
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
if self.shard_config.pipeline_stage_manager:
|
||||
self.set_pipeline_forward("DeepseekV3Model", deepseek_v3_model_forward, policy)
|
||||
return policy
|
||||
|
||||
|
||||
class DeepseekV3ForCausalLMPolicy(DeepseekV3Policy):
|
||||
pass
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
if self.shard_config.pipeline_stage_manager:
|
||||
self.set_pipeline_forward("DeepseekV3ForCausalLM", deepseek_v3_for_causal_lm_forward, policy)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
@ -60,9 +60,9 @@ MODEL_CONFIGS = {
|
||||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"v3-6b": AutoConfig.from_pretrained(
|
||||
"v3-7b": AutoConfig.from_pretrained(
|
||||
"deepseek-ai/DeepSeek-V3",
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=6,
|
||||
first_k_dense_replace=2,
|
||||
n_routed_experts=32,
|
||||
vocab_size=8192,
|
||||
@ -210,14 +210,15 @@ def main():
|
||||
config, trust_remote_code=True, attn_implementation=attn_impl, torch_dtype=torch.bfloat16
|
||||
).to(torch.bfloat16)
|
||||
if args.enable_lora:
|
||||
booster.enable_lora(
|
||||
model = booster.enable_lora(
|
||||
model,
|
||||
lora_config=LoraConfig(task_type="CAUSAL_LM", target_modules=["gate_proj", "up_proj", "down_proj"]),
|
||||
)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if model.__class__.__name__.startswith("DeepseekV3"):
|
||||
if config.__class__.__name__.startswith("DeepseekV3"):
|
||||
model.config.use_cache = False
|
||||
model.eval()
|
||||
# enable grad for moe layers
|
||||
for m in model.modules():
|
||||
@ -257,7 +258,10 @@ def main():
|
||||
) as prof: # , distributed_debug_mode(10, enable=True):
|
||||
if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1:
|
||||
data_iter = iter(dataloader)
|
||||
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
|
||||
with tqdm(
|
||||
range(len(dataloader)), desc="Step", disable=dist.get_rank() != dist.get_world_size() - 1
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = booster.execute_pipeline(
|
||||
data_iter,
|
||||
@ -267,23 +271,22 @@ def main():
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
loss_scalar = loss.item() if loss is not None else None
|
||||
pbar.set_postfix({"loss": loss_scalar})
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||
prof.step()
|
||||
print(f"rank {dist.get_rank()} step {step} passed")
|
||||
else:
|
||||
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
|
||||
with tqdm(dataloader, desc="Step", disable=not coordinator.is_master()) as pbar:
|
||||
for step, batch in enumerate(pbar):
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
del outputs # free memory
|
||||
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
|
@ -126,6 +126,7 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
|
||||
|
||||
booster.save_lora_as_pretrained(model, model_ckpt_path)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False)
|
||||
dist.barrier()
|
||||
new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config)
|
||||
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||
|
@ -51,7 +51,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
if target_grad is None:
|
||||
continue
|
||||
target_grad = target_grad.view(-1).chunk(dist.get_world_size(pg))[dist.get_rank(pg)]
|
||||
assert_close(grad, target_grad, atol=3e-1, rtol=0)
|
||||
assert_close(grad, target_grad, atol=5e-1, rtol=0)
|
||||
|
||||
|
||||
@parameterize(
|
||||
|
Loading…
Reference in New Issue
Block a user