mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
support session-based training (#4313)
Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [Install requirements](#install-requirements)
|
||||
- [Supervised datasets collection](#supervised-datasets-collection)
|
||||
- [Conversation dataset generation](#conversation-dataset-generation)
|
||||
- [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning)
|
||||
- [Arg List](#arg-list)
|
||||
- [Stage2 - Training reward model](#stage2---training-reward-model)
|
||||
@@ -45,6 +46,49 @@ The following pic shows how we collected the data.
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/data-collect.png" width=500/>
|
||||
</p>
|
||||
|
||||
### Conversation dataset generation
|
||||
|
||||
In order to further improve the model's ability to handle multi-turn conversations, we need to include samples with multi-turn conversations in the dataset. However, the samples in InstructWild and Alpaca datasets currently consist of only single-turn conversations, and their dataset organization is not suitable for storing multi-turn conversations. Additionally, after converting the aforementioned datasets, we also need to include multi-turn conversation datasets like ShareGPT, and we should transform them into the training format supported by ColossalChat.
|
||||
|
||||
A sample of conversation dataset should have the following fields:
|
||||
|
||||
* `type` (str, optional): The type of the data sample.
|
||||
* `language` (str, optional): The language of the data sample.
|
||||
* `dataset` (str, optional): The dataset the data sample originates from.
|
||||
* `conversations` (str, compulsory): Conversation content of the data sample.
|
||||
* `id` (int, optional): The ID of the data sample.
|
||||
|
||||
A simple example:
|
||||
```json
|
||||
{
|
||||
"type": "instruction",
|
||||
"language": "English",
|
||||
"dataset": "Alpaca",
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "Give three tips for staying healthy."
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
|
||||
}
|
||||
],
|
||||
"id": 1
|
||||
}
|
||||
```
|
||||
|
||||
> **NOTE:** Only key `conversations` is compulsary for training and other keys serve as metadata. The length of `conversations` varies.
|
||||
|
||||
You can run the `examples/generate_conversation_dataset.py` to generate a conversation dataset supported by ColossalChat.
|
||||
|
||||
You can use the following cmd to generate conversation dataset.
|
||||
```
|
||||
python generate_conversation_dataset.py \
|
||||
--dataset "All"
|
||||
--save_path "/path/to/dataset"
|
||||
```
|
||||
|
||||
## Stage1 - Supervised instructs tuning
|
||||
|
||||
Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model.
|
||||
|
79
applications/Chat/examples/generate_conversation_dataset.py
Normal file
79
applications/Chat/examples/generate_conversation_dataset.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def generate_alpaca():
|
||||
# We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation.
|
||||
conversation_dataset = []
|
||||
dataset = load_dataset("tatsu-lab/alpaca", split="train")
|
||||
|
||||
instructions = dataset["instruction"]
|
||||
inputs = dataset["input"]
|
||||
outputs = dataset["output"]
|
||||
|
||||
assert len(instructions) == len(inputs) == len(outputs)
|
||||
|
||||
for idx in range(len(instructions)):
|
||||
human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx]
|
||||
human = {"from": "human", "value": human_utterance}
|
||||
|
||||
gpt_utterance = outputs[idx]
|
||||
gpt = {"from": "gpt", "value": gpt_utterance}
|
||||
|
||||
conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt])
|
||||
conversation_dataset.append(conversation)
|
||||
|
||||
return conversation_dataset
|
||||
|
||||
|
||||
def generate_sharegpt():
|
||||
# ShareGPT data requires less processing.
|
||||
conversation_dataset = []
|
||||
dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered",
|
||||
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
|
||||
split="train")
|
||||
|
||||
conversations = dataset["conversations"]
|
||||
|
||||
for idx in range(len(conversations)):
|
||||
for conv in conversations[idx]:
|
||||
# We don't need markdown and text value.
|
||||
del conv["markdown"]
|
||||
del conv["text"]
|
||||
|
||||
conversation = dict(type="conversation",
|
||||
language="Multilingual",
|
||||
dataset="ShareGPT",
|
||||
conversations=conversations[idx])
|
||||
conversation_dataset.append(conversation)
|
||||
|
||||
return conversation_dataset
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
default="All",
|
||||
choices=["Alpaca", "ShareGPT", "All"],
|
||||
help="which dataset to convert, All will combine Alpaca and ShareGPT")
|
||||
parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset")
|
||||
args = parser.parse_args()
|
||||
|
||||
conversation_dataset = []
|
||||
|
||||
if args.dataset == "Alpaca":
|
||||
conversation_dataset.extend(generate_alpaca())
|
||||
elif args.dataset == "ShareGPT":
|
||||
conversation_dataset.extend(generate_sharegpt())
|
||||
else:
|
||||
conversation_dataset.extend(generate_alpaca())
|
||||
conversation_dataset.extend(generate_sharegpt())
|
||||
|
||||
for idx, sample in enumerate(conversation_dataset):
|
||||
sample["id"] = idx + 1
|
||||
|
||||
with open(args.save_path, mode='w') as f:
|
||||
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
|
@@ -74,8 +74,8 @@ def train(args):
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
tokenizer.eos_token = '</s>'
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
@@ -153,9 +153,7 @@ def train(args):
|
||||
optim,
|
||||
num_warmup_steps=math.ceil(max_steps * 0.03),
|
||||
num_training_steps=max_steps)
|
||||
strategy_dict = strategy.prepare(
|
||||
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
|
||||
)
|
||||
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
|
||||
model = strategy_dict['model']
|
||||
optim = strategy_dict['optimizer']
|
||||
lr_scheduler = strategy_dict['lr_scheduler']
|
||||
|
Reference in New Issue
Block a user