mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Deque, Hashable, List, Tuple
|
||||
|
||||
import torch
|
||||
from typing import List, Deque, Tuple, Hashable, Any
|
||||
from energonai import BatchManager, SubmitEntry, TaskEntry
|
||||
|
||||
|
||||
@@ -10,15 +11,15 @@ class BatchManagerForGeneration(BatchManager):
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
def _left_padding(self, batch_inputs):
|
||||
max_len = max(len(inputs['input_ids']) for inputs in batch_inputs)
|
||||
outputs = {'input_ids': [], 'attention_mask': []}
|
||||
max_len = max(len(inputs["input_ids"]) for inputs in batch_inputs)
|
||||
outputs = {"input_ids": [], "attention_mask": []}
|
||||
for inputs in batch_inputs:
|
||||
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
|
||||
input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
|
||||
padding_len = max_len - len(input_ids)
|
||||
input_ids = [self.pad_token_id] * padding_len + input_ids
|
||||
attention_mask = [0] * padding_len + attention_mask
|
||||
outputs['input_ids'].append(input_ids)
|
||||
outputs['attention_mask'].append(attention_mask)
|
||||
outputs["input_ids"].append(input_ids)
|
||||
outputs["attention_mask"].append(attention_mask)
|
||||
for k in outputs:
|
||||
outputs[k] = torch.tensor(outputs[k])
|
||||
return outputs, max_len
|
||||
@@ -26,7 +27,7 @@ class BatchManagerForGeneration(BatchManager):
|
||||
@staticmethod
|
||||
def _make_batch_key(entry: SubmitEntry) -> tuple:
|
||||
data = entry.data
|
||||
return (data['top_k'], data['top_p'], data['temperature'])
|
||||
return (data["top_k"], data["top_p"], data["temperature"])
|
||||
|
||||
def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]:
|
||||
entry = q.popleft()
|
||||
@@ -37,7 +38,7 @@ class BatchManagerForGeneration(BatchManager):
|
||||
break
|
||||
if self._make_batch_key(entry) != self._make_batch_key(q[0]):
|
||||
break
|
||||
if q[0].data['max_tokens'] > entry.data['max_tokens']:
|
||||
if q[0].data["max_tokens"] > entry.data["max_tokens"]:
|
||||
break
|
||||
e = q.popleft()
|
||||
batch.append(e.data)
|
||||
@@ -45,12 +46,12 @@ class BatchManagerForGeneration(BatchManager):
|
||||
inputs, max_len = self._left_padding(batch)
|
||||
trunc_lens = []
|
||||
for data in batch:
|
||||
trunc_lens.append(max_len + data['max_tokens'])
|
||||
inputs['top_k'] = entry.data['top_k']
|
||||
inputs['top_p'] = entry.data['top_p']
|
||||
inputs['temperature'] = entry.data['temperature']
|
||||
inputs['max_tokens'] = max_len + entry.data['max_tokens']
|
||||
return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens}
|
||||
trunc_lens.append(max_len + data["max_tokens"])
|
||||
inputs["top_k"] = entry.data["top_k"]
|
||||
inputs["top_p"] = entry.data["top_p"]
|
||||
inputs["temperature"] = entry.data["temperature"]
|
||||
inputs["max_tokens"] = max_len + entry.data["max_tokens"]
|
||||
return TaskEntry(tuple(uids), inputs), {"trunc_lens": trunc_lens}
|
||||
|
||||
def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]:
|
||||
retval = []
|
||||
|
@@ -1,15 +1,14 @@
|
||||
from locust import HttpUser, task
|
||||
from json import JSONDecodeError
|
||||
|
||||
|
||||
class GenerationUser(HttpUser):
|
||||
@task
|
||||
def generate(self):
|
||||
prompt = 'Question: What is the longest river on the earth? Answer:'
|
||||
prompt = "Question: What is the longest river on the earth? Answer:"
|
||||
for i in range(4, 9):
|
||||
data = {'max_tokens': 2**i, 'prompt': prompt}
|
||||
with self.client.post('/generation', json=data, catch_response=True) as response:
|
||||
data = {"max_tokens": 2**i, "prompt": prompt}
|
||||
with self.client.post("/generation", json=data, catch_response=True) as response:
|
||||
if response.status_code in (200, 406):
|
||||
response.success()
|
||||
else:
|
||||
response.failure('Response wrong')
|
||||
response.failure("Response wrong")
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from collections import OrderedDict
|
||||
from threading import Lock
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Any, Hashable, Dict
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, Hashable, List
|
||||
|
||||
|
||||
class MissCacheError(Exception):
|
||||
|
@@ -4,20 +4,21 @@ import random
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
from energonai import QueueFullError, launch_engine
|
||||
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
max_tokens: int = Field(gt=0, le=256, example=64)
|
||||
prompt: str = Field(
|
||||
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
|
||||
min_length=1,
|
||||
example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:",
|
||||
)
|
||||
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
||||
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
||||
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
||||
@@ -26,7 +27,7 @@ class GenerationTaskReq(BaseModel):
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post('/generation')
|
||||
@app.post("/generation")
|
||||
async def generate(data: GenerationTaskReq, request: Request):
|
||||
logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}')
|
||||
key = (data.prompt, data.max_tokens)
|
||||
@@ -35,13 +36,13 @@ async def generate(data: GenerationTaskReq, request: Request):
|
||||
raise MissCacheError()
|
||||
outputs = cache.get(key)
|
||||
output = random.choice(outputs)
|
||||
logger.info('Cache hit')
|
||||
logger.info("Cache hit")
|
||||
except MissCacheError:
|
||||
inputs = tokenizer(data.prompt, truncation=True, max_length=512)
|
||||
inputs['max_tokens'] = data.max_tokens
|
||||
inputs['top_k'] = data.top_k
|
||||
inputs['top_p'] = data.top_p
|
||||
inputs['temperature'] = data.temperature
|
||||
inputs["max_tokens"] = data.max_tokens
|
||||
inputs["top_k"] = data.top_k
|
||||
inputs["top_p"] = data.top_p
|
||||
inputs["temperature"] = data.temperature
|
||||
try:
|
||||
uid = id(data)
|
||||
engine.submit(uid, inputs)
|
||||
@@ -52,7 +53,7 @@ async def generate(data: GenerationTaskReq, request: Request):
|
||||
except QueueFullError as e:
|
||||
raise HTTPException(status_code=406, detail=e.args[0])
|
||||
|
||||
return {'text': output}
|
||||
return {"text": output}
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
@@ -64,60 +65,72 @@ async def shutdown(*_):
|
||||
|
||||
|
||||
def get_model_fn(model_name: str):
|
||||
model_map = {
|
||||
'opt-125m': opt_125M,
|
||||
'opt-6.7b': opt_6B,
|
||||
'opt-30b': opt_30B,
|
||||
'opt-175b': opt_175B
|
||||
}
|
||||
model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B}
|
||||
return model_map[model_name]
|
||||
|
||||
|
||||
def print_args(args: argparse.Namespace):
|
||||
print('\n==> Args:')
|
||||
print("\n==> Args:")
|
||||
for k, v in args.__dict__.items():
|
||||
print(f'{k} = {v}')
|
||||
print(f"{k} = {v}")
|
||||
|
||||
|
||||
FIXED_CACHE_KEYS = [
|
||||
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
|
||||
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
|
||||
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
|
||||
(
|
||||
"Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:",
|
||||
64,
|
||||
),
|
||||
(
|
||||
"A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.",
|
||||
64,
|
||||
),
|
||||
(
|
||||
"English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:",
|
||||
64,
|
||||
),
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
|
||||
parser.add_argument('--tp', type=int, default=1)
|
||||
parser.add_argument('--master_host', default='localhost')
|
||||
parser.add_argument('--master_port', type=int, default=19990)
|
||||
parser.add_argument('--rpc_port', type=int, default=19980)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--pipe_size', type=int, default=1)
|
||||
parser.add_argument('--queue_size', type=int, default=0)
|
||||
parser.add_argument('--http_host', default='0.0.0.0')
|
||||
parser.add_argument('--http_port', type=int, default=7070)
|
||||
parser.add_argument('--checkpoint', default=None)
|
||||
parser.add_argument('--cache_size', type=int, default=0)
|
||||
parser.add_argument('--cache_list_size', type=int, default=1)
|
||||
parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"])
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--master_host", default="localhost")
|
||||
parser.add_argument("--master_port", type=int, default=19990)
|
||||
parser.add_argument("--rpc_port", type=int, default=19980)
|
||||
parser.add_argument("--max_batch_size", type=int, default=8)
|
||||
parser.add_argument("--pipe_size", type=int, default=1)
|
||||
parser.add_argument("--queue_size", type=int, default=0)
|
||||
parser.add_argument("--http_host", default="0.0.0.0")
|
||||
parser.add_argument("--http_port", type=int, default=7070)
|
||||
parser.add_argument("--checkpoint", default=None)
|
||||
parser.add_argument("--cache_size", type=int, default=0)
|
||||
parser.add_argument("--cache_list_size", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
print_args(args)
|
||||
model_kwargs = {}
|
||||
if args.checkpoint is not None:
|
||||
model_kwargs['checkpoint'] = args.checkpoint
|
||||
model_kwargs["checkpoint"] = args.checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b")
|
||||
if args.cache_size > 0:
|
||||
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
|
||||
else:
|
||||
cache = None
|
||||
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
|
||||
pad_token_id=tokenizer.pad_token_id),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs)
|
||||
engine = launch_engine(
|
||||
args.tp,
|
||||
1,
|
||||
args.master_host,
|
||||
args.master_port,
|
||||
args.rpc_port,
|
||||
get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(
|
||||
max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id
|
||||
),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs,
|
||||
)
|
||||
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
|
||||
server = uvicorn.Server(config=config)
|
||||
server.run()
|
||||
|
@@ -1,33 +1,36 @@
|
||||
import logging
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
from torch import Tensor
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B
|
||||
from transformers import GPT2Tokenizer
|
||||
from energonai import launch_engine, QueueFullError
|
||||
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
from energonai import QueueFullError, launch_engine
|
||||
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
|
||||
from pydantic import BaseModel, Field
|
||||
from sanic import Sanic
|
||||
from sanic.request import Request
|
||||
from sanic.response import json
|
||||
from sanic_ext import validate, openapi
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
from sanic_ext import openapi, validate
|
||||
from torch import Tensor
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
max_tokens: int = Field(gt=0, le=256, example=64)
|
||||
prompt: str = Field(
|
||||
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
|
||||
min_length=1,
|
||||
example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:",
|
||||
)
|
||||
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
||||
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
||||
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
||||
|
||||
|
||||
app = Sanic('opt')
|
||||
app = Sanic("opt")
|
||||
|
||||
|
||||
@app.post('/generation')
|
||||
@app.post("/generation")
|
||||
@openapi.body(GenerationTaskReq)
|
||||
@validate(json=GenerationTaskReq)
|
||||
async def generate(request: Request, body: GenerationTaskReq):
|
||||
@@ -38,13 +41,13 @@ async def generate(request: Request, body: GenerationTaskReq):
|
||||
raise MissCacheError()
|
||||
outputs = cache.get(key)
|
||||
output = random.choice(outputs)
|
||||
logger.info('Cache hit')
|
||||
logger.info("Cache hit")
|
||||
except MissCacheError:
|
||||
inputs = tokenizer(body.prompt, truncation=True, max_length=512)
|
||||
inputs['max_tokens'] = body.max_tokens
|
||||
inputs['top_k'] = body.top_k
|
||||
inputs['top_p'] = body.top_p
|
||||
inputs['temperature'] = body.temperature
|
||||
inputs["max_tokens"] = body.max_tokens
|
||||
inputs["top_k"] = body.top_k
|
||||
inputs["top_p"] = body.top_p
|
||||
inputs["temperature"] = body.temperature
|
||||
try:
|
||||
uid = id(body)
|
||||
engine.submit(uid, inputs)
|
||||
@@ -54,9 +57,9 @@ async def generate(request: Request, body: GenerationTaskReq):
|
||||
if cache is not None:
|
||||
cache.add(key, output)
|
||||
except QueueFullError as e:
|
||||
return json({'detail': e.args[0]}, status=406)
|
||||
return json({"detail": e.args[0]}, status=406)
|
||||
|
||||
return json({'text': output})
|
||||
return json({"text": output})
|
||||
|
||||
|
||||
@app.after_server_stop
|
||||
@@ -65,58 +68,70 @@ def shutdown(*_):
|
||||
|
||||
|
||||
def get_model_fn(model_name: str):
|
||||
model_map = {
|
||||
'opt-125m': opt_125M,
|
||||
'opt-6.7b': opt_6B,
|
||||
'opt-30b': opt_30B,
|
||||
'opt-175b': opt_175B
|
||||
}
|
||||
model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B}
|
||||
return model_map[model_name]
|
||||
|
||||
|
||||
def print_args(args: argparse.Namespace):
|
||||
print('\n==> Args:')
|
||||
print("\n==> Args:")
|
||||
for k, v in args.__dict__.items():
|
||||
print(f'{k} = {v}')
|
||||
print(f"{k} = {v}")
|
||||
|
||||
|
||||
FIXED_CACHE_KEYS = [
|
||||
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
|
||||
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
|
||||
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
|
||||
(
|
||||
"Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:",
|
||||
64,
|
||||
),
|
||||
(
|
||||
"A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.",
|
||||
64,
|
||||
),
|
||||
(
|
||||
"English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:",
|
||||
64,
|
||||
),
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
|
||||
parser.add_argument('--tp', type=int, default=1)
|
||||
parser.add_argument('--master_host', default='localhost')
|
||||
parser.add_argument('--master_port', type=int, default=19990)
|
||||
parser.add_argument('--rpc_port', type=int, default=19980)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--pipe_size', type=int, default=1)
|
||||
parser.add_argument('--queue_size', type=int, default=0)
|
||||
parser.add_argument('--http_host', default='0.0.0.0')
|
||||
parser.add_argument('--http_port', type=int, default=7070)
|
||||
parser.add_argument('--checkpoint', default=None)
|
||||
parser.add_argument('--cache_size', type=int, default=0)
|
||||
parser.add_argument('--cache_list_size', type=int, default=1)
|
||||
parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"])
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--master_host", default="localhost")
|
||||
parser.add_argument("--master_port", type=int, default=19990)
|
||||
parser.add_argument("--rpc_port", type=int, default=19980)
|
||||
parser.add_argument("--max_batch_size", type=int, default=8)
|
||||
parser.add_argument("--pipe_size", type=int, default=1)
|
||||
parser.add_argument("--queue_size", type=int, default=0)
|
||||
parser.add_argument("--http_host", default="0.0.0.0")
|
||||
parser.add_argument("--http_port", type=int, default=7070)
|
||||
parser.add_argument("--checkpoint", default=None)
|
||||
parser.add_argument("--cache_size", type=int, default=0)
|
||||
parser.add_argument("--cache_list_size", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
print_args(args)
|
||||
model_kwargs = {}
|
||||
if args.checkpoint is not None:
|
||||
model_kwargs['checkpoint'] = args.checkpoint
|
||||
model_kwargs["checkpoint"] = args.checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b")
|
||||
if args.cache_size > 0:
|
||||
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
|
||||
else:
|
||||
cache = None
|
||||
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
|
||||
pad_token_id=tokenizer.pad_token_id),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs)
|
||||
engine = launch_engine(
|
||||
args.tp,
|
||||
1,
|
||||
args.master_host,
|
||||
args.master_port,
|
||||
args.rpc_port,
|
||||
get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(
|
||||
max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id
|
||||
),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs,
|
||||
)
|
||||
app.run(args.http_host, args.http_port)
|
||||
|
@@ -43,4 +43,3 @@ Finally, you will get 8 files in `<new_output_dir>` with following checksums:
|
||||
5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt
|
||||
f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt
|
||||
```
|
||||
|
||||
|
@@ -14,42 +14,45 @@ def load_json(path: str):
|
||||
|
||||
|
||||
def parse_shape_info(flat_dir: str):
|
||||
data = load_json(os.path.join(flat_dir, 'shape.json'))
|
||||
data = load_json(os.path.join(flat_dir, "shape.json"))
|
||||
flat_info = defaultdict(lambda: defaultdict(list))
|
||||
for k, shape in data.items():
|
||||
matched = re.match(r'decoder.layers.\d+', k)
|
||||
matched = re.match(r"decoder.layers.\d+", k)
|
||||
if matched is None:
|
||||
flat_key = 'flat_param_0'
|
||||
flat_key = "flat_param_0"
|
||||
else:
|
||||
flat_key = f'{matched[0]}.flat_param_0'
|
||||
flat_info[flat_key]['names'].append(k)
|
||||
flat_info[flat_key]['shapes'].append(shape)
|
||||
flat_info[flat_key]['numels'].append(int(np.prod(shape)))
|
||||
flat_key = f"{matched[0]}.flat_param_0"
|
||||
flat_info[flat_key]["names"].append(k)
|
||||
flat_info[flat_key]["shapes"].append(shape)
|
||||
flat_info[flat_key]["numels"].append(int(np.prod(shape)))
|
||||
return flat_info
|
||||
|
||||
|
||||
def convert(flat_dir: str, output_dir: str, part: int):
|
||||
flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt')
|
||||
output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt')
|
||||
flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json'))
|
||||
flat_path = os.path.join(flat_dir, f"reshard-model_part-{part}-shard0.pt")
|
||||
output_path = os.path.join(output_dir, f"reshard-model_part-{part}.pt")
|
||||
flat_meta = load_json(os.path.join(flat_dir, "flat-meta.json"))
|
||||
flat_sd = torch.load(flat_path)
|
||||
print(f'Loaded flat state dict from {flat_path}')
|
||||
print(f"Loaded flat state dict from {flat_path}")
|
||||
output_sd = {}
|
||||
for flat_key, param_meta in flat_meta.items():
|
||||
flat_param = flat_sd['model'][flat_key]
|
||||
assert sum(param_meta['numels']) == flat_param.numel(
|
||||
flat_param = flat_sd["model"][flat_key]
|
||||
assert (
|
||||
sum(param_meta["numels"]) == flat_param.numel()
|
||||
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
|
||||
for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])):
|
||||
for name, shape, param in zip(
|
||||
param_meta["names"], param_meta["shapes"], flat_param.split(param_meta["numels"])
|
||||
):
|
||||
output_sd[name] = param.view(shape)
|
||||
|
||||
torch.save(output_sd, output_path)
|
||||
print(f'Saved unflat state dict to {output_path}')
|
||||
print(f"Saved unflat state dict to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('flat_dir')
|
||||
parser.add_argument('output_dir')
|
||||
parser.add_argument('part', type=int)
|
||||
parser.add_argument("flat_dir")
|
||||
parser.add_argument("output_dir")
|
||||
parser.add_argument("part", type=int)
|
||||
args = parser.parse_args()
|
||||
convert(args.flat_dir, args.output_dir, args.part)
|
||||
|
File diff suppressed because one or more lines are too long
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import torch
|
||||
from multiprocessing import Pool
|
||||
|
||||
import torch
|
||||
|
||||
# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main
|
||||
# you can use whether wget or git lfs
|
||||
|
||||
@@ -20,14 +21,14 @@ with Pool(14) as pool:
|
||||
|
||||
restored = {}
|
||||
for ckpt in ckpts:
|
||||
for k,v in ckpt.items():
|
||||
if(k[0] == 'm'):
|
||||
k = k[6:]
|
||||
if(k == "lm_head.weight"):
|
||||
for k, v in ckpt.items():
|
||||
if k[0] == "m":
|
||||
k = k[6:]
|
||||
if k == "lm_head.weight":
|
||||
k = "head.dense.weight"
|
||||
if(k == "decoder.final_layer_norm.weight"):
|
||||
if k == "decoder.final_layer_norm.weight":
|
||||
k = "decoder.layer_norm.weight"
|
||||
if(k == "decoder.final_layer_norm.bias"):
|
||||
if k == "decoder.final_layer_norm.bias":
|
||||
k = "decoder.layer_norm.bias"
|
||||
restored[k] = v
|
||||
restored["decoder.version"] = "0.0"
|
||||
@@ -37,11 +38,11 @@ split_num = len(restored.keys()) // 60
|
||||
count = 0
|
||||
file_count = 1
|
||||
tmp = {}
|
||||
for k,v in restored.items():
|
||||
for k, v in restored.items():
|
||||
print(k)
|
||||
tmp[k] = v
|
||||
count = count + 1
|
||||
if(count == split_num):
|
||||
count = count + 1
|
||||
if count == split_num:
|
||||
filename = str(file_count) + "-restored.pt"
|
||||
torch.save(tmp, os.path.join(new_path, filename))
|
||||
file_count = file_count + 1
|
||||
@@ -50,6 +51,3 @@ for k,v in restored.items():
|
||||
|
||||
filename = str(file_count) + "-restored.pt"
|
||||
torch.save(tmp, os.path.join(new_path, filename))
|
||||
|
||||
|
||||
|
||||
|
@@ -4,7 +4,7 @@ except ImportError:
|
||||
# colossalai > 0.2.8
|
||||
from colossalai.legacy.zero import TensorShardStrategy
|
||||
|
||||
zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
|
||||
tensor_placement_policy="auto",
|
||||
reuse_fp16_shard=True),
|
||||
optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384))
|
||||
zero = dict(
|
||||
model_config=dict(shard_strategy=TensorShardStrategy(), tensor_placement_policy="auto", reuse_fp16_shard=True),
|
||||
optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384),
|
||||
)
|
||||
|
@@ -4,7 +4,7 @@ from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
|
||||
class barrier_context():
|
||||
class barrier_context:
|
||||
"""
|
||||
This context manager is used to allow one process to execute while blocking all
|
||||
other processes in the same process group. This is often useful when downloading is required
|
||||
|
@@ -86,14 +86,12 @@ def parse_args():
|
||||
default=None,
|
||||
help="The configuration name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument("--train_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A csv or a json file containing the training data.")
|
||||
parser.add_argument("--validation_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A csv or a json file containing the validation data.")
|
||||
parser.add_argument(
|
||||
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_split_percentage",
|
||||
default=5,
|
||||
@@ -161,10 +159,9 @@ def parse_args():
|
||||
help="The scheduler type to use.",
|
||||
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
||||
)
|
||||
parser.add_argument("--num_warmup_steps",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler.")
|
||||
parser.add_argument(
|
||||
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
@@ -178,9 +175,11 @@ def parse_args():
|
||||
"--block_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of"
|
||||
" this size for training. Default to the model max input length for single sentence inputs (take into"
|
||||
" account special tokens)."),
|
||||
help=(
|
||||
"Optional input sequence length after tokenization. The training dataset will be truncated in block of"
|
||||
" this size for training. Default to the model max input length for single sentence inputs (take into"
|
||||
" account special tokens)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocessing_num_workers",
|
||||
@@ -188,17 +187,16 @@ def parse_args():
|
||||
default=None,
|
||||
help="The number of processes to use for the preprocessing.",
|
||||
)
|
||||
parser.add_argument("--overwrite_cache",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument("--no_keep_linebreaks",
|
||||
action="store_true",
|
||||
help="Do not keep line breaks when using TXT files.")
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files."
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_model_id",
|
||||
type=str,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||
)
|
||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
@@ -221,13 +219,15 @@ def parse_args():
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="all",
|
||||
help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
|
||||
"Only applicable when `--with_tracking` is passed."),
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
|
||||
"Only applicable when `--with_tracking` is passed."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
|
||||
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
|
||||
parser.add_argument("--init_in_cpu", action="store_true", default=False, help="init training model in cpu")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
@@ -250,6 +250,7 @@ def parse_args():
|
||||
|
||||
def colo_memory_cap(size_in_GB):
|
||||
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
|
||||
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
if size_in_GB * (1024**3) < cuda_capacity:
|
||||
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
|
||||
@@ -257,7 +258,6 @@ def colo_memory_cap(size_in_GB):
|
||||
|
||||
|
||||
class DummyDataloader:
|
||||
|
||||
def __init__(self, length, batch_size, seq_len, vocab_size):
|
||||
self.length = length
|
||||
self.batch_size = batch_size
|
||||
@@ -380,32 +380,36 @@ def main():
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
logger.info("Model config has been created", ranks=[0])
|
||||
|
||||
if args.model_name_or_path == 'facebook/opt-13b':
|
||||
if args.model_name_or_path == "facebook/opt-13b":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
||||
else:
|
||||
print(f'load model from {args.model_name_or_path}')
|
||||
print(f"load model from {args.model_name_or_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
|
||||
logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0])
|
||||
|
||||
if args.init_in_cpu:
|
||||
init_dev = torch.device('cpu')
|
||||
init_dev = torch.device("cpu")
|
||||
else:
|
||||
init_dev = get_current_device()
|
||||
|
||||
cai_version = colossalai.__version__
|
||||
logger.info(f'using Colossal-AI version {cai_version}')
|
||||
logger.info(f"using Colossal-AI version {cai_version}")
|
||||
# build model
|
||||
if version.parse(cai_version) >= version.parse("0.3.1"):
|
||||
from contextlib import nullcontext
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
ctx = LazyInitContext(
|
||||
default_device=init_dev
|
||||
) if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b' else nullcontext()
|
||||
|
||||
ctx = (
|
||||
LazyInitContext(default_device=init_dev)
|
||||
if args.model_name_or_path is None or args.model_name_or_path == "facebook/opt-13b"
|
||||
else nullcontext()
|
||||
)
|
||||
else:
|
||||
from colossalai.zero import ColoInitContext
|
||||
|
||||
ctx = ColoInitContext(device=init_dev)
|
||||
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
|
||||
if args.model_name_or_path is None or args.model_name_or_path == "facebook/opt-13b":
|
||||
# currently, there has a bug in pretrained opt-13b
|
||||
# we can not import it until huggingface fix it
|
||||
logger.info("Train a new model from scratch", ranks=[0])
|
||||
@@ -414,17 +418,20 @@ def main():
|
||||
else:
|
||||
logger.info("Finetune a pre-trained model", ranks=[0])
|
||||
with ctx:
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
local_files_only=False)
|
||||
model = OPTForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
local_files_only=False,
|
||||
)
|
||||
|
||||
# enable graident checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
PLACEMENT_POLICY = 'auto'
|
||||
PLACEMENT_POLICY = "auto"
|
||||
if version.parse(cai_version) >= version.parse("0.3.1"):
|
||||
from colossalai.zero import GeminiDDP
|
||||
|
||||
model = GeminiDDP(model, offload_optim_frac=1.0, pin_memory=True)
|
||||
elif version.parse(cai_version) > version.parse("0.1.10"):
|
||||
try:
|
||||
@@ -435,16 +442,19 @@ def main():
|
||||
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
|
||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
|
||||
pg = ProcessGroup()
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||||
chunk_manager = ChunkManager(chunk_size,
|
||||
pg,
|
||||
enable_distributed_storage=True,
|
||||
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
|
||||
chunk_manager = ChunkManager(
|
||||
chunk_size,
|
||||
pg,
|
||||
enable_distributed_storage=True,
|
||||
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY),
|
||||
)
|
||||
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
|
||||
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
||||
logger.info(f"{model.__class__.__name__} has been created", ranks=[0])
|
||||
|
||||
if not args.synthetic:
|
||||
# Preprocessing the datasets.
|
||||
@@ -470,12 +480,15 @@ def main():
|
||||
if block_size > 1024:
|
||||
logger.warning(
|
||||
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
||||
"Picking 1024 instead. You can change that default value by passing --block_size xxx.")
|
||||
"Picking 1024 instead. You can change that default value by passing --block_size xxx."
|
||||
)
|
||||
block_size = 1024
|
||||
else:
|
||||
if args.block_size > tokenizer.model_max_length:
|
||||
logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
|
||||
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.")
|
||||
logger.warning(
|
||||
f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
|
||||
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
|
||||
)
|
||||
block_size = min(args.block_size, tokenizer.model_max_length)
|
||||
|
||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
||||
@@ -489,8 +502,8 @@ def main():
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# Split by chunks of max_len.
|
||||
result = {
|
||||
k: [t[i:i + block_size] for i in range(0, total_length, block_size)
|
||||
] for k, t in concatenated_examples.items()
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
@@ -520,19 +533,23 @@ def main():
|
||||
# logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = get_dataloader(train_dataset,
|
||||
shuffle=True,
|
||||
add_sampler=True,
|
||||
collate_fn=default_data_collator,
|
||||
batch_size=args.per_device_train_batch_size)
|
||||
eval_dataloader = DataLoader(eval_dataset,
|
||||
collate_fn=default_data_collator,
|
||||
batch_size=args.per_device_eval_batch_size)
|
||||
train_dataloader = get_dataloader(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
add_sampler=True,
|
||||
collate_fn=default_data_collator,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size
|
||||
)
|
||||
else:
|
||||
train_dataloader = DummyDataloader(30, args.per_device_train_batch_size, config.max_position_embeddings,
|
||||
config.vocab_size)
|
||||
eval_dataloader = DummyDataloader(10, args.per_device_train_batch_size, config.max_position_embeddings,
|
||||
config.vocab_size)
|
||||
train_dataloader = DummyDataloader(
|
||||
30, args.per_device_train_batch_size, config.max_position_embeddings, config.vocab_size
|
||||
)
|
||||
eval_dataloader = DummyDataloader(
|
||||
10, args.per_device_train_batch_size, config.max_position_embeddings, config.vocab_size
|
||||
)
|
||||
logger.info("Dataloaders have been created", ranks=[0])
|
||||
|
||||
# Optimizer
|
||||
@@ -593,7 +610,6 @@ def main():
|
||||
global_step = 0
|
||||
|
||||
for epoch in range(starting_epoch, args.num_train_epochs):
|
||||
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
@@ -601,7 +617,7 @@ def main():
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
outputs = model(use_cache=False, **batch)
|
||||
loss = outputs['loss']
|
||||
loss = outputs["loss"]
|
||||
optimizer.backward(loss)
|
||||
|
||||
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||
@@ -624,7 +640,7 @@ def main():
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
|
||||
loss = outputs['loss'].unsqueeze(0)
|
||||
loss = outputs["loss"].unsqueeze(0)
|
||||
losses.append(loss)
|
||||
|
||||
losses = torch.cat(losses)
|
||||
@@ -640,7 +656,7 @@ def main():
|
||||
if args.output_dir is not None:
|
||||
model_state = model.state_dict()
|
||||
if is_main_process:
|
||||
torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps))
|
||||
torch.save(model_state, args.output_dir + "/epoch_{}_model.pth".format(completed_steps))
|
||||
dist.barrier()
|
||||
# load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps))
|
||||
# model.load_state_dict(load_state, strict=False)
|
||||
|
Reference in New Issue
Block a user