mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-01 06:19:48 +00:00
[tutorial] edited hands-on practices (#1899)
* Add handson to ColossalAI. * Change names of handsons and edit sequence parallel example. * Edit wrong folder name * resolve conflict * delete readme
This commit is contained in:
@@ -0,0 +1,46 @@
|
||||
# Process OPT-175B weights
|
||||
|
||||
You should download the pre-trained weights following the [doc](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) before reading this.
|
||||
|
||||
First, install `metaseq` and `git clone https://github.com/facebookresearch/metaseq.git`.
|
||||
|
||||
Then, `cd metaseq`.
|
||||
|
||||
To consolidate checkpoints to eliminate FSDP:
|
||||
|
||||
```shell
|
||||
bash metaseq/scripts/reshard_mp_launch_no_slurm.sh <directory_where_all_the_shards_are>/checkpoint_last <output_dir>/ 8 1
|
||||
```
|
||||
|
||||
You will get 8 files in `<output_dir>`, and you should have the following checksums:
|
||||
```
|
||||
7e71cb65c4be784aa0b2889ac6039ee8 reshard-model_part-0-shard0.pt
|
||||
c8123da04f2c25a9026ea3224d5d5022 reshard-model_part-1-shard0.pt
|
||||
45e5d10896382e5bc4a7064fcafd2b1e reshard-model_part-2-shard0.pt
|
||||
abb7296c4d2fc17420b84ca74fc3ce64 reshard-model_part-3-shard0.pt
|
||||
05dcc7ac6046f4d3f90b3d1068e6da15 reshard-model_part-4-shard0.pt
|
||||
d24dd334019060ce1ee7e625fcf6b4bd reshard-model_part-5-shard0.pt
|
||||
fb1615ce0bbe89cc717f3e5079ee2655 reshard-model_part-6-shard0.pt
|
||||
2f3124432d2dbc6aebfca06be4b791c2 reshard-model_part-7-shard0.pt
|
||||
```
|
||||
|
||||
Copy `flat-meta.json` to `<output_dir>`.
|
||||
|
||||
Then cd to this dir, and we unflatten parameters.
|
||||
|
||||
```shell
|
||||
bash unflat.sh <output_dir>/ <new_output_dir>/
|
||||
```
|
||||
|
||||
Finally, you will get 8 files in `<new_output_dir>` with following checksums:
|
||||
```
|
||||
6169c59d014be95553c89ec01b8abb62 reshard-model_part-0.pt
|
||||
58868105da3d74a528a548fdb3a8cff6 reshard-model_part-1.pt
|
||||
69b255dc5a49d0eba9e4b60432cda90b reshard-model_part-2.pt
|
||||
002c052461ff9ffb0cdac3d5906f41f2 reshard-model_part-3.pt
|
||||
6d57f72909320d511ffd5f1c668b2beb reshard-model_part-4.pt
|
||||
93c8c4041cdc0c7907cc7afcf15cec2a reshard-model_part-5.pt
|
||||
5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt
|
||||
f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt
|
||||
```
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def load_json(path: str):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def parse_shape_info(flat_dir: str):
|
||||
data = load_json(os.path.join(flat_dir, 'shape.json'))
|
||||
flat_info = defaultdict(lambda: defaultdict(list))
|
||||
for k, shape in data.items():
|
||||
matched = re.match(r'decoder.layers.\d+', k)
|
||||
if matched is None:
|
||||
flat_key = 'flat_param_0'
|
||||
else:
|
||||
flat_key = f'{matched[0]}.flat_param_0'
|
||||
flat_info[flat_key]['names'].append(k)
|
||||
flat_info[flat_key]['shapes'].append(shape)
|
||||
flat_info[flat_key]['numels'].append(int(np.prod(shape)))
|
||||
return flat_info
|
||||
|
||||
|
||||
def convert(flat_dir: str, output_dir: str, part: int):
|
||||
flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt')
|
||||
output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt')
|
||||
flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json'))
|
||||
flat_sd = torch.load(flat_path)
|
||||
print(f'Loaded flat state dict from {flat_path}')
|
||||
output_sd = {}
|
||||
for flat_key, param_meta in flat_meta.items():
|
||||
flat_param = flat_sd['model'][flat_key]
|
||||
assert sum(param_meta['numels']) == flat_param.numel(
|
||||
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
|
||||
for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])):
|
||||
output_sd[name] = param.view(shape)
|
||||
|
||||
torch.save(output_sd, output_path)
|
||||
print(f'Saved unflat state dict to {output_path}')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('flat_dir')
|
||||
parser.add_argument('output_dir')
|
||||
parser.add_argument('part', type=int)
|
||||
args = parser.parse_args()
|
||||
convert(args.flat_dir, args.output_dir, args.part)
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
for i in $(seq 0 7); do
|
||||
python convert_ckpt.py $1 $2 ${i} &
|
||||
done
|
||||
|
||||
wait $(jobs -p)
|
||||
Reference in New Issue
Block a user