[tutorial] edited hands-on practices (#1899)

* Add handson to ColossalAI.

* Change names of handsons and edit sequence parallel example.

* Edit wrong folder name

* resolve conflict

* delete readme
This commit is contained in:
BoxiangW
2022-11-11 04:08:17 -05:00
committed by GitHub
parent d9bf83e084
commit ca6e75bc28
121 changed files with 20464 additions and 0 deletions

View File

@@ -0,0 +1,77 @@
# Overview
This is an example showing how to run OPT generation. The OPT model is implemented using ColossalAI.
It supports tensor parallelism, batching and caching.
# How to run
Run OPT-125M:
```shell
python opt_fastapi.py opt-125m
```
It will launch a HTTP server on `0.0.0.0:7070` by default and you can customize host and port. You can open `localhost:7070/docs` in your browser to see the openapi docs.
## Configure
### Configure model
```shell
python opt_fastapi.py <model>
```
Available models: opt-125m, opt-6.7b, opt-30b, opt-175b.
### Configure tensor parallelism
```shell
python opt_fastapi.py <model> --tp <TensorParallelismWorldSize>
```
The `<TensorParallelismWorldSize>` can be an integer in `[1, #GPUs]`. Default `1`.
### Configure checkpoint
```shell
python opt_fastapi.py <model> --checkpoint <CheckpointPath>
```
The `<CheckpointPath>` can be a file path or a directory path. If it's a directory path, all files under the directory will be loaded.
### Configure queue
```shell
python opt_fastapi.py <model> --queue_size <QueueSize>
```
The `<QueueSize>` can be an integer in `[0, MAXINT]`. If it's `0`, the request queue size is infinite. If it's a positive integer, when the request queue is full, incoming requests will be dropped (the HTTP status code of response will be 406).
### Configure bathcing
```shell
python opt_fastapi.py <model> --max_batch_size <MaxBatchSize>
```
The `<MaxBatchSize>` can be an integer in `[1, MAXINT]`. The engine will make batch whose size is less or equal to this value.
Note that the batch size is not always equal to `<MaxBatchSize>`, as some consecutive requests may not be batched.
### Configure caching
```shell
python opt_fastapi.py <model> --cache_size <CacheSize> --cache_list_size <CacheListSize>
```
This will cache `<CacheSize>` unique requests. And for each unique request, it cache `<CacheListSize>` different results. A random result will be returned if the cache is hit.
The `<CacheSize>` can be an integer in `[0, MAXINT]`. If it's `0`, cache won't be applied. The `<CacheListSize>` can be an integer in `[1, MAXINT]`.
### Other configurations
```shell
python opt_fastapi.py -h
```
# How to benchmark
```shell
cd benchmark
locust
```
Then open the web interface link which is on your console.
# Pre-process pre-trained weights
## OPT-66B
See [script/processing_ckpt_66b.py](./script/processing_ckpt_66b.py).
## OPT-175B
See [script/process-opt-175b](./script/process-opt-175b/).

View File

@@ -0,0 +1,59 @@
import torch
from typing import List, Deque, Tuple, Hashable, Any
from energonai import BatchManager, SubmitEntry, TaskEntry
class BatchManagerForGeneration(BatchManager):
def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.pad_token_id = pad_token_id
def _left_padding(self, batch_inputs):
max_len = max(len(inputs['input_ids']) for inputs in batch_inputs)
outputs = {'input_ids': [], 'attention_mask': []}
for inputs in batch_inputs:
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
padding_len = max_len - len(input_ids)
input_ids = [self.pad_token_id] * padding_len + input_ids
attention_mask = [0] * padding_len + attention_mask
outputs['input_ids'].append(input_ids)
outputs['attention_mask'].append(attention_mask)
for k in outputs:
outputs[k] = torch.tensor(outputs[k])
return outputs, max_len
@staticmethod
def _make_batch_key(entry: SubmitEntry) -> tuple:
data = entry.data
return (data['top_k'], data['top_p'], data['temperature'])
def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]:
entry = q.popleft()
uids = [entry.uid]
batch = [entry.data]
while len(batch) < self.max_batch_size:
if len(q) == 0:
break
if self._make_batch_key(entry) != self._make_batch_key(q[0]):
break
if q[0].data['max_tokens'] > entry.data['max_tokens']:
break
e = q.popleft()
batch.append(e.data)
uids.append(e.uid)
inputs, max_len = self._left_padding(batch)
trunc_lens = []
for data in batch:
trunc_lens.append(max_len + data['max_tokens'])
inputs['top_k'] = entry.data['top_k']
inputs['top_p'] = entry.data['top_p']
inputs['temperature'] = entry.data['temperature']
inputs['max_tokens'] = max_len + entry.data['max_tokens']
return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens}
def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]:
retval = []
for uid, output, trunc_len in zip(task_entry.uids, task_entry.batch, trunc_lens):
retval.append((uid, output[:trunc_len]))
return retval

