From e417dd004ee166e6787c1c6325bfc037d3f8b83e Mon Sep 17 00:00:00 2001 From: Baizhou Zhang <56809903+Fridge003@users.noreply.github.com> Date: Thu, 8 Jun 2023 11:27:05 +0800 Subject: [PATCH] [example] update opt example using booster api (#3918) --- examples/language/opt/README.md | 32 ++- examples/language/opt/args.py | 120 +++++++++++ examples/language/opt/benchmark.sh | 21 -- examples/language/opt/data.py | 37 ++++ examples/language/opt/opt_benchmark.py | 146 ++++++++++++++ examples/language/opt/opt_train_demo.py | 149 ++++++++++++++ examples/language/opt/requirements.txt | 2 + examples/language/opt/run_benchmark.sh | 30 +++ examples/language/opt/run_demo.sh | 44 ++++ examples/language/opt/run_gemini.sh | 28 --- examples/language/opt/test_ci.sh | 19 +- examples/language/opt/train_gemini_opt.py | 233 ---------------------- 12 files changed, 571 insertions(+), 290 deletions(-) create mode 100644 examples/language/opt/args.py delete mode 100644 examples/language/opt/benchmark.sh create mode 100644 examples/language/opt/data.py create mode 100755 examples/language/opt/opt_benchmark.py create mode 100644 examples/language/opt/opt_train_demo.py create mode 100644 examples/language/opt/run_benchmark.sh create mode 100644 examples/language/opt/run_demo.sh delete mode 100644 examples/language/opt/run_gemini.sh delete mode 100755 examples/language/opt/train_gemini_opt.py diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md index c2fd25457..37e1ff4d9 100644 --- a/examples/language/opt/README.md +++ b/examples/language/opt/README.md @@ -19,15 +19,35 @@ Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/fa The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. -We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before -the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). ## Our Modifications -We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP. -## Quick Start -You can launch training by using the following bash script +We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before +the tokenization). +We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin. + +## Run Demo + +By running the following script: ```bash -bash ./run_gemini.sh +bash run_demo.sh ``` +You will finetune a [facebook/opt-350m](https://huggingface.co/facebook/opt-350m) model on this [dataset](https://huggingface.co/datasets/hugginglearners/netflix-shows), which contains more than 8000 comments on Netflix shows. + +The script can be modified if you want to try another set of hyperparameters or change to another OPT model with different size. + +The demo code is adapted from this [blog](https://medium.com/geekculture/fine-tune-eleutherai-gpt-neo-to-generate-netflix-movie-descriptions-in-only-47-lines-of-code-40c9b4c32475) and the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). + + + +## Run Benchmark + +You can run benchmark for OPT model by running the following script: +```bash +bash run_benchmark.sh +``` +The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your set of hyperparameters for testing. + + + diff --git a/examples/language/opt/args.py b/examples/language/opt/args.py new file mode 100644 index 000000000..16730be7e --- /dev/null +++ b/examples/language/opt/args.py @@ -0,0 +1,120 @@ +from colossalai import get_default_parser + + +def parse_demo_args(): + + parser = get_default_parser() + parser.add_argument( + "--model_name_or_path", + type=str, + default="facebook/opt-350m", + help="Path to pretrained model or model identifier from huggingface.co/models." + ) + parser.add_argument( + "--output_path", + type=str, + default="./output_model.bin", + help="The path of your saved model after finetuning." + ) + parser.add_argument( + "--plugin", + type=str, + default="gemini", + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'." + ) + parser.add_argument( + "--num_epoch", + type=int, + default=10, + help="Number of epochs." + ) + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Batch size (per dp group) for the training dataloader." + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use." + ) + parser.add_argument( + "--warmup_ratio", + type=float, + default=0.1, + help="Ratio of warmup steps against total training steps." + ) + parser.add_argument( + "--weight_decay", + type=float, + default=0.01, + help="Weight decay to use." + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="A seed for reproducible training." + ) + + args = parser.parse_args() + return args + + + +def parse_benchmark_args(): + + parser = get_default_parser() + parser.add_argument( + "--model_name_or_path", + type=str, + default="facebook/opt-125m", + help="Path to pretrained model or model identifier from huggingface.co/models." + ) + parser.add_argument( + "--plugin", + type=str, + default="gemini", + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'." + ) + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Batch size (per dp group) for the training dataloader." + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use." + ) + parser.add_argument( + "--weight_decay", + type=float, + default=0.0, + help="Weight decay to use." + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=20, + help="Total number of training steps to perform." + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="A seed for reproducible training." + ) + parser.add_argument( + "--mem_cap", + type=int, + default=0, + help="Limit on the usage of space for each GPU (in GB)." + ) + args = parser.parse_args() + + return args \ No newline at end of file diff --git a/examples/language/opt/benchmark.sh b/examples/language/opt/benchmark.sh deleted file mode 100644 index 0d04b5e9b..000000000 --- a/examples/language/opt/benchmark.sh +++ /dev/null @@ -1,21 +0,0 @@ -export BS=16 -export MEMCAP=0 -export MODEL="6.7b" -export GPUNUM=1 - -for MODEL in "6.7b" "13b" "1.3b" -do -for GPUNUM in 8 1 -do -for BS in 16 24 32 8 -do -for MEMCAP in 0 40 -do -pkill -9 torchrun -pkill -9 python - -env BS=$BS MEM_CAP=$MEMCAP MODEL=$MODEL GPUNUM=$GPUNUM bash ./run_gemini.sh -done -done -done -done diff --git a/examples/language/opt/data.py b/examples/language/opt/data.py new file mode 100644 index 000000000..6cfffb5fc --- /dev/null +++ b/examples/language/opt/data.py @@ -0,0 +1,37 @@ +import torch +from torch.utils.data import Dataset +from datasets import load_dataset + + +class NetflixDataset(Dataset): + + def __init__(self, tokenizer): + + super().__init__() + + self.tokenizer = tokenizer + self.input_ids = [] + self.attn_masks = [] + self.labels = [] + self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")['description'] + self.max_length = max([len(self.tokenizer.encode(description)) for description in netflix_descriptions]) + + for txt in self.txt_list: + encodings_dict = self.tokenizer('' + txt + '', + truncation=True, + max_length=self.max_length, + padding="max_length") + self.input_ids.append(torch.tensor(encodings_dict['input_ids'])) + self.attn_masks.append(torch.tensor(encodings_dict['attention_mask'])) + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, idx): + return self.input_ids[idx], self.attn_masks[idx] + + +def netflix_collator(data): + return {'input_ids': torch.stack([x[0] for x in data]), + 'attention_mask': torch.stack([x[1] for x in data]), + 'labels': torch.stack([x[0] for x in data])} diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py new file mode 100755 index 000000000..da2be4055 --- /dev/null +++ b/examples/language/opt/opt_benchmark.py @@ -0,0 +1,146 @@ +import time + +import torch +import transformers +from transformers import AutoConfig, OPTForCausalLM +from transformers.utils.versions import require_version +import tqdm + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.tensor import ProcessGroup, ShardSpec +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator + +from args import parse_benchmark_args + +require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt") + + +def format_num(num: int, bytes=False): + """Scale bytes to its proper format, e.g. 1253656 => '1.20MB'""" + factor = 1024 if bytes else 1000 + suffix = "B" if bytes else "" + for unit in ["", " K", " M", " G", " T", " P"]: + if num < factor: + return f"{num:.2f}{unit}{suffix}" + num /= factor + + +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def colo_memory_cap(size_in_GB): + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) + if size_in_GB * (1024**3) < cuda_capacity: + colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) + print(f"Limiting GPU memory usage to {size_in_GB} GB") + + +def main(): + + args = parse_benchmark_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + world_size = coordinator.world_size + + # Manage loggers + disable_existing_loggers() + logger = get_dist_logger() + if coordinator.is_master(): + transformers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + + # Whether to set limit of memory capacity + if args.mem_cap > 0: + colo_memory_cap(args.mem_cap) + + # Build OPT model + # Initialize the model under ColoInitContext if using GeminiPlugin + config = AutoConfig.from_pretrained(args.model_name_or_path) + if args.plugin == 'gemini': + shard_pg = ProcessGroup(tp_degree=world_size) + default_dist_spec = ShardSpec([-1], [world_size]) + with ColoInitContext(device='cpu', + default_dist_spec=default_dist_spec, + default_pg=shard_pg): + model = OPTForCausalLM(config) + else: + model = OPTForCausalLM(config) + logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Set plugin + booster_kwargs = {} + if args.plugin == 'torch_ddp_fp16': + booster_kwargs['mixed_precision'] = 'fp16' + if args.plugin.startswith('torch_ddp'): + plugin = TorchDDPPlugin() + elif args.plugin == 'gemini': + plugin = GeminiPlugin(device=get_current_device(), + placement_policy='cpu', + pin_memory=True, + strict_ddp_mode=True, + initial_scale=2**5) + elif args.plugin == 'low_level_zero': + plugin = LowLevelZeroPlugin(initial_scale=2**5) + logger.info(f"Set plugin as {args.plugin}", ranks=[0]) + + # Set optimizer + optimizer = HybridAdam(model.parameters(), lr=args.learning_rate) + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + model, optimizer, _, _, _ = booster.boost(model, optimizer) + + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + + # Start training. + logger.info(f"Start testing", ranks=[0]) + progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master()) + + torch.cuda.synchronize() + model.train() + start_time = time.time() + + for _ in range(args.max_train_steps): + + input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE) + optimizer.zero_grad() + outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False) + loss = outputs['loss'] + booster.backward(loss, optimizer) + optimizer.step() + + torch.cuda.synchronize() + progress_bar.update(1) + + # Compute Statistics + end_time = time.time() + throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time)) + max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True) + + logger.info(f"Testing finished, " + f"batch size per gpu: {args.batch_size}, " + f"plugin: {args.plugin}, " + f"throughput: {throughput}, " + f"maximum memory usage per gpu: {max_mem}.", + ranks=[0]) + + +if __name__ == "__main__": + main() diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py new file mode 100644 index 000000000..8a2ad5f55 --- /dev/null +++ b/examples/language/opt/opt_train_demo.py @@ -0,0 +1,149 @@ +import time + +import torch +import datasets +import transformers +from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer +from transformers import get_linear_schedule_with_warmup +from transformers.utils.versions import require_version +from tqdm import tqdm + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.tensor import ProcessGroup, ShardSpec +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator + +from args import parse_demo_args +from data import NetflixDataset, netflix_collator + +require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt") +require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt") + + +def move_to_cuda(batch, device): + return {k: v.to(device) for k, v in batch.items()} + + +def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator): + + torch.cuda.synchronize() + model.train() + + with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: + + for batch in pbar: + + # Foward + optimizer.zero_grad() + batch = move_to_cuda(batch, torch.cuda.current_device()) + + outputs = model(use_cache=False, **batch) + loss = outputs['loss'] + + # Backward + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() + + # Print batch loss + pbar.set_postfix({'loss': loss.item()}) + + +def main(): + + args = parse_demo_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + world_size = coordinator.world_size + + # Manage loggers + disable_existing_loggers() + logger = get_dist_logger() + if coordinator.is_master(): + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Build OPT model + # Initialize the model under ColoInitContext if using GeminiPlugin + config = AutoConfig.from_pretrained(args.model_name_or_path) + if args.plugin == 'gemini': + shard_pg = ProcessGroup(tp_degree=world_size) + default_dist_spec = ShardSpec([-1], [world_size]) + with ColoInitContext(device='cpu', + default_dist_spec=default_dist_spec, + default_pg=shard_pg): + model = OPTForCausalLM(config) + else: + model = OPTForCausalLM(config) + logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Set plugin + booster_kwargs = {} + if args.plugin == 'torch_ddp_fp16': + booster_kwargs['mixed_precision'] = 'fp16' + if args.plugin.startswith('torch_ddp'): + plugin = TorchDDPPlugin() + elif args.plugin == 'gemini': + plugin = GeminiPlugin(device=get_current_device(), + placement_policy='cpu', + pin_memory=True, + strict_ddp_mode=True, + initial_scale=2**5) + elif args.plugin == 'low_level_zero': + plugin = LowLevelZeroPlugin(initial_scale=2**5) + logger.info(f"Set plugin as {args.plugin}", ranks=[0]) + + # Prepare tokenizer and dataloader + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) + dataset = NetflixDataset(tokenizer) + dataloader = plugin.prepare_dataloader(dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=netflix_collator) + + # Set optimizer + optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size)) + + # Set lr scheduler + total_steps = len(dataloader) * args.num_epoch + num_warmup_steps = int(args.warmup_ratio * total_steps) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=len(dataloader) * args.num_epoch + ) + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model, + optimizer=optimizer, + dataloader=dataloader, + lr_scheduler=lr_scheduler) + + # Start finetuning + logger.info(f"Start finetuning", ranks=[0]) + for epoch in range(args.num_epoch): + train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator) + + # Finish training and evaluate + logger.info(f"Finish finetuning", ranks=[0]) + booster.save_model(model, args.output_path) + logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0]) + + +if __name__ == "__main__": + main() diff --git a/examples/language/opt/requirements.txt b/examples/language/opt/requirements.txt index 137a69e80..4422216e6 100644 --- a/examples/language/opt/requirements.txt +++ b/examples/language/opt/requirements.txt @@ -1,2 +1,4 @@ colossalai >= 0.1.12 torch >= 1.8.1 +datasets >= 1.8.0 +transformers >= 4.20.0 \ No newline at end of file diff --git a/examples/language/opt/run_benchmark.sh b/examples/language/opt/run_benchmark.sh new file mode 100644 index 000000000..76c5e8601 --- /dev/null +++ b/examples/language/opt/run_benchmark.sh @@ -0,0 +1,30 @@ +set -xe +pip install -r requirements.txt + +export BS=32 +export MEMCAP=0 +export GPUNUM=1 + +# acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b` +export MODEL="125m" + +for BS in 8 32 128 +do +for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" +do +for GPUNUM in 1 4 +do + +MODLE_PATH="facebook/opt-${MODEL}" +torchrun \ + --standalone \ + --nproc_per_node ${GPUNUM} \ + opt_benchmark.py \ + --model_name_or_path ${MODLE_PATH} \ + --mem_cap ${MEMCAP} \ + --plugin ${PLUGIN} \ + --batch_size ${BS} + +done +done +done diff --git a/examples/language/opt/run_demo.sh b/examples/language/opt/run_demo.sh new file mode 100644 index 000000000..0c9759c34 --- /dev/null +++ b/examples/language/opt/run_demo.sh @@ -0,0 +1,44 @@ +set -xe +pip install -r requirements.txt + +# model name or path +MODEL="facebook/opt-350m" + +# path for saving model +OUTPUT_PATH="./output_model.bin" + +# plugin(training strategy) +# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini" +PLUGIN="gemini" + +# number of gpus to use +GPUNUM=4 + +# batch size per gpu +BS=16 + +# learning rate +LR="5e-5" + +# number of epoch +EPOCH=10 + +# weight decay +WEIGHT_DECAY=0.01 + +# ratio of warmup steps +WARMUP_RATIO=0.1 + +# run the script for demo +torchrun \ + --standalone \ + --nproc_per_node ${GPUNUM} \ + opt_train_demo.py \ + --model_name_or_path ${MODEL} \ + --output_path ${OUTPUT_PATH} \ + --plugin ${PLUGIN} \ + --batch_size ${BS} \ + --num_epoch ${EPOCH} \ + --learning_rate ${LR} \ + --weight_decay ${WEIGHT_DECAY} \ + --warmup_ratio ${WARMUP_RATIO} diff --git a/examples/language/opt/run_gemini.sh b/examples/language/opt/run_gemini.sh deleted file mode 100644 index 73f231292..000000000 --- a/examples/language/opt/run_gemini.sh +++ /dev/null @@ -1,28 +0,0 @@ -set -x -export BS=${BS:-16} -export MEMCAP=${MEMCAP:-0} -# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b` -export MODEL=${MODEL:-"125m"} -export GPUNUM=${GPUNUM:-1} -export USE_SHARD_INIT=${USE_SHARD_INIT:-"false"} - -# make directory for logs -mkdir -p ./logs - -if [ ${USE_SHARD_INIT} = "true" ]; then - USE_SHARD_INIT="--shardinit" -else - USE_SHARD_INIT="" -fi - -export MODLE_PATH="facebook/opt-${MODEL}" - -# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 -torchrun \ - --nproc_per_node ${GPUNUM} \ - --master_port 19198 \ - train_gemini_opt.py \ - --mem_cap ${MEMCAP} \ - --model_name_or_path ${MODLE_PATH} \ - ${USE_SHARD_INIT} \ - --batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/language/opt/test_ci.sh b/examples/language/opt/test_ci.sh index 317f602cd..fa14f52b7 100644 --- a/examples/language/opt/test_ci.sh +++ b/examples/language/opt/test_ci.sh @@ -1,4 +1,19 @@ -for GPUNUM in 2 1 +set -xe +pip install -r requirements.txt + +BS=4 +for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" do -env BS=2 MODEL="125m" GPUNUM=$GPUNUM bash ./run_gemini.sh +for GPUNUM in 1 4 +do + +torchrun \ + --standalone \ + --nproc_per_node ${GPUNUM} \ + opt_benchmark.py \ + --model_name_or_path "facebook/opt-125m" \ + --plugin ${PLUGIN} \ + --batch_size ${BS} + +done done diff --git a/examples/language/opt/train_gemini_opt.py b/examples/language/opt/train_gemini_opt.py deleted file mode 100755 index 3614b689d..000000000 --- a/examples/language/opt/train_gemini_opt.py +++ /dev/null @@ -1,233 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) -on a text file or a dataset without using HuggingFace Trainer. - -Here is the full list of checkpoints on the hub that can be fine-tuned by this script: -https://huggingface.co/models?filter=text-generation -""" -# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. - -import time -from functools import partial - -import datasets -import torch -import torch.distributed as dist -import transformers -from transformers import CONFIG_MAPPING, MODEL_MAPPING, AutoConfig, OPTForCausalLM -from transformers.utils.versions import require_version - -import colossalai -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.tensor import ProcessGroup, ShardSpec -from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP - - -def get_data(batch_size, seq_len, vocab_size): - input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) - attention_mask = torch.ones_like(input_ids) - return input_ids, attention_mask - - -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") - -MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) -MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) - - -def get_time_stamp(): - torch.cuda.synchronize() - return time.time() - - -def get_tflops(model_numel, batch_size, seq_len, step_time): - return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) - - -def parse_args(): - parser = colossalai.get_default_parser() - parser.add_argument( - "--model_name_or_path", - type=str, - help="Path to pretrained model or model identifier from huggingface.co/models.", - required=True, - ) - parser.add_argument( - "--config_name", - type=str, - default=None, - help="Pretrained config name or path if not the same as model_name", - ) - parser.add_argument( - "--batch_size", - type=int, - default=8, - help="Batch size (per dp group) for the training dataloader.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") - parser.add_argument( - "--max_train_steps", - type=int, - default=20, - help="Total number of training steps to perform.", - ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--model_type", - type=str, - default=None, - help="Model type to use if training from scratch.", - choices=MODEL_TYPES, - ) - parser.add_argument( - "--shardinit", - action="store_true", - help="Initialize the model with tensor parallel", - ) - parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") - parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") - args = parser.parse_args() - - return args - - -def colo_memory_cap(size_in_GB): - from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device - cuda_capacity = colo_device_memory_capacity(get_current_device()) - if size_in_GB * (1024**3) < cuda_capacity: - colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) - print("Using {} GB of GPU memory".format(size_in_GB)) - - -def main(): - args = parse_args() - disable_existing_loggers() - colossalai.launch_from_torch({}) - logger = get_dist_logger() - is_main_process = dist.get_rank() == 0 - - if is_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - - if args.mem_cap > 0: - colo_memory_cap(args.mem_cap) - - # If passed along, set the training seed now. - if args.seed is not None: - torch.mannul_seed(args.seed) - logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}") - - # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at - # https://huggingface.co/docs/datasets/loading_datasets.html. - - # Load pretrained model - # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - if args.config_name: - config = AutoConfig.from_pretrained(args.config_name) - elif args.model_name_or_path: - config = AutoConfig.from_pretrained(args.model_name_or_path) - else: - config = CONFIG_MAPPING[args.model_type]() - logger.warning("You are instantiating a new config instance from scratch.") - logger.info("Model config has been created", ranks=[0]) - - if args.init_in_cpu: - init_dev = torch.device('cpu') - else: - init_dev = get_current_device() - - # shard init parameters - if args.shardinit: - logger.info("Sharding initialization !", ranks=[0]) - else: - logger.info("Skipping sharding initialization", ranks=[0]) - - world_size = torch.distributed.get_world_size() - shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None - default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None - - # build model - if args.model_name_or_path is None: - logger.info("Train a new model from scratch", ranks=[0]) - with ColoInitContext(device=init_dev, - dtype=torch.half, - default_dist_spec=default_dist_spec, - default_pg=shard_pg): - model = OPTForCausalLM(config) - else: - logger.info("Finetune a pre-trained model", ranks=[0]) - with ColoInitContext(device=init_dev, - dtype=torch.half, - default_dist_spec=default_dist_spec, - default_pg=shard_pg): - model = OPTForCausalLM.from_pretrained(args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - local_files_only=False) - - # enable gradient checkpointing - model.gradient_checkpointing_enable() - - numel = sum([p.numel() for p in model.parameters()]) - PLACEMENT_POLICY = 'cpu' - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=PLACEMENT_POLICY, - pin_memory=True, - strict_ddp_mode=args.shardinit) - optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) - - SEQ_LEN = 1024 - VOCAB_SIZE = 50257 - - get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN) - - model.train() - for step in range(args.max_train_steps): - st_time = time.time() - input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE) - - outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False) - loss = outputs['loss'] - optimizer.backward(loss) - - optimizer.step() - optimizer.zero_grad() - torch.cuda.synchronize() - step_time = time.time() - st_time - step_tflops = get_tflops_func(step_time) - - logger.info("step {} finished, Tflops {}".format(step, step_tflops), ranks=[0]) - - logger.info("Training finished", ranks=[0]) - - -if __name__ == "__main__": - main()