From 65ec606f21e4ba708ceb7a1d301f3959a7b85b4c Mon Sep 17 00:00:00 2001 From: Zach Date: Tue, 4 Apr 2023 22:01:55 +0000 Subject: [PATCH] fix: prompt len for larger --- data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data.py b/data.py index 0de26cfa..358dd007 100644 --- a/data.py +++ b/data.py @@ -9,7 +9,6 @@ from transformers import DefaultDataCollator def tokenize_inputs(config, tokenizer, examples): max_length = config["max_length"] - input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id) # ignore bos newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0] if newline_tokens[0] == tokenizer.bos_token_id: @@ -29,6 +28,7 @@ def tokenize_inputs(config, tokenizer, examples): # we need to include some labels if prompt_len >= max_length - 1: prompt = prompt[:len(prompt) // 2] + prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0]) input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token, truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()