diff --git a/inference.py b/inference.py index 25e3b508..8a4efb51 100644 --- a/inference.py +++ b/inference.py @@ -115,7 +115,6 @@ def inference(config): train_outputs["embeddings"] = np.concatenate(train_outputs["embeddings"]) df_train = Dataset.from_dict(train_outputs) - df_train = df_train.sort("index") curr_idx = df_train["index"] # compute mask in pyarrow since it's super fast @@ -136,11 +135,11 @@ def inference(config): for batch in tqdm(val_dataloader, disable=local_rank != 0): batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}") batch["labels"] = batch["labels"].to(f"cuda:{local_rank}") - outputs = model(input_ids=batch["input_ids"], labels=batch["labels"]) + outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True) loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"]) val_outputs["loss"].extend(loss) - logits = outputs.logits + embeddings = outputs.hidden_states[-1] batch_size = batch["input_ids"].shape[0] sequence_lengths = [] # since we use mutiturn with multiple <|endoftext|>, we need to find the place where @@ -149,17 +148,17 @@ def inference(config): indices = torch.where(item == tokenizer.pad_token_id)[0] found = False for index in indices: + # case where sequence is less than max length if torch.all(item[index:] == tokenizer.pad_token_id): sequence_lengths.append(index) found = True break - - # no match found + # case where sequence is >= max length if not found: sequence_lengths.append(len(item) - 1) sequence_lengths = torch.tensor(sequence_lengths) - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths] val_outputs["embeddings"].append(pooled_logits) val_outputs["index"].extend(batch["index"].to(model.device)) @@ -172,7 +171,6 @@ def inference(config): val_outputs["embeddings"] = np.concatenate(val_outputs["embeddings"]) df_val = Dataset.from_dict(val_outputs) - df_val = df_val.sort("index") curr_idx = df_val["index"] # compute mask in pyarrow since it's super fast @@ -182,7 +180,6 @@ def inference(config): filtered_table = table.filter(mask) # convert from pyarrow to Dataset filtered_val = Dataset.from_dict(filtered_table.to_pydict()) - filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"]) filtered_val = filtered_val.add_column("loss", df_val["loss"]) filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val))