mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 15:46:47 +00:00
Apply patch [skip ci]
This commit is contained in:
parent
2e95b3fa71
commit
d0e5924876
@ -110,6 +110,82 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
def _mean_pooling(
|
||||
self, model_output: Any, attention_mask: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Apply mean pooling to model output."""
|
||||
# Extract token embeddings from model output
|
||||
token_embeddings = model_output[0]
|
||||
|
||||
# Expand attention mask for broadcasting
|
||||
input_mask_expanded = (
|
||||
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
)
|
||||
|
||||
# Apply mean pooling
|
||||
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
||||
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
return sum_embeddings / sum_mask
|
||||
|
||||
def encode(
|
||||
self,
|
||||
texts: list[str],
|
||||
batch_size: int = 32,
|
||||
show_progress_bar: Optional[bool] = None,
|
||||
normalize_embeddings: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> np.ndarray:
|
||||
"""Encode texts into embeddings.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
batch_size: Batch size for encoding.
|
||||
show_progress_bar: Whether to show progress bar.
|
||||
normalize_embeddings: Whether to normalize embeddings.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Array of embeddings.
|
||||
"""
|
||||
if show_progress_bar is None:
|
||||
show_progress_bar = self.show_progress
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i : i + batch_size]
|
||||
|
||||
# Tokenize texts
|
||||
encoded_input = self.tokenizer(
|
||||
batch_texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Move to device
|
||||
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
|
||||
|
||||
# Generate embeddings
|
||||
with torch.no_grad():
|
||||
model_output = self.model(**encoded_input)
|
||||
|
||||
# Apply mean pooling
|
||||
embeddings = self._mean_pooling(
|
||||
model_output, encoded_input["attention_mask"]
|
||||
)
|
||||
|
||||
# Normalize if requested
|
||||
if normalize_embeddings:
|
||||
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
||||
|
||||
all_embeddings.append(embeddings.cpu().numpy())
|
||||
|
||||
# Concatenate all embeddings
|
||||
return np.vstack(all_embeddings)
|
||||
|
||||
def _embed(
|
||||
self, texts: list[str], encode_kwargs: dict[str, Any]
|
||||
) -> list[list[float]]:
|
||||
@ -118,35 +194,19 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
encode_kwargs: Keyword arguments to pass when calling the
|
||||
`encode` method for the documents of the SentenceTransformer
|
||||
encode method.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
import sentence_transformers # type: ignore[import]
|
||||
|
||||
texts = [x.replace("\n", " ") for x in texts]
|
||||
if self.multi_process:
|
||||
pool = self._client.start_multi_process_pool()
|
||||
embeddings = self._client.encode_multi_process(texts, pool)
|
||||
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
|
||||
else:
|
||||
embeddings = self._client.encode(
|
||||
texts,
|
||||
show_progress_bar=self.show_progress,
|
||||
**encode_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(embeddings, list):
|
||||
msg = (
|
||||
"Expected embeddings to be a Tensor or a numpy array, "
|
||||
"got a list instead."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
return embeddings.tolist() # type: ignore[return-type]
|
||||
embeddings = self.encode(
|
||||
texts,
|
||||
show_progress_bar=self.show_progress,
|
||||
**encode_kwargs,
|
||||
)
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
@ -179,3 +239,4 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user