mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-25 15:02:03 +00:00
fix: add eos
This commit is contained in:
parent
2daecd6066
commit
eac7734cbf
11
data.py
11
data.py
@ -15,21 +15,24 @@ def tokenize_inputs(config, tokenizer, examples):
|
|||||||
|
|
||||||
out = {"labels": [], "attention_mask": []}
|
out = {"labels": [], "attention_mask": []}
|
||||||
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])):
|
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])):
|
||||||
# HACK to get 512 to work for now
|
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze()
|
||||||
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length //2, return_tensors="pt")["input_ids"].squeeze()
|
|
||||||
input_len = len(input_tokens)
|
input_len = len(input_tokens)
|
||||||
|
|
||||||
# plus one since we remove bos from response
|
# plus one since we remove bos from response
|
||||||
remaining_tokens = max_length - input_len - len(newline_tokens) + 1
|
# but we subtract one since we want to add eos token
|
||||||
|
remaining_tokens = max_length - input_len - len(newline_tokens)
|
||||||
|
# remove bos
|
||||||
target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]
|
target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]
|
||||||
|
|
||||||
input_ids[i, :input_len] = input_tokens
|
input_ids[i, :input_len] = input_tokens
|
||||||
# add newline between prompt and response
|
# add newline between prompt and response
|
||||||
newline_plus_inputs = input_len + len(newline_tokens)
|
newline_plus_inputs = input_len + len(newline_tokens)
|
||||||
input_ids[i, input_len: newline_plus_inputs] = newline_tokens
|
input_ids[i, input_len: newline_plus_inputs] = newline_tokens
|
||||||
|
|
||||||
# add target tokens, remove bos
|
# add target tokens, remove bos
|
||||||
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens
|
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens
|
||||||
|
# add eos token, enforce stopping
|
||||||
|
input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id
|
||||||
|
|
||||||
labels = input_ids[i].clone()
|
labels = input_ids[i].clone()
|
||||||
labels[: newline_plus_inputs] = -100
|
labels[: newline_plus_inputs] = -100
|
||||||
|
Loading…
Reference in New Issue
Block a user