From 4b51e6ef376871a5e59c9c80903b8e08a0ec7256 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 7 Apr 2023 19:04:19 +0000 Subject: [PATCH] fix: pyarrow filter --- inference.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/inference.py b/inference.py index 0fa73025..25e3b508 100644 --- a/inference.py +++ b/inference.py @@ -12,6 +12,8 @@ from transformers.trainer_pt_utils import nested_numpify from transformers import DefaultDataCollator from torch.utils.data import DataLoader, DistributedSampler import numpy as np +import pyarrow as pa +from pyarrow import compute as pc def calc_cross_entropy_no_reduction(lm_logits, labels): @@ -116,7 +118,13 @@ def inference(config): df_train = df_train.sort("index") curr_idx = df_train["index"] - filtered_train = train_dataset.filter(lambda example: example["index"] in curr_idx) + # compute mask in pyarrow since it's super fast + # ty @bmschmidt for showing me this! + table = train_dataset.data + mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32())) + filtered_table = table.filter(mask) + # convert from pyarrow to Dataset + filtered_train = Dataset.from_dict(filtered_table.to_pydict()) filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"]) filtered_train = filtered_train.add_column("loss", df_train["loss"]) @@ -167,7 +175,13 @@ def inference(config): df_val = df_val.sort("index") curr_idx = df_val["index"] - filtered_val = val_dataset.filter(lambda example: example["index"] in curr_idx) + # compute mask in pyarrow since it's super fast + # ty @bmschmidt for showing me this! + table = val_dataset.data + mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32())) + filtered_table = table.filter(mask) + # convert from pyarrow to Dataset + filtered_val = Dataset.from_dict(filtered_table.to_pydict()) filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"]) filtered_val = filtered_val.add_column("loss", df_val["loss"])