mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 05:29:36 +00:00
[Colossal-LLaMA] Refactor latest APIs (#6030)
* refactor latest code * update api * add dummy dataset * update Readme * add setup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update files * add PP support * update arguments * update argument * reorg folder * update version * remove IB infor * update utils * update readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update save for zero * update save * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add apex * update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
cc1b0efc17
commit
4a68efb7da
@ -30,7 +30,7 @@ Colossal-LLaMA
|
|||||||
- [Install](#install)
|
- [Install](#install)
|
||||||
- [0. Pre-requisite](#0-pre-requisite)
|
- [0. Pre-requisite](#0-pre-requisite)
|
||||||
- [1. Install required packages](#1-install-required-packages)
|
- [1. Install required packages](#1-install-required-packages)
|
||||||
- [2. Install `xentropy`, `layer_norm` and `rotary`](#2-install-xentropy-layer_norm-and-rotary)
|
- [2. Install Apex](#2-install-apex)
|
||||||
- [How to run](#how-to-run)
|
- [How to run](#how-to-run)
|
||||||
- [1. Init Tokenizer Preparation](#1-init-tokenizer-preparation)
|
- [1. Init Tokenizer Preparation](#1-init-tokenizer-preparation)
|
||||||
- [2. Init Model Preparation](#2-init-model-preparation)
|
- [2. Init Model Preparation](#2-init-model-preparation)
|
||||||
@ -297,17 +297,13 @@ Here is details about CLI arguments:
|
|||||||
#### 1. Install required packages
|
#### 1. Install required packages
|
||||||
```
|
```
|
||||||
cd Colossal-LLaMA
|
cd Colossal-LLaMA
|
||||||
pip install -r requirements.txt
|
pip install -e .
|
||||||
```
|
```
|
||||||
#### 2. Install `xentropy`, `layer_norm` and `rotary`
|
|
||||||
|
#### 2. Install Apex
|
||||||
```bash
|
```bash
|
||||||
git clone git@github.com:Dao-AILab/flash-attention.git
|
git clone git@github.com:NVIDIA/apex.git
|
||||||
# At the root folder
|
# Install from source.
|
||||||
cd csrc/xentropy && pip install .
|
|
||||||
# At the root folder
|
|
||||||
cd csrc/layer_norm && pip install .
|
|
||||||
# At the root folder
|
|
||||||
cd csrc/rotary && pip install .
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### How to run
|
### How to run
|
||||||
@ -427,25 +423,33 @@ Make sure master node can access all nodes (including itself) by ssh without pas
|
|||||||
Here is details about CLI arguments:
|
Here is details about CLI arguments:
|
||||||
* Pre-trained model path: `--pretrained`. Path to the pre-trained model in Hugging Face format.
|
* Pre-trained model path: `--pretrained`. Path to the pre-trained model in Hugging Face format.
|
||||||
* Dataset path: `--dataset`. Path to the pre-tokenized dataset.
|
* Dataset path: `--dataset`. Path to the pre-tokenized dataset.
|
||||||
* Booster plugin: `--plugin`. `gemini`, `gemini_auto`, `zero2`,`zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/).
|
* Booster plugin: `--plugin`. `ddp`,`gemini`, `gemini_auto`, `zero2`,`zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/).
|
||||||
* Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training.
|
* Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training.
|
||||||
* Save interval: `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
|
* Save interval: `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
|
||||||
* Checkpoint directory: `--save_dir`. The directory path to save checkpoint and intermediate states. Intermediate states include `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`.
|
* Checkpoint directory: `--save_dir`. The directory path to save checkpoint and intermediate states. Intermediate states include `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`.
|
||||||
* Tensorboard directory: `--tensorboard_dir`. The path to save tensorboard logs.
|
* Tensorboard directory: `--tensorboard_dir`. The path to save tensorboard logs.
|
||||||
* Configuration file: `--config_file`. The path to save the configuration file.
|
* Configuration file: `--config_file`. The path to save the configuration file.
|
||||||
* Number of epochs: `--num_epochs`. Number of training epochs. The default value is 1.
|
* Number of epochs: `--num_epochs`. Number of training epochs. The default value is 1.
|
||||||
* Micro batch size: `--micro_batch_size`. Batch size per GPU. The default value is 1.
|
* Batch size: `--batch_size`. Batch size per GPU. The default value is 1. For PP, it refers to number of samples per step.
|
||||||
* Learning rate: `--lr`. The default value is 3e-4.
|
* Learning rate: `--lr`. The default value is 3e-4.
|
||||||
* Max length: `--max_length`. Max context length. The default value is 4096.
|
* Max length: `--max_length`. Max context length. The default value is 4096.
|
||||||
* Mixed precision: `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
|
* Mixed precision: `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
|
||||||
* Gradient clipping: `--gradient_clipping`. The default value is 1.0.
|
* Gradient clipping: `--gradient_clipping`. The default value is 1.0.
|
||||||
* Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
|
* Weight decay: `--weight_decay`. The default value is 0.1.
|
||||||
* Warmup steps: `-s`, `--warmup_steps`. The default value is calculated by 0.025 warmup ratio.
|
* Warmup steps: `--warmup_steps`. The default value is calculated by 0.025 warmup ratio.
|
||||||
* Gradient checkpointing: `--use_grad_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
|
* Gradient checkpointing: `--use_grad_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
|
||||||
* Flash attention: `--use_flash_attn`. If you want to use flash attention, you must install `flash-attn` and related packages. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
|
* Flash attention: `--use_flash_attn`. If you want to use flash attention, you must install `flash-attn` and related packages. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
|
||||||
* Freeze non-embedding parameters: `--freeze_non_embeds_params`. Freeze non-embedding parameters. It can be helpful to align embeddings after extending vocabulary size.
|
* Freeze non-embedding parameters: `--freeze_non_embeds_params`. Freeze non-embedding parameters. It can be helpful to align embeddings after extending vocabulary size.
|
||||||
* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1.
|
* Tensor parallelism size: `--tp`. TP size for 3d parallelism. The default value is 1. Used for 3d plugin.
|
||||||
* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1.
|
* Pipeline parallelism size: `--pp`. PP size for 3d parallelism. The default value is 1. Used for 3d plugin.
|
||||||
|
* Sequence parallelism size: `--sp`. SP size for 3d parallelism. The default value is 1. Used for 3d plugin.
|
||||||
|
* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1. Used for 3d plugin.
|
||||||
|
* Sequence parallelism mode: `--sp_mode`. SP mode, used for 3d plugin. Choose from "split_gather", "ring", "all_to_all".
|
||||||
|
* Switch for sequence parallelism: `--enable_sequence_parallelism`. Whether to enable SP, used for 3d plugin.
|
||||||
|
* Zero CPU offload: `--zero_cpu_offload`. Whether to use offloading, used for 3d plugin.
|
||||||
|
* Micro batch size: `--microbatch_size`. Batch size for each process in PP, used for 3d plugin.
|
||||||
|
* Number of dummy sample: `--num_samples`. Number of samples for benchmarking.
|
||||||
|
* Benchmark switch: `--benchmark`. Benchmark performance using random dataset.
|
||||||
|
|
||||||
##### 4.2 Arguments for Supervised Fine-tuning
|
##### 4.2 Arguments for Supervised Fine-tuning
|
||||||
We add support for gradient accumulation and NEFTuning for supervised fine-tuning and thus there are two more arguments apart from the arguments listed in [4.1 Arguments for Pretraining](#41-arguments-for-pretraining).
|
We add support for gradient accumulation and NEFTuning for supervised fine-tuning and thus there are two more arguments apart from the arguments listed in [4.1 Arguments for Pretraining](#41-arguments-for-pretraining).
|
||||||
|
@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
|
||||||
|
|
||||||
|
class RandomDataset(Dataset):
|
||||||
|
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.max_length = max_length
|
||||||
|
self.input_ids = torch.randint(
|
||||||
|
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
|
||||||
|
)
|
||||||
|
self.attention_mask = torch.ones_like(self.input_ids)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return {
|
||||||
|
"input_ids": self.input_ids[idx],
|
||||||
|
"attention_mask": self.attention_mask[idx],
|
||||||
|
"labels": self.input_ids[idx],
|
||||||
|
}
|
@ -1,352 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
import math
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import rearrange
|
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
LlamaAttention,
|
|
||||||
LlamaForCausalLM,
|
|
||||||
LlamaModel,
|
|
||||||
LlamaRMSNorm,
|
|
||||||
apply_rotary_pos_emb,
|
|
||||||
repeat_kv,
|
|
||||||
)
|
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
|
||||||
|
|
||||||
if get_accelerator().name == "cuda":
|
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
|
|
||||||
from flash_attn.ops.rms_norm import rms_norm
|
|
||||||
|
|
||||||
def _prepare_decoder_attention_mask(
|
|
||||||
self: LlamaModel,
|
|
||||||
attention_mask: torch.BoolTensor,
|
|
||||||
input_shape: torch.Size,
|
|
||||||
inputs_embeds: torch.Tensor,
|
|
||||||
past_key_values_length: int,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Decoder attetion mask
|
|
||||||
"""
|
|
||||||
if past_key_values_length > 0 and attention_mask is not None:
|
|
||||||
attention_mask = torch.cat(
|
|
||||||
tensors=(
|
|
||||||
torch.full(
|
|
||||||
size=(input_shape[0], past_key_values_length),
|
|
||||||
fill_value=True,
|
|
||||||
dtype=attention_mask.dtype,
|
|
||||||
device=attention_mask.device,
|
|
||||||
),
|
|
||||||
attention_mask,
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
) # (bsz, past_key_values_length + q_len)
|
|
||||||
if attention_mask is not None and torch.all(attention_mask):
|
|
||||||
return None # Faster
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
def attention_forward(
|
|
||||||
self: LlamaAttention,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
"""
|
|
||||||
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
|
||||||
"""
|
|
||||||
if output_attentions:
|
|
||||||
logger.warning(
|
|
||||||
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
|
|
||||||
"return `None` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
|
||||||
q_slicing, kv_slicing = (
|
|
||||||
dim // self.config.pretraining_tp
|
|
||||||
for dim in (
|
|
||||||
self.num_heads * self.head_dim,
|
|
||||||
self.num_key_value_heads * self.head_dim,
|
|
||||||
)
|
|
||||||
) # `Tuple[int, int]`
|
|
||||||
q_slices, k_slices, v_slices = (
|
|
||||||
proj.weight.split(slicing, dim=0)
|
|
||||||
for proj, slicing in (
|
|
||||||
(self.q_proj, q_slicing),
|
|
||||||
(self.k_proj, kv_slicing),
|
|
||||||
(self.v_proj, kv_slicing),
|
|
||||||
)
|
|
||||||
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
|
|
||||||
q, k, v = (
|
|
||||||
torch.cat(
|
|
||||||
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
for slices in (q_slices, k_slices, v_slices)
|
|
||||||
)
|
|
||||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
|
||||||
# (bsz, q_len, num_heads * head_dim),
|
|
||||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
|
||||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
|
||||||
else:
|
|
||||||
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
|
|
||||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
|
||||||
# (bsz, q_len, num_heads * head_dim),
|
|
||||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
|
||||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
|
||||||
|
|
||||||
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
|
|
||||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
|
|
||||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
|
|
||||||
q, k, v = (
|
|
||||||
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
for states, num_heads in (
|
|
||||||
(q, self.num_heads),
|
|
||||||
(k, self.num_key_value_heads),
|
|
||||||
(v, self.num_key_value_heads),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
|
|
||||||
past_kv_len = 0
|
|
||||||
if past_key_value is not None:
|
|
||||||
# if `past_key_value` is not None, `kv_len` > `q_len`.
|
|
||||||
past_kv_len = past_key_value[0].shape[-2]
|
|
||||||
kv_len += past_kv_len
|
|
||||||
|
|
||||||
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
|
|
||||||
cos, sin = self.rotary_emb(v, seq_len=kv_len)
|
|
||||||
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
|
|
||||||
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
|
|
||||||
if past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
k = torch.cat([past_key_value[0], k], dim=2)
|
|
||||||
v = torch.cat([past_key_value[1], v], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (k, v) if use_cache else None
|
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
|
||||||
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
|
|
||||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
|
||||||
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
|
|
||||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
|
||||||
|
|
||||||
key_padding_mask = attention_mask
|
|
||||||
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
|
|
||||||
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
|
|
||||||
|
|
||||||
if past_kv_len > 0:
|
|
||||||
q = torch.cat(
|
|
||||||
tensors=(
|
|
||||||
torch.full(
|
|
||||||
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
|
|
||||||
fill_value=0.0,
|
|
||||||
dtype=q.dtype,
|
|
||||||
device=q.device,
|
|
||||||
),
|
|
||||||
q,
|
|
||||||
),
|
|
||||||
dim=1,
|
|
||||||
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
|
|
||||||
|
|
||||||
if key_padding_mask is None:
|
|
||||||
# (bsz, past_kv_len + q_len, num_heads, head_dim)
|
|
||||||
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
|
|
||||||
output = rearrange(
|
|
||||||
output, pattern="... h d -> ... (h d)"
|
|
||||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
|
||||||
else:
|
|
||||||
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
|
|
||||||
kv, _, cu_kv_lens, max_kv_len = unpad_input(
|
|
||||||
hidden_states=torch.stack(tensors=(k, v), dim=2),
|
|
||||||
attention_mask=key_padding_mask,
|
|
||||||
)
|
|
||||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
|
||||||
q=q,
|
|
||||||
kv=kv,
|
|
||||||
cu_seqlens_q=cu_q_lens,
|
|
||||||
cu_seqlens_k=cu_kv_lens,
|
|
||||||
max_seqlen_q=max_q_len,
|
|
||||||
max_seqlen_k=max_kv_len,
|
|
||||||
dropout_p=0.0,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
output = pad_input(
|
|
||||||
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
|
|
||||||
indices=indices,
|
|
||||||
batch=bsz,
|
|
||||||
seqlen=past_kv_len + q_len,
|
|
||||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
|
||||||
|
|
||||||
if past_kv_len > 0:
|
|
||||||
# Strip off the zero query outputs.
|
|
||||||
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
|
|
||||||
output = self.o_proj(output) # (bsz, q_len, hidden_size)
|
|
||||||
return output, None, past_key_value
|
|
||||||
|
|
||||||
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Formard function for RMS Norm
|
|
||||||
"""
|
|
||||||
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
|
|
||||||
|
|
||||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if isinstance(module, LlamaAttention):
|
|
||||||
module.forward = MethodType(attention_forward, module)
|
|
||||||
if isinstance(module, LlamaModel):
|
|
||||||
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
|
|
||||||
if isinstance(module, LlamaRMSNorm):
|
|
||||||
module.forward = MethodType(rms_norm_forward, module)
|
|
||||||
|
|
||||||
elif get_accelerator().name == "npu":
|
|
||||||
import torch_npu
|
|
||||||
|
|
||||||
class NPULlamaAttention(LlamaAttention):
|
|
||||||
use_flash: bool = True
|
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self.setup()
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self._softmax_scale = 1 / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
|
||||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
|
||||||
query_slices = self.q_proj.weight.split(
|
|
||||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
|
||||||
)
|
|
||||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
|
|
||||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
query_states = torch.cat(query_states, dim=-1)
|
|
||||||
|
|
||||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
key_states = torch.cat(key_states, dim=-1)
|
|
||||||
|
|
||||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
if not self.use_flash:
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
||||||
f" {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
else:
|
|
||||||
attn_output, *_ = torch_npu.npu_fusion_attention(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
self.num_heads,
|
|
||||||
"BNSD",
|
|
||||||
atten_mask=attention_mask.bool(),
|
|
||||||
scale=self._softmax_scale,
|
|
||||||
padding_mask=None,
|
|
||||||
pre_tockens=65535,
|
|
||||||
next_tockens=0,
|
|
||||||
keep_prob=1.0,
|
|
||||||
inner_precise=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
|
||||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
|
||||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
|
||||||
attn_output = sum(
|
|
||||||
[F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
if not output_attentions:
|
|
||||||
attn_weights = None
|
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
|
||||||
|
|
||||||
class NPURMSNorm(LlamaRMSNorm):
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
|
||||||
|
|
||||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if isinstance(module, LlamaAttention):
|
|
||||||
module.__class__ = NPULlamaAttention
|
|
||||||
module.setup()
|
|
||||||
if isinstance(module, LlamaRMSNorm):
|
|
||||||
module.__class__ = NPURMSNorm
|
|
36
applications/Colossal-LLaMA/colossal_llama/utils/utils.py
Normal file
36
applications/Colossal-LLaMA/colossal_llama/utils/utils.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
"""
|
||||||
|
Utils for Colossal-LLaMA
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from colossalai.booster import Plugin
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
||||||
|
if plugin is not None:
|
||||||
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
|
||||||
|
tensor.div_(plugin.dp_size)
|
||||||
|
else:
|
||||||
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||||
|
tensor.div_(dist.get_world_size())
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_numel(model: torch.nn.Module) -> int:
|
||||||
|
return sum(p.numel() for p in model.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
def format_numel_str(numel: int) -> str:
|
||||||
|
B = 1024**3
|
||||||
|
M = 1024**2
|
||||||
|
K = 1024
|
||||||
|
if numel >= B:
|
||||||
|
return f"{numel / B:.2f} B"
|
||||||
|
elif numel >= M:
|
||||||
|
return f"{numel / M:.2f} M"
|
||||||
|
elif numel >= K:
|
||||||
|
return f"{numel / K:.2f} K"
|
||||||
|
else:
|
||||||
|
return f"{numel}"
|
@ -1,15 +1,15 @@
|
|||||||
torch==2.1.2
|
torch==2.1.2
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
packaging==24.0
|
packaging==24.0
|
||||||
colossalai==0.3.6
|
colossalai>=0.4.0
|
||||||
autoflake==2.2.1
|
autoflake==2.2.1
|
||||||
black==23.9.1
|
black==23.9.1
|
||||||
transformers==4.34.1
|
transformers>=4.39.3
|
||||||
tensorboard==2.14.0
|
tensorboard==2.14.0
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
datasets
|
datasets
|
||||||
ninja==1.11.1
|
ninja==1.11.1
|
||||||
flash-attn>=2.0.0,<=2.0.5
|
flash-attn
|
||||||
tqdm
|
tqdm
|
||||||
sentencepiece==0.1.99
|
sentencepiece==0.1.99
|
||||||
protobuf<=3.20.0
|
protobuf<=3.20.0
|
||||||
|
37
applications/Colossal-LLaMA/setup.py
Normal file
37
applications/Colossal-LLaMA/setup.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_requirements(path):
|
||||||
|
with open(path, "r") as fd:
|
||||||
|
return [r.strip() for r in fd.readlines()]
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_readme():
|
||||||
|
with open("README.md", encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_version():
|
||||||
|
with open("version.txt", "r") as f:
|
||||||
|
return f.read().strip()
|
||||||
|
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="colossal_llama",
|
||||||
|
version=fetch_version(),
|
||||||
|
packages=find_packages(exclude=("*.egg-info",)),
|
||||||
|
description="Continual Pre-training and SFT for LLaMA",
|
||||||
|
long_description=fetch_readme(),
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
license="Apache Software License 2.0",
|
||||||
|
url="https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA",
|
||||||
|
install_requires=fetch_requirements("requirements.txt"),
|
||||||
|
python_requires=">=3.7",
|
||||||
|
classifiers=[
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"License :: OSI Approved :: Apache Software License",
|
||||||
|
"Environment :: GPU :: NVIDIA CUDA",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: System :: Distributed Computing",
|
||||||
|
],
|
||||||
|
)
|
@ -1,13 +1,20 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||||
|
local n=${1:-"9999"}
|
||||||
|
echo "GPU Memory Usage:"
|
||||||
|
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||||
|
tail -n +2 |
|
||||||
|
nl -v 0 |
|
||||||
|
tee /dev/tty |
|
||||||
|
sort -g -k 2 |
|
||||||
|
awk '{print $1}' |
|
||||||
|
head -n $n)
|
||||||
|
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||||
|
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||||
|
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||||
|
}
|
||||||
|
|
||||||
# NCCL IB environment variables
|
set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
||||||
export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
|
|
||||||
export NCCL_IB_DISABLE=0
|
|
||||||
export NCCL_SOCKET_IFNAME=eth0
|
|
||||||
export NCCL_IB_GID_INDEX=3
|
|
||||||
export NCCL_IB_TIMEOUT=23
|
|
||||||
export NCCL_IB_RETRY_CNT=7
|
|
||||||
export OMP_NUM_THREADS=8
|
|
||||||
|
|
||||||
PROJECT_NAME=""
|
PROJECT_NAME=""
|
||||||
PARENT_SAVE_DIR=""
|
PARENT_SAVE_DIR=""
|
||||||
|
@ -11,24 +11,24 @@ import resource
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
from colossal_llama.dataset.dummy_dataset import RandomDataset
|
||||||
from colossal_llama.dataset.loader import (
|
from colossal_llama.dataset.loader import (
|
||||||
DataCollatorForSupervisedDataset,
|
DataCollatorForSupervisedDataset,
|
||||||
StatefulDistributedSampler,
|
StatefulDistributedSampler,
|
||||||
load_tokenized_dataset,
|
load_tokenized_dataset,
|
||||||
)
|
)
|
||||||
from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint
|
from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||||
from colossal_llama.utils.flash_attention_patch import replace_with_flash_attention
|
|
||||||
from colossal_llama.utils.froze import freeze_non_embeds_parameters
|
from colossal_llama.utils.froze import freeze_non_embeds_parameters
|
||||||
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
|
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||||
|
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoTokenizer, LlamaForCausalLM
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
@ -36,109 +36,7 @@ from colossalai.nn.optimizer import HybridAdam
|
|||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
def get_model_numel(model: torch.nn.Module) -> int:
|
def train(args) -> None:
|
||||||
return sum(p.numel() for p in model.parameters())
|
|
||||||
|
|
||||||
|
|
||||||
def format_numel_str(numel: int) -> str:
|
|
||||||
B = 1024**3
|
|
||||||
M = 1024**2
|
|
||||||
K = 1024
|
|
||||||
if numel >= B:
|
|
||||||
return f"{numel / B:.2f} B"
|
|
||||||
elif numel >= M:
|
|
||||||
return f"{numel / M:.2f} M"
|
|
||||||
elif numel >= K:
|
|
||||||
return f"{numel / K:.2f} K"
|
|
||||||
else:
|
|
||||||
return f"{numel}"
|
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
|
||||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
|
||||||
tensor = tensor.data
|
|
||||||
tensor.div_(dist.get_world_size())
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
# ==============================
|
|
||||||
# Parse Arguments
|
|
||||||
# ==============================
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--pretrained",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Address of the pre-trained modeling",
|
|
||||||
)
|
|
||||||
parser.add_argument("--dataset", nargs="+", default=[])
|
|
||||||
parser.add_argument(
|
|
||||||
"--plugin",
|
|
||||||
type=str,
|
|
||||||
default="gemini",
|
|
||||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
|
||||||
help="Choose which plugin to use",
|
|
||||||
)
|
|
||||||
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
|
|
||||||
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
|
|
||||||
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
|
|
||||||
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
|
||||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
|
||||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
|
|
||||||
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
|
|
||||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
|
||||||
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
|
|
||||||
parser.add_argument(
|
|
||||||
"--mixed_precision",
|
|
||||||
type=str,
|
|
||||||
default="fp16",
|
|
||||||
choices=["fp16", "bf16"],
|
|
||||||
help="Mixed precision",
|
|
||||||
)
|
|
||||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
|
||||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
|
||||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_grad_checkpoint",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Use gradient checkpointing",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_flash_attn",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Use flash-attention",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_neft",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Use NEFTune",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--freeze_non_embeds_params",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Freeze non embeddings parameters",
|
|
||||||
)
|
|
||||||
parser.add_argument("--tp", type=int, default=1)
|
|
||||||
parser.add_argument("--zero", type=int, default=1)
|
|
||||||
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
|
|
||||||
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_save_each_epoch",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="skip saving the model checkpoint after each epoch is completed.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
with open(args.config_file, "w") as f:
|
|
||||||
json.dump(args.__dict__, f, indent=4)
|
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Initialize Distributed Training
|
# Initialize Distributed Training
|
||||||
# ==============================
|
# ==============================
|
||||||
@ -147,21 +45,27 @@ def main() -> None:
|
|||||||
coordinator = DistCoordinator()
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Initialize Tensorboard
|
# Initialize Tensorboard and Save Config
|
||||||
# ==============================
|
# ==============================
|
||||||
if coordinator.is_master():
|
if coordinator.is_master():
|
||||||
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||||
writer = SummaryWriter(args.tensorboard_dir)
|
writer = SummaryWriter(args.tensorboard_dir)
|
||||||
|
|
||||||
|
with open(args.config_file, "w") as f:
|
||||||
|
json.dump(args.__dict__, f, indent=4)
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Initialize Booster
|
# Initialize Booster
|
||||||
# ==============================
|
# ==============================
|
||||||
if args.plugin == "gemini":
|
if args.plugin == "ddp":
|
||||||
|
plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False)
|
||||||
|
elif args.plugin == "gemini":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
initial_scale=2**16,
|
initial_scale=2**16,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||||
|
enable_flash_attention=args.use_flash_attn,
|
||||||
)
|
)
|
||||||
elif args.plugin == "gemini_auto":
|
elif args.plugin == "gemini_auto":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
@ -170,6 +74,7 @@ def main() -> None:
|
|||||||
initial_scale=2**16,
|
initial_scale=2**16,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||||
|
enable_flash_attention=args.use_flash_attn,
|
||||||
)
|
)
|
||||||
elif args.plugin == "zero2":
|
elif args.plugin == "zero2":
|
||||||
plugin = LowLevelZeroPlugin(
|
plugin = LowLevelZeroPlugin(
|
||||||
@ -189,10 +94,17 @@ def main() -> None:
|
|||||||
elif args.plugin == "3d":
|
elif args.plugin == "3d":
|
||||||
plugin = HybridParallelPlugin(
|
plugin = HybridParallelPlugin(
|
||||||
tp_size=args.tp,
|
tp_size=args.tp,
|
||||||
pp_size=1,
|
pp_size=args.pp,
|
||||||
zero_stage=args.zero,
|
sp_size=args.sp,
|
||||||
|
sequence_parallelism_mode=args.sp_mode,
|
||||||
|
zero_stage=args.zero_stage,
|
||||||
|
enable_flash_attention=args.use_flash_attn,
|
||||||
|
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||||
|
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||||
|
parallel_output=False,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
|
microbatch_size=args.microbatch_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
@ -210,24 +122,38 @@ def main() -> None:
|
|||||||
tokenizer.add_bos_token = False
|
tokenizer.add_bos_token = False
|
||||||
tokenizer.add_eos_token = False
|
tokenizer.add_eos_token = False
|
||||||
|
|
||||||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
coordinator.print_on_master(
|
||||||
coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
|
f"Training Info:\nConfig file: {args.config_file} \nTensorboard logs: {args.tensorboard_dir} \nModel checkpoint: {args.save_dir}"
|
||||||
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
|
|
||||||
|
|
||||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
|
||||||
|
|
||||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
|
||||||
data_collator = DataCollatorForSupervisedDataset(
|
|
||||||
tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode
|
|
||||||
)
|
|
||||||
dataloader = plugin.prepare_dataloader(
|
|
||||||
dataset=dataset,
|
|
||||||
batch_size=args.micro_batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
drop_last=True,
|
|
||||||
collate_fn=data_collator,
|
|
||||||
distributed_sampler_cls=StatefulDistributedSampler,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.benchmark:
|
||||||
|
coordinator.print_on_master(f"Run benchmark with {args.num_samples} random samples.")
|
||||||
|
dataset = RandomDataset(
|
||||||
|
num_samples=args.num_samples, max_length=args.max_length, vocab_size=tokenizer.vocab_size
|
||||||
|
)
|
||||||
|
dataloader = plugin.prepare_dataloader(
|
||||||
|
dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
seed=42,
|
||||||
|
distributed_sampler_cls=StatefulDistributedSampler,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||||
|
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||||
|
data_collator = DataCollatorForSupervisedDataset(
|
||||||
|
tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode
|
||||||
|
)
|
||||||
|
dataloader = plugin.prepare_dataloader(
|
||||||
|
dataset=dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
collate_fn=data_collator,
|
||||||
|
distributed_sampler_cls=StatefulDistributedSampler,
|
||||||
|
)
|
||||||
|
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
)
|
)
|
||||||
@ -241,7 +167,19 @@ def main() -> None:
|
|||||||
else nullcontext()
|
else nullcontext()
|
||||||
)
|
)
|
||||||
with init_ctx:
|
with init_ctx:
|
||||||
model = LlamaForCausalLM.from_pretrained(args.pretrained)
|
if args.use_flash_attn:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
args.pretrained,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
args.pretrained,
|
||||||
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
# Freeze part of parameters.
|
# Freeze part of parameters.
|
||||||
if args.freeze_non_embeds_params:
|
if args.freeze_non_embeds_params:
|
||||||
freeze_non_embeds_parameters(model=model)
|
freeze_non_embeds_parameters(model=model)
|
||||||
@ -251,9 +189,6 @@ def main() -> None:
|
|||||||
if args.use_grad_checkpoint:
|
if args.use_grad_checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||||
if args.use_flash_attn:
|
|
||||||
replace_with_flash_attention(model=model)
|
|
||||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
|
||||||
|
|
||||||
model_numel = get_model_numel(model)
|
model_numel = get_model_numel(model)
|
||||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||||
@ -342,43 +277,98 @@ def main() -> None:
|
|||||||
|
|
||||||
for epoch in range(start_epoch, args.num_epochs):
|
for epoch in range(start_epoch, args.num_epochs):
|
||||||
dataloader.sampler.set_epoch(epoch=epoch)
|
dataloader.sampler.set_epoch(epoch=epoch)
|
||||||
pbar = tqdm(
|
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
|
||||||
desc=f"Epoch {epoch}",
|
data_iter = iter(dataloader)
|
||||||
disable=not coordinator.is_master(),
|
step_bar = tqdm(
|
||||||
total=num_steps_per_epoch,
|
range(len(dataloader)),
|
||||||
initial=start_step // args.accumulation_steps,
|
desc="Step",
|
||||||
)
|
disable=not (coordinator._local_rank == coordinator._world_size - 1),
|
||||||
total_loss = torch.tensor(0.0, device=get_current_device())
|
)
|
||||||
for step, batch in enumerate(dataloader, start=start_step):
|
for step in step_bar:
|
||||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
outputs = booster.execute_pipeline(
|
||||||
|
data_iter,
|
||||||
batch_output = model(**batch)
|
model,
|
||||||
|
criterion=lambda outputs, inputs: outputs[0],
|
||||||
loss = batch_output.loss / args.accumulation_steps
|
optimizer=optimizer,
|
||||||
total_loss.add_(loss.data)
|
return_loss=True,
|
||||||
|
)
|
||||||
booster.backward(loss=loss, optimizer=optimizer)
|
loss = outputs["loss"]
|
||||||
|
if booster.plugin.stage_manager.is_last_stage():
|
||||||
if (step + 1) % args.accumulation_steps == 0:
|
global_loss = all_reduce_mean(loss, plugin)
|
||||||
|
if coordinator._local_rank == coordinator._world_size - 1:
|
||||||
|
step_bar.set_postfix({"train/loss": global_loss.item()})
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
all_reduce_mean(tensor=total_loss)
|
# Save modeling.
|
||||||
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
|
save_model_condition = args.save_interval > 0 and (step + 1) % args.save_interval == 0
|
||||||
if coordinator.is_master():
|
|
||||||
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
if not args.skip_save_each_epoch:
|
||||||
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
|
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
|
||||||
writer.add_scalar(
|
|
||||||
tag="Learning Rate",
|
if save_model_condition and not args.benchmark:
|
||||||
scalar_value=lr_scheduler.get_last_lr()[0],
|
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||||
global_step=global_step,
|
|
||||||
|
if args.use_neft:
|
||||||
|
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
||||||
|
deactivate_neftune(model, handle)
|
||||||
|
|
||||||
|
accelerator.empty_cache()
|
||||||
|
save_checkpoint(
|
||||||
|
save_dir=args.save_dir,
|
||||||
|
booster=booster,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
epoch=epoch,
|
||||||
|
step=step + 1,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
coordinator=coordinator,
|
||||||
)
|
)
|
||||||
total_loss.fill_(0.0)
|
coordinator.print_on_master(
|
||||||
pbar.update()
|
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.use_neft:
|
||||||
|
coordinator.print_on_master("Activate NEFTune.")
|
||||||
|
model, handle = activate_neftune(model)
|
||||||
|
else:
|
||||||
|
pbar = tqdm(
|
||||||
|
desc=f"Epoch {epoch}",
|
||||||
|
disable=not coordinator.is_master(),
|
||||||
|
total=num_steps_per_epoch,
|
||||||
|
initial=start_step // args.accumulation_steps,
|
||||||
|
)
|
||||||
|
total_loss = torch.tensor(0.0, device=get_current_device())
|
||||||
|
for step, batch in enumerate(dataloader, start=start_step):
|
||||||
|
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||||
|
|
||||||
|
batch_output = model(**batch)
|
||||||
|
|
||||||
|
loss = batch_output.loss / args.accumulation_steps
|
||||||
|
total_loss.add_(loss.data)
|
||||||
|
|
||||||
|
booster.backward(loss=loss, optimizer=optimizer)
|
||||||
|
|
||||||
|
if (step + 1) % args.accumulation_steps == 0:
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
all_reduce_mean(tensor=total_loss)
|
||||||
|
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
|
||||||
|
if coordinator.is_master():
|
||||||
|
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
||||||
|
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
|
||||||
|
writer.add_scalar(
|
||||||
|
tag="Learning Rate",
|
||||||
|
scalar_value=lr_scheduler.get_last_lr()[0],
|
||||||
|
global_step=global_step,
|
||||||
|
)
|
||||||
|
total_loss.fill_(0.0)
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
# Save modeling.
|
# Save modeling.
|
||||||
|
|
||||||
save_model_condition = (
|
save_model_condition = (
|
||||||
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
|
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
|
||||||
)
|
)
|
||||||
@ -386,7 +376,7 @@ def main() -> None:
|
|||||||
if not args.skip_save_each_epoch:
|
if not args.skip_save_each_epoch:
|
||||||
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
|
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
|
||||||
|
|
||||||
if save_model_condition:
|
if save_model_condition and not args.benchmark:
|
||||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||||
|
|
||||||
if args.use_neft:
|
if args.use_neft:
|
||||||
@ -402,7 +392,7 @@ def main() -> None:
|
|||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
step=step + 1,
|
step=step + 1,
|
||||||
batch_size=args.micro_batch_size,
|
batch_size=args.batch_size,
|
||||||
coordinator=coordinator,
|
coordinator=coordinator,
|
||||||
)
|
)
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
@ -426,12 +416,114 @@ def main() -> None:
|
|||||||
deactivate_neftune(model, handle)
|
deactivate_neftune(model, handle)
|
||||||
|
|
||||||
# Final save.
|
# Final save.
|
||||||
coordinator.print_on_master("Start saving final model checkpoint")
|
if not args.benchmark:
|
||||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
coordinator.print_on_master("Start saving final model checkpoint")
|
||||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||||
|
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||||
|
|
||||||
coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
|
coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
parser = argparse.ArgumentParser()
|
||||||
|
# Basic training information.
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Address of the pre-trained model",
|
||||||
|
)
|
||||||
|
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint for continuous training.")
|
||||||
|
parser.add_argument("--dataset", nargs="+", default=[])
|
||||||
|
parser.add_argument(
|
||||||
|
"--plugin",
|
||||||
|
type=str,
|
||||||
|
default="gemini",
|
||||||
|
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"],
|
||||||
|
help="Choose which plugin to use",
|
||||||
|
)
|
||||||
|
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
|
||||||
|
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
|
||||||
|
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
||||||
|
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||||
|
# Training parameters
|
||||||
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process")
|
||||||
|
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||||
|
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mixed_precision",
|
||||||
|
type=str,
|
||||||
|
default="fp16",
|
||||||
|
choices=["fp16", "bf16"],
|
||||||
|
help="Mixed precision",
|
||||||
|
)
|
||||||
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||||
|
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_grad_checkpoint",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use gradient checkpointing",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_flash_attn",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use flash-attention",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_neft",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use NEFTune",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--freeze_non_embeds_params",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Freeze non embeddings parameters",
|
||||||
|
)
|
||||||
|
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
|
||||||
|
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_save_each_epoch",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Skip saving the model checkpoint after each epoch is completed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Additional arguments for 3d plugin.
|
||||||
|
parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.")
|
||||||
|
parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.")
|
||||||
|
parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.")
|
||||||
|
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2])
|
||||||
|
parser.add_argument(
|
||||||
|
"--sp_mode",
|
||||||
|
type=str,
|
||||||
|
default="split_gather",
|
||||||
|
choices=["split_gather", "ring", "all_to_all"],
|
||||||
|
help="SP mode, used for 3d plugin.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_sequence_parallelism",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to enable SP, used for 3d plugin.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Additional arguments for benchmark.
|
||||||
|
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--benchmark", action="store_true", default=False, help="Benchmark performance using random dataset."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
train(args)
|
||||||
|
@ -1 +1 @@
|
|||||||
1.0.0
|
1.1.0
|
||||||
|
Loading…
Reference in New Issue
Block a user