mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-29 04:40:36 +00:00
[tutorial] edited hands-on practices (#1899)
* Add handson to ColossalAI. * Change names of handsons and edit sequence parallel example. * Edit wrong folder name * resolve conflict * delete readme
This commit is contained in:
77
examples/tutorial/opt/inference/README.md
Normal file
77
examples/tutorial/opt/inference/README.md
Normal file
@@ -0,0 +1,77 @@
|
||||
# Overview
|
||||
|
||||
This is an example showing how to run OPT generation. The OPT model is implemented using ColossalAI.
|
||||
|
||||
It supports tensor parallelism, batching and caching.
|
||||
|
||||
# How to run
|
||||
|
||||
Run OPT-125M:
|
||||
```shell
|
||||
python opt_fastapi.py opt-125m
|
||||
```
|
||||
|
||||
It will launch a HTTP server on `0.0.0.0:7070` by default and you can customize host and port. You can open `localhost:7070/docs` in your browser to see the openapi docs.
|
||||
|
||||
## Configure
|
||||
|
||||
### Configure model
|
||||
```shell
|
||||
python opt_fastapi.py <model>
|
||||
```
|
||||
Available models: opt-125m, opt-6.7b, opt-30b, opt-175b.
|
||||
|
||||
### Configure tensor parallelism
|
||||
```shell
|
||||
python opt_fastapi.py <model> --tp <TensorParallelismWorldSize>
|
||||
```
|
||||
The `<TensorParallelismWorldSize>` can be an integer in `[1, #GPUs]`. Default `1`.
|
||||
|
||||
### Configure checkpoint
|
||||
```shell
|
||||
python opt_fastapi.py <model> --checkpoint <CheckpointPath>
|
||||
```
|
||||
The `<CheckpointPath>` can be a file path or a directory path. If it's a directory path, all files under the directory will be loaded.
|
||||
|
||||
### Configure queue
|
||||
```shell
|
||||
python opt_fastapi.py <model> --queue_size <QueueSize>
|
||||
```
|
||||
The `<QueueSize>` can be an integer in `[0, MAXINT]`. If it's `0`, the request queue size is infinite. If it's a positive integer, when the request queue is full, incoming requests will be dropped (the HTTP status code of response will be 406).
|
||||
|
||||
### Configure bathcing
|
||||
```shell
|
||||
python opt_fastapi.py <model> --max_batch_size <MaxBatchSize>
|
||||
```
|
||||
The `<MaxBatchSize>` can be an integer in `[1, MAXINT]`. The engine will make batch whose size is less or equal to this value.
|
||||
|
||||
Note that the batch size is not always equal to `<MaxBatchSize>`, as some consecutive requests may not be batched.
|
||||
|
||||
### Configure caching
|
||||
```shell
|
||||
python opt_fastapi.py <model> --cache_size <CacheSize> --cache_list_size <CacheListSize>
|
||||
```
|
||||
This will cache `<CacheSize>` unique requests. And for each unique request, it cache `<CacheListSize>` different results. A random result will be returned if the cache is hit.
|
||||
|
||||
The `<CacheSize>` can be an integer in `[0, MAXINT]`. If it's `0`, cache won't be applied. The `<CacheListSize>` can be an integer in `[1, MAXINT]`.
|
||||
|
||||
### Other configurations
|
||||
```shell
|
||||
python opt_fastapi.py -h
|
||||
```
|
||||
|
||||
# How to benchmark
|
||||
```shell
|
||||
cd benchmark
|
||||
locust
|
||||
```
|
||||
|
||||
Then open the web interface link which is on your console.
|
||||
|
||||
# Pre-process pre-trained weights
|
||||
|
||||
## OPT-66B
|
||||
See [script/processing_ckpt_66b.py](./script/processing_ckpt_66b.py).
|
||||
|
||||
## OPT-175B
|
||||
See [script/process-opt-175b](./script/process-opt-175b/).
|
||||
59
examples/tutorial/opt/inference/batch.py
Normal file
59
examples/tutorial/opt/inference/batch.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
from typing import List, Deque, Tuple, Hashable, Any
|
||||
from energonai import BatchManager, SubmitEntry, TaskEntry
|
||||
|
||||
|
||||
class BatchManagerForGeneration(BatchManager):
|
||||
def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None:
|
||||
super().__init__()
|
||||
self.max_batch_size = max_batch_size
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
def _left_padding(self, batch_inputs):
|
||||
max_len = max(len(inputs['input_ids']) for inputs in batch_inputs)
|
||||
outputs = {'input_ids': [], 'attention_mask': []}
|
||||
for inputs in batch_inputs:
|
||||
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
|
||||
padding_len = max_len - len(input_ids)
|
||||
input_ids = [self.pad_token_id] * padding_len + input_ids
|
||||
attention_mask = [0] * padding_len + attention_mask
|
||||
outputs['input_ids'].append(input_ids)
|
||||
outputs['attention_mask'].append(attention_mask)
|
||||
for k in outputs:
|
||||
outputs[k] = torch.tensor(outputs[k])
|
||||
return outputs, max_len
|
||||
|
||||
@staticmethod
|
||||
def _make_batch_key(entry: SubmitEntry) -> tuple:
|
||||
data = entry.data
|
||||
return (data['top_k'], data['top_p'], data['temperature'])
|
||||
|
||||
def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]:
|
||||
entry = q.popleft()
|
||||
uids = [entry.uid]
|
||||
batch = [entry.data]
|
||||
while len(batch) < self.max_batch_size:
|
||||
if len(q) == 0:
|
||||
break
|
||||
if self._make_batch_key(entry) != self._make_batch_key(q[0]):
|
||||
break
|
||||
if q[0].data['max_tokens'] > entry.data['max_tokens']:
|
||||
break
|
||||
e = q.popleft()
|
||||
batch.append(e.data)
|
||||
uids.append(e.uid)
|
||||
inputs, max_len = self._left_padding(batch)
|
||||
trunc_lens = []
|
||||
for data in batch:
|
||||
trunc_lens.append(max_len + data['max_tokens'])
|
||||
inputs['top_k'] = entry.data['top_k']
|
||||
inputs['top_p'] = entry.data['top_p']
|
||||
inputs['temperature'] = entry.data['temperature']
|
||||
inputs['max_tokens'] = max_len + entry.data['max_tokens']
|
||||
return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens}
|
||||
|
||||
def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]:
|
||||
retval = []
|
||||
for uid, output, trunc_len in zip(task_entry.uids, task_entry.batch, trunc_lens):
|
||||
retval.append((uid, output[:trunc_len]))
|
||||
return retval
|
||||
15
examples/tutorial/opt/inference/benchmark/locustfile.py
Normal file
15
examples/tutorial/opt/inference/benchmark/locustfile.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from locust import HttpUser, task
|
||||
from json import JSONDecodeError
|
||||
|
||||
|
||||
class GenerationUser(HttpUser):
|
||||
@task
|
||||
def generate(self):
|
||||
prompt = 'Question: What is the longest river on the earth? Answer:'
|
||||
for i in range(4, 9):
|
||||
data = {'max_tokens': 2**i, 'prompt': prompt}
|
||||
with self.client.post('/generation', json=data, catch_response=True) as response:
|
||||
if response.status_code in (200, 406):
|
||||
response.success()
|
||||
else:
|
||||
response.failure('Response wrong')
|
||||
64
examples/tutorial/opt/inference/cache.py
Normal file
64
examples/tutorial/opt/inference/cache.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from collections import OrderedDict
|
||||
from threading import Lock
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Any, Hashable, Dict
|
||||
|
||||
|
||||
class MissCacheError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ListCache:
|
||||
def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None:
|
||||
"""Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied.
|
||||
When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed.
|
||||
|
||||
Args:
|
||||
cache_size (int): Max size for LRU cache.
|
||||
list_size (int): Value list size.
|
||||
fixed_keys (List[Hashable], optional): The keys which won't be removed. Defaults to [].
|
||||
"""
|
||||
self.cache_size = cache_size
|
||||
self.list_size = list_size
|
||||
self.cache: OrderedDict[Hashable, List[Any]] = OrderedDict()
|
||||
self.fixed_cache: Dict[Hashable, List[Any]] = {}
|
||||
for key in fixed_keys:
|
||||
self.fixed_cache[key] = []
|
||||
self._lock = Lock()
|
||||
|
||||
def get(self, key: Hashable) -> List[Any]:
|
||||
with self.lock():
|
||||
if key in self.fixed_cache:
|
||||
l = self.fixed_cache[key]
|
||||
if len(l) >= self.list_size:
|
||||
return l
|
||||
elif key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
l = self.cache[key]
|
||||
if len(l) >= self.list_size:
|
||||
return l
|
||||
raise MissCacheError()
|
||||
|
||||
def add(self, key: Hashable, value: Any) -> None:
|
||||
with self.lock():
|
||||
if key in self.fixed_cache:
|
||||
l = self.fixed_cache[key]
|
||||
if len(l) < self.list_size and value not in l:
|
||||
l.append(value)
|
||||
elif key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
l = self.cache[key]
|
||||
if len(l) < self.list_size and value not in l:
|
||||
l.append(value)
|
||||
else:
|
||||
if len(self.cache) >= self.cache_size:
|
||||
self.cache.popitem(last=False)
|
||||
self.cache[key] = [value]
|
||||
|
||||
@contextmanager
|
||||
def lock(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
self._lock.release()
|
||||
123
examples/tutorial/opt/inference/opt_fastapi.py
Normal file
123
examples/tutorial/opt/inference/opt_fastapi.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from energonai import QueueFullError, launch_engine
|
||||
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
max_tokens: int = Field(gt=0, le=256, example=64)
|
||||
prompt: str = Field(
|
||||
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
|
||||
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
||||
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
||||
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post('/generation')
|
||||
async def generate(data: GenerationTaskReq, request: Request):
|
||||
logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}')
|
||||
key = (data.prompt, data.max_tokens)
|
||||
try:
|
||||
if cache is None:
|
||||
raise MissCacheError()
|
||||
outputs = cache.get(key)
|
||||
output = random.choice(outputs)
|
||||
logger.info('Cache hit')
|
||||
except MissCacheError:
|
||||
inputs = tokenizer(data.prompt, truncation=True, max_length=512)
|
||||
inputs['max_tokens'] = data.max_tokens
|
||||
inputs['top_k'] = data.top_k
|
||||
inputs['top_p'] = data.top_p
|
||||
inputs['temperature'] = data.temperature
|
||||
try:
|
||||
uid = id(data)
|
||||
engine.submit(uid, inputs)
|
||||
output = await engine.wait(uid)
|
||||
output = tokenizer.decode(output, skip_special_tokens=True)
|
||||
if cache is not None:
|
||||
cache.add(key, output)
|
||||
except QueueFullError as e:
|
||||
raise HTTPException(status_code=406, detail=e.args[0])
|
||||
|
||||
return {'text': output}
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown(*_):
|
||||
engine.shutdown()
|
||||
server.should_exit = True
|
||||
server.force_exit = True
|
||||
await server.shutdown()
|
||||
|
||||
|
||||
def get_model_fn(model_name: str):
|
||||
model_map = {
|
||||
'opt-125m': opt_125M,
|
||||
'opt-6.7b': opt_6B,
|
||||
'opt-30b': opt_30B,
|
||||
'opt-175b': opt_175B
|
||||
}
|
||||
return model_map[model_name]
|
||||
|
||||
|
||||
def print_args(args: argparse.Namespace):
|
||||
print('\n==> Args:')
|
||||
for k, v in args.__dict__.items():
|
||||
print(f'{k} = {v}')
|
||||
|
||||
|
||||
FIXED_CACHE_KEYS = [
|
||||
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
|
||||
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
|
||||
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
|
||||
parser.add_argument('--tp', type=int, default=1)
|
||||
parser.add_argument('--master_host', default='localhost')
|
||||
parser.add_argument('--master_port', type=int, default=19990)
|
||||
parser.add_argument('--rpc_port', type=int, default=19980)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--pipe_size', type=int, default=1)
|
||||
parser.add_argument('--queue_size', type=int, default=0)
|
||||
parser.add_argument('--http_host', default='0.0.0.0')
|
||||
parser.add_argument('--http_port', type=int, default=7070)
|
||||
parser.add_argument('--checkpoint', default=None)
|
||||
parser.add_argument('--cache_size', type=int, default=0)
|
||||
parser.add_argument('--cache_list_size', type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
print_args(args)
|
||||
model_kwargs = {}
|
||||
if args.checkpoint is not None:
|
||||
model_kwargs['checkpoint'] = args.checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
|
||||
if args.cache_size > 0:
|
||||
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
|
||||
else:
|
||||
cache = None
|
||||
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
|
||||
pad_token_id=tokenizer.pad_token_id),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs)
|
||||
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
|
||||
server = uvicorn.Server(config=config)
|
||||
server.run()
|
||||
122
examples/tutorial/opt/inference/opt_server.py
Normal file
122
examples/tutorial/opt/inference/opt_server.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import logging
|
||||
import argparse
|
||||
import random
|
||||
from torch import Tensor
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B
|
||||
from transformers import GPT2Tokenizer
|
||||
from energonai import launch_engine, QueueFullError
|
||||
from sanic import Sanic
|
||||
from sanic.request import Request
|
||||
from sanic.response import json
|
||||
from sanic_ext import validate, openapi
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
max_tokens: int = Field(gt=0, le=256, example=64)
|
||||
prompt: str = Field(
|
||||
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
|
||||
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
||||
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
||||
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
||||
|
||||
|
||||
app = Sanic('opt')
|
||||
|
||||
|
||||
@app.post('/generation')
|
||||
@openapi.body(GenerationTaskReq)
|
||||
@validate(json=GenerationTaskReq)
|
||||
async def generate(request: Request, body: GenerationTaskReq):
|
||||
logger.info(f'{request.ip}:{request.port} - "{request.method} {request.path}" - {body}')
|
||||
key = (body.prompt, body.max_tokens)
|
||||
try:
|
||||
if cache is None:
|
||||
raise MissCacheError()
|
||||
outputs = cache.get(key)
|
||||
output = random.choice(outputs)
|
||||
logger.info('Cache hit')
|
||||
except MissCacheError:
|
||||
inputs = tokenizer(body.prompt, truncation=True, max_length=512)
|
||||
inputs['max_tokens'] = body.max_tokens
|
||||
inputs['top_k'] = body.top_k
|
||||
inputs['top_p'] = body.top_p
|
||||
inputs['temperature'] = body.temperature
|
||||
try:
|
||||
uid = id(body)
|
||||
engine.submit(uid, inputs)
|
||||
output = await engine.wait(uid)
|
||||
assert isinstance(output, Tensor)
|
||||
output = tokenizer.decode(output, skip_special_tokens=True)
|
||||
if cache is not None:
|
||||
cache.add(key, output)
|
||||
except QueueFullError as e:
|
||||
return json({'detail': e.args[0]}, status=406)
|
||||
|
||||
return json({'text': output})
|
||||
|
||||
|
||||
@app.after_server_stop
|
||||
def shutdown(*_):
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
def get_model_fn(model_name: str):
|
||||
model_map = {
|
||||
'opt-125m': opt_125M,
|
||||
'opt-6.7b': opt_6B,
|
||||
'opt-30b': opt_30B,
|
||||
'opt-175b': opt_175B
|
||||
}
|
||||
return model_map[model_name]
|
||||
|
||||
|
||||
def print_args(args: argparse.Namespace):
|
||||
print('\n==> Args:')
|
||||
for k, v in args.__dict__.items():
|
||||
print(f'{k} = {v}')
|
||||
|
||||
|
||||
FIXED_CACHE_KEYS = [
|
||||
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
|
||||
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
|
||||
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
|
||||
parser.add_argument('--tp', type=int, default=1)
|
||||
parser.add_argument('--master_host', default='localhost')
|
||||
parser.add_argument('--master_port', type=int, default=19990)
|
||||
parser.add_argument('--rpc_port', type=int, default=19980)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--pipe_size', type=int, default=1)
|
||||
parser.add_argument('--queue_size', type=int, default=0)
|
||||
parser.add_argument('--http_host', default='0.0.0.0')
|
||||
parser.add_argument('--http_port', type=int, default=7070)
|
||||
parser.add_argument('--checkpoint', default=None)
|
||||
parser.add_argument('--cache_size', type=int, default=0)
|
||||
parser.add_argument('--cache_list_size', type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
print_args(args)
|
||||
model_kwargs = {}
|
||||
if args.checkpoint is not None:
|
||||
model_kwargs['checkpoint'] = args.checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
|
||||
if args.cache_size > 0:
|
||||
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
|
||||
else:
|
||||
cache = None
|
||||
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
|
||||
pad_token_id=tokenizer.pad_token_id),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs)
|
||||
app.run(args.http_host, args.http_port)
|
||||
8
examples/tutorial/opt/inference/requirements.txt
Normal file
8
examples/tutorial/opt/inference/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi==0.85.1
|
||||
locust==2.11.0
|
||||
pydantic==1.10.2
|
||||
sanic==22.9.0
|
||||
sanic_ext==22.9.0
|
||||
torch>=1.10.0
|
||||
transformers==4.23.1
|
||||
uvicorn==0.19.0
|
||||
@@ -0,0 +1,46 @@
|
||||
# Process OPT-175B weights
|
||||
|
||||
You should download the pre-trained weights following the [doc](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) before reading this.
|
||||
|
||||
First, install `metaseq` and `git clone https://github.com/facebookresearch/metaseq.git`.
|
||||
|
||||
Then, `cd metaseq`.
|
||||
|
||||
To consolidate checkpoints to eliminate FSDP:
|
||||
|
||||
```shell
|
||||
bash metaseq/scripts/reshard_mp_launch_no_slurm.sh <directory_where_all_the_shards_are>/checkpoint_last <output_dir>/ 8 1
|
||||
```
|
||||
|
||||
You will get 8 files in `<output_dir>`, and you should have the following checksums:
|
||||
```
|
||||
7e71cb65c4be784aa0b2889ac6039ee8 reshard-model_part-0-shard0.pt
|
||||
c8123da04f2c25a9026ea3224d5d5022 reshard-model_part-1-shard0.pt
|
||||
45e5d10896382e5bc4a7064fcafd2b1e reshard-model_part-2-shard0.pt
|
||||
abb7296c4d2fc17420b84ca74fc3ce64 reshard-model_part-3-shard0.pt
|
||||
05dcc7ac6046f4d3f90b3d1068e6da15 reshard-model_part-4-shard0.pt
|
||||
d24dd334019060ce1ee7e625fcf6b4bd reshard-model_part-5-shard0.pt
|
||||
fb1615ce0bbe89cc717f3e5079ee2655 reshard-model_part-6-shard0.pt
|
||||
2f3124432d2dbc6aebfca06be4b791c2 reshard-model_part-7-shard0.pt
|
||||
```
|
||||
|
||||
Copy `flat-meta.json` to `<output_dir>`.
|
||||
|
||||
Then cd to this dir, and we unflatten parameters.
|
||||
|
||||
```shell
|
||||
bash unflat.sh <output_dir>/ <new_output_dir>/
|
||||
```
|
||||
|
||||
Finally, you will get 8 files in `<new_output_dir>` with following checksums:
|
||||
```
|
||||
6169c59d014be95553c89ec01b8abb62 reshard-model_part-0.pt
|
||||
58868105da3d74a528a548fdb3a8cff6 reshard-model_part-1.pt
|
||||
69b255dc5a49d0eba9e4b60432cda90b reshard-model_part-2.pt
|
||||
002c052461ff9ffb0cdac3d5906f41f2 reshard-model_part-3.pt
|
||||
6d57f72909320d511ffd5f1c668b2beb reshard-model_part-4.pt
|
||||
93c8c4041cdc0c7907cc7afcf15cec2a reshard-model_part-5.pt
|
||||
5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt
|
||||
f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt
|
||||
```
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def load_json(path: str):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def parse_shape_info(flat_dir: str):
|
||||
data = load_json(os.path.join(flat_dir, 'shape.json'))
|
||||
flat_info = defaultdict(lambda: defaultdict(list))
|
||||
for k, shape in data.items():
|
||||
matched = re.match(r'decoder.layers.\d+', k)
|
||||
if matched is None:
|
||||
flat_key = 'flat_param_0'
|
||||
else:
|
||||
flat_key = f'{matched[0]}.flat_param_0'
|
||||
flat_info[flat_key]['names'].append(k)
|
||||
flat_info[flat_key]['shapes'].append(shape)
|
||||
flat_info[flat_key]['numels'].append(int(np.prod(shape)))
|
||||
return flat_info
|
||||
|
||||
|
||||
def convert(flat_dir: str, output_dir: str, part: int):
|
||||
flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt')
|
||||
output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt')
|
||||
flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json'))
|
||||
flat_sd = torch.load(flat_path)
|
||||
print(f'Loaded flat state dict from {flat_path}')
|
||||
output_sd = {}
|
||||
for flat_key, param_meta in flat_meta.items():
|
||||
flat_param = flat_sd['model'][flat_key]
|
||||
assert sum(param_meta['numels']) == flat_param.numel(
|
||||
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
|
||||
for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])):
|
||||
output_sd[name] = param.view(shape)
|
||||
|
||||
torch.save(output_sd, output_path)
|
||||
print(f'Saved unflat state dict to {output_path}')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('flat_dir')
|
||||
parser.add_argument('output_dir')
|
||||
parser.add_argument('part', type=int)
|
||||
args = parser.parse_args()
|
||||
convert(args.flat_dir, args.output_dir, args.part)
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
for i in $(seq 0 7); do
|
||||
python convert_ckpt.py $1 $2 ${i} &
|
||||
done
|
||||
|
||||
wait $(jobs -p)
|
||||
@@ -0,0 +1,55 @@
|
||||
import os
|
||||
import torch
|
||||
from multiprocessing import Pool
|
||||
|
||||
# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main
|
||||
# you can use whether wget or git lfs
|
||||
|
||||
path = "/path/to/your/ckpt"
|
||||
new_path = "/path/to/the/processed/ckpt/"
|
||||
|
||||
assert os.path.isdir(path)
|
||||
files = []
|
||||
for filename in os.listdir(path):
|
||||
filepath = os.path.join(path, filename)
|
||||
if os.path.isfile(filepath):
|
||||
files.append(filepath)
|
||||
|
||||
with Pool(14) as pool:
|
||||
ckpts = pool.map(torch.load, files)
|
||||
|
||||
restored = {}
|
||||
for ckpt in ckpts:
|
||||
for k,v in ckpt.items():
|
||||
if(k[0] == 'm'):
|
||||
k = k[6:]
|
||||
if(k == "lm_head.weight"):
|
||||
k = "head.dense.weight"
|
||||
if(k == "decoder.final_layer_norm.weight"):
|
||||
k = "decoder.layer_norm.weight"
|
||||
if(k == "decoder.final_layer_norm.bias"):
|
||||
k = "decoder.layer_norm.bias"
|
||||
restored[k] = v
|
||||
restored["decoder.version"] = "0.0"
|
||||
|
||||
|
||||
split_num = len(restored.keys()) // 60
|
||||
count = 0
|
||||
file_count = 1
|
||||
tmp = {}
|
||||
for k,v in restored.items():
|
||||
print(k)
|
||||
tmp[k] = v
|
||||
count = count + 1
|
||||
if(count == split_num):
|
||||
filename = str(file_count) + "-restored.pt"
|
||||
torch.save(tmp, os.path.join(new_path, filename))
|
||||
file_count = file_count + 1
|
||||
count = 0
|
||||
tmp = {}
|
||||
|
||||
filename = str(file_count) + "-restored.pt"
|
||||
torch.save(tmp, os.path.join(new_path, filename))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user