mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-04 02:58:04 +00:00
fix: embeddings instead of logits!!!
This commit is contained in:
parent
4b51e6ef37
commit
1c6d2d9622
13
inference.py
13
inference.py
@ -115,7 +115,6 @@ def inference(config):
|
|||||||
train_outputs["embeddings"] = np.concatenate(train_outputs["embeddings"])
|
train_outputs["embeddings"] = np.concatenate(train_outputs["embeddings"])
|
||||||
|
|
||||||
df_train = Dataset.from_dict(train_outputs)
|
df_train = Dataset.from_dict(train_outputs)
|
||||||
df_train = df_train.sort("index")
|
|
||||||
curr_idx = df_train["index"]
|
curr_idx = df_train["index"]
|
||||||
|
|
||||||
# compute mask in pyarrow since it's super fast
|
# compute mask in pyarrow since it's super fast
|
||||||
@ -136,11 +135,11 @@ def inference(config):
|
|||||||
for batch in tqdm(val_dataloader, disable=local_rank != 0):
|
for batch in tqdm(val_dataloader, disable=local_rank != 0):
|
||||||
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
|
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
|
||||||
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
|
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
|
||||||
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"])
|
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
|
||||||
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
|
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
|
||||||
val_outputs["loss"].extend(loss)
|
val_outputs["loss"].extend(loss)
|
||||||
|
|
||||||
logits = outputs.logits
|
embeddings = outputs.hidden_states[-1]
|
||||||
batch_size = batch["input_ids"].shape[0]
|
batch_size = batch["input_ids"].shape[0]
|
||||||
sequence_lengths = []
|
sequence_lengths = []
|
||||||
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
|
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
|
||||||
@ -149,17 +148,17 @@ def inference(config):
|
|||||||
indices = torch.where(item == tokenizer.pad_token_id)[0]
|
indices = torch.where(item == tokenizer.pad_token_id)[0]
|
||||||
found = False
|
found = False
|
||||||
for index in indices:
|
for index in indices:
|
||||||
|
# case where sequence is less than max length
|
||||||
if torch.all(item[index:] == tokenizer.pad_token_id):
|
if torch.all(item[index:] == tokenizer.pad_token_id):
|
||||||
sequence_lengths.append(index)
|
sequence_lengths.append(index)
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
|
# case where sequence is >= max length
|
||||||
# no match found
|
|
||||||
if not found:
|
if not found:
|
||||||
sequence_lengths.append(len(item) - 1)
|
sequence_lengths.append(len(item) - 1)
|
||||||
|
|
||||||
sequence_lengths = torch.tensor(sequence_lengths)
|
sequence_lengths = torch.tensor(sequence_lengths)
|
||||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
|
||||||
|
|
||||||
val_outputs["embeddings"].append(pooled_logits)
|
val_outputs["embeddings"].append(pooled_logits)
|
||||||
val_outputs["index"].extend(batch["index"].to(model.device))
|
val_outputs["index"].extend(batch["index"].to(model.device))
|
||||||
@ -172,7 +171,6 @@ def inference(config):
|
|||||||
val_outputs["embeddings"] = np.concatenate(val_outputs["embeddings"])
|
val_outputs["embeddings"] = np.concatenate(val_outputs["embeddings"])
|
||||||
|
|
||||||
df_val = Dataset.from_dict(val_outputs)
|
df_val = Dataset.from_dict(val_outputs)
|
||||||
df_val = df_val.sort("index")
|
|
||||||
curr_idx = df_val["index"]
|
curr_idx = df_val["index"]
|
||||||
|
|
||||||
# compute mask in pyarrow since it's super fast
|
# compute mask in pyarrow since it's super fast
|
||||||
@ -182,7 +180,6 @@ def inference(config):
|
|||||||
filtered_table = table.filter(mask)
|
filtered_table = table.filter(mask)
|
||||||
# convert from pyarrow to Dataset
|
# convert from pyarrow to Dataset
|
||||||
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
|
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
|
||||||
|
|
||||||
filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
|
filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
|
||||||
filtered_val = filtered_val.add_column("loss", df_val["loss"])
|
filtered_val = filtered_val.add_column("loss", df_val["loss"])
|
||||||
filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val))
|
filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val))
|
||||||
|
Loading…
Reference in New Issue
Block a user