[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2024-05-28 08:02:42 +00:00
committed by YeAnbang
parent b1031f7244
commit 1b880ce095
26 changed files with 69 additions and 57 deletions

View File

@@ -25,7 +25,9 @@ class Conversation:
Setup the conversation template from config
"""
tokenizer.chat_template = config["chat_template"]
conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"], config["end_of_assistant"])
conv = cls(
tokenizer, config["system_message"], config["chat_template"], config["stop_ids"], config["end_of_assistant"]
)
conv.clear()
return conv

View File

@@ -97,8 +97,9 @@ def supervised_tokenize_sft(
target_turn = turns[target_turn_index - 1]
prompt = template.get_prompt(2 * target_turn)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt,
conversation_template.end_of_assistant)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * target_turn], prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
labels = [ignore_index] * len(tokenized)
@@ -106,7 +107,7 @@ def supervised_tokenize_sft(
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
labels = labels + [ignore_index]
labels[start : end] = tokenized[start : end]
labels[start:end] = tokenized[start:end]
# truncate the sequence at the last token that requires loss calculation
to_truncate_len = 0
@@ -139,14 +140,14 @@ def supervised_tokenize_sft(
label_decode = []
for i in range(len(labels)):
if labels[i] == ignore_index:
if start!=end:
label_decode.append(tokenizer.decode(labels[start+1:i], skip_special_tokens=False))
if start != end:
label_decode.append(tokenizer.decode(labels[start + 1 : i], skip_special_tokens=False))
start = i
end = i
else:
end = i
if i == len(labels) - 1:
label_decode.append(tokenizer.decode(labels[start+1:], skip_special_tokens=False))
label_decode.append(tokenizer.decode(labels[start + 1 :], skip_special_tokens=False))
except TypeError as e:
raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}")
@@ -216,8 +217,9 @@ def tokenize_prompt_dataset(
# Prepare data
prompt = template.get_prompt(target_turn, add_generation_prompt=True)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: target_turn], prompt,
conversation_template.end_of_assistant)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[:target_turn], prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
@@ -246,8 +248,9 @@ def apply_rlhf_data_format(
):
target_turn = int(len(template.messages) / 2)
prompt = template.get_prompt(target_turn * 2)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt,
template.end_of_assistant)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
loss_mask = [0] * len(tokenized)
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
@@ -260,8 +263,8 @@ def apply_rlhf_data_format(
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
loss_mask = loss_mask + [1]
loss_mask[start : end] = [1] * len(loss_mask[start : end])
label_decode.append(tokenizer.decode(tokenized[start : end], skip_special_tokens=False))
loss_mask[start:end] = [1] * len(loss_mask[start:end])
label_decode.append(tokenizer.decode(tokenized[start:end], skip_special_tokens=False))
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized

View File

@@ -121,8 +121,10 @@ def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: s
for line in messages:
content_length = len(line["content"])
first_occur = prompt.find(line["content"], start_idx)
if line["role"].lower() == "assistant" and end_of_assistant in prompt[first_occur + content_length:]:
content_length = prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
if line["role"].lower() == "assistant" and end_of_assistant in prompt[first_occur + content_length :]:
content_length = (
prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
)
if prompt[first_occur - 1] != " ":
chunks.append(prompt[start_idx:first_occur])
chunks.append(prompt[first_occur : first_occur + content_length])

View File

@@ -37,4 +37,4 @@ class Critic(BaseModel):
return self.model.get_input_embeddings()
def get_output_embeddings(self):
return self.model.get_output_embeddings()
return self.model.get_output_embeddings()