mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-16 15:07:18 +00:00
refactor: clean up embed texts
This commit is contained in:
parent
c20379f7e9
commit
b0f92b610e
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@ -12,6 +13,8 @@ import torch
|
|||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
# this isn't used but keeping in case we need it in the future
|
||||||
|
# collate and pad inputs to the right shape
|
||||||
class PadCollateInputs:
|
class PadCollateInputs:
|
||||||
def __init__(self, tokenizer):
|
def __init__(self, tokenizer):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@ -31,39 +34,44 @@ class PadCollateInputs:
|
|||||||
|
|
||||||
|
|
||||||
def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False, split='train'):
|
def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False, split='train'):
|
||||||
rank0_print(f"World size: {dist.get_world_size()}")
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
dataset = load_dataset(f"{ds_path}", split=split)
|
rank0_print(f"World size: {world_size}")
|
||||||
|
|
||||||
|
if os.path.exists(ds_path):
|
||||||
|
dataset = Dataset.load_from_disk(ds_path)
|
||||||
|
else:
|
||||||
|
dataset = load_dataset(ds_path, split=split)
|
||||||
rank0_print(f"Dataset size: {len(dataset)}")
|
rank0_print(f"Dataset size: {len(dataset)}")
|
||||||
|
|
||||||
model = Embedder()
|
model = Embedder()
|
||||||
|
|
||||||
dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64)
|
if "input_ids" not in dataset.column_names:
|
||||||
|
dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64)
|
||||||
|
|
||||||
columns_to_keep = ["input_ids", "attention_mask"]
|
|
||||||
#to_remove = [e for e in dataset.column_names if not e in columns_to_keep]
|
|
||||||
print('cols: ', dataset.column_names)
|
|
||||||
#dataset = dataset.remove_columns(to_remove)
|
|
||||||
|
|
||||||
#dataset = Dataset.load_from_disk(ds_path)
|
columns_to_keep = ["id", "input_ids", "attention_mask"]
|
||||||
#dataset = dataset.remove_columns(["url", "title", "text"])
|
to_remove = [e for e in dataset.column_names if not e in columns_to_keep]
|
||||||
|
dataset = dataset.remove_columns(to_remove)
|
||||||
|
|
||||||
dataset = dataset.with_format("torch")
|
dataset = dataset.with_format("torch")
|
||||||
|
|
||||||
num_processes = dist.get_world_size()
|
num_processes = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
local_rank = dist.get_rank()
|
local_rank = dist.get_rank() if dist.is_initialized() else 0
|
||||||
|
|
||||||
|
|
||||||
#collator = PadCollateInputs(model.tokenizer)
|
if num_processes > 1:
|
||||||
|
sampler = DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
num_replicas=num_processes,
|
||||||
|
rank=local_rank,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampler = None
|
||||||
|
|
||||||
sampler = DistributedSampler(
|
|
||||||
dataset,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False,
|
|
||||||
num_replicas=num_processes,
|
|
||||||
rank=local_rank,
|
|
||||||
)
|
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
# collate_fn=collator,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
@ -100,7 +108,7 @@ def main():
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
embed_texts(args.ds_path, args.batch_size)
|
embed_texts(args.ds_path, args.batch_size, save_to_disk=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user