[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -43,4 +43,3 @@ Finally, you will get 8 files in `<new_output_dir>` with following checksums:
5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt
f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt
```

View File

@@ -14,42 +14,45 @@ def load_json(path: str):
def parse_shape_info(flat_dir: str):
data = load_json(os.path.join(flat_dir, 'shape.json'))
data = load_json(os.path.join(flat_dir, "shape.json"))
flat_info = defaultdict(lambda: defaultdict(list))
for k, shape in data.items():
matched = re.match(r'decoder.layers.\d+', k)
matched = re.match(r"decoder.layers.\d+", k)
if matched is None:
flat_key = 'flat_param_0'
flat_key = "flat_param_0"
else:
flat_key = f'{matched[0]}.flat_param_0'
flat_info[flat_key]['names'].append(k)
flat_info[flat_key]['shapes'].append(shape)
flat_info[flat_key]['numels'].append(int(np.prod(shape)))
flat_key = f"{matched[0]}.flat_param_0"
flat_info[flat_key]["names"].append(k)
flat_info[flat_key]["shapes"].append(shape)
flat_info[flat_key]["numels"].append(int(np.prod(shape)))
return flat_info
def convert(flat_dir: str, output_dir: str, part: int):
flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt')
output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt')
flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json'))
flat_path = os.path.join(flat_dir, f"reshard-model_part-{part}-shard0.pt")
output_path = os.path.join(output_dir, f"reshard-model_part-{part}.pt")
flat_meta = load_json(os.path.join(flat_dir, "flat-meta.json"))
flat_sd = torch.load(flat_path)
print(f'Loaded flat state dict from {flat_path}')
print(f"Loaded flat state dict from {flat_path}")
output_sd = {}
for flat_key, param_meta in flat_meta.items():
flat_param = flat_sd['model'][flat_key]
assert sum(param_meta['numels']) == flat_param.numel(
flat_param = flat_sd["model"][flat_key]
assert (
sum(param_meta["numels"]) == flat_param.numel()
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])):
for name, shape, param in zip(
param_meta["names"], param_meta["shapes"], flat_param.split(param_meta["numels"])
):
output_sd[name] = param.view(shape)
torch.save(output_sd, output_path)
print(f'Saved unflat state dict to {output_path}')
print(f"Saved unflat state dict to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('flat_dir')
parser.add_argument('output_dir')
parser.add_argument('part', type=int)
parser.add_argument("flat_dir")
parser.add_argument("output_dir")
parser.add_argument("part", type=int)
args = parser.parse_args()
convert(args.flat_dir, args.output_dir, args.part)

File diff suppressed because one or more lines are too long

View File

@@ -1,7 +1,8 @@
import os
import torch
from multiprocessing import Pool
import torch
# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main
# you can use whether wget or git lfs
@@ -20,14 +21,14 @@ with Pool(14) as pool:
restored = {}
for ckpt in ckpts:
for k,v in ckpt.items():
if(k[0] == 'm'):
k = k[6:]
if(k == "lm_head.weight"):
for k, v in ckpt.items():
if k[0] == "m":
k = k[6:]
if k == "lm_head.weight":
k = "head.dense.weight"
if(k == "decoder.final_layer_norm.weight"):
if k == "decoder.final_layer_norm.weight":
k = "decoder.layer_norm.weight"
if(k == "decoder.final_layer_norm.bias"):
if k == "decoder.final_layer_norm.bias":
k = "decoder.layer_norm.bias"
restored[k] = v
restored["decoder.version"] = "0.0"
@@ -37,11 +38,11 @@ split_num = len(restored.keys()) // 60
count = 0
file_count = 1
tmp = {}
for k,v in restored.items():
for k, v in restored.items():
print(k)
tmp[k] = v
count = count + 1
if(count == split_num):
count = count + 1
if count == split_num:
filename = str(file_count) + "-restored.pt"
torch.save(tmp, os.path.join(new_path, filename))
file_count = file_count + 1
@@ -50,6 +51,3 @@ for k,v in restored.items():
filename = str(file_count) + "-restored.pt"
torch.save(tmp, os.path.join(new_path, filename))