[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,5 +1,6 @@
from typing import Any, Deque, Hashable, List, Tuple
import torch
from typing import List, Deque, Tuple, Hashable, Any
from energonai import BatchManager, SubmitEntry, TaskEntry
@@ -10,15 +11,15 @@ class BatchManagerForGeneration(BatchManager):
self.pad_token_id = pad_token_id
def _left_padding(self, batch_inputs):
max_len = max(len(inputs['input_ids']) for inputs in batch_inputs)
outputs = {'input_ids': [], 'attention_mask': []}
max_len = max(len(inputs["input_ids"]) for inputs in batch_inputs)
outputs = {"input_ids": [], "attention_mask": []}
for inputs in batch_inputs:
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
padding_len = max_len - len(input_ids)
input_ids = [self.pad_token_id] * padding_len + input_ids
attention_mask = [0] * padding_len + attention_mask
outputs['input_ids'].append(input_ids)
outputs['attention_mask'].append(attention_mask)
outputs["input_ids"].append(input_ids)
outputs["attention_mask"].append(attention_mask)
for k in outputs:
outputs[k] = torch.tensor(outputs[k])
return outputs, max_len
@@ -26,7 +27,7 @@ class BatchManagerForGeneration(BatchManager):
@staticmethod
def _make_batch_key(entry: SubmitEntry) -> tuple:
data = entry.data
return (data['top_k'], data['top_p'], data['temperature'])
return (data["top_k"], data["top_p"], data["temperature"])
def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]:
entry = q.popleft()
@@ -37,7 +38,7 @@ class BatchManagerForGeneration(BatchManager):
break
if self._make_batch_key(entry) != self._make_batch_key(q[0]):
break
if q[0].data['max_tokens'] > entry.data['max_tokens']:
if q[0].data["max_tokens"] > entry.data["max_tokens"]:
break
e = q.popleft()
batch.append(e.data)
@@ -45,12 +46,12 @@ class BatchManagerForGeneration(BatchManager):
inputs, max_len = self._left_padding(batch)
trunc_lens = []
for data in batch:
trunc_lens.append(max_len + data['max_tokens'])
inputs['top_k'] = entry.data['top_k']
inputs['top_p'] = entry.data['top_p']
inputs['temperature'] = entry.data['temperature']
inputs['max_tokens'] = max_len + entry.data['max_tokens']
return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens}
trunc_lens.append(max_len + data["max_tokens"])
inputs["top_k"] = entry.data["top_k"]
inputs["top_p"] = entry.data["top_p"]
inputs["temperature"] = entry.data["temperature"]
inputs["max_tokens"] = max_len + entry.data["max_tokens"]
return TaskEntry(tuple(uids), inputs), {"trunc_lens": trunc_lens}
def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]:
retval = []

View File

@@ -1,15 +1,14 @@
from locust import HttpUser, task
from json import JSONDecodeError
class GenerationUser(HttpUser):
@task
def generate(self):
prompt = 'Question: What is the longest river on the earth? Answer:'
prompt = "Question: What is the longest river on the earth? Answer:"
for i in range(4, 9):
data = {'max_tokens': 2**i, 'prompt': prompt}
with self.client.post('/generation', json=data, catch_response=True) as response:
data = {"max_tokens": 2**i, "prompt": prompt}
with self.client.post("/generation", json=data, catch_response=True) as response:
if response.status_code in (200, 406):
response.success()
else:
response.failure('Response wrong')
response.failure("Response wrong")

View File

@@ -1,7 +1,7 @@
from collections import OrderedDict
from threading import Lock
from contextlib import contextmanager
from typing import List, Any, Hashable, Dict
from threading import Lock
from typing import Any, Dict, Hashable, List
class MissCacheError(Exception):

View File

