mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-06 02:50:36 +00:00
fix: update train scripts and configs for other models (#1164)
* feat: falcon config * feat: mpt config * chore: gitignore * refactor: step calculation * fix: attention mask + shuffle on epoch end * fix: return tensors * fix: wait for everyone * chore: config * chore: ds config * fix: remove ccols * fix: logging and saving * chore: add einops
This commit is contained in:
@@ -12,7 +12,7 @@ def tokenize_inputs(config, tokenizer, examples):
|
||||
|
||||
# hacky backward compatible
|
||||
different_eos = tokenizer.eos_token != "</s>"
|
||||
out = {"labels": [], "input_ids": []}
|
||||
out = {"labels": [], "input_ids": [], "attention_mask": []}
|
||||
for prompt, response in zip(examples["prompt"], examples["response"]):
|
||||
if different_eos:
|
||||
if response.count("</s> \n") > 0:
|
||||
@@ -49,9 +49,10 @@ def tokenize_inputs(config, tokenizer, examples):
|
||||
print(response)
|
||||
raise
|
||||
|
||||
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
|
||||
padded = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length, return_tensors="pt")
|
||||
out["labels"].append(labels)
|
||||
out["input_ids"].append(input_tokens)
|
||||
out["input_ids"].append(padded["input_ids"])
|
||||
out["attention_mask"].append(padded["attention_mask"])
|
||||
|
||||
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
||||
|
||||
@@ -72,7 +73,7 @@ def load_data(config, tokenizer):
|
||||
dataset = load_dataset("json", data_files=files, split="train")
|
||||
|
||||
else:
|
||||
dataset = load_dataset(dataset_path, split="train")
|
||||
dataset = load_dataset(dataset_path, split="train", revision=config["revision"] if "revision" in config else None)
|
||||
|
||||
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
||||
|
||||
@@ -83,19 +84,23 @@ def load_data(config, tokenizer):
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
cols_to_keep = ["input_ids", "labels", "attention_mask"]
|
||||
# tokenize inputs and return labels and attention mask
|
||||
train_dataset = train_dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||
batched=True,
|
||||
remove_columns=["source", "prompt"],
|
||||
**kwargs
|
||||
)
|
||||
remove_cols = [col for col in train_dataset.column_names if col not in cols_to_keep]
|
||||
train_dataset = train_dataset.remove_columns(remove_cols)
|
||||
|
||||
val_dataset = val_dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||
batched=True,
|
||||
remove_columns=["source", "prompt"],
|
||||
**kwargs
|
||||
)
|
||||
remove_cols = [col for col in val_dataset.column_names if col not in cols_to_keep]
|
||||
val_dataset = val_dataset.remove_columns(remove_cols)
|
||||
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
val_dataset = val_dataset.with_format("torch")
|
||||
@@ -106,12 +111,14 @@ def load_data(config, tokenizer):
|
||||
train_dataset,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
batch_size=config["batch_size"],
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
batch_size=config["batch_size"],
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
return train_dataloader, val_dataloader
|
||||
|
Reference in New Issue
Block a user