[Colossal-LLaMA] Refactor latest APIs (#6030)

* refactor latest code

* update api

* add dummy dataset

* update Readme

* add setup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update files

* add PP support

* update arguments

* update argument

* reorg folder

* update version

* remove IB infor

* update utils

* update readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update save for zero

* update save

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add apex

* update

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Tong Li 2024-08-28 17:01:58 +08:00 committed by GitHub
parent cc1b0efc17
commit 4a68efb7da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 397 additions and 549 deletions

View File

@ -30,7 +30,7 @@ Colossal-LLaMA
- [Install](#install) - [Install](#install)
- [0. Pre-requisite](#0-pre-requisite) - [0. Pre-requisite](#0-pre-requisite)
- [1. Install required packages](#1-install-required-packages) - [1. Install required packages](#1-install-required-packages)
- [2. Install `xentropy`, `layer_norm` and `rotary`](#2-install-xentropy-layer_norm-and-rotary) - [2. Install Apex](#2-install-apex)
- [How to run](#how-to-run) - [How to run](#how-to-run)
- [1. Init Tokenizer Preparation](#1-init-tokenizer-preparation) - [1. Init Tokenizer Preparation](#1-init-tokenizer-preparation)
- [2. Init Model Preparation](#2-init-model-preparation) - [2. Init Model Preparation](#2-init-model-preparation)
@ -297,17 +297,13 @@ Here is details about CLI arguments:
#### 1. Install required packages #### 1. Install required packages
``` ```
cd Colossal-LLaMA cd Colossal-LLaMA
pip install -r requirements.txt pip install -e .
``` ```
#### 2. Install `xentropy`, `layer_norm` and `rotary`
#### 2. Install Apex
```bash ```bash
git clone git@github.com:Dao-AILab/flash-attention.git git clone git@github.com:NVIDIA/apex.git
# At the root folder # Install from source.
cd csrc/xentropy && pip install .
# At the root folder
cd csrc/layer_norm && pip install .
# At the root folder
cd csrc/rotary && pip install .
``` ```
### How to run ### How to run
@ -427,25 +423,33 @@ Make sure master node can access all nodes (including itself) by ssh without pas
Here is details about CLI arguments: Here is details about CLI arguments:
* Pre-trained model path: `--pretrained`. Path to the pre-trained model in Hugging Face format. * Pre-trained model path: `--pretrained`. Path to the pre-trained model in Hugging Face format.
* Dataset path: `--dataset`. Path to the pre-tokenized dataset. * Dataset path: `--dataset`. Path to the pre-tokenized dataset.
* Booster plugin: `--plugin`. `gemini`, `gemini_auto`, `zero2``zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/). * Booster plugin: `--plugin`. `ddp`,`gemini`, `gemini_auto`, `zero2``zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/).
* Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training. * Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training.
* Save interval: `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. * Save interval: `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
* Checkpoint directory: `--save_dir`. The directory path to save checkpoint and intermediate states. Intermediate states include `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. * Checkpoint directory: `--save_dir`. The directory path to save checkpoint and intermediate states. Intermediate states include `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`.
* Tensorboard directory: `--tensorboard_dir`. The path to save tensorboard logs. * Tensorboard directory: `--tensorboard_dir`. The path to save tensorboard logs.
* Configuration file: `--config_file`. The path to save the configuration file. * Configuration file: `--config_file`. The path to save the configuration file.
* Number of epochs: `--num_epochs`. Number of training epochs. The default value is 1. * Number of epochs: `--num_epochs`. Number of training epochs. The default value is 1.
* Micro batch size: `--micro_batch_size`. Batch size per GPU. The default value is 1. * Batch size: `--batch_size`. Batch size per GPU. The default value is 1. For PP, it refers to number of samples per step.
* Learning rate: `--lr`. The default value is 3e-4. * Learning rate: `--lr`. The default value is 3e-4.
* Max length: `--max_length`. Max context length. The default value is 4096. * Max length: `--max_length`. Max context length. The default value is 4096.
* Mixed precision: `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. * Mixed precision: `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
* Gradient clipping: `--gradient_clipping`. The default value is 1.0. * Gradient clipping: `--gradient_clipping`. The default value is 1.0.
* Weight decay: `-w`, `--weight_decay`. The default value is 0.1. * Weight decay: `--weight_decay`. The default value is 0.1.
* Warmup steps: `-s`, `--warmup_steps`. The default value is calculated by 0.025 warmup ratio. * Warmup steps: `--warmup_steps`. The default value is calculated by 0.025 warmup ratio.
* Gradient checkpointing: `--use_grad_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. * Gradient checkpointing: `--use_grad_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
* Flash attention: `--use_flash_attn`. If you want to use flash attention, you must install `flash-attn` and related packages. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. * Flash attention: `--use_flash_attn`. If you want to use flash attention, you must install `flash-attn` and related packages. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
* Freeze non-embedding parameters: `--freeze_non_embeds_params`. Freeze non-embedding parameters. It can be helpful to align embeddings after extending vocabulary size. * Freeze non-embedding parameters: `--freeze_non_embeds_params`. Freeze non-embedding parameters. It can be helpful to align embeddings after extending vocabulary size.
* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1. * Tensor parallelism size: `--tp`. TP size for 3d parallelism. The default value is 1. Used for 3d plugin.
* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1. * Pipeline parallelism size: `--pp`. PP size for 3d parallelism. The default value is 1. Used for 3d plugin.
* Sequence parallelism size: `--sp`. SP size for 3d parallelism. The default value is 1. Used for 3d plugin.
* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1. Used for 3d plugin.
* Sequence parallelism mode: `--sp_mode`. SP mode, used for 3d plugin. Choose from "split_gather", "ring", "all_to_all".
* Switch for sequence parallelism: `--enable_sequence_parallelism`. Whether to enable SP, used for 3d plugin.
* Zero CPU offload: `--zero_cpu_offload`. Whether to use offloading, used for 3d plugin.
* Micro batch size: `--microbatch_size`. Batch size for each process in PP, used for 3d plugin.
* Number of dummy sample: `--num_samples`. Number of samples for benchmarking.
* Benchmark switch: `--benchmark`. Benchmark performance using random dataset.
##### 4.2 Arguments for Supervised Fine-tuning ##### 4.2 Arguments for Supervised Fine-tuning
We add support for gradient accumulation and NEFTuning for supervised fine-tuning and thus there are two more arguments apart from the arguments listed in [4.1 Arguments for Pretraining](#41-arguments-for-pretraining). We add support for gradient accumulation and NEFTuning for supervised fine-tuning and thus there are two more arguments apart from the arguments listed in [4.1 Arguments for Pretraining](#41-arguments-for-pretraining).

View File

@ -0,0 +1,24 @@
import torch
from torch.utils.data import Dataset
from colossalai.accelerator import get_accelerator
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
)
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}

View File

@ -1,352 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import math
from types import MethodType
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
apply_rotary_pos_emb,
repeat_kv,
)
from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
if get_accelerator().name == "cuda":
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
from flash_attn.ops.rms_norm import rms_norm
def _prepare_decoder_attention_mask(
self: LlamaModel,
attention_mask: torch.BoolTensor,
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
past_key_values_length: int,
) -> Optional[torch.Tensor]:
"""
Decoder attetion mask
"""
if past_key_values_length > 0 and attention_mask is not None:
attention_mask = torch.cat(
tensors=(
torch.full(
size=(input_shape[0], past_key_values_length),
fill_value=True,
dtype=attention_mask.dtype,
device=attention_mask.device,
),
attention_mask,
),
dim=-1,
) # (bsz, past_key_values_length + q_len)
if attention_mask is not None and torch.all(attention_mask):
return None # Faster
return attention_mask
def attention_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
"""
if output_attentions:
logger.warning(
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
"return `None` instead."
)
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
q_slicing, kv_slicing = (
dim // self.config.pretraining_tp
for dim in (
self.num_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
)
) # `Tuple[int, int]`
q_slices, k_slices, v_slices = (
proj.weight.split(slicing, dim=0)
for proj, slicing in (
(self.q_proj, q_slicing),
(self.k_proj, kv_slicing),
(self.v_proj, kv_slicing),
)
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
q, k, v = (
torch.cat(
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
dim=-1,
)
for slices in (q_slices, k_slices, v_slices)
)
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
else:
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
q, k, v = (
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
for states, num_heads in (
(q, self.num_heads),
(k, self.num_key_value_heads),
(v, self.num_key_value_heads),
)
)
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
past_kv_len = 0
if past_key_value is not None:
# if `past_key_value` is not None, `kv_len` > `q_len`.
past_kv_len = past_key_value[0].shape[-2]
kv_len += past_kv_len
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
cos, sin = self.rotary_emb(v, seq_len=kv_len)
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
k = torch.cat([past_key_value[0], k], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)
past_key_value = (k, v) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
key_padding_mask = attention_mask
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
if past_kv_len > 0:
q = torch.cat(
tensors=(
torch.full(
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
fill_value=0.0,
dtype=q.dtype,
device=q.device,
),
q,
),
dim=1,
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
if key_padding_mask is None:
# (bsz, past_kv_len + q_len, num_heads, head_dim)
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
output = rearrange(
output, pattern="... h d -> ... (h d)"
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
else:
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
kv, _, cu_kv_lens, max_kv_len = unpad_input(
hidden_states=torch.stack(tensors=(k, v), dim=2),
attention_mask=key_padding_mask,
)
output_unpad = flash_attn_varlen_kvpacked_func(
q=q,
kv=kv,
cu_seqlens_q=cu_q_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_q_len,
max_seqlen_k=max_kv_len,
dropout_p=0.0,
softmax_scale=None,
causal=True,
)
output = pad_input(
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
indices=indices,
batch=bsz,
seqlen=past_kv_len + q_len,
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
if past_kv_len > 0:
# Strip off the zero query outputs.
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
output = self.o_proj(output) # (bsz, q_len, hidden_size)
return output, None, past_key_value
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Formard function for RMS Norm
"""
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
module.forward = MethodType(attention_forward, module)
if isinstance(module, LlamaModel):
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
if isinstance(module, LlamaRMSNorm):
module.forward = MethodType(rms_norm_forward, module)
elif get_accelerator().name == "npu":
import torch_npu
class NPULlamaAttention(LlamaAttention):
use_flash: bool = True
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.setup()
def setup(self):
self._softmax_scale = 1 / math.sqrt(self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if not self.use_flash:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
else:
attn_output, *_ = torch_npu.npu_fusion_attention(
query_states,
key_states,
value_states,
self.num_heads,
"BNSD",
atten_mask=attention_mask.bool(),
scale=self._softmax_scale,
padding_mask=None,
pre_tockens=65535,
next_tockens=0,
keep_prob=1.0,
inner_precise=0,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum(
[F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
)
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class NPURMSNorm(LlamaRMSNorm):
def forward(self, hidden_states):
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
module.__class__ = NPULlamaAttention
module.setup()
if isinstance(module, LlamaRMSNorm):
module.__class__ = NPURMSNorm

View File

@ -0,0 +1,36 @@
"""
Utils for Colossal-LLaMA
"""
import torch
import torch.distributed as dist
from colossalai.booster import Plugin
def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
if plugin is not None:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
tensor.div_(plugin.dp_size)
else:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor
def get_model_numel(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def format_numel_str(numel: int) -> str:
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f"{numel / B:.2f} B"
elif numel >= M:
return f"{numel / M:.2f} M"
elif numel >= K:
return f"{numel / K:.2f} K"
else:
return f"{numel}"

View File

@ -1,15 +1,15 @@
torch==2.1.2 torch==2.1.2
huggingface-hub huggingface-hub
packaging==24.0 packaging==24.0
colossalai==0.3.6 colossalai>=0.4.0
autoflake==2.2.1 autoflake==2.2.1
black==23.9.1 black==23.9.1
transformers==4.34.1 transformers>=4.39.3
tensorboard==2.14.0 tensorboard==2.14.0
six==1.16.0 six==1.16.0
datasets datasets
ninja==1.11.1 ninja==1.11.1
flash-attn>=2.0.0,<=2.0.5 flash-attn
tqdm tqdm
sentencepiece==0.1.99 sentencepiece==0.1.99
protobuf<=3.20.0 protobuf<=3.20.0

View File

@ -0,0 +1,37 @@
from setuptools import find_packages, setup
def fetch_requirements(path):
with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()]
def fetch_readme():
with open("README.md", encoding="utf-8") as f:
return f.read()
def fetch_version():
with open("version.txt", "r") as f:
return f.read().strip()
setup(
name="colossal_llama",
version=fetch_version(),
packages=find_packages(exclude=("*.egg-info",)),
description="Continual Pre-training and SFT for LLaMA",
long_description=fetch_readme(),
long_description_content_type="text/markdown",
license="Apache Software License 2.0",
url="https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA",
install_requires=fetch_requirements("requirements.txt"),
python_requires=">=3.7",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Environment :: GPU :: NVIDIA CUDA",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: System :: Distributed Computing",
],
)

View File

@ -1,13 +1,20 @@
#!/bin/bash #!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
# NCCL IB environment variables set_n_least_used_CUDA_VISIBLE_DEVICES 8
export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_GID_INDEX=3
export NCCL_IB_TIMEOUT=23
export NCCL_IB_RETRY_CNT=7
export OMP_NUM_THREADS=8
PROJECT_NAME="" PROJECT_NAME=""
PARENT_SAVE_DIR="" PARENT_SAVE_DIR=""

View File

@ -11,24 +11,24 @@ import resource
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
import torch.distributed as dist from colossal_llama.dataset.dummy_dataset import RandomDataset
from colossal_llama.dataset.loader import ( from colossal_llama.dataset.loader import (
DataCollatorForSupervisedDataset, DataCollatorForSupervisedDataset,
StatefulDistributedSampler, StatefulDistributedSampler,
load_tokenized_dataset, load_tokenized_dataset,
) )
from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama.utils.froze import freeze_non_embeds_parameters from colossal_llama.utils.froze import freeze_non_embeds_parameters
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer, LlamaForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai import colossalai
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -36,109 +36,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
def get_model_numel(model: torch.nn.Module) -> int: def train(args) -> None:
return sum(p.numel() for p in model.parameters())
def format_numel_str(numel: int) -> str:
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f"{numel / B:.2f} B"
elif numel >= M:
return f"{numel / M:.2f} M"
elif numel >= K:
return f"{numel / K:.2f} K"
else:
return f"{numel}"
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor = tensor.data
tensor.div_(dist.get_world_size())
return tensor
def main() -> None:
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained",
type=str,
default=None,
help="Address of the pre-trained modeling",
)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument(
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
help="Choose which plugin to use",
)
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
parser.add_argument(
"--mixed_precision",
type=str,
default="fp16",
choices=["fp16", "bf16"],
help="Mixed precision",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument(
"--use_grad_checkpoint",
action="store_true",
default=False,
help="Use gradient checkpointing",
)
parser.add_argument(
"--use_flash_attn",
action="store_true",
default=False,
help="Use flash-attention",
)
parser.add_argument(
"--use_neft",
action="store_true",
default=False,
help="Use NEFTune",
)
parser.add_argument(
"--freeze_non_embeds_params",
action="store_true",
default=False,
help="Freeze non embeddings parameters",
)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--zero", type=int, default=1)
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
parser.add_argument(
"--skip_save_each_epoch",
action="store_true",
default=False,
help="skip saving the model checkpoint after each epoch is completed.",
)
args = parser.parse_args()
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
# ============================== # ==============================
# Initialize Distributed Training # Initialize Distributed Training
# ============================== # ==============================
@ -147,21 +45,27 @@ def main() -> None:
coordinator = DistCoordinator() coordinator = DistCoordinator()
# ============================== # ==============================
# Initialize Tensorboard # Initialize Tensorboard and Save Config
# ============================== # ==============================
if coordinator.is_master(): if coordinator.is_master():
os.makedirs(args.tensorboard_dir, exist_ok=True) os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir) writer = SummaryWriter(args.tensorboard_dir)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
# ============================== # ==============================
# Initialize Booster # Initialize Booster
# ============================== # ==============================
if args.plugin == "gemini": if args.plugin == "ddp":
plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False)
elif args.plugin == "gemini":
plugin = GeminiPlugin( plugin = GeminiPlugin(
precision=args.mixed_precision, precision=args.mixed_precision,
initial_scale=2**16, initial_scale=2**16,
max_norm=args.grad_clip, max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1), enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_flash_attention=args.use_flash_attn,
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
plugin = GeminiPlugin( plugin = GeminiPlugin(
@ -170,6 +74,7 @@ def main() -> None:
initial_scale=2**16, initial_scale=2**16,
max_norm=args.grad_clip, max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1), enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_flash_attention=args.use_flash_attn,
) )
elif args.plugin == "zero2": elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin( plugin = LowLevelZeroPlugin(
@ -189,10 +94,17 @@ def main() -> None:
elif args.plugin == "3d": elif args.plugin == "3d":
plugin = HybridParallelPlugin( plugin = HybridParallelPlugin(
tp_size=args.tp, tp_size=args.tp,
pp_size=1, pp_size=args.pp,
zero_stage=args.zero, sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip, max_norm=args.grad_clip,
precision=args.mixed_precision, precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
) )
else: else:
raise ValueError(f"Unknown plugin {args.plugin}") raise ValueError(f"Unknown plugin {args.plugin}")
@ -210,24 +122,38 @@ def main() -> None:
tokenizer.add_bos_token = False tokenizer.add_bos_token = False
tokenizer.add_eos_token = False tokenizer.add_eos_token = False
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}") coordinator.print_on_master(
coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}") f"Training Info:\nConfig file: {args.config_file} \nTensorboard logs: {args.tensorboard_dir} \nModel checkpoint: {args.save_dir}"
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}") )
if args.benchmark:
coordinator.print_on_master(f"Run benchmark with {args.num_samples} random samples.")
dataset = RandomDataset(
num_samples=args.num_samples, max_length=args.max_length, vocab_size=tokenizer.vocab_size
)
dataloader = plugin.prepare_dataloader(
dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
seed=42,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
coordinator.print_on_master(f"Load dataset: {args.dataset}") coordinator.print_on_master(f"Load dataset: {args.dataset}")
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
data_collator = DataCollatorForSupervisedDataset( data_collator = DataCollatorForSupervisedDataset(
tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode
) )
dataloader = plugin.prepare_dataloader( dataloader = plugin.prepare_dataloader(
dataset=dataset, dataset=dataset,
batch_size=args.micro_batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
collate_fn=data_collator, collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler, distributed_sampler_cls=StatefulDistributedSampler,
) )
coordinator.print_on_master( coordinator.print_on_master(
f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
) )
@ -241,7 +167,19 @@ def main() -> None:
else nullcontext() else nullcontext()
) )
with init_ctx: with init_ctx:
model = LlamaForCausalLM.from_pretrained(args.pretrained) if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
# Freeze part of parameters. # Freeze part of parameters.
if args.freeze_non_embeds_params: if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model) freeze_non_embeds_parameters(model=model)
@ -251,9 +189,6 @@ def main() -> None:
if args.use_grad_checkpoint: if args.use_grad_checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
if args.use_flash_attn:
replace_with_flash_attention(model=model)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
model_numel = get_model_numel(model) model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
@ -342,6 +277,62 @@ def main() -> None:
for epoch in range(start_epoch, args.num_epochs): for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch) dataloader.sampler.set_epoch(epoch=epoch)
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
data_iter = iter(dataloader)
step_bar = tqdm(
range(len(dataloader)),
desc="Step",
disable=not (coordinator._local_rank == coordinator._world_size - 1),
)
for step in step_bar:
outputs = booster.execute_pipeline(
data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
return_loss=True,
)
loss = outputs["loss"]
if booster.plugin.stage_manager.is_last_stage():
global_loss = all_reduce_mean(loss, plugin)
if coordinator._local_rank == coordinator._world_size - 1:
step_bar.set_postfix({"train/loss": global_loss.item()})
optimizer.step()
optimizer.zero_grad()
# Save modeling.
save_model_condition = args.save_interval > 0 and (step + 1) % args.save_interval == 0
if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)
accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)
if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)
else:
pbar = tqdm( pbar = tqdm(
desc=f"Epoch {epoch}", desc=f"Epoch {epoch}",
disable=not coordinator.is_master(), disable=not coordinator.is_master(),
@ -378,7 +369,6 @@ def main() -> None:
pbar.update() pbar.update()
# Save modeling. # Save modeling.
save_model_condition = ( save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0 args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
) )
@ -386,7 +376,7 @@ def main() -> None:
if not args.skip_save_each_epoch: if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader) save_model_condition = save_model_condition or (step + 1) == len(dataloader)
if save_model_condition: if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states") coordinator.print_on_master("\nStart saving model checkpoint with running states")
if args.use_neft: if args.use_neft:
@ -402,7 +392,7 @@ def main() -> None:
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
epoch=epoch, epoch=epoch,
step=step + 1, step=step + 1,
batch_size=args.micro_batch_size, batch_size=args.batch_size,
coordinator=coordinator, coordinator=coordinator,
) )
coordinator.print_on_master( coordinator.print_on_master(
@ -426,6 +416,7 @@ def main() -> None:
deactivate_neftune(model, handle) deactivate_neftune(model, handle)
# Final save. # Final save.
if not args.benchmark:
coordinator.print_on_master("Start saving final model checkpoint") coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
@ -434,4 +425,105 @@ def main() -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() parser = argparse.ArgumentParser()
# Basic training information.
parser.add_argument(
"--pretrained",
type=str,
default=None,
help="Address of the pre-trained model",
)
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint for continuous training.")
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument(
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"],
help="Choose which plugin to use",
)
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
# Training parameters
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
parser.add_argument(
"--mixed_precision",
type=str,
default="fp16",
choices=["fp16", "bf16"],
help="Mixed precision",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument(
"--use_grad_checkpoint",
action="store_true",
default=False,
help="Use gradient checkpointing",
)
parser.add_argument(
"--use_flash_attn",
action="store_true",
default=False,
help="Use flash-attention",
)
parser.add_argument(
"--use_neft",
action="store_true",
default=False,
help="Use NEFTune",
)
parser.add_argument(
"--freeze_non_embeds_params",
action="store_true",
default=False,
help="Freeze non embeddings parameters",
)
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
parser.add_argument(
"--skip_save_each_epoch",
action="store_true",
default=False,
help="Skip saving the model checkpoint after each epoch is completed.",
)
# Additional arguments for 3d plugin.
parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.")
parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.")
parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2])
parser.add_argument(
"--sp_mode",
type=str,
default="split_gather",
choices=["split_gather", "ring", "all_to_all"],
help="SP mode, used for 3d plugin.",
)
parser.add_argument(
"--enable_sequence_parallelism",
default=False,
action="store_true",
help="Whether to enable SP, used for 3d plugin.",
)
parser.add_argument(
"--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin."
)
parser.add_argument(
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
)
# Additional arguments for benchmark.
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")
parser.add_argument(
"--benchmark", action="store_true", default=False, help="Benchmark performance using random dataset."
)
args = parser.parse_args()
train(args)

View File

@ -1 +1 @@
1.0.0 1.1.0