mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com>
This commit is contained in:
168
applications/ColossalChat/examples/inference/chatio.py
Executable file
168
applications/ColossalChat/examples/inference/chatio.py
Executable file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
command line IO utils for chatbot
|
||||
"""
|
||||
|
||||
import abc
|
||||
import re
|
||||
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
||||
from prompt_toolkit.completion import WordCompleter
|
||||
from prompt_toolkit.history import InMemoryHistory
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
|
||||
|
||||
class ChatIO(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def prompt_for_input(self, role: str) -> str:
|
||||
"""Prompt for input from a role."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def prompt_for_output(self, role: str):
|
||||
"""Prompt for output from a role."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def stream_output(self, output_stream):
|
||||
"""Stream output."""
|
||||
|
||||
|
||||
class SimpleChatIO(ChatIO):
|
||||
def prompt_for_input(self, role) -> str:
|
||||
return input(f"{role}: ")
|
||||
|
||||
def prompt_for_output(self, role: str):
|
||||
print(f"{role}: ", end="", flush=True)
|
||||
|
||||
def stream_output(self, output_stream):
|
||||
pre = 0
|
||||
for outputs in output_stream:
|
||||
outputs = outputs.strip()
|
||||
outputs = outputs.split(" ")
|
||||
now = len(outputs) - 1
|
||||
if now > pre:
|
||||
print(" ".join(outputs[pre:now]), end=" ", flush=True)
|
||||
pre = now
|
||||
print(" ".join(outputs[pre:]), flush=True)
|
||||
return " ".join(outputs)
|
||||
|
||||
|
||||
class RichChatIO(ChatIO):
|
||||
def __init__(self):
|
||||
self._prompt_session = PromptSession(history=InMemoryHistory())
|
||||
self._completer = WordCompleter(words=["!exit", "!reset"], pattern=re.compile("$"))
|
||||
self._console = Console()
|
||||
|
||||
def prompt_for_input(self, role) -> str:
|
||||
self._console.print(f"[bold]{role}:")
|
||||
prompt_input = self._prompt_session.prompt(
|
||||
completer=self._completer,
|
||||
multiline=False,
|
||||
auto_suggest=AutoSuggestFromHistory(),
|
||||
key_bindings=None,
|
||||
)
|
||||
self._console.print()
|
||||
return prompt_input
|
||||
|
||||
def prompt_for_output(self, role: str) -> str:
|
||||
self._console.print(f"[bold]{role}:")
|
||||
|
||||
def stream_output(self, output_stream):
|
||||
"""Stream output from a role."""
|
||||
# Create a Live context for updating the console output
|
||||
with Live(console=self._console, refresh_per_second=60) as live:
|
||||
# Read lines from the stream
|
||||
for outputs in output_stream:
|
||||
accumulated_text = outputs
|
||||
if not accumulated_text:
|
||||
continue
|
||||
# Render the accumulated text as Markdown
|
||||
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
||||
# in rich. The chatbots output treat "\n" as a new line for
|
||||
# better compatibility with real-world text. However, rendering
|
||||
# in markdown would break the format. It is because standard markdown
|
||||
# treat a single "\n" in normal text as a space.
|
||||
# Our workaround is adding two spaces at the end of each line.
|
||||
# This is not a perfect solution, as it would
|
||||
# introduce trailing spaces (only) in code block, but it works well
|
||||
# especially for console output, because in general the console does not
|
||||
# care about trailing spaces.
|
||||
lines = []
|
||||
for line in accumulated_text.splitlines():
|
||||
lines.append(line)
|
||||
if line.startswith("```"):
|
||||
# Code block marker - do not add trailing spaces, as it would
|
||||
# break the syntax highlighting
|
||||
lines.append("\n")
|
||||
else:
|
||||
lines.append(" \n")
|
||||
markdown = Markdown("".join(lines))
|
||||
# Update the Live console output
|
||||
live.update(markdown)
|
||||
self._console.print()
|
||||
return outputs
|
||||
|
||||
|
||||
class DummyChatIO(ChatIO):
|
||||
"""
|
||||
Dummy ChatIO class for testing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.roles = []
|
||||
self._console = Console()
|
||||
|
||||
def prompt_for_input(self, role) -> str:
|
||||
self.roles.append(role)
|
||||
if len(self.roles) == 1:
|
||||
ret = "Hello"
|
||||
elif len(self.roles) == 2:
|
||||
ret = "What's the value of 1+1?"
|
||||
else:
|
||||
ret = "exit"
|
||||
self._console.print(f"[bold]{role}:{ret}")
|
||||
return ret
|
||||
|
||||
def prompt_for_output(self, role: str) -> str:
|
||||
self._console.print(f"[bold]{role}:")
|
||||
|
||||
def stream_output(self, output_stream):
|
||||
"""Stream output from a role."""
|
||||
# Create a Live context for updating the console output
|
||||
with Live(console=self._console, refresh_per_second=60) as live:
|
||||
# Read lines from the stream
|
||||
for outputs in output_stream:
|
||||
accumulated_text = outputs
|
||||
if not accumulated_text:
|
||||
continue
|
||||
# Render the accumulated text as Markdown
|
||||
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
||||
# in rich. The chatbots output treat "\n" as a new line for
|
||||
# better compatibility with real-world text. However, rendering
|
||||
# in markdown would break the format. It is because standard markdown
|
||||
# treat a single "\n" in normal text as a space.
|
||||
# Our workaround is adding two spaces at the end of each line.
|
||||
# This is not a perfect solution, as it would
|
||||
# introduce trailing spaces (only) in code block, but it works well
|
||||
# especially for console output, because in general the console does not
|
||||
# care about trailing spaces.
|
||||
lines = []
|
||||
for line in accumulated_text.splitlines():
|
||||
lines.append(line)
|
||||
if line.startswith("```"):
|
||||
# Code block marker - do not add trailing spaces, as it would
|
||||
# break the syntax highlighting
|
||||
lines.append("\n")
|
||||
else:
|
||||
lines.append(" \n")
|
||||
markdown = Markdown("".join(lines))
|
||||
# Update the Live console output
|
||||
live.update(markdown)
|
||||
self._console.print()
|
||||
return outputs
|
||||
|
||||
|
||||
simple_io = SimpleChatIO()
|
||||
rich_io = RichChatIO()
|
||||
dummy_io = DummyChatIO()
|
195
applications/ColossalChat/examples/inference/inference.py
Executable file
195
applications/ColossalChat/examples/inference/inference.py
Executable file
@@ -0,0 +1,195 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from chatio import dummy_io, rich_io, simple_io
|
||||
from coati.dataset.conversation import setup_conversation_template
|
||||
from coati.models import generate_streaming
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def get_gpu_memory(max_gpus=None):
|
||||
"""
|
||||
Get the available memory for each GPU.
|
||||
|
||||
Args:
|
||||
max_gpus (int, optional): The maximum number of GPUs to consider. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: A list of available memory for each GPU.
|
||||
"""
|
||||
gpu_memory = []
|
||||
num_gpus = torch.cuda.device_count() if max_gpus is None else min(max_gpus, torch.cuda.device_count())
|
||||
|
||||
for gpu_id in range(num_gpus):
|
||||
# Code to get GPU memory goes here
|
||||
with torch.cuda.device(gpu_id):
|
||||
device = torch.cuda.current_device()
|
||||
gpu_properties = torch.cuda.get_device_properties(device)
|
||||
total_memory = gpu_properties.total_memory / (1024**3)
|
||||
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
|
||||
available_memory = total_memory - allocated_memory
|
||||
gpu_memory.append(available_memory)
|
||||
return gpu_memory
|
||||
|
||||
|
||||
def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs):
|
||||
"""
|
||||
Load the model and tokenizer from the specified paths and move the model to the specified device.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the pre-trained model.
|
||||
tokenizer_path (str): The path to the pre-trained tokenizer.
|
||||
device (str, optional): The device to move the model to. Defaults to "cuda".
|
||||
**kwargs: Additional keyword arguments to be passed to the `AutoModelForCausalLM.from_pretrained` function.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the loaded model and tokenizer.
|
||||
"""
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.to(device)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def _set_default_generate_kwargs(model: PreTrainedModel) -> Dict:
|
||||
"""
|
||||
Set default keyword arguments for generation based on the given model.
|
||||
|
||||
Args:
|
||||
model (PreTrainedModel): The model used for generation.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary containing the default keyword arguments for generation.
|
||||
"""
|
||||
unwrapped_model = model
|
||||
new_kwargs = {}
|
||||
# Use huggingface models method directly
|
||||
if hasattr(unwrapped_model, "prepare_inputs_for_generation"):
|
||||
new_kwargs["prepare_inputs_fn"] = unwrapped_model.prepare_inputs_for_generation
|
||||
|
||||
if hasattr(unwrapped_model, "_update_model_kwargs_for_generation"):
|
||||
new_kwargs["update_model_kwargs_fn"] = unwrapped_model._update_model_kwargs_for_generation
|
||||
return new_kwargs
|
||||
|
||||
|
||||
def generation_wrapper(*args, **kwargs):
|
||||
input_ids = args[1]
|
||||
tokenizer = args[2]
|
||||
for output in generate_streaming(*args, **kwargs):
|
||||
yield tokenizer.batch_decode(output[:, input_ids.size(1) :], skip_special_tokens=True)[0]
|
||||
|
||||
|
||||
def main(args):
|
||||
conversation_template_config = json.load(open(args.conversation_template_config, "r", encoding="utf8"))
|
||||
|
||||
max_new_tokens = args.max_new_tokens
|
||||
model_max_length = args.model_max_length
|
||||
model, tokenizer = load_model_and_tokenizer(
|
||||
args.model_path, args.tokenizer_path or args.model_path, local_files_only=True
|
||||
)
|
||||
|
||||
assert max_new_tokens <= model_max_length
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
model_kwargs = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
# 'early_stopping': True,
|
||||
# 'top_k': -1,
|
||||
# 'top_p': 1.0,
|
||||
# 'temperature': 1.0,
|
||||
# 'temperature':0.1,
|
||||
}
|
||||
round = 1
|
||||
|
||||
conv = setup_conversation_template(tokenizer, conversation_template_config, args.conversation_template_config)
|
||||
|
||||
while True:
|
||||
if args.io == "simple":
|
||||
chat_io = simple_io
|
||||
elif args.io == "rich":
|
||||
chat_io = rich_io
|
||||
elif args.io == "dummy":
|
||||
chat_io = dummy_io
|
||||
else:
|
||||
raise ValueError(f"Unknown io type: {args.io}")
|
||||
# raw_text = print(">>> Human:", end=" ")
|
||||
inp = chat_io.prompt_for_input("user")
|
||||
|
||||
if not inp:
|
||||
print("prompt should not be empty!")
|
||||
continue
|
||||
|
||||
if inp.strip() == "clear":
|
||||
conv.clear()
|
||||
os.system("clear")
|
||||
continue
|
||||
|
||||
if inp.strip() == "exit":
|
||||
print("End of chat.")
|
||||
break
|
||||
|
||||
query_text = inp.strip()
|
||||
|
||||
conv.append_message("user", query_text)
|
||||
|
||||
chat_io.prompt_for_output("assistant")
|
||||
|
||||
prompt = conv.get_prompt(add_generation_prompt=True)
|
||||
print(prompt + "<end_of_prompt>")
|
||||
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(
|
||||
torch.cuda.current_device()
|
||||
)
|
||||
default_generate_kwargs = _set_default_generate_kwargs(model)
|
||||
model_kwargs.update(default_generate_kwargs)
|
||||
output_stream = generation_wrapper(
|
||||
model,
|
||||
input_ids,
|
||||
tokenizer,
|
||||
max_length=model_max_length,
|
||||
temperature=0.7,
|
||||
early_stopping=True,
|
||||
stop_token_ids=conversation_template_config["stop_ids"],
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# print(f">>> Assistant:", end=" ")
|
||||
outputs = chat_io.stream_output(output_stream)
|
||||
|
||||
conv.append_message("assistant", outputs.strip())
|
||||
|
||||
with open("round.txt", mode="a", encoding="utf-8") as f:
|
||||
f.write("\n\n" + "=" * 10 + "\n")
|
||||
f.write(f"round {round}:\n{conv.save_prompt()}\n\n")
|
||||
f.write("=" * 10 + "\n")
|
||||
|
||||
# print(f">>> Assistant:", end=" ")
|
||||
|
||||
round += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None)
|
||||
parser.add_argument("--conversation_template_config", type=str, default=None)
|
||||
parser.add_argument("--model_max_length", type=int, default=2048)
|
||||
parser.add_argument("--max_new_tokens", type=int, default=512)
|
||||
parser.add_argument("--io", type=str, default="rich", choices=["simple", "rich", "dummy"])
|
||||
args = parser.parse_args()
|
||||
main(args)
|
118
applications/ColossalChat/examples/inference/web_chatbot/README.md
Executable file
118
applications/ColossalChat/examples/inference/web_chatbot/README.md
Executable file
@@ -0,0 +1,118 @@
|
||||
# Inference
|
||||
|
||||
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
|
||||
|
||||
We support 8-bit quantization (RTN), which is powered by [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [transformers](https://github.com/huggingface/transformers). And 4-bit quantization (GPTQ), which is powered by [gptq](https://github.com/IST-DASLab/gptq) and [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). We also support FP16 inference.
|
||||
|
||||
We only support LLaMA family models now.
|
||||
|
||||
## Choosing precision (quantization)
|
||||
|
||||
**FP16**: Fastest, best output quality, highest memory usage
|
||||
|
||||
**8-bit**: Slow, easier setup (originally supported by transformers), lower output quality (due to RTN), **recommended for first-timers**
|
||||
|
||||
**4-bit**: Faster, lowest memory usage, higher output quality (due to GPTQ), but more difficult setup
|
||||
|
||||
## Hardware requirements for LLaMA
|
||||
|
||||
Tha data is from [LLaMA Int8 4bit ChatBot Guide v2](https://rentry.org/llama-tard-v2).
|
||||
|
||||
### 8-bit
|
||||
|
||||
| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples |
|
||||
| :-------: | :---------: | :-----------------: | :----------: | :--------------------------------: |
|
||||
| LLaMA-7B | 9.2GB | 10GB | 24GB | 3060 12GB, RTX 3080 10GB, RTX 3090 |
|
||||
| LLaMA-13B | 16.3GB | 20GB | 32GB | RTX 3090 Ti, RTX 4090 |
|
||||
| LLaMA-30B | 36GB | 40GB | 64GB | A6000 48GB, A100 40GB |
|
||||
| LLaMA-65B | 74GB | 80GB | 128GB | A100 80GB |
|
||||
|
||||
### 4-bit
|
||||
|
||||
| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples |
|
||||
| :-------: | :---------: | :-----------------: | :----------: | :--------------------------------------------------------: |
|
||||
| LLaMA-7B | 3.5GB | 6GB | 16GB | RTX 1660, 2060, AMD 5700xt, RTX 3050, 3060 |
|
||||
| LLaMA-13B | 6.5GB | 10GB | 32GB | AMD 6900xt, RTX 2060 12GB, 3060 12GB, 3080, A2000 |
|
||||
| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 |
|
||||
| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada |
|
||||
|
||||
## General setup
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 8-bit setup
|
||||
|
||||
8-bit quantization is originally supported by the latest [transformers](https://github.com/huggingface/transformers). Please install it from source.
|
||||
|
||||
Please ensure you have downloaded HF-format model weights of LLaMA models.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
USE_8BIT = True # use 8-bit quantization; otherwise, use fp16
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"pretrained/path",
|
||||
load_in_8bit=USE_8BIT,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
if not USE_8BIT:
|
||||
model.half() # use fp16
|
||||
model.eval()
|
||||
```
|
||||
|
||||
**Troubleshooting**: if you get error indicating your CUDA-related libraries not found when loading 8-bit model, you can check whether your `LD_LIBRARY_PATH` is correct.
|
||||
|
||||
E.g. you can set `export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH`.
|
||||
|
||||
## 4-bit setup
|
||||
|
||||
Please ensure you have downloaded HF-format model weights of LLaMA models first.
|
||||
|
||||
Then you can follow [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). This lib provides efficient CUDA kernels and weight conversion script.
|
||||
|
||||
After installing this lib, we may convert the original HF-format LLaMA model weights to 4-bit version.
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0 python llama.py /path/to/pretrained/llama-7b c4 --wbits 4 --groupsize 128 --save llama7b-4bit.pt
|
||||
```
|
||||
|
||||
Run this command in your cloned `GPTQ-for-LLaMa` directory, then you will get a 4-bit weight file `llama7b-4bit-128g.pt`.
|
||||
|
||||
**Troubleshooting**: if you get error about `position_ids`, you can checkout to commit `50287c3b9ae4a3b66f6b5127c643ec39b769b155`(`GPTQ-for-LLaMa` repo).
|
||||
|
||||
## Online inference server
|
||||
|
||||
In this directory:
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
# fp16, will listen on 0.0.0.0:7070 by default
|
||||
python server.py /path/to/pretrained
|
||||
# 8-bit, will listen on localhost:8080
|
||||
python server.py /path/to/pretrained --quant 8bit --http_host localhost --http_port 8080
|
||||
# 4-bit
|
||||
python server.py /path/to/pretrained --quant 4bit --gptq_checkpoint /path/to/llama7b-4bit-128g.pt --gptq_group_size 128
|
||||
```
|
||||
|
||||
## Benchmark
|
||||
|
||||
In this directory:
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
# fp16
|
||||
python benchmark.py /path/to/pretrained
|
||||
# 8-bit
|
||||
python benchmark.py /path/to/pretrained --quant 8bit
|
||||
# 4-bit
|
||||
python benchmark.py /path/to/pretrained --quant 4bit --gptq_checkpoint /path/to/llama7b-4bit-128g.pt --gptq_group_size 128
|
||||
```
|
||||
|
||||
This benchmark will record throughput and peak CUDA memory usage.
|
26
applications/ColossalChat/examples/inference/web_chatbot/locustfile.py
Executable file
26
applications/ColossalChat/examples/inference/web_chatbot/locustfile.py
Executable file
@@ -0,0 +1,26 @@
|
||||
from locust import HttpUser, task
|
||||
|
||||
samples = [
|
||||
[
|
||||
dict(
|
||||
instruction="Who is the best player in the history of NBA?",
|
||||
response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
|
||||
),
|
||||
dict(instruction="continue this talk", response=""),
|
||||
],
|
||||
[
|
||||
dict(instruction="Who is the best player in the history of NBA?", response=""),
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
class GenerationUser(HttpUser):
|
||||
@task
|
||||
def generate(self):
|
||||
for sample in samples:
|
||||
data = {"max_new_tokens": 64, "history": sample}
|
||||
with self.client.post("/generate", json=data, catch_response=True) as response:
|
||||
if response.status_code in (200, 406):
|
||||
response.success()
|
||||
else:
|
||||
response.failure("Response wrong")
|
13
applications/ColossalChat/examples/inference/web_chatbot/requirements.txt
Executable file
13
applications/ColossalChat/examples/inference/web_chatbot/requirements.txt
Executable file
@@ -0,0 +1,13 @@
|
||||
fastapi
|
||||
locust
|
||||
numpy
|
||||
pydantic
|
||||
safetensors
|
||||
slowapi
|
||||
sse_starlette
|
||||
torch
|
||||
uvicorn
|
||||
git+https://github.com/huggingface/transformers
|
||||
accelerate
|
||||
bitsandbytes
|
||||
jieba
|
208
applications/ColossalChat/examples/inference/web_chatbot/server.py
Executable file
208
applications/ColossalChat/examples/inference/web_chatbot/server.py
Executable file
@@ -0,0 +1,208 @@
|
||||
import argparse
|
||||
import os
|
||||
from threading import Lock
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from coati.models import generate_streaming
|
||||
from coati.quant import llama_load_quant, low_resource_init
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.util import get_remote_address
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, update_model_kwargs_fn
|
||||
|
||||
MAX_LEN = 512
|
||||
running_lock = Lock()
|
||||
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
max_new_tokens: int = Field(gt=0, le=512, example=64)
|
||||
history: List[Dialogue] = Field(min_items=1)
|
||||
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
||||
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
||||
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
||||
repetition_penalty: Optional[float] = Field(default=None, gt=1.0, example=1.2)
|
||||
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
app = FastAPI()
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# set CORS
|
||||
origin_spec_from_env = os.environ.get("CORS_ORIGIN", None)
|
||||
|
||||
if origin_spec_from_env is not None:
|
||||
# allow CORS from the specified origins
|
||||
origins = os.environ["CORS_ORIGIN"].split(",")
|
||||
else:
|
||||
# allow CORS from all origins
|
||||
origins = ["*"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def generate_streamingly(prompt, max_length, max_new_tokens, top_k, top_p, temperature):
|
||||
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
|
||||
# TODO(ver217): streaming generation does not support repetition_penalty now
|
||||
model_kwargs = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"early_stopping": True,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"prepare_inputs_fn": None,
|
||||
"update_model_kwargs_fn": update_model_kwargs_fn,
|
||||
}
|
||||
is_first_word = True
|
||||
generator = LockedIterator(
|
||||
generate_streaming(model, input_ids, tokenizer, max_length, **model_kwargs), running_lock
|
||||
)
|
||||
for output in generator:
|
||||
output = output.cpu()
|
||||
tokens = tokenizer.convert_ids_to_tokens(output, skip_special_tokens=True)
|
||||
current_sub_tokens = []
|
||||
for token in tokens:
|
||||
if token in tokenizer.all_special_tokens:
|
||||
continue
|
||||
current_sub_tokens.append(token)
|
||||
if current_sub_tokens:
|
||||
out_string = tokenizer.sp_model.decode(current_sub_tokens)
|
||||
if is_first_word:
|
||||
out_string = out_string.lstrip()
|
||||
is_first_word = False
|
||||
elif current_sub_tokens[0].startswith("▁"):
|
||||
# whitespace will be ignored by the frontend
|
||||
out_string = " " + out_string
|
||||
yield out_string
|
||||
|
||||
|
||||
async def event_generator(request: Request, generator: Generator):
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
try:
|
||||
yield {"event": "generate", "data": next(generator)}
|
||||
except StopIteration:
|
||||
yield {"event": "end", "data": ""}
|
||||
break
|
||||
|
||||
|
||||
@app.post("/generate/stream")
|
||||
@limiter.limit("1/second")
|
||||
def generate(data: GenerationTaskReq, request: Request):
|
||||
prompt = prompt_processor.preprocess_prompt(data.history)
|
||||
event_source = event_generator(
|
||||
request,
|
||||
generate_streamingly(prompt, data.max_length, data.max_new_tokens, data.top_k, data.top_p, data.temperature),
|
||||
)
|
||||
return EventSourceResponse(event_source)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
@limiter.limit("1/second")
|
||||
def generate_no_stream(data: GenerationTaskReq, request: Request):
|
||||
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
|
||||
if prompt_processor.has_censored_words(prompt):
|
||||
return prompt_processor.SAFE_RESPONSE
|
||||
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
|
||||
with running_lock:
|
||||
output = model.generate(**inputs, **data.dict(exclude={"history"}))
|
||||
output = output.cpu()
|
||||
prompt_len = inputs["input_ids"].size(1)
|
||||
response = output[0, prompt_len:]
|
||||
out_string = tokenizer.decode(response, skip_special_tokens=True)
|
||||
out_string = prompt_processor.postprocess_output(out_string)
|
||||
if prompt_processor.has_censored_words(out_string):
|
||||
return prompt_processor.SAFE_RESPONSE
|
||||
return out_string
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"pretrained",
|
||||
help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_path",
|
||||
help="Path to pretrained tokenizer. Can be a local path or a model name from the HuggingFace model hub.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quant",
|
||||
choices=["8bit", "4bit"],
|
||||
default=None,
|
||||
help="Quantization mode. Default: None (no quantization, fp16).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gptq_checkpoint",
|
||||
default=None,
|
||||
help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gptq_group_size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
|
||||
)
|
||||
parser.add_argument("--http_host", default="0.0.0.0")
|
||||
parser.add_argument("--http_port", type=int, default=7070)
|
||||
parser.add_argument(
|
||||
"--profanity_file",
|
||||
default=None,
|
||||
help="Path to profanity words list. It should be a JSON file containing a list of words.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.quant == "4bit":
|
||||
assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
|
||||
|
||||
if args.tokenizer_path is None:
|
||||
args.tokenizer_path = args.pretrained
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, local_files_only=True)
|
||||
|
||||
if args.profanity_file is not None:
|
||||
censored_words = load_json(args.profanity_file)
|
||||
else:
|
||||
censored_words = []
|
||||
prompt_processor = ChatPromptProcessor(censored_words=censored_words)
|
||||
|
||||
if args.quant == "4bit":
|
||||
with low_resource_init():
|
||||
config = AutoConfig.from_pretrained(args.pretrained)
|
||||
model = AutoModelForCausalLM(config)
|
||||
model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
|
||||
model.cuda()
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrained,
|
||||
load_in_8bit=(args.quant == "8bit"),
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
local_files_only=True,
|
||||
)
|
||||
if args.quant != "8bit":
|
||||
model.half() # seems to fix bugs for some users.
|
||||
model.eval()
|
||||
|
||||
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
|
||||
server = uvicorn.Server(config=config)
|
||||
server.run()
|
||||
|
||||
|
||||
"""
|
||||
python server.py /home/lcyab/data/models/experiments5/checkpoint/experiment5-2023-10-20-21-53-51/modeling/ --tokenizer_path /mnt/vepfs/lcxyc/leaderboard_models/Colossal-LLaMA-2-7b-base/
|
||||
"""
|
78
applications/ColossalChat/examples/inference/web_chatbot/utils.py
Executable file
78
applications/ColossalChat/examples/inference/web_chatbot/utils.py
Executable file
@@ -0,0 +1,78 @@
|
||||
import copy
|
||||
import json
|
||||
from threading import Lock
|
||||
from typing import List
|
||||
|
||||
import jieba
|
||||
import torch
|
||||
from coati.dataset.conversation import default_conversation
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
|
||||
if "past_key_values" in outputs:
|
||||
model_kwargs["past"] = outputs["past_key_values"]
|
||||
else:
|
||||
model_kwargs["past"] = None
|
||||
|
||||
# update token_type_ids with last value
|
||||
if "token_type_ids" in model_kwargs:
|
||||
token_type_ids = model_kwargs["token_type_ids"]
|
||||
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
||||
|
||||
# update attention mask
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||
)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
||||
class Dialogue(BaseModel):
|
||||
instruction: str = Field(min_length=1, example="Count up from 1 to 500.")
|
||||
response: str = Field(example="")
|
||||
|
||||
|
||||
class ChatPromptProcessor:
|
||||
SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt."
|
||||
|
||||
def __init__(self, censored_words: List[str] = []):
|
||||
self.censored_words = set([word.lower() for word in censored_words])
|
||||
self.conv = copy.deepcopy(default_conversation)
|
||||
|
||||
def preprocess_prompt(self, history: List[Dialogue]) -> str:
|
||||
self.conv.clear()
|
||||
for round in history:
|
||||
self.conv.append_message(self.conv.roles[0], round.instruction)
|
||||
if len(round.instruction) > 0:
|
||||
self.conv.append_message(self.conv.roles[1], round.response)
|
||||
return self.conv.get_prompt()
|
||||
|
||||
def postprocess_output(self, output: str) -> str:
|
||||
return output.strip()
|
||||
|
||||
def has_censored_words(self, text: str) -> bool:
|
||||
if len(self.censored_words) == 0:
|
||||
return False
|
||||
intersection = set(jieba.cut(text.lower())) & self.censored_words
|
||||
return len(intersection) > 0
|
||||
|
||||
|
||||
class LockedIterator:
|
||||
def __init__(self, it, lock: Lock) -> None:
|
||||
self.lock = lock
|
||||
self.it = iter(it)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
with self.lock:
|
||||
return next(self.it)
|
||||
|
||||
|
||||
def load_json(path: str):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
Reference in New Issue
Block a user