diff --git a/gpt4all/index/embed_texts.py b/gpt4all/index/embed_texts.py index c2007fde..9168a6c8 100644 --- a/gpt4all/index/embed_texts.py +++ b/gpt4all/index/embed_texts.py @@ -1,3 +1,4 @@ +import os import torch.distributed as dist from argparse import ArgumentParser from datasets import Dataset @@ -12,6 +13,8 @@ import torch from datasets import load_dataset +# this isn't used but keeping in case we need it in the future +# collate and pad inputs to the right shape class PadCollateInputs: def __init__(self, tokenizer): self.tokenizer = tokenizer @@ -31,39 +34,44 @@ class PadCollateInputs: def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False, split='train'): - rank0_print(f"World size: {dist.get_world_size()}") - dataset = load_dataset(f"{ds_path}", split=split) + world_size = dist.get_world_size() if dist.is_initialized() else 1 + rank0_print(f"World size: {world_size}") + + if os.path.exists(ds_path): + dataset = Dataset.load_from_disk(ds_path) + else: + dataset = load_dataset(ds_path, split=split) rank0_print(f"Dataset size: {len(dataset)}") model = Embedder() - dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64) + if "input_ids" not in dataset.column_names: + dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64) - columns_to_keep = ["input_ids", "attention_mask"] - #to_remove = [e for e in dataset.column_names if not e in columns_to_keep] - print('cols: ', dataset.column_names) - #dataset = dataset.remove_columns(to_remove) - #dataset = Dataset.load_from_disk(ds_path) - #dataset = dataset.remove_columns(["url", "title", "text"]) + columns_to_keep = ["id", "input_ids", "attention_mask"] + to_remove = [e for e in dataset.column_names if not e in columns_to_keep] + dataset = dataset.remove_columns(to_remove) + dataset = dataset.with_format("torch") - num_processes = dist.get_world_size() - local_rank = dist.get_rank() + num_processes = dist.get_world_size() if dist.is_initialized() else 1 + local_rank = dist.get_rank() if dist.is_initialized() else 0 - #collator = PadCollateInputs(model.tokenizer) + if num_processes > 1: + sampler = DistributedSampler( + dataset, + shuffle=False, + drop_last=False, + num_replicas=num_processes, + rank=local_rank, + ) + else: + sampler = None - sampler = DistributedSampler( - dataset, - shuffle=False, - drop_last=False, - num_replicas=num_processes, - rank=local_rank, - ) dataloader = DataLoader( dataset, - # collate_fn=collator, batch_size=batch_size, sampler=sampler, drop_last=False, @@ -100,7 +108,7 @@ def main(): args = parser.parse_args() - embed_texts(args.ds_path, args.batch_size) + embed_texts(args.ds_path, args.batch_size, save_to_disk=True) if __name__ == "__main__":