mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-26 15:31:55 +00:00
fix: data for inference
This commit is contained in:
parent
fb9ff9c40d
commit
1b14b1f723
48
data.py
48
data.py
@ -75,7 +75,7 @@ def load_data(config, tokenizer):
|
||||
dataset = load_dataset("json", data_files=files, split="train")
|
||||
|
||||
else:
|
||||
dataset = load_dataset(dataset_path)
|
||||
dataset = load_dataset(dataset_path, split="train")
|
||||
|
||||
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
||||
|
||||
@ -118,3 +118,49 @@ def load_data(config, tokenizer):
|
||||
)
|
||||
|
||||
return train_dataloader, val_dataloader
|
||||
|
||||
|
||||
def load_data_for_inference(config, tokenizer):
|
||||
dataset_path = config["dataset_path"]
|
||||
|
||||
if os.path.exists(dataset_path):
|
||||
# check if path is a directory
|
||||
if os.path.isdir(dataset_path):
|
||||
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
|
||||
else:
|
||||
files = [dataset_path]
|
||||
|
||||
print(f"Reading files {files}")
|
||||
|
||||
dataset = load_dataset("json", data_files=files, split="train")
|
||||
|
||||
else:
|
||||
dataset = load_dataset(dataset_path, split="train")
|
||||
|
||||
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
||||
|
||||
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||
|
||||
train_dataset = train_dataset.add_column("index", list(range(len(train_dataset))))
|
||||
val_dataset = val_dataset.add_column("index", list(range(len(val_dataset))))
|
||||
|
||||
if config["streaming"] is False:
|
||||
kwargs = {"num_proc": config["num_proc"]}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
# tokenize inputs and return labels and attention mask
|
||||
train_dataset = train_dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||
batched=True,
|
||||
**kwargs
|
||||
)
|
||||
val_dataset = val_dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||
batched=True,
|
||||
**kwargs
|
||||
)
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
val_dataset = val_dataset.with_format("torch")
|
||||
|
||||
return train_dataset, val_dataset
|
||||
|
@ -9,4 +9,5 @@ peft
|
||||
nodelist-inflator
|
||||
deepspeed
|
||||
sentencepiece
|
||||
jsonlines
|
||||
jsonlines
|
||||
nomic
|
Loading…
Reference in New Issue
Block a user