mono repo structure

This commit is contained in:
Richard Guo
2023-05-01 15:45:23 -04:00
parent b00d338c1e
commit 02d1bdb0be
53 changed files with 20 additions and 254 deletions

View File

@@ -0,0 +1,17 @@
# Inference on Training Data
## Run Inference
```bash
torchrun --master_port=29085 --nproc-per-node 8 inference.py --config=configs/inference/gptj.yaml
```
## Visualizations
```bash
python build_map.py
```
will build a map in `Atlas`, one using the internal clustering algorithm provided by Nomic and one using the embeddings generated by the finetuned model.

View File

@@ -0,0 +1,283 @@
# Training Trials and Tribulations of gpt4all
This is a training log for both the LoRa and full model training we underwent
## Inspiration
## Initial Experiment
We train an initial LoRa model on ~700k examples including data from P3/BLOOM, StackOverflow, and unified_chip2.
We trained using 8 x A100 80GB GPUs.
We used the initial parameters:
| Hyperparameter | Value |
|----------------|-------|
| Per Device BS | 4 |
| Global BS | 32 |
| Learning rate | 5e-5 |
| Epochs | 1 |
| Max length | 1024 |
| Weight decay | 0 |
| Warmup Steps | 100 |
We randomly shuffle and set aside 5% of the data for validation.
We had an initial bug in logging the training loss but we noticed a decrease in validation loss.
![](figs/first_lora.png)
During generation, we noticed that our model seemed to never generate an `eos` token. We realized that this was due to the fact that we set
the tokenizer `pad` token equal to the `eos` token and most likely never learned to "stop". However, generations seemed to be working
well, even with greedy generation.
For example, when prompted with `Generate a python script to reverse a string.`, the model outputs valid markdown and a working script
Here's a Python script to reverse a string:
```python
def reverse_string(string):
return ''.join(reversed(string))
```
This function takes a string as an argument and returns the reversed string. The `reversed()` function reverses the string and returns a list of characters. The `join()` function joins the list of characters into a string.
You can use this function like this:
```python
string = 'hello world'
print(reverse_string(string))
```
This will output `'olleh dna hwl'`.
Note that this function does not handle strings with spaces or other special characters. If you need to handle those cases, you can use a regular expression to split the string into words and then reverse each word.
```python
import re
def reverse_string(string):
words = re.split(' ', string)
return ''.join(reversed(word) for word in words)
```
This function takes a string as an argument and returns the reversed string. The `re.split()` function splits the string into words using a space as the delimiter. The `reversed()` function reverses each word and returns a list of characters. The `join()` function joins the list of characters into a string.
You can use this function like this:
```python
string = 'hello world'
print(reverse_string(string))
```
This will output `'olleh dna hwl'`.
Note that this function also handles strings with spaces or other special characters.
Answer: Here's a Python script to reverse a string:
```python
def reverse_string(string):
return ''.join(reversed(string))
```
This function takes a string as an argument and returns the reversed string. The `reversed()` function reverses the string and returns a list of characters. The `join()` function joins the list of characters into a string.
You can use this function like this:
```python
string = 'hello world'
print(reverse_string(string))
```
This will output `'olleh dna hwl'`.
Note that this function does not handle strings with
although it duplicates and continues generating past answering the full question.
The base model doesn't answer the question and hallucinates
```
Write a python script to reverse a string. The string should be reversed in place.
The string should be reversed in place.
The string should be reversed in place. The string should be reversed in place.
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be
```
## EOS and Accidental Duplication
Seeing as the model never stopped, we experimented with training a model with a separate token for `eos` and `pad`. Since we couldn't find a `pad` token present in the vocabulary, we added one to the tokenizer and expanded the embedding size of the model (from 32,000). In theory, we could have expanded the embedding size by a factor of 64 to improve throughput and performance, as [noted here](https://twitter.com/ctnzr/status/1623758178587648000?s=20).
For every sequence, we appended an `eos` token in hopes the model would learn to exit.
We successfully trained a model using the same parameters as before ![](figs/duplicate_loss.png)
During generation, our model exited early even with greedy generations
You can use the `reversed()` function to reverse a string in Python. Here's an example:
```python
string = 'hello world'
reversed_string = reversed(string)
print(reversed_string)
```
Output:
```
world hello
```
For a harder prompt where we try to trick the model with comments
```python
#this code prints a string reversed
my_string = "hello how are you"
print(len(my_string))
My code above does not work. Can you help me?
```
The model correctly answers
The code you provided does not print a string reversed. It prints the length of the string "hello how are you".
We realized that we had two bugs however:
- We accidentally duplicated data and effectively trained for 2 epochs instead of 1
- We added an eos token to every sequence, even those that we truncated (e.g. long code that exceeds the 1024).
## Conditional EOS and 1 Epoch
Using the same parameters, we then trained a model using a "conditional" eos token where we only add an `eos` when the inputs are less than the maximum sequence length for one epoch.
Our training and validation loss looks fairly normal
![](figs/single_epoch.png)
However, we found generations to be slightly poorer. For the same prompt of `Generate a python script to reverse a string.`, the model generates
Write a python script to reverse a string.
Here is a python script to reverse a string:
import string
def reverse_string(string):
return string[::-1]
print(reverse_string('hello world'))
Output:
world hello
The prompt
```python
#this code prints a string reversed
my_string = "hello how are you"
print(len(my_string))
My code above does not work. Can you help me?
```
does not generate any text.
And the prompt `"Generate a python script to make a get request to an api endpoint."` generates
I'm sorry, I cannot provide a specific answer to this question as it requires more context and details about the API endpoint and the specific task you are trying to accomplish. Can you please provide more information?
## Multi Epoch and Full Model Training
We decided to remove the entire Bigscience/P3 subset from the final training dataset due to data diversity considerations.
P3 contains many homogeneous prompts which produce short and homogeneous responses from GPT-3.5-Turbo.
The final dataset is ~400k examples.
We train a LoRa model using the parameters
| Hyperparameter | Value |
|----------------|-------|
| Per Device BS | 4 |
| Global BS | 32 |
| Learning rate | 5e-5 |
| Epochs | 4 |
| Max length | 1024 |
| Weight decay | 0 |
| Warmup Steps | 100 |
We additionally train a full model
| Hyperparameter | Value |
|----------------|-------|
| Per Device BS | 32 |
| Global BS | 256 |
| Learning rate | 5e-5 |
| Epochs | 2 |
| Max length | 1024 |
| Weight decay | 0 |
| Warmup Steps | 100 |
Taking inspiration from [the Alpaca Repo](https://github.com/tatsu-lab/stanford_alpaca), we roughly scale the learning rate by `sqrt(k)`, where `k` is the increase in batch size, where Alpaca used a batch size of 128 and learning rate of 2e-5.
Comparing our model LoRa to the [Alpaca LoRa](https://huggingface.co/tloen/alpaca-lora-7b), our model has lower perplexity. Qualitatively, training on 3 epochs performed the best on perplexity as well as qualitative examples.
We tried training a full model using the parameters above, but found that during the second epoch the model diverged and samples generated post training were worse than the first epoch.
## GPT-J Training
### Model Training Divergence
We trained multiple [GPT-J models](https://huggingface.co/EleutherAI/gpt-j-6b) with varying success. We found that training the full model lead to diverged post epoch 1. ![](figs/overfit-gpt-j.png)
We release the checkpoint after epoch 1.
Using Atlas, we extracted the embeddings of each point in the dataset and calculated the loss per sequence. We then uploaded [this to Atlas](https://atlas.nomic.ai/map/gpt4all-j-post-epoch-1-embeddings) and noticed that the higher loss items seem to cluster. On further inspection, the highest density clusters seemded to be of prompt/response pairs that asked for creative-like generations such as `Generate a story about ...` ![](figs/clustering_overfit.png)
### GPT4All-J Hyperparameters
We varied learning rate, learning rate schedule, and weight decay following suggestions from the [original GPT-J codebase](https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md) but found no real performance difference (qualitatively or quantitatively) when varying these parameters.
The final model was trained using the following hyperparameters with a linear warmup followed by constant learning rate:
| Hyperparameter | Value |
|----------------|-------|
| Per Device BS | 32 |
| Global BS | 256 |
| Learning rate | 2e-5 |
| Epochs | 2 |
| Max length | 1024 |
| Weight decay | 0 |
| Warmup Steps | 500 |
The LoRA model was trained using using the following hyperparameters with a linear warmup followed by constant learning rate:
| Hyperparameter | Value |
|----------------|-------|
| Per Device BS | 4 |
| Global BS | 32 |
| Learning rate | 2e-5 |
| Epochs | 2 |
| Max length | 1024 |
| Weight decay | 0 |
| Warmup Steps | 500 |

View File

@@ -0,0 +1,54 @@
import numpy as np
from nomic import atlas
import glob
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets
from sklearn.decomposition import PCA
files = glob.glob("inference/*.jsonl")
print(files)
df = concatenate_datasets([load_dataset("json", data_files=file, split="train") for file in tqdm(files)])
print(len(df))
print(df)
df = df.map(lambda example: {"inputs": [prompt + "\n" + response for prompt, response in zip(example["prompt"], example["response"])]},
batched=True,
num_proc=64)
df = df.map(lambda example: {"trained_on": [int(t) for t in example["is_train"]]},
batched=True,
num_proc=64)
df = df.remove_columns("is_train")
text = df.remove_columns(["labels", "input_ids", "embeddings"])
text_df = [text[i] for i in range(len(text))]
atlas.map_text(text_df, indexed_field="inputs",
name="CHANGE ME!",
colorable_fields=["source", "loss", "trained_on"],
reset_project_if_exists=True,
)
# index is local to train/test split, regenerate
data = df.remove_columns(["labels", "input_ids", "index"])
data = data.add_column("index", list(range(len(data))))
# max embed dim is 2048 for now
# note! this is slow in pyarrow/hf datasets
embeddings = np.array(data["embeddings"])
print("embeddings shape:", embeddings.shape)
embeddings = PCA(n_components=2048).fit_transform(embeddings)
data = data.remove_columns(["embeddings"])
columns = data.to_pandas().to_dict("records")
atlas.map_embeddings(embeddings,
data=columns,
id_field="index",
name="CHANGE ME!",
colorable_fields=["source", "loss", "trained_on"],
build_topic_model=True,
topic_label_field="inputs",
reset_project_if_exists=True,)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

74
gpt4all-training/clean.py Normal file
View File

@@ -0,0 +1,74 @@
import numpy as np
import glob
import os
import json
import jsonlines
import pandas as pd
prompt_generation_dir = "raw_data_sanity_cleaned_without_p3/"
for file in glob.glob(os.path.join(prompt_generation_dir, "*.jsonl")):
if "clean.jsonl" in file:
continue
data = []
print(file)
with open(file) as f:
for line in f:
try:
contents = json.loads(line)
data.append(contents)
except BaseException:
pass
processed = []
for item in data:
if 'source' not in item:
item['source'] = 'unspecified'
if 'model_settings' in item:
item.pop('model_settings', None)
for key in list(item.keys()):
if key not in ['source', 'prompt', 'response']:
#print(item[key])
item.pop(key, None)
if isinstance(item['prompt'], dict):
if "value" in item["prompt"]:
item["prompt"] = item["prompt"]["value"]
elif "description" in item["prompt"]:
item["prompt"] = item["prompt"]["description"]
else:
continue
elif not isinstance(item['prompt'], str):
continue
if isinstance(item['response'], dict):
if "value" in item["response"]:
item["response"] = item["response"]["value"]
elif "description" in item["response"]:
item["response"] = item["response"]["description"]
else:
continue
elif not isinstance(item['response'], str):
continue
if item:
processed.append(item)
df = pd.DataFrame(processed)
prev_len = len(df)
# drop empty or null string
df = df.dropna(subset=['prompt', 'response'])
df = df[df['prompt'] != '']
df = df[df['response'] != '']
df = df[df["prompt"].str.len() > 1]
curr_len = len(df)
print(f"Removed {prev_len - curr_len} rows")
clean_name = file.split(".jsonl")[0] + "_clean.jsonl"
print(f"writing to {curr_len} rows to {clean_name}")
df.to_json(clean_name, orient="records", lines=True)

View File

@@ -0,0 +1,48 @@
{
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"train_micro_batch_size_per_gpu": "auto",
"fp16": {
"enabled": "auto",
"min_loss_scale": 1,
"loss_scale_window": 1000,
"hysteresis": 2,
"initial_scale_power": 32
},
"bf16": {
"enabled": "auto"
},
"gradient_clipping": 1,
"zero_optimization": {
"stage": 2,
"offload_param": {
"device": "none"
},
"offload_optimizer": {
"device": "none"
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [
0.9,
0.999
],
"eps": 1e-08
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear"
}
}
}

View File

@@ -0,0 +1,48 @@
{
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"train_micro_batch_size_per_gpu": "auto",
"fp16": {
"enabled": "auto",
"min_loss_scale": 1,
"loss_scale_window": 1000,
"hysteresis": 2,
"initial_scale_power": 32
},
"bf16": {
"enabled": "auto"
},
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 2,
"offload_param": {
"device": "none"
},
"offload_optimizer": {
"device": "none"
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [
0.9,
0.999
],
"eps": 1e-08
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear"
}
}
}

View File

@@ -0,0 +1,48 @@
{
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"train_micro_batch_size_per_gpu": "auto",
"fp16": {
"enabled": "auto",
"min_loss_scale": 1,
"loss_scale_window": 1000,
"hysteresis": 2,
"initial_scale_power": 32
},
"bf16": {
"enabled": "auto"
},
"gradient_clipping": 1,
"zero_optimization": {
"stage": 2,
"offload_param": {
"device": "cpu"
},
"offload_optimizer": {
"device": "cpu"
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [
0.9,
0.999
],
"eps": 1e-08
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear"
}
}
}

View File

@@ -0,0 +1,5 @@
# model/tokenizer
model_name: "zpn/llama-7b"
tokenizer_name: "zpn/llama-7b"
lora: true
lora_path: "tloen/alpaca-lora-7b"

View File

@@ -0,0 +1,4 @@
# model/tokenizer
model_name: "nomic-ai/gpt4all-warmup-lr-epoch_0"
tokenizer_name: "EleutherAI/gpt-j-6b"
lora: false

View File

@@ -0,0 +1,5 @@
# model/tokenizer
model_name: "EleutherAI/gpt-j-6b"
tokenizer_name: "EleutherAI/gpt-j-6B"
lora: true
lora_path: "nomic-ai/gpt4all-gptj-lora-epoch_1"

View File

@@ -0,0 +1,5 @@
# model/tokenizer
model_name: "zpn/llama-7b"
tokenizer_name: "zpn/llama-7b"
lora: true
lora_path: "nomic-ai/gpt4all-lora"

View File

@@ -0,0 +1,9 @@
# model/tokenizer
model_name: "zpn/llama-7b"
tokenizer_name: "zpn/llama-7b"
lora: true
lora_path: "nomic-ai/gpt4all-lora"
max_new_tokens: 512
temperature: 0
prompt: null

View File

@@ -0,0 +1,15 @@
# model/tokenizer
model_name: "nomic-ai/gpt4all-warmup-lr-epoch_1"
tokenizer_name: "EleutherAI/gpt-j-6b"
lora: false
max_new_tokens: 512
temperature: 0.001
prompt: |
#this code prints a string reversed
my_string = "hello how are you"
print(len(my_string))
My code above does not work. Can you help me?

View File

@@ -0,0 +1,15 @@
# model/tokenizer
model_name: "EleutherAI/gpt-j-6b"
tokenizer_name: "EleutherAI/gpt-j-6b"
lora: true
lora_path: "nomic-ai/gpt4all-gptj-lora-epoch_0"
max_new_tokens: 512
temperature: 0
prompt: |
#this code prints a string reversed
my_string = "hello how are you"
print(len(my_string))
My code above does not work. Can you help me?

View File

@@ -0,0 +1,14 @@
# model/tokenizer
model_name: # REPLACE WITH LLAMA MODEL NAME
tokenizer_name: # REPLACE WITH LLAMA MODEL NAME
max_new_tokens: 512
temperature: 0.001
prompt: |
#this code prints a string reversed
my_string = "hello how are you"
print(len(my_string))
My code above does not work. Can you help me?

View File

@@ -0,0 +1,14 @@
# model/tokenizer
model_name: "nomic-ai/gpt4all-warmup-lr-epoch_1"
tokenizer_name: "EleutherAI/gpt-j-6B"
# dataset
streaming: false
num_proc: 64
dataset_path: "nomic-ai/turbo-500k-multi"
max_length: 1024
batch_size: 32
# logging
seed: 42

View File

@@ -0,0 +1,30 @@
# model/tokenizer
model_name: # add model here
tokenizer_name: # add model here
gradient_checkpointing: true
save_name: # CHANGE
# dataset
streaming: false
num_proc: 64
dataset_path: # update
max_length: 1024
batch_size: 32
# train dynamics
lr: 5.0e-5
eval_every: 800
eval_steps: 100
save_every: 800
output_dir: # CHANGE
checkpoint: null
lora: false
warmup_steps: 100
num_epochs: 2
# logging
wandb: true
wandb_entity: # update
wandb_project_name: # update
seed: 42

View File

@@ -0,0 +1,33 @@
# model/tokenizer
model_name: "EleutherAI/gpt-j-6B"
tokenizer_name: "EleutherAI/gpt-j-6B"
gradient_checkpointing: true
save_name: # CHANGE
# dataset
streaming: false
num_proc: 64
dataset_path: # CHANGE
max_length: 1024
batch_size: 32
# train dynamics
lr: 2.0e-5
min_lr: 0
weight_decay: 0.0
eval_every: 500
eval_steps: 105
save_every: 500
log_grads_every: 100
output_dir: # CHANGE
checkpoint: null
lora: false
warmup_steps: 500
num_epochs: 2
# logging
wandb: true
wandb_entity: # CHANGE
wandb_project_name: # CHANGE
seed: 42

View File

@@ -0,0 +1,33 @@
# model/tokenizer
model_name: "EleutherAI/gpt-j-6b"
tokenizer_name: "EleutherAI/gpt-j-6b"
gradient_checkpointing: false
save_name: # CHANGE
# dataset
streaming: false
num_proc: 64
dataset_path: # CHANGE
max_length: 1024
batch_size: 1
# train dynamics
lr: 2.0e-5
min_lr: 0
weight_decay: 0.0
eval_every: 500
eval_steps: 105
save_every: 500
log_grads_every: 500
output_dir: # CHANGE
checkpoint: null
lora: true
warmup_steps: 500
num_epochs: 2
# logging
wandb: true
wandb_entity: # CHANGE
wandb_project_name: # CHANGE
seed: 42

View File

@@ -0,0 +1,31 @@
# model/tokenizer
model_name: # update
tokenizer_name: # update
gradient_checkpointing: false
save_name: # CHANGE
# dataset
streaming: false
num_proc: 64
dataset_path: # CHANGE
max_length: 1024
batch_size: 4
# train dynamics
lr: 5.0e-5
min_lr: 0
weight_decay: 0.0
eval_every: 2000
eval_steps: 100
save_every: 2000
output_dir: # CHANGE
checkpoint: null
lora: true
warmup_steps: 100
num_epochs: 2
# logging
wandb: true
wandb_entity: # update
wandb_project_name: # update
seed: 42

View File

@@ -0,0 +1,8 @@
#!/bin/bash
export WORKER_IP=$1
N_GPUS=8
# create dir if doesn't exist
sudo mkdir -p /job
printf "localhost slots=$N_GPUS\n$WORKER_IP slots=$N_GPUS" | sudo tee /job/hostfile
echo /job/hostfile

167
gpt4all-training/data.py Normal file
View File

@@ -0,0 +1,167 @@
import glob
import torch
from datasets import load_dataset, concatenate_datasets
import os
from torch.utils.data import DataLoader
from transformers import DefaultDataCollator
def tokenize_inputs(config, tokenizer, examples):
max_length = config["max_length"]
# hacky backward compatible
different_eos = tokenizer.eos_token != "</s>"
out = {"labels": [], "input_ids": []}
for prompt, response in zip(examples["prompt"], examples["response"]):
if different_eos:
if response.count("</s> \n") > 0:
response = response.replace("</s> \n", f"{tokenizer.eos_token} \n")
prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0])
# hack if our prompt is super long
# we need to include some labels so we arbitrarily trunacate at max_length // 2
# if the length is too long
if prompt_len >= max_length // 2:
# if prompt is too long, truncate
# but make sure to truncate to at max 1024 tokens
new_len = min(max_length // 2, len(prompt) // 2)
prompt = prompt[:new_len]
# get new prompt length
prompt_len = tokenizer(prompt + "\n", return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item()
assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}"
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
labels = input_tokens.clone()
labels[:prompt_len] = -100
if len(labels) < max_length:
# pad to max_length with -100
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
assert (labels == -100).sum() < len(labels), f"Labels are all -100, something wrong. prompt length {prompt_len} exceeds max length {max_length}"
if (labels == -100).sum() == len(labels) - 1:
print(prompt)
print(response)
raise
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
out["labels"].append(labels)
out["input_ids"].append(input_tokens)
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
return out
def load_data(config, tokenizer):
dataset_path = config["dataset_path"]
if os.path.exists(dataset_path):
if os.path.isdir(dataset_path):
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
else:
files = [dataset_path]
print(f"Reading files {files}")
dataset = load_dataset("json", data_files=files, split="train")
else:
dataset = load_dataset(dataset_path, split="train")
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
train_dataset, val_dataset = dataset["train"], dataset["test"]
if config["streaming"] is False:
kwargs = {"num_proc": config["num_proc"]}
else:
kwargs = {}
# tokenize inputs and return labels and attention mask
train_dataset = train_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
batched=True,
remove_columns=["source", "prompt"],
**kwargs
)
val_dataset = val_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
batched=True,
remove_columns=["source", "prompt"],
**kwargs
)
train_dataset = train_dataset.with_format("torch")
val_dataset = val_dataset.with_format("torch")
# create dataloader with default data collator since we already have labels
train_dataloader = DataLoader(
train_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
)
val_dataloader = DataLoader(
val_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
)
return train_dataloader, val_dataloader
def load_data_for_inference(config, tokenizer):
dataset_path = config["dataset_path"]
if os.path.exists(dataset_path):
# check if path is a directory
if os.path.isdir(dataset_path):
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
else:
files = [dataset_path]
print(f"Reading files {files}")
dataset = load_dataset("json", data_files=files, split="train")
else:
dataset = load_dataset(dataset_path, split="train")
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
train_dataset, val_dataset = dataset["train"], dataset["test"]
train_dataset = train_dataset.add_column("index", list(range(len(train_dataset))))
# select first N batches that are divisible by batch_size
# gather is a bit annoying (or the way I'm using it) to get uneven batches as it duplicates data
train_dataset = train_dataset.select(range((len(train_dataset) // config["batch_size"]) * config["batch_size"]))
val_dataset = val_dataset.add_column("index", list(range(len(val_dataset))))
val_dataset = val_dataset.select(range((len(val_dataset) // config["batch_size"]) * config["batch_size"]))
if config["streaming"] is False:
kwargs = {"num_proc": config["num_proc"]}
else:
kwargs = {}
# tokenize inputs and return labels and attention mask
train_dataset = train_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
batched=True,
**kwargs
)
val_dataset = val_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
batched=True,
**kwargs
)
train_dataset = train_dataset.with_format("torch")
val_dataset = val_dataset.with_format("torch")
return train_dataset, val_dataset

20
gpt4all-training/env.yaml Normal file
View File

@@ -0,0 +1,20 @@
name: vicuna
channels:
- conda-forge
- pytorch
- nvidia
- huggingface
dependencies:
- python=3.8
- accelerate
- datasets
- torchmetrics
- evaluate
- transformers
- wandb
- jsonlines
- pip:
- peft
- nodelist-inflator
- deepspeed
- sentencepiece

View File

@@ -0,0 +1,28 @@
import glob
import pickle
import numpy as np
from matplotlib import pyplot as plt
plt.figure()
for fpath in glob.glob('./eval_data/*.pkl'):
parts = fpath.split('__')
model_name = "-".join(fpath.replace(".pkl", "").split("_")[2:])
with open(fpath, 'rb') as f:
data = pickle.load(f)
perplexities = data['perplexities']
perplexities = np.nan_to_num(perplexities, 100)
perplexities = np.clip(perplexities, 0, 100)
if 'alpaca' not in fpath:
identifier = model_name = "-".join(fpath.replace(".pkl", "").split("eval__model-")[1:])
label = 'GPT4all-'
label += identifier
else:
label = 'alpaca-lora'
plt.hist(perplexities, label=label, alpha=.5, bins=50)
plt.xlabel('Perplexity')
plt.ylabel('Frequency')
plt.legend()
plt.savefig('figs/perplexity_hist.png')

View File

@@ -0,0 +1,108 @@
import json
import torch
import pickle
import numpy as np
from tqdm import tqdm
from read import read_config
from argparse import ArgumentParser
from peft import PeftModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
'''
Evaluates perplexity on the outputs of:
https://github.com/yizhongw/self-instruct/blob/main/human_eval/user_oriented_instructions.jsonl
'''
def read_jsonl_file(file_path):
data = []
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
json_object = json.loads(line.strip())
data.append(json_object)
return data
def setup_model(config):
model = AutoModelForCausalLM.from_pretrained(config["model_name"], device_map="auto", torch_dtype=torch.float16, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
added_tokens = tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>"})
if added_tokens > 0:
model.resize_token_embeddings(len(tokenizer))
if 'lora' in config and config['lora']:
model = PeftModelForCausalLM.from_pretrained(model, config["lora_path"], device_map="auto", torch_dtype=torch.float16, return_hidden_states=True)
model.to(dtype=torch.float16)
print(f"Mem needed: {model.get_memory_footprint() / 1024 / 1024 / 1024:.2f} GB")
return model, tokenizer
def eval_example(model, tokenizer, example, config):
prompt = example['instruction'] + ' ' + example['instances'][0]['input']
gt = prompt + ' ' + example['instances'][0]['output']
#decode several continuations and compute their page trajectories
input = tokenizer(prompt, return_tensors="pt")
input = {k: v.to(model.device) for k, v in input.items()}
#compute the ground truth perplexity
gt_input = tokenizer(gt, return_tensors="pt")
gt_input = {k: v.to(model.device) for k, v in gt_input.items()}
nlls = []
prev_end_loc = 0
stride = 512
seq_len = gt_input['input_ids'].size(1)
for begin_loc in tqdm(range(input['input_ids'].size(1), gt_input['input_ids'].size(1), stride)):
end_loc = min(begin_loc + stride, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
input_ids = gt_input['input_ids'][:, begin_loc:end_loc].to(model.device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
neg_log_likelihood = outputs.loss * trg_len
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).sum() / end_loc).item()
print('ppl: ', ppl)
print(prompt)
print(80*'-')
return ppl
def do_eval(config):
eval_data = read_jsonl_file('eval_data/user_oriented_instructions.jsonl')
model, tokenizer = setup_model(config)
all_perplexities = []
for example in tqdm(eval_data):
gt_perplexity = eval_example(model, tokenizer, example, config)
all_perplexities.append(gt_perplexity)
name = f"eval_data/eval__model-{config['model_name'].replace('/', '_')}{'__lora-' + config['lora_path'].replace('/', '_') if config['lora'] else ''}.pkl"
with open(name, 'wb') as f:
r = {'perplexities': all_perplexities}
pickle.dump(r, f)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
config = read_config(args.config)
do_eval(config)

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 362 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 308 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 356 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 353 KiB

View File

@@ -0,0 +1,58 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModelForCausalLM
from read import read_config
from argparse import ArgumentParser
import torch
import time
def generate(tokenizer, prompt, model, config):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
outputs = model.generate(input_ids=input_ids, max_new_tokens=config["max_new_tokens"], temperature=config["temperature"])
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
return decoded[len(prompt):]
def setup_model(config):
model = AutoModelForCausalLM.from_pretrained(config["model_name"], device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
added_tokens = tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>"})
if added_tokens > 0:
model.resize_token_embeddings(len(tokenizer))
if config["lora"]:
model = PeftModelForCausalLM.from_pretrained(model, config["lora_path"], device_map="auto", torch_dtype=torch.float16)
model.to(dtype=torch.float16)
print(f"Mem needed: {model.get_memory_footprint() / 1024 / 1024 / 1024:.2f} GB")
return model, tokenizer
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--prompt", type=str)
args = parser.parse_args()
config = read_config(args.config)
if config["prompt"] is None and args.prompt is None:
raise ValueError("Prompt is required either in config or as argument")
prompt = config["prompt"] if args.prompt is None else args.prompt
print("Setting up model")
model, tokenizer = setup_model(config)
print("Generating")
start = time.time()
generation = generate(tokenizer, prompt, model, config)
print(f"Done in {time.time() - start:.2f}s")
print(generation)

View File

@@ -0,0 +1,204 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
from argparse import ArgumentParser
from read import read_config
from accelerate.utils import set_seed
from data import load_data_for_inference
from tqdm import tqdm
from datasets import Dataset
import torch.distributed as dist
from transformers.trainer_pt_utils import nested_numpify
from transformers import DefaultDataCollator
from torch.utils.data import DataLoader, DistributedSampler
import numpy as np
import pyarrow as pa
from pyarrow import compute as pc
def calc_cross_entropy_no_reduction(lm_logits, labels):
# calculate cross entropy across batch dim
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(reduction='none')
loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels).mean(dim=1)
return loss
def rank0_print(msg):
if dist.get_rank() == 0:
print(msg)
def inference(config):
set_seed(config['seed'])
rank0_print(f"World size: {dist.get_world_size()}")
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
# llama has no pad token, set it to new token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
train_dataset, val_dataset = load_data_for_inference(config, tokenizer)
num_processes = dist.get_world_size()
local_rank = dist.get_rank()
train_sampler = DistributedSampler(train_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
train_dataloader = DataLoader(
train_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
sampler=train_sampler,
drop_last=True
)
val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
val_dataloader = DataLoader(
val_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
sampler=val_sampler,
drop_last=True
)
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
model.to(f"cuda:{local_rank}")
with torch.no_grad():
train_outputs = {"loss": [], "embeddings": [], "index": []}
for batch in tqdm(train_dataloader, disable=local_rank != 0):
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
train_outputs["loss"].extend(loss)
embeddings = outputs.hidden_states[-1]
batch_size = batch["input_ids"].shape[0]
sequence_lengths = []
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
# <|endoftext|> is repeated
for item in batch["input_ids"]:
indices = torch.where(item == tokenizer.pad_token_id)[0]
found = False
for index in indices:
# case where sequence is less than max length
if torch.all(item[index:] == tokenizer.pad_token_id):
sequence_lengths.append(index)
found = True
break
# case where sequence is >= max length
if not found:
sequence_lengths.append(len(item) - 1)
sequence_lengths = torch.tensor(sequence_lengths)
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
train_outputs["embeddings"].append(pooled_logits)
train_outputs["index"].extend(batch["index"].to(model.device))
torch.cuda.empty_cache()
train_outputs = nested_numpify(train_outputs)
# stack since they're 0-dim arrays
train_outputs["index"] = np.stack(train_outputs["index"])
train_outputs["loss"] = np.stack(train_outputs["loss"])
train_outputs["embeddings"] = np.concatenate(train_outputs["embeddings"])
df_train = Dataset.from_dict(train_outputs)
curr_idx = df_train["index"]
# compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = train_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_train = Dataset.from_dict(filtered_table.to_pydict())
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
filtered_train = filtered_train.add_column("loss", df_train["loss"])
filtered_train = filtered_train.add_column("is_train", [True] * len(filtered_train))
filtered_train.to_json(f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
val_outputs = {"loss": [], "embeddings": [], "index": []}
for batch in tqdm(val_dataloader, disable=local_rank != 0):
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
val_outputs["loss"].extend(loss)
embeddings = outputs.hidden_states[-1]
batch_size = batch["input_ids"].shape[0]
sequence_lengths = []
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
# <|endoftext|> is repeated
for item in batch["input_ids"]:
indices = torch.where(item == tokenizer.pad_token_id)[0]
found = False
for index in indices:
# case where sequence is less than max length
if torch.all(item[index:] == tokenizer.pad_token_id):
sequence_lengths.append(index)
found = True
break
# case where sequence is >= max length
if not found:
sequence_lengths.append(len(item) - 1)
sequence_lengths = torch.tensor(sequence_lengths)
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
val_outputs["embeddings"].append(pooled_logits)
val_outputs["index"].extend(batch["index"].to(model.device))
torch.cuda.empty_cache()
val_outputs = nested_numpify(val_outputs)
val_outputs["index"] = np.stack(val_outputs["index"])
val_outputs["loss"] = np.stack(val_outputs["loss"])
val_outputs["embeddings"] = np.concatenate(val_outputs["embeddings"])
df_val = Dataset.from_dict(val_outputs)
curr_idx = df_val["index"]
# compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = val_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
filtered_val = filtered_val.add_column("loss", df_val["loss"])
filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val))
filtered_val.to_json(f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
def main():
dist.init_process_group("nccl")
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
inference(config)
if __name__ == "__main__":
# parse arguments by reading in a config
main()

View File

@@ -0,0 +1,88 @@
#!/bin/bash
# Display header
echo "=========================================================="
echo " ██████ ██████ ████████ ██ ██ █████ ██ ██ "
echo "██ ██ ██ ██ ██ ██ ██ ██ ██ ██ "
echo "██ ███ ██████ ██ ███████ ███████ ██ ██ "
echo "██ ██ ██ ██ ██ ██ ██ ██ ██ "
echo " ██████ ██ ██ ██ ██ ██ ███████ ███████ "
echo " └─> https://github.com/nomic-ai/gpt4all"
# Function to detect macOS architecture and set the binary filename
detect_mac_arch() {
local mac_arch
mac_arch=$(uname -m)
case "$mac_arch" in
arm64)
os_type="M1 Mac/OSX"
binary_filename="gpt4all-lora-quantized-OSX-m1"
;;
x86_64)
os_type="Intel Mac/OSX"
binary_filename="gpt4all-lora-quantized-OSX-intel"
;;
*)
echo "Unknown macOS architecture"
exit 1
;;
esac
}
# Detect operating system and set the binary filename
case "$(uname -s)" in
Darwin*)
detect_mac_arch
;;
Linux*)
if grep -q Microsoft /proc/version; then
os_type="Windows (WSL)"
binary_filename="gpt4all-lora-quantized-win64.exe"
else
os_type="Linux"
binary_filename="gpt4all-lora-quantized-linux-x86"
fi
;;
CYGWIN*|MINGW32*|MSYS*|MINGW*)
os_type="Windows (Cygwin/MSYS/MINGW)"
binary_filename="gpt4all-lora-quantized-win64.exe"
;;
*)
echo "Unknown operating system"
exit 1
;;
esac
echo "================================"
echo "== You are using $os_type."
# Change to the chat directory
cd chat
# List .bin files and prompt user to select one
bin_files=(*.bin)
echo "== Available .bin files:"
for i in "${!bin_files[@]}"; do
echo " [$((i+1))] ${bin_files[i]}"
done
# Function to get user input and validate it
get_valid_user_input() {
local input_valid=false
while ! $input_valid; do
echo "==> Please enter a number:"
read -r user_selection
if [[ $user_selection =~ ^[0-9]+$ ]] && (( user_selection >= 1 && user_selection <= ${#bin_files[@]} )); then
input_valid=true
else
echo "Invalid input. Please enter a number between 1 and ${#bin_files[@]}."
fi
done
}
get_valid_user_input
selected_bin_file="${bin_files[$((user_selection-1))]}"
# Run the selected .bin file with the appropriate command
./"$binary_filename" -m "$selected_bin_file"

10
gpt4all-training/read.py Normal file
View File

@@ -0,0 +1,10 @@
import yaml
def read_config(path):
# read yaml and return contents
with open(path, 'r') as file:
try:
return yaml.safe_load(file)
except yaml.YAMLError as exc:
print(exc)

View File

@@ -0,0 +1,15 @@
accelerate
datasets
torchmetrics
evaluate
transformers>=4.28.0
wandb
pip
peft
nodelist-inflator
deepspeed
sentencepiece
jsonlines
nomic
scikit-learn
matplotlib

233
gpt4all-training/train.py Normal file
View File

@@ -0,0 +1,233 @@
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM
import torch
from torch.optim import AdamW
from argparse import ArgumentParser
from read import read_config
from accelerate import Accelerator
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
from peft import get_peft_model, LoraConfig, TaskType
from data import load_data
from torchmetrics import MeanMetric
from tqdm import tqdm
import wandb
torch.backends.cuda.matmul.allow_tf32 = True
def format_metrics(metrics, split, prefix=""):
log = f"[{split}]" + prefix
log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
return log
def evaluate(model, val_dataloader):
model.eval()
val_loss = MeanMetric(nan_strategy="error").to(model.device)
with torch.no_grad():
for batch in tqdm(val_dataloader):
loss = model(**batch).loss
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
val_loss.update(loss_values["loss"])
return val_loss
def train(accelerator, config):
set_seed(config['seed'])
accelerator.print(config)
accelerator.print(f"Using {accelerator.num_processes} GPUs")
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
# if no pad token, set it to eos
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
with accelerator.main_process_first():
train_dataloader, val_dataloader = load_data(config, tokenizer)
checkpoint = config["gradient_checkpointing"]
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
use_cache=False if checkpoint else True,
trust_remote_code=True)
if checkpoint:
model.gradient_checkpointing_enable()
if config["lora"]:
peft_config = LoraConfig(
# should R be configurable?
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
optimizer_cls = (
AdamW
if accelerator.state.deepspeed_plugin is None
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
else DummyOptim
)
# karpathy doesn't decay embeddding, maybe we should exclude
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
if accelerator.state.deepspeed_plugin is not None:
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
"gradient_accumulation_steps"
]
# decay to min_lr instead of 0
lr_ratio = config["min_lr"] / config["lr"]
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"]
# instead of decaying to zero, decay to ratio of min_lr / lr
total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"]
accelerator.print(f"Total training steps: {total_num_steps}")
# Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
if (
accelerator.state.deepspeed_plugin is None
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
):
scheduler = get_scheduler(
name="cosine",
optimizer=optimizer,
num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
num_training_steps=total_num_steps,
)
else:
scheduler = DummyScheduler(
optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"]
)
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
model, optimizer, train_dataloader, val_dataloader, scheduler
)
# setup for saving training states in case preemption
accelerator.register_for_checkpointing(scheduler)
if config["checkpoint"]:
accelerator.load_state(config["checkpoint"])
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
training_difference = os.path.splitext(path)[0]
resume_step = int(training_difference.replace("step_", ""))
accelerator.skip_first_batches(train_dataloader, resume_step)
accelerator.print(f"Resuming from step {resume_step}")
# log gradients
if accelerator.is_main_process and config["wandb"]:
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
for epoch in range(config["num_epochs"]):
train_loss = MeanMetric(nan_strategy="error").to(model.device)
for step, batch in enumerate(tqdm(train_dataloader)):
model.train()
outputs = model(**batch)
loss = outputs.loss
# gather loss before backprop in case of gradient accumulation
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
train_loss.update(loss_values["loss"])
loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
# get gradient norm of all params
# log LR in case something weird happens
if step > 0 and step % (config["eval_every"] // 10) == 0:
if config["wandb"]:
curr_step = step + epoch * len(train_dataloader)
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step)
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step > 0 and step % config["save_every"] == 0:
curr_step = step + epoch * len(train_dataloader)
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss = evaluate(model, val_dataloader)
log_train = {
"train_loss": train_loss.compute()
}
log_val = {
"val_loss": val_loss.compute()
}
if config["wandb"]:
curr_step = step + epoch * len(train_dataloader)
accelerator.log({**log_train, **log_val}, step=curr_step)
accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
train_loss.reset()
accelerator.print(f"Epoch {epoch} finished")
accelerator.print(f"Pushing to HF hub")
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
try:
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
except Exception as e:
accelerator.print(e)
accelerator.print(f"Failed to push to hub")
unwrapped_model.save_pretrained(
f"{config['output_dir']}/epoch_{epoch}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/final",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.end_training()
if __name__ == "__main__":
# parse arguments by reading in a config
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
if config["wandb"]:
accelerator = Accelerator(log_with="wandb")
accelerator.init_trackers(
project_name=config["wandb_project_name"],
config=config,
init_kwargs={"wandb": {"entity": config["wandb_entity"]}},
)
else:
accelerator = Accelerator()
train(accelerator, config=config)