From eac7734cbfcc3e45eda2c6ac92e8d35e0c237958 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sun, 26 Mar 2023 17:45:31 +0000 Subject: [PATCH] fix: add eos --- data.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/data.py b/data.py index 6db7a3c9..ef84cc2d 100644 --- a/data.py +++ b/data.py @@ -15,21 +15,24 @@ def tokenize_inputs(config, tokenizer, examples): out = {"labels": [], "attention_mask": []} for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])): - # HACK to get 512 to work for now - input_tokens = tokenizer(prompt, truncation=True, max_length=max_length //2, return_tensors="pt")["input_ids"].squeeze() + input_tokens = tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze() input_len = len(input_tokens) # plus one since we remove bos from response - remaining_tokens = max_length - input_len - len(newline_tokens) + 1 - + # but we subtract one since we want to add eos token + remaining_tokens = max_length - input_len - len(newline_tokens) + # remove bos target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:] input_ids[i, :input_len] = input_tokens # add newline between prompt and response newline_plus_inputs = input_len + len(newline_tokens) input_ids[i, input_len: newline_plus_inputs] = newline_tokens + # add target tokens, remove bos input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens + # add eos token, enforce stopping + input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id labels = input_ids[i].clone() labels[: newline_plus_inputs] = -100