mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-22 12:02:44 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -43,4 +43,3 @@ Finally, you will get 8 files in `<new_output_dir>` with following checksums:
|
||||
5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt
|
||||
f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt
|
||||
```
|
||||
|
||||
|
||||
@@ -14,42 +14,45 @@ def load_json(path: str):
|
||||
|
||||
|
||||
def parse_shape_info(flat_dir: str):
|
||||
data = load_json(os.path.join(flat_dir, 'shape.json'))
|
||||
data = load_json(os.path.join(flat_dir, "shape.json"))
|
||||
flat_info = defaultdict(lambda: defaultdict(list))
|
||||
for k, shape in data.items():
|
||||
matched = re.match(r'decoder.layers.\d+', k)
|
||||
matched = re.match(r"decoder.layers.\d+", k)
|
||||
if matched is None:
|
||||
flat_key = 'flat_param_0'
|
||||
flat_key = "flat_param_0"
|
||||
else:
|
||||
flat_key = f'{matched[0]}.flat_param_0'
|
||||
flat_info[flat_key]['names'].append(k)
|
||||
flat_info[flat_key]['shapes'].append(shape)
|
||||
flat_info[flat_key]['numels'].append(int(np.prod(shape)))
|
||||
flat_key = f"{matched[0]}.flat_param_0"
|
||||
flat_info[flat_key]["names"].append(k)
|
||||
flat_info[flat_key]["shapes"].append(shape)
|
||||
flat_info[flat_key]["numels"].append(int(np.prod(shape)))
|
||||
return flat_info
|
||||
|
||||
|
||||
def convert(flat_dir: str, output_dir: str, part: int):
|
||||
flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt')
|
||||
output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt')
|
||||
flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json'))
|
||||
flat_path = os.path.join(flat_dir, f"reshard-model_part-{part}-shard0.pt")
|
||||
output_path = os.path.join(output_dir, f"reshard-model_part-{part}.pt")
|
||||
flat_meta = load_json(os.path.join(flat_dir, "flat-meta.json"))
|
||||
flat_sd = torch.load(flat_path)
|
||||
print(f'Loaded flat state dict from {flat_path}')
|
||||
print(f"Loaded flat state dict from {flat_path}")
|
||||
output_sd = {}
|
||||
for flat_key, param_meta in flat_meta.items():
|
||||
flat_param = flat_sd['model'][flat_key]
|
||||
assert sum(param_meta['numels']) == flat_param.numel(
|
||||
flat_param = flat_sd["model"][flat_key]
|
||||
assert (
|
||||
sum(param_meta["numels"]) == flat_param.numel()
|
||||
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
|
||||
for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])):
|
||||
for name, shape, param in zip(
|
||||
param_meta["names"], param_meta["shapes"], flat_param.split(param_meta["numels"])
|
||||
):
|
||||
output_sd[name] = param.view(shape)
|
||||
|
||||
torch.save(output_sd, output_path)
|
||||
print(f'Saved unflat state dict to {output_path}')
|
||||
print(f"Saved unflat state dict to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('flat_dir')
|
||||
parser.add_argument('output_dir')
|
||||
parser.add_argument('part', type=int)
|
||||
parser.add_argument("flat_dir")
|
||||
parser.add_argument("output_dir")
|
||||
parser.add_argument("part", type=int)
|
||||
args = parser.parse_args()
|
||||
convert(args.flat_dir, args.output_dir, args.part)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import torch
|
||||
from multiprocessing import Pool
|
||||
|
||||
import torch
|
||||
|
||||
# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main
|
||||
# you can use whether wget or git lfs
|
||||
|
||||
@@ -20,14 +21,14 @@ with Pool(14) as pool:
|
||||
|
||||
restored = {}
|
||||
for ckpt in ckpts:
|
||||
for k,v in ckpt.items():
|
||||
if(k[0] == 'm'):
|
||||
k = k[6:]
|
||||
if(k == "lm_head.weight"):
|
||||
for k, v in ckpt.items():
|
||||
if k[0] == "m":
|
||||
k = k[6:]
|
||||
if k == "lm_head.weight":
|
||||
k = "head.dense.weight"
|
||||
if(k == "decoder.final_layer_norm.weight"):
|
||||
if k == "decoder.final_layer_norm.weight":
|
||||
k = "decoder.layer_norm.weight"
|
||||
if(k == "decoder.final_layer_norm.bias"):
|
||||
if k == "decoder.final_layer_norm.bias":
|
||||
k = "decoder.layer_norm.bias"
|
||||
restored[k] = v
|
||||
restored["decoder.version"] = "0.0"
|
||||
@@ -37,11 +38,11 @@ split_num = len(restored.keys()) // 60
|
||||
count = 0
|
||||
file_count = 1
|
||||
tmp = {}
|
||||
for k,v in restored.items():
|
||||
for k, v in restored.items():
|
||||
print(k)
|
||||
tmp[k] = v
|
||||
count = count + 1
|
||||
if(count == split_num):
|
||||
count = count + 1
|
||||
if count == split_num:
|
||||
filename = str(file_count) + "-restored.pt"
|
||||
torch.save(tmp, os.path.join(new_path, filename))
|
||||
file_count = file_count + 1
|
||||
@@ -50,6 +51,3 @@ for k,v in restored.items():
|
||||
|
||||
filename = str(file_count) + "-restored.pt"
|
||||
torch.save(tmp, os.path.join(new_path, filename))
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user