refactor: clean up embed texts

This commit is contained in:
Zach Nussbaum 2023-04-25 20:34:49 +00:00
parent c20379f7e9
commit b0f92b610e

View File

@ -1,3 +1,4 @@
import os
import torch.distributed as dist
from argparse import ArgumentParser
from datasets import Dataset
@ -12,6 +13,8 @@ import torch
from datasets import load_dataset
# this isn't used but keeping in case we need it in the future
# collate and pad inputs to the right shape
class PadCollateInputs:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
@ -31,39 +34,44 @@ class PadCollateInputs:
def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False, split='train'):
rank0_print(f"World size: {dist.get_world_size()}")
dataset = load_dataset(f"{ds_path}", split=split)
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank0_print(f"World size: {world_size}")
if os.path.exists(ds_path):
dataset = Dataset.load_from_disk(ds_path)
else:
dataset = load_dataset(ds_path, split=split)
rank0_print(f"Dataset size: {len(dataset)}")
model = Embedder()
dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64)
if "input_ids" not in dataset.column_names:
dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64)
columns_to_keep = ["input_ids", "attention_mask"]
#to_remove = [e for e in dataset.column_names if not e in columns_to_keep]
print('cols: ', dataset.column_names)
#dataset = dataset.remove_columns(to_remove)
#dataset = Dataset.load_from_disk(ds_path)
#dataset = dataset.remove_columns(["url", "title", "text"])
columns_to_keep = ["id", "input_ids", "attention_mask"]
to_remove = [e for e in dataset.column_names if not e in columns_to_keep]
dataset = dataset.remove_columns(to_remove)
dataset = dataset.with_format("torch")
num_processes = dist.get_world_size()
local_rank = dist.get_rank()
num_processes = dist.get_world_size() if dist.is_initialized() else 1
local_rank = dist.get_rank() if dist.is_initialized() else 0
#collator = PadCollateInputs(model.tokenizer)
if num_processes > 1:
sampler = DistributedSampler(
dataset,
shuffle=False,
drop_last=False,
num_replicas=num_processes,
rank=local_rank,
)
else:
sampler = None
sampler = DistributedSampler(
dataset,
shuffle=False,
drop_last=False,
num_replicas=num_processes,
rank=local_rank,
)
dataloader = DataLoader(
dataset,
# collate_fn=collator,
batch_size=batch_size,
sampler=sampler,
drop_last=False,
@ -100,7 +108,7 @@ def main():
args = parser.parse_args()
embed_texts(args.ds_path, args.batch_size)
embed_texts(args.ds_path, args.batch_size, save_to_disk=True)
if __name__ == "__main__":