ColossalAI/applications/Chat/examples/inference.py
Wenhao Chen 7b9b86441f
[chat]: update rm, add wandb and fix bugs (#4471)
* feat: modify forward fn of critic and reward model

* feat: modify calc_action_log_probs

* to: add wandb in sft and rm trainer

* feat: update train_sft

* feat: update train_rm

* style: modify type annotation and add warning

* feat: pass tokenizer to ppo trainer

* to: modify trainer base and maker base

* feat: add wandb in ppo trainer

* feat: pass tokenizer to generate

* test: update generate fn tests

* test: update train tests

* fix: remove action_mask

* feat: remove unused code

* fix: fix wrong ignore_index

* fix: fix mock tokenizer

* chore: update requirements

* revert: modify make_experience

* fix: fix inference

* fix: add padding side

* style: modify _on_learn_batch_end

* test: use mock tokenizer

* fix: use bf16 to avoid overflow

* fix: fix workflow

* [chat] fix gemini strategy

* [chat] fix

* sync: update colossalai strategy

* fix: fix args and model dtype

* fix: fix checkpoint test

* fix: fix requirements

* fix: fix missing import and wrong arg

* fix: temporarily skip gemini test in stage 3

* style: apply pre-commit

* fix: temporarily skip gemini test in stage 1&2

---------

Co-authored-by: Mingyan Jiang <1829166702@qq.com>
2023-09-20 15:53:58 +08:00

74 lines
2.7 KiB
Python

import argparse
import torch
from coati.models.bloom import BLOOMActor
from coati.models.generation import generate
from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
def eval(args):
# configure model
if args.model == "gpt2":
actor = GPTActor(pretrained=args.pretrain)
elif args.model == "bloom":
actor = BLOOMActor(pretrained=args.pretrain)
elif args.model == "opt":
actor = OPTActor(pretrained=args.pretrain)
elif args.model == "llama":
actor = LlamaActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
actor.to(torch.cuda.current_device())
if args.model_path is not None:
state_dict = torch.load(args.model_path)
actor.load_state_dict(state_dict)
# configure tokenizer
if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
actor.eval()
tokenizer.padding_side = "left"
input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
outputs = generate(
actor,
input_ids,
tokenizer=tokenizer,
max_length=args.max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1,
)
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
print(f"[Output]: {''.join(output)}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--model_path", type=str, default=None)
parser.add_argument("--input", type=str, default="Question: How are you ? Answer:")
parser.add_argument("--max_length", type=int, default=100)
args = parser.parse_args()
eval(args)