ColossalAI/applications/Chat/benchmarks/ray/mmmt_dummy.py
Hongxin Liu b5f0566363
[chat] add distributed PPO trainer (#3740)
* Detached ppo (#9)

* run the base

* working on dist ppo

* sync

* detached trainer

* update detached trainer. no maker update function

* facing init problem

* 1 maker 1 trainer detached run. but no model update

* facing cuda problem

* fix save functions

* verified maker update

* nothing

* add ignore

* analyize loss issue

* remove some debug codes

* facing 2m1t stuck issue

* 2m1t verified

* do not use torchrun

* working on 2m2t

* working on 2m2t

* initialize strategy in ray actor env

* facing actor's init order issue

* facing ddp model update issue (need unwarp ddp)

* unwrap ddp actor

* checking 1m2t stuck problem

* nothing

* set timeout for trainer choosing. It solves the stuck problem!

* delete some debug output

* rename to sync with upstream

* rename to sync with upstream

* coati rename

* nothing

* I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations

* experience_maker_holder performs target-revolving _send_experience() instead of length comparison.

* move code to ray subfolder

* working on pipeline inference

* apply comments

* working on pipeline strategy. in progress.

* remove pipeline code. clean this branch

* update remote parameters by state_dict. no test

* nothing

* state_dict sharding transfer

* merge debug branch

* gemini _unwrap_model fix

* simplify code

* simplify code & fix LoRALinear AttributeError

* critic unwrapped state_dict

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] add perfomance evaluator and fix bugs (#10)

* [chat] add performance evaluator for ray

* [chat] refactor debug arg

* [chat] support hf config

* [chat] fix generation

* [chat] add 1mmt dummy example

* [chat] fix gemini ckpt

* split experience to send (#11)

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] refactor trainer and maker (#12)

* [chat] refactor experience maker holder

* [chat] refactor model init

* [chat] refactor trainer args

* [chat] refactor model init

* [chat] refactor trainer

* [chat] refactor experience sending logic and training loop args (#13)

* [chat] refactor experience send logic

* [chat] refactor trainer

* [chat] refactor trainer

* [chat] refactor experience maker

* [chat] refactor pbar

* [chat] refactor example folder (#14)

* [chat] support quant (#15)

* [chat] add quant

* [chat] add quant example

* prompt example (#16)

* prompt example

* prompt load csv data

* remove legacy try

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] add mmmt dummy example and refactor experience sending (#17)

* [chat] add mmmt dummy example

* [chat] refactor naive strategy

* [chat] fix struck problem

* [chat] fix naive strategy

* [chat] optimize experience maker sending logic

* [chat] refactor sending assignment

* [chat] refactor performance evaluator (#18)

* Prompt Example & requires_grad state_dict & sharding state_dict (#19)

* prompt example

* prompt load csv data

* remove legacy try

* maker models require_grad set to False

* working on zero redundancy update

* mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad.

* remove legacy examples

* remove legacy examples

* remove replay buffer tp state. bad design

---------

Co-authored-by: csric <richcsr256@gmail.com>

* state_dict sending adapts to new unwrap function (#20)

* prompt example

* prompt load csv data

* remove legacy try

* maker models require_grad set to False

* working on zero redundancy update

* mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad.

* remove legacy examples

* remove legacy examples

* remove replay buffer tp state. bad design

* opt benchmark

* better script

* nothing

* [chat] strategy refactor unwrap model

* [chat] strategy refactor save model

* [chat] add docstr

* [chat] refactor trainer save model

* [chat] fix strategy typing

* [chat] refactor trainer save model

* [chat] update readme

* [chat] fix unit test

* working on lora reconstruction

* state_dict sending adapts to new unwrap function

* remove comments

---------

Co-authored-by: csric <richcsr256@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* [chat-ray] add readme (#21)

* add readme

* transparent graph

* add note background

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] get images from url (#22)

* Refactor/chat ray (#23)

* [chat] lora add todo

* [chat] remove unused pipeline strategy

* [chat] refactor example structure

* [chat] setup ci for ray

* [chat-ray] Support LoRA trainer. LoRA weights reconstruction. (#24)

* lora support prototype

* lora support

* 1mmt lora & remove useless code

---------

Co-authored-by: csric <richcsr256@gmail.com>

* [chat] fix test ci for ray

* [chat] fix test ci requirements for ray

* [chat] fix ray runtime env

* [chat] fix ray runtime env

* [chat] fix example ci docker args

* [chat] add debug info in trainer

* [chat] add nccl debug info

* [chat] skip ray test

* [doc] fix typo

---------

Co-authored-by: csric <59389055+CsRic@users.noreply.github.com>
Co-authored-by: csric <richcsr256@gmail.com>
2023-06-07 10:41:16 +08:00

190 lines
7.8 KiB
Python

import argparse
import os
import socket
from functools import partial
import ray
import torch
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
from coati.ray.utils import (
get_actor_from_args,
get_critic_from_args,
get_receivers_per_sender,
get_reward_model_from_args,
get_strategy_from_args,
)
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_utils import no_init_weights
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
return s.getsockname()[0]
def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainers = [{
'local_rank': '0',
'rank': str(rank),
'world_size': str(args.num_trainers),
'master_port': trainer_port,
'master_addr': master_addr
} for rank in range(args.num_trainers)]
# maker_env_info
maker_port = str(get_free_port())
env_info_makers = [{
'local_rank': '0',
'rank': str(rank),
'world_size': str(args.num_makers),
'master_port': maker_port,
'master_addr': master_addr
} for rank in range(args.num_makers)]
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
def model_fn():
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.critic_model,
config=critic_cfg).requires_grad_(False).half().cuda()
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
# quantize initial model
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
args.quant_group_size).cuda().requires_grad_(False)
else:
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
# configure Experience Maker
experience_holder_refs = [
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[
f'trainer{x}'
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
kl_coef=0.1,
debug=args.debug,
# sync_models_from_trainers=True,
# generation kwargs:
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
eval_performance=True,
use_cache=True,
)
for i, env_info_maker in enumerate(env_info_makers)
]
def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
critic = get_critic_from_args(args.critic_model,
config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
return actor, critic
# configure Trainer
trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=[
f"maker{x}"
for x in get_receivers_per_sender(i, args.num_trainers, args.num_makers, allow_idle_sender=True)
],
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn,
env_info=env_info_trainer,
train_batch_size=args.train_batch_size,
buffer_limit=16,
eval_performance=True,
debug=args.debug,
)
for i, env_info_trainer in enumerate(env_info_trainers)
]
dataset_size = args.experience_batch_size * 4
def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids)
return {'input_ids': input_ids, 'attention_mask': attn_mask}
def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)]
dataloader = DataLoader(dataset, batch_size=args.experience_batch_size)
return dataloader
# uncomment this function if sync_models_from_trainers is True
# ray.get([
# trainer_ref.sync_models_to_remote_makers.remote()
# for trainer_ref in trainer_refs
# ])
wait_tasks = []
for experience_holder_ref in experience_holder_refs:
wait_tasks.append(
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
num_steps=args.experience_steps))
total_steps = args.experience_batch_size * args.experience_steps * \
args.num_makers // (args.num_trainers * args.train_batch_size)
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_makers', type=int, default=1)
parser.add_argument('--num_trainers', type=int, default=1)
parser.add_argument('--trainer_strategy',
choices=[
'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
'colossalai_zero2_cpu'
],
default='naive')
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--critic_pretrain', type=str, default=None)
parser.add_argument('--experience_steps', type=int, default=4)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--train_epochs', type=int, default=1)
parser.add_argument('--update_steps', type=int, default=2)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
parser.add_argument('--quant_bits', type=int, default=4)
parser.add_argument('--quant_group_size', type=int, default=128)
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)