mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-29 12:52:44 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Deque, Hashable, List, Tuple
|
||||
|
||||
import torch
|
||||
from typing import List, Deque, Tuple, Hashable, Any
|
||||
from energonai import BatchManager, SubmitEntry, TaskEntry
|
||||
|
||||
|
||||
@@ -10,15 +11,15 @@ class BatchManagerForGeneration(BatchManager):
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
def _left_padding(self, batch_inputs):
|
||||
max_len = max(len(inputs['input_ids']) for inputs in batch_inputs)
|
||||
outputs = {'input_ids': [], 'attention_mask': []}
|
||||
max_len = max(len(inputs["input_ids"]) for inputs in batch_inputs)
|
||||
outputs = {"input_ids": [], "attention_mask": []}
|
||||
for inputs in batch_inputs:
|
||||
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
|
||||
input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
|
||||
padding_len = max_len - len(input_ids)
|
||||
input_ids = [self.pad_token_id] * padding_len + input_ids
|
||||
attention_mask = [0] * padding_len + attention_mask
|
||||
outputs['input_ids'].append(input_ids)
|
||||
outputs['attention_mask'].append(attention_mask)
|
||||
outputs["input_ids"].append(input_ids)
|
||||
outputs["attention_mask"].append(attention_mask)
|
||||
for k in outputs:
|
||||
outputs[k] = torch.tensor(outputs[k])
|
||||
return outputs, max_len
|
||||
@@ -26,7 +27,7 @@ class BatchManagerForGeneration(BatchManager):
|
||||
@staticmethod
|
||||
def _make_batch_key(entry: SubmitEntry) -> tuple:
|
||||
data = entry.data
|
||||
return (data['top_k'], data['top_p'], data['temperature'])
|
||||
return (data["top_k"], data["top_p"], data["temperature"])
|
||||
|
||||
def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]:
|
||||
entry = q.popleft()
|
||||
@@ -37,7 +38,7 @@ class BatchManagerForGeneration(BatchManager):
|
||||
break
|
||||
if self._make_batch_key(entry) != self._make_batch_key(q[0]):
|
||||
break
|
||||
if q[0].data['max_tokens'] > entry.data['max_tokens']:
|
||||
if q[0].data["max_tokens"] > entry.data["max_tokens"]:
|
||||
break
|
||||
e = q.popleft()
|
||||
batch.append(e.data)
|
||||
@@ -45,12 +46,12 @@ class BatchManagerForGeneration(BatchManager):
|
||||
inputs, max_len = self._left_padding(batch)
|
||||
trunc_lens = []
|
||||
for data in batch:
|
||||
trunc_lens.append(max_len + data['max_tokens'])
|
||||
inputs['top_k'] = entry.data['top_k']
|
||||
inputs['top_p'] = entry.data['top_p']
|
||||
inputs['temperature'] = entry.data['temperature']
|
||||
inputs['max_tokens'] = max_len + entry.data['max_tokens']
|
||||
return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens}
|
||||
trunc_lens.append(max_len + data["max_tokens"])
|
||||
inputs["top_k"] = entry.data["top_k"]
|
||||
inputs["top_p"] = entry.data["top_p"]
|
||||
inputs["temperature"] = entry.data["temperature"]
|
||||
inputs["max_tokens"] = max_len + entry.data["max_tokens"]
|
||||
return TaskEntry(tuple(uids), inputs), {"trunc_lens": trunc_lens}
|
||||
|
||||
def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]:
|
||||
retval = []
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
from locust import HttpUser, task
|
||||
from json import JSONDecodeError
|
||||
|
||||
|
||||
class GenerationUser(HttpUser):
|
||||
@task
|
||||
def generate(self):
|
||||
prompt = 'Question: What is the longest river on the earth? Answer:'
|
||||
prompt = "Question: What is the longest river on the earth? Answer:"
|
||||
for i in range(4, 9):
|
||||
data = {'max_tokens': 2**i, 'prompt': prompt}
|
||||
with self.client.post('/generation', json=data, catch_response=True) as response:
|
||||
data = {"max_tokens": 2**i, "prompt": prompt}
|
||||
with self.client.post("/generation", json=data, catch_response=True) as response:
|
||||
if response.status_code in (200, 406):
|
||||
response.success()
|
||||
else:
|
||||
response.failure('Response wrong')
|
||||
response.failure("Response wrong")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections import OrderedDict
|
||||
from threading import Lock
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Any, Hashable, Dict
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, Hashable, List
|
||||
|
||||
|
||||
class MissCacheError(Exception):
|
||||
|
||||
@@ -4,20 +4,21 @@ import random
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
from energonai import QueueFullError, launch_engine
|
||||
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
max_tokens: int = Field(gt=0, le=256, example=64)
|
||||
prompt: str = Field(
|
||||
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
|
||||
min_length=1,
|
||||
example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:",
|
||||
)
|
||||
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
||||
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
||||
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
||||
@@ -26,7 +27,7 @@ class GenerationTaskReq(BaseModel):
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post('/generation')
|
||||
@app.post("/generation")
|
||||
async def generate(data: GenerationTaskReq, request: Request):
|
||||
logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}')
|
||||
key = (data.prompt, data.max_tokens)
|
||||
@@ -35,13 +36,13 @@ async def generate(data: GenerationTaskReq, request: Request):
|
||||
raise MissCacheError()
|
||||
outputs = cache.get(key)
|
||||
output = random.choice(outputs)
|
||||
logger.info('Cache hit')
|
||||
logger.info("Cache hit")
|
||||
except MissCacheError:
|
||||
inputs = tokenizer(data.prompt, truncation=True, max_length=512)
|
||||
inputs['max_tokens'] = data.max_tokens
|
||||
inputs['top_k'] = data.top_k
|
||||
inputs['top_p'] = data.top_p
|
||||
inputs['temperature'] = data.temperature
|
||||
inputs["max_tokens"] = data.max_tokens
|
||||
inputs["top_k"] = data.top_k
|
||||
inputs["top_p"] = data.top_p
|
||||
inputs["temperature"] = data.temperature
|
||||
try:
|
||||
uid = id(data)
|
||||
engine.submit(uid, inputs)
|
||||
@@ -52,7 +53,7 @@ async def generate(data: GenerationTaskReq, request: Request):
|
||||
except QueueFullError as e:
|
||||
raise HTTPException(status_code=406, detail=e.args[0])
|
||||
|
||||
return {'text': output}
|
||||
return {"text": output}
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
@@ -64,60 +65,72 @@ async def shutdown(*_):
|
||||
|
||||
|
||||
def get_model_fn(model_name: str):
|
||||
model_map = {
|
||||
'opt-125m': opt_125M,
|
||||
'opt-6.7b': opt_6B,
|
||||
'opt-30b': opt_30B,
|
||||
'opt-175b': opt_175B
|
||||
}
|
||||
model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B}
|
||||
return model_map[model_name]
|
||||
|
||||
|
||||
def print_args(args: argparse.Namespace):
|
||||
print('\n==> Args:')
|
||||
print("\n==> Args:")
|
||||
for k, v in args.__dict__.items():
|
||||
print(f'{k} = {v}')
|
||||
print(f"{k} = {v}")
|
||||
|
||||
|
||||
FIXED_CACHE_KEYS = [
|
||||
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
|
||||
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
|
||||
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
|
||||
(
|
||||
"Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:",
|
||||
64,
|
||||
),
|
||||
(
|
||||
"A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.",
|
||||
64,
|
||||
),
|
||||
(
|
||||
"English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:",
|
||||
64,
|
||||
),
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
|
||||
parser.add_argument('--tp', type=int, default=1)
|
||||
parser.add_argument('--master_host', default='localhost')
|
||||
parser.add_argument('--master_port', type=int, default=19990)
|
||||
parser.add_argument('--rpc_port', type=int, default=19980)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--pipe_size', type=int, default=1)
|
||||
parser.add_argument('--queue_size', type=int, default=0)
|
||||
parser.add_argument('--http_host', default='0.0.0.0')
|
||||
parser.add_argument('--http_port', type=int, default=7070)
|
||||
parser.add_argument('--checkpoint', default=None)
|
||||
parser.add_argument('--cache_size', type=int, default=0)
|
||||
parser.add_argument('--cache_list_size', type=int, default=1)
|
||||
parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"])
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--master_host", default="localhost")
|
||||
parser.add_argument("--master_port", type=int, default=19990)
|
||||
parser.add_argument("--rpc_port", type=int, default=19980)
|
||||
parser.add_argument("--max_batch_size", type=int, default=8)
|
||||
parser.add_argument("--pipe_size", type=int, default=1)
|
||||
parser.add_argument("--queue_size", type=int, default=0)
|
||||
parser.add_argument("--http_host", default="0.0.0.0")
|
||||
parser.add_argument("--http_port", type=int, default=7070)
|
||||
parser.add_argument("--checkpoint", default=None)
|
||||
parser.add_argument("--cache_size", type=int, default=0)
|
||||
parser.add_argument("--cache_list_size", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
print_args(args)
|
||||
model_kwargs = {}
|
||||
if args.checkpoint is not None:
|
||||
model_kwargs['checkpoint'] = args.checkpoint
|
||||
model_kwargs["checkpoint"] = args.checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b")
|
||||
if args.cache_size > 0:
|
||||
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
|
||||
else:
|
||||
cache = None
|
||||
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
|
||||
pad_token_id=tokenizer.pad_token_id),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs)
|
||||
engine = launch_engine(
|
||||
args.tp,
|
||||
1,
|
||||
args.master_host,
|
||||
args.master_port,
|
||||
args.rpc_port,
|
||||
get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(
|
||||
max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id
|
||||
),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs,
|
||||
)
|
||||
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
|
||||
server = uvicorn.Server(config=config)
|
||||
server.run()
|
||||
|
||||
@@ -1,33 +1,36 @@
|
||||
import logging
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
from torch import Tensor
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B
|
||||
from transformers import GPT2Tokenizer
|
||||
from energonai import launch_engine, QueueFullError
|
||||
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
from energonai import QueueFullError, launch_engine
|
||||
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
|
||||
from pydantic import BaseModel, Field
|
||||
from sanic import Sanic
|
||||
from sanic.request import Request
|
||||
from sanic.response import json
|
||||
from sanic_ext import validate, openapi
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
from sanic_ext import openapi, validate
|
||||
from torch import Tensor
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
max_tokens: int = Field(gt=0, le=256, example=64)
|
||||
prompt: str = Field(
|
||||
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
|
||||
min_length=1,
|
||||
example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:",
|
||||
)
|
||||
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
||||
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
||||
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
||||
|
||||
|
||||
app = Sanic('opt')
|
||||
app = Sanic("opt")
|
||||
|
||||
|
||||
@app.post('/generation')
|
||||
@app.post("/generation")
|
||||
@openapi.body(GenerationTaskReq)
|
||||
@validate(json=GenerationTaskReq)
|
||||
async def generate(request: Request, body: GenerationTaskReq):
|
||||
@@ -38,13 +41,13 @@ async def generate(request: Request, body: GenerationTaskReq):
|
||||
raise MissCacheError()
|
||||
outputs = cache.get(key)
|
||||
output = random.choice(outputs)
|
||||
logger.info('Cache hit')
|
||||
logger.info("Cache hit")
|
||||
except MissCacheError:
|
||||
inputs = tokenizer(body.prompt, truncation=True, max_length=512)
|
||||
inputs['max_tokens'] = body.max_tokens
|
||||
inputs['top_k'] = body.top_k
|
||||
inputs['top_p'] = body.top_p
|
||||
inputs['temperature'] = body.temperature
|
||||
inputs["max_tokens"] = body.max_tokens
|
||||
inputs["top_k"] = body.top_k
|
||||
inputs["top_p"] = body.top_p
|
||||
inputs["temperature"] = body.temperature
|
||||
try:
|
||||
uid = id(body)
|
||||
engine.submit(uid, inputs)
|
||||
@@ -54,9 +57,9 @@ async def generate(request: Request, body: GenerationTaskReq):
|
||||
if cache is not None:
|
||||
cache.add(key, output)
|
||||
except QueueFullError as e:
|
||||
return json({'detail': e.args[0]}, status=406)
|
||||
return json({"detail": e.args[0]}, status=406)
|
||||
|
||||
return json({'text': output})
|
||||
return json({"text": output})
|
||||
|
||||
|
||||
@app.after_server_stop
|
||||
@@ -65,58 +68,70 @@ def shutdown(*_):
|
||||
|
||||
|
||||
def get_model_fn(model_name: str):
|
||||
model_map = {
|
||||
'opt-125m': opt_125M,
|
||||
'opt-6.7b': opt_6B,
|
||||
'opt-30b': opt_30B,
|
||||
'opt-175b': opt_175B
|
||||
}
|
||||
model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B}
|
||||
return model_map[model_name]
|
||||
|
||||
|
||||
def print_args(args: argparse.Namespace):
|
||||
print('\n==> Args:')
|
||||
print("\n==> Args:")
|
||||
for k, v in args.__dict__.items():
|
||||
print(f'{k} = {v}')
|
||||
print(f"{k} = {v}")
|
||||
|
||||
|
||||
FIXED_CACHE_KEYS = [
|
||||
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
|
||||
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
|
||||
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
|
||||
(
|
||||
"Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:",
|
||||
64,
|
||||
),
|
||||
(
|
||||
"A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.",
|
||||
64,
|
||||
),
|
||||
(
|
||||
"English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:",
|
||||
64,
|
||||
),
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
|
||||
parser.add_argument('--tp', type=int, default=1)
|
||||
parser.add_argument('--master_host', default='localhost')
|
||||
parser.add_argument('--master_port', type=int, default=19990)
|
||||
parser.add_argument('--rpc_port', type=int, default=19980)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--pipe_size', type=int, default=1)
|
||||
parser.add_argument('--queue_size', type=int, default=0)
|
||||
parser.add_argument('--http_host', default='0.0.0.0')
|
||||
parser.add_argument('--http_port', type=int, default=7070)
|
||||
parser.add_argument('--checkpoint', default=None)
|
||||
parser.add_argument('--cache_size', type=int, default=0)
|
||||
parser.add_argument('--cache_list_size', type=int, default=1)
|
||||
parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"])
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--master_host", default="localhost")
|
||||
parser.add_argument("--master_port", type=int, default=19990)
|
||||
parser.add_argument("--rpc_port", type=int, default=19980)
|
||||
parser.add_argument("--max_batch_size", type=int, default=8)
|
||||
parser.add_argument("--pipe_size", type=int, default=1)
|
||||
parser.add_argument("--queue_size", type=int, default=0)
|
||||
parser.add_argument("--http_host", default="0.0.0.0")
|
||||
parser.add_argument("--http_port", type=int, default=7070)
|
||||
parser.add_argument("--checkpoint", default=None)
|
||||
parser.add_argument("--cache_size", type=int, default=0)
|
||||
parser.add_argument("--cache_list_size", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
print_args(args)
|
||||
model_kwargs = {}
|
||||
if args.checkpoint is not None:
|
||||
model_kwargs['checkpoint'] = args.checkpoint
|
||||
model_kwargs["checkpoint"] = args.checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b")
|
||||
if args.cache_size > 0:
|
||||
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
|
||||
else:
|
||||
cache = None
|
||||
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
|
||||
pad_token_id=tokenizer.pad_token_id),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs)
|
||||
engine = launch_engine(
|
||||
args.tp,
|
||||
1,
|
||||
args.master_host,
|
||||
args.master_port,
|
||||
args.rpc_port,
|
||||
get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(
|
||||
max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id
|
||||
),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs,
|
||||
)
|
||||
app.run(args.http_host, args.http_port)
|
||||
|
||||
@@ -43,4 +43,3 @@ Finally, you will get 8 files in `<new_output_dir>` with following checksums:
|
||||
5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt
|
||||
f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt
|
||||
```
|
||||
|
||||
|
||||
@@ -14,42 +14,45 @@ def load_json(path: str):
|
||||
|
||||
|
||||
def parse_shape_info(flat_dir: str):
|
||||
data = load_json(os.path.join(flat_dir, 'shape.json'))
|
||||
data = load_json(os.path.join(flat_dir, "shape.json"))
|
||||
flat_info = defaultdict(lambda: defaultdict(list))
|
||||
for k, shape in data.items():
|
||||
matched = re.match(r'decoder.layers.\d+', k)
|
||||
matched = re.match(r"decoder.layers.\d+", k)
|
||||
if matched is None:
|
||||
flat_key = 'flat_param_0'
|
||||
flat_key = "flat_param_0"
|
||||
else:
|
||||
flat_key = f'{matched[0]}.flat_param_0'
|
||||
flat_info[flat_key]['names'].append(k)
|
||||
flat_info[flat_key]['shapes'].append(shape)
|
||||
flat_info[flat_key]['numels'].append(int(np.prod(shape)))
|
||||
flat_key = f"{matched[0]}.flat_param_0"
|
||||
flat_info[flat_key]["names"].append(k)
|
||||
flat_info[flat_key]["shapes"].append(shape)
|
||||
flat_info[flat_key]["numels"].append(int(np.prod(shape)))
|
||||
return flat_info
|
||||
|
||||
|
||||
def convert(flat_dir: str, output_dir: str, part: int):
|
||||
flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt')
|
||||
output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt')
|
||||
flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json'))
|
||||
flat_path = os.path.join(flat_dir, f"reshard-model_part-{part}-shard0.pt")
|
||||
output_path = os.path.join(output_dir, f"reshard-model_part-{part}.pt")
|
||||
flat_meta = load_json(os.path.join(flat_dir, "flat-meta.json"))
|
||||
flat_sd = torch.load(flat_path)
|
||||
print(f'Loaded flat state dict from {flat_path}')
|
||||
print(f"Loaded flat state dict from {flat_path}")
|
||||
output_sd = {}
|
||||
for flat_key, param_meta in flat_meta.items():
|
||||
flat_param = flat_sd['model'][flat_key]
|
||||
assert sum(param_meta['numels']) == flat_param.numel(
|
||||
flat_param = flat_sd["model"][flat_key]
|
||||
assert (
|
||||
sum(param_meta["numels"]) == flat_param.numel()
|
||||
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
|
||||
for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])):
|
||||
for name, shape, param in zip(
|
||||
param_meta["names"], param_meta["shapes"], flat_param.split(param_meta["numels"])
|
||||
):
|
||||
output_sd[name] = param.view(shape)
|
||||
|
||||
torch.save(output_sd, output_path)
|
||||
print(f'Saved unflat state dict to {output_path}')
|
||||
print(f"Saved unflat state dict to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('flat_dir')
|
||||
parser.add_argument('output_dir')
|
||||
parser.add_argument('part', type=int)
|
||||
parser.add_argument("flat_dir")
|
||||
parser.add_argument("output_dir")
|
||||
parser.add_argument("part", type=int)
|
||||
args = parser.parse_args()
|
||||
convert(args.flat_dir, args.output_dir, args.part)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import torch
|
||||
from multiprocessing import Pool
|
||||
|
||||
import torch
|
||||
|
||||
# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main
|
||||
# you can use whether wget or git lfs
|
||||
|
||||
@@ -20,14 +21,14 @@ with Pool(14) as pool:
|
||||
|
||||
restored = {}
|
||||
for ckpt in ckpts:
|
||||
for k,v in ckpt.items():
|
||||
if(k[0] == 'm'):
|
||||
k = k[6:]
|
||||
if(k == "lm_head.weight"):
|
||||
for k, v in ckpt.items():
|
||||
if k[0] == "m":
|
||||
k = k[6:]
|
||||
if k == "lm_head.weight":
|
||||
k = "head.dense.weight"
|
||||
if(k == "decoder.final_layer_norm.weight"):
|
||||
if k == "decoder.final_layer_norm.weight":
|
||||
k = "decoder.layer_norm.weight"
|
||||
if(k == "decoder.final_layer_norm.bias"):
|
||||
if k == "decoder.final_layer_norm.bias":
|
||||
k = "decoder.layer_norm.bias"
|
||||
restored[k] = v
|
||||
restored["decoder.version"] = "0.0"
|
||||
@@ -37,11 +38,11 @@ split_num = len(restored.keys()) // 60
|
||||
count = 0
|
||||
file_count = 1
|
||||
tmp = {}
|
||||
for k,v in restored.items():
|
||||
for k, v in restored.items():
|
||||
print(k)
|
||||
tmp[k] = v
|
||||
count = count + 1
|
||||
if(count == split_num):
|
||||
count = count + 1
|
||||
if count == split_num:
|
||||
filename = str(file_count) + "-restored.pt"
|
||||
torch.save(tmp, os.path.join(new_path, filename))
|
||||
file_count = file_count + 1
|
||||
@@ -50,6 +51,3 @@ for k,v in restored.items():
|
||||
|
||||
filename = str(file_count) + "-restored.pt"
|
||||
torch.save(tmp, os.path.join(new_path, filename))
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user