@@ -4,20 +4,21 @@ import random
from typing import Optional
import uvicorn
from batch import BatchManagerForGeneration
from cache import ListCache, MissCacheError
from energonai import QueueFullError, launch_engine
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field
from transformers import GPT2Tokenizer
from batch import BatchManagerForGeneration
from cache import ListCache, MissCacheError
class GenerationTaskReq(BaseModel):
max_tokens: int = Field(gt=0, le=256, example=64)
prompt: str = Field(
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
min_length=1,
example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:",
)
top_k: Optional[int] = Field(default=None, gt=0, example=50)
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
@@ -26,7 +27,7 @@ class GenerationTaskReq(BaseModel):
app = FastAPI()
@app.post('/generation')
@app.post("/generation")
async def generate(data: GenerationTaskReq, request: Request):
logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}')
key = (data.prompt, data.max_tokens)
@@ -35,13 +36,13 @@ async def generate(data: GenerationTaskReq, request: Request):
raise MissCacheError()
outputs = cache.get(key)
output = random.choice(outputs)
logger.info('Cache hit')
logger.info("Cache hit")
except MissCacheError:
inputs = tokenizer(data.prompt, truncation=True, max_length=512)
inputs['max_tokens'] = data.max_tokens
inputs['top_k'] = data.top_k
inputs['top_p'] = data.top_p
inputs['temperature'] = data.temperature
inputs["max_tokens"] = data.max_tokens
inputs["top_k"] = data.top_k
inputs["top_p"] = data.top_p
inputs["temperature"] = data.temperature
try:
uid = id(data)
engine.submit(uid, inputs)
@@ -52,7 +53,7 @@ async def generate(data: GenerationTaskReq, request: Request):
except QueueFullError as e:
raise HTTPException(status_code=406, detail=e.args[0])
return {'text': output}
return {"text": output}
@app.on_event("shutdown")
@@ -64,60 +65,72 @@ async def shutdown(*_):
def get_model_fn(model_name: str):
model_map = {
'opt-125m': opt_125M,
'opt-6.7b': opt_6B,
'opt-30b': opt_30B,
'opt-175b': opt_175B
}
model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B}
return model_map[model_name]
def print_args(args: argparse.Namespace):
print('\n==> Args:')
print("\n==> Args:")
for k, v in args.__dict__.items():
print(f'{k} = {v}')
print(f"{k} = {v}")
FIXED_CACHE_KEYS = [
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
(
"Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:",
64,
),
(
"A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.",
64,
),
(
"English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:",
64,
),
]
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
parser.add_argument('--tp', type=int, default=1)
parser.add_argument('--master_host', default='localhost')
parser.add_argument('--master_port', type=int, default=19990)
parser.add_argument('--rpc_port', type=int, default=19980)
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--pipe_size', type=int, default=1)
parser.add_argument('--queue_size', type=int, default=0)
parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070)
parser.add_argument('--checkpoint', default=None)
parser.add_argument('--cache_size', type=int, default=0)
parser.add_argument('--cache_list_size', type=int, default=1)
parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"])
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--master_host", default="localhost")
parser.add_argument("--master_port", type=int, default=19990)
parser.add_argument("--rpc_port", type=int, default=19980)
parser.add_argument("--max_batch_size", type=int, default=8)
parser.add_argument("--pipe_size", type=int, default=1)
parser.add_argument("--queue_size", type=int, default=0)
parser.add_argument("--http_host", default="0.0.0.0")
parser.add_argument("--http_port", type=int, default=7070)
parser.add_argument("--checkpoint", default=None)
parser.add_argument("--cache_size", type=int, default=0)
parser.add_argument("--cache_list_size", type=int, default=1)
args = parser.parse_args()
print_args(args)
model_kwargs = {}
if args.checkpoint is not None:
model_kwargs['checkpoint'] = args.checkpoint
model_kwargs["checkpoint"] = args.checkpoint
logger = logging.getLogger(__name__)
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b")
if args.cache_size > 0:
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
else:
cache = None
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
pad_token_id=tokenizer.pad_token_id),
pipe_size=args.pipe_size,
queue_size=args.queue_size,
**model_kwargs)
engine = launch_engine(
args.tp,
1,
args.master_host,
args.master_port,
args.rpc_port,
get_model_fn(args.model),
batch_manager=BatchManagerForGeneration(
max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id
),
pipe_size=args.pipe_size,
queue_size=args.queue_size,
**model_kwargs,
)
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
server = uvicorn.Server(config=config)
server.run()

View File

