mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-05 05:32:03 +00:00
* Detached ppo (#9) * run the base * working on dist ppo * sync * detached trainer * update detached trainer. no maker update function * facing init problem * 1 maker 1 trainer detached run. but no model update * facing cuda problem * fix save functions * verified maker update * nothing * add ignore * analyize loss issue * remove some debug codes * facing 2m1t stuck issue * 2m1t verified * do not use torchrun * working on 2m2t * working on 2m2t * initialize strategy in ray actor env * facing actor's init order issue * facing ddp model update issue (need unwarp ddp) * unwrap ddp actor * checking 1m2t stuck problem * nothing * set timeout for trainer choosing. It solves the stuck problem! * delete some debug output * rename to sync with upstream * rename to sync with upstream * coati rename * nothing * I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations * experience_maker_holder performs target-revolving _send_experience() instead of length comparison. * move code to ray subfolder * working on pipeline inference * apply comments * working on pipeline strategy. in progress. * remove pipeline code. clean this branch * update remote parameters by state_dict. no test * nothing * state_dict sharding transfer * merge debug branch * gemini _unwrap_model fix * simplify code * simplify code & fix LoRALinear AttributeError * critic unwrapped state_dict --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add perfomance evaluator and fix bugs (#10) * [chat] add performance evaluator for ray * [chat] refactor debug arg * [chat] support hf config * [chat] fix generation * [chat] add 1mmt dummy example * [chat] fix gemini ckpt * split experience to send (#11) Co-authored-by: csric <richcsr256@gmail.com> * [chat] refactor trainer and maker (#12) * [chat] refactor experience maker holder * [chat] refactor model init * [chat] refactor trainer args * [chat] refactor model init * [chat] refactor trainer * [chat] refactor experience sending logic and training loop args (#13) * [chat] refactor experience send logic * [chat] refactor trainer * [chat] refactor trainer * [chat] refactor experience maker * [chat] refactor pbar * [chat] refactor example folder (#14) * [chat] support quant (#15) * [chat] add quant * [chat] add quant example * prompt example (#16) * prompt example * prompt load csv data * remove legacy try --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add mmmt dummy example and refactor experience sending (#17) * [chat] add mmmt dummy example * [chat] refactor naive strategy * [chat] fix struck problem * [chat] fix naive strategy * [chat] optimize experience maker sending logic * [chat] refactor sending assignment * [chat] refactor performance evaluator (#18) * Prompt Example & requires_grad state_dict & sharding state_dict (#19) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design --------- Co-authored-by: csric <richcsr256@gmail.com> * state_dict sending adapts to new unwrap function (#20) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design * opt benchmark * better script * nothing * [chat] strategy refactor unwrap model * [chat] strategy refactor save model * [chat] add docstr * [chat] refactor trainer save model * [chat] fix strategy typing * [chat] refactor trainer save model * [chat] update readme * [chat] fix unit test * working on lora reconstruction * state_dict sending adapts to new unwrap function * remove comments --------- Co-authored-by: csric <richcsr256@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * [chat-ray] add readme (#21) * add readme * transparent graph * add note background --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] get images from url (#22) * Refactor/chat ray (#23) * [chat] lora add todo * [chat] remove unused pipeline strategy * [chat] refactor example structure * [chat] setup ci for ray * [chat-ray] Support LoRA trainer. LoRA weights reconstruction. (#24) * lora support prototype * lora support * 1mmt lora & remove useless code --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] fix test ci for ray * [chat] fix test ci requirements for ray * [chat] fix ray runtime env * [chat] fix ray runtime env * [chat] fix example ci docker args * [chat] add debug info in trainer * [chat] add nccl debug info * [chat] skip ray test * [doc] fix typo --------- Co-authored-by: csric <59389055+CsRic@users.noreply.github.com> Co-authored-by: csric <richcsr256@gmail.com>
153 lines
6.0 KiB
Python
153 lines
6.0 KiB
Python
import os
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
|
|
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
|
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
|
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
|
from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM
|
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
|
from coati.utils import prepare_llama_tokenizer_and_embedding
|
|
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer
|
|
|
|
|
|
def is_rank_0() -> bool:
|
|
return not dist.is_initialized() or dist.get_rank() == 0
|
|
|
|
|
|
def get_rank() -> int:
|
|
return dist.get_rank() if dist.is_initialized() else 0
|
|
|
|
|
|
def get_world_size() -> int:
|
|
return dist.get_world_size() if dist.is_initialized() else 1
|
|
|
|
|
|
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
|
if model == 'gpt2':
|
|
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
|
elif model == 'bloom':
|
|
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
|
elif model == 'opt':
|
|
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
|
elif model == 'llama':
|
|
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
|
elif model == 'roberta':
|
|
actor = RoBERTaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
|
else:
|
|
raise ValueError(f'Unsupported actor model "{model}"')
|
|
return actor
|
|
|
|
|
|
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
|
if model == 'gpt2':
|
|
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
|
elif model == 'bloom':
|
|
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
|
elif model == 'opt':
|
|
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
|
elif model == 'llama':
|
|
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
|
elif model == 'roberta':
|
|
critic = RoBERTaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
|
else:
|
|
raise ValueError(f'Unsupported reward model "{model}"')
|
|
return critic
|
|
|
|
|
|
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
|
|
if model == 'gpt2':
|
|
reward_model = GPTRM(pretrained=pretrained, config=config)
|
|
elif model == 'bloom':
|
|
reward_model = BLOOMRM(pretrained=pretrained, config=config)
|
|
elif model == 'opt':
|
|
reward_model = OPTRM(pretrained=pretrained, config=config)
|
|
elif model == 'llama':
|
|
reward_model = LlamaRM(pretrained=pretrained, config=config)
|
|
elif model == 'roberta':
|
|
reward_model = RoBERTaRM(pretrained=pretrained, config=config)
|
|
else:
|
|
raise ValueError(f'Unsupported reward model "{model}"')
|
|
return reward_model
|
|
|
|
|
|
def get_strategy_from_args(strategy: str):
|
|
if strategy == 'naive':
|
|
strategy_ = NaiveStrategy()
|
|
elif strategy == 'ddp':
|
|
strategy_ = DDPStrategy()
|
|
elif strategy == 'colossalai_gemini':
|
|
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
|
elif strategy == 'colossalai_zero2':
|
|
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
|
elif strategy == 'colossalai_gemini_cpu':
|
|
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
|
elif strategy == 'colossalai_zero2_cpu':
|
|
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
|
else:
|
|
raise ValueError(f'Unsupported strategy "{strategy}"')
|
|
return strategy_
|
|
|
|
|
|
def get_tokenizer_from_args(model: str, **kwargs):
|
|
if model == 'gpt2':
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
elif model == 'bloom':
|
|
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
|
elif model == 'opt':
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
|
elif model == 'llama':
|
|
pretrain_path = kwargs["pretrain"]
|
|
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
|
|
elif model == 'roberta':
|
|
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
|
else:
|
|
raise ValueError(f'Unsupported model "{model}"')
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
return tokenizer
|
|
|
|
|
|
def set_dist_env(env_info: Dict[str, str]):
|
|
os.environ["RANK"] = env_info['rank']
|
|
os.environ["LOCAL_RANK"] = env_info['local_rank']
|
|
os.environ["WORLD_SIZE"] = env_info['world_size']
|
|
os.environ['MASTER_PORT'] = env_info['master_port']
|
|
os.environ['MASTER_ADDR'] = env_info['master_addr']
|
|
|
|
|
|
def get_model_numel(model: nn.Module) -> int:
|
|
numel = sum(p.numel() for p in model.parameters())
|
|
return numel
|
|
|
|
|
|
def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list:
|
|
target_receivers = []
|
|
if num_senders <= num_receivers or allow_idle_sender:
|
|
# a sender will send data to one or more than one receivers
|
|
# a receiver only has one sender
|
|
for i in range(num_receivers):
|
|
if i % num_senders == sender_idx:
|
|
target_receivers.append(i)
|
|
else:
|
|
# a sender will send data to one receiver
|
|
# a receiver may have more than one sender
|
|
target_receivers.append(sender_idx % num_receivers)
|
|
return target_receivers
|
|
|
|
|
|
def state_dict_to(state_dict: Dict[str, Any],
|
|
dtype: torch.dtype = torch.float16,
|
|
device: torch.device = torch.device('cpu')):
|
|
'''
|
|
keep state_dict intact
|
|
'''
|
|
new_state_dict = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
new_state_dict[k] = v.to(dtype=dtype, device=device)
|
|
return new_state_dict
|