mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 12:06:54 +00:00
feat: sbert abstractor
This commit is contained in:
parent
4671f4e82f
commit
2dae153c68
68
gpt4all/index/embed.py
Normal file
68
gpt4all/index/embed.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class Embedder:
|
||||||
|
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
self.embedder = AutoModel.from_pretrained(model_name)
|
||||||
|
# hack
|
||||||
|
self.offset = self.tokenizer.model_max_length // 2
|
||||||
|
|
||||||
|
def _mean_pool(self, model_output, attention_mask):
|
||||||
|
token_embeddings = model_output[
|
||||||
|
0
|
||||||
|
] # First element of model_output contains all token embeddings
|
||||||
|
input_mask_expanded = (
|
||||||
|
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||||
|
)
|
||||||
|
sentence_embeddings = torch.sum(
|
||||||
|
token_embeddings * input_mask_expanded, 1
|
||||||
|
) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||||
|
return F.normalize(sentence_embeddings, p=2, dim=1)
|
||||||
|
|
||||||
|
def chunk_text(self, text):
|
||||||
|
tokenized_text = {"input_ids": [], "attention_mask": []}
|
||||||
|
tokenized = self.tokenizer(text)
|
||||||
|
tokenized_len = len(tokenized["input_ids"])
|
||||||
|
max_len = self.tokenizer.model_max_length
|
||||||
|
if tokenized_len > max_len:
|
||||||
|
start = 0
|
||||||
|
while start < tokenized_len:
|
||||||
|
tokenized_text["input_ids"].append(
|
||||||
|
tokenized["input_ids"][start : start + max_len]
|
||||||
|
)
|
||||||
|
tokenized_text["attention_mask"].append(
|
||||||
|
tokenized["attention_mask"][start : start + max_len]
|
||||||
|
)
|
||||||
|
# this could probably be done better
|
||||||
|
start += self.offset
|
||||||
|
|
||||||
|
else:
|
||||||
|
tokenized_text["input_ids"].append(tokenized["input_ids"])
|
||||||
|
tokenized_text["attention_mask"].append(tokenized["attention_mask"])
|
||||||
|
|
||||||
|
return tokenized_text
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
if isinstance(batch, dict):
|
||||||
|
outputs = self.embedder(
|
||||||
|
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
|
||||||
|
)
|
||||||
|
embedding = self._mean_pool(outputs, batch["attention_mask"])
|
||||||
|
|
||||||
|
return {"id": batch["id"], "embedding": embedding}
|
||||||
|
|
||||||
|
elif isinstance(batch, str):
|
||||||
|
tokenized = self.tokenizer(batch, return_tensors="pt", truncation=True)
|
||||||
|
return self._mean_pool(
|
||||||
|
self.embedder(
|
||||||
|
input_ids=tokenized["input_ids"],
|
||||||
|
attention_mask=tokenized["attention_mask"],
|
||||||
|
),
|
||||||
|
tokenized["attention_mask"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
self.embedder.to(device)
|
Loading…
Reference in New Issue
Block a user