@@ -1,33 +1,36 @@
import logging
import argparse
import logging
import random
from torch import Tensor
from pydantic import BaseModel, Field
from typing import Optional
from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B
from transformers import GPT2Tokenizer
from energonai import launch_engine, QueueFullError
from batch import BatchManagerForGeneration
from cache import ListCache, MissCacheError
from energonai import QueueFullError, launch_engine
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
from pydantic import BaseModel, Field
from sanic import Sanic
from sanic.request import Request
from sanic.response import json
from sanic_ext import validate, openapi
from batch import BatchManagerForGeneration
from cache import ListCache, MissCacheError
from sanic_ext import openapi, validate
from torch import Tensor
from transformers import GPT2Tokenizer
class GenerationTaskReq(BaseModel):
max_tokens: int = Field(gt=0, le=256, example=64)
prompt: str = Field(
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
min_length=1,
example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:",
)
top_k: Optional[int] = Field(default=None, gt=0, example=50)
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
app = Sanic('opt')
app = Sanic("opt")
@app.post('/generation')
@app.post("/generation")
@openapi.body(GenerationTaskReq)
@validate(json=GenerationTaskReq)
async def generate(request: Request, body: GenerationTaskReq):
@@ -38,13 +41,13 @@ async def generate(request: Request, body: GenerationTaskReq):
raise MissCacheError()
outputs = cache.get(key)
output = random.choice(outputs)
logger.info('Cache hit')
logger.info("Cache hit")
except MissCacheError:
inputs = tokenizer(body.prompt, truncation=True, max_length=512)
inputs['max_tokens'] = body.max_tokens
inputs['top_k'] = body.top_k
inputs['top_p'] = body.top_p
inputs['temperature'] = body.temperature
inputs["max_tokens"] = body.max_tokens
inputs["top_k"] = body.top_k
inputs["top_p"] = body.top_p
inputs["temperature"] = body.temperature
try:
uid = id(body)
engine.submit(uid, inputs)
@@ -54,9 +57,9 @@ async def generate(request: Request, body: GenerationTaskReq):
if cache is not None:
cache.add(key, output)
except QueueFullError as e:
return json({'detail': e.args[0]}, status=406)
return json({"detail": e.args[0]}, status=406)
return json({'text': output})
return json({"text": output})
@app.after_server_stop
@@ -65,58 +68,70 @@ def shutdown(*_):
def get_model_fn(model_name: str):
model_map = {
'opt-125m': opt_125M,
'opt-6.7b': opt_6B,
'opt-30b': opt_30B,
'opt-175b': opt_175B
}
model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B}
return model_map[model_name]
def print_args(args: argparse.Namespace):
print('\n==> Args:')
print("\n==> Args:")
for k, v in args.__dict__.items():
print(f'{k} = {v}')
print(f"{k} = {v}")
FIXED_CACHE_KEYS = [
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
(
"Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:",
64,
),
(
"A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.",
64,
),
(
"English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:",
64,
),
]
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
parser.add_argument('--tp', type=int, default=1)
parser.add_argument('--master_host', default='localhost')
parser.add_argument('--master_port', type=int, default=19990)
parser.add_argument('--rpc_port', type=int, default=19980)
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--pipe_size', type=int, default=1)
parser.add_argument('--queue_size', type=int, default=0)
parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070)
parser.add_argument('--checkpoint', default=None)
parser.add_argument('--cache_size', type=int, default=0)
parser.add_argument('--cache_list_size', type=int, default=1)
parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"])
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--master_host", default="localhost")
parser.add_argument("--master_port", type=int, default=19990)
parser.add_argument("--rpc_port", type=int, default=19980)
parser.add_argument("--max_batch_size", type=int, default=8)
parser.add_argument("--pipe_size", type=int, default=1)
parser.add_argument("--queue_size", type=int, default=0)
parser.add_argument("--http_host", default="0.0.0.0")
parser.add_argument("--http_port", type=int, default=7070)
parser.add_argument("--checkpoint", default=None)
parser.add_argument("--cache_size", type=int, default=0)
parser.add_argument("--cache_list_size", type=int, default=1)
args = parser.parse_args()
print_args(args)
model_kwargs = {}
if args.checkpoint is not None:
model_kwargs['checkpoint'] = args.checkpoint
model_kwargs["checkpoint"] = args.checkpoint
logger = logging.getLogger(__name__)
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b")
if args.cache_size > 0:
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
else:
cache = None
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
pad_token_id=tokenizer.pad_token_id),
pipe_size=args.pipe_size,
queue_size=args.queue_size,
**model_kwargs)
engine = launch_engine(
args.tp,
1,
args.master_host,
args.master_port,
args.rpc_port,
get_model_fn(args.model),
batch_manager=BatchManagerForGeneration(
max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id
),
pipe_size=args.pipe_size,
queue_size=args.queue_size,
**model_kwargs,
)
app.run(args.http_host, args.http_port)

