mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-01 08:38:35 +00:00
mono repo structure
This commit is contained in:
17
gpt4all-training/GPT-J_MAP.md
Normal file
17
gpt4all-training/GPT-J_MAP.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# Inference on Training Data
|
||||
|
||||
|
||||
## Run Inference
|
||||
|
||||
```bash
|
||||
torchrun --master_port=29085 --nproc-per-node 8 inference.py --config=configs/inference/gptj.yaml
|
||||
```
|
||||
|
||||
|
||||
## Visualizations
|
||||
|
||||
```bash
|
||||
python build_map.py
|
||||
```
|
||||
|
||||
will build a map in `Atlas`, one using the internal clustering algorithm provided by Nomic and one using the embeddings generated by the finetuned model.
|
283
gpt4all-training/TRAINING_LOG.md
Normal file
283
gpt4all-training/TRAINING_LOG.md
Normal file
@@ -0,0 +1,283 @@
|
||||
# Training Trials and Tribulations of gpt4all
|
||||
|
||||
This is a training log for both the LoRa and full model training we underwent
|
||||
|
||||
## Inspiration
|
||||
|
||||
|
||||
|
||||
## Initial Experiment
|
||||
|
||||
We train an initial LoRa model on ~700k examples including data from P3/BLOOM, StackOverflow, and unified_chip2.
|
||||
We trained using 8 x A100 80GB GPUs.
|
||||
|
||||
We used the initial parameters:
|
||||
|
||||
| Hyperparameter | Value |
|
||||
|----------------|-------|
|
||||
| Per Device BS | 4 |
|
||||
| Global BS | 32 |
|
||||
| Learning rate | 5e-5 |
|
||||
| Epochs | 1 |
|
||||
| Max length | 1024 |
|
||||
| Weight decay | 0 |
|
||||
| Warmup Steps | 100 |
|
||||
|
||||
We randomly shuffle and set aside 5% of the data for validation.
|
||||
|
||||
We had an initial bug in logging the training loss but we noticed a decrease in validation loss.
|
||||
|
||||

|
||||
|
||||
|
||||
During generation, we noticed that our model seemed to never generate an `eos` token. We realized that this was due to the fact that we set
|
||||
the tokenizer `pad` token equal to the `eos` token and most likely never learned to "stop". However, generations seemed to be working
|
||||
well, even with greedy generation.
|
||||
|
||||
For example, when prompted with `Generate a python script to reverse a string.`, the model outputs valid markdown and a working script
|
||||
|
||||
Here's a Python script to reverse a string:
|
||||
|
||||
```python
|
||||
def reverse_string(string):
|
||||
return ''.join(reversed(string))
|
||||
```
|
||||
|
||||
This function takes a string as an argument and returns the reversed string. The `reversed()` function reverses the string and returns a list of characters. The `join()` function joins the list of characters into a string.
|
||||
|
||||
You can use this function like this:
|
||||
|
||||
```python
|
||||
string = 'hello world'
|
||||
print(reverse_string(string))
|
||||
```
|
||||
|
||||
This will output `'olleh dna hwl'`.
|
||||
|
||||
Note that this function does not handle strings with spaces or other special characters. If you need to handle those cases, you can use a regular expression to split the string into words and then reverse each word.
|
||||
|
||||
```python
|
||||
import re
|
||||
|
||||
def reverse_string(string):
|
||||
words = re.split(' ', string)
|
||||
return ''.join(reversed(word) for word in words)
|
||||
```
|
||||
|
||||
This function takes a string as an argument and returns the reversed string. The `re.split()` function splits the string into words using a space as the delimiter. The `reversed()` function reverses each word and returns a list of characters. The `join()` function joins the list of characters into a string.
|
||||
|
||||
You can use this function like this:
|
||||
|
||||
```python
|
||||
string = 'hello world'
|
||||
print(reverse_string(string))
|
||||
```
|
||||
|
||||
This will output `'olleh dna hwl'`.
|
||||
|
||||
Note that this function also handles strings with spaces or other special characters.
|
||||
|
||||
|
||||
Answer: Here's a Python script to reverse a string:
|
||||
|
||||
```python
|
||||
def reverse_string(string):
|
||||
return ''.join(reversed(string))
|
||||
```
|
||||
|
||||
This function takes a string as an argument and returns the reversed string. The `reversed()` function reverses the string and returns a list of characters. The `join()` function joins the list of characters into a string.
|
||||
|
||||
You can use this function like this:
|
||||
|
||||
```python
|
||||
string = 'hello world'
|
||||
print(reverse_string(string))
|
||||
```
|
||||
|
||||
This will output `'olleh dna hwl'`.
|
||||
|
||||
Note that this function does not handle strings with
|
||||
|
||||
although it duplicates and continues generating past answering the full question.
|
||||
|
||||
The base model doesn't answer the question and hallucinates
|
||||
```
|
||||
Write a python script to reverse a string. The string should be reversed in place.
|
||||
The string should be reversed in place.
|
||||
The string should be reversed in place. The string should be reversed in place.
|
||||
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
|
||||
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
|
||||
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
|
||||
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
|
||||
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
|
||||
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
|
||||
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
|
||||
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place.
|
||||
The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be reversed in place. The string should be
|
||||
```
|
||||
|
||||
|
||||
## EOS and Accidental Duplication
|
||||
|
||||
Seeing as the model never stopped, we experimented with training a model with a separate token for `eos` and `pad`. Since we couldn't find a `pad` token present in the vocabulary, we added one to the tokenizer and expanded the embedding size of the model (from 32,000). In theory, we could have expanded the embedding size by a factor of 64 to improve throughput and performance, as [noted here](https://twitter.com/ctnzr/status/1623758178587648000?s=20).
|
||||
For every sequence, we appended an `eos` token in hopes the model would learn to exit.
|
||||
|
||||
We successfully trained a model using the same parameters as before 
|
||||
|
||||
During generation, our model exited early even with greedy generations
|
||||
|
||||
You can use the `reversed()` function to reverse a string in Python. Here's an example:
|
||||
|
||||
```python
|
||||
string = 'hello world'
|
||||
reversed_string = reversed(string)
|
||||
print(reversed_string)
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
world hello
|
||||
```
|
||||
|
||||
For a harder prompt where we try to trick the model with comments
|
||||
|
||||
```python
|
||||
#this code prints a string reversed
|
||||
my_string = "hello how are you"
|
||||
print(len(my_string))
|
||||
|
||||
|
||||
My code above does not work. Can you help me?
|
||||
```
|
||||
|
||||
The model correctly answers
|
||||
|
||||
The code you provided does not print a string reversed. It prints the length of the string "hello how are you".
|
||||
|
||||
|
||||
We realized that we had two bugs however:
|
||||
- We accidentally duplicated data and effectively trained for 2 epochs instead of 1
|
||||
- We added an eos token to every sequence, even those that we truncated (e.g. long code that exceeds the 1024).
|
||||
|
||||
## Conditional EOS and 1 Epoch
|
||||
|
||||
Using the same parameters, we then trained a model using a "conditional" eos token where we only add an `eos` when the inputs are less than the maximum sequence length for one epoch.
|
||||
|
||||
Our training and validation loss looks fairly normal
|
||||
|
||||

