diff --git a/data.py b/data.py index 0cff50c7..db322793 100644 --- a/data.py +++ b/data.py @@ -70,8 +70,7 @@ def load_data(config, tokenizer): else: dataset = load_dataset(dataset_path) - uuids = dataset.filter(lambda x: x["source"] == "nomic") - dataset = dataset.filter(lambda x: x["source"] != "nomic") + uuids = load_dataset("json", data_files="watermark.jsonl", split="train") dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) train_dataset, val_dataset = dataset["train"], dataset["test"]