From a514492d527061c679b715dd56b75c20ce81f6a4 Mon Sep 17 00:00:00 2001 From: zanussbaum Date: Wed, 19 Apr 2023 18:11:02 -0400 Subject: [PATCH 01/36] chore: ignore ds store --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 2952a641..5fadbd31 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.DS_Store *.pkl ckpts* .deepspeed_env From 29c7ac7f735e672486bf8e1174562f269ac8d79d Mon Sep 17 00:00:00 2001 From: zanussbaum Date: Wed, 19 Apr 2023 18:12:03 -0400 Subject: [PATCH 02/36] refactor: clean up directory structure --- gpt4all/__init__.py | 0 gpt4all/eval/__init__.py | 0 .../eval/eval_figures.py | 0 .../eval/eval_self_instruct.py | 2 +- gpt4all/inference/__init__.py | 0 generate.py => gpt4all/inference/generate.py | 2 +- .../inference/inference.py | 4 +-- gpt4all/train/__init__.py | 0 train.py => gpt4all/train/train.py | 4 +-- gpt4all/utils/__init__.py | 0 data.py => gpt4all/utils/data.py | 2 +- read.py => gpt4all/utils/read.py | 0 setup.py | 34 +++++++++++++++++++ 13 files changed, 41 insertions(+), 7 deletions(-) create mode 100644 gpt4all/__init__.py create mode 100644 gpt4all/eval/__init__.py rename eval_figures.py => gpt4all/eval/eval_figures.py (100%) rename eval_self_instruct.py => gpt4all/eval/eval_self_instruct.py (98%) create mode 100644 gpt4all/inference/__init__.py rename generate.py => gpt4all/inference/generate.py (97%) rename inference.py => gpt4all/inference/inference.py (98%) create mode 100644 gpt4all/train/__init__.py rename train.py => gpt4all/train/train.py (99%) create mode 100644 gpt4all/utils/__init__.py rename data.py => gpt4all/utils/data.py (99%) rename read.py => gpt4all/utils/read.py (100%) create mode 100644 setup.py diff --git a/gpt4all/__init__.py b/gpt4all/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/gpt4all/eval/__init__.py b/gpt4all/eval/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/eval_figures.py b/gpt4all/eval/eval_figures.py similarity index 100% rename from eval_figures.py rename to gpt4all/eval/eval_figures.py diff --git a/eval_self_instruct.py b/gpt4all/eval/eval_self_instruct.py similarity index 98% rename from eval_self_instruct.py rename to gpt4all/eval/eval_self_instruct.py index e05a68e4..7206fdd5 100644 --- a/eval_self_instruct.py +++ b/gpt4all/eval/eval_self_instruct.py @@ -3,7 +3,7 @@ import torch import pickle import numpy as np from tqdm import tqdm -from read import read_config +from gpt4all.utils.read import read_config from argparse import ArgumentParser from peft import PeftModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer diff --git a/gpt4all/inference/__init__.py b/gpt4all/inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/generate.py b/gpt4all/inference/generate.py similarity index 97% rename from generate.py rename to gpt4all/inference/generate.py index fa1c43fa..f4184d62 100644 --- a/generate.py +++ b/gpt4all/inference/generate.py @@ -1,6 +1,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModelForCausalLM -from read import read_config +from gpt4all.utils.read import read_config from argparse import ArgumentParser import torch import time diff --git a/inference.py b/gpt4all/inference/inference.py similarity index 98% rename from inference.py rename to gpt4all/inference/inference.py index 8a4efb51..5e351c46 100644 --- a/inference.py +++ b/gpt4all/inference/inference.py @@ -2,9 +2,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import torch import torch.nn as nn from argparse import ArgumentParser -from read import read_config +from gpt4all.utils.read import read_config from accelerate.utils import set_seed -from data import load_data_for_inference +from gpt4all.utils.data import load_data_for_inference from tqdm import tqdm from datasets import Dataset import torch.distributed as dist diff --git a/gpt4all/train/__init__.py b/gpt4all/train/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/train.py b/gpt4all/train/train.py similarity index 99% rename from train.py rename to gpt4all/train/train.py index 8605af11..97b6c9a8 100644 --- a/train.py +++ b/gpt4all/train/train.py @@ -3,11 +3,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, Lla import torch from torch.optim import AdamW from argparse import ArgumentParser -from read import read_config +from gpt4all.utils.read import read_config from accelerate import Accelerator from accelerate.utils import DummyScheduler, DummyOptim, set_seed from peft import get_peft_model, LoraConfig, TaskType -from data import load_data +from gpt4all.utils.data import load_data from torchmetrics import MeanMetric from tqdm import tqdm import wandb diff --git a/gpt4all/utils/__init__.py b/gpt4all/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/data.py b/gpt4all/utils/data.py similarity index 99% rename from data.py rename to gpt4all/utils/data.py index 8227de00..b55a589a 100644 --- a/data.py +++ b/gpt4all/utils/data.py @@ -1,6 +1,6 @@ import glob import torch -from datasets import load_dataset, concatenate_datasets +from datasets import load_dataset import os from torch.utils.data import DataLoader from transformers import DefaultDataCollator diff --git a/read.py b/gpt4all/utils/read.py similarity index 100% rename from read.py rename to gpt4all/utils/read.py diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..6b100ed5 --- /dev/null +++ b/setup.py @@ -0,0 +1,34 @@ +from setuptools import setup, find_packages + +with open('README.md', 'r', encoding='utf-8') as f: + long_description = f.read() + +with open('requirements.txt', 'r', encoding='utf-8') as f: + requirements = [line.strip() for line in f if line.strip()] + +setup( + name='gpt4all', + version='0.0.1', + author='nomic-ai', + author_email='zach@nomic-ai', + description='an ecosystem of open-source chatbots trained on a massive collections of clean assistant data including code, stories and dialogue', + long_description=long_description, + long_description_content_type='text/markdown', + url='https://github.com/nomic-ai/gpt4all', + packages=find_packages(), + install_requires=requirements, + classifiers=[ + 'Development Status :: 3 - Alpha', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Topic :: Text Processing :: Linguistic', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Intended Audience :: Science/Research', + 'Operating System :: OS Independent', + ], + python_requires='>=3.6', +) \ No newline at end of file From 97a1cd05395e4ab36f197384774bbf79d26a4971 Mon Sep 17 00:00:00 2001 From: zanussbaum Date: Wed, 19 Apr 2023 18:13:24 -0400 Subject: [PATCH 03/36] chore: delete unused --- create_hostname.sh | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 create_hostname.sh diff --git a/create_hostname.sh b/create_hostname.sh deleted file mode 100644 index 8a9187f2..00000000 --- a/create_hostname.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -export WORKER_IP=$1 -N_GPUS=8 -# create dir if doesn't exist -sudo mkdir -p /job -printf "localhost slots=$N_GPUS\n$WORKER_IP slots=$N_GPUS" | sudo tee /job/hostfile -echo /job/hostfile \ No newline at end of file From 09cddbedc07f2ce406f96d4f81a8ea6dca5da61c Mon Sep 17 00:00:00 2001 From: zanussbaum Date: Thu, 20 Apr 2023 16:04:41 -0400 Subject: [PATCH 04/36] feat: models wip --- gpt4all/models/__init__.py | 8 + gpt4all/models/configuration_gpt_jr.py | 148 +++++ gpt4all/models/modeling_gpt_jr.py | 831 +++++++++++++++++++++++++ test_gpt_jr.py | 41 ++ 4 files changed, 1028 insertions(+) create mode 100644 gpt4all/models/__init__.py create mode 100644 gpt4all/models/configuration_gpt_jr.py create mode 100644 gpt4all/models/modeling_gpt_jr.py create mode 100644 test_gpt_jr.py diff --git a/gpt4all/models/__init__.py b/gpt4all/models/__init__.py new file mode 100644 index 00000000..138b5b41 --- /dev/null +++ b/gpt4all/models/__init__.py @@ -0,0 +1,8 @@ +from .configuration_gpt_jr import GPTJRConfig +from .modeling_gpt_jr import GPTJRForCausalLM + + +__all__ = [ + "GPTJRConfig", + "GPTJRForCausalLM" +] \ No newline at end of file diff --git a/gpt4all/models/configuration_gpt_jr.py b/gpt4all/models/configuration_gpt_jr.py new file mode 100644 index 00000000..e98ff2a8 --- /dev/null +++ b/gpt4all/models/configuration_gpt_jr.py @@ -0,0 +1,148 @@ +# coding=utf-8 +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" GPT-J model configuration""" +from collections import OrderedDict +from typing import Any, List, Mapping, Optional + +from transformers import PreTrainedTokenizer, TensorType, is_torch_available +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxConfigWithPast, PatchingSpec +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "EleutherAI/gpt-j-6B": "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/config.json", + # See all GPT-J models at https://huggingface.co/models?filter=gpt_j +} + + +class GPTJRConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GPTJModel`]. It is used to instantiate a GPT-J + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the GPT-J + [EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6B) architecture. Configuration objects inherit from + [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 50400): + Vocabulary size of the GPT-J model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPTJModel`]. + n_positions (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 4096): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + rotary_dim (`int`, *optional*, defaults to 64): + Number of dimensions in the embedding that Rotary Position Embedding is applied to. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_new"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`int`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + + Example: + + ```python + >>> from transformers import GPTJModel, GPTJConfig + + >>> # Initializing a GPT-J 6B configuration + >>> configuration = GPTJConfig() + + >>> # Initializing a model from the configuration + >>> model = GPTJModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "gptj" + attribute_map = { + "max_position_embeddings": "n_positions", + "hidden_size": "n_embd", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50400, + n_positions=2048, + n_embd=4096, + n_layer=28, + n_head=16, + rotary_dim=64, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + tie_word_embeddings=False, + encoder_ndim=4096, + alpha=.5, + encoder_path=None, + **kwargs + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.rotary_dim = rotary_dim + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.encoder_ndim = encoder_ndim + self.alpha = alpha + self.encoder_path = encoder_path + + super().__init__( + bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs + ) diff --git a/gpt4all/models/modeling_gpt_jr.py b/gpt4all/models/modeling_gpt_jr.py new file mode 100644 index 00000000..e86de220 --- /dev/null +++ b/gpt4all/models/modeling_gpt_jr.py @@ -0,0 +1,831 @@ +# coding=utf-8 +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch GPT-J model.""" + +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers import AutoModel +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from gpt4all.models.configuration_gpt_jr import GPTJRConfig + + +logger = logging.get_logger(__name__) + + +GPTJR_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "EleutherAI/gpt-j-6B", + # See all GPT-J models at https://huggingface.co/models?filter=gptj +] + + +def fixed_pos_embedding(x, seq_dim=1, seq_len=None): + dim = x.shape[-1] + if seq_len is None: + seq_len = x.shape[seq_dim] + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) + sinusoid_inp = ( + torch.einsum("i , j -> i j", torch.arange(seq_len, dtype=torch.float), inv_freq).to(x.device).float() + ) + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + + +def rotate_every_two(x): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + + +def duplicate_interleave(m): + """ + A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. + """ + dim0 = m.shape[0] + m = m.view(-1, 1) # flatten the matrix + m = m.repeat(1, 2) # repeat all elements into the 2nd dimension + m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy + return m + + +def apply_rotary_pos_emb(x, sincos, offset=0): + sin, cos = map(lambda t: duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :], sincos) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + + + +class GPTJRAttention(nn.Module): + def __init__(self, config): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e9)) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.embed_dim = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_attention_heads + if self.head_dim * self.num_attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and" + f" `num_attention_heads`: {self.num_attention_heads})." + ) + self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.rotary_dim = None + if config.rotary_dim is not None: + self.rotary_dim = config.rotary_dim + + def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + tensor = tensor.view(new_shape) + if rotary: + return tensor + if len(tensor.shape) == 5: + return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features) + elif len(tensor.shape) == 4: + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + + def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + if len(tensor.shape) == 5: + tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() + elif len(tensor.shape) == 4: + tensor = tensor.permute(0, 2, 1, 3).contiguous() + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) + return tensor.view(new_shape) + + def _attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None, + ): + + # compute causal mask from causal mask buffer + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) + + # Keep the attention weights computation in fp32 to avoid overflow issues + # TODO: do we need to do this with bfloat16?? + # query = query.to(torch.float32) + # key = key.to(torch.float32) + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + attn_weights = attn_weights / self.scale_attn + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.to(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + encoder_hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: + + query = self.q_proj(hidden_states) + # if we are doing cross attention + if encoder_hidden_states is not None: + key = self.k_proj(encoder_hidden_states) + value = self.v_proj(encoder_hidden_states) + else: + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) + value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) + + seq_len = key.shape[1] + offset = 0 + + if layer_past is not None: + offset = layer_past[0].shape[-2] + seq_len += offset + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) + k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) + q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) + else: + sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) + key = apply_rotary_pos_emb(key, sincos, offset=offset) + query = apply_rotary_pos_emb(query, sincos, offset=offset) + + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class GPTJRCrossAttention(GPTJRAttention): + def __init__(self, config): + super().__init__(config) + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e9)) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.embed_dim = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_attention_heads + if self.head_dim * self.num_attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and" + f" `num_attention_heads`: {self.num_attention_heads})." + ) + self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(config.encoder_ndim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(config.encoder_ndim, self.embed_dim, bias=False) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.rotary_dim = None + if config.rotary_dim is not None: + self.rotary_dim = config.rotary_dim + + + +class GPTJRMLP(nn.Module): + def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim + super().__init__() + embed_dim = config.n_embd + + self.fc_in = nn.Linear(embed_dim, intermediate_size) + self.fc_out = nn.Linear(intermediate_size, embed_dim) + + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor: + hidden_states = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc_out(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class GPTJRBlock(nn.Module): + def __init__(self, config): + super().__init__() + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = GPTJRAttention(config) + self.mlp = GPTJRMLP(inner_dim, config) + + # TODO: fix for n neighbors + # for SBERT this is 384 + self.ln_2 = nn.LayerNorm(config.encoder_ndim, eps=config.layer_norm_epsilon) + self.cross_attn = GPTJRCrossAttention(config) + self.alpha = config.alpha + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + encoder_hidden_states: Optional[torch.FloatTensor], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + # shape (bs, seq_len, hidden_dim) + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + + feed_forward_hidden_states = self.mlp(hidden_states) + self_attention_residual = attn_output + feed_forward_hidden_states + residual + + # encoder_hidden_states (bs, knn, encoder_dim) + encoder_normed = self.ln_2(encoder_hidden_states) + # TODO: how do we handle neighbors + + # TODO: we have to make sure we're doing masking right here + # TODO: T5 passes query length to cross attention, do we need that? + cross_attn_outputs = self.cross_attn( + residual, + encoder_hidden_states=encoder_normed, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + cross_attn_output = cross_attn_outputs[0] # output_attn: a, present, (attentions) + cross_attn_outputs = cross_attn_outputs[1:] + + hidden_states = self.alpha * cross_attn_output + (1 - self.alpha) * self_attention_residual + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions) + + +class GPTJRPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTJRConfig + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPTJRBlock"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear,)): + # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPTJRModel): + module.gradient_checkpointing = value + + +class GPTJRModel(GPTJRPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.n_embd + self.vocab_size = config.vocab_size + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPTJRBlock(config) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + None, + attention_mask, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + encoder_hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class GPTJRForCausalLM(GPTJRPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPTJRModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size) + if config.encoder_path is not None: + self.encoder = AutoModel.from_pretrained(config.encoder_path) + # freeze encoder and don't get gradiets + self.encoder.requires_grad_(False) + + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.FloatTensor] = None, + decoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + transformer_outputs = self.transformer( + input_ids, + encoder_hidden_states=encoder_outputs[0], + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + # make sure sampling in fp16 works correctly and + # compute loss in fp32 to match with mesh-tf version + # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 + # TODO: do we need to do conversion to fp32 if training in bf16? + lm_logits = self.lm_head(hidden_states).to(torch.float32) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or + [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + diff --git a/test_gpt_jr.py b/test_gpt_jr.py new file mode 100644 index 00000000..8f0dde8d --- /dev/null +++ b/test_gpt_jr.py @@ -0,0 +1,41 @@ +import torch +from gpt4all.models import GPTJRForCausalLM, GPTJRConfig +from transformers import AutoTokenizer, AutoModel + +print("loading model") +config = GPTJRConfig(encoder_ndim=384) +model = GPTJRForCausalLM(config) +print("loaded model") + + +tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b") + +encoder_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') +encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') + + +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] #First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + +text = "The quick brown fox jumps over the lazy dog." +print("Encoded knn") +tokenized = encoder_tokenizer(text, return_tensors="pt") + +encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"]) + +# make 2 neighbors +# (bs, knn, encoding_dim) +encoder_outputs = torch.stack([encodings, encodings]).unsqueeze(0) + +inputs = "What did the fox do?" + +print("Encoded inputs") +tokenized_input = tokenizer(inputs, padding="max_length", truncation="true", return_tensors="pt") + +print("Running model") +outputs = model(**tokenized_input, encoder_outputs=encoder_outputs) + +print(outputs.shape) + From df79fd64b017987f23ebebf95cde7d0273698ce9 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 21 Apr 2023 02:50:27 +0000 Subject: [PATCH 05/36] fix: forward works! --- gpt4all/models/configuration_gpt_jr.py | 6 +- gpt4all/models/modeling_gpt_jr.py | 194 +++++++++++++++++++++---- 2 files changed, 166 insertions(+), 34 deletions(-) diff --git a/gpt4all/models/configuration_gpt_jr.py b/gpt4all/models/configuration_gpt_jr.py index e98ff2a8..fdb9eec4 100644 --- a/gpt4all/models/configuration_gpt_jr.py +++ b/gpt4all/models/configuration_gpt_jr.py @@ -115,8 +115,7 @@ class GPTJRConfig(PretrainedConfig): bos_token_id=50256, eos_token_id=50256, tie_word_embeddings=False, - encoder_ndim=4096, - alpha=.5, + encoder_dim=4096, encoder_path=None, **kwargs ): @@ -139,8 +138,7 @@ class GPTJRConfig(PretrainedConfig): self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id - self.encoder_ndim = encoder_ndim - self.alpha = alpha + self.encoder_dim = encoder_dim self.encoder_path = encoder_path super().__init__( diff --git a/gpt4all/models/modeling_gpt_jr.py b/gpt4all/models/modeling_gpt_jr.py index e86de220..d2e57857 100644 --- a/gpt4all/models/modeling_gpt_jr.py +++ b/gpt4all/models/modeling_gpt_jr.py @@ -154,11 +154,6 @@ class GPTJRAttention(nn.Module): query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) - # Keep the attention weights computation in fp32 to avoid overflow issues - # TODO: do we need to do this with bfloat16?? - # query = query.to(torch.float32) - # key = key.to(torch.float32) - attn_weights = torch.matmul(query, key.transpose(-1, -2)) mask_value = torch.finfo(attn_weights.dtype).min @@ -188,7 +183,7 @@ class GPTJRAttention(nn.Module): def forward( self, hidden_states: Optional[torch.FloatTensor], - encoder_hidden_states: Optional[torch.FloatTensor], + encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -291,13 +286,131 @@ class GPTJRCrossAttention(GPTJRAttention): ) self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) - self.v_proj = nn.Linear(config.encoder_ndim, self.embed_dim, bias=False) - self.q_proj = nn.Linear(config.encoder_ndim, self.embed_dim, bias=False) + self.k_proj = nn.Linear(config.encoder_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(config.encoder_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) - self.rotary_dim = None - if config.rotary_dim is not None: - self.rotary_dim = config.rotary_dim + + + def _split_knn_attn_heads(self, tensor, num_attention_heads, attn_head_size): + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + tensor = tensor.view(new_shape) + + return tensor.permute(0, 2, 1) + + + def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + # tensor -> (bs, seq_len, num_attention_heads, head_dim) + new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) + return tensor.view(new_shape) + + def _attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None, + ): + + # compute causal mask from causal mask buffer + # since key and value don't have seq length, just use causal mask as normal + query_length = query.size(-2) + causal_mask = self.bias[:, :, : query_length, :query_length].to(torch.bool) + + # Keep the attention weights computation in fp32 to avoid overflow issues + # TODO: do we need to do this with bfloat16?? + # query = query.to(torch.float32) + # key = key.to(torch.float32) + + # query -> (bs, seq_len, num_attention_heads, head_dim) + # key -> (bs, num_attention_heads, head_dim) + # attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads) + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + attn_weights = attn_weights / self.scale_attn + + if attention_mask is not None: + # Apply the attention mask + # attn mask (1, 1, 1, seq_len) + attn_weights = (attn_weights.permute(0, 2, 3, 1) + attention_mask).permute(0, 3, 1, 2) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.to(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + # value -> (bs, num_attention_heads, head_dim) + # attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads) + # attn_output -> (bs, num_attention_heads, seq_len, head_dim) + attn_output = torch.matmul(attn_weights, value.transpose(-1, -2)) + + return attn_output, attn_weights + + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + encoder_hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: + + query = self.q_proj(hidden_states) + # if we are doing cross attention + key = self.k_proj(encoder_hidden_states) + value = self.v_proj(encoder_hidden_states) + # (bs, seq_len, dim) -> (bs, num_attention_heads, seq_len, head_dim) + query = self._split_heads(query, self.num_attention_heads, self.head_dim, False) + # (bs, dim) -> (bs, num_attention_heads, head_dim) + key = self._split_knn_attn_heads(key, self.num_attention_heads, self.head_dim) + value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim) + + + key = key.permute(0, 2, 1) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) @@ -330,9 +443,13 @@ class GPTJRBlock(nn.Module): # TODO: fix for n neighbors # for SBERT this is 384 - self.ln_2 = nn.LayerNorm(config.encoder_ndim, eps=config.layer_norm_epsilon) + self.ln_2 = nn.LayerNorm(config.encoder_dim, eps=config.layer_norm_epsilon) self.cross_attn = GPTJRCrossAttention(config) - self.alpha = config.alpha + self.cross_attn_mlp = GPTJRMLP(inner_dim, config) + + self.alpha = nn.Parameter(torch.ones(1), requires_grad=False).to(self.ln_1.weight.dtype) + self.step = 1 + self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else torch.cuda.device_count() or 1 def forward( self, @@ -361,31 +478,48 @@ class GPTJRBlock(nn.Module): feed_forward_hidden_states = self.mlp(hidden_states) self_attention_residual = attn_output + feed_forward_hidden_states + residual - # encoder_hidden_states (bs, knn, encoder_dim) + # encoder_hidden_states -> (bs, knn, encoder_dim) encoder_normed = self.ln_2(encoder_hidden_states) - # TODO: how do we handle neighbors - # TODO: we have to make sure we're doing masking right here - # TODO: T5 passes query length to cross attention, do we need that? - cross_attn_outputs = self.cross_attn( - residual, - encoder_hidden_states=encoder_normed, - attention_mask=attention_mask, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, + num_neighbors = encoder_normed.shape[1] + cross_attn_outputs = [] + for k in range(num_neighbors): + # cross_attn_outputs -> (bs, seq_len, num_attention_heads, head_dim) + cross_attn_output = self.cross_attn( + residual, + encoder_hidden_states=encoder_normed[:, k, :], + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + cross_attn_outputs.append(cross_attn_output[0]) + + cross_attn_output = torch.stack(cross_attn_outputs, dim=1).mean(dim=1) + # gpt-j has parallel ff + attn, can do ff on encoder_normed too I guess? + cross_attn_ff = self.cross_attn_mlp( + cross_attn_output ) - cross_attn_output = cross_attn_outputs[0] # output_attn: a, present, (attentions) - cross_attn_outputs = cross_attn_outputs[1:] - - hidden_states = self.alpha * cross_attn_output + (1 - self.alpha) * self_attention_residual + + alpha = self.alpha if self.training else 0.5 + + hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual if use_cache: outputs = (hidden_states,) + outputs else: outputs = (hidden_states,) + outputs[1:] + # if training update alpha + if self.training: + self.step += 1 + self._update_alpha(self.step) + + return outputs # hidden_states, present, (attentions) + + def _update_alpha(self, iteration): + self.alpha.data = torch.clamp(torch.tensor([1 / (iteration * self.world_size) ** 0.05]), min=torch.tensor([0.5]), max=torch.tensor([1.0])) class GPTJRPreTrainedModel(PreTrainedModel): @@ -769,7 +903,7 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel): transformer_outputs = self.transformer( input_ids, - encoder_hidden_states=encoder_outputs[0], + encoder_hidden_states=encoder_outputs, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, From aa814757fcfe54bbd8fe544622f940ba7f21fbf7 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 21 Apr 2023 04:18:16 +0000 Subject: [PATCH 06/36] fix: testing works --- test_gpt_jr.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/test_gpt_jr.py b/test_gpt_jr.py index 8f0dde8d..e8c406dc 100644 --- a/test_gpt_jr.py +++ b/test_gpt_jr.py @@ -2,13 +2,16 @@ import torch from gpt4all.models import GPTJRForCausalLM, GPTJRConfig from transformers import AutoTokenizer, AutoModel +config = GPTJRConfig(encoder_dim=384, n_layer=4) +print("loaded config") + print("loading model") -config = GPTJRConfig(encoder_ndim=384) model = GPTJRForCausalLM(config) print("loaded model") tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b") +tokenizer.pad_token = tokenizer.eos_token encoder_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') @@ -23,19 +26,21 @@ text = "The quick brown fox jumps over the lazy dog." print("Encoded knn") tokenized = encoder_tokenizer(text, return_tensors="pt") +# bs, seq_len, dim encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"]) # make 2 neighbors # (bs, knn, encoding_dim) -encoder_outputs = torch.stack([encodings, encodings]).unsqueeze(0) +encoder_outputs = torch.stack([encodings, encodings]).squeeze().unsqueeze(0) inputs = "What did the fox do?" print("Encoded inputs") -tokenized_input = tokenizer(inputs, padding="max_length", truncation="true", return_tensors="pt") +tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt") print("Running model") outputs = model(**tokenized_input, encoder_outputs=encoder_outputs) -print(outputs.shape) +print(outputs) +print(outputs[0].shape) From e62baf87f8eb5b10a891816aa6e36d1ed5f6ed5b Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 21 Apr 2023 04:19:37 +0000 Subject: [PATCH 07/36] fix: seed --- test_gpt_jr.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test_gpt_jr.py b/test_gpt_jr.py index e8c406dc..77a2c2fd 100644 --- a/test_gpt_jr.py +++ b/test_gpt_jr.py @@ -2,6 +2,10 @@ import torch from gpt4all.models import GPTJRForCausalLM, GPTJRConfig from transformers import AutoTokenizer, AutoModel +# seed torch + +torch.manual_seed(0) + config = GPTJRConfig(encoder_dim=384, n_layer=4) print("loaded config") From ca66d12d89042a6ca0a91caccd82f365eb522306 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 21 Apr 2023 14:23:33 +0000 Subject: [PATCH 08/36] fix: remove causal cross attn mask --- gpt4all/models/modeling_gpt_jr.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/gpt4all/models/modeling_gpt_jr.py b/gpt4all/models/modeling_gpt_jr.py index d2e57857..3a0cf584 100644 --- a/gpt4all/models/modeling_gpt_jr.py +++ b/gpt4all/models/modeling_gpt_jr.py @@ -316,34 +316,11 @@ class GPTJRCrossAttention(GPTJRAttention): head_mask=None, ): - # compute causal mask from causal mask buffer - # since key and value don't have seq length, just use causal mask as normal - query_length = query.size(-2) - causal_mask = self.bias[:, :, : query_length, :query_length].to(torch.bool) - - # Keep the attention weights computation in fp32 to avoid overflow issues - # TODO: do we need to do this with bfloat16?? - # query = query.to(torch.float32) - # key = key.to(torch.float32) - # query -> (bs, seq_len, num_attention_heads, head_dim) # key -> (bs, num_attention_heads, head_dim) # attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads) attn_weights = torch.matmul(query, key.transpose(-1, -2)) - mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights, mask_value) - - attn_weights = attn_weights / self.scale_attn - - if attention_mask is not None: - # Apply the attention mask - # attn mask (1, 1, 1, seq_len) - attn_weights = (attn_weights.permute(0, 2, 3, 1) + attention_mask).permute(0, 3, 1, 2) - attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = attn_weights.to(value.dtype) attn_weights = self.attn_dropout(attn_weights) @@ -441,8 +418,6 @@ class GPTJRBlock(nn.Module): self.attn = GPTJRAttention(config) self.mlp = GPTJRMLP(inner_dim, config) - # TODO: fix for n neighbors - # for SBERT this is 384 self.ln_2 = nn.LayerNorm(config.encoder_dim, eps=config.layer_norm_epsilon) self.cross_attn = GPTJRCrossAttention(config) self.cross_attn_mlp = GPTJRMLP(inner_dim, config) From e255e0a805da5455f12a27fd2bafed24713d04aa Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 21 Apr 2023 21:54:47 +0000 Subject: [PATCH 09/36] fix: batched xattn --- gpt4all/models/modeling_gpt_jr.py | 48 ++++++++++++++----------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/gpt4all/models/modeling_gpt_jr.py b/gpt4all/models/modeling_gpt_jr.py index 3a0cf584..a38ad5b0 100644 --- a/gpt4all/models/modeling_gpt_jr.py +++ b/gpt4all/models/modeling_gpt_jr.py @@ -296,7 +296,7 @@ class GPTJRCrossAttention(GPTJRAttention): new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1) + return tensor.permute(0, 1, 3, 2) def _merge_heads(self, tensor, num_attention_heads, attn_head_size): @@ -304,6 +304,7 @@ class GPTJRCrossAttention(GPTJRAttention): Merges attn_head_size dim and num_attn_heads dim into hidden dim """ # tensor -> (bs, seq_len, num_attention_heads, head_dim) + tensor = tensor.permute(0, 2, 1, 3).contiguous() new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) return tensor.view(new_shape) @@ -316,10 +317,10 @@ class GPTJRCrossAttention(GPTJRAttention): head_mask=None, ): - # query -> (bs, seq_len, num_attention_heads, head_dim) - # key -> (bs, num_attention_heads, head_dim) - # attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads) - attn_weights = torch.matmul(query, key.transpose(-1, -2)) + # query -> (bs, num_attention_heads, seq_len, head_dim) + # key -> (bs, num_attention_heads, head_dim, neighbors) + # attn_weights -> (bs, num_attention_heads, seq_len, neighbors) + attn_weights = torch.matmul(query, key) attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = attn_weights.to(value.dtype) @@ -329,10 +330,10 @@ class GPTJRCrossAttention(GPTJRAttention): if head_mask is not None: attn_weights = attn_weights * head_mask - # value -> (bs, num_attention_heads, head_dim) - # attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads) + # value -> (bs, num_attention_heads, seq_len, head_dim) + # attn_weights -> (bs, num_attention_heads, seq_len, neighbors) # attn_output -> (bs, num_attention_heads, seq_len, head_dim) - attn_output = torch.matmul(attn_weights, value.transpose(-1, -2)) + attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights @@ -362,8 +363,8 @@ class GPTJRCrossAttention(GPTJRAttention): value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim) - key = key.permute(0, 2, 1) - query = query.permute(0, 2, 1, 3) + value = value.permute(0, 3, 1, 2) + key = key.permute(0, 3, 2, 1) if layer_past is not None: past_key = layer_past[0] @@ -454,30 +455,25 @@ class GPTJRBlock(nn.Module): self_attention_residual = attn_output + feed_forward_hidden_states + residual # encoder_hidden_states -> (bs, knn, encoder_dim) + # may not need, can norm encodings encoder_normed = self.ln_2(encoder_hidden_states) - num_neighbors = encoder_normed.shape[1] - cross_attn_outputs = [] - for k in range(num_neighbors): - # cross_attn_outputs -> (bs, seq_len, num_attention_heads, head_dim) - cross_attn_output = self.cross_attn( - residual, - encoder_hidden_states=encoder_normed[:, k, :], - attention_mask=attention_mask, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - cross_attn_outputs.append(cross_attn_output[0]) + # cross_attn_outputs -> (bs, seq_len, dim) + cross_attn_output = self.cross_attn( + hidden_states, + encoder_hidden_states=encoder_normed, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) - cross_attn_output = torch.stack(cross_attn_outputs, dim=1).mean(dim=1) # gpt-j has parallel ff + attn, can do ff on encoder_normed too I guess? cross_attn_ff = self.cross_attn_mlp( - cross_attn_output + cross_attn_output[0] ) alpha = self.alpha if self.training else 0.5 - hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual if use_cache: From 4671f4e82f3eb0905eb0be2a25e808463da9ee1d Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sat, 22 Apr 2023 19:37:30 +0000 Subject: [PATCH 10/36] chore: pull out common dist print fn --- gpt4all/inference/inference.py | 102 +++++++++++++++++++---------- gpt4all/utils/distributed_utils.py | 6 ++ 2 files changed, 73 insertions(+), 35 deletions(-) create mode 100644 gpt4all/utils/distributed_utils.py diff --git a/gpt4all/inference/inference.py b/gpt4all/inference/inference.py index 5e351c46..2096819c 100644 --- a/gpt4all/inference/inference.py +++ b/gpt4all/inference/inference.py @@ -3,12 +3,13 @@ import torch import torch.nn as nn from argparse import ArgumentParser from gpt4all.utils.read import read_config -from accelerate.utils import set_seed +from accelerate.utils import set_seed from gpt4all.utils.data import load_data_for_inference +from gpt4all.utils.distributed_utils import rank0_print from tqdm import tqdm -from datasets import Dataset +from datasets import Dataset import torch.distributed as dist -from transformers.trainer_pt_utils import nested_numpify +from transformers.trainer_pt_utils import nested_numpify from transformers import DefaultDataCollator from torch.utils.data import DataLoader, DistributedSampler import numpy as np @@ -21,56 +22,64 @@ def calc_cross_entropy_no_reduction(lm_logits, labels): shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = nn.CrossEntropyLoss(reduction='none') + loss_fct = nn.CrossEntropyLoss(reduction="none") loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels).mean(dim=1) return loss -def rank0_print(msg): - if dist.get_rank() == 0: - print(msg) - - def inference(config): - set_seed(config['seed']) + set_seed(config["seed"]) rank0_print(f"World size: {dist.get_world_size()}") - tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length']) + tokenizer = AutoTokenizer.from_pretrained( + config["tokenizer_name"], model_max_length=config["max_length"] + ) # llama has no pad token, set it to new token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - - train_dataset, val_dataset = load_data_for_inference(config, tokenizer) + train_dataset, val_dataset = load_data_for_inference(config, tokenizer) num_processes = dist.get_world_size() local_rank = dist.get_rank() - train_sampler = DistributedSampler(train_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank) + train_sampler = DistributedSampler( + train_dataset, + shuffle=False, + drop_last=True, + num_replicas=num_processes, + rank=local_rank, + ) train_dataloader = DataLoader( train_dataset, collate_fn=DefaultDataCollator(), batch_size=config["batch_size"], sampler=train_sampler, - drop_last=True + drop_last=True, ) - val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank) + val_sampler = DistributedSampler( + val_dataset, + shuffle=False, + drop_last=True, + num_replicas=num_processes, + rank=local_rank, + ) val_dataloader = DataLoader( val_dataset, collate_fn=DefaultDataCollator(), batch_size=config["batch_size"], sampler=val_sampler, - drop_last=True + drop_last=True, ) - - model = AutoModelForCausalLM.from_pretrained(config["model_name"], - trust_remote_code=True, - torch_dtype=torch.bfloat16, - ) + model = AutoModelForCausalLM.from_pretrained( + config["model_name"], + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) model.to(f"cuda:{local_rank}") with torch.no_grad(): @@ -78,14 +87,18 @@ def inference(config): for batch in tqdm(train_dataloader, disable=local_rank != 0): batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}") batch["labels"] = batch["labels"].to(f"cuda:{local_rank}") - outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True) + outputs = model( + input_ids=batch["input_ids"], + labels=batch["labels"], + output_hidden_states=True, + ) loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"]) train_outputs["loss"].extend(loss) embeddings = outputs.hidden_states[-1] batch_size = batch["input_ids"].shape[0] sequence_lengths = [] - # since we use mutiturn with multiple <|endoftext|>, we need to find the place where + # since we use mutiturn with multiple <|endoftext|>, we need to find the place where # <|endoftext|> is repeated for item in batch["input_ids"]: indices = torch.where(item == tokenizer.pad_token_id)[0] @@ -101,7 +114,9 @@ def inference(config): sequence_lengths.append(len(item) - 1) sequence_lengths = torch.tensor(sequence_lengths) - pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths] + pooled_logits = embeddings[ + torch.arange(batch_size, device=embeddings.device), sequence_lengths + ] train_outputs["embeddings"].append(pooled_logits) train_outputs["index"].extend(batch["index"].to(model.device)) @@ -120,29 +135,40 @@ def inference(config): # compute mask in pyarrow since it's super fast # ty @bmschmidt for showing me this! table = train_dataset.data - mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32())) + mask = pc.is_in(table["index"], value_set=pa.array(curr_idx, pa.int32())) filtered_table = table.filter(mask) # convert from pyarrow to Dataset filtered_train = Dataset.from_dict(filtered_table.to_pydict()) filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"]) filtered_train = filtered_train.add_column("loss", df_train["loss"]) - filtered_train = filtered_train.add_column("is_train", [True] * len(filtered_train)) + filtered_train = filtered_train.add_column( + "is_train", [True] * len(filtered_train) + ) - filtered_train.to_json(f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64) + filtered_train.to_json( + f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl", + lines=True, + orient="records", + num_proc=64, + ) val_outputs = {"loss": [], "embeddings": [], "index": []} for batch in tqdm(val_dataloader, disable=local_rank != 0): batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}") batch["labels"] = batch["labels"].to(f"cuda:{local_rank}") - outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True) + outputs = model( + input_ids=batch["input_ids"], + labels=batch["labels"], + output_hidden_states=True, + ) loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"]) val_outputs["loss"].extend(loss) embeddings = outputs.hidden_states[-1] batch_size = batch["input_ids"].shape[0] sequence_lengths = [] - # since we use mutiturn with multiple <|endoftext|>, we need to find the place where + # since we use mutiturn with multiple <|endoftext|>, we need to find the place where # <|endoftext|> is repeated for item in batch["input_ids"]: indices = torch.where(item == tokenizer.pad_token_id)[0] @@ -158,7 +184,9 @@ def inference(config): sequence_lengths.append(len(item) - 1) sequence_lengths = torch.tensor(sequence_lengths) - pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths] + pooled_logits = embeddings[ + torch.arange(batch_size, device=embeddings.device), sequence_lengths + ] val_outputs["embeddings"].append(pooled_logits) val_outputs["index"].extend(batch["index"].to(model.device)) @@ -176,7 +204,7 @@ def inference(config): # compute mask in pyarrow since it's super fast # ty @bmschmidt for showing me this! table = val_dataset.data - mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32())) + mask = pc.is_in(table["index"], value_set=pa.array(curr_idx, pa.int32())) filtered_table = table.filter(mask) # convert from pyarrow to Dataset filtered_val = Dataset.from_dict(filtered_table.to_pydict()) @@ -184,8 +212,13 @@ def inference(config): filtered_val = filtered_val.add_column("loss", df_val["loss"]) filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val)) - filtered_val.to_json(f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64) - + filtered_val.to_json( + f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl", + lines=True, + orient="records", + num_proc=64, + ) + def main(): dist.init_process_group("nccl") @@ -201,4 +234,3 @@ def main(): if __name__ == "__main__": # parse arguments by reading in a config main() - diff --git a/gpt4all/utils/distributed_utils.py b/gpt4all/utils/distributed_utils.py new file mode 100644 index 00000000..839a7a92 --- /dev/null +++ b/gpt4all/utils/distributed_utils.py @@ -0,0 +1,6 @@ +import torch.distributed as dist + + +def rank0_print(msg): + if dist.get_rank() == 0: + print(msg) From 2dae153c68998da6c7e5c30dab89e982d9f1c3f7 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sat, 22 Apr 2023 19:38:34 +0000 Subject: [PATCH 11/36] feat: sbert abstractor --- gpt4all/index/embed.py | 68 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 gpt4all/index/embed.py diff --git a/gpt4all/index/embed.py b/gpt4all/index/embed.py new file mode 100644 index 00000000..6ac75a3b --- /dev/null +++ b/gpt4all/index/embed.py @@ -0,0 +1,68 @@ +import torch +from transformers import AutoTokenizer, AutoModel +import torch.nn.functional as F + + +class Embedder: + def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"): + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.embedder = AutoModel.from_pretrained(model_name) + # hack + self.offset = self.tokenizer.model_max_length // 2 + + def _mean_pool(self, model_output, attention_mask): + token_embeddings = model_output[ + 0 + ] # First element of model_output contains all token embeddings + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + sentence_embeddings = torch.sum( + token_embeddings * input_mask_expanded, 1 + ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + return F.normalize(sentence_embeddings, p=2, dim=1) + + def chunk_text(self, text): + tokenized_text = {"input_ids": [], "attention_mask": []} + tokenized = self.tokenizer(text) + tokenized_len = len(tokenized["input_ids"]) + max_len = self.tokenizer.model_max_length + if tokenized_len > max_len: + start = 0 + while start < tokenized_len: + tokenized_text["input_ids"].append( + tokenized["input_ids"][start : start + max_len] + ) + tokenized_text["attention_mask"].append( + tokenized["attention_mask"][start : start + max_len] + ) + # this could probably be done better + start += self.offset + + else: + tokenized_text["input_ids"].append(tokenized["input_ids"]) + tokenized_text["attention_mask"].append(tokenized["attention_mask"]) + + return tokenized_text + + def __call__(self, batch): + if isinstance(batch, dict): + outputs = self.embedder( + input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] + ) + embedding = self._mean_pool(outputs, batch["attention_mask"]) + + return {"id": batch["id"], "embedding": embedding} + + elif isinstance(batch, str): + tokenized = self.tokenizer(batch, return_tensors="pt", truncation=True) + return self._mean_pool( + self.embedder( + input_ids=tokenized["input_ids"], + attention_mask=tokenized["attention_mask"], + ), + tokenized["attention_mask"], + ) + + def to(self, device): + self.embedder.to(device) From 4fb19d67b5ddd888f47900ab73aed76733e7b8c3 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sat, 22 Apr 2023 19:38:51 +0000 Subject: [PATCH 12/36] feat: tokenize texts into chunks --- gpt4all/index/tokenize_texts.py | 62 +++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 gpt4all/index/tokenize_texts.py diff --git a/gpt4all/index/tokenize_texts.py b/gpt4all/index/tokenize_texts.py new file mode 100644 index 00000000..2c3e83e2 --- /dev/null +++ b/gpt4all/index/tokenize_texts.py @@ -0,0 +1,62 @@ +from datasets import load_dataset +from argparse import ArgumentParser +from gpt4all.index.embed import Embedder + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--tokenized_save_path", type=str, default="tokenized") + + return parser.parse_args() + + +def tokenize_texts(examples, embedder): + split_data = {k: [] for k in examples.keys()} + split_data["tokenized_chunk"] = [] + split_data["tokenized_attn_mask"] = [] + + keys = [k for k in examples.keys() if k != "text"] + for i, text in enumerate(examples["text"]): + tokenized_text = embedder.chunk_text(text) + # do we want to add sep/cls tokens to beginning and end? + decoded_text = embedder.tokenizer.batch_decode( + sequences=tokenized_text["input_ids"] + ) + + num_splits = len(tokenized_text["input_ids"]) + split_data["id"].extend( + [f"{examples['id'][i]}_split_{j}" for j in range(num_splits)] + ) + + for col in keys: + if col != "id": + split_data[col].extend( + [examples[col][i]] * len(tokenized_text["input_ids"]) + ) + + split_data["text"].extend(decoded_text) + split_data["tokenized_chunk"].extend(tokenized_text["input_ids"]) + split_data["tokenized_attn_mask"].extend(tokenized_text["attention_mask"]) + + return split_data + + +def chunk_dataset( + ds_name="wikipedia", + version="20220301.simple", + sbert_model="sentence-transformers/all-MiniLM-L6-v2", + save_path="tokenized", +): + dataset = load_dataset(ds_name, version, split="train") + print(len(dataset)) + embedder = Embedder(sbert_model) + dataset = dataset.map( + lambda x: tokenize_texts(x, embedder), batched=True, num_proc=64 + ) + + dataset.save_to_disk(save_path) + + +if __name__ == "__main__": + args = parse_args() + chunked_dataset = chunk_dataset(save_path=args.tokenized_save_path) From a2b1f9983829ea4e05c4ef8f15350f725a5ccb4b Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sat, 22 Apr 2023 19:39:30 +0000 Subject: [PATCH 13/36] feat: distributed eval of embedder --- gpt4all/index/embed_texts.py | 96 ++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 gpt4all/index/embed_texts.py diff --git a/gpt4all/index/embed_texts.py b/gpt4all/index/embed_texts.py new file mode 100644 index 00000000..f130217d --- /dev/null +++ b/gpt4all/index/embed_texts.py @@ -0,0 +1,96 @@ +import torch.distributed as dist +from argparse import ArgumentParser +from datasets import Dataset +from gpt4all.index.embed import Embedder +from gpt4all.utils.distributed_utils import rank0_print +from torch.utils.data import DataLoader, DistributedSampler +from transformers.trainer_pt_utils import nested_numpify +from transformers import BatchEncoding +from tqdm import tqdm +import numpy as np +import torch + + +class PadCollateInputs: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __call__(self, batch): + mapped_inputs = {"input_ids": [], "attention_mask": []} + mapped_inputs["input_ids"] = [b["tokenized_chunk"] for b in batch] + mapped_inputs["attention_mask"] = [b["tokenized_attn_mask"] for b in batch] + encoding = BatchEncoding(mapped_inputs) + + padded_inputs = self.tokenizer.pad( + encoding, padding="max_length", return_tensors="pt" + ) + padded_inputs["id"] = [b["id"] for b in batch] + + return padded_inputs + + +def embed_texts(ds_path, batch_size): + rank0_print(f"World size: {dist.get_world_size()}") + + dataset = Dataset.load_from_disk(ds_path) + rank0_print(f"Dataset size: {len(dataset)}") + dataset = dataset.remove_columns(["url", "title", "text"]) + dataset = dataset.with_format("torch") + + num_processes = dist.get_world_size() + local_rank = dist.get_rank() + + model = Embedder() + + collator = PadCollateInputs(model.tokenizer) + + sampler = DistributedSampler( + dataset, + shuffle=False, + drop_last=False, + num_replicas=num_processes, + rank=local_rank, + ) + dataloader = DataLoader( + dataset, + collate_fn=collator, + batch_size=batch_size, + sampler=sampler, + drop_last=False, + ) + + model.to(f"cuda:{local_rank}") + with torch.no_grad(): + embedded_outputs = {"id": [], "embedding": []} + for batch in tqdm(dataloader, disable=local_rank != 0): + batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}") + batch["attention_mask"] = batch["attention_mask"].to(f"cuda:{local_rank}") + outputs = model(batch) + embedded_outputs["id"].extend(batch["id"]) + embedded_outputs["embedding"].extend(outputs["embedding"]) + + embedded_outputs["embedding"] = nested_numpify(embedded_outputs["embedding"]) + embedded_outputs["id"] = np.stack(embedded_outputs["id"]) + embedded_outputs["embedding"] = np.stack(embedded_outputs["embedding"]) + + ds = Dataset.from_dict(embedded_outputs) + + # feeling lazy, don't want to wait for all_gather to finish + # will load and concat in a single process in another script + ds.save_to_disk(f"embedded/{ds_path}_embedded_rank_{local_rank}") + + +def main(): + dist.init_process_group("nccl") + parser = ArgumentParser() + parser.add_argument("--ds_path", type=str, default="tokenized") + parser.add_argument("--batch_size", type=int, default=1) + + args = parser.parse_args() + + embed_texts(args.ds_path, args.batch_size) + + +if __name__ == "__main__": + # parse arguments by reading in a config + main() From 4eeab60306f67b43d39cd8c83755fb0d822cfc5d Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sat, 22 Apr 2023 19:39:46 +0000 Subject: [PATCH 14/36] feat: build knn index --- gpt4all/index/build_index.py | 90 ++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 gpt4all/index/build_index.py diff --git a/gpt4all/index/build_index.py b/gpt4all/index/build_index.py new file mode 100644 index 00000000..7e61e2ac --- /dev/null +++ b/gpt4all/index/build_index.py @@ -0,0 +1,90 @@ +import os +from datasets import Dataset, concatenate_datasets +import glob +from argparse import ArgumentParser +import hnswlib +import pyarrow as pa +import pyarrow.compute as pc + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--ds_path", type=str, required=True) + parser.add_argument("--embed_folder", type=str, required=True) + parser.add_argument("--index_path", type=str, default="wiki-index") + + return parser.parse_args() + + +def concat_embedded(folder): + files = glob.glob(f"{folder}/*") + + all_embeddings = [] + for file in files: + all_embeddings.append(Dataset.load_from_disk(file)) + + all_embeddings = concatenate_datasets(all_embeddings) + + return all_embeddings + + +def join(original_ds, embedded_ds): + embedded_ds = embedded_ds.add_column("index", range(len(embedded_ds))) + embed_table = embedded_ds.data.table + + seen = set() + indices = [] + for i, id in enumerate(original_ds["id"]): + if id not in seen: + indices.append(i) + seen.add(id) + + mask = pc.is_in(embed_table["index"], value_set=pa.array(indices, pa.int32())) + filtered_table = embed_table.filter(mask) + + # sort to make sure we're adding in right order + filtered_table = filtered_table.sort_by("id") + + original_table = original_ds.data.table + original_table = original_table.sort_by("id") + + original_table = original_table.append_column( + "embedding", filtered_table["embedding"] + ) + # there's definitely a better way to do this but + # Dataset(original_table) throws `KeyError: 'embedding'` + joined = Dataset.from_dict(original_table.to_pydict()) + + return joined + + +def build_index(orig_path, embed_folder_path, index_path): + if not os.path.exists(orig_path + "_embedded_with_text"): + ds = Dataset.load_from_disk(orig_path) + embed_ds = concat_embedded(embed_folder_path) + print("Concatenated embeddings") + print(f"Length: {len(ds)}") + print(f"Length: {len(embed_ds)}") + ds = join(ds, embed_ds) + ds = ds.add_column("index", range(len(ds))) + print("Saving to disk") + ds.save_to_disk(f"{orig_path}_embedded_with_text") + else: + ds = Dataset.load_from_disk(orig_path + "_embedded_with_text") + + print(f"Length of ds: {len(ds)}") + + print("Building index") + index = hnswlib.Index(space="cosine", dim=384) + # not sure what we should set M and ef_construction to + index.init_index(max_elements=len(ds), M=64, ef_construction=200) + print("Adding items") + index.add_items(ds["embedding"], ds["index"]) + + print("Saving index") + index.save_index(index_path + ".bin") + + +if __name__ == "__main__": + args = parse_args() + build_index(args.ds_path, args.embed_folder, args.index_path) From 8cb3c8de35e608b006802ca817b8326d7211f8d3 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sat, 22 Apr 2023 19:40:06 +0000 Subject: [PATCH 15/36] chore: reqs, ignore, readme --- .gitignore | 2 ++ gpt4all/index/README.md | 27 +++++++++++++++++++++++++++ gpt4all/index/__init__.py | 0 requirements.txt | 5 ++++- 4 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 gpt4all/index/README.md create mode 100644 gpt4all/index/__init__.py diff --git a/.gitignore b/.gitignore index 5fadbd31..be27e4e7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +# ignore knn index +gpt4all/index/**/ .DS_Store *.pkl ckpts* diff --git a/gpt4all/index/README.md b/gpt4all/index/README.md new file mode 100644 index 00000000..0ba876d9 --- /dev/null +++ b/gpt4all/index/README.md @@ -0,0 +1,27 @@ +# How to Tokenize and Embed + +Split text into chunks +``` +python tokenize_texts.py +``` + +Embbed Texts + +``` +torchrun --master_port=29085 --nproc-per-node 8 embed_texts.py --ds_path=tokenized --batch_size=2048 +``` + + +Combine Embeddings and Build Index +``` +python build_index.py --ds_path=wiki_sample_tokenized --embed_folder=wiki_sample_embedded +``` + +To use the Index + +``` +import hnswlib + +index = hnswlib.Index(space='l2', dim=384) +index.load_index() +``` diff --git a/gpt4all/index/__init__.py b/gpt4all/index/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/requirements.txt b/requirements.txt index b38ab36c..1b0400ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,7 @@ sentencepiece jsonlines nomic scikit-learn -matplotlib \ No newline at end of file +matplotlib +apache_beam +mwparserfromhell +hnswlib \ No newline at end of file From 9bc88fb33d2ae7463350b04807d94bdfa276706f Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sat, 22 Apr 2023 19:53:47 +0000 Subject: [PATCH 16/36] feat: add options to specify different datasets --- gpt4all/index/tokenize_texts.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/gpt4all/index/tokenize_texts.py b/gpt4all/index/tokenize_texts.py index 2c3e83e2..adc35781 100644 --- a/gpt4all/index/tokenize_texts.py +++ b/gpt4all/index/tokenize_texts.py @@ -5,7 +5,12 @@ from gpt4all.index.embed import Embedder def parse_args(): parser = ArgumentParser() + # fmt: off parser.add_argument("--tokenized_save_path", type=str, default="tokenized") + parser.add_argument("--ds_name", type=str, default="wikipedia") + parser.add_argument("--ds_version", type=str, default="20220301.simple") + parser.add_argument("--sbert_model", type=str, default="sentence-transformers/all-MiniLM-L6-v2") + # fmt: on return parser.parse_args() @@ -59,4 +64,9 @@ def chunk_dataset( if __name__ == "__main__": args = parse_args() - chunked_dataset = chunk_dataset(save_path=args.tokenized_save_path) + chunked_dataset = chunk_dataset( + ds_name=args.ds_name, + version=args.ds_version, + sbert_model=args.sbert_model, + save_path=args.tokenized_save_path, + ) From 84acbc82250bc563228bc2360098eb035f93384c Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sun, 23 Apr 2023 17:34:49 +0000 Subject: [PATCH 17/36] chore: ignore swp files --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index be27e4e7..42f80eae 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +*.swp # ignore knn index gpt4all/index/**/ .DS_Store From 240feae27728d1d526c502d10cd506458feff376 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sun, 23 Apr 2023 20:02:22 +0000 Subject: [PATCH 18/36] added initial files for dataset prep and ingestion for gpt4all jr --- configs/train/finetune_gptjr.yaml | 41 +++++ gpt4all/index/build_index.py | 14 +- gpt4all/index/embed.py | 20 ++- gpt4all/index/embed_texts.py | 28 ++- gpt4all/index/prep_index_for_train.py | 58 ++++++ gpt4all/index/test_load_index.py | 0 gpt4all/train/train_r.py | 246 ++++++++++++++++++++++++++ gpt4all/utils/data.py | 65 +++++++ 8 files changed, 454 insertions(+), 18 deletions(-) create mode 100644 configs/train/finetune_gptjr.yaml create mode 100644 gpt4all/index/prep_index_for_train.py create mode 100644 gpt4all/index/test_load_index.py create mode 100644 gpt4all/train/train_r.py diff --git a/configs/train/finetune_gptjr.yaml b/configs/train/finetune_gptjr.yaml new file mode 100644 index 00000000..291cf841 --- /dev/null +++ b/configs/train/finetune_gptjr.yaml @@ -0,0 +1,41 @@ +# model/tokenizer +model_name: "nomic-ai/gpt4all-j" +tokenizer_name: "nomic-ai/gpt4all-j" +version: 'v1.2-jazzy' +gradient_checkpointing: true +save_name: # CHANGE + +# dataset +streaming: false +num_proc: 64 +dataset_path: "squad" +max_length: 1024 +batch_size: 32 + +#index +index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin" +index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki-full-tokenized_embedded_with_text" +index_space: "cosine" +index_dim: 384 +query_embedding_field: 'question' + +# train dynamics +lr: 2.0e-5 +min_lr: 0 +weight_decay: 0.0 +eval_every: 500 +eval_steps: 105 +save_every: 500 +log_grads_every: 100 +output_dir: # CHANGE +checkpoint: null +lora: false +warmup_steps: 500 +num_epochs: 2 + +# logging +wandb: false +wandb_entity: # CHANGE +wandb_project_name: # CHANGE +seed: 42 + diff --git a/gpt4all/index/build_index.py b/gpt4all/index/build_index.py index 7e61e2ac..566f5abd 100644 --- a/gpt4all/index/build_index.py +++ b/gpt4all/index/build_index.py @@ -5,6 +5,7 @@ from argparse import ArgumentParser import hnswlib import pyarrow as pa import pyarrow.compute as pc +from tqdm import tqdm def parse_args(): @@ -42,6 +43,7 @@ def join(original_ds, embedded_ds): mask = pc.is_in(embed_table["index"], value_set=pa.array(indices, pa.int32())) filtered_table = embed_table.filter(mask) + import pdb; pdb.set_trace() # sort to make sure we're adding in right order filtered_table = filtered_table.sort_by("id") @@ -60,6 +62,8 @@ def join(original_ds, embedded_ds): def build_index(orig_path, embed_folder_path, index_path): if not os.path.exists(orig_path + "_embedded_with_text"): + # TODO: this doesn't work for large datasets! + # just convert to pandas and do this manually ds = Dataset.load_from_disk(orig_path) embed_ds = concat_embedded(embed_folder_path) print("Concatenated embeddings") @@ -79,7 +83,15 @@ def build_index(orig_path, embed_folder_path, index_path): # not sure what we should set M and ef_construction to index.init_index(max_elements=len(ds), M=64, ef_construction=200) print("Adding items") - index.add_items(ds["embedding"], ds["index"]) + chunk_size = 50_000 + num_chunks = len(ds) // chunk_size + progbar = tqdm(total=num_chunks) + start = 0 + while start < len(ds): + chunk = ds[start:start + chunk_size] + index.add_items(chunk["embedding"], chunk["index"], num_threads=64) + progbar.update(1) + start += chunk_size print("Saving index") index.save_index(index_path + ".bin") diff --git a/gpt4all/index/embed.py b/gpt4all/index/embed.py index 6ac75a3b..2cda33b2 100644 --- a/gpt4all/index/embed.py +++ b/gpt4all/index/embed.py @@ -45,16 +45,11 @@ class Embedder: return tokenized_text + def tokenize(self, text): + return self.tokenizer(text, return_tensors="pt", truncation=True, padding="max_length") + def __call__(self, batch): - if isinstance(batch, dict): - outputs = self.embedder( - input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] - ) - embedding = self._mean_pool(outputs, batch["attention_mask"]) - - return {"id": batch["id"], "embedding": embedding} - - elif isinstance(batch, str): + if isinstance(batch, str): tokenized = self.tokenizer(batch, return_tensors="pt", truncation=True) return self._mean_pool( self.embedder( @@ -63,6 +58,13 @@ class Embedder: ), tokenized["attention_mask"], ) + else: + outputs = self.embedder( + input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] + ) + embedding = self._mean_pool(outputs, batch["attention_mask"]) + + return {"id": batch["id"], "embedding": embedding} def to(self, device): self.embedder.to(device) diff --git a/gpt4all/index/embed_texts.py b/gpt4all/index/embed_texts.py index f130217d..a139f5df 100644 --- a/gpt4all/index/embed_texts.py +++ b/gpt4all/index/embed_texts.py @@ -9,6 +9,7 @@ from transformers import BatchEncoding from tqdm import tqdm import numpy as np import torch +from datasets import load_dataset class PadCollateInputs: @@ -29,20 +30,29 @@ class PadCollateInputs: return padded_inputs -def embed_texts(ds_path, batch_size): +def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False): rank0_print(f"World size: {dist.get_world_size()}") - - dataset = Dataset.load_from_disk(ds_path) + dataset = load_dataset(f"{ds_path}", split="train") rank0_print(f"Dataset size: {len(dataset)}") - dataset = dataset.remove_columns(["url", "title", "text"]) + + model = Embedder() + + dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64) + + columns_to_keep = ["input_ids", "attention_mask"] + #to_remove = [e for e in dataset.column_names if not e in columns_to_keep] + print('cols: ', dataset.column_names) + #dataset = dataset.remove_columns(to_remove) + + #dataset = Dataset.load_from_disk(ds_path) + #dataset = dataset.remove_columns(["url", "title", "text"]) dataset = dataset.with_format("torch") num_processes = dist.get_world_size() local_rank = dist.get_rank() - model = Embedder() - collator = PadCollateInputs(model.tokenizer) + #collator = PadCollateInputs(model.tokenizer) sampler = DistributedSampler( dataset, @@ -53,7 +63,7 @@ def embed_texts(ds_path, batch_size): ) dataloader = DataLoader( dataset, - collate_fn=collator, + # collate_fn=collator, batch_size=batch_size, sampler=sampler, drop_last=False, @@ -77,7 +87,9 @@ def embed_texts(ds_path, batch_size): # feeling lazy, don't want to wait for all_gather to finish # will load and concat in a single process in another script - ds.save_to_disk(f"embedded/{ds_path}_embedded_rank_{local_rank}") + if save_to_disk: + ds.save_to_disk(f"{ds_path}_embedded/{ds_path}_embedded_rank_{local_rank}") + return ds def main(): diff --git a/gpt4all/index/prep_index_for_train.py b/gpt4all/index/prep_index_for_train.py new file mode 100644 index 00000000..dcb90da3 --- /dev/null +++ b/gpt4all/index/prep_index_for_train.py @@ -0,0 +1,58 @@ +import os +import hnswlib +import numpy as np +from datasets import Dataset +import torch.distributed as dist +from datasets import load_dataset +from argparse import ArgumentParser +from gpt4all.utils.read import read_config +from gpt4all.index.embed_texts import embed_texts + +CHUNK_SIZE = 1024 +K = 5 + + +if __name__ == "__main__": + + dist.init_process_group("nccl") + + parser = ArgumentParser() + parser.add_argument("--config", type=str, default="config.yaml") + + args = parser.parse_args() + + config = read_config(args.config) + + #load index + index = hnswlib.Index(space=config['index_space'], dim=config['index_dim']) + index.load_index(config['index_path']) + + #load query dataset + ds_path = config['dataset_path'] + + #load retrieval dataset + retrieval_dataset = Dataset.load_from_disk(config['index_database']) + print(type(retrieval_dataset._data)) + raise + + #vectorize queries + query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded" + if not os.path.exists(query_vector_path): + print('Embedding dataset...') + ds = embed_texts(ds_path, config['batch_size'], embed_on=config['query_embedding_field'], save_to_disk=False) + ds.save_to_disk(query_vector_path) + else: + print('Found cached embedding dataset!') + ds = Dataset.load_from_disk(query_vector_path) + + #search the index for each query + for chunk_start in range(0, len(ds), CHUNK_SIZE): + chunk_end = chunk_start + CHUNK_SIZE + chunk = ds[chunk_start:chunk_end] + query_vectors = np.array(chunk['embedding']) + print(query_vectors.shape) + neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1) + raise + + #get the embeddings for each of the neighbor ids + diff --git a/gpt4all/index/test_load_index.py b/gpt4all/index/test_load_index.py new file mode 100644 index 00000000..e69de29b diff --git a/gpt4all/train/train_r.py b/gpt4all/train/train_r.py new file mode 100644 index 00000000..9254d6ef --- /dev/null +++ b/gpt4all/train/train_r.py @@ -0,0 +1,246 @@ +import os +from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM +import torch +from torch.optim import AdamW +from argparse import ArgumentParser +from gpt4all.utils.read import read_config +from accelerate import Accelerator +from accelerate.utils import DummyScheduler, DummyOptim, set_seed +from peft import get_peft_model, LoraConfig, TaskType +from gpt4all.utils.data import load_data, load_retrieval_augmented_data +from torchmetrics import MeanMetric +from tqdm import tqdm +from gpt4all.models import GPTJRForCausalLM +import wandb + +torch.backends.cuda.matmul.allow_tf32 = True + +def format_metrics(metrics, split, prefix=""): + log = f"[{split}]" + prefix + log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()]) + + return log + + +def evaluate(model, val_dataloader): + model.eval() + val_loss = MeanMetric(nan_strategy="error").to(model.device) + + with torch.no_grad(): + for batch in tqdm(val_dataloader): + loss = model(**batch).loss + + loss_values = accelerator.gather_for_metrics({"loss": loss.detach()}) + + val_loss.update(loss_values["loss"]) + + return val_loss + + +def train(accelerator, config): + set_seed(config['seed']) + + accelerator.print(config) + accelerator.print(f"Using {accelerator.num_processes} GPUs") + + tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length']) + # if no pad token, set it to eos + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + + with accelerator.main_process_first(): + + if 'index_path' in config: + train_dataloader, val_dataloader = load_retrieval_augmented_data(config, tokenizer) + else: + train_dataloader, val_dataloader = load_data(config, tokenizer) + + + checkpoint = config["gradient_checkpointing"] + #ensures back compat with non retrieval models + if 'index_path' in config: + model = GPTJRForCausalLM.from_pretrained(config["model_name"], + revision=config['version'], + use_cache=False if checkpoint else True, + trust_remote_code=True) + else: + model = AutoModelForCausalLM.from_pretrained(config["model_name"], + use_cache=False if checkpoint else True, + trust_remote_code=True) + + if checkpoint: + model.gradient_checkpointing_enable() + + if config["lora"]: + peft_config = LoraConfig( + # should R be configurable? + task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1 + ) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + optimizer_cls = ( + AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + + # karpathy doesn't decay embeddding, maybe we should exclude + # https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s + optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"]) + + if accelerator.state.deepspeed_plugin is not None: + gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ + "gradient_accumulation_steps" + ] + + # decay to min_lr instead of 0 + lr_ratio = config["min_lr"] / config["lr"] + accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}") + total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"] + # instead of decaying to zero, decay to ratio of min_lr / lr + total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"] + accelerator.print(f"Total training steps: {total_num_steps}") + + # Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + scheduler = get_scheduler( + name="cosine", + optimizer=optimizer, + num_warmup_steps=config["warmup_steps"] * accelerator.num_processes, + num_training_steps=total_num_steps, + ) + else: + scheduler = DummyScheduler( + optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"] + ) + + model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( + model, optimizer, train_dataloader, val_dataloader, scheduler + ) + + # setup for saving training states in case preemption + accelerator.register_for_checkpointing(scheduler) + + if config["checkpoint"]: + accelerator.load_state(config["checkpoint"]) + accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}") + path = os.path.basename(config["train_args"]["resume_from_checkpoint"]) + training_difference = os.path.splitext(path)[0] + resume_step = int(training_difference.replace("step_", "")) + accelerator.skip_first_batches(train_dataloader, resume_step) + accelerator.print(f"Resuming from step {resume_step}") + + + # log gradients + if accelerator.is_main_process and config["wandb"]: + wandb.watch(model, log_freq=config["log_grads_every"], log="all") + + for epoch in range(config["num_epochs"]): + train_loss = MeanMetric(nan_strategy="error").to(model.device) + for step, batch in enumerate(tqdm(train_dataloader)): + model.train() + outputs = model(**batch) + loss = outputs.loss + + # gather loss before backprop in case of gradient accumulation + loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()}) + train_loss.update(loss_values["loss"]) + + loss = loss / gradient_accumulation_steps + accelerator.backward(loss) + # get gradient norm of all params + + # log LR in case something weird happens + if step > 0 and step % (config["eval_every"] // 10) == 0: + if config["wandb"]: + curr_step = step + epoch * len(train_dataloader) + accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step) + + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + + if step > 0 and step % config["save_every"] == 0: + curr_step = step + epoch * len(train_dataloader) + accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") + + if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): + val_loss = evaluate(model, val_dataloader) + + log_train = { + "train_loss": train_loss.compute() + } + log_val = { + "val_loss": val_loss.compute() + } + + if config["wandb"]: + curr_step = step + epoch * len(train_dataloader) + accelerator.log({**log_train, **log_val}, step=curr_step) + + accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}") + accelerator.print(format_metrics(log_train, "train", f" step {step} ")) + accelerator.print(format_metrics(log_val, "val", f" step {step} ")) + + train_loss.reset() + + accelerator.print(f"Epoch {epoch} finished") + accelerator.print(f"Pushing to HF hub") + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + try: + if accelerator.is_main_process: + unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True) + + except Exception as e: + accelerator.print(e) + accelerator.print(f"Failed to push to hub") + + unwrapped_model.save_pretrained( + f"{config['output_dir']}/epoch_{epoch}", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) + + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + f"{config['output_dir']}/final", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) + + accelerator.end_training() + + + +if __name__ == "__main__": + # parse arguments by reading in a config + parser = ArgumentParser() + parser.add_argument("--config", type=str, default="config.yaml") + + args = parser.parse_args() + + config = read_config(args.config) + + if config["wandb"]: + accelerator = Accelerator(log_with="wandb") + accelerator.init_trackers( + project_name=config["wandb_project_name"], + config=config, + init_kwargs={"wandb": {"entity": config["wandb_entity"]}}, + ) + else: + accelerator = Accelerator() + + train(accelerator, config=config) diff --git a/gpt4all/utils/data.py b/gpt4all/utils/data.py index b55a589a..3216441f 100644 --- a/gpt4all/utils/data.py +++ b/gpt4all/utils/data.py @@ -2,6 +2,7 @@ import glob import torch from datasets import load_dataset import os +import hnswlib from torch.utils.data import DataLoader from transformers import DefaultDataCollator @@ -116,6 +117,70 @@ def load_data(config, tokenizer): return train_dataloader, val_dataloader +def load_retrieval_augmented_data(config, tokenizer): + dataset_path = config["dataset_path"] + index_path = config['index_path'] + + #TODO this should precache at some point + index = hnswlib.Index(space=config['index_space'], dim=config['index_dim']) + index.load_index(index_path) + + if os.path.exists(dataset_path): + if os.path.isdir(dataset_path): + files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl")) + else: + files = [dataset_path] + + print(f"Reading files {files}") + + dataset = load_dataset("json", data_files=files, split="train") + + else: + dataset = load_dataset(dataset_path, split="train") + + dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) + + train_dataset, val_dataset = dataset["train"], dataset["test"] + + if config["streaming"] is False: + kwargs = {"num_proc": config["num_proc"]} + else: + kwargs = {} + + # tokenize inputs and return labels and attention mask + train_dataset = train_dataset.map( + lambda ele: tokenize_inputs(config, tokenizer, ele), + batched=True, + remove_columns=["source", "prompt"], + **kwargs + ) + val_dataset = val_dataset.map( + lambda ele: tokenize_inputs(config, tokenizer, ele), + batched=True, + remove_columns=["source", "prompt"], + **kwargs + ) + + train_dataset = train_dataset.with_format("torch") + val_dataset = val_dataset.with_format("torch") + + # create dataloader with default data collator since we already have labels + + train_dataloader = DataLoader( + train_dataset, + collate_fn=DefaultDataCollator(), + batch_size=config["batch_size"], + ) + + val_dataloader = DataLoader( + val_dataset, + collate_fn=DefaultDataCollator(), + batch_size=config["batch_size"], + ) + + return train_dataloader, val_dataloader + + def load_data_for_inference(config, tokenizer): dataset_path = config["dataset_path"] From cc3cc3f7e9fba7ef8432b490316876826621fb0d Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sun, 23 Apr 2023 21:04:43 +0000 Subject: [PATCH 19/36] we now get the neighbors embeddings from the disk index --- configs/train/finetune_gptjr.yaml | 2 +- gpt4all/index/prep_index_for_train.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/configs/train/finetune_gptjr.yaml b/configs/train/finetune_gptjr.yaml index 291cf841..a03f8ff0 100644 --- a/configs/train/finetune_gptjr.yaml +++ b/configs/train/finetune_gptjr.yaml @@ -14,7 +14,7 @@ batch_size: 32 #index index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin" -index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki-full-tokenized_embedded_with_text" +index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki_sample_tokenized_embedded_with_text" index_space: "cosine" index_dim: 384 query_embedding_field: 'question' diff --git a/gpt4all/index/prep_index_for_train.py b/gpt4all/index/prep_index_for_train.py index dcb90da3..cf8152fd 100644 --- a/gpt4all/index/prep_index_for_train.py +++ b/gpt4all/index/prep_index_for_train.py @@ -1,9 +1,11 @@ import os import hnswlib import numpy as np +import pyarrow as pa from datasets import Dataset import torch.distributed as dist from datasets import load_dataset +from pyarrow import compute as pc from argparse import ArgumentParser from gpt4all.utils.read import read_config from gpt4all.index.embed_texts import embed_texts @@ -32,8 +34,6 @@ if __name__ == "__main__": #load retrieval dataset retrieval_dataset = Dataset.load_from_disk(config['index_database']) - print(type(retrieval_dataset._data)) - raise #vectorize queries query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded" @@ -50,9 +50,8 @@ if __name__ == "__main__": chunk_end = chunk_start + CHUNK_SIZE chunk = ds[chunk_start:chunk_end] query_vectors = np.array(chunk['embedding']) - print(query_vectors.shape) neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1) - raise + value_set = pa.array([str(e) for e in neighbor_ids.flatten()]) + out = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set)) - #get the embeddings for each of the neighbor ids From 869829a06576fd2c3298c550fef0e8e8d8f62d22 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sun, 23 Apr 2023 22:28:57 +0000 Subject: [PATCH 20/36] nearly have the neighbor caching working, but combinging info into final dataset is challenging --- gpt4all/index/embed_texts.py | 4 ++-- gpt4all/index/prep_index_for_train.py | 24 +++++++++++++++++++++--- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/gpt4all/index/embed_texts.py b/gpt4all/index/embed_texts.py index a139f5df..c2007fde 100644 --- a/gpt4all/index/embed_texts.py +++ b/gpt4all/index/embed_texts.py @@ -30,9 +30,9 @@ class PadCollateInputs: return padded_inputs -def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False): +def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False, split='train'): rank0_print(f"World size: {dist.get_world_size()}") - dataset = load_dataset(f"{ds_path}", split="train") + dataset = load_dataset(f"{ds_path}", split=split) rank0_print(f"Dataset size: {len(dataset)}") model = Embedder() diff --git a/gpt4all/index/prep_index_for_train.py b/gpt4all/index/prep_index_for_train.py index cf8152fd..f442e995 100644 --- a/gpt4all/index/prep_index_for_train.py +++ b/gpt4all/index/prep_index_for_train.py @@ -10,6 +10,7 @@ from argparse import ArgumentParser from gpt4all.utils.read import read_config from gpt4all.index.embed_texts import embed_texts +SPLIT = 'train' CHUNK_SIZE = 1024 K = 5 @@ -36,22 +37,39 @@ if __name__ == "__main__": retrieval_dataset = Dataset.load_from_disk(config['index_database']) #vectorize queries - query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded" + query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded_{SPLIT}" if not os.path.exists(query_vector_path): print('Embedding dataset...') - ds = embed_texts(ds_path, config['batch_size'], embed_on=config['query_embedding_field'], save_to_disk=False) + ds = embed_texts(ds_path, + config['batch_size'], + embed_on=config['query_embedding_field'], + save_to_disk=False, + split=SPLIT) ds.save_to_disk(query_vector_path) else: print('Found cached embedding dataset!') ds = Dataset.load_from_disk(query_vector_path) + #build training dataset + train_dataset = load_dataset(ds_path, split=SPLIT) + #search the index for each query + neighbor_embs_column = [] + neighbor_ids_column = [] for chunk_start in range(0, len(ds), CHUNK_SIZE): chunk_end = chunk_start + CHUNK_SIZE chunk = ds[chunk_start:chunk_end] query_vectors = np.array(chunk['embedding']) neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1) value_set = pa.array([str(e) for e in neighbor_ids.flatten()]) - out = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set)) + #TODO @nussy should be id + neighbor_objs = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set)) + neighbor_ids_column.extend(neighbor_objs['index']) #TODO @nussy should be id + neighbor_embs_column.extend(neighbor_objs['embedding']) + #import pdb;pdb.set_trace() + train_dataset = train_dataset.add_column('neighbor_ids', neighbor_ids_column) + train_dataset = train_dataset.add_column('neighbor_embeddings', neighbor_embs_column) + supplemented_dataset_path = f"{ds_path}_supplemented_{SPLIT}/" + train_dataset.save_to_disk(supplemented_dataset_path) From 7832707c375a6689e4355ac3f8e8dc5599b9feeb Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Tue, 25 Apr 2023 20:33:49 +0000 Subject: [PATCH 21/36] fix: index -> id --- gpt4all/index/build_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt4all/index/build_index.py b/gpt4all/index/build_index.py index 566f5abd..77a6aa24 100644 --- a/gpt4all/index/build_index.py +++ b/gpt4all/index/build_index.py @@ -89,7 +89,7 @@ def build_index(orig_path, embed_folder_path, index_path): start = 0 while start < len(ds): chunk = ds[start:start + chunk_size] - index.add_items(chunk["embedding"], chunk["index"], num_threads=64) + index.add_items(chunk["embedding"], chunk["id"], num_threads=64) progbar.update(1) start += chunk_size From 586a8abc0673a71f1d22d5399515406d6da83b46 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Tue, 25 Apr 2023 20:34:12 +0000 Subject: [PATCH 22/36] fix: allow for print for fns that are used in both dist and single --- gpt4all/utils/distributed_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gpt4all/utils/distributed_utils.py b/gpt4all/utils/distributed_utils.py index 839a7a92..3fcf276f 100644 --- a/gpt4all/utils/distributed_utils.py +++ b/gpt4all/utils/distributed_utils.py @@ -2,5 +2,8 @@ import torch.distributed as dist def rank0_print(msg): - if dist.get_rank() == 0: + if dist.is_initialized(): + if dist.get_rank() == 0: + print(msg) + else: print(msg) From f2161f7e5900ce68b910bbaf69ee7916e1f910ca Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Tue, 25 Apr 2023 20:34:26 +0000 Subject: [PATCH 23/36] docs: prep index barebones --- gpt4all/index/README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/gpt4all/index/README.md b/gpt4all/index/README.md index 0ba876d9..675ce92a 100644 --- a/gpt4all/index/README.md +++ b/gpt4all/index/README.md @@ -25,3 +25,10 @@ import hnswlib index = hnswlib.Index(space='l2', dim=384) index.load_index() ``` + + +Prep index for train + +``` +CUDA_VISIBLE_DEVICES=7 torchrun --master_port=29086 --nproc-per-node 1 prep_index_for_train.py --config=../../configs/train/finetune_gptjr.yaml +``` From c20379f7e99914779d19eaaf1b2f2f4265bedb3a Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Tue, 25 Apr 2023 20:34:38 +0000 Subject: [PATCH 24/36] refactor: clean up prep index --- gpt4all/index/prep_index_for_train.py | 69 +++++++++++++++++---------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/gpt4all/index/prep_index_for_train.py b/gpt4all/index/prep_index_for_train.py index f442e995..64e306d8 100644 --- a/gpt4all/index/prep_index_for_train.py +++ b/gpt4all/index/prep_index_for_train.py @@ -9,67 +9,86 @@ from pyarrow import compute as pc from argparse import ArgumentParser from gpt4all.utils.read import read_config from gpt4all.index.embed_texts import embed_texts +from tqdm import tqdm -SPLIT = 'train' CHUNK_SIZE = 1024 -K = 5 - - -if __name__ == "__main__": - - dist.init_process_group("nccl") +def parse_args(): parser = ArgumentParser() parser.add_argument("--config", type=str, default="config.yaml") + parser.add_argument("--split", type=str, default="train") + parser.add_argument("--k", type=int, default=5) - args = parser.parse_args() + return parser.parse_args() + +def prep_index(): + args = parse_args() config = read_config(args.config) - - #load index index = hnswlib.Index(space=config['index_space'], dim=config['index_dim']) + print("loading index") index.load_index(config['index_path']) - #load query dataset + # load query dataset ds_path = config['dataset_path'] - #load retrieval dataset - retrieval_dataset = Dataset.load_from_disk(config['index_database']) + # load retrieval dataset + print("loading retrieval dataset") + print(config["index_database"]) + if os.path.exists(config['index_database']): + retrieval_dataset = Dataset.load_from_disk(config['index_database']) + else: + retrieval_dataset = load_dataset(config['index_database'], split=args.split) - #vectorize queries - query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded_{SPLIT}" + # vectorize queries + query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded_{args.split}" if not os.path.exists(query_vector_path): print('Embedding dataset...') ds = embed_texts(ds_path, config['batch_size'], embed_on=config['query_embedding_field'], save_to_disk=False, - split=SPLIT) + split=args.split) ds.save_to_disk(query_vector_path) else: print('Found cached embedding dataset!') ds = Dataset.load_from_disk(query_vector_path) #build training dataset - train_dataset = load_dataset(ds_path, split=SPLIT) + train_dataset = load_dataset(ds_path, split=args.split) #search the index for each query neighbor_embs_column = [] neighbor_ids_column = [] - for chunk_start in range(0, len(ds), CHUNK_SIZE): + for chunk_start in tqdm(range(0, len(ds), CHUNK_SIZE)): chunk_end = chunk_start + CHUNK_SIZE chunk = ds[chunk_start:chunk_end] query_vectors = np.array(chunk['embedding']) - neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1) + neighbor_ids, _ = index.knn_query(query_vectors, k=args.k, num_threads=-1) # neighbor ids is of shape [n_queries, n_neighbors] value_set = pa.array([str(e) for e in neighbor_ids.flatten()]) - #TODO @nussy should be id - neighbor_objs = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set)) - neighbor_ids_column.extend(neighbor_objs['index']) #TODO @nussy should be id - neighbor_embs_column.extend(neighbor_objs['embedding']) + neighbor_objs = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['id'], value_set)) - #import pdb;pdb.set_trace() + # build mapping between indices and embeddings + neighbor_id_list = neighbor_objs['id'] + emb_list = neighbor_objs['embedding'] + idx_to_embedding = {idx.as_py(): emb_list[i] for i, idx in enumerate(neighbor_id_list)} + + neighbor_embs = [] + for cur_neighbor_ids in neighbor_ids: + cur_embs = [idx_to_embedding[id].as_py() for id in cur_neighbor_ids] + neighbor_embs.append(cur_embs) + + neighbor_embs_column.extend(neighbor_embs) + neighbor_ids_column.extend(neighbor_ids) + + print("adding neighbor ids") train_dataset = train_dataset.add_column('neighbor_ids', neighbor_ids_column) + print("adding neighbor embeddings") train_dataset = train_dataset.add_column('neighbor_embeddings', neighbor_embs_column) - supplemented_dataset_path = f"{ds_path}_supplemented_{SPLIT}/" + supplemented_dataset_path = f"{ds_path}_supplemented_{args.split}/" train_dataset.save_to_disk(supplemented_dataset_path) + + +if __name__ == "__main__": + prep_index() \ No newline at end of file From b0f92b610e45f6315f0c7b965608405dda727fea Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Tue, 25 Apr 2023 20:34:49 +0000 Subject: [PATCH 25/36] refactor: clean up embed texts --- gpt4all/index/embed_texts.py | 50 +++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/gpt4all/index/embed_texts.py b/gpt4all/index/embed_texts.py index c2007fde..9168a6c8 100644 --- a/gpt4all/index/embed_texts.py +++ b/gpt4all/index/embed_texts.py @@ -1,3 +1,4 @@ +import os import torch.distributed as dist from argparse import ArgumentParser from datasets import Dataset @@ -12,6 +13,8 @@ import torch from datasets import load_dataset +# this isn't used but keeping in case we need it in the future +# collate and pad inputs to the right shape class PadCollateInputs: def __init__(self, tokenizer): self.tokenizer = tokenizer @@ -31,39 +34,44 @@ class PadCollateInputs: def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False, split='train'): - rank0_print(f"World size: {dist.get_world_size()}") - dataset = load_dataset(f"{ds_path}", split=split) + world_size = dist.get_world_size() if dist.is_initialized() else 1 + rank0_print(f"World size: {world_size}") + + if os.path.exists(ds_path): + dataset = Dataset.load_from_disk(ds_path) + else: + dataset = load_dataset(ds_path, split=split) rank0_print(f"Dataset size: {len(dataset)}") model = Embedder() - dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64) + if "input_ids" not in dataset.column_names: + dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64) - columns_to_keep = ["input_ids", "attention_mask"] - #to_remove = [e for e in dataset.column_names if not e in columns_to_keep] - print('cols: ', dataset.column_names) - #dataset = dataset.remove_columns(to_remove) - #dataset = Dataset.load_from_disk(ds_path) - #dataset = dataset.remove_columns(["url", "title", "text"]) + columns_to_keep = ["id", "input_ids", "attention_mask"] + to_remove = [e for e in dataset.column_names if not e in columns_to_keep] + dataset = dataset.remove_columns(to_remove) + dataset = dataset.with_format("torch") - num_processes = dist.get_world_size() - local_rank = dist.get_rank() + num_processes = dist.get_world_size() if dist.is_initialized() else 1 + local_rank = dist.get_rank() if dist.is_initialized() else 0 - #collator = PadCollateInputs(model.tokenizer) + if num_processes > 1: + sampler = DistributedSampler( + dataset, + shuffle=False, + drop_last=False, + num_replicas=num_processes, + rank=local_rank, + ) + else: + sampler = None - sampler = DistributedSampler( - dataset, - shuffle=False, - drop_last=False, - num_replicas=num_processes, - rank=local_rank, - ) dataloader = DataLoader( dataset, - # collate_fn=collator, batch_size=batch_size, sampler=sampler, drop_last=False, @@ -100,7 +108,7 @@ def main(): args = parser.parse_args() - embed_texts(args.ds_path, args.batch_size) + embed_texts(args.ds_path, args.batch_size, save_to_disk=True) if __name__ == "__main__": From da5ce0a181e512a7a98a2bd8fd02ebffd1ae3bb3 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Tue, 25 Apr 2023 20:36:04 +0000 Subject: [PATCH 26/36] chore: ignore large arrow files --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 42f80eae..26f820e5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +*.arrow *.swp # ignore knn index gpt4all/index/**/ From a58d1eb3bd9fb7ff2d7d9d246d9a891e3d1e1ad9 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Tue, 25 Apr 2023 21:28:42 +0000 Subject: [PATCH 27/36] refactor: move file around --- test_gpt_jr.py => gpt4all/models/test_gpt_jr.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test_gpt_jr.py => gpt4all/models/test_gpt_jr.py (100%) diff --git a/test_gpt_jr.py b/gpt4all/models/test_gpt_jr.py similarity index 100% rename from test_gpt_jr.py rename to gpt4all/models/test_gpt_jr.py From 8a917ad4e1735672436b706340aff3aa10d096d1 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Tue, 25 Apr 2023 21:28:56 +0000 Subject: [PATCH 28/36] chore: create data folder --- gpt4all/data/__init__.py | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 gpt4all/data/__init__.py diff --git a/gpt4all/data/__init__.py b/gpt4all/data/__init__.py new file mode 100644 index 00000000..c530d861 --- /dev/null +++ b/gpt4all/data/__init__.py @@ -0,0 +1,2 @@ +from .instruction_tuning_dataloader import * +from .retrieval_dataloader import * \ No newline at end of file From 80d810322a42b1eb49b9addfd129c41c1426943a Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 1 May 2023 21:38:01 +0000 Subject: [PATCH 29/36] fix: lr schedule --- .../instruction_tuning_dataloader.py} | 63 ------------------- gpt4all/train/train.py | 6 +- 2 files changed, 3 insertions(+), 66 deletions(-) rename gpt4all/{utils/data.py => data/instruction_tuning_dataloader.py} (75%) diff --git a/gpt4all/utils/data.py b/gpt4all/data/instruction_tuning_dataloader.py similarity index 75% rename from gpt4all/utils/data.py rename to gpt4all/data/instruction_tuning_dataloader.py index 3216441f..37fef4d2 100644 --- a/gpt4all/utils/data.py +++ b/gpt4all/data/instruction_tuning_dataloader.py @@ -117,69 +117,6 @@ def load_data(config, tokenizer): return train_dataloader, val_dataloader -def load_retrieval_augmented_data(config, tokenizer): - dataset_path = config["dataset_path"] - index_path = config['index_path'] - - #TODO this should precache at some point - index = hnswlib.Index(space=config['index_space'], dim=config['index_dim']) - index.load_index(index_path) - - if os.path.exists(dataset_path): - if os.path.isdir(dataset_path): - files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl")) - else: - files = [dataset_path] - - print(f"Reading files {files}") - - dataset = load_dataset("json", data_files=files, split="train") - - else: - dataset = load_dataset(dataset_path, split="train") - - dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) - - train_dataset, val_dataset = dataset["train"], dataset["test"] - - if config["streaming"] is False: - kwargs = {"num_proc": config["num_proc"]} - else: - kwargs = {} - - # tokenize inputs and return labels and attention mask - train_dataset = train_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele), - batched=True, - remove_columns=["source", "prompt"], - **kwargs - ) - val_dataset = val_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele), - batched=True, - remove_columns=["source", "prompt"], - **kwargs - ) - - train_dataset = train_dataset.with_format("torch") - val_dataset = val_dataset.with_format("torch") - - # create dataloader with default data collator since we already have labels - - train_dataloader = DataLoader( - train_dataset, - collate_fn=DefaultDataCollator(), - batch_size=config["batch_size"], - ) - - val_dataloader = DataLoader( - val_dataset, - collate_fn=DefaultDataCollator(), - batch_size=config["batch_size"], - ) - - return train_dataloader, val_dataloader - def load_data_for_inference(config, tokenizer): diff --git a/gpt4all/train/train.py b/gpt4all/train/train.py index 97b6c9a8..75d8eda3 100644 --- a/gpt4all/train/train.py +++ b/gpt4all/train/train.py @@ -1,5 +1,5 @@ import os -from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler import torch from torch.optim import AdamW from argparse import ArgumentParser @@ -7,7 +7,7 @@ from gpt4all.utils.read import read_config from accelerate import Accelerator from accelerate.utils import DummyScheduler, DummyOptim, set_seed from peft import get_peft_model, LoraConfig, TaskType -from gpt4all.utils.data import load_data +from gpt4all.data.instruction_tuning_dataloader import load_data from torchmetrics import MeanMetric from tqdm import tqdm import wandb @@ -104,7 +104,7 @@ def train(accelerator, config): ) else: scheduler = DummyScheduler( - optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"] + optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"] ) model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( From 48e07be9e961c850a5ad817d8e31b0c04a66fd43 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 1 May 2023 21:38:23 +0000 Subject: [PATCH 30/36] feat: training script --- configs/train/finetune_gptjr.yaml | 37 ++++++----- gpt4all/data/instruction_tuning_dataloader.py | 61 ++----------------- .../train/{train_r.py => train_retrieval.py} | 54 +++++++++------- 3 files changed, 55 insertions(+), 97 deletions(-) rename gpt4all/train/{train_r.py => train_retrieval.py} (81%) diff --git a/configs/train/finetune_gptjr.yaml b/configs/train/finetune_gptjr.yaml index a03f8ff0..352487c7 100644 --- a/configs/train/finetune_gptjr.yaml +++ b/configs/train/finetune_gptjr.yaml @@ -1,41 +1,40 @@ # model/tokenizer -model_name: "nomic-ai/gpt4all-j" -tokenizer_name: "nomic-ai/gpt4all-j" -version: 'v1.2-jazzy' +model_name: "EleutherAI/gpt-j-6B" +tokenizer_name: "EleutherAI/gpt-j-6B" +version: null gradient_checkpointing: true -save_name: # CHANGE +save_name: "nomic-ai/gpt-jr-decay-alpha" +encoder_dim: 384 # dataset streaming: false num_proc: 64 -dataset_path: "squad" +dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train" max_length: 1024 batch_size: 32 +pct_test: 0.05 +q_column: "question" +a_column: "answers" +encoder_column: "neighbor_embeddings" -#index -index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin" -index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki_sample_tokenized_embedded_with_text" -index_space: "cosine" -index_dim: 384 -query_embedding_field: 'question' # train dynamics -lr: 2.0e-5 +lr: 1.0e-4 min_lr: 0 weight_decay: 0.0 -eval_every: 500 -eval_steps: 105 +eval_every: 50 save_every: 500 log_grads_every: 100 -output_dir: # CHANGE +log_lr_every: 10 +output_dir: "ckpts/decay_alpha" checkpoint: null lora: false warmup_steps: 500 -num_epochs: 2 +num_epochs: 5 # logging -wandb: false -wandb_entity: # CHANGE -wandb_project_name: # CHANGE +wandb: true +wandb_entity: gpt4all +wandb_project_name: retrieval seed: 42 diff --git a/gpt4all/data/instruction_tuning_dataloader.py b/gpt4all/data/instruction_tuning_dataloader.py index 37fef4d2..5803ea76 100644 --- a/gpt4all/data/instruction_tuning_dataloader.py +++ b/gpt4all/data/instruction_tuning_dataloader.py @@ -5,58 +5,7 @@ import os import hnswlib from torch.utils.data import DataLoader from transformers import DefaultDataCollator - - - -def tokenize_inputs(config, tokenizer, examples): - max_length = config["max_length"] - - # hacky backward compatible - different_eos = tokenizer.eos_token != "" - out = {"labels": [], "input_ids": []} - for prompt, response in zip(examples["prompt"], examples["response"]): - if different_eos: - if response.count(" \n") > 0: - response = response.replace(" \n", f"{tokenizer.eos_token} \n") - - prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0]) - - # hack if our prompt is super long - # we need to include some labels so we arbitrarily trunacate at max_length // 2 - # if the length is too long - if prompt_len >= max_length // 2: - # if prompt is too long, truncate - # but make sure to truncate to at max 1024 tokens - new_len = min(max_length // 2, len(prompt) // 2) - prompt = prompt[:new_len] - # get new prompt length - prompt_len = tokenizer(prompt + "\n", return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item() - - assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}" - - input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token, - truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze() - - labels = input_tokens.clone() - labels[:prompt_len] = -100 - if len(labels) < max_length: - # pad to max_length with -100 - labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)]) - - assert (labels == -100).sum() < len(labels), f"Labels are all -100, something wrong. prompt length {prompt_len} exceeds max length {max_length}" - - if (labels == -100).sum() == len(labels) - 1: - print(prompt) - print(response) - raise - - input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"] - out["labels"].append(labels) - out["input_ids"].append(input_tokens) - - out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()} - - return out +from .preprocess import tokenize_inputs def load_data(config, tokenizer): @@ -86,13 +35,13 @@ def load_data(config, tokenizer): # tokenize inputs and return labels and attention mask train_dataset = train_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele), + lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"), batched=True, remove_columns=["source", "prompt"], **kwargs ) val_dataset = val_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele), + lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"), batched=True, remove_columns=["source", "prompt"], **kwargs @@ -154,12 +103,12 @@ def load_data_for_inference(config, tokenizer): # tokenize inputs and return labels and attention mask train_dataset = train_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele), + lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"), batched=True, **kwargs ) val_dataset = val_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele), + lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"), batched=True, **kwargs ) diff --git a/gpt4all/train/train_r.py b/gpt4all/train/train_retrieval.py similarity index 81% rename from gpt4all/train/train_r.py rename to gpt4all/train/train_retrieval.py index 9254d6ef..1e31f01d 100644 --- a/gpt4all/train/train_r.py +++ b/gpt4all/train/train_retrieval.py @@ -7,11 +7,13 @@ from gpt4all.utils.read import read_config from accelerate import Accelerator from accelerate.utils import DummyScheduler, DummyOptim, set_seed from peft import get_peft_model, LoraConfig, TaskType -from gpt4all.utils.data import load_data, load_retrieval_augmented_data +from gpt4all.data.retrieval_dataloader import load_retrieval_augmented_data from torchmetrics import MeanMetric from tqdm import tqdm from gpt4all.models import GPTJRForCausalLM +from gpt4all.train.metrics import f1_score, exact_match_score import wandb +import torch.distributed as dist torch.backends.cuda.matmul.allow_tf32 = True @@ -22,15 +24,18 @@ def format_metrics(metrics, split, prefix=""): return log -def evaluate(model, val_dataloader): +def evaluate(model, val_dataloader, step, main_process=False): model.eval() val_loss = MeanMetric(nan_strategy="error").to(model.device) with torch.no_grad(): - for batch in tqdm(val_dataloader): - loss = model(**batch).loss + for batch in tqdm(val_dataloader, disable=not main_process): + outputs = model(input_ids=batch["input_ids"], + labels=batch["labels"], + encoder_hidden_states=batch["encoder_hidden_states"], + step=step) + loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()}) - loss_values = accelerator.gather_for_metrics({"loss": loss.detach()}) val_loss.update(loss_values["loss"]) @@ -50,20 +55,18 @@ def train(accelerator, config): with accelerator.main_process_first(): - - if 'index_path' in config: - train_dataloader, val_dataloader = load_retrieval_augmented_data(config, tokenizer) - else: - train_dataloader, val_dataloader = load_data(config, tokenizer) + train_dataloader, val_dataloader = load_retrieval_augmented_data(config, tokenizer) checkpoint = config["gradient_checkpointing"] #ensures back compat with non retrieval models - if 'index_path' in config: - model = GPTJRForCausalLM.from_pretrained(config["model_name"], - revision=config['version'], - use_cache=False if checkpoint else True, - trust_remote_code=True) + if 'encoder_dim' in config: + with accelerator.main_process_first(): + model = GPTJRForCausalLM.from_pretrained(config["model_name"], + revision=config['version'] if 'version' in config else None, + use_cache=False if checkpoint else True, + encoder_dim=config["encoder_dim"], + ) else: model = AutoModelForCausalLM.from_pretrained(config["model_name"], use_cache=False if checkpoint else True, @@ -117,13 +120,14 @@ def train(accelerator, config): ) else: scheduler = DummyScheduler( - optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"] + optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"] ) model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare( model, optimizer, train_dataloader, val_dataloader, scheduler ) + # setup for saving training states in case preemption accelerator.register_for_checkpointing(scheduler) @@ -141,11 +145,16 @@ def train(accelerator, config): if accelerator.is_main_process and config["wandb"]: wandb.watch(model, log_freq=config["log_grads_every"], log="all") + main_process = accelerator.is_main_process + for epoch in range(config["num_epochs"]): train_loss = MeanMetric(nan_strategy="error").to(model.device) - for step, batch in enumerate(tqdm(train_dataloader)): + for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)): model.train() - outputs = model(**batch) + outputs = model(input_ids=batch["input_ids"], + labels=batch["labels"], + encoder_hidden_states=batch["encoder_hidden_states"], + step=step) loss = outputs.loss # gather loss before backprop in case of gradient accumulation @@ -157,8 +166,8 @@ def train(accelerator, config): # get gradient norm of all params # log LR in case something weird happens - if step > 0 and step % (config["eval_every"] // 10) == 0: - if config["wandb"]: + if config["wandb"]: + if step > 0 and step % (config["log_lr_every"] ) == 0: curr_step = step + epoch * len(train_dataloader) accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step) @@ -173,13 +182,14 @@ def train(accelerator, config): accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): - val_loss = evaluate(model, val_dataloader) + curr_step = step + epoch * len(train_dataloader) + val_loss = evaluate(model, val_dataloader, step=curr_step, main_process=main_process) log_train = { "train_loss": train_loss.compute() } log_val = { - "val_loss": val_loss.compute() + "val_loss": val_loss.compute(), } if config["wandb"]: From c9dd9152c3fa4cae1049a3bfebd9a4a4b5427665 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 1 May 2023 21:38:36 +0000 Subject: [PATCH 31/36] feat: model def + metrics --- gpt4all/models/configuration_gpt_jr.py | 1 - gpt4all/models/modeling_gpt_jr.py | 56 +++++++++----------------- gpt4all/train/metrics.py | 50 +++++++++++++++++++++++ 3 files changed, 68 insertions(+), 39 deletions(-) create mode 100644 gpt4all/train/metrics.py diff --git a/gpt4all/models/configuration_gpt_jr.py b/gpt4all/models/configuration_gpt_jr.py index fdb9eec4..314a87af 100644 --- a/gpt4all/models/configuration_gpt_jr.py +++ b/gpt4all/models/configuration_gpt_jr.py @@ -139,7 +139,6 @@ class GPTJRConfig(PretrainedConfig): self.eos_token_id = eos_token_id self.encoder_dim = encoder_dim - self.encoder_path = encoder_path super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs diff --git a/gpt4all/models/modeling_gpt_jr.py b/gpt4all/models/modeling_gpt_jr.py index a38ad5b0..36e053ea 100644 --- a/gpt4all/models/modeling_gpt_jr.py +++ b/gpt4all/models/modeling_gpt_jr.py @@ -423,8 +423,6 @@ class GPTJRBlock(nn.Module): self.cross_attn = GPTJRCrossAttention(config) self.cross_attn_mlp = GPTJRMLP(inner_dim, config) - self.alpha = nn.Parameter(torch.ones(1), requires_grad=False).to(self.ln_1.weight.dtype) - self.step = 1 self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else torch.cuda.device_count() or 1 def forward( @@ -436,6 +434,7 @@ class GPTJRBlock(nn.Module): head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + step: Optional[int] = None, ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: # shape (bs, seq_len, hidden_dim) residual = hidden_states @@ -455,7 +454,8 @@ class GPTJRBlock(nn.Module): self_attention_residual = attn_output + feed_forward_hidden_states + residual # encoder_hidden_states -> (bs, knn, encoder_dim) - # may not need, can norm encodings + if encoder_hidden_states.dtype != hidden_states.dtype: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) encoder_normed = self.ln_2(encoder_hidden_states) # cross_attn_outputs -> (bs, seq_len, dim) @@ -473,24 +473,22 @@ class GPTJRBlock(nn.Module): cross_attn_output[0] ) - alpha = self.alpha if self.training else 0.5 - hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual + if step is not None: + alpha = self._update_alpha(step) + alpha = alpha.to(cross_attn_ff.device).to(cross_attn_ff.dtype) + else: + alpha = 0.5 + hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual if use_cache: outputs = (hidden_states,) + outputs else: outputs = (hidden_states,) + outputs[1:] - # if training update alpha - if self.training: - self.step += 1 - self._update_alpha(self.step) - - return outputs # hidden_states, present, (attentions) def _update_alpha(self, iteration): - self.alpha.data = torch.clamp(torch.tensor([1 / (iteration * self.world_size) ** 0.05]), min=torch.tensor([0.5]), max=torch.tensor([1.0])) + return torch.clamp(torch.tensor([1 / (max(iteration * self.world_size, 1)) ** 0.08]), min=torch.tensor([0.5]), max=torch.tensor([1.0])) class GPTJRPreTrainedModel(PreTrainedModel): @@ -597,6 +595,7 @@ class GPTJRModel(GPTJRPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + step: Optional[int] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -704,7 +703,7 @@ class GPTJRModel(GPTJRPreTrainedModel): def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, use_cache, output_attentions) + return module(*inputs, use_cache, output_attentions, step) return custom_forward @@ -725,6 +724,7 @@ class GPTJRModel(GPTJRPreTrainedModel): head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, + step=step ) hidden_states = outputs[0] @@ -765,11 +765,6 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel): super().__init__(config) self.transformer = GPTJRModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size) - if config.encoder_path is not None: - self.encoder = AutoModel.from_pretrained(config.encoder_path) - # freeze encoder and don't get gradiets - self.encoder.requires_grad_(False) - # Model parallel self.model_parallel = False @@ -832,23 +827,20 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel): def forward( self, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor, + encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.FloatTensor] = None, - decoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.FloatTensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + step: Optional[int] = None ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -859,22 +851,9 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - # Convert encoder inputs in embeddings if needed - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - transformer_outputs = self.transformer( input_ids, - encoder_hidden_states=encoder_outputs, + encoder_hidden_states=encoder_hidden_states, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -885,6 +864,7 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + step=step, ) hidden_states = transformer_outputs[0] diff --git a/gpt4all/train/metrics.py b/gpt4all/train/metrics.py new file mode 100644 index 00000000..dac602d8 --- /dev/null +++ b/gpt4all/train/metrics.py @@ -0,0 +1,50 @@ +from collections import Counter +import string +import re + + +# adapted from huggingface +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(predictions, ground_truths): + total_f1 = [] + for prediction, ground_truth in zip(predictions, ground_truths): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + total_f1.append(0) + continue + + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + total_f1.append(f1) + + return total_f1 + + +def exact_match_score(predictions, ground_truths): + exact_scores = [] + for prediction, ground_truth in zip(predictions, ground_truths): + exact_scores.append(normalize_answer(prediction) == normalize_answer(ground_truth)) + + return exact_scores From 0c0a56acab47c16b47e5cc6d1083c5e7702c9cc6 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 1 May 2023 21:38:46 +0000 Subject: [PATCH 32/36] feat: data preprocessing --- gpt4all/data/preprocess.py | 51 +++++++++++++++++++ gpt4all/data/retrieval_dataloader.py | 75 ++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 gpt4all/data/preprocess.py create mode 100644 gpt4all/data/retrieval_dataloader.py diff --git a/gpt4all/data/preprocess.py b/gpt4all/data/preprocess.py new file mode 100644 index 00000000..6f5c2e26 --- /dev/null +++ b/gpt4all/data/preprocess.py @@ -0,0 +1,51 @@ +import torch + +def tokenize_inputs(config, tokenizer, examples, input_col, target_col): + max_length = config["max_length"] + + # hacky backward compatible + different_eos = tokenizer.eos_token != "" + out = {"labels": [], "input_ids": []} + for prompt, response in zip(examples[input_col], examples[target_col]): + if different_eos: + if response.count(" \n") > 0: + response = response.replace(" \n", f"{tokenizer.eos_token} \n") + + prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0]) + + # hack if our prompt is super long + # we need to include some labels so we arbitrarily trunacate at max_length // 2 + # if the length is too long + if prompt_len >= max_length // 2: + # if prompt is too long, truncate + # but make sure to truncate to at max 1024 tokens + new_len = min(max_length // 2, len(prompt) // 2) + prompt = prompt[:new_len] + # get new prompt length + prompt_len = tokenizer(prompt + "\n", return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item() + + assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}" + + input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token, + truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze() + + labels = input_tokens.clone() + labels[:prompt_len] = -100 + if len(labels) < max_length: + # pad to max_length with -100 + labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)]) + + assert (labels == -100).sum() < len(labels), f"Labels are all -100, something wrong. prompt length {prompt_len} exceeds max length {max_length}" + + if (labels == -100).sum() == len(labels) - 1: + print(prompt) + print(response) + raise + + input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"] + out["labels"].append(labels) + out["input_ids"].append(input_tokens) + + out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()} + + return out \ No newline at end of file diff --git a/gpt4all/data/retrieval_dataloader.py b/gpt4all/data/retrieval_dataloader.py new file mode 100644 index 00000000..75c87574 --- /dev/null +++ b/gpt4all/data/retrieval_dataloader.py @@ -0,0 +1,75 @@ +from datasets import load_dataset, Dataset +import os +from torch.utils.data import DataLoader +from .preprocess import tokenize_inputs +from transformers import DefaultDataCollator + + +def load_retrieval_augmented_data(config, tokenizer, split="train", split_dataset=True): + dataset_path = config["dataset_path"] + + if os.path.exists(dataset_path): + dataset = Dataset.load_from_disk(dataset_path) + else: + dataset = load_dataset(dataset_path, split=split) + + + question_col = config["q_column"] + answer_col = config["a_column"] + encoder_column = config["encoder_column"] + + if config["streaming"] is False: + kwargs = {"num_proc": config["num_proc"]} + else: + kwargs = {} + + # strip any unneccessary whitespace + # there's one question that's includes a ton of whitespace + dataset = dataset.map(lambda ele: {question_col: [q.strip() for q in ele[question_col]]}, batched=True) + # in squad, the data is formatted where each ele in answers is a dict where the key text holds + # a list of the answer + dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True) + + dataset = dataset.map( + lambda ele: tokenize_inputs(config, tokenizer, ele, question_col, answer_col), + batched=True, + **kwargs + ) + + # tokenize inputs + labels in teacher-force format + # rename encoder hidden states if not already called that + if encoder_column != "encoder_hidden_states": + dataset = dataset.rename_column(encoder_column, "encoder_hidden_states") + + columns_to_keep = ["input_ids", "labels", "encoder_hidden_states"] + + col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep] + dataset = dataset.remove_columns(col_names_to_rm) + + if split_dataset: + dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"]) + train_dataset, val_dataset = dataset["train"], dataset["test"] + + train_dataloader = DataLoader( + train_dataset, + batch_size=config["batch_size"], + collate_fn=DefaultDataCollator(), + ) + + val_dataloader = DataLoader( + val_dataset, + batch_size=config["batch_size"], + collate_fn=DefaultDataCollator(), + ) + + return train_dataloader, val_dataloader + + else: + dataloader = DataLoader( + dataset, + batch_size=config["batch_size"], + collate_fn=DefaultDataCollator(), + ) + + return dataloader + From 1b3f18bef262340a03998057849747362bc680cd Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 1 May 2023 21:39:09 +0000 Subject: [PATCH 33/36] fix: import path --- gpt4all/inference/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt4all/inference/inference.py b/gpt4all/inference/inference.py index 2096819c..c69e50bf 100644 --- a/gpt4all/inference/inference.py +++ b/gpt4all/inference/inference.py @@ -4,7 +4,7 @@ import torch.nn as nn from argparse import ArgumentParser from gpt4all.utils.read import read_config from accelerate.utils import set_seed -from gpt4all.utils.data import load_data_for_inference +from gpt4all.data.instruction_tuning_dataloader import load_data_for_inference from gpt4all.utils.distributed_utils import rank0_print from tqdm import tqdm from datasets import Dataset From 3736eda56a1594127f8a32931d1ec3052e9b47e3 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 1 May 2023 21:39:21 +0000 Subject: [PATCH 34/36] feat: eval for retrieval --- configs/eval/evaluate_gpt4all_jr.yaml | 18 +++++++++ gpt4all/eval/eval_squad.py | 54 +++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 configs/eval/evaluate_gpt4all_jr.yaml create mode 100644 gpt4all/eval/eval_squad.py diff --git a/configs/eval/evaluate_gpt4all_jr.yaml b/configs/eval/evaluate_gpt4all_jr.yaml new file mode 100644 index 00000000..989928e8 --- /dev/null +++ b/configs/eval/evaluate_gpt4all_jr.yaml @@ -0,0 +1,18 @@ +# model/tokenizer +model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/epoch_2" +tokenizer_name: "EleutherAI/gpt-j-6B" +version: null +gradient_checkpointing: true +save_name: "nomic-ai/gpt-jr" +encoder_dim: 384 + +# dataset +streaming: false +num_proc: 64 +dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_validation" +max_length: 1024 +batch_size: 32 +pct_test: 0.05 +q_column: "question" +a_column: "answers" +encoder_column: "neighbor_embeddings" \ No newline at end of file diff --git a/gpt4all/eval/eval_squad.py b/gpt4all/eval/eval_squad.py new file mode 100644 index 00000000..d474e4cc --- /dev/null +++ b/gpt4all/eval/eval_squad.py @@ -0,0 +1,54 @@ +import torch +from gpt4all.models import GPTJRForCausalLM +from gpt4all.data.retrieval_dataloader import load_retrieval_augmented_data +from gpt4all.train.metrics import f1_score, exact_match_score +from gpt4all.utils.read import read_config +from transformers import AutoTokenizer +from argparse import ArgumentParser +from tqdm import tqdm + +parser = ArgumentParser() +parser.add_argument("--config", type=str, default="config.yaml") + +args = parser.parse_args() + +config = read_config(args.config) + +tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"]) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +dataloader = load_retrieval_augmented_data(config, tokenizer, split_dataset=False) + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = GPTJRForCausalLM.from_pretrained(config["model_name"], use_cache=False) +model.to(device) +model.eval() + +# Evaluate the model on the SQUAD dataset +f1s = [] +exact_matches = [] +with torch.no_grad(): + for batch in tqdm(dataloader): + outputs = model(input_ids=batch["input_ids"].to(device), + labels=batch["labels"].to(device), + encoder_hidden_states=batch["encoder_hidden_states"].to(device)) + + predicted_tokens = outputs.logits.argmax(dim=-1) + predicted = tokenizer.batch_decode(predicted_tokens, skip_special_tokens=True) + + labels = batch["labels"] + mask = labels == -100 + labels[mask] = tokenizer.pad_token_id + ground_truth = tokenizer.batch_decode(labels, skip_special_tokens=True) + + f1 = f1_score(predicted, ground_truth) + exact_match = exact_match_score(predicted, ground_truth) + + f1s.extend(f1) + exact_matches.extend(exact_match) + + +print(torch.tensor(f1s).mean()) +print(torch.tensor(exact_matches).to(torch.float32).mean()) \ No newline at end of file From 0f61cd8b42503785c258451a67751be6667b779f Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 1 May 2023 21:39:40 +0000 Subject: [PATCH 35/36] fix: retrieval dataset only has train split --- gpt4all/index/prep_index_for_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt4all/index/prep_index_for_train.py b/gpt4all/index/prep_index_for_train.py index 64e306d8..b0d32711 100644 --- a/gpt4all/index/prep_index_for_train.py +++ b/gpt4all/index/prep_index_for_train.py @@ -38,7 +38,7 @@ def prep_index(): if os.path.exists(config['index_database']): retrieval_dataset = Dataset.load_from_disk(config['index_database']) else: - retrieval_dataset = load_dataset(config['index_database'], split=args.split) + retrieval_dataset = load_dataset(config['index_database'], split="train") # vectorize queries query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded_{args.split}" From 00f04360d23ab091ed74219d6ea33ff4d26e9c1e Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 1 May 2023 21:39:55 +0000 Subject: [PATCH 36/36] fix: config for index building --- configs/index/finetune_gptjr.yaml | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 configs/index/finetune_gptjr.yaml diff --git a/configs/index/finetune_gptjr.yaml b/configs/index/finetune_gptjr.yaml new file mode 100644 index 00000000..883a59f2 --- /dev/null +++ b/configs/index/finetune_gptjr.yaml @@ -0,0 +1,13 @@ +# dataset +streaming: false +num_proc: 64 +dataset_path: "squad" +max_length: 1024 +batch_size: 32 + +#index +index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-cohere-index.bin" +index_database: "nomic-ai/cohere-wiki-sbert" +index_space: "cosine" +index_dim: 384 +query_embedding_field: 'question' \ No newline at end of file