mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[Coati] first commit (#3283)
This commit is contained in:
94
applications/Chat/benchmarks/README.md
Normal file
94
applications/Chat/benchmarks/README.md
Normal file
@@ -0,0 +1,94 @@
|
||||
# Benchmarks
|
||||
|
||||
## Benchmark GPT on dummy prompt data
|
||||
|
||||
We provide various GPT models (string in parentheses is the corresponding model name used in this script):
|
||||
|
||||
- GPT2-S (s)
|
||||
- GPT2-M (m)
|
||||
- GPT2-L (l)
|
||||
- GPT2-XL (xl)
|
||||
- GPT2-4B (4b)
|
||||
- GPT2-6B (6b)
|
||||
- GPT2-8B (8b)
|
||||
- GPT2-10B (10b)
|
||||
- GPT2-12B (12b)
|
||||
- GPT2-15B (15b)
|
||||
- GPT2-18B (18b)
|
||||
- GPT2-20B (20b)
|
||||
- GPT2-24B (24b)
|
||||
- GPT2-28B (28b)
|
||||
- GPT2-32B (32b)
|
||||
- GPT2-36B (36b)
|
||||
- GPT2-40B (40b)
|
||||
- GPT3 (175b)
|
||||
|
||||
We also provide various training strategies:
|
||||
|
||||
- ddp: torch DDP
|
||||
- colossalai_gemini: ColossalAI GeminiDDP with `placement_policy="cuda"`, like zero3
|
||||
- colossalai_gemini_cpu: ColossalAI GeminiDDP with `placement_policy="cpu"`, like zero3-offload
|
||||
- colossalai_zero2: ColossalAI zero2
|
||||
- colossalai_zero2_cpu: ColossalAI zero2-offload
|
||||
- colossalai_zero1: ColossalAI zero1
|
||||
- colossalai_zero1_cpu: ColossalAI zero1-offload
|
||||
|
||||
We only support `torchrun` to launch now. E.g.
|
||||
|
||||
```shell
|
||||
# run GPT2-S on single-node single-GPU with min batch size
|
||||
torchrun --standalone --nproc_per_node 1 benchmark_gpt_dummy.py --model s --strategy ddp --experience_batch_size 1 --train_batch_size 1
|
||||
# run GPT2-XL on single-node 4-GPU
|
||||
torchrun --standalone --nproc_per_node 4 benchmark_gpt_dummy.py --model xl --strategy colossalai_zero2
|
||||
# run GPT3 on 8-node 8-GPU
|
||||
torchrun --nnodes 8 --nproc_per_node 8 \
|
||||
--rdzv_id=$JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$HOST_NODE_ADDR \
|
||||
benchmark_gpt_dummy.py --model 175b --strategy colossalai_gemini
|
||||
```
|
||||
|
||||
> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU.
|
||||
|
||||
In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic.
|
||||
|
||||
We also provide a simple shell script to run a set of benchmarks. But it only supports benchmark on single node. However, it's easy to run on multi-nodes by modifying launch command in this script.
|
||||
|
||||
Usage:
|
||||
|
||||
```shell
|
||||
# run for GPUS=(1 2 4 8) x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
|
||||
./benchmark_gpt_dummy.sh
|
||||
# run for GPUS=2 x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
|
||||
./benchmark_gpt_dummy.sh 2
|
||||
# run for GPUS=2 x strategy=ddp x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
|
||||
./benchmark_gpt_dummy.sh 2 ddp
|
||||
# run for GPUS=2 x strategy=ddp x model=l x batch_size=(1 2 4 8 16 32 64 128 256)
|
||||
./benchmark_gpt_dummy.sh 2 ddp l
|
||||
```
|
||||
|
||||
## Benchmark OPT with LoRA on dummy prompt data
|
||||
|
||||
We provide various OPT models (string in parentheses is the corresponding model name used in this script):
|
||||
|
||||
- OPT-125M (125m)
|
||||
- OPT-350M (350m)
|
||||
- OPT-700M (700m)
|
||||
- OPT-1.3B (1.3b)
|
||||
- OPT-2.7B (2.7b)
|
||||
- OPT-3.5B (3.5b)
|
||||
- OPT-5.5B (5.5b)
|
||||
- OPT-6.7B (6.7b)
|
||||
- OPT-10B (10b)
|
||||
- OPT-13B (13b)
|
||||
|
||||
We only support `torchrun` to launch now. E.g.
|
||||
|
||||
```shell
|
||||
# run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size
|
||||
torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py --model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0
|
||||
# run OPT-350M with lora_rank=4 on single-node 4-GPU
|
||||
torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py --model 350m --strategy colossalai_zero2 --lora_rank 4
|
||||
```
|
||||
|
||||
> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU.
|
||||
|
||||
In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic.
|
184
applications/Chat/benchmarks/benchmark_gpt_dummy.py
Normal file
184
applications/Chat/benchmarks/benchmark_gpt_dummy.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.gpt import GPTActor, GPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.callbacks import PerformanceEvaluator
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
||||
numel = sum(p.numel() for p in model.parameters())
|
||||
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
return numel
|
||||
|
||||
|
||||
def preprocess_batch(samples) -> dict:
|
||||
input_ids = torch.stack(samples)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
||||
|
||||
|
||||
def print_rank_0(*args, **kwargs) -> None:
|
||||
if dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def print_model_numel(model_dict: dict) -> None:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
outputs = ''
|
||||
for name, numel in model_dict.items():
|
||||
outputs += f'{name}: '
|
||||
if numel >= B:
|
||||
outputs += f'{numel / B:.2f} B\n'
|
||||
elif numel >= M:
|
||||
outputs += f'{numel / M:.2f} M\n'
|
||||
elif numel >= K:
|
||||
outputs += f'{numel / K:.2f} K\n'
|
||||
else:
|
||||
outputs += f'{numel}\n'
|
||||
print_rank_0(outputs)
|
||||
|
||||
|
||||
def get_gpt_config(model_name: str) -> GPT2Config:
|
||||
model_map = {
|
||||
's': GPT2Config(),
|
||||
'm': GPT2Config(n_embd=1024, n_layer=24, n_head=16),
|
||||
'l': GPT2Config(n_embd=1280, n_layer=36, n_head=20),
|
||||
'xl': GPT2Config(n_embd=1600, n_layer=48, n_head=25),
|
||||
'2b': GPT2Config(n_embd=2048, n_layer=40, n_head=16),
|
||||
'4b': GPT2Config(n_embd=2304, n_layer=64, n_head=16),
|
||||
'6b': GPT2Config(n_embd=4096, n_layer=30, n_head=16),
|
||||
'8b': GPT2Config(n_embd=4096, n_layer=40, n_head=16),
|
||||
'10b': GPT2Config(n_embd=4096, n_layer=50, n_head=16),
|
||||
'12b': GPT2Config(n_embd=4096, n_layer=60, n_head=16),
|
||||
'15b': GPT2Config(n_embd=4096, n_layer=78, n_head=16),
|
||||
'18b': GPT2Config(n_embd=4096, n_layer=90, n_head=16),
|
||||
'20b': GPT2Config(n_embd=8192, n_layer=25, n_head=16),
|
||||
'24b': GPT2Config(n_embd=8192, n_layer=30, n_head=16),
|
||||
'28b': GPT2Config(n_embd=8192, n_layer=35, n_head=16),
|
||||
'32b': GPT2Config(n_embd=8192, n_layer=40, n_head=16),
|
||||
'36b': GPT2Config(n_embd=8192, n_layer=45, n_head=16),
|
||||
'40b': GPT2Config(n_embd=8192, n_layer=50, n_head=16),
|
||||
'175b': GPT2Config(n_positions=2048, n_embd=12288, n_layer=96, n_head=96),
|
||||
}
|
||||
try:
|
||||
return model_map[model_name]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown model "{model_name}"')
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_gemini_cpu':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2_cpu':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
elif args.strategy == 'colossalai_zero1':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero1_cpu':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
model_config = get_gpt_config(args.model)
|
||||
|
||||
with strategy.model_init_context():
|
||||
actor = GPTActor(config=model_config).cuda()
|
||||
critic = GPTCritic(config=model_config).cuda()
|
||||
|
||||
initial_model = deepcopy(actor).cuda()
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
||||
|
||||
actor_numel = get_model_numel(actor, strategy)
|
||||
critic_numel = get_model_numel(critic, strategy)
|
||||
initial_model_numel = get_model_numel(initial_model, strategy)
|
||||
reward_model_numel = get_model_numel(reward_model, strategy)
|
||||
print_model_numel({
|
||||
'Actor': actor_numel,
|
||||
'Critic': critic_numel,
|
||||
'Initial model': initial_model_numel,
|
||||
'Reward model': reward_model_numel
|
||||
})
|
||||
performance_evaluator = PerformanceEvaluator(actor_numel,
|
||||
critic_numel,
|
||||
initial_model_numel,
|
||||
reward_model_numel,
|
||||
enable_grad_checkpoint=False,
|
||||
ignore_episodes=1)
|
||||
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='s')
|
||||
parser.add_argument('--strategy',
|
||||
choices=[
|
||||
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
|
||||
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
|
||||
],
|
||||
default='ddp')
|
||||
parser.add_argument('--num_episodes', type=int, default=3)
|
||||
parser.add_argument('--max_timesteps', type=int, default=8)
|
||||
parser.add_argument('--update_timesteps', type=int, default=8)
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
45
applications/Chat/benchmarks/benchmark_gpt_dummy.sh
Executable file
45
applications/Chat/benchmarks/benchmark_gpt_dummy.sh
Executable file
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env bash
|
||||
# Usage: $0 <?number-of-gpus> <?strategy> <?model>
|
||||
set -xu
|
||||
|
||||
BASE=$(realpath $(dirname $0))
|
||||
|
||||
|
||||
PY_SCRIPT=${BASE}/benchmark_gpt_dummy.py
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
function tune_batch_size() {
|
||||
# we found when experience batch size is equal to train batch size
|
||||
# peak CUDA memory usage of making experience phase is less than or equal to that of training phase
|
||||
# thus, experience batch size can be larger than or equal to train batch size
|
||||
for bs in 1 2 4 8 16 32 64 128 256; do
|
||||
torchrun --standalone --nproc_per_node $1 $PY_SCRIPT --model $2 --strategy $3 --experience_batch_size $bs --train_batch_size $bs || return 1
|
||||
done
|
||||
}
|
||||
|
||||
if [ $# -eq 0 ]; then
|
||||
num_gpus=(1 2 4 8)
|
||||
else
|
||||
num_gpus=($1)
|
||||
fi
|
||||
|
||||
if [ $# -le 1 ]; then
|
||||
strategies=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu")
|
||||
else
|
||||
strategies=($2)
|
||||
fi
|
||||
|
||||
if [ $# -le 2 ]; then
|
||||
models=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b")
|
||||
else
|
||||
models=($3)
|
||||
fi
|
||||
|
||||
|
||||
for num_gpu in ${num_gpus[@]}; do
|
||||
for strategy in ${strategies[@]}; do
|
||||
for model in ${models[@]}; do
|
||||
tune_batch_size $num_gpu $model $strategy || break
|
||||
done
|
||||
done
|
||||
done
|
179
applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
Normal file
179
applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.opt import OPTActor, OPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.callbacks import PerformanceEvaluator
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
||||
numel = sum(p.numel() for p in model.parameters())
|
||||
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
return numel
|
||||
|
||||
|
||||
def preprocess_batch(samples) -> dict:
|
||||
input_ids = torch.stack(samples)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
||||
|
||||
|
||||
def print_rank_0(*args, **kwargs) -> None:
|
||||
if dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def print_model_numel(model_dict: dict) -> None:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
outputs = ''
|
||||
for name, numel in model_dict.items():
|
||||
outputs += f'{name}: '
|
||||
if numel >= B:
|
||||
outputs += f'{numel / B:.2f} B\n'
|
||||
elif numel >= M:
|
||||
outputs += f'{numel / M:.2f} M\n'
|
||||
elif numel >= K:
|
||||
outputs += f'{numel / K:.2f} K\n'
|
||||
else:
|
||||
outputs += f'{numel}\n'
|
||||
print_rank_0(outputs)
|
||||
|
||||
|
||||
def get_gpt_config(model_name: str) -> OPTConfig:
|
||||
model_map = {
|
||||
'125m': OPTConfig.from_pretrained('facebook/opt-125m'),
|
||||
'350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
|
||||
'700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
|
||||
'1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'),
|
||||
'2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'),
|
||||
'3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
|
||||
'5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
|
||||
'6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'),
|
||||
'10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
|
||||
'13b': OPTConfig.from_pretrained('facebook/opt-13b'),
|
||||
}
|
||||
try:
|
||||
return model_map[model_name]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown model "{model_name}"')
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_gemini_cpu':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2_cpu':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
elif args.strategy == 'colossalai_zero1':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero1_cpu':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac)
|
||||
|
||||
model_config = get_gpt_config(args.model)
|
||||
|
||||
with strategy.model_init_context():
|
||||
actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda()
|
||||
critic = OPTCritic(config=model_config, lora_rank=args.lora_rank).cuda()
|
||||
|
||||
initial_model = deepcopy(actor).cuda()
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
||||
|
||||
actor_numel = get_model_numel(actor, strategy)
|
||||
critic_numel = get_model_numel(critic, strategy)
|
||||
initial_model_numel = get_model_numel(initial_model, strategy)
|
||||
reward_model_numel = get_model_numel(reward_model, strategy)
|
||||
print_model_numel({
|
||||
'Actor': actor_numel,
|
||||
'Critic': critic_numel,
|
||||
'Initial model': initial_model_numel,
|
||||
'Reward model': reward_model_numel
|
||||
})
|
||||
performance_evaluator = PerformanceEvaluator(actor_numel,
|
||||
critic_numel,
|
||||
initial_model_numel,
|
||||
reward_model_numel,
|
||||
enable_grad_checkpoint=False,
|
||||
ignore_episodes=1)
|
||||
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='125m')
|
||||
parser.add_argument('--strategy',
|
||||
choices=[
|
||||
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
|
||||
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
|
||||
],
|
||||
default='ddp')
|
||||
parser.add_argument('--num_episodes', type=int, default=3)
|
||||
parser.add_argument('--max_timesteps', type=int, default=8)
|
||||
parser.add_argument('--update_timesteps', type=int, default=8)
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=4)
|
||||
parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
Reference in New Issue
Block a user