|
||||
|
||||
However, we found generations to be slightly poorer. For the same prompt of `Generate a python script to reverse a string.`, the model generates
|
||||
|
||||
Write a python script to reverse a string.
|
||||
Here is a python script to reverse a string:
|
||||
import string
|
||||
|
||||
def reverse_string(string):
|
||||
return string[::-1]
|
||||
|
||||
print(reverse_string('hello world'))
|
||||
|
||||
Output:
|
||||
world hello
|
||||
|
||||
The prompt
|
||||
```python
|
||||
#this code prints a string reversed
|
||||
my_string = "hello how are you"
|
||||
print(len(my_string))
|
||||
|
||||
|
||||
My code above does not work. Can you help me?
|
||||
```
|
||||
|
||||
does not generate any text.
|
||||
|
||||
|
||||
And the prompt `"Generate a python script to make a get request to an api endpoint."` generates
|
||||
|
||||
I'm sorry, I cannot provide a specific answer to this question as it requires more context and details about the API endpoint and the specific task you are trying to accomplish. Can you please provide more information?
|
||||
|
||||
|
||||
## Multi Epoch and Full Model Training
|
||||
|
||||
We decided to remove the entire Bigscience/P3 subset from the final training dataset due to data diversity considerations.
|
||||
P3 contains many homogeneous prompts which produce short and homogeneous responses from GPT-3.5-Turbo.
|
||||
The final dataset is ~400k examples.
|
||||
|
||||
We train a LoRa model using the parameters
|
||||
|
||||
| Hyperparameter | Value |
|
||||
|----------------|-------|
|
||||
| Per Device BS | 4 |
|
||||
| Global BS | 32 |
|
||||
| Learning rate | 5e-5 |
|
||||
| Epochs | 4 |
|
||||
| Max length | 1024 |
|
||||
| Weight decay | 0 |
|
||||
| Warmup Steps | 100 |
|
||||
|
||||
|
||||
We additionally train a full model
|
||||
| Hyperparameter | Value |
|
||||
|----------------|-------|
|
||||
| Per Device BS | 32 |
|
||||
| Global BS | 256 |
|
||||
| Learning rate | 5e-5 |
|
||||
| Epochs | 2 |
|
||||
| Max length | 1024 |
|
||||
| Weight decay | 0 |
|
||||
| Warmup Steps | 100 |
|
||||
|
||||
Taking inspiration from [the Alpaca Repo](https://github.com/tatsu-lab/stanford_alpaca), we roughly scale the learning rate by `sqrt(k)`, where `k` is the increase in batch size, where Alpaca used a batch size of 128 and learning rate of 2e-5.
|
||||
|
||||
Comparing our model LoRa to the [Alpaca LoRa](https://huggingface.co/tloen/alpaca-lora-7b), our model has lower perplexity. Qualitatively, training on 3 epochs performed the best on perplexity as well as qualitative examples.
|
||||
|
||||
We tried training a full model using the parameters above, but found that during the second epoch the model diverged and samples generated post training were worse than the first epoch.
|
||||
|
||||
|
||||
## GPT-J Training
|
||||
|
||||
### Model Training Divergence
|
||||
|
||||
We trained multiple [GPT-J models](https://huggingface.co/EleutherAI/gpt-j-6b) with varying success. We found that training the full model lead to diverged post epoch 1. 
|
||||
|
||||
|
||||
We release the checkpoint after epoch 1.
|
||||
|
||||
|
||||
Using Atlas, we extracted the embeddings of each point in the dataset and calculated the loss per sequence. We then uploaded [this to Atlas](https://atlas.nomic.ai/map/gpt4all-j-post-epoch-1-embeddings) and noticed that the higher loss items seem to cluster. On further inspection, the highest density clusters seemded to be of prompt/response pairs that asked for creative-like generations such as `Generate a story about ...` 
|
||||
|
||||
|
||||
|
||||
### GPT4All-J Hyperparameters
|
||||
|
||||
We varied learning rate, learning rate schedule, and weight decay following suggestions from the [original GPT-J codebase](https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md) but found no real performance difference (qualitatively or quantitatively) when varying these parameters.
|
||||
|
||||
|
||||
|
||||
The final model was trained using the following hyperparameters with a linear warmup followed by constant learning rate:
|
||||
|
||||
| Hyperparameter | Value |
|
||||
|----------------|-------|
|
||||
| Per Device BS | 32 |
|
||||
| Global BS | 256 |
|
||||
| Learning rate | 2e-5 |
|
||||
| Epochs | 2 |
|
||||
| Max length | 1024 |
|
||||
| Weight decay | 0 |
|
||||
| Warmup Steps | 500 |
|
||||
|
||||
|
||||
The LoRA model was trained using using the following hyperparameters with a linear warmup followed by constant learning rate:
|
||||
|
||||
| Hyperparameter | Value |
|
||||
|----------------|-------|
|
||||
| Per Device BS | 4 |
|
||||
| Global BS | 32 |
|
||||
| Learning rate | 2e-5 |
|
||||
| Epochs | 2 |
|
||||
| Max length | 1024 |
|
||||
| Weight decay | 0 |
|
||||
| Warmup Steps | 500 |
|
54
gpt4all-training/build_map.py
Normal file
54
gpt4all-training/build_map.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import numpy as np
|
||||
from nomic import atlas
|
||||
import glob
|
||||
from tqdm import tqdm
|
||||
from datasets import load_dataset, concatenate_datasets
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
files = glob.glob("inference/*.jsonl")
|
||||
print(files)
|
||||
df = concatenate_datasets([load_dataset("json", data_files=file, split="train") for file in tqdm(files)])
|
||||
|
||||
print(len(df))
|
||||
print(df)
|
||||
|
||||
df = df.map(lambda example: {"inputs": [prompt + "\n" + response for prompt, response in zip(example["prompt"], example["response"])]},
|
||||
batched=True,
|
||||
num_proc=64)
|
||||
|
||||
df = df.map(lambda example: {"trained_on": [int(t) for t in example["is_train"]]},
|
||||
batched=True,
|
||||
num_proc=64)
|
||||
|
||||
df = df.remove_columns("is_train")
|
||||
|
||||
text = df.remove_columns(["labels", "input_ids", "embeddings"])
|
||||
|
||||
text_df = [text[i] for i in range(len(text))]
|
||||
|
||||
atlas.map_text(text_df, indexed_field="inputs",
|
||||
name="CHANGE ME!",
|
||||
colorable_fields=["source", "loss", "trained_on"],
|
||||
reset_project_if_exists=True,
|
||||
)
|
||||
|
||||
# index is local to train/test split, regenerate
|
||||
data = df.remove_columns(["labels", "input_ids", "index"])
|
||||
data = data.add_column("index", list(range(len(data))))
|
||||
# max embed dim is 2048 for now
|
||||
# note! this is slow in pyarrow/hf datasets
|
||||
embeddings = np.array(data["embeddings"])
|
||||
print("embeddings shape:", embeddings.shape)
|
||||
embeddings = PCA(n_components=2048).fit_transform(embeddings)
|
||||
|
||||
data = data.remove_columns(["embeddings"])
|
||||
columns = data.to_pandas().to_dict("records")
|
||||
|
||||
atlas.map_embeddings(embeddings,
|
||||
data=columns,
|
||||
id_field="index",
|
||||
name="CHANGE ME!",
|
||||
colorable_fields=["source", "loss", "trained_on"],
|
||||
build_topic_model=True,
|
||||
topic_label_field="inputs",
|
||||
reset_project_if_exists=True,)
|
BIN
gpt4all-training/chat/gpt4all-lora-quantized-OSX-intel
Executable file
BIN
gpt4all-training/chat/gpt4all-lora-quantized-OSX-intel
Executable file
Binary file not shown.
BIN
gpt4all-training/chat/gpt4all-lora-quantized-OSX-m1
Executable file
BIN
gpt4all-training/chat/gpt4all-lora-quantized-OSX-m1
Executable file
Binary file not shown.
BIN
gpt4all-training/chat/gpt4all-lora-quantized-linux-x86
Executable file
BIN
gpt4all-training/chat/gpt4all-lora-quantized-linux-x86
Executable file
Binary file not shown.
BIN
gpt4all-training/chat/gpt4all-lora-quantized-win64.exe
Normal file
BIN
gpt4all-training/chat/gpt4all-lora-quantized-win64.exe
Normal file
Binary file not shown.
74
gpt4all-training/clean.py
Normal file
74
gpt4all-training/clean.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import numpy as np
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
import jsonlines
|
||||
import pandas as pd
|
||||
|
||||
|
||||
prompt_generation_dir = "raw_data_sanity_cleaned_without_p3/"
|
||||
for file in glob.glob(os.path.join(prompt_generation_dir, "*.jsonl")):
|
||||
if "clean.jsonl" in file:
|
||||
continue
|
||||
data = []
|
||||
print(file)
|
||||
with open(file) as f:
|
||||
for line in f:
|
||||
try:
|
||||
contents = json.loads(line)
|
||||
data.append(contents)
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
processed = []
|
||||
|
||||
for item in data:
|
||||
if 'source' not in item:
|
||||
item['source'] = 'unspecified'
|
||||
if 'model_settings' in item:
|
||||
item.pop('model_settings', None)
|
||||
|
||||
for key in list(item.keys()):
|
||||
if key not in ['source', 'prompt', 'response']:
|
||||
#print(item[key])
|
||||
item.pop(key, None)
|
||||
|
||||
if isinstance(item['prompt'], dict):
|
||||
if "value" in item["prompt"]:
|
||||
item["prompt"] = item["prompt"]["value"]
|
||||
elif "description" in item["prompt"]:
|
||||
item["prompt"] = item["prompt"]["description"]
|
||||
else:
|
||||
continue
|
||||
|
||||
elif not isinstance(item['prompt'], str):
|
||||
continue
|
||||
|
||||
if isinstance(item['response'], dict):
|
||||
if "value" in item["response"]:
|
||||
item["response"] = item["response"]["value"]
|
||||
elif "description" in item["response"]:
|
||||
item["response"] = item["response"]["description"]
|
||||
else:
|
||||
continue
|
||||
elif not isinstance(item['response'], str):
|
||||
continue
|
||||
|
||||
if item:
|
||||
processed.append(item)
|
||||
|
||||
df = pd.DataFrame(processed)
|
||||
prev_len = len(df)
|
||||
|
||||
# drop empty or null string
|
||||
df = df.dropna(subset=['prompt', 'response'])
|
||||
df = df[df['prompt'] != '']
|
||||
df = df[df['response'] != '']
|
||||
df = df[df["prompt"].str.len() > 1]
|
||||
curr_len = len(df)
|
||||
|
||||
print(f"Removed {prev_len - curr_len} rows")
|
||||
|
||||
clean_name = file.split(".jsonl")[0] + "_clean.jsonl"
|
||||
print(f"writing to {curr_len} rows to {clean_name}")
|
||||
df.to_json(clean_name, orient="records", lines=True)
|
48
gpt4all-training/configs/deepspeed/ds_config.json
Normal file
48
gpt4all-training/configs/deepspeed/ds_config.json
Normal file
@@ -0,0 +1,48 @@
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"min_loss_scale": 1,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"initial_scale_power": 32
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"gradient_clipping": 1,
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_param": {
|
||||
"device": "none"
|
||||
},
|
||||
"offload_optimizer": {
|
||||
"device": "none"
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": [
|
||||
0.9,
|
||||
0.999
|
||||
],
|
||||
"eps": 1e-08
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear"
|
||||
}
|
||||
}
|
||||
}
|
48
gpt4all-training/configs/deepspeed/ds_config_gptj.json
Normal file
48
gpt4all-training/configs/deepspeed/ds_config_gptj.json
Normal file
@@ -0,0 +1,48 @@
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"min_loss_scale": 1,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"initial_scale_power": 32
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_param": {
|
||||
"device": "none"
|
||||
},
|
||||
"offload_optimizer": {
|
||||
"device": "none"
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": [
|
||||
0.9,
|
||||
0.999
|
||||
],
|
||||
"eps": 1e-08
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear"
|
||||
}
|
||||
}
|
||||
}
|
48
gpt4all-training/configs/deepspeed/ds_config_gptj_lora.json
Normal file
48
gpt4all-training/configs/deepspeed/ds_config_gptj_lora.json
Normal file
@@ -0,0 +1,48 @@
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"min_loss_scale": 1,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"initial_scale_power": 32
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"gradient_clipping": 1,
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_param": {
|
||||
"device": "cpu"
|
||||
},
|
||||
"offload_optimizer": {
|
||||
"device": "cpu"
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": [
|
||||
0.9,
|
||||
0.999
|
||||
],
|
||||
"eps": 1e-08
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear"
|
||||
}
|
||||
}
|
||||
}
|
5
gpt4all-training/configs/eval/generate_baseline.yaml
Normal file
5
gpt4all-training/configs/eval/generate_baseline.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
# model/tokenizer
|
||||
model_name: "zpn/llama-7b"
|
||||
tokenizer_name: "zpn/llama-7b"
|
||||
lora: true
|
||||
lora_path: "tloen/alpaca-lora-7b"
|
4
gpt4all-training/configs/eval/generate_gpt4all_gptj.yaml
Normal file
4
gpt4all-training/configs/eval/generate_gpt4all_gptj.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
# model/tokenizer
|
||||
model_name: "nomic-ai/gpt4all-warmup-lr-epoch_0"
|
||||
tokenizer_name: "EleutherAI/gpt-j-6b"
|
||||
lora: false
|
@@ -0,0 +1,5 @@
|
||||
# model/tokenizer
|
||||
model_name: "EleutherAI/gpt-j-6b"
|
||||
tokenizer_name: "EleutherAI/gpt-j-6B"
|
||||
lora: true
|
||||
lora_path: "nomic-ai/gpt4all-gptj-lora-epoch_1"
|
@@ -0,0 +1,5 @@
|
||||
# model/tokenizer
|
||||
model_name: "zpn/llama-7b"
|
||||
tokenizer_name: "zpn/llama-7b"
|
||||
lora: true
|
||||
lora_path: "nomic-ai/gpt4all-lora"
|
9
gpt4all-training/configs/generate/generate.yaml
Normal file
9
gpt4all-training/configs/generate/generate.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
# model/tokenizer
|
||||
model_name: "zpn/llama-7b"
|
||||
tokenizer_name: "zpn/llama-7b"
|
||||
lora: true
|
||||
lora_path: "nomic-ai/gpt4all-lora"
|
||||
|
||||
max_new_tokens: 512
|
||||
temperature: 0
|
||||
prompt: null
|
15
gpt4all-training/configs/generate/generate_gptj.yaml
Normal file
15
gpt4all-training/configs/generate/generate_gptj.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
# model/tokenizer
|
||||
model_name: "nomic-ai/gpt4all-warmup-lr-epoch_1"
|
||||
tokenizer_name: "EleutherAI/gpt-j-6b"
|
||||
lora: false
|
||||
|
||||
|
||||
max_new_tokens: 512
|
||||
temperature: 0.001
|
||||
prompt: |
|
||||
#this code prints a string reversed
|
||||
my_string = "hello how are you"
|
||||
print(len(my_string))
|
||||
|
||||
|
||||
My code above does not work. Can you help me?
|
15
gpt4all-training/configs/generate/generate_gptj_lora.yaml
Normal file
15
gpt4all-training/configs/generate/generate_gptj_lora.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
# model/tokenizer
|
||||
model_name: "EleutherAI/gpt-j-6b"
|
||||
tokenizer_name: "EleutherAI/gpt-j-6b"
|
||||
lora: true
|
||||
lora_path: "nomic-ai/gpt4all-gptj-lora-epoch_0"
|
||||
|
||||
max_new_tokens: 512
|
||||
temperature: 0
|
||||
prompt: |
|
||||
#this code prints a string reversed
|
||||
my_string = "hello how are you"
|
||||
print(len(my_string))
|
||||
|
||||
|
||||
My code above does not work. Can you help me?
|
14
gpt4all-training/configs/generate/generate_llama.yaml
Normal file
14
gpt4all-training/configs/generate/generate_llama.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
# model/tokenizer
|
||||
model_name: # REPLACE WITH LLAMA MODEL NAME
|
||||
tokenizer_name: # REPLACE WITH LLAMA MODEL NAME
|
||||
|
||||
|
||||
max_new_tokens: 512
|
||||
temperature: 0.001
|
||||
prompt: |
|
||||
#this code prints a string reversed
|
||||
my_string = "hello how are you"
|
||||
print(len(my_string))
|
||||
|
||||
|
||||
My code above does not work. Can you help me?
|
14
gpt4all-training/configs/inference/gptj.yaml
Normal file
14
gpt4all-training/configs/inference/gptj.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
# model/tokenizer
|
||||
model_name: "nomic-ai/gpt4all-warmup-lr-epoch_1"
|
||||
tokenizer_name: "EleutherAI/gpt-j-6B"
|
||||
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: "nomic-ai/turbo-500k-multi"
|
||||
max_length: 1024
|
||||
batch_size: 32
|
||||
|
||||
# logging
|
||||
seed: 42
|
||||
|
30
gpt4all-training/configs/train/finetune.yaml
Normal file
30
gpt4all-training/configs/train/finetune.yaml
Normal file
@@ -0,0 +1,30 @@
|
||||
# model/tokenizer
|
||||
model_name: # add model here
|
||||
tokenizer_name: # add model here
|
||||
gradient_checkpointing: true
|
||||
save_name: # CHANGE
|
||||
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: # update
|
||||
max_length: 1024
|
||||
batch_size: 32
|
||||
|
||||
# train dynamics
|
||||
lr: 5.0e-5
|
||||
eval_every: 800
|
||||
eval_steps: 100
|
||||
save_every: 800
|
||||
output_dir: # CHANGE
|
||||
checkpoint: null
|
||||
lora: false
|
||||
warmup_steps: 100
|
||||
num_epochs: 2
|
||||
|
||||
# logging
|
||||
wandb: true
|
||||
wandb_entity: # update
|
||||
wandb_project_name: # update
|
||||
seed: 42
|
||||
|
33
gpt4all-training/configs/train/finetune_gptj.yaml
Normal file
33
gpt4all-training/configs/train/finetune_gptj.yaml
Normal file
@@ -0,0 +1,33 @@
|
||||
# model/tokenizer
|
||||
model_name: "EleutherAI/gpt-j-6B"
|
||||
tokenizer_name: "EleutherAI/gpt-j-6B"
|
||||
gradient_checkpointing: true
|
||||
save_name: # CHANGE
|
||||
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: # CHANGE
|
||||
max_length: 1024
|
||||
batch_size: 32
|
||||
|
||||
# train dynamics
|
||||
lr: 2.0e-5
|
||||
min_lr: 0
|
||||
weight_decay: 0.0
|
||||
eval_every: 500
|
||||
eval_steps: 105
|
||||
save_every: 500
|
||||
log_grads_every: 100
|
||||
output_dir: # CHANGE
|
||||
checkpoint: null
|
||||
lora: false
|
||||
warmup_steps: 500
|
||||
num_epochs: 2
|
||||
|
||||
# logging
|
||||
wandb: true
|
||||
wandb_entity: # CHANGE
|
||||
wandb_project_name: # CHANGE
|
||||
seed: 42
|
||||
|
33
gpt4all-training/configs/train/finetune_gptj_lora.yaml
Normal file
33
gpt4all-training/configs/train/finetune_gptj_lora.yaml
Normal file
@@ -0,0 +1,33 @@
|
||||
# model/tokenizer
|
||||
model_name: "EleutherAI/gpt-j-6b"
|
||||
tokenizer_name: "EleutherAI/gpt-j-6b"
|
||||
gradient_checkpointing: false
|
||||
save_name: # CHANGE
|
||||
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: # CHANGE
|
||||
max_length: 1024
|
||||
batch_size: 1
|
||||
|
||||
# train dynamics
|
||||
lr: 2.0e-5
|
||||
min_lr: 0
|
||||
weight_decay: 0.0
|
||||
eval_every: 500
|
||||
eval_steps: 105
|
||||
save_every: 500
|
||||
log_grads_every: 500
|
||||
output_dir: # CHANGE
|
||||
checkpoint: null
|
||||
lora: true
|
||||
warmup_steps: 500
|
||||
num_epochs: 2
|
||||
|
||||
# logging
|
||||
wandb: true
|
||||
wandb_entity: # CHANGE
|
||||
wandb_project_name: # CHANGE
|
||||
seed: 42
|
||||
|
31
gpt4all-training/configs/train/finetune_lora.yaml
Normal file
31
gpt4all-training/configs/train/finetune_lora.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
# model/tokenizer
|
||||
model_name: # update
|
||||
tokenizer_name: # update
|
||||
gradient_checkpointing: false
|
||||
save_name: # CHANGE
|
||||
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: # CHANGE
|
||||
max_length: 1024
|
||||
batch_size: 4
|
||||
|
||||
# train dynamics
|
||||
lr: 5.0e-5
|
||||
min_lr: 0
|
||||
weight_decay: 0.0
|
||||
eval_every: 2000
|
||||
eval_steps: 100
|
||||
save_every: 2000
|
||||
output_dir: # CHANGE
|
||||
checkpoint: null
|
||||
lora: true
|
||||
warmup_steps: 100
|
||||
num_epochs: 2
|
||||
|
||||
# logging
|
||||
wandb: true
|
||||
wandb_entity: # update
|
||||
wandb_project_name: # update
|
||||
seed: 42
|
8
gpt4all-training/create_hostname.sh
Normal file
8
gpt4all-training/create_hostname.sh
Normal file
@@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
export WORKER_IP=$1
|
||||
N_GPUS=8
|
||||
# create dir if doesn't exist
|
||||
sudo mkdir -p /job
|
||||
printf "localhost slots=$N_GPUS\n$WORKER_IP slots=$N_GPUS" | sudo tee /job/hostfile
|
||||
echo /job/hostfile
|
167
gpt4all-training/data.py
Normal file
167
gpt4all-training/data.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import glob
|
||||
import torch
|
||||
from datasets import load_dataset, concatenate_datasets
|
||||
import os
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import DefaultDataCollator
|
||||
|
||||
|
||||
|
||||
def tokenize_inputs(config, tokenizer, examples):
|
||||
max_length = config["max_length"]
|
||||
|
||||
# hacky backward compatible
|
||||
different_eos = tokenizer.eos_token != "</s>"
|
||||
out = {"labels": [], "input_ids": []}
|
||||
for prompt, response in zip(examples["prompt"], examples["response"]):
|
||||
if different_eos:
|
||||
if response.count("</s> \n") > 0:
|
||||
response = response.replace("</s> \n", f"{tokenizer.eos_token} \n")
|
||||
|
||||
prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0])
|
||||
|
||||
# hack if our prompt is super long
|
||||
# we need to include some labels so we arbitrarily trunacate at max_length // 2
|
||||
# if the length is too long
|
||||
if prompt_len >= max_length // 2:
|
||||
# if prompt is too long, truncate
|
||||
# but make sure to truncate to at max 1024 tokens
|
||||
new_len = min(max_length // 2, len(prompt) // 2)
|
||||
prompt = prompt[:new_len]
|
||||
# get new prompt length
|
||||
prompt_len = tokenizer(prompt + "\n", return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item()
|
||||
|
||||
assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}"
|
||||
|
||||
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
|
||||
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
|
||||
|
||||
labels = input_tokens.clone()
|
||||
labels[:prompt_len] = -100
|
||||
if len(labels) < max_length:
|
||||
# pad to max_length with -100
|
||||
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
|
||||
|
||||
assert (labels == -100).sum() < len(labels), f"Labels are all -100, something wrong. prompt length {prompt_len} exceeds max length {max_length}"
|
||||
|
||||
if (labels == -100).sum() == len(labels) - 1:
|
||||
print(prompt)
|
||||
print(response)
|
||||
raise
|
||||
|
||||
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
|
||||
out["labels"].append(labels)
|
||||
out["input_ids"].append(input_tokens)
|
||||
|
||||
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def load_data(config, tokenizer):
|
||||
dataset_path = config["dataset_path"]
|
||||
|
||||
if os.path.exists(dataset_path):
|
||||
if os.path.isdir(dataset_path):
|
||||
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
|
||||
else:
|
||||
files = [dataset_path]
|
||||
|
||||
print(f"Reading files {files}")
|
||||
|
||||
dataset = load_dataset("json", data_files=files, split="train")
|
||||
|
||||
else:
|
||||
dataset = load_dataset(dataset_path, split="train")
|
||||
|
||||
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
||||
|
||||
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||
|
||||
if config["streaming"] is False:
|
||||
kwargs = {"num_proc": config["num_proc"]}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
# tokenize inputs and return labels and attention mask
|
||||
train_dataset = train_dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||
batched=True,
|
||||
remove_columns=["source", "prompt"],
|
||||
**kwargs
|
||||
)
|
||||
val_dataset = val_dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||
batched=True,
|
||||
remove_columns=["source", "prompt"],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
val_dataset = val_dataset.with_format("torch")
|
||||
|
||||
# create dataloader with default data collator since we already have labels
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
batch_size=config["batch_size"],
|
||||
)
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
batch_size=config["batch_size"],
|
||||
)
|
||||
|
||||
return train_dataloader, val_dataloader
|
||||
|
||||
|
||||
def load_data_for_inference(config, tokenizer):
|
||||
dataset_path = config["dataset_path"]
|
||||
|
||||
if os.path.exists(dataset_path):
|
||||
# check if path is a directory
|
||||
if os.path.isdir(dataset_path):
|
||||
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
|
||||
else:
|
||||
files = [dataset_path]
|
||||
|
||||
print(f"Reading files {files}")
|
||||
|
||||
dataset = load_dataset("json", data_files=files, split="train")
|
||||
|
||||
else:
|
||||
dataset = load_dataset(dataset_path, split="train")
|
||||
|
||||
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
||||
|
||||
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||
|
||||
train_dataset = train_dataset.add_column("index", list(range(len(train_dataset))))
|
||||
# select first N batches that are divisible by batch_size
|
||||
# gather is a bit annoying (or the way I'm using it) to get uneven batches as it duplicates data
|
||||
train_dataset = train_dataset.select(range((len(train_dataset) // config["batch_size"]) * config["batch_size"]))
|
||||
val_dataset = val_dataset.add_column("index", list(range(len(val_dataset))))
|
||||
val_dataset = val_dataset.select(range((len(val_dataset) // config["batch_size"]) * config["batch_size"]))
|
||||
|
||||
if config["streaming"] is False:
|
||||
kwargs = {"num_proc": config["num_proc"]}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
# tokenize inputs and return labels and attention mask
|
||||
train_dataset = train_dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||
batched=True,
|
||||
**kwargs
|
||||
)
|
||||
val_dataset = val_dataset.map(
|
||||
lambda ele: tokenize_inputs(config, tokenizer, ele),
|
||||
batched=True,
|
||||
**kwargs
|
||||
)
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
val_dataset = val_dataset.with_format("torch")
|
||||
|
||||
return train_dataset, val_dataset
|
20
gpt4all-training/env.yaml
Normal file
20
gpt4all-training/env.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
name: vicuna
|
||||
channels:
|
||||
- conda-forge
|
||||
- pytorch
|
||||
- nvidia
|
||||
- huggingface
|
||||
dependencies:
|
||||
- python=3.8
|
||||
- accelerate
|
||||
- datasets
|
||||
- torchmetrics
|
||||
- evaluate
|
||||
- transformers
|
||||
- wandb
|
||||
- jsonlines
|
||||
- pip:
|
||||
- peft
|
||||
- nodelist-inflator
|
||||
- deepspeed
|
||||
- sentencepiece
|
28
gpt4all-training/eval_figures.py
Normal file
28
gpt4all-training/eval_figures.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import glob
|
||||
import pickle
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
plt.figure()
|
||||
for fpath in glob.glob('./eval_data/*.pkl'):
|
||||
parts = fpath.split('__')
|
||||
model_name = "-".join(fpath.replace(".pkl", "").split("_")[2:])
|
||||
with open(fpath, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
perplexities = data['perplexities']
|
||||
perplexities = np.nan_to_num(perplexities, 100)
|
||||
perplexities = np.clip(perplexities, 0, 100)
|
||||
if 'alpaca' not in fpath:
|
||||
identifier = model_name = "-".join(fpath.replace(".pkl", "").split("eval__model-")[1:])
|
||||
label = 'GPT4all-'
|
||||
label += identifier
|
||||
|
||||
else:
|
||||
label = 'alpaca-lora'
|
||||
plt.hist(perplexities, label=label, alpha=.5, bins=50)
|
||||
|
||||
plt.xlabel('Perplexity')
|
||||
plt.ylabel('Frequency')
|
||||
plt.legend()
|
||||
plt.savefig('figs/perplexity_hist.png')
|
||||
|
108
gpt4all-training/eval_self_instruct.py
Normal file
108
gpt4all-training/eval_self_instruct.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import json
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from read import read_config
|
||||
from argparse import ArgumentParser
|
||||
from peft import PeftModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
'''
|
||||
Evaluates perplexity on the outputs of:
|
||||
https://github.com/yizhongw/self-instruct/blob/main/human_eval/user_oriented_instructions.jsonl
|
||||
'''
|
||||
|
||||
def read_jsonl_file(file_path):
|
||||
data = []
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
for line in file:
|
||||
json_object = json.loads(line.strip())
|
||||
data.append(json_object)
|
||||
return data
|
||||
|
||||
def setup_model(config):
|
||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"], device_map="auto", torch_dtype=torch.float16, output_hidden_states=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
|
||||
added_tokens = tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>"})
|
||||
|
||||
if added_tokens > 0:
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if 'lora' in config and config['lora']:
|
||||
model = PeftModelForCausalLM.from_pretrained(model, config["lora_path"], device_map="auto", torch_dtype=torch.float16, return_hidden_states=True)
|
||||
model.to(dtype=torch.float16)
|
||||
|
||||
print(f"Mem needed: {model.get_memory_footprint() / 1024 / 1024 / 1024:.2f} GB")
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
||||
|
||||
def eval_example(model, tokenizer, example, config):
|
||||
|
||||
prompt = example['instruction'] + ' ' + example['instances'][0]['input']
|
||||
gt = prompt + ' ' + example['instances'][0]['output']
|
||||
|
||||
#decode several continuations and compute their page trajectories
|
||||
input = tokenizer(prompt, return_tensors="pt")
|
||||
input = {k: v.to(model.device) for k, v in input.items()}
|
||||
|
||||
#compute the ground truth perplexity
|
||||
gt_input = tokenizer(gt, return_tensors="pt")
|
||||
gt_input = {k: v.to(model.device) for k, v in gt_input.items()}
|
||||
|
||||
nlls = []
|
||||
prev_end_loc = 0
|
||||
stride = 512
|
||||
seq_len = gt_input['input_ids'].size(1)
|
||||
|
||||
for begin_loc in tqdm(range(input['input_ids'].size(1), gt_input['input_ids'].size(1), stride)):
|
||||
end_loc = min(begin_loc + stride, seq_len)
|
||||
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
|
||||
input_ids = gt_input['input_ids'][:, begin_loc:end_loc].to(model.device)
|
||||
target_ids = input_ids.clone()
|
||||
target_ids[:, :-trg_len] = -100
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids, labels=target_ids)
|
||||
neg_log_likelihood = outputs.loss * trg_len
|
||||
|
||||
nlls.append(neg_log_likelihood)
|
||||
prev_end_loc = end_loc
|
||||
if end_loc == seq_len:
|
||||
break
|
||||
|
||||
ppl = torch.exp(torch.stack(nlls).sum() / end_loc).item()
|
||||
print('ppl: ', ppl)
|
||||
|
||||
print(prompt)
|
||||
print(80*'-')
|
||||
|
||||
|
||||
return ppl
|
||||
|
||||
def do_eval(config):
|
||||
eval_data = read_jsonl_file('eval_data/user_oriented_instructions.jsonl')
|
||||
model, tokenizer = setup_model(config)
|
||||
all_perplexities = []
|
||||
for example in tqdm(eval_data):
|
||||
gt_perplexity = eval_example(model, tokenizer, example, config)
|
||||
all_perplexities.append(gt_perplexity)
|
||||
|
||||
|
||||
name = f"eval_data/eval__model-{config['model_name'].replace('/', '_')}{'__lora-' + config['lora_path'].replace('/', '_') if config['lora'] else ''}.pkl"
|
||||
|
||||
with open(name, 'wb') as f:
|
||||
r = {'perplexities': all_perplexities}
|
||||
pickle.dump(r, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = read_config(args.config)
|
||||
do_eval(config)
|
BIN
gpt4all-training/figs/clustering_overfit.png
Normal file
BIN
gpt4all-training/figs/clustering_overfit.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.3 MiB |
BIN
gpt4all-training/figs/duplicate_loss.png
Normal file
BIN
gpt4all-training/figs/duplicate_loss.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 362 KiB |
BIN
gpt4all-training/figs/first_lora.png
Normal file
BIN
gpt4all-training/figs/first_lora.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 308 KiB |
BIN
gpt4all-training/figs/overfit-gpt-j.png
Normal file
BIN
gpt4all-training/figs/overfit-gpt-j.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 356 KiB |
BIN
gpt4all-training/figs/perplexity_hist.png
Normal file
BIN
gpt4all-training/figs/perplexity_hist.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 15 KiB |
BIN
gpt4all-training/figs/single_epoch.png
Normal file
BIN
gpt4all-training/figs/single_epoch.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 353 KiB |
58
gpt4all-training/generate.py
Normal file
58
gpt4all-training/generate.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import PeftModelForCausalLM
|
||||
from read import read_config
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import time
|
||||
|
||||
|
||||
def generate(tokenizer, prompt, model, config):
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
|
||||
|
||||
outputs = model.generate(input_ids=input_ids, max_new_tokens=config["max_new_tokens"], temperature=config["temperature"])
|
||||
|
||||
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
||||
|
||||
return decoded[len(prompt):]
|
||||
|
||||
|
||||
def setup_model(config):
|
||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"], device_map="auto", torch_dtype=torch.float16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
|
||||
added_tokens = tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>"})
|
||||
|
||||
if added_tokens > 0:
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if config["lora"]:
|
||||
model = PeftModelForCausalLM.from_pretrained(model, config["lora_path"], device_map="auto", torch_dtype=torch.float16)
|
||||
model.to(dtype=torch.float16)
|
||||
|
||||
print(f"Mem needed: {model.get_memory_footprint() / 1024 / 1024 / 1024:.2f} GB")
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
parser.add_argument("--prompt", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = read_config(args.config)
|
||||
|
||||
if config["prompt"] is None and args.prompt is None:
|
||||
raise ValueError("Prompt is required either in config or as argument")
|
||||
|
||||
prompt = config["prompt"] if args.prompt is None else args.prompt
|
||||
|
||||
print("Setting up model")
|
||||
model, tokenizer = setup_model(config)
|
||||
|
||||
print("Generating")
|
||||
start = time.time()
|
||||
generation = generate(tokenizer, prompt, model, config)
|
||||
print(f"Done in {time.time() - start:.2f}s")
|
||||
print(generation)
|
204
gpt4all-training/inference.py
Normal file
204
gpt4all-training/inference.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from argparse import ArgumentParser
|
||||
from read import read_config
|
||||
from accelerate.utils import set_seed
|
||||
from data import load_data_for_inference
|
||||
from tqdm import tqdm
|
||||
from datasets import Dataset
|
||||
import torch.distributed as dist
|
||||
from transformers.trainer_pt_utils import nested_numpify
|
||||
from transformers import DefaultDataCollator
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pyarrow import compute as pc
|
||||
|
||||
|
||||
def calc_cross_entropy_no_reduction(lm_logits, labels):
|
||||
# calculate cross entropy across batch dim
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels).mean(dim=1)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def rank0_print(msg):
|
||||
if dist.get_rank() == 0:
|
||||
print(msg)
|
||||
|
||||
|
||||
def inference(config):
|
||||
set_seed(config['seed'])
|
||||
|
||||
rank0_print(f"World size: {dist.get_world_size()}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
|
||||
# llama has no pad token, set it to new token
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
train_dataset, val_dataset = load_data_for_inference(config, tokenizer)
|
||||
|
||||
num_processes = dist.get_world_size()
|
||||
local_rank = dist.get_rank()
|
||||
|
||||
train_sampler = DistributedSampler(train_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
batch_size=config["batch_size"],
|
||||
sampler=train_sampler,
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
batch_size=config["batch_size"],
|
||||
sampler=val_sampler,
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
model.to(f"cuda:{local_rank}")
|
||||
|
||||
with torch.no_grad():
|
||||
train_outputs = {"loss": [], "embeddings": [], "index": []}
|
||||
for batch in tqdm(train_dataloader, disable=local_rank != 0):
|
||||
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
|
||||
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
|
||||
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
|
||||
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
|
||||
train_outputs["loss"].extend(loss)
|
||||
|
||||
embeddings = outputs.hidden_states[-1]
|
||||
batch_size = batch["input_ids"].shape[0]
|
||||
sequence_lengths = []
|
||||
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
|
||||
# <|endoftext|> is repeated
|
||||
for item in batch["input_ids"]:
|
||||
indices = torch.where(item == tokenizer.pad_token_id)[0]
|
||||
found = False
|
||||
for index in indices:
|
||||
# case where sequence is less than max length
|
||||
if torch.all(item[index:] == tokenizer.pad_token_id):
|
||||
sequence_lengths.append(index)
|
||||
found = True
|
||||
break
|
||||
# case where sequence is >= max length
|
||||
if not found:
|
||||
sequence_lengths.append(len(item) - 1)
|
||||
|
||||
sequence_lengths = torch.tensor(sequence_lengths)
|
||||
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
|
||||
|
||||
train_outputs["embeddings"].append(pooled_logits)
|
||||
train_outputs["index"].extend(batch["index"].to(model.device))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
train_outputs = nested_numpify(train_outputs)
|
||||
# stack since they're 0-dim arrays
|
||||
train_outputs["index"] = np.stack(train_outputs["index"])
|
||||
train_outputs["loss"] = np.stack(train_outputs["loss"])
|
||||
train_outputs["embeddings"] = np.concatenate(train_outputs["embeddings"])
|
||||
|
||||
df_train = Dataset.from_dict(train_outputs)
|
||||
curr_idx = df_train["index"]
|
||||
|
||||
# compute mask in pyarrow since it's super fast
|
||||
# ty @bmschmidt for showing me this!
|
||||
table = train_dataset.data
|
||||
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
|
||||
filtered_table = table.filter(mask)
|
||||
# convert from pyarrow to Dataset
|
||||
filtered_train = Dataset.from_dict(filtered_table.to_pydict())
|
||||
|
||||
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
|
||||
filtered_train = filtered_train.add_column("loss", df_train["loss"])
|
||||
filtered_train = filtered_train.add_column("is_train", [True] * len(filtered_train))
|
||||
|
||||
filtered_train.to_json(f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
|
||||
|
||||
val_outputs = {"loss": [], "embeddings": [], "index": []}
|
||||
for batch in tqdm(val_dataloader, disable=local_rank != 0):
|
||||
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
|
||||
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
|
||||
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
|
||||
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
|
||||
val_outputs["loss"].extend(loss)
|
||||
|
||||
embeddings = outputs.hidden_states[-1]
|
||||
batch_size = batch["input_ids"].shape[0]
|
||||
sequence_lengths = []
|
||||
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
|
||||
# <|endoftext|> is repeated
|
||||
for item in batch["input_ids"]:
|
||||
indices = torch.where(item == tokenizer.pad_token_id)[0]
|
||||
found = False
|
||||
for index in indices:
|
||||
# case where sequence is less than max length
|
||||
if torch.all(item[index:] == tokenizer.pad_token_id):
|
||||
sequence_lengths.append(index)
|
||||
found = True
|
||||
break
|
||||
# case where sequence is >= max length
|
||||
if not found:
|
||||
sequence_lengths.append(len(item) - 1)
|
||||
|
||||
sequence_lengths = torch.tensor(sequence_lengths)
|
||||
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
|
||||
|
||||
val_outputs["embeddings"].append(pooled_logits)
|
||||
val_outputs["index"].extend(batch["index"].to(model.device))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
val_outputs = nested_numpify(val_outputs)
|
||||
val_outputs["index"] = np.stack(val_outputs["index"])
|
||||
val_outputs["loss"] = np.stack(val_outputs["loss"])
|
||||
val_outputs["embeddings"] = np.concatenate(val_outputs["embeddings"])
|
||||
|
||||
df_val = Dataset.from_dict(val_outputs)
|
||||
curr_idx = df_val["index"]
|
||||
|
||||
# compute mask in pyarrow since it's super fast
|
||||
# ty @bmschmidt for showing me this!
|
||||
table = val_dataset.data
|
||||
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
|
||||
filtered_table = table.filter(mask)
|
||||
# convert from pyarrow to Dataset
|
||||
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
|
||||
filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
|
||||
filtered_val = filtered_val.add_column("loss", df_val["loss"])
|
||||
filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val))
|
||||
|
||||
filtered_val.to_json(f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
|
||||
|
||||
|
||||
def main():
|
||||
dist.init_process_group("nccl")
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="config.yaml")
|
||||
|
||||
args = parser.parse_args()
|
||||
config = read_config(args.config)
|
||||
|
||||
inference(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parse arguments by reading in a config
|
||||
main()
|
||||
|
88
gpt4all-training/launcher.sh
Normal file
88
gpt4all-training/launcher.sh
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Display header
|
||||
echo "=========================================================="
|
||||
echo " ██████ ██████ ████████ ██ ██ █████ ██ ██ "
|
||||
echo "██ ██ ██ ██ ██ ██ ██ ██ ██ ██ "
|
||||
echo "██ ███ ██████ ██ ███████ ███████ ██ ██ "
|
||||
echo "██ ██ ██ ██ ██ ██ ██ ██ ██ "
|
||||
echo " ██████ ██ ██ ██ ██ ██ ███████ ███████ "
|
||||
echo " └─> https://github.com/nomic-ai/gpt4all"
|
||||
|
||||
# Function to detect macOS architecture and set the binary filename
|
||||
detect_mac_arch() {
|
||||
local mac_arch
|
||||
mac_arch=$(uname -m)
|
||||
case "$mac_arch" in
|
||||
arm64)
|
||||
os_type="M1 Mac/OSX"
|
||||
binary_filename="gpt4all-lora-quantized-OSX-m1"
|
||||
;;
|
||||
x86_64)
|
||||
os_type="Intel Mac/OSX"
|
||||
binary_filename="gpt4all-lora-quantized-OSX-intel"
|
||||
;;
|
||||
*)
|
||||
echo "Unknown macOS architecture"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Detect operating system and set the binary filename
|
||||
case "$(uname -s)" in
|
||||
Darwin*)
|
||||
detect_mac_arch
|
||||
;;
|
||||
Linux*)
|
||||
if grep -q Microsoft /proc/version; then
|
||||
os_type="Windows (WSL)"
|
||||
binary_filename="gpt4all-lora-quantized-win64.exe"
|
||||
else
|
||||
os_type="Linux"
|
||||
binary_filename="gpt4all-lora-quantized-linux-x86"
|
||||
fi
|
||||
;;
|
||||
CYGWIN*|MINGW32*|MSYS*|MINGW*)
|
||||
os_type="Windows (Cygwin/MSYS/MINGW)"
|
||||
binary_filename="gpt4all-lora-quantized-win64.exe"
|
||||
;;
|
||||
*)
|
||||
echo "Unknown operating system"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
echo "================================"
|
||||
echo "== You are using $os_type."
|
||||
|
||||
|
||||
# Change to the chat directory
|
||||
cd chat
|
||||
|
||||
# List .bin files and prompt user to select one
|
||||
bin_files=(*.bin)
|
||||
echo "== Available .bin files:"
|
||||
for i in "${!bin_files[@]}"; do
|
||||
echo " [$((i+1))] ${bin_files[i]}"
|
||||
done
|
||||
|
||||
# Function to get user input and validate it
|
||||
get_valid_user_input() {
|
||||
local input_valid=false
|
||||
|
||||
while ! $input_valid; do
|
||||
echo "==> Please enter a number:"
|
||||
read -r user_selection
|
||||
if [[ $user_selection =~ ^[0-9]+$ ]] && (( user_selection >= 1 && user_selection <= ${#bin_files[@]} )); then
|
||||
input_valid=true
|
||||
else
|
||||
echo "Invalid input. Please enter a number between 1 and ${#bin_files[@]}."
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
get_valid_user_input
|
||||
selected_bin_file="${bin_files[$((user_selection-1))]}"
|
||||
|
||||
# Run the selected .bin file with the appropriate command
|
||||
./"$binary_filename" -m "$selected_bin_file"
|
10
gpt4all-training/read.py
Normal file
10
gpt4all-training/read.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import yaml
|
||||
|
||||
|
||||
def read_config(path):
|
||||
# read yaml and return contents
|
||||
with open(path, 'r') as file:
|
||||
try:
|
||||
return yaml.safe_load(file)
|
||||
except yaml.YAMLError as exc:
|
||||
print(exc)
|
15
gpt4all-training/requirements.txt
Normal file
15
gpt4all-training/requirements.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
accelerate
|
||||
datasets
|
||||
torchmetrics
|
||||
evaluate
|
||||
transformers>=4.28.0
|
||||
wandb
|
||||
pip
|
||||
peft
|
||||
nodelist-inflator
|
||||
deepspeed
|
||||
sentencepiece
|
||||
jsonlines
|
||||
nomic
|
||||
scikit-learn
|
||||
matplotlib
|
233
gpt4all-training/train.py
Normal file
233
gpt4all-training/train.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import os
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM
|
||||
import torch
|
||||
from torch.optim import AdamW
|
||||
from argparse import ArgumentParser
|
||||
from read import read_config
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
|
||||
from peft import get_peft_model, LoraConfig, TaskType
|
||||
from data import load_data
|
||||
from torchmetrics import MeanMetric
|
||||
from tqdm import tqdm
|
||||
import wandb
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
def format_metrics(metrics, split, prefix=""):
|
||||
log = f"[{split}]" + prefix
|
||||
log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
|
||||
|
||||
return log
|
||||
|
||||
|
||||
def evaluate(model, val_dataloader):
|
||||
model.eval()
|
||||
val_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(val_dataloader):
|
||||
loss = model(**batch).loss
|
||||
|
||||
loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
|
||||
|
||||
val_loss.update(loss_values["loss"])
|
||||
|
||||
return val_loss
|
||||
|
||||
|
||||
def train(accelerator, config):
|
||||
set_seed(config['seed'])
|
||||
|
||||
accelerator.print(config)
|
||||
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
|
||||
# if no pad token, set it to eos
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
with accelerator.main_process_first():
|
||||
train_dataloader, val_dataloader = load_data(config, tokenizer)
|
||||
|
||||
|
||||
checkpoint = config["gradient_checkpointing"]
|
||||
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
||||
use_cache=False if checkpoint else True,
|
||||
trust_remote_code=True)
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
if config["lora"]:
|
||||
peft_config = LoraConfig(
|
||||
# should R be configurable?
|
||||
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
optimizer_cls = (
|
||||
AdamW
|
||||
if accelerator.state.deepspeed_plugin is None
|
||||
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
else DummyOptim
|
||||
)
|
||||
|
||||
# karpathy doesn't decay embeddding, maybe we should exclude
|
||||
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
|
||||
optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
|
||||
|
||||
if accelerator.state.deepspeed_plugin is not None:
|
||||
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||
"gradient_accumulation_steps"
|
||||
]
|
||||
|
||||
# decay to min_lr instead of 0
|
||||
lr_ratio = config["min_lr"] / config["lr"]
|
||||
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
|
||||
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"]
|
||||
# instead of decaying to zero, decay to ratio of min_lr / lr
|
||||
total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"]
|
||||
accelerator.print(f"Total training steps: {total_num_steps}")
|
||||
|
||||
# Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
|
||||
if (
|
||||
accelerator.state.deepspeed_plugin is None
|
||||
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
):
|
||||
scheduler = get_scheduler(
|
||||
name="cosine",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
|
||||
num_training_steps=total_num_steps,
|
||||
)
|
||||
else:
|
||||
scheduler = DummyScheduler(
|
||||
optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"]
|
||||
)
|
||||
|
||||
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, val_dataloader, scheduler
|
||||
)
|
||||
|
||||
# setup for saving training states in case preemption
|
||||
accelerator.register_for_checkpointing(scheduler)
|
||||
|
||||
if config["checkpoint"]:
|
||||
accelerator.load_state(config["checkpoint"])
|
||||
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
|
||||
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
|
||||
training_difference = os.path.splitext(path)[0]
|
||||
resume_step = int(training_difference.replace("step_", ""))
|
||||
accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||
accelerator.print(f"Resuming from step {resume_step}")
|
||||
|
||||
|
||||
# log gradients
|
||||
if accelerator.is_main_process and config["wandb"]:
|
||||
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
|
||||
|
||||
for epoch in range(config["num_epochs"]):
|
||||
train_loss = MeanMetric(nan_strategy="error").to(model.device)
|
||||
for step, batch in enumerate(tqdm(train_dataloader)):
|
||||
model.train()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
# gather loss before backprop in case of gradient accumulation
|
||||
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
|
||||
train_loss.update(loss_values["loss"])
|
||||
|
||||
loss = loss / gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
# get gradient norm of all params
|
||||
|
||||
# log LR in case something weird happens
|
||||
if step > 0 and step % (config["eval_every"] // 10) == 0:
|
||||
if config["wandb"]:
|
||||
curr_step = step + epoch * len(train_dataloader)
|
||||
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step)
|
||||
|
||||
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
if step > 0 and step % config["save_every"] == 0:
|
||||
curr_step = step + epoch * len(train_dataloader)
|
||||
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
||||
|
||||
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
||||
val_loss = evaluate(model, val_dataloader)
|
||||
|
||||
log_train = {
|
||||
"train_loss": train_loss.compute()
|
||||
}
|
||||
log_val = {
|
||||
"val_loss": val_loss.compute()
|
||||
}
|
||||
|
||||
if config["wandb"]:
|
||||
curr_step = step + epoch * len(train_dataloader)
|
||||
accelerator.log({**log_train, **log_val}, step=curr_step)
|
||||
|
||||
accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
|
||||
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
|
||||
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
|
||||
|
||||
train_loss.reset()
|
||||
|
||||
accelerator.print(f"Epoch {epoch} finished")
|
||||
accelerator.print(f"Pushing to HF hub")
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
try:
|
||||
if accelerator.is_main_process:
|
||||
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
|
||||
|
||||
except Exception as e:
|
||||
accelerator.print(e)
|
||||
accelerator.print(f"Failed to push to hub")
|
||||
|
||||
unwrapped_model.save_pretrained(
|
||||
f"{config['output_dir']}/epoch_{epoch}",
|
||||
is_main_process=accelerator.is_main_process,
|
||||
save_function=accelerator.save,
|
||||
state_dict=accelerator.get_state_dict(model),
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(
|
||||
f"{config['output_dir']}/final",
|
||||
is_main_process=accelerator.is_main_process,
|
||||
save_function=accelerator.save,
|
||||
state_dict=accelerator.get_state_dict(model),
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parse arguments by reading in a config
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="config.yaml")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = read_config(args.config)
|
||||
|
||||
if config["wandb"]:
|
||||
accelerator = Accelerator(log_with="wandb")
|
||||
accelerator.init_trackers(
|
||||
project_name=config["wandb_project_name"],
|
||||
config=config,
|
||||
init_kwargs={"wandb": {"entity": config["wandb_entity"]}},
|
||||
)
|
||||
else:
|
||||
accelerator = Accelerator()
|
||||
|
||||
train(accelerator, config=config)
|
Reference in New Issue
Block a user