mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-22 11:13:13 +00:00
[feat] Update sync model by tensor, fix tMbs problem, add qwen train benchmark.
This commit is contained in:
parent
d9b5f10d82
commit
ad1ceb0424
@ -55,3 +55,32 @@ def ray_broadcast_tensor_dict(
|
|||||||
if rank == src:
|
if rank == src:
|
||||||
out_dict = tensor_dict
|
out_dict = tensor_dict
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
|
def ray_broadcast_tensor_dict_and_load(
|
||||||
|
producer_obj, tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
|
||||||
|
):
|
||||||
|
rank = cc.get_rank(group_name)
|
||||||
|
if rank == src:
|
||||||
|
metadata = []
|
||||||
|
for k, v in tensor_dict.items():
|
||||||
|
metadata.append((k, v.shape, v.dtype))
|
||||||
|
else:
|
||||||
|
metadata = None
|
||||||
|
metadata = ray_broadcast_object(metadata, src, device, group_name)
|
||||||
|
for k, shape, dtype in metadata:
|
||||||
|
if "consumer_global_step" == k:
|
||||||
|
continue
|
||||||
|
if rank == src:
|
||||||
|
tensor = tensor_dict[k]
|
||||||
|
else:
|
||||||
|
out_dict = {}
|
||||||
|
tensor = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
cc.broadcast(tensor, src, group_name)
|
||||||
|
if rank != src:
|
||||||
|
out_dict[k] = tensor
|
||||||
|
producer_obj.load_state_dict(out_dict)
|
||||||
|
del out_dict
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
if rank == src:
|
||||||
|
out_dict = tensor_dict
|
||||||
|
@ -15,7 +15,7 @@ from colossalai.booster.plugin import HybridParallelPlugin
|
|||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
from .comm import ray_broadcast_tensor_dict
|
from .comm import ray_broadcast_tensor_dict, ray_broadcast_tensor_dict_and_load
|
||||||
from .utils import bind_batch, post_recv, unbind_batch
|
from .utils import bind_batch, post_recv, unbind_batch
|
||||||
|
|
||||||
|
|
||||||
@ -172,6 +172,8 @@ class BaseConsumer:
|
|||||||
)
|
)
|
||||||
self.profiler.enter("step")
|
self.profiler.enter("step")
|
||||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||||
|
del batch
|
||||||
|
del raw_mini_batches_metric_dict
|
||||||
self.profiler.exit("step")
|
self.profiler.exit("step")
|
||||||
self.buffer = self.buffer[
|
self.buffer = self.buffer[
|
||||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||||
@ -303,7 +305,8 @@ class BaseConsumer:
|
|||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||||
ray_broadcast_tensor_dict(
|
ray_broadcast_tensor_dict_and_load(
|
||||||
|
None,
|
||||||
state_dict,
|
state_dict,
|
||||||
src=self.num_producers,
|
src=self.num_producers,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
@ -311,8 +314,12 @@ class BaseConsumer:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
ray_broadcast_tensor_dict(
|
ray_broadcast_tensor_dict_and_load(
|
||||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
None,
|
||||||
|
state_dict,
|
||||||
|
src=self.num_producers,
|
||||||
|
device=self.device,
|
||||||
|
group_name="sync_model",
|
||||||
)
|
)
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
|
@ -62,6 +62,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
batch_size,
|
batch_size,
|
||||||
model_config,
|
model_config,
|
||||||
plugin_config,
|
plugin_config,
|
||||||
|
generate_config,
|
||||||
minibatch_size,
|
minibatch_size,
|
||||||
save_interval=save_interval,
|
save_interval=save_interval,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
|
@ -17,7 +17,7 @@ from ray.util.collective.types import ReduceOp
|
|||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from .comm import ray_broadcast_tensor_dict
|
from .comm import ray_broadcast_tensor_dict, ray_broadcast_tensor_dict_and_load
|
||||||
from .inference_backend import BACKEND_MAP
|
from .inference_backend import BACKEND_MAP
|
||||||
from .utils import safe_append_to_jsonl_file
|
from .utils import safe_append_to_jsonl_file
|
||||||
|
|
||||||
@ -191,6 +191,7 @@ class BaseProducer:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend="hccl", group_name="sync_model")
|
cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend="hccl", group_name="sync_model")
|
||||||
|
cc.init_collective_group(self.num_producers, self.producer_idx, backend="hccl", group_name="producer_group")
|
||||||
|
|
||||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -340,25 +341,16 @@ class BaseProducer:
|
|||||||
print(
|
print(
|
||||||
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||||
)
|
)
|
||||||
state_dict = ray_broadcast_tensor_dict(
|
ray_broadcast_tensor_dict_and_load(
|
||||||
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
self, None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
||||||
)
|
)
|
||||||
if "consumer_global_step" in state_dict:
|
|
||||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
|
||||||
self.load_state_dict(state_dict)
|
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||||
)
|
)
|
||||||
state_dict = ray_broadcast_tensor_dict(
|
ray_broadcast_tensor_dict_and_load(
|
||||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
self, None, self.num_producers, device=self.device, group_name=f"sync_model"
|
||||||
)
|
)
|
||||||
if "consumer_global_step" in state_dict:
|
|
||||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
|
||||||
self.load_state_dict(state_dict)
|
|
||||||
self.profiler.exit("sync_model")
|
|
||||||
del state_dict
|
|
||||||
torch.npu.empty_cache()
|
|
||||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||||
"enable_sleep_mode", False
|
"enable_sleep_mode", False
|
||||||
):
|
):
|
||||||
|
@ -166,6 +166,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
|
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--cpu_offload", action="store_true", default=False, help="Cpu offload.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.train_minibatch_size is None:
|
if args.train_minibatch_size is None:
|
||||||
@ -251,7 +252,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_tokens=args.max_new_tokens, # max new tokens
|
max_tokens=args.max_new_tokens + args.max_prompt_tokens, # max new tokens
|
||||||
include_stop_str_in_output=True,
|
include_stop_str_in_output=True,
|
||||||
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||||
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
||||||
@ -344,6 +345,7 @@ if __name__ == "__main__":
|
|||||||
1, args.train_microbatch_size // args.pipeline_parallel_size
|
1, args.train_microbatch_size // args.pipeline_parallel_size
|
||||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
"zero_stage": args.zero_stage,
|
"zero_stage": args.zero_stage,
|
||||||
|
"cpu_offload": args.cpu_offload,
|
||||||
"max_norm": 1.0,
|
"max_norm": 1.0,
|
||||||
"enable_flash_attention": True,
|
"enable_flash_attention": True,
|
||||||
"sp_size": args.tensor_parallel_size,
|
"sp_size": args.tensor_parallel_size,
|
||||||
|
@ -12,10 +12,7 @@ from transformers.modeling_outputs import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.modeling_attn_mask_utils import (
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
_prepare_4d_causal_attention_mask,
|
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
)
|
|
||||||
from transformers.models.qwen2.modeling_qwen2 import (
|
from transformers.models.qwen2.modeling_qwen2 import (
|
||||||
Qwen2Attention,
|
Qwen2Attention,
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
@ -132,46 +129,20 @@ class Qwen2PipelineForwards:
|
|||||||
else:
|
else:
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
if (
|
|
||||||
not shard_config.enable_flash_attention
|
|
||||||
and attention_mask is not None
|
|
||||||
and self._attn_implementation == "flash_attention_2"
|
|
||||||
and use_cache
|
|
||||||
):
|
|
||||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
|
||||||
if is_padding_right:
|
|
||||||
raise ValueError(
|
|
||||||
"You are attempting to perform batched generation with padding_side='right'"
|
|
||||||
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
|
|
||||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
|
||||||
)
|
|
||||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||||
# for the other stages, hidden_states is the output of the previous stage
|
# for the other stages, hidden_states is the output of the previous stage
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
# in this case, attention_mask is a dict rather than a tensor
|
# in this case, attention_mask is a dict rather than a tensor
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
else:
|
else:
|
||||||
if self._attn_implementation == "flash_attention_2":
|
# 4d mask is passed through the layers
|
||||||
# 2d mask is passed through the layers
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
attention_mask,
|
||||||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
(batch_size, seq_length),
|
||||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
hidden_states,
|
||||||
# the manual implementation that requires a 4D causal mask in all cases.
|
past_key_values_length,
|
||||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
sliding_window=self.config.sliding_window,
|
||||||
attention_mask,
|
)
|
||||||
(batch_size, seq_length),
|
|
||||||
inputs_embeds,
|
|
||||||
past_key_values_length,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 4d mask is passed through the layers
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
hidden_states,
|
|
||||||
past_key_values_length,
|
|
||||||
sliding_window=self.config.sliding_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
if shard_config.enable_sequence_parallelism:
|
if shard_config.enable_sequence_parallelism:
|
||||||
|
@ -161,18 +161,18 @@ class PerformanceEvaluator:
|
|||||||
) * (
|
) * (
|
||||||
1.0 + (seq_len / (6.0 * self.hidden_size)) + (self.vocab_size / (16.0 * self.num_layers * self.hidden_size))
|
1.0 + (seq_len / (6.0 * self.hidden_size)) + (self.vocab_size / (16.0 * self.num_layers * self.hidden_size))
|
||||||
)
|
)
|
||||||
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
|
self.flop += batch_size * (seq_len // 1024) * self.model_numel * (3 + int(self.enable_grad_checkpoint))
|
||||||
|
|
||||||
def on_fit_end(self) -> None:
|
def on_fit_end(self) -> None:
|
||||||
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
|
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
|
||||||
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
|
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
|
||||||
mp_world_size = self.coordinator.world_size // self.dp_world_size
|
mp_world_size = self.coordinator.world_size // self.dp_world_size
|
||||||
avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
||||||
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
||||||
self.coordinator.print_on_master(
|
self.coordinator.print_on_master(
|
||||||
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
||||||
f"avg_throughput: {avg_throughput}"
|
f"avg_throughput: {avg_throughput}"
|
||||||
)
|
)
|
||||||
self.coordinator.print_on_master(
|
self.coordinator.print_on_master(
|
||||||
f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f}, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
|
f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
|
||||||
)
|
)
|
||||||
|
127
examples/language/qwen2/README.md
Normal file
127
examples/language/qwen2/README.md
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
# Pretraining LLaMA-1/2/3: best practices for building LLaMA-1/2/3-like base models
|
||||||
|
### LLaMA3
|
||||||
|
<p align="center">
|
||||||
|
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/LLaMA3-70B-H100.png" width=600/>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
- 70 billion parameter LLaMA3 model training accelerated by 18%
|
||||||
|
|
||||||
|
### LLaMA2
|
||||||
|
<p align="center">
|
||||||
|
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/llama2_pretraining.png" width=600/>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
- 70 billion parameter LLaMA2 model training accelerated by 195%
|
||||||
|
[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
|
||||||
|
|
||||||
|
### LLaMA1
|
||||||
|
<p align="center">
|
||||||
|
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/LLaMA_pretraining.png" width=600/>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
- 65-billion-parameter large model pretraining accelerated by 38%
|
||||||
|
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
> ⚠ This example only has benchmarking script. For training/finetuning, please refer to the [applications/Colossal-LLaMA](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA).
|
||||||
|
|
||||||
|
### 1. Installation
|
||||||
|
|
||||||
|
Please install the latest ColossalAI from source.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
BUILD_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
|
||||||
|
```
|
||||||
|
|
||||||
|
Then install other dependencies.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Shell Script Examples
|
||||||
|
|
||||||
|
For your convenience, we provide some shell scripts to run benchmark with various configurations.
|
||||||
|
|
||||||
|
You can find them in `scripts/benchmark_7B` and `scripts/benchmark_70B` directory. The main command should be in the format of:
|
||||||
|
```bash
|
||||||
|
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
|
||||||
|
benchmark.py --OTHER_CONFIGURATIONS
|
||||||
|
```
|
||||||
|
Here we will show an example of how to run training
|
||||||
|
llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`.
|
||||||
|
|
||||||
|
#### a. Running environment
|
||||||
|
This experiment was performed on 4 computing nodes with 32 A800/H800 80GB GPUs in total for LLaMA-1 65B or LLaMA-2 70B. The nodes are
|
||||||
|
connected with RDMA and GPUs within one node are fully connected with NVLink.
|
||||||
|
|
||||||
|
#### b. Running command
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd scripts/benchmark_7B
|
||||||
|
```
|
||||||
|
|
||||||
|
First, put your host file (`hosts.txt`) in this directory with your real host ip or host name.
|
||||||
|
|
||||||
|
Here is a sample `hosts.txt`:
|
||||||
|
```text
|
||||||
|
hostname1
|
||||||
|
hostname2
|
||||||
|
hostname3
|
||||||
|
hostname4
|
||||||
|
```
|
||||||
|
|
||||||
|
Then add environment variables to script if needed.
|
||||||
|
|
||||||
|
Finally, run the following command to start training:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash gemini.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
If you encounter out-of-memory(OOM) error during training with script `gemini.sh`, changing to script `gemini_auto.sh` might be a solution, since gemini_auto will set a upper limit on GPU memory usage through offloading part of the model parameters and optimizer states back to CPU memory. But there's a trade-off: `gemini_auto.sh` will be a bit slower, since more data are transmitted between CPU and GPU.
|
||||||
|
|
||||||
|
#### c. Results
|
||||||
|
If you run the above command successfully, you will get the following results:
|
||||||
|
`max memory usage: 55491.10 MB, throughput: 24.26 samples/s, TFLOPS/GPU: 167.43`.
|
||||||
|
|
||||||
|
|
||||||
|
## Reference
|
||||||
|
```
|
||||||
|
@article{bian2021colossal,
|
||||||
|
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
|
||||||
|
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
|
||||||
|
journal={arXiv preprint arXiv:2110.14883},
|
||||||
|
year={2021}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@software{openlm2023openllama,
|
||||||
|
author = {Geng, Xinyang and Liu, Hao},
|
||||||
|
title = {OpenLLaMA: An Open Reproduction of LLaMA},
|
||||||
|
month = May,
|
||||||
|
year = 2023,
|
||||||
|
url = {https://github.com/openlm-research/open_llama}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@software{together2023redpajama,
|
||||||
|
author = {Together Computer},
|
||||||
|
title = {RedPajama-Data: An Open Source Recipe to Reproduce LLaMA training dataset},
|
||||||
|
month = April,
|
||||||
|
year = 2023,
|
||||||
|
url = {https://github.com/togethercomputer/RedPajama-Data}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{touvron2023llama,
|
||||||
|
title={Llama: Open and efficient foundation language models},
|
||||||
|
author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others},
|
||||||
|
journal={arXiv preprint arXiv:2302.13971},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
```
|
282
examples/language/qwen2/benchmark.py
Normal file
282
examples/language/qwen2/benchmark.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
import argparse
|
||||||
|
import resource
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from data_utils import RandomDataset
|
||||||
|
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoConfig, Qwen2ForCausalLM
|
||||||
|
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
# ==============================
|
||||||
|
# Constants
|
||||||
|
# ==============================
|
||||||
|
|
||||||
|
# We have lots of qwen2 for your choice!
|
||||||
|
MODEL_CONFIGS = {
|
||||||
|
"7b": Qwen2Config(
|
||||||
|
hidden_size=3584,
|
||||||
|
intermediate_size=18944,
|
||||||
|
num_hidden_layers=28,
|
||||||
|
num_attention_heads=28,
|
||||||
|
num_key_value_heads=4,
|
||||||
|
max_position_embeddings=131072,
|
||||||
|
),
|
||||||
|
"72b": Qwen2Config(
|
||||||
|
hidden_size=8192,
|
||||||
|
intermediate_size=29568,
|
||||||
|
num_hidden_layers=80,
|
||||||
|
num_attention_heads=64,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
max_position_embeddings=131072,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# ==============================
|
||||||
|
# Parse Arguments
|
||||||
|
# ==============================
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration")
|
||||||
|
parser.add_argument("-model", "--model_path", type=str, help="Model path")
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--plugin",
|
||||||
|
choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"],
|
||||||
|
default="gemini",
|
||||||
|
help="Choose which plugin to use",
|
||||||
|
)
|
||||||
|
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
|
||||||
|
parser.add_argument("-s", "--num_steps", type=int, default=10, help="Number of steps to run")
|
||||||
|
parser.add_argument("-i", "--ignore_steps", type=int, default=3, help="Number of steps to ignore")
|
||||||
|
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
|
||||||
|
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
|
||||||
|
parser.add_argument(
|
||||||
|
"-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto"
|
||||||
|
)
|
||||||
|
parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb")
|
||||||
|
parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers")
|
||||||
|
parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini")
|
||||||
|
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
|
||||||
|
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
|
||||||
|
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||||
|
parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
|
||||||
|
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
||||||
|
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||||
|
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||||
|
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||||
|
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||||
|
|
||||||
|
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
|
||||||
|
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||||
|
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||||
|
parser.add_argument(
|
||||||
|
"--nsys",
|
||||||
|
action="store_true",
|
||||||
|
help="Use nsys for profiling. \
|
||||||
|
You should put something like this before colossalai launch: \
|
||||||
|
nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out",
|
||||||
|
)
|
||||||
|
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
|
||||||
|
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||||
|
parser.add_argument("--no_cache", action="store_true")
|
||||||
|
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
||||||
|
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
|
||||||
|
parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p")
|
||||||
|
parser.add_argument("--overlap_allgather", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sp_mode",
|
||||||
|
default="all_to_all",
|
||||||
|
choices=["all_to_all", "ring_attn", "ring", "split_gather"],
|
||||||
|
help="Sequence parallelism mode",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
colossalai.launch_from_torch()
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
|
def empty_init():
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ckpt config for LLaMA3-70B on 64 H100 GPUs
|
||||||
|
hybrid_kwargs = (
|
||||||
|
{
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||||
|
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||||
|
),
|
||||||
|
"num_layers_per_stage": [19, 20, 20, 21],
|
||||||
|
"pp_style": "interleaved",
|
||||||
|
}
|
||||||
|
if args.custom_ckpt
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Booster
|
||||||
|
# ==============================
|
||||||
|
if args.config in MODEL_CONFIGS:
|
||||||
|
config = MODEL_CONFIGS[args.config]
|
||||||
|
else:
|
||||||
|
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||||
|
|
||||||
|
if args.plugin == "3d":
|
||||||
|
scheduler_nodes = None
|
||||||
|
plugin = HybridParallelPlugin(
|
||||||
|
tp_size=args.tp,
|
||||||
|
pp_size=args.pp,
|
||||||
|
pp_style=args.pp_style,
|
||||||
|
num_model_chunks=args.n_chunks,
|
||||||
|
zero_stage=args.zero,
|
||||||
|
sp_size=args.sp,
|
||||||
|
sequence_parallelism_mode=args.sp_mode,
|
||||||
|
enable_sequence_parallelism=args.sp > 1,
|
||||||
|
enable_fused_normalization=get_accelerator().is_available(),
|
||||||
|
enable_flash_attention=args.xformers,
|
||||||
|
microbatch_size=args.mbs,
|
||||||
|
precision="bf16",
|
||||||
|
enable_metadata_cache=not args.no_cache,
|
||||||
|
overlap_allgather=args.overlap_allgather,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
|
scheduler_nodes=scheduler_nodes,
|
||||||
|
**hybrid_kwargs,
|
||||||
|
)
|
||||||
|
elif args.plugin == "3d_cpu":
|
||||||
|
plugin = HybridParallelPlugin(
|
||||||
|
tp_size=args.tp,
|
||||||
|
pp_size=args.pp,
|
||||||
|
pp_style=args.pp_style,
|
||||||
|
num_model_chunks=args.n_chunks,
|
||||||
|
zero_stage=args.zero,
|
||||||
|
cpu_offload=True,
|
||||||
|
enable_fused_normalization=get_accelerator().is_available(),
|
||||||
|
enable_flash_attention=args.xformers,
|
||||||
|
microbatch_size=args.mbs,
|
||||||
|
initial_scale=2**8,
|
||||||
|
precision="bf16",
|
||||||
|
overlap_p2p=args.overlap_p2p,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
|
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Dataset and Dataloader
|
||||||
|
# ==============================
|
||||||
|
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
||||||
|
|
||||||
|
if args.config in MODEL_CONFIGS:
|
||||||
|
config = MODEL_CONFIGS[args.config]
|
||||||
|
else:
|
||||||
|
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||||
|
get_accelerator().manual_seed(42)
|
||||||
|
|
||||||
|
dataset = RandomDataset(
|
||||||
|
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||||
|
)
|
||||||
|
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Model and Optimizer
|
||||||
|
# ==============================
|
||||||
|
init_ctx = (
|
||||||
|
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||||
|
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
|
||||||
|
model = Qwen2ForCausalLM.from_pretrained(
|
||||||
|
MODEL_PATH, trust_remote_code=True, use_flash_attention_2=False, use_cache=False, attn_implementation="eager"
|
||||||
|
)
|
||||||
|
if args.grad_checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
model_numel = 14480488529920
|
||||||
|
num_layers = model.config.num_hidden_layers
|
||||||
|
performance_evaluator = PerformanceEvaluator(
|
||||||
|
model_numel,
|
||||||
|
num_layers,
|
||||||
|
model.config.hidden_size,
|
||||||
|
model.config.vocab_size,
|
||||||
|
args.grad_checkpoint,
|
||||||
|
args.ignore_steps,
|
||||||
|
dp_world_size=dp_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = HybridAdam(model.parameters())
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||||
|
|
||||||
|
torch.set_default_dtype(torch.float)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Booster init max device memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
with get_profile_context(
|
||||||
|
args.profile,
|
||||||
|
args.ignore_steps,
|
||||||
|
1, # avoid creating massive log files
|
||||||
|
save_dir=f"./profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-qwen2-{args.config}",
|
||||||
|
nsys=args.nsys,
|
||||||
|
) as prof:
|
||||||
|
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||||
|
data_iter = iter(dataloader)
|
||||||
|
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
|
||||||
|
performance_evaluator.on_step_start(step)
|
||||||
|
outputs = booster.execute_pipeline(
|
||||||
|
data_iter,
|
||||||
|
model,
|
||||||
|
criterion=lambda outputs, inputs: outputs[0],
|
||||||
|
optimizer=optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
)
|
||||||
|
loss = outputs["loss"]
|
||||||
|
if coordinator.is_last_process():
|
||||||
|
print(f"Step {step} loss: {loss}")
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||||
|
prof.step()
|
||||||
|
else:
|
||||||
|
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
|
||||||
|
performance_evaluator.on_step_start(step)
|
||||||
|
outputs = model(**batch)
|
||||||
|
loss = outputs[0]
|
||||||
|
del outputs # free memory
|
||||||
|
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
print(f"Step {step} loss: {loss}")
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
performance_evaluator.on_step_end(**batch)
|
||||||
|
prof.step()
|
||||||
|
performance_evaluator.on_fit_end()
|
||||||
|
coordinator.print_on_master(f"Max device memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
examples/language/qwen2/data_utils.py
Symbolic link
1
examples/language/qwen2/data_utils.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../data_utils.py
|
0
examples/language/qwen2/hosts.txt
Normal file
0
examples/language/qwen2/hosts.txt
Normal file
10
examples/language/qwen2/hybrid_test_N1C8.sh
Normal file
10
examples/language/qwen2/hybrid_test_N1C8.sh
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
################
|
||||||
|
#Load your environments and modules here
|
||||||
|
################
|
||||||
|
|
||||||
|
export OMP_NUM_THREADS=8
|
||||||
|
|
||||||
|
#hybird: zero2+flash_atten+grad_ckpt+bs4
|
||||||
|
colossalai run --nproc_per_node 8 benchmark.py -m "/home/grpo/models/Qwen2.5-7B/" -p "3d" -x -g --zero 1 -b 32 --mbs 1 --tp 2 --pp 2 -l 4096
|
1
examples/language/qwen2/model_utils.py
Symbolic link
1
examples/language/qwen2/model_utils.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../model_utils.py
|
1
examples/language/qwen2/performance_evaluator.py
Symbolic link
1
examples/language/qwen2/performance_evaluator.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../performance_evaluator.py
|
8
examples/language/qwen2/requirements.txt
Normal file
8
examples/language/qwen2/requirements.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
colossalai>=0.3.6
|
||||||
|
datasets
|
||||||
|
numpy
|
||||||
|
tqdm
|
||||||
|
transformers
|
||||||
|
flash-attn>=2.0.0
|
||||||
|
SentencePiece==0.1.99
|
||||||
|
tensorboard==2.14.0
|
0
examples/language/qwen2/test_ci.sh
Executable file
0
examples/language/qwen2/test_ci.sh
Executable file
Loading…
Reference in New Issue
Block a user