mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[gemini] accelerate inference (#3641)
* [gemini] support don't scatter after inference * [chat] update colossalai strategy * [chat] fix opt benchmark * [chat] update opt benchmark * [gemini] optimize inference * [test] add gemini inference test * [chat] fix unit test ci * [chat] fix ci * [chat] fix ci * [chat] skip checkpoint test
This commit is contained in:
@@ -10,6 +10,7 @@ from coati.trainer import PPOTrainer
|
||||
from coati.trainer.callbacks import PerformanceEvaluator
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
|
||||
@@ -92,13 +93,13 @@ def main(args):
|
||||
torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac)
|
||||
|
||||
model_config = get_gpt_config(args.model)
|
||||
|
||||
critic_config = get_gpt_config(args.critic_model)
|
||||
with strategy.model_init_context():
|
||||
actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda()
|
||||
critic = OPTCritic(config=model_config, lora_rank=args.lora_rank).cuda()
|
||||
critic = OPTCritic(config=critic_config, lora_rank=args.lora_rank).cuda()
|
||||
|
||||
initial_model = deepcopy(actor).cuda()
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
||||
initial_model = deepcopy(actor).cuda().half()
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda().half()
|
||||
|
||||
actor_numel = get_model_numel(actor, strategy)
|
||||
critic_numel = get_model_numel(critic, strategy)
|
||||
@@ -127,8 +128,7 @@ def main(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
@@ -137,6 +137,7 @@ def main(args):
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
ptx_coef=0,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
@@ -145,14 +146,19 @@ def main(args):
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
use_cache=True,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 400), device=torch.cuda.current_device())
|
||||
random_attention_mask = torch.randint(1, (1000, 1, 400), device=torch.cuda.current_device()).to(torch.bool)
|
||||
random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)]
|
||||
trainer.fit(random_prompts, random_pretrain,
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
|
||||
dataloader = DataLoader(random_prompts,
|
||||
batch_size=args.experience_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=preprocess_batch)
|
||||
|
||||
trainer.fit(dataloader,
|
||||
None,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
@@ -163,6 +169,7 @@ def main(args):
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='125m')
|
||||
parser.add_argument('--critic_model', default='125m')
|
||||
parser.add_argument('--strategy',
|
||||
choices=[
|
||||
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
|
||||
@@ -175,7 +182,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=4)
|
||||
parser.add_argument('--lora_rank', type=int, default=0)
|
||||
parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user