replace the customized dataloader setup with the build-in one

This commit is contained in:
YeAnbang
2024-06-07 09:43:42 +00:00
parent 790e1362a6
commit 0d7ff10ea5
12 changed files with 79 additions and 218 deletions

View File

@@ -55,8 +55,6 @@ def supervised_tokenize_sft(
for mess in messages:
from_str = mess["from"]
if from_str is None:
print(mess)
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
@@ -133,24 +131,20 @@ def supervised_tokenize_sft(
labels[-1] = tokenizer.eos_token_id
# For some model without bos/eos may raise the following errors
try:
inputs_decode = tokenizer.decode(tokenized)
start = 0
end = 0
label_decode = []
for i in range(len(labels)):
if labels[i] == ignore_index:
if start != end:
label_decode.append(tokenizer.decode(labels[start + 1 : i], skip_special_tokens=False))
start = i
end = i
else:
end = i
if i == len(labels) - 1:
label_decode.append(tokenizer.decode(labels[start + 1 :], skip_special_tokens=False))
except TypeError as e:
raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}")
inputs_decode = tokenizer.decode(tokenized)
start = 0
end = 0
label_decode = []
for i in range(len(labels)):
if labels[i] == ignore_index:
if start != end:
label_decode.append(tokenizer.decode(labels[start + 1 : i], skip_special_tokens=False))
start = i
end = i
else:
end = i
if i == len(labels) - 1:
label_decode.append(tokenizer.decode(labels[start + 1 :], skip_special_tokens=False))
# Check if all labels are ignored, this may happen when the tokenized length is too long
if labels.count(ignore_index) == len(labels):