View File

@@ -43,4 +43,3 @@ Finally, you will get 8 files in `<new_output_dir>` with following checksums:
5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt
f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt
```

View File

@@ -14,42 +14,45 @@ def load_json(path: str):
def parse_shape_info(flat_dir: str):
data = load_json(os.path.join(flat_dir, 'shape.json'))
data = load_json(os.path.join(flat_dir, "shape.json"))
flat_info = defaultdict(lambda: defaultdict(list))
for k, shape in data.items():
matched = re.match(r'decoder.layers.\d+', k)
matched = re.match(r"decoder.layers.\d+", k)
if matched is None:
flat_key = 'flat_param_0'
flat_key = "flat_param_0"
else:
flat_key = f'{matched[0]}.flat_param_0'
flat_info[flat_key]['names'].append(k)
flat_info[flat_key]['shapes'].append(shape)
flat_info[flat_key]['numels'].append(int(np.prod(shape)))
flat_key = f"{matched[0]}.flat_param_0"
flat_info[flat_key]["names"].append(k)
flat_info[flat_key]["shapes"].append(shape)
flat_info[flat_key]["numels"].append(int(np.prod(shape)))
return flat_info
def convert(flat_dir: str, output_dir: str, part: int):
flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt')
output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt')
flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json'))
flat_path = os.path.join(flat_dir, f"reshard-model_part-{part}-shard0.pt")
output_path = os.path.join(output_dir, f"reshard-model_part-{part}.pt")
flat_meta = load_json(os.path.join(flat_dir, "flat-meta.json"))
flat_sd = torch.load(flat_path)
print(f'Loaded flat state dict from {flat_path}')
print(f"Loaded flat state dict from {flat_path}")
output_sd = {}
for flat_key, param_meta in flat_meta.items():
flat_param = flat_sd['model'][flat_key]
assert sum(param_meta['numels']) == flat_param.numel(
flat_param = flat_sd["model"][flat_key]
assert (
sum(param_meta["numels"]) == flat_param.numel()
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])):
for name, shape, param in zip(
param_meta["names"], param_meta["shapes"], flat_param.split(param_meta["numels"])
):
output_sd[name] = param.view(shape)
torch.save(output_sd, output_path)
print(f'Saved unflat state dict to {output_path}')
print(f"Saved unflat state dict to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('flat_dir')
parser.add_argument('output_dir')
parser.add_argument('part', type=int)
parser.add_argument("flat_dir")
parser.add_argument("output_dir")
parser.add_argument("part", type=int)
args = parser.parse_args()
convert(args.flat_dir, args.output_dir, args.part)

File diff suppressed because one or more lines are too long

View File

@@ -1,7 +1,8 @@
import os
import torch
from multiprocessing import Pool
import torch
# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main
# you can use whether wget or git lfs
@@ -20,14 +21,14 @@ with Pool(14) as pool:
restored = {}
for ckpt in ckpts:
for k,v in ckpt.items():
if(k[0] == 'm'):
k = k[6:]
if(k == "lm_head.weight"):
for k, v in ckpt.items():
if k[0] == "m":
k = k[6:]
if k == "lm_head.weight":
k = "head.dense.weight"
if(k == "decoder.final_layer_norm.weight"):
if k == "decoder.final_layer_norm.weight":
k = "decoder.layer_norm.weight"
if(k == "decoder.final_layer_norm.bias"):
if k == "decoder.final_layer_norm.bias":
k = "decoder.layer_norm.bias"
restored[k] = v
restored["decoder.version"] = "0.0"
@@ -37,11 +38,11 @@ split_num = len(restored.keys()) // 60
count = 0
file_count = 1
tmp = {}
for k,v in restored.items():
for k, v in restored.items():
print(k)
tmp[k] = v
count = count + 1
if(count == split_num):
count = count + 1
if count == split_num:
filename = str(file_count) + "-restored.pt"
torch.save(tmp, os.path.join(new_path, filename))
file_count = file_count + 1
@@ -50,6 +51,3 @@ for k,v in restored.items():
filename = str(file_count) + "-restored.pt"
torch.save(tmp, os.path.join(new_path, filename))

View File

@@ -4,7 +4,7 @@ except ImportError:
# colossalai > 0.2.8
from colossalai.legacy.zero import TensorShardStrategy
zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
tensor_placement_policy="auto",
reuse_fp16_shard=True),
optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384))
zero = dict(
model_config=dict(shard_strategy=TensorShardStrategy(), tensor_placement_policy="auto", reuse_fp16_shard=True),
optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384),
)

View File

@@ -4,7 +4,7 @@ from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
class barrier_context():
class barrier_context:
"""
This context manager is used to allow one process to execute while blocking all
other processes in the same process group. This is often useful when downloading is required

View File

@@ -86,14 +86,12 @@ def parse_args():
default=None,
help="The configuration name of the dataset to use (via the datasets library).",
)
parser.add_argument("--train_file",
type=str,
default=None,
help="A csv or a json file containing the training data.")
parser.add_argument("--validation_file",
type=str,
default=None,
help="A csv or a json file containing the validation data.")
parser.add_argument(
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
)
parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
)
parser.add_argument(
"--validation_split_percentage",
default=5,
@@ -161,10 +159,9 @@ def parse_args():
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument("--num_warmup_steps",
type=int,
default=0,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
@@ -178,9 +175,11 @@ def parse_args():
"--block_size",
type=int,
default=None,
help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of"
" this size for training. Default to the model max input length for single sentence inputs (take into"
" account special tokens)."),
help=(
"Optional input sequence length after tokenization. The training dataset will be truncated in block of"
" this size for training. Default to the model max input length for single sentence inputs (take into"
" account special tokens)."
),
)
parser.add_argument(
"--preprocessing_num_workers",
@@ -188,17 +187,16 @@ def parse_args():
default=None,
help="The number of processes to use for the preprocessing.",
)
parser.add_argument("--overwrite_cache",
type=bool,
default=False,
help="Overwrite the cached training and evaluation sets")
parser.add_argument("--no_keep_linebreaks",
action="store_true",
help="Do not keep line breaks when using TXT files.")
parser.add_argument(
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
)
parser.add_argument(
"--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files."
)
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_model_id",
type=str,
help="The name of the repository to keep in sync with the local `output_dir`.")
parser.add_argument(
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--checkpointing_steps",
@@ -221,13 +219,15 @@ def parse_args():
"--report_to",
type=str,
default="all",
help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."),
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
parser.add_argument("--init_in_cpu", action="store_true", default=False, help="init training model in cpu")
args = parser.parse_args()
# Sanity checks
@@ -250,6 +250,7 @@ def parse_args():
def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
cuda_capacity = colo_device_memory_capacity(get_current_device())
if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
@@ -257,7 +258,6 @@ def colo_memory_cap(size_in_GB):
class DummyDataloader:
def __init__(self, length, batch_size, seq_len, vocab_size):
self.length = length
self.batch_size = batch_size
@@ -380,32 +380,36 @@ def main():
logger.warning("You are instantiating a new config instance from scratch.")
logger.info("Model config has been created", ranks=[0])
if args.model_name_or_path == 'facebook/opt-13b':
if args.model_name_or_path == "facebook/opt-13b":
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
else:
print(f'load model from {args.model_name_or_path}')
print(f"load model from {args.model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0])
if args.init_in_cpu:
init_dev = torch.device('cpu')
init_dev = torch.device("cpu")
else:
init_dev = get_current_device()
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
logger.info(f"using Colossal-AI version {cai_version}")
# build model
if version.parse(cai_version) >= version.parse("0.3.1"):
from contextlib import nullcontext
from colossalai.lazy import LazyInitContext
ctx = LazyInitContext(
default_device=init_dev
) if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b' else nullcontext()
ctx = (
LazyInitContext(default_device=init_dev)
if args.model_name_or_path is None or args.model_name_or_path == "facebook/opt-13b"
else nullcontext()
)
else:
from colossalai.zero import ColoInitContext
ctx = ColoInitContext(device=init_dev)
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
if args.model_name_or_path is None or args.model_name_or_path == "facebook/opt-13b":
# currently, there has a bug in pretrained opt-13b
# we can not import it until huggingface fix it
logger.info("Train a new model from scratch", ranks=[0])
@@ -414,17 +418,20 @@ def main():
else:
logger.info("Finetune a pre-trained model", ranks=[0])
with ctx:
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
local_files_only=False)
model = OPTForCausalLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
local_files_only=False,
)
# enable graident checkpointing
model.gradient_checkpointing_enable()
PLACEMENT_POLICY = 'auto'
PLACEMENT_POLICY = "auto"
if version.parse(cai_version) >= version.parse("0.3.1"):
from colossalai.zero import GeminiDDP
model = GeminiDDP(model, offload_optim_frac=1.0, pin_memory=True)
elif version.parse(cai_version) > version.parse("0.1.10"):
try:
@@ -435,16 +442,19 @@ def main():
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
chunk_manager = ChunkManager(
chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY),
)
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
model = ZeroDDP(model, gemini_manager)
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
logger.info(f"{model.__class__.__name__} has been created", ranks=[0])
if not args.synthetic:
# Preprocessing the datasets.
@@ -470,12 +480,15 @@ def main():
if block_size > 1024:
logger.warning(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can change that default value by passing --block_size xxx.")
"Picking 1024 instead. You can change that default value by passing --block_size xxx."
)
block_size = 1024
else:
if args.block_size > tokenizer.model_max_length:
logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.")
logger.warning(
f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
)
block_size = min(args.block_size, tokenizer.model_max_length)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
@@ -489,8 +502,8 @@ def main():
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i:i + block_size] for i in range(0, total_length, block_size)
] for k, t in concatenated_examples.items()
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
@@ -520,19 +533,23 @@ def main():
# logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# DataLoaders creation:
train_dataloader = get_dataloader(train_dataset,
shuffle=True,
add_sampler=True,
collate_fn=default_data_collator,
batch_size=args.per_device_train_batch_size)
eval_dataloader = DataLoader(eval_dataset,
collate_fn=default_data_collator,
batch_size=args.per_device_eval_batch_size)
train_dataloader = get_dataloader(
train_dataset,
shuffle=True,
add_sampler=True,
collate_fn=default_data_collator,
batch_size=args.per_device_train_batch_size,
)
eval_dataloader = DataLoader(
eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size
)
else:
train_dataloader = DummyDataloader(30, args.per_device_train_batch_size, config.max_position_embeddings,
config.vocab_size)
eval_dataloader = DummyDataloader(10, args.per_device_train_batch_size, config.max_position_embeddings,
config.vocab_size)
train_dataloader = DummyDataloader(
30, args.per_device_train_batch_size, config.max_position_embeddings, config.vocab_size
)
eval_dataloader = DummyDataloader(
10, args.per_device_train_batch_size, config.max_position_embeddings, config.vocab_size
)
logger.info("Dataloaders have been created", ranks=[0])
# Optimizer
@@ -593,7 +610,6 @@ def main():
global_step = 0
for epoch in range(starting_epoch, args.num_train_epochs):
if completed_steps >= args.max_train_steps:
break
@@ -601,7 +617,7 @@ def main():
for step, batch in enumerate(train_dataloader):
batch = {k: v.cuda() for k, v in batch.items()}
outputs = model(use_cache=False, **batch)
loss = outputs['loss']
loss = outputs["loss"]
optimizer.backward(loss)
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
@@ -624,7 +640,7 @@ def main():
batch = {k: v.cuda() for k, v in batch.items()}
outputs = model(**batch)
loss = outputs['loss'].unsqueeze(0)
loss = outputs["loss"].unsqueeze(0)
losses.append(loss)
losses = torch.cat(losses)
@@ -640,7 +656,7 @@ def main():
if args.output_dir is not None:
model_state = model.state_dict()
if is_main_process:
torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps))
torch.save(model_state, args.output_dir + "/epoch_{}_model.pth".format(completed_steps))
dist.barrier()
# load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps))
# model.load_state_dict(load_state, strict=False)