View File

@@ -0,0 +1,15 @@
from locust import HttpUser, task
from json import JSONDecodeError
class GenerationUser(HttpUser):
@task
def generate(self):
prompt = 'Question: What is the longest river on the earth? Answer:'
for i in range(4, 9):
data = {'max_tokens': 2**i, 'prompt': prompt}
with self.client.post('/generation', json=data, catch_response=True) as response:
if response.status_code in (200, 406):
response.success()
else:
response.failure('Response wrong')

View File

@@ -0,0 +1,64 @@
from collections import OrderedDict
from threading import Lock
from contextlib import contextmanager
from typing import List, Any, Hashable, Dict
class MissCacheError(Exception):
pass
class ListCache:
def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None:
"""Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied.
When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed.
Args:
cache_size (int): Max size for LRU cache.
list_size (int): Value list size.
fixed_keys (List[Hashable], optional): The keys which won't be removed. Defaults to [].
"""
self.cache_size = cache_size
self.list_size = list_size
self.cache: OrderedDict[Hashable, List[Any]] = OrderedDict()
self.fixed_cache: Dict[Hashable, List[Any]] = {}
for key in fixed_keys:
self.fixed_cache[key] = []
self._lock = Lock()
def get(self, key: Hashable) -> List[Any]:
with self.lock():
if key in self.fixed_cache:
l = self.fixed_cache[key]
if len(l) >= self.list_size:
return l
elif key in self.cache:
self.cache.move_to_end(key)
l = self.cache[key]
if len(l) >= self.list_size:
return l
raise MissCacheError()
def add(self, key: Hashable, value: Any) -> None:
with self.lock():
if key in self.fixed_cache:
l = self.fixed_cache[key]
if len(l) < self.list_size and value not in l:
l.append(value)
elif key in self.cache:
self.cache.move_to_end(key)
l = self.cache[key]
if len(l) < self.list_size and value not in l:
l.append(value)
else:
if len(self.cache) >= self.cache_size:
self.cache.popitem(last=False)
self.cache[key] = [value]
@contextmanager
def lock(self):
try:
self._lock.acquire()
yield
finally:
self._lock.release()

View File

@@ -0,0 +1,123 @@
import argparse
import logging
import random
from typing import Optional
import uvicorn
from energonai import QueueFullError, launch_engine
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field
from transformers import GPT2Tokenizer
from batch import BatchManagerForGeneration
from cache import ListCache, MissCacheError
class GenerationTaskReq(BaseModel):
max_tokens: int = Field(gt=0, le=256, example=64)
prompt: str = Field(
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
top_k: Optional[int] = Field(default=None, gt=0, example=50)
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
app = FastAPI()
@app.post('/generation')
async def generate(data: GenerationTaskReq, request: Request):
logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}')
key = (data.prompt, data.max_tokens)
try:
if cache is None:
raise MissCacheError()
outputs = cache.get(key)
output = random.choice(outputs)
logger.info('Cache hit')
except MissCacheError:
inputs = tokenizer(data.prompt, truncation=True, max_length=512)
inputs['max_tokens'] = data.max_tokens
inputs['top_k'] = data.top_k
inputs['top_p'] = data.top_p
inputs['temperature'] = data.temperature
try:
uid = id(data)
engine.submit(uid, inputs)
output = await engine.wait(uid)
output = tokenizer.decode(output, skip_special_tokens=True)
if cache is not None:
cache.add(key, output)
except QueueFullError as e:
raise HTTPException(status_code=406, detail=e.args[0])
return {'text': output}
@app.on_event("shutdown")
async def shutdown(*_):
engine.shutdown()
server.should_exit = True
server.force_exit = True
await server.shutdown()
def get_model_fn(model_name: str):
model_map = {
'opt-125m': opt_125M,
'opt-6.7b': opt_6B,
'opt-30b': opt_30B,
'opt-175b': opt_175B
}
return model_map[model_name]
def print_args(args: argparse.Namespace):
print('\n==> Args:')
for k, v in args.__dict__.items():
print(f'{k} = {v}')
FIXED_CACHE_KEYS = [
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
parser.add_argument('--tp', type=int, default=1)
parser.add_argument('--master_host', default='localhost')
parser.add_argument('--master_port', type=int, default=19990)
parser.add_argument('--rpc_port', type=int, default=19980)
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--pipe_size', type=int, default=1)
parser.add_argument('--queue_size', type=int, default=0)
parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070)
parser.add_argument('--checkpoint', default=None)
parser.add_argument('--cache_size', type=int, default=0)
parser.add_argument('--cache_list_size', type=int, default=1)
args = parser.parse_args()
print_args(args)
model_kwargs = {}
if args.checkpoint is not None:
model_kwargs['checkpoint'] = args.checkpoint
logger = logging.getLogger(__name__)
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
if args.cache_size > 0:
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
else:
cache = None
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
pad_token_id=tokenizer.pad_token_id),
pipe_size=args.pipe_size,
queue_size=args.queue_size,
**model_kwargs)
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
server = uvicorn.Server(config=config)
server.run()

