fix: update train scripts and configs for other models (#1164)

* feat: falcon config

* feat: mpt config

* chore: gitignore

* refactor: step calculation

* fix: attention mask + shuffle on epoch end

* fix: return tensors

* fix: wait for everyone

* chore: config

* chore: ds config

* fix: remove ccols

* fix: logging and saving

* chore: add einops
This commit is contained in:
Zach Nussbaum
2023-07-12 15:18:24 -04:00
committed by GitHub
parent e8b19b8e82
commit 6c4f449b7a
9 changed files with 245 additions and 29 deletions

View File

@@ -12,7 +12,7 @@ def tokenize_inputs(config, tokenizer, examples):
# hacky backward compatible
different_eos = tokenizer.eos_token != "</s>"
out = {"labels": [], "input_ids": []}
out = {"labels": [], "input_ids": [], "attention_mask": []}
for prompt, response in zip(examples["prompt"], examples["response"]):
if different_eos:
if response.count("</s> \n") > 0:
@@ -49,9 +49,10 @@ def tokenize_inputs(config, tokenizer, examples):
print(response)
raise
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
padded = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length, return_tensors="pt")
out["labels"].append(labels)
out["input_ids"].append(input_tokens)
out["input_ids"].append(padded["input_ids"])
out["attention_mask"].append(padded["attention_mask"])
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
@@ -72,7 +73,7 @@ def load_data(config, tokenizer):
dataset = load_dataset("json", data_files=files, split="train")
else:
dataset = load_dataset(dataset_path, split="train")
dataset = load_dataset(dataset_path, split="train", revision=config["revision"] if "revision" in config else None)
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
@@ -83,19 +84,23 @@ def load_data(config, tokenizer):
else:
kwargs = {}
cols_to_keep = ["input_ids", "labels", "attention_mask"]
# tokenize inputs and return labels and attention mask
train_dataset = train_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
batched=True,
remove_columns=["source", "prompt"],
**kwargs
)
remove_cols = [col for col in train_dataset.column_names if col not in cols_to_keep]
train_dataset = train_dataset.remove_columns(remove_cols)
val_dataset = val_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
batched=True,
remove_columns=["source", "prompt"],
**kwargs
)
remove_cols = [col for col in val_dataset.column_names if col not in cols_to_keep]
val_dataset = val_dataset.remove_columns(remove_cols)
train_dataset = train_dataset.with_format("torch")
val_dataset = val_dataset.with_format("torch")
@@ -106,12 +111,14 @@ def load_data(config, tokenizer):
train_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
shuffle=True,
)
val_dataloader = DataLoader(
val_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
shuffle=True,
)
return train_dataloader, val_dataloader