fix: pyarrow filter

This commit is contained in:
Zach Nussbaum 2023-04-07 19:04:19 +00:00
parent f974ca651c
commit 1b5e660476

View File

@ -12,6 +12,8 @@ from transformers.trainer_pt_utils import nested_numpify
from transformers import DefaultDataCollator from transformers import DefaultDataCollator
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
import numpy as np import numpy as np
import pyarrow as pa
from pyarrow import compute as pc
def calc_cross_entropy_no_reduction(lm_logits, labels): def calc_cross_entropy_no_reduction(lm_logits, labels):
@ -116,7 +118,13 @@ def inference(config):
df_train = df_train.sort("index") df_train = df_train.sort("index")
curr_idx = df_train["index"] curr_idx = df_train["index"]
filtered_train = train_dataset.filter(lambda example: example["index"] in curr_idx) # compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = train_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_train = Dataset.from_dict(filtered_table.to_pydict())
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"]) filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
filtered_train = filtered_train.add_column("loss", df_train["loss"]) filtered_train = filtered_train.add_column("loss", df_train["loss"])
@ -167,7 +175,13 @@ def inference(config):
df_val = df_val.sort("index") df_val = df_val.sort("index")
curr_idx = df_val["index"] curr_idx = df_val["index"]
filtered_val = val_dataset.filter(lambda example: example["index"] in curr_idx) # compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = val_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"]) filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
filtered_val = filtered_val.add_column("loss", df_val["loss"]) filtered_val = filtered_val.add_column("loss", df_val["loss"])