refactor tokenization

This commit is contained in:
YeAnbang 2024-07-19 10:10:48 +00:00
parent 544b7a38a1
commit d49550fb49
9 changed files with 159 additions and 175 deletions

View File

@ -18,6 +18,7 @@ class Conversation:
chat_template: str chat_template: str
stop_ids: List[int] stop_ids: List[int]
end_of_assistant: str end_of_assistant: str
roles = ["user", "assistant"]
@classmethod @classmethod
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict): def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
@ -85,7 +86,7 @@ class Conversation:
Raises: Raises:
AssertionError: If the role is not 'user' or 'assistant'. AssertionError: If the role is not 'user' or 'assistant'.
""" """
assert role in ["user", "assistant"] assert role in self.roles
self.messages.append({"role": role, "content": message}) self.messages.append({"role": role, "content": message})
def copy(self): def copy(self):

View File

@ -39,7 +39,7 @@ def supervised_tokenize_sft(
Args: Args:
data_point: the data point of the following format data_point: the data point of the following format
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]} {"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose tokenizer: the tokenizer whose
conversation_template: the conversation template to apply conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training ignore_index: the ignore index when calculate loss during training
@ -52,41 +52,25 @@ def supervised_tokenize_sft(
messages = data_point["messages"] messages = data_point["messages"]
template = deepcopy(conversation_template) template = deepcopy(conversation_template)
template.messages = [] template.messages = []
for idx, mess in enumerate(messages):
for mess in messages: if mess["from"] != template.roles[idx % 2]:
from_str = mess["from"] raise ValueError(
if from_str.lower() == "human": f"Message should iterate between user and assistant and starts with a \
from_str = "user" line from the user. Got the following data:\n{messages}"
elif from_str.lower() == "assistant": )
from_str = "assistant" template.append_message(mess["from"], mess["content"])
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
template.append_message(from_str, mess["content"])
if len(template.messages) % 2 != 0: if len(template.messages) % 2 != 0:
# Force to end with assistant response
template.messages = template.messages[0:-1] template.messages = template.messages[0:-1]
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time. # tokenize and calculate masked labels -100 for positions corresponding to non-assistant lines
turns = [i for i in range(1, len(messages) // 2 + 1)] prompt = template.get_prompt()
lo, hi = 0, len(turns)
while lo < hi:
mid = (lo + hi) // 2
prompt = template.get_prompt(2 * turns[mid] - 1)
chunks, require_loss = split_templated_prompt_into_chunks( chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * turns[mid] - 1], prompt, conversation_template.end_of_assistant template.messages, prompt, conversation_template.end_of_assistant
) )
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss) tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=max_length)
if max_length - 1 < len(tokenized): if tokenized is None:
hi = mid
else:
lo = mid + 1
target_turn_index = lo
# The tokenized length for first turn already exceeds `max_length - 1`.
if target_turn_index - 1 < 0:
warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.")
return dict( return dict(
input_ids=None, input_ids=None,
labels=None, labels=None,
@ -96,45 +80,18 @@ def supervised_tokenize_sft(
seq_category=None, seq_category=None,
) )
target_turn = turns[target_turn_index - 1]
prompt = template.get_prompt(2 * target_turn)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * target_turn], prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
labels = [ignore_index] * len(tokenized) labels = [ignore_index] * len(tokenized)
for start, end in zip(starts, ends): for start, end in zip(starts, ends):
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
labels = labels + [ignore_index]
labels[start:end] = tokenized[start:end] labels[start:end] = tokenized[start:end]
# truncate the sequence at the last token that requires loss calculation
to_truncate_len = 0
for i in range(len(tokenized) - 1, -1, -1):
if labels[i] == ignore_index:
to_truncate_len += 1
else:
break
to_truncate_len = max(len(tokenized) - max_length, to_truncate_len)
tokenized = tokenized[: len(tokenized) - to_truncate_len]
labels = labels[: len(labels) - to_truncate_len]
if tokenizer.bos_token_id is not None: if tokenizer.bos_token_id is not None:
# Force to add bos token at the beginning of the tokenized sequence if the input ids doesn;t starts with bos
if tokenized[0] != tokenizer.bos_token_id: if tokenized[0] != tokenizer.bos_token_id:
# Some chat templates already include bos token
tokenized = [tokenizer.bos_token_id] + tokenized tokenized = [tokenizer.bos_token_id] + tokenized
labels = [ignore_index] + labels labels = [-100] + labels
if tokenizer.eos_token_id is not None: # log decoded inputs and labels for debugging
# Force to add eos token at the end of the tokenized sequence
if tokenized[-1] != tokenizer.eos_token_id:
tokenized = tokenized + [tokenizer.eos_token_id]
labels = labels + [tokenizer.eos_token_id]
else:
labels[-1] = tokenizer.eos_token_id
# For some model without bos/eos may raise the following errors
inputs_decode = tokenizer.decode(tokenized) inputs_decode = tokenizer.decode(tokenized)
start = 0 start = 0
end = 0 end = 0
@ -183,7 +140,7 @@ def tokenize_prompt_dataset(
"Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]" "Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]"
Args: Args:
data_point: the data point of the following format data_point: the data point of the following format
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]} {"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose tokenizer: the tokenizer whose
conversation_template: the conversation template to apply conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training ignore_index: the ignore index when calculate loss during training
@ -196,35 +153,28 @@ def tokenize_prompt_dataset(
template = deepcopy(conversation_template) template = deepcopy(conversation_template)
template.messages = [] template.messages = []
for mess in messages: for idx, mess in enumerate(messages):
from_str = mess["from"] if mess["from"] != template.roles[idx % 2]:
if from_str.lower() == "human": raise ValueError(
from_str = "user" f"Message should iterate between user and assistant and starts with a \
elif from_str.lower() == "assistant": line from the user. Got the following data:\n{messages}"
from_str = "assistant" )
else: template.append_message(mess["from"], mess["content"])
raise ValueError(f"Unsupported role {from_str.lower()}")
template.append_message(from_str, mess["content"])
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time. # `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
target_turn = len(template.messages) if len(template.messages) % 2 != 1:
if target_turn % 2 != 1:
# exclude the answer if provided. keep only the prompt # exclude the answer if provided. keep only the prompt
target_turn = target_turn - 1 template.messages = template.messages[:-1]
# Prepare data # Prepare data
prompt = template.get_prompt(target_turn, add_generation_prompt=True) prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True)
chunks, require_loss = split_templated_prompt_into_chunks( tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
template.messages[:target_turn], prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
if tokenizer.bos_token_id is not None: if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id: if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized tokenized = [tokenizer.bos_token_id] + tokenized
# Skip overlength data if len(tokenized) > max_length:
if max_length - 1 < len(tokenized):
return dict( return dict(
input_ids=None, input_ids=None,
inputs_decode=None, inputs_decode=None,
@ -235,47 +185,32 @@ def tokenize_prompt_dataset(
# `inputs_decode` can be used to check whether the tokenization method is true. # `inputs_decode` can be used to check whether the tokenization method is true.
return dict( return dict(
input_ids=tokenized, input_ids=tokenized,
inputs_decode=tokenizer.decode(tokenized), inputs_decode=prompt,
seq_length=len(tokenized), seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None", seq_category=data_point["category"] if "category" in data_point else "None",
) )
def apply_rlhf_data_format( def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False
):
target_turn = int(len(template.messages) / 2) target_turn = int(len(template.messages) / 2)
prompt = template.get_prompt(target_turn * 2) prompt = template.get_prompt(target_turn * 2)
chunks, require_loss = split_templated_prompt_into_chunks( chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * target_turn], prompt, template.end_of_assistant template.messages[: 2 * target_turn], prompt, template.end_of_assistant
) )
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss) # no truncation applied
loss_mask = [0] * len(tokenized) tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=int(1e10))
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
if mask_token is None:
mask_token = 1 # If the tokenizer doesn't have eos_token or pad_token: Qwen
loss_mask = [0] * len(tokenized)
label_decode = [] label_decode = []
for start, end in zip(starts[-1:], ends[-1:]): # only the last round (chosen/rejected) is used to calculate loss
# only the last round (chosen/rejected) counts for i in range(starts[-1], ends[-1]):
if end == len(tokenized): loss_mask[i] = 1
tokenized = tokenized + [tokenizer.eos_token_id] label_decode.append(tokenizer.decode(tokenized[starts[-1] : ends[-1]], skip_special_tokens=False))
loss_mask = loss_mask + [1]
loss_mask[start:end] = [1] * len(loss_mask[start:end])
label_decode.append(tokenizer.decode(tokenized[start:end], skip_special_tokens=False))
if tokenizer.bos_token_id is not None: if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id: if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized tokenized = [tokenizer.bos_token_id] + tokenized
loss_mask = [0] + loss_mask loss_mask = [0] + loss_mask
if tokenizer.eos_token_id is not None:
# Force to add eos token at the end of the tokenized sequence
if tokenized[-1] != tokenizer.eos_token_id:
tokenized = tokenized + [tokenizer.eos_token_id]
loss_mask = loss_mask + [1]
else:
loss_mask[-1] = 1
return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode} return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode}
@ -288,7 +223,7 @@ def tokenize_rlhf(
) -> Dict[str, Union[int, str, List[int]]]: ) -> Dict[str, Union[int, str, List[int]]]:
""" """
A tokenization function to tokenize an original pretraining data point as following: A tokenization function to tokenize an original pretraining data point as following:
{"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}], {"context": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
"chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}} "chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
""" """
if ignore_index is None: if ignore_index is None:
@ -298,24 +233,17 @@ def tokenize_rlhf(
template = deepcopy(conversation_template) template = deepcopy(conversation_template)
template.clear() template.clear()
for mess in context: for idx, mess in enumerate(context):
from_str = mess["from"] if mess["from"] != template.roles[idx % 2]:
if from_str.lower() == "human": raise ValueError(
from_str = "user" f"Message should iterate between user and assistant and starts with a \
elif from_str.lower() == "assistant": line from the user. Got the following data:\n{context}"
from_str = "assistant" )
else: template.append_message(mess["from"], mess["content"])
raise ValueError(f"Unsupported role {from_str.lower()}")
if len(template.messages) > 0 and from_str == template.messages[-1]["role"]:
# Concate adjacent message from the same role
template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"])
else:
template.append_message(from_str, mess["content"])
if len(template.messages) % 2 != 1: if len(template.messages) % 2 != 1:
warnings.warn( warnings.warn(
"Please make sure leading context starts and ends with a line from human\nLeading context: " "Please make sure leading context starts and ends with a line from user\nLeading context: "
+ str(template.messages) + str(template.messages)
) )
return dict( return dict(
@ -326,31 +254,27 @@ def tokenize_rlhf(
rejected_loss_mask=None, rejected_loss_mask=None,
rejected_label_decode=None, rejected_label_decode=None,
) )
round_of_context = int((len(template.messages) - 1) / 2)
assert context[-1]["from"].lower() == "human", "The last message in context should be from human." assert context[-1]["from"].lower() == template.roles[0], "The last message in context should be from user."
chosen = deepcopy(template) chosen = deepcopy(template)
rejected = deepcopy(template) rejected = deepcopy(template)
chosen_continuation = data_point["chosen"]
rejected_continuation = data_point["rejected"]
for round in range(len(chosen_continuation)):
if chosen_continuation[round]["from"] != template.roles[(round + 1) % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{chosen_continuation}"
)
chosen.append_message(chosen_continuation[round]["from"], chosen_continuation[round]["content"])
for round in range(len(data_point["chosen"])): for round in range(len(rejected_continuation)):
from_str = data_point["chosen"][round]["from"] if rejected_continuation[round]["from"] != template.roles[(round + 1) % 2]:
if from_str.lower() == "human": raise ValueError(
from_str = "user" f"Message should iterate between user and assistant and starts with a \
elif from_str.lower() == "assistant": line from the user. Got the following data:\n{rejected_continuation}"
from_str = "assistant" )
else: rejected.append_message(rejected_continuation[round]["from"], rejected_continuation[round]["content"])
raise ValueError(f"Unsupported role {from_str.lower()}")
chosen.append_message(from_str, data_point["chosen"][round]["content"])
for round in range(len(data_point["rejected"])):
from_str = data_point["rejected"][round]["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
rejected.append_message(from_str, data_point["rejected"][round]["content"])
( (
chosen_input_ids, chosen_input_ids,
@ -361,16 +285,14 @@ def tokenize_rlhf(
rejected_label_decode, rejected_label_decode,
) = (None, None, None, None, None, None) ) = (None, None, None, None, None, None)
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context) chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer)
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = ( (chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
chosen_data_packed["input_ids"], chosen_data_packed["input_ids"],
chosen_data_packed["loss_mask"], chosen_data_packed["loss_mask"],
chosen_data_packed["label_decode"], chosen_data_packed["label_decode"],
) )
rejected_data_packed = apply_rlhf_data_format( rejected_data_packed = apply_rlhf_data_format(rejected, tokenizer)
rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
)
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = ( (rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
rejected_data_packed["input_ids"], rejected_data_packed["input_ids"],
rejected_data_packed["loss_mask"], rejected_data_packed["loss_mask"],
@ -387,7 +309,7 @@ def tokenize_rlhf(
rejected_label_decode=None, rejected_label_decode=None,
) )
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long # Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
if chosen_loss_mask[1:].count(1) == 0 or rejected_loss_mask[1:].count(1) == 0: if chosen_loss_mask.count(1) == 0 or rejected_loss_mask.count(1) == 0:
return dict( return dict(
chosen_input_ids=None, chosen_input_ids=None,
chosen_loss_mask=None, chosen_loss_mask=None,
@ -411,14 +333,13 @@ def tokenize_kto(
data_point: Dict[str, str], data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None, conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096, max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]: ) -> Dict[str, Union[int, str, List[int]]]:
""" """
Tokenize a dataset for KTO training Tokenize a dataset for KTO training
The raw input data is conversation that have the following format The raw input data is conversation that have the following format
{ {
"prompt": [{"from": "human", "content": "xxx"}...], "prompt": [{"from": "user", "content": "xxx"}...],
"completion": {"from": "assistant", "content": "xxx"}, "completion": {"from": "assistant", "content": "xxx"},
"label": true/false "label": true/false
} }
@ -427,21 +348,18 @@ def tokenize_kto(
the completion, which only contains the assistance's answer, the completion, which only contains the assistance's answer,
and a binary label, which indicates if the sample is prefered or not and a binary label, which indicates if the sample is prefered or not
""" """
if ignore_index is None:
ignore_index = IGNORE_INDEX
prompt = data_point["prompt"] prompt = data_point["prompt"]
completion = data_point["completion"] completion = data_point["completion"]
template = deepcopy(conversation_template) template = deepcopy(conversation_template)
template.clear() template.clear()
if prompt[0].get("from", None) != "human": if prompt[0].get("from", None) != "user":
raise ValueError("conversation should start with human") raise ValueError("conversation should start with user")
if completion.get("from", None) != "assistant": if completion.get("from", None) != "assistant":
raise ValueError("conversation should end with assistant") raise ValueError("conversation should end with assistant")
for mess in prompt: for mess in prompt:
if mess.get("from", None) == "human": if mess.get("from", None) == "user":
template.append_message("user", mess["content"]) template.append_message("user", mess["content"])
elif mess.get("from", None) == "assistant": elif mess.get("from", None) == "assistant":
template.append_message("assistant", mess["content"]) template.append_message("assistant", mess["content"])

View File

@ -88,7 +88,13 @@ def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, s
return -1 return -1
def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]): def tokenize_and_concatenate(
tokenizer: PreTrainedTokenizer,
text: List[str],
require_loss: List[bool],
max_length: int,
discard_non_loss_tokens_at_tail: bool = True,
):
""" """
Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs. Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs.
@ -96,6 +102,13 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization. tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization.
text (List[str]): The list of texts to tokenize. text (List[str]): The list of texts to tokenize.
require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation. require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation.
max_length: used to truncate the input ids
discard_non_loss_tokens_at_tail: whether to discard the non-loss tokens at the tail
if the first round has already exeeded max length
- if the user query already exeeded max length, discard the sample
- if only the first assistant response exeeded max length, truncate the response to fit the max length
else keep the first several complete rounds of the conversations until max length is reached
Returns: Returns:
Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids, Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids,
@ -106,10 +119,17 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
loss_ends = [] loss_ends = []
for s, r in zip(text, require_loss): for s, r in zip(text, require_loss):
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"] tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
if len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
if r: if r:
loss_starts.append(len(input_ids)) loss_starts.append(len(input_ids))
loss_ends.append(len(input_ids) + len(tokenized)) loss_ends.append(len(input_ids) + len(tokenized))
input_ids.extend(tokenized) input_ids.extend(tokenized)
if loss_starts[0] >= max_length:
return None, None, None
if discard_non_loss_tokens_at_tail:
input_ids = input_ids[: loss_ends[-1]]
input_ids = input_ids[:max_length]
loss_ends[-1] = min(max_length, loss_ends[-1])
return input_ids, loss_starts, loss_ends return input_ids, loss_starts, loss_ends
@ -125,6 +145,12 @@ def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: s
content_length = ( content_length = (
prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
) )
# if the tokenized content start with a leading space, we want to keep it in loss calculation
# e.g., Assistant: I am saying...
# if the tokenized content doesn't start with a leading space, we only need to keep the content in loss calculation
# e.g.,
# Assistant: # '\n' as line breaker
# I am saying...
if prompt[first_occur - 1] != " ": if prompt[first_occur - 1] != " ":
chunks.append(prompt[start_idx:first_occur]) chunks.append(prompt[start_idx:first_occur])
chunks.append(prompt[first_occur : first_occur + content_length]) chunks.append(prompt[first_occur : first_occur + content_length])

View File

@ -448,7 +448,7 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the
{"messages": {"messages":
[ [
{ {
"from": "human", "from": "user",
"content": "what are some pranks with a pen i can do?" "content": "what are some pranks with a pen i can do?"
}, },
{ {
@ -529,7 +529,7 @@ Below shows the preference dataset format used in training the reward model.
[ [
{"context": [ {"context": [
{ {
"from": "human", "from": "user",
"content": "Introduce butterflies species in Oregon." "content": "Introduce butterflies species in Oregon."
} }
] ]
@ -598,7 +598,7 @@ In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimi
#### Step 1: Data Collection #### Step 1: Data Collection
PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format. PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "user" and thus the "assistant" needs to generate a response to answer to the "user". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
```json ```json
@ -606,7 +606,7 @@ PPO uses two kinds of training data--- the prompt data and the pretrain data (op
{"messages": {"messages":
[ [
{ {
"from": "human", "from": "user",
"content": "what are some pranks with a pen i can do?" "content": "what are some pranks with a pen i can do?"
} }
... ...
@ -761,12 +761,12 @@ For KTO data preparation, please use the script [prepare_kto_dataset.sh](./examp
```jsonl ```jsonl
{ {
"prompt": [ "prompt": [
{"from": "human", "content": "What are some cuss words in english?"}, {"from": "user", "content": "What are some cuss words in english?"},
{ {
"from": "assistant", "from": "assistant",
"content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, ...", "content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, ...",
}, },
{"from": "human", "content": "What's your favorite one?"}, {"from": "user", "content": "What's your favorite one?"},
], ],
"completion": {"from": "assistant", "content": "Ass."}, # the completion must contain a single line from the assistant. "completion": {"from": "assistant", "content": "Ass."}, # the completion must contain a single line from the assistant.
"label": False, # whether the response is favorable or not "label": False, # whether the response is favorable or not

View File

@ -1 +1 @@
{"context": [{"from": "human", "content": "If I were to give you some coordinates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinates are within any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give you some details about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]} {"context": [{"from": "user", "content": "If I were to giv6rkkurjetes, could you tell me where the location is?"}, {"from": "assistant", "content": "If I were to give you some coordinates, could you tell me where the location is?"}, {"from": "user", "content": "If I were to give you 76i85l5e4henates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinatestw434yy45wn any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give45wywydfgw46 about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}

View File

@ -0,0 +1 @@
{"prompt": [{"from": "user", "content": "What are some cuss words in english?"}, {"from": "assistant", "content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, ..."},{"from": "user", "content": "What's your favorite one?"}],"completion": {"from": "assistant", "content": "Ass."},"label": false}

View File

@ -1 +1 @@
{"messages": [{"from": "human", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "human", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "human", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]} {"messages": [{"from": "user", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "user", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "user", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}

View File

@ -94,7 +94,7 @@ done
# Test DPO/PPO data Preparation # Test DPO/PPO data Preparation
for model in ${MODELS[@]}; do for model in ${MODELS[@]}; do
echo "Testing DPO/PPO data templating for $model" echo "Testing DPO/RM data templating for $model"
SAVE_DIR=$DATA_SAVE_PATH/dpo/$model SAVE_DIR=$DATA_SAVE_PATH/dpo/$model
rm -rf $SAVE_DIR/cache rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl rm -rf $SAVE_DIR/jsonl
@ -109,14 +109,44 @@ for model in ${MODELS[@]}; do
--data_arrow_output_dir $SAVE_DIR/arrow --data_arrow_output_dir $SAVE_DIR/arrow
passed=$? passed=$?
if [ $passed -ne 0 ]; then if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the DPO data templating for $model" echo "[Test]: Failed in the DPO/RM data templating for $model"
exit 1 exit 1
fi fi
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/dpo/test_dpo_data.jsonl \ python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/dpo/test_dpo_data.jsonl \
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type dpo --to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type dpo
passed=$? passed=$?
if [ $passed -ne 0 ]; then if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the DPO data templating test for $model" echo "[Test]: Failed in the DPO/RM data templating test for $model"
exit 1
fi
done
# Test KTO data Preparation
for model in ${MODELS[@]}; do
echo "Testing KTO data templating for $model"
SAVE_DIR=$DATA_SAVE_PATH/kto/$model
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
pretrain=$(get_pretrain $model)
conversation_template_config=$(get_conversation_template_config $model)
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type kto --data_input_dirs $TEST_DATA_DIR/kto \
--tokenizer_dir $pretrain \
--conversation_template_config $conversation_template_config \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the KTO data templating for $model"
exit 1
fi
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/kto/test_kto_data.jsonl \
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type kto
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the KTO data templating test for $model"
exit 1 exit 1
fi fi
done done

View File

@ -62,3 +62,11 @@ if __name__ == "__main__":
assert any( assert any(
[rejected_lable in s for s in to_verify_lable_rejected] [rejected_lable in s for s in to_verify_lable_rejected]
), f"Rejected label {rejected_lable} not in target rejected label {to_verify_lable_chosen}" ), f"Rejected label {rejected_lable} not in target rejected label {to_verify_lable_chosen}"
elif args.data_type == "kto":
sample = data[0]
to_verify_data = to_verify_data[0]
for line in sample["prompt"]:
assert line["content"] in to_verify_data["input_id_decode"]
assert sample["completion"]["content"] in to_verify_data["input_id_decode"]
assert sample["completion"]["content"] in to_verify_data["completion_decode"]
assert sample["label"] == to_verify_data["label"]