diff --git a/data.py b/data.py index e77f76f0..7e103dc1 100644 --- a/data.py +++ b/data.py @@ -106,13 +106,13 @@ def load_data(config, tokenizer): train_dataset = train_dataset.map( lambda ele: tokenize_inputs(config, tokenizer, ele), batched=True, - remove_columns=["source", "prompt"], + remove_columns=["source", "prompt", "id", "response"], **kwargs ) val_dataset = val_dataset.map( lambda ele: tokenize_inputs(config, tokenizer, ele), batched=True, - remove_columns=["source", "prompt"], + remove_columns=["source", "prompt", "id", "response"], **kwargs )