mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 20:09:58 +00:00
feat: distributed eval of embedder
This commit is contained in:
parent
4fb19d67b5
commit
a2b1f99838
96
gpt4all/index/embed_texts.py
Normal file
96
gpt4all/index/embed_texts.py
Normal file
@ -0,0 +1,96 @@
|
||||
import torch.distributed as dist
|
||||
from argparse import ArgumentParser
|
||||
from datasets import Dataset
|
||||
from gpt4all.index.embed import Embedder
|
||||
from gpt4all.utils.distributed_utils import rank0_print
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers.trainer_pt_utils import nested_numpify
|
||||
from transformers import BatchEncoding
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class PadCollateInputs:
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, batch):
|
||||
mapped_inputs = {"input_ids": [], "attention_mask": []}
|
||||
mapped_inputs["input_ids"] = [b["tokenized_chunk"] for b in batch]
|
||||
mapped_inputs["attention_mask"] = [b["tokenized_attn_mask"] for b in batch]
|
||||
encoding = BatchEncoding(mapped_inputs)
|
||||
|
||||
padded_inputs = self.tokenizer.pad(
|
||||
encoding, padding="max_length", return_tensors="pt"
|
||||
)
|
||||
padded_inputs["id"] = [b["id"] for b in batch]
|
||||
|
||||
return padded_inputs
|
||||
|
||||
|
||||
def embed_texts(ds_path, batch_size):
|
||||
rank0_print(f"World size: {dist.get_world_size()}")
|
||||
|
||||
dataset = Dataset.load_from_disk(ds_path)
|
||||
rank0_print(f"Dataset size: {len(dataset)}")
|
||||
dataset = dataset.remove_columns(["url", "title", "text"])
|
||||
dataset = dataset.with_format("torch")
|
||||
|
||||
num_processes = dist.get_world_size()
|
||||
local_rank = dist.get_rank()
|
||||
|
||||
model = Embedder()
|
||||
|
||||
collator = PadCollateInputs(model.tokenizer)
|
||||
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
num_replicas=num_processes,
|
||||
rank=local_rank,
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
collate_fn=collator,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
model.to(f"cuda:{local_rank}")
|
||||
with torch.no_grad():
|
||||
embedded_outputs = {"id": [], "embedding": []}
|
||||
for batch in tqdm(dataloader, disable=local_rank != 0):
|
||||
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
|
||||
batch["attention_mask"] = batch["attention_mask"].to(f"cuda:{local_rank}")
|
||||
outputs = model(batch)
|
||||
embedded_outputs["id"].extend(batch["id"])
|
||||
embedded_outputs["embedding"].extend(outputs["embedding"])
|
||||
|
||||
embedded_outputs["embedding"] = nested_numpify(embedded_outputs["embedding"])
|
||||
embedded_outputs["id"] = np.stack(embedded_outputs["id"])
|
||||
embedded_outputs["embedding"] = np.stack(embedded_outputs["embedding"])
|
||||
|
||||
ds = Dataset.from_dict(embedded_outputs)
|
||||
|
||||
# feeling lazy, don't want to wait for all_gather to finish
|
||||
# will load and concat in a single process in another script
|
||||
ds.save_to_disk(f"embedded/{ds_path}_embedded_rank_{local_rank}")
|
||||
|
||||
|
||||
def main():
|
||||
dist.init_process_group("nccl")
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--ds_path", type=str, default="tokenized")
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
embed_texts(args.ds_path, args.batch_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parse arguments by reading in a config
|
||||
main()
|
Loading…
Reference in New Issue
Block a user