From 05759839bd28b0e9b0b43b308a1e01d5bc730e36 Mon Sep 17 00:00:00 2001 From: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Date: Wed, 17 May 2023 17:44:05 +0800 Subject: [PATCH] [chat] fix bugs in stage 3 training (#3759) Co-authored-by: Yuanchen Xu --- .../Chat/coati/dataset/prompt_dataset.py | 2 +- applications/Chat/examples/README.md | 2 +- .../Chat/examples/example_data_reformat.py | 12 -------- .../Chat/examples/generate_prompt_dataset.py | 30 +++++++++++++++++++ 4 files changed, 32 insertions(+), 14 deletions(-) delete mode 100644 applications/Chat/examples/example_data_reformat.py create mode 100644 applications/Chat/examples/generate_prompt_dataset.py diff --git a/applications/Chat/coati/dataset/prompt_dataset.py b/applications/Chat/coati/dataset/prompt_dataset.py index f8ab2346c..5858052c8 100644 --- a/applications/Chat/coati/dataset/prompt_dataset.py +++ b/applications/Chat/coati/dataset/prompt_dataset.py @@ -45,7 +45,7 @@ class PromptDataset(Dataset): self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind()) def __len__(self): - return len(self.keyed_prompt) + return len(self.keyed_prompt["input_ids"]) def __getitem__(self, i) -> Dict[str, torch.Tensor]: return {k: v[i] for k, v in self.keyed_prompt.items()} diff --git a/applications/Chat/examples/README.md b/applications/Chat/examples/README.md index 2a2128e25..60f876eda 100644 --- a/applications/Chat/examples/README.md +++ b/applications/Chat/examples/README.md @@ -153,7 +153,7 @@ torchrun --standalone --nproc_per_node=4 train_prompts.py \ --rm_path /your/rm/model/path ``` -Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/example_data_reformat.py) to reformat [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild. +Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/generate_prompt_dataset.py) which samples `instinwild_en.json` or `instinwild_ch.json` in [InstructionWild](https://github.com/XueFuzhao/InstructionWild/tree/main/data#instructwild-data) to generate the prompt dataset. Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning. ### Arg List diff --git a/applications/Chat/examples/example_data_reformat.py b/applications/Chat/examples/example_data_reformat.py deleted file mode 100644 index dc83b29b5..000000000 --- a/applications/Chat/examples/example_data_reformat.py +++ /dev/null @@ -1,12 +0,0 @@ -jsonl_file = 'seed_prompts_xx.jsonl' # seed_prompts_en.jsonl or seed_prompts_ch.json from InstructionWild -reformat_file = 'prompts_xx.jsonl' # reformat jsonl file used as Prompt dataset in Stage3 - -data = '' -with open(jsonl_file, 'r', encoding="utf-8") as f1: - for jsonstr in f1.readlines(): - jsonstr = '\t' + jsonstr.strip('\n') + ',\n' - data = data + jsonstr - data = '[\n' + data + ']' - -with open(reformat_file, 'w') as f2: - f2.write(data) \ No newline at end of file diff --git a/applications/Chat/examples/generate_prompt_dataset.py b/applications/Chat/examples/generate_prompt_dataset.py new file mode 100644 index 000000000..95e40fefe --- /dev/null +++ b/applications/Chat/examples/generate_prompt_dataset.py @@ -0,0 +1,30 @@ +import argparse + +import random +import json + +random.seed(42) + + +def sample(args): + with open(args.dataset_path, mode='r') as f: + dataset_list = json.load(f) + + sampled_dataset = [{"instruction": sample["instruction"], "id":idx} + for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))] + + with open(args.save_path, mode='w') as f: + json.dump(sampled_dataset, f, indent=4, + default=str, ensure_ascii=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset_path', type=str, default=None, + required=True, help="path to the pretrain dataset") + parser.add_argument('--save_path', type=str, default='prompt.json', + help="path to save the prompt dataset") + parser.add_argument('--sample_size', type=int, + default=16384, help="size of the prompt dataset") + args = parser.parse_args() + sample(args)