View File

@@ -0,0 +1,122 @@
import logging
import argparse
import random
from torch import Tensor
from pydantic import BaseModel, Field
from typing import Optional
from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B
from transformers import GPT2Tokenizer
from energonai import launch_engine, QueueFullError
from sanic import Sanic
from sanic.request import Request
from sanic.response import json
from sanic_ext import validate, openapi
from batch import BatchManagerForGeneration
from cache import ListCache, MissCacheError
class GenerationTaskReq(BaseModel):
max_tokens: int = Field(gt=0, le=256, example=64)
prompt: str = Field(
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
top_k: Optional[int] = Field(default=None, gt=0, example=50)
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
app = Sanic('opt')
@app.post('/generation')
@openapi.body(GenerationTaskReq)
@validate(json=GenerationTaskReq)
async def generate(request: Request, body: GenerationTaskReq):
logger.info(f'{request.ip}:{request.port} - "{request.method} {request.path}" - {body}')
key = (body.prompt, body.max_tokens)
try:
if cache is None:
raise MissCacheError()
outputs = cache.get(key)
output = random.choice(outputs)
logger.info('Cache hit')
except MissCacheError:
inputs = tokenizer(body.prompt, truncation=True, max_length=512)
inputs['max_tokens'] = body.max_tokens
inputs['top_k'] = body.top_k
inputs['top_p'] = body.top_p
inputs['temperature'] = body.temperature
try:
uid = id(body)
engine.submit(uid, inputs)
output = await engine.wait(uid)
assert isinstance(output, Tensor)
output = tokenizer.decode(output, skip_special_tokens=True)
if cache is not None:
cache.add(key, output)
except QueueFullError as e:
return json({'detail': e.args[0]}, status=406)
return json({'text': output})
@app.after_server_stop
def shutdown(*_):
engine.shutdown()
def get_model_fn(model_name: str):
model_map = {
'opt-125m': opt_125M,
'opt-6.7b': opt_6B,
'opt-30b': opt_30B,
'opt-175b': opt_175B
}
return model_map[model_name]
def print_args(args: argparse.Namespace):
print('\n==> Args:')
for k, v in args.__dict__.items():
print(f'{k} = {v}')
FIXED_CACHE_KEYS = [
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
parser.add_argument('--tp', type=int, default=1)
parser.add_argument('--master_host', default='localhost')
parser.add_argument('--master_port', type=int, default=19990)
parser.add_argument('--rpc_port', type=int, default=19980)
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--pipe_size', type=int, default=1)
parser.add_argument('--queue_size', type=int, default=0)
parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070)
parser.add_argument('--checkpoint', default=None)
parser.add_argument('--cache_size', type=int, default=0)
parser.add_argument('--cache_list_size', type=int, default=1)
args = parser.parse_args()
print_args(args)
model_kwargs = {}
if args.checkpoint is not None:
model_kwargs['checkpoint'] = args.checkpoint
logger = logging.getLogger(__name__)
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
if args.cache_size > 0:
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
else:
cache = None
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
pad_token_id=tokenizer.pad_token_id),
pipe_size=args.pipe_size,
queue_size=args.queue_size,
**model_kwargs)
app.run(args.http_host, args.http_port)

View File

@@ -0,0 +1,8 @@
fastapi==0.85.1
locust==2.11.0
pydantic==1.10.2
sanic==22.9.0
sanic_ext==22.9.0
torch>=1.10.0
transformers==4.23.1
uvicorn==0.19.0

View File

