mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[app] add chatgpt application (#2698)
This commit is contained in:
105
applications/ChatGPT/examples/README.md
Normal file
105
applications/ChatGPT/examples/README.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# Examples
|
||||
|
||||
## Install requirements
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Train with dummy prompt data
|
||||
|
||||
This script supports 3 strategies:
|
||||
|
||||
- naive
|
||||
- ddp
|
||||
- colossalai
|
||||
|
||||
It uses random generated prompt data.
|
||||
|
||||
Naive strategy only support single GPU training:
|
||||
|
||||
```shell
|
||||
python train_dummy.py --strategy naive
|
||||
# display cli help
|
||||
python train_dummy.py -h
|
||||
```
|
||||
|
||||
DDP strategy and ColossalAI strategy support multi GPUs training:
|
||||
|
||||
```shell
|
||||
# run DDP on 2 GPUs
|
||||
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy ddp
|
||||
# run ColossalAI on 2 GPUs
|
||||
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai
|
||||
```
|
||||
|
||||
## Train with real prompt data
|
||||
|
||||
We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-chatgpt-prompts) as example dataset. It is a small dataset with hundreds of prompts.
|
||||
|
||||
You should download `prompts.csv` first.
|
||||
|
||||
This script also supports 3 strategies.
|
||||
|
||||
```shell
|
||||
# display cli help
|
||||
python train_dummy.py -h
|
||||
# run naive on 1 GPU
|
||||
python train_prompts.py prompts.csv --strategy naive
|
||||
# run DDP on 2 GPUs
|
||||
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy ddp
|
||||
# run ColossalAI on 2 GPUs
|
||||
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai
|
||||
```
|
||||
|
||||
## Train the reward model
|
||||
We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt.
|
||||
|
||||
You can download the dataset from huggingface automatically.
|
||||
|
||||
Use these code to train your reward model.
|
||||
|
||||
```shell
|
||||
# Naive reward model training
|
||||
python train_reward_model.py --pretrain <your model path>
|
||||
# if to use LoRA
|
||||
python train_reward_model.py --pretrain <your model path> --lora_rank 16
|
||||
```
|
||||
|
||||
## Support Model
|
||||
|
||||
### GPT
|
||||
- [ ] GPT2-S (s)
|
||||
- [ ] GPT2-M (m)
|
||||
- [ ] GPT2-L (l)
|
||||
- [ ] GPT2-XL (xl)
|
||||
- [ ] GPT2-4B (4b)
|
||||
- [ ] GPT2-6B (6b)
|
||||
- [ ] GPT2-8B (8b)
|
||||
- [ ] GPT2-10B (10b)
|
||||
- [ ] GPT2-12B (12b)
|
||||
- [ ] GPT2-15B (15b)
|
||||
- [ ] GPT2-18B (18b)
|
||||
- [ ] GPT2-20B (20b)
|
||||
- [ ] GPT2-24B (24b)
|
||||
- [ ] GPT2-28B (28b)
|
||||
- [ ] GPT2-32B (32b)
|
||||
- [ ] GPT2-36B (36b)
|
||||
- [ ] GPT2-40B (40b)
|
||||
- [ ] GPT3 (175b)
|
||||
|
||||
### BLOOM
|
||||
- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m)
|
||||
- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1)
|
||||
- [ ] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
|
||||
- [ ] [BLOOM-7b](https://huggingface.co/bigscience/bloomz-7b1)
|
||||
- [ ] BLOOM-175b
|
||||
|
||||
### OPT
|
||||
- [x] [OPT-125M](https://huggingface.co/facebook/opt-125m)
|
||||
- [x] [OPT-350M](https://huggingface.co/facebook/opt-350m)
|
||||
- [ ] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b)
|
||||
- [ ] [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b)
|
||||
- [ ] [OPT-6.7B](https://huggingface.co/facebook/opt-6.7b)
|
||||
- [ ] [OPT-13B](https://huggingface.co/facebook/opt-13b)
|
||||
- [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b)
|
1
applications/ChatGPT/examples/requirements.txt
Normal file
1
applications/ChatGPT/examples/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
pandas>=1.4.1
|
27
applications/ChatGPT/examples/test_ci.sh
Executable file
27
applications/ChatGPT/examples/test_ci.sh
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -xue
|
||||
|
||||
if [ -z "$PROMPT_PATH" ]; then
|
||||
echo "Please set \$PROMPT_PATH to the path to prompts csv."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
BASE=$(realpath $(dirname $0))
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
# install requirements
|
||||
pip install -r ${BASE}/requirements.txt
|
||||
|
||||
# train dummy
|
||||
python ${BASE}/train_dummy.py --strategy naive --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
|
||||
for strategy in ddp colossalai_gemini colossalai_zero2; do
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py --strategy ${strategy} --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
|
||||
done
|
||||
|
||||
# train prompts
|
||||
python ${BASE}/train_prompts.py $PROMPT_PATH --strategy naive --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3
|
||||
for strategy in ddp colossalai_gemini colossalai_zero2; do
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH --strategy ${strategy} --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
|
||||
done
|
121
applications/ChatGPT/examples/train_dummy.py
Normal file
121
applications/ChatGPT/examples/train_dummy.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.nn.generation_utils import (
|
||||
bloom_prepare_inputs_fn,
|
||||
gpt_prepare_inputs_fn,
|
||||
opt_prepare_inputs_fn,
|
||||
update_model_kwargs_fn,
|
||||
)
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def preprocess_batch(samples):
|
||||
input_ids = torch.stack(samples)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
||||
|
||||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'gpt2':
|
||||
actor = GPTActor().cuda()
|
||||
critic = GPTCritic().cuda()
|
||||
elif args.model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
elif args.model == 'opt':
|
||||
actor = OPTActor().cuda()
|
||||
critic = OPTCritic().cuda()
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
initial_model = deepcopy(actor).cuda()
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
prepare_inputs_fn = gpt_prepare_inputs_fn
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
prepare_inputs_fn = bloom_prepare_inputs_fn
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
prepare_inputs_fn = opt_prepare_inputs_fn
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn)
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--num_episodes', type=int, default=50)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
18
applications/ChatGPT/examples/train_dummy.sh
Executable file
18
applications/ChatGPT/examples/train_dummy.sh
Executable file
@@ -0,0 +1,18 @@
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 1
|
||||
|
||||
python train_dummy.py --model bloom --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16
|
113
applications/ChatGPT/examples/train_prompts.py
Normal file
113
applications/ChatGPT/examples/train_prompts.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import pandas as pd
|
||||
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'gpt2':
|
||||
actor = GPTActor().cuda()
|
||||
critic = GPTCritic().cuda()
|
||||
elif args.model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
elif args.model == 'opt':
|
||||
actor = OPTActor(lora_rank=args.lora_rank).cuda()
|
||||
critic = OPTCritic(lora_rank=args.lora_rank).cuda()
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
initial_model = deepcopy(actor)
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||
|
||||
def tokenize_fn(texts):
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
tokenizer=tokenize_fn,
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
prepare_inputs_fn=gpt_prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn)
|
||||
|
||||
trainer.fit(dataset,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('prompt_path')
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
18
applications/ChatGPT/examples/train_prompts.sh
Executable file
18
applications/ChatGPT/examples/train_prompts.sh
Executable file
@@ -0,0 +1,18 @@
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 1
|
||||
|
||||
python train_prompts.py prompts.csv --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16
|
53
applications/ChatGPT/examples/train_reward_model.py
Normal file
53
applications/ChatGPT/examples/train_reward_model.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import argparse
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
from chatgpt.dataset import RewardDataset
|
||||
from chatgpt.nn import BLOOMRM
|
||||
from chatgpt.trainer import RewardModelTrainer
|
||||
from datasets import load_dataset
|
||||
from transformers import BloomTokenizerFast
|
||||
|
||||
|
||||
def train(args):
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = BLOOMRM(pretrained=args.pretrain)
|
||||
|
||||
model.cuda()
|
||||
|
||||
max_len = 1024
|
||||
|
||||
# prepare for data and dataset
|
||||
data = load_dataset(args.dataset)
|
||||
train_data = data["train"]
|
||||
eval_data = data['test']
|
||||
train_dataset = RewardDataset(train_data, tokenizer, max_len)
|
||||
eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
|
||||
|
||||
# batch_size here is expected to be C(k,2), k means # response of each prompt
|
||||
# be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1
|
||||
trainer = RewardModelTrainer(model=model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_epochs=args.max_epochs)
|
||||
|
||||
trainer.fit(use_lora=args.lora_rank)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
torch.save({'model_state_dict': lora.lora_state_dict(trainer.model)}, args.save_path)
|
||||
else:
|
||||
torch.save(trainer.model, args.save_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
|
||||
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
|
||||
parser.add_argument('--max_epochs', type=int, default=2)
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
args = parser.parse_args()
|
||||
train(args)
|
18
applications/ChatGPT/examples/train_rm.sh
Executable file
18
applications/ChatGPT/examples/train_rm.sh
Executable file
@@ -0,0 +1,18 @@
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 1
|
||||
|
||||
python train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16
|
Reference in New Issue
Block a user