mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-15 22:53:22 +00:00
refactor: clean up embed texts
This commit is contained in:
parent
c20379f7e9
commit
b0f92b610e
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import torch.distributed as dist
|
||||
from argparse import ArgumentParser
|
||||
from datasets import Dataset
|
||||
@ -12,6 +13,8 @@ import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
# this isn't used but keeping in case we need it in the future
|
||||
# collate and pad inputs to the right shape
|
||||
class PadCollateInputs:
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
@ -31,39 +34,44 @@ class PadCollateInputs:
|
||||
|
||||
|
||||
def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False, split='train'):
|
||||
rank0_print(f"World size: {dist.get_world_size()}")
|
||||
dataset = load_dataset(f"{ds_path}", split=split)
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
rank0_print(f"World size: {world_size}")
|
||||
|
||||
if os.path.exists(ds_path):
|
||||
dataset = Dataset.load_from_disk(ds_path)
|
||||
else:
|
||||
dataset = load_dataset(ds_path, split=split)
|
||||
rank0_print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
model = Embedder()
|
||||
|
||||
dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64)
|
||||
if "input_ids" not in dataset.column_names:
|
||||
dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64)
|
||||
|
||||
columns_to_keep = ["input_ids", "attention_mask"]
|
||||
#to_remove = [e for e in dataset.column_names if not e in columns_to_keep]
|
||||
print('cols: ', dataset.column_names)
|
||||
#dataset = dataset.remove_columns(to_remove)
|
||||
|
||||
#dataset = Dataset.load_from_disk(ds_path)
|
||||
#dataset = dataset.remove_columns(["url", "title", "text"])
|
||||
columns_to_keep = ["id", "input_ids", "attention_mask"]
|
||||
to_remove = [e for e in dataset.column_names if not e in columns_to_keep]
|
||||
dataset = dataset.remove_columns(to_remove)
|
||||
|
||||
dataset = dataset.with_format("torch")
|
||||
|
||||
num_processes = dist.get_world_size()
|
||||
local_rank = dist.get_rank()
|
||||
num_processes = dist.get_world_size() if dist.is_initialized() else 1
|
||||
local_rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
|
||||
|
||||
#collator = PadCollateInputs(model.tokenizer)
|
||||
if num_processes > 1:
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
num_replicas=num_processes,
|
||||
rank=local_rank,
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
num_replicas=num_processes,
|
||||
rank=local_rank,
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
# collate_fn=collator,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
drop_last=False,
|
||||
@ -100,7 +108,7 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
embed_texts(args.ds_path, args.batch_size)
|
||||
embed_texts(args.ds_path, args.batch_size, save_to_disk=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user