@@ -0,0 +1,46 @@
# Process OPT-175B weights
You should download the pre-trained weights following the [doc](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) before reading this.
First, install `metaseq` and `git clone https://github.com/facebookresearch/metaseq.git`.
Then, `cd metaseq`.
To consolidate checkpoints to eliminate FSDP:
```shell
bash metaseq/scripts/reshard_mp_launch_no_slurm.sh <directory_where_all_the_shards_are>/checkpoint_last <output_dir>/ 8 1
```
You will get 8 files in `<output_dir>`, and you should have the following checksums:
```
7e71cb65c4be784aa0b2889ac6039ee8 reshard-model_part-0-shard0.pt
c8123da04f2c25a9026ea3224d5d5022 reshard-model_part-1-shard0.pt
45e5d10896382e5bc4a7064fcafd2b1e reshard-model_part-2-shard0.pt
abb7296c4d2fc17420b84ca74fc3ce64 reshard-model_part-3-shard0.pt
05dcc7ac6046f4d3f90b3d1068e6da15 reshard-model_part-4-shard0.pt
d24dd334019060ce1ee7e625fcf6b4bd reshard-model_part-5-shard0.pt
fb1615ce0bbe89cc717f3e5079ee2655 reshard-model_part-6-shard0.pt
2f3124432d2dbc6aebfca06be4b791c2 reshard-model_part-7-shard0.pt
```
Copy `flat-meta.json` to `<output_dir>`.
Then cd to this dir, and we unflatten parameters.
```shell
bash unflat.sh <output_dir>/ <new_output_dir>/
```
Finally, you will get 8 files in `<new_output_dir>` with following checksums:
```
6169c59d014be95553c89ec01b8abb62 reshard-model_part-0.pt
58868105da3d74a528a548fdb3a8cff6 reshard-model_part-1.pt
69b255dc5a49d0eba9e4b60432cda90b reshard-model_part-2.pt
002c052461ff9ffb0cdac3d5906f41f2 reshard-model_part-3.pt
6d57f72909320d511ffd5f1c668b2beb reshard-model_part-4.pt
93c8c4041cdc0c7907cc7afcf15cec2a reshard-model_part-5.pt
5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt
f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt
```

View File

@@ -0,0 +1,55 @@
import argparse
import json
import os
import re
from collections import defaultdict
import numpy as np
import torch
def load_json(path: str):
with open(path) as f:
return json.load(f)
def parse_shape_info(flat_dir: str):
data = load_json(os.path.join(flat_dir, 'shape.json'))
flat_info = defaultdict(lambda: defaultdict(list))
for k, shape in data.items():
matched = re.match(r'decoder.layers.\d+', k)
if matched is None:
flat_key = 'flat_param_0'
else:
flat_key = f'{matched[0]}.flat_param_0'
flat_info[flat_key]['names'].append(k)
flat_info[flat_key]['shapes'].append(shape)
flat_info[flat_key]['numels'].append(int(np.prod(shape)))
return flat_info
def convert(flat_dir: str, output_dir: str, part: int):
flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt')
output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt')
flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json'))
flat_sd = torch.load(flat_path)
print(f'Loaded flat state dict from {flat_path}')
output_sd = {}
for flat_key, param_meta in flat_meta.items():
flat_param = flat_sd['model'][flat_key]
assert sum(param_meta['numels']) == flat_param.numel(
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])):
output_sd[name] = param.view(shape)
torch.save(output_sd, output_path)
print(f'Saved unflat state dict to {output_path}')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('flat_dir')
parser.add_argument('output_dir')
parser.add_argument('part', type=int)
args = parser.parse_args()
convert(args.flat_dir, args.output_dir, args.part)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,7 @@
#!/usr/bin/env sh
for i in $(seq 0 7); do
python convert_ckpt.py $1 $2 ${i} &
done
wait $(jobs -p)

View File

@@ -0,0 +1,55 @@
import os
import torch
from multiprocessing import Pool
# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main
# you can use whether wget or git lfs
path = "/path/to/your/ckpt"
new_path = "/path/to/the/processed/ckpt/"
assert os.path.isdir(path)
files = []
for filename in os.listdir(path):
filepath = os.path.join(path, filename)
if os.path.isfile(filepath):
files.append(filepath)
with Pool(14) as pool:
ckpts = pool.map(torch.load, files)
restored = {}
for ckpt in ckpts:
for k,v in ckpt.items():
if(k[0] == 'm'):
k = k[6:]
if(k == "lm_head.weight"):
k = "head.dense.weight"
if(k == "decoder.final_layer_norm.weight"):
k = "decoder.layer_norm.weight"
if(k == "decoder.final_layer_norm.bias"):
k = "decoder.layer_norm.bias"
restored[k] = v
restored["decoder.version"] = "0.0"
split_num = len(restored.keys()) // 60
count = 0
file_count = 1
tmp = {}
for k,v in restored.items():
print(k)
tmp[k] = v
count = count + 1
if(count == split_num):
filename = str(file_count) + "-restored.pt"
torch.save(tmp, os.path.join(new_path, filename))
file_count = file_count + 1
count = 0
tmp = {}
filename = str(file_count) + "-restored.pt"
torch.save(tmp, os.path.join(new_path, filename))