mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-05 00:30:18 +00:00
Compare commits
14 Commits
fork/async
...
wfh/embedd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b3d30eaa7c | ||
|
|
5efe913936 | ||
|
|
ee902ba7b2 | ||
|
|
5918c2ffc0 | ||
|
|
097538882d | ||
|
|
619d6c0b14 | ||
|
|
b28610c13a | ||
|
|
f273c99158 | ||
|
|
32ca9dce3e | ||
|
|
d5ad0d2421 | ||
|
|
ef930eda9a | ||
|
|
08cf728e57 | ||
|
|
43d835dda4 | ||
|
|
472b434f02 |
@@ -69,6 +69,35 @@ class LLMManagerMixin:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
|
||||
class EmbeddingsManagerMixin:
|
||||
"""Mixin for Embeddings callbacks."""
|
||||
|
||||
def on_embedding_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Called when an embedding model throws an error."""
|
||||
|
||||
def on_embedding_end(
|
||||
self,
|
||||
vector: List[float],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Called when embeddings model finishes generating embeddings.
|
||||
Args:
|
||||
vector (List[float]): The generated embeddings.
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
"""
|
||||
|
||||
|
||||
class ChainManagerMixin:
|
||||
"""Mixin for chain callbacks."""
|
||||
|
||||
@@ -182,6 +211,19 @@ class CallbackManagerMixin:
|
||||
) -> Any:
|
||||
"""Run when Retriever starts running."""
|
||||
|
||||
def on_embedding_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
texts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Embeddings starts running."""
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
@@ -225,6 +267,7 @@ class RunManagerMixin:
|
||||
|
||||
class BaseCallbackHandler(
|
||||
LLMManagerMixin,
|
||||
EmbeddingsManagerMixin,
|
||||
ChainManagerMixin,
|
||||
ToolManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
@@ -267,6 +310,11 @@ class BaseCallbackHandler(
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_embeddings(self) -> bool:
|
||||
"""Whether to ignore embeddings callbacks."""
|
||||
return False
|
||||
|
||||
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
@@ -294,7 +342,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
) -> None:
|
||||
"""Run when a chat model starts running."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||
@@ -333,6 +381,38 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
async def on_embedding_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
texts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: List[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when embeddings call starts running."""
|
||||
|
||||
async def on_embedding_end(
|
||||
self,
|
||||
vector: List[float],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when embeddings call ends running."""
|
||||
|
||||
async def on_embedding_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when embeddings call errors."""
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
|
||||
@@ -31,6 +31,7 @@ from langchain.callbacks.base import (
|
||||
BaseCallbackManager,
|
||||
Callbacks,
|
||||
ChainManagerMixin,
|
||||
EmbeddingsManagerMixin,
|
||||
LLMManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
RunManagerMixin,
|
||||
@@ -645,6 +646,50 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForEmbeddingsRun(RunManager, EmbeddingsManagerMixin):
|
||||
"""Callback manager for embeddings run."""
|
||||
|
||||
def on_embedding_end(
|
||||
self,
|
||||
vector: List[float],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when embeddings are generated.
|
||||
Args:
|
||||
embeddings (List[List[float]]): The generated embeddings.
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
"""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_embedding_end",
|
||||
"ignore_embeddings",
|
||||
vector,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_embedding_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when embeddings errors.
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
"""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_embedding_error",
|
||||
"ignore_embeddings",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
"""Async callback manager for LLM run."""
|
||||
|
||||
@@ -872,6 +917,50 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForEmbeddingsRun(ParentRunManager, EmbeddingsManagerMixin):
|
||||
"""Callback manager for embeddings run."""
|
||||
|
||||
async def on_embedding_end(
|
||||
self,
|
||||
vector: List[float],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when embeddings are generated.
|
||||
Args:
|
||||
vector (List[float]): The generated embedding vector.
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
"""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_embedding_end",
|
||||
"ignore_embeddings",
|
||||
vector,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_embedding_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when an embeddings run raises an error.
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
"""
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_embedding_error",
|
||||
"ignore_embeddings",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
|
||||
"""Callback manager for tool run."""
|
||||
|
||||
@@ -1137,6 +1226,44 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
return managers
|
||||
|
||||
def on_embedding_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
texts: List[str],
|
||||
**kwargs: Any,
|
||||
) -> List[CallbackManagerForEmbeddingsRun]:
|
||||
"""Run when embeddings model starts running.
|
||||
Args:
|
||||
serialized (Dict[str, Any]): The serialized embeddings model.
|
||||
texts (List[str]): The list of texts.
|
||||
Returns:
|
||||
List[CallbackManagerForEmbeddingsRun]: A callback manager for each
|
||||
text as an embeddings run.
|
||||
"""
|
||||
managers = []
|
||||
for text in texts:
|
||||
run_id_ = uuid.uuid4()
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_embedding_start",
|
||||
"ignore_embeddings",
|
||||
serialized,
|
||||
[text],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
managers.append(
|
||||
CallbackManagerForEmbeddingsRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
)
|
||||
return managers
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
@@ -1422,6 +1549,48 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
return managers
|
||||
|
||||
async def on_embedding_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
texts: List[str],
|
||||
**kwargs: Any,
|
||||
) -> List[AsyncCallbackManagerForEmbeddingsRun]:
|
||||
"""Run when embeddings model starts running.
|
||||
Args:
|
||||
serialized (Dict[str, Any]): The serialized embeddings model.
|
||||
texts (List[str]): The list of texts.
|
||||
Returns:
|
||||
List[CallbackManagerForEmbeddingsRun]: A callback manager for each
|
||||
text as an embeddings run.
|
||||
"""
|
||||
tasks = []
|
||||
managers: List[AsyncCallbackManagerForEmbeddingsRun] = []
|
||||
for text in texts:
|
||||
run_id_ = uuid.uuid4()
|
||||
tasks.append(
|
||||
_ahandle_event(
|
||||
self.handlers,
|
||||
"on_embedding_start",
|
||||
"ignore_embeddings",
|
||||
serialized,
|
||||
[text],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
managers.append(
|
||||
AsyncCallbackManagerForEmbeddingsRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
return managers
|
||||
|
||||
@@ -427,6 +427,82 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self._end_trace(retrieval_run)
|
||||
self._on_retriever_end(retrieval_run)
|
||||
|
||||
def on_embedding_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
texts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when embeddings model starts running.
|
||||
Args:
|
||||
serialized (Dict[str, Any]): The serialized embeddings model.
|
||||
texts (List[str]): The list of texts.
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
start_time = datetime.utcnow()
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
embeddings_run = Run(
|
||||
id=run_id,
|
||||
name="Embeddings", # TODO: Derive from serialized model
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs={"texts": texts},
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
tags=tags,
|
||||
run_type="embedding",
|
||||
)
|
||||
self._start_trace(embeddings_run)
|
||||
self._on_embedding_start(embeddings_run)
|
||||
|
||||
def on_embedding_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when embeddings model errors."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_embedding_error callback.")
|
||||
embeddings_run = self.run_map.get(str(run_id))
|
||||
if embeddings_run is None or embeddings_run.run_type != "embedding":
|
||||
raise TracerException("No embeddings Run found to be traced")
|
||||
|
||||
embeddings_run.error = repr(error)
|
||||
embeddings_run.end_time = datetime.utcnow()
|
||||
embeddings_run.events.append({"name": "error", "time": embeddings_run.end_time})
|
||||
self._end_trace(embeddings_run)
|
||||
self._on_embedding_error(embeddings_run)
|
||||
|
||||
def on_embedding_end(
|
||||
self, vector: List[float], *, run_id: UUID, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when embeddings model ends running."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_embedding_end callback.")
|
||||
embeddings_run = self.run_map.get(str(run_id))
|
||||
if embeddings_run is None or embeddings_run.run_type != "embedding":
|
||||
raise TracerException("No embeddings Run found to be traced")
|
||||
embeddings_run.outputs = {"vector": vector}
|
||||
embeddings_run.end_time = datetime.utcnow()
|
||||
embeddings_run.events.append({"name": "end", "time": embeddings_run.end_time})
|
||||
self._end_trace(embeddings_run)
|
||||
self._on_embedding_end(embeddings_run)
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> BaseTracer:
|
||||
"""Deepcopy the tracer."""
|
||||
return self
|
||||
@@ -473,3 +549,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon error."""
|
||||
|
||||
def _on_embedding_start(self, run: Run) -> None:
|
||||
"""Process the Embeddings Run upon start."""
|
||||
|
||||
def _on_embedding_end(self, run: Run) -> None:
|
||||
"""Process the Embeddings Run."""
|
||||
|
||||
def _on_embedding_error(self, run: Run) -> None:
|
||||
"""Process the Embeddings Run upon error."""
|
||||
|
||||
@@ -216,6 +216,20 @@ class LangChainTracer(BaseTracer):
|
||||
"""Process the Retriever Run upon error."""
|
||||
self._submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_embedding_start(self, run: Run) -> None:
|
||||
"""Process the Embeddings Run upon start."""
|
||||
if run.parent_run_id is None:
|
||||
run.reference_example_id = self.example_id
|
||||
self._submit(self._persist_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_embedding_end(self, run: Run) -> None:
|
||||
"""Process the Embeddings Run."""
|
||||
self._submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_embedding_error(self, run: Run) -> None:
|
||||
"""Process the Embeddings Run upon error."""
|
||||
self._submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
def wait_for_futures(self) -> None:
|
||||
"""Wait for the given futures to complete."""
|
||||
futures = list(self._futures)
|
||||
|
||||
@@ -4,12 +4,15 @@ https://arxiv.org/abs/2212.10496
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -42,7 +45,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
"""Output keys for Hyde's LLM chain."""
|
||||
return self.llm_chain.output_keys
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Call the base embeddings."""
|
||||
return self.base_embeddings.embed_documents(texts)
|
||||
|
||||
@@ -50,7 +58,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
"""Combine embeddings into final embeddings."""
|
||||
return list(np.array(embeddings).mean(axis=0))
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Generate a hypothetical document and embedded it."""
|
||||
var_name = self.llm_chain.input_keys[0]
|
||||
result = self.llm_chain.generate([{var_name: text}])
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -106,7 +109,12 @@ class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings):
|
||||
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Call out to Aleph Alpha's asymmetric Document endpoint.
|
||||
|
||||
Args:
|
||||
@@ -147,7 +155,13 @@ class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings):
|
||||
|
||||
return document_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
"""Call out to Aleph Alpha's asymmetric, query embedding endpoint
|
||||
Args:
|
||||
text: The text to embed.
|
||||
@@ -230,7 +244,12 @@ class AlephAlphaSymmetricSemanticEmbedding(AlephAlphaAsymmetricSemanticEmbedding
|
||||
|
||||
return query_response.embedding
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Call out to Aleph Alpha's Document endpoint.
|
||||
|
||||
Args:
|
||||
@@ -245,7 +264,13 @@ class AlephAlphaSymmetricSemanticEmbedding(AlephAlphaAsymmetricSemanticEmbedding
|
||||
document_embeddings.append(self._embed(text))
|
||||
return document_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
"""Call out to Aleph Alpha's asymmetric, query embedding endpoint
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Sequence
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
@@ -33,7 +34,12 @@ class AwaEmbeddings(BaseModel, Embeddings):
|
||||
self.model = model_name
|
||||
self.client.model_name = model_name
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of documents using AwaEmbedding.
|
||||
|
||||
Args:
|
||||
@@ -44,7 +50,12 @@ class AwaEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
return self.client.EmbeddingBatch(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Compute query embeddings using AwaEmbedding.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,22 +1,228 @@
|
||||
"""Interface for embedding models."""
|
||||
import asyncio
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from inspect import signature
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForEmbeddingsRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
Callbacks,
|
||||
)
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models."""
|
||||
|
||||
_new_arg_supported: bool = False
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__()
|
||||
if cls.embed_documents != Embeddings.embed_documents:
|
||||
warnings.warn(
|
||||
"Embedding models must implement abstract `_embed_documents` method"
|
||||
" instead of `embed_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
swap = cls.embed_documents
|
||||
cls.embed_documents = Embeddings.embed_documents # type: ignore[assignment]
|
||||
cls._embed_documents = swap # type: ignore[assignment]
|
||||
if (
|
||||
hasattr(cls, "aembed_documents")
|
||||
and cls.aembed_documents != Embeddings.aembed_documents
|
||||
):
|
||||
warnings.warn(
|
||||
"Embedding models must implement abstract `_aembed_documents` method"
|
||||
" instead of `aembed_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
aswap = cls.aembed_documents
|
||||
cls.aembed_documents = ( # type: ignore[assignment]
|
||||
Embeddings.aembed_documents
|
||||
)
|
||||
cls._aembed_documents = aswap # type: ignore[assignment]
|
||||
parameters = signature(cls._embed_documents).parameters
|
||||
cls._new_arg_supported = parameters.get("run_managers") is not None
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
raise NotImplementedError
|
||||
async def _aembed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[AsyncCallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not support async")
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
raise NotImplementedError
|
||||
async def _aembed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not support async")
|
||||
|
||||
def embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, None, inheritable_tags=tags, inheritable_metadata=metadata
|
||||
)
|
||||
run_managers: List[
|
||||
CallbackManagerForEmbeddingsRun
|
||||
] = callback_manager.on_embedding_start(
|
||||
{}, # TODO: make embeddings serializable
|
||||
texts,
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = self._embed_documents(
|
||||
texts,
|
||||
run_managers=run_managers,
|
||||
)
|
||||
else:
|
||||
result = self._embed_documents(texts) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_embedding_error(e)
|
||||
raise e
|
||||
else:
|
||||
for single_result, run_manager in zip(result, run_managers):
|
||||
run_manager.on_embedding_end(
|
||||
single_result,
|
||||
)
|
||||
return result
|
||||
|
||||
def embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, None, inheritable_tags=tags, inheritable_metadata=metadata
|
||||
)
|
||||
run_managers: List[
|
||||
CallbackManagerForEmbeddingsRun
|
||||
] = callback_manager.on_embedding_start(
|
||||
{}, # TODO: make embeddings serializable
|
||||
[text],
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = self._embed_query(text, run_manager=run_managers[0])
|
||||
else:
|
||||
result = self._embed_query(text) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
run_managers[0].on_embedding_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_managers[0].on_embedding_end(
|
||||
result,
|
||||
)
|
||||
return result
|
||||
|
||||
async def aembed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> List[List[float]]:
|
||||
"""Asynchronously embed search docs."""
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, None, inheritable_tags=tags, inheritable_metadata=metadata
|
||||
)
|
||||
run_managers: List[
|
||||
AsyncCallbackManagerForEmbeddingsRun
|
||||
] = await callback_manager.on_embedding_start(
|
||||
{}, # TODO: make embeddings serializable
|
||||
texts,
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = await self._aembed_documents(
|
||||
texts,
|
||||
run_managers=run_managers,
|
||||
)
|
||||
else:
|
||||
result = await self._aembed_documents(texts) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
tasks = [run_manager.on_embedding_error(e) for run_manager in run_managers]
|
||||
await asyncio.gather(*tasks)
|
||||
raise e
|
||||
else:
|
||||
tasks = [
|
||||
run_manager.on_embedding_end(
|
||||
single_result,
|
||||
)
|
||||
for run_manager, single_result in zip(run_managers, result)
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
return result
|
||||
|
||||
async def aembed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> List[float]:
|
||||
"""Asynchronously embed query text."""
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, None, inheritable_tags=tags, inheritable_metadata=metadata
|
||||
)
|
||||
run_managers: List[
|
||||
AsyncCallbackManagerForEmbeddingsRun
|
||||
] = await callback_manager.on_embedding_start(
|
||||
{}, # TODO: make embeddings serializable
|
||||
[text],
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = await self._aembed_query(
|
||||
text,
|
||||
run_manager=run_managers[0],
|
||||
)
|
||||
else:
|
||||
result = await self._aembed_query(text) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
await run_managers[0].on_embedding_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_managers[0].on_embedding_end(
|
||||
result,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
@@ -128,8 +129,11 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: int = 1
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a Bedrock model.
|
||||
|
||||
@@ -149,7 +153,13 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
||||
results.append(response)
|
||||
return results
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
"""Compute query embeddings using a Bedrock model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -82,7 +85,12 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Call out to Clarifai's embedding models.
|
||||
|
||||
Args:
|
||||
@@ -137,7 +145,12 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
|
||||
]
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Call out to Clarifai's embedding models.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
"""Wrapper around Cohere embedding models."""
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForEmbeddingsRun,
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -57,7 +62,12 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Call out to Cohere's embedding endpoint.
|
||||
|
||||
Args:
|
||||
@@ -71,7 +81,12 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
).embeddings
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
async def _aembed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[AsyncCallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Async call out to Cohere's embedding endpoint.
|
||||
|
||||
Args:
|
||||
@@ -85,7 +100,12 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return [list(map(float, e)) for e in embeddings.embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Call out to Cohere's embedding endpoint.
|
||||
|
||||
Args:
|
||||
@@ -94,9 +114,11 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
return self._embed_documents([text], run_managers=[run_manager])[0]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def _aembed_query(
|
||||
self, text: str, *, run_manager: AsyncCallbackManagerForEmbeddingsRun
|
||||
) -> List[float]:
|
||||
"""Async call out to Cohere's embedding endpoint.
|
||||
|
||||
Args:
|
||||
@@ -105,5 +127,5 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embeddings = await self.aembed_documents([text])
|
||||
embeddings = await self._aembed_documents([text], run_managers=[run_manager])
|
||||
return embeddings[0]
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
@@ -19,6 +20,9 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -123,7 +127,12 @@ class DashScopeEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Call out to DashScope's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
@@ -140,7 +149,12 @@ class DashScopeEmbeddings(BaseModel, Embeddings):
|
||||
embedding_list = [item["embedding"] for item in embeddings]
|
||||
return embedding_list
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Call out to DashScope's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -102,7 +105,12 @@ class DeepInfraEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Embed documents using a Deep Infra deployed embedding model.
|
||||
|
||||
Args:
|
||||
@@ -115,7 +123,12 @@ class DeepInfraEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = self._embed(instruction_pairs)
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed a query using a Deep Infra deployed embedding model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.requests import Requests
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -64,7 +65,12 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of documents using EdenAI.
|
||||
|
||||
Args:
|
||||
@@ -76,7 +82,12 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return self._generate_embeddings(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed a query using EdenAI.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional, Sequence
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.utils import get_from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -195,7 +198,12 @@ class ElasticsearchEmbeddings(Embeddings):
|
||||
embeddings = [doc["predicted_value"] for doc in response["inference_results"]]
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for a list of documents.
|
||||
|
||||
@@ -209,7 +217,13 @@ class ElasticsearchEmbeddings(Embeddings):
|
||||
"""
|
||||
return self._embedding_func(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
"""
|
||||
Generate an embedding for a single query text.
|
||||
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
"""Wrapper around embaas embeddings API."""
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -110,7 +114,12 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
raise
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Get embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
@@ -126,7 +135,13 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
|
||||
# flatten the list of lists into a single list
|
||||
return [embedding for batch in embeddings for embedding in batch]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
"""Get embeddings for a single text.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import hashlib
|
||||
from typing import List
|
||||
from typing import List, Sequence
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
@@ -16,10 +17,21 @@ class FakeEmbeddings(Embeddings, BaseModel):
|
||||
def _get_embedding(self) -> List[float]:
|
||||
return list(np.random.normal(size=self.size))
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
return [self._get_embedding() for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
return self._get_embedding()
|
||||
|
||||
|
||||
@@ -43,8 +55,15 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
||||
"""
|
||||
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self, text: str, *, run_manager: CallbackManagerForEmbeddingsRun
|
||||
) -> List[float]:
|
||||
return self._get_embedding(seed=self._get_seed(text))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
from tenacity import (
|
||||
@@ -12,6 +12,9 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -78,10 +81,21 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
return [self.embed_query(text) for text in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
embedding = embed_with_retry(self, self.model_name, text)
|
||||
return embedding["embedding"]
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import Any, Dict, List
|
||||
"""Wrapper around GPT4All embedding models."""
|
||||
from typing import Any, Dict, List, Sequence
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
@@ -36,7 +40,12 @@ class GPT4AllEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of documents using GPT4All.
|
||||
|
||||
Args:
|
||||
@@ -49,7 +58,12 @@ class GPT4AllEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = [self.client.embed(text) for text in texts]
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed a query using GPT4All.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
"""Wrapper around HuggingFace embedding models."""
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
@@ -64,7 +68,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
@@ -77,7 +86,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
@@ -144,7 +158,12 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
@@ -157,7 +176,12 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs)
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
"""Wrapper around HuggingFace Hub embedding models."""
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -76,7 +80,12 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Call out to HuggingFaceHub's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
@@ -91,7 +100,12 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
|
||||
responses = self.client(inputs=texts, params=_model_kwargs)
|
||||
return responses
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Call out to HuggingFaceHub's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -71,7 +74,12 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
payload = dict(inputs=docs, metadata=self.request_headers, **kwargs)
|
||||
return self.client.post(on="/encode", **payload)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Call out to Jina's embedding endpoint.
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
@@ -85,7 +93,12 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
).embeddings
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Call out to Jina's embedding endpoint.
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
"""Wrapper around llama.cpp embedding models."""
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
@@ -98,7 +102,12 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of documents using the Llama model.
|
||||
|
||||
Args:
|
||||
@@ -110,7 +119,12 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = [self.client.embed(text) for text in texts]
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed a query using the Llama model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -25,6 +25,10 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForEmbeddingsRun,
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
@@ -285,8 +289,11 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
)["data"][0]["embedding"]
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Call out to LocalAI's embedding endpoint for embedding search docs.
|
||||
|
||||
@@ -301,8 +308,11 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||
# call _embedding_func for each text
|
||||
return [self._embedding_func(text, engine=self.deployment) for text in texts]
|
||||
|
||||
async def aembed_documents(
|
||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||
async def _aembed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[AsyncCallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Call out to LocalAI's embedding endpoint async for embedding search docs.
|
||||
|
||||
@@ -320,7 +330,9 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||
embeddings.append(response)
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self, text: str, *, run_manager: CallbackManagerForEmbeddingsRun
|
||||
) -> List[float]:
|
||||
"""Call out to LocalAI's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
@@ -332,7 +344,9 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||
embedding = self._embedding_func(text, engine=self.deployment)
|
||||
return embedding
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def _aembed_query(
|
||||
self, text: str, *, run_manager: AsyncCallbackManagerForEmbeddingsRun
|
||||
) -> List[float]:
|
||||
"""Call out to LocalAI's embedding endpoint async for embedding query text.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
@@ -12,6 +12,9 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -135,7 +138,12 @@ class MiniMaxEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Embed documents using a MiniMax embedding endpoint.
|
||||
|
||||
Args:
|
||||
@@ -147,7 +155,12 @@ class MiniMaxEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = embed_with_retry(self, texts=texts, embed_type=self.embed_type_db)
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed a query using a MiniMax embedding endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterator, List, Optional
|
||||
from typing import Any, Iterator, List, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
@@ -63,8 +64,18 @@ class MlflowAIGatewayEmbeddings(Embeddings, BaseModel):
|
||||
embeddings.append(resp["embeddings"])
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
return self._query(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
return self._query([text])[0]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Any, List, Optional
|
||||
"""Wrapper around ModelScopeHub embedding models."""
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
@@ -45,7 +47,12 @@ class ModelScopeEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a modelscope embedding model.
|
||||
|
||||
Args:
|
||||
@@ -59,7 +66,12 @@ class ModelScopeEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = self.embed(input=inputs)["text_embedding"]
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Compute query embeddings using a modelscope embedding model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional, Tuple
|
||||
"""Wrapper around MosaicML APIs."""
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -136,7 +140,12 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Embed documents using a MosaicML deployed instructor embedding model.
|
||||
|
||||
Args:
|
||||
@@ -149,7 +158,12 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = self._embed(instruction_pairs)
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed a query using a MosaicML deployed instructor embedding model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Sequence
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -27,7 +28,7 @@ class NLPCloudEmbeddings(BaseModel, Embeddings):
|
||||
self,
|
||||
model_name: str = "paraphrase-multilingual-mpnet-base-v2",
|
||||
gpu: bool = False,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(model_name=model_name, gpu=gpu, **kwargs)
|
||||
|
||||
@@ -50,7 +51,12 @@ class NLPCloudEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of documents using NLP Cloud.
|
||||
|
||||
Args:
|
||||
@@ -62,7 +68,12 @@ class NLPCloudEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return self.client.embeddings(texts)["embeddings"]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed a query using NLP Cloud.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
"""Module providing a wrapper around OctoAI Compute Service embedding models."""
|
||||
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -80,12 +85,22 @@ class OctoAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Compute document embeddings using an OctoAI instruct model."""
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
return self._compute_embeddings(texts, self.embed_instruction)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Compute query embedding using an OctoAI instruct model."""
|
||||
text = text.replace("\n", " ")
|
||||
return self._compute_embeddings([text], self.embed_instruction)[0]
|
||||
|
||||
@@ -26,6 +26,10 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForEmbeddingsRun,
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
@@ -124,29 +128,23 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
environment variable ``OPENAI_API_KEY`` set with your API key or pass it
|
||||
as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
openai = OpenAIEmbeddings(openai_api_key="my-api-key")
|
||||
|
||||
In order to use the library with Microsoft Azure endpoints, you need to set
|
||||
the OPENAI_API_TYPE, OPENAI_API_BASE, OPENAI_API_KEY and OPENAI_API_VERSION.
|
||||
The OPENAI_API_TYPE must be set to 'azure' and the others correspond to
|
||||
the properties of your endpoint.
|
||||
In addition, the deployment name must be passed as the model parameter.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
import os
|
||||
os.environ["OPENAI_API_TYPE"] = "azure"
|
||||
os.environ["OPENAI_API_BASE"] = "https://<your-endpoint.openai.azure.com/"
|
||||
os.environ["OPENAI_API_KEY"] = "your AzureOpenAI key"
|
||||
os.environ["OPENAI_API_VERSION"] = "2023-05-15"
|
||||
os.environ["OPENAI_PROXY"] = "http://your-corporate-proxy:8080"
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings(
|
||||
deployment="your-embeddings-deployment-name",
|
||||
@@ -156,7 +154,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
text = "This is a test query."
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
@@ -454,16 +451,51 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint."""
|
||||
# handle large input text
|
||||
if len(text) > self.embedding_ctx_length:
|
||||
return self._get_len_safe_embeddings([text], engine=engine)[0]
|
||||
else:
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
return embed_with_retry(
|
||||
self,
|
||||
input=[text],
|
||||
**self._invocation_params,
|
||||
)[
|
||||
"data"
|
||||
][0]["embedding"]
|
||||
|
||||
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint."""
|
||||
# handle large input text
|
||||
if len(text) > self.embedding_ctx_length:
|
||||
return (await self._aget_len_safe_embeddings([text], engine=engine))[0]
|
||||
else:
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
return (
|
||||
await async_embed_with_retry(
|
||||
self,
|
||||
input=[text],
|
||||
**self._invocation_params,
|
||||
)
|
||||
)["data"][0]["embedding"]
|
||||
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||
specified by the class.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
@@ -471,16 +503,15 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
return self._get_len_safe_embeddings(texts, engine=self.deployment)
|
||||
|
||||
async def aembed_documents(
|
||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||
async def _aembed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[AsyncCallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||
specified by the class.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
@@ -488,23 +519,29 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
return await self._aget_len_safe_embeddings(texts, engine=self.deployment)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def _aembed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint async for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
"""Wrapper around Sagemaker InvokeEndpoint API."""
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
|
||||
|
||||
@@ -98,6 +102,8 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
function. See `boto3`_. docs for more info.
|
||||
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>
|
||||
"""
|
||||
chunk_size: int = 64
|
||||
"""The number of documents to send to the endpoint at a time."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -163,8 +169,11 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return self.content_handler.transform_output(response["Body"])
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: int = 64
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a SageMaker Inference Endpoint.
|
||||
|
||||
@@ -179,13 +188,19 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
results = []
|
||||
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
|
||||
_chunk_size = len(texts) if self.chunk_size > len(texts) else self.chunk_size
|
||||
for i in range(0, len(texts), _chunk_size):
|
||||
response = self._embedding_func(texts[i : i + _chunk_size])
|
||||
results.extend(response)
|
||||
return results
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
"""Compute query embeddings using a SageMaker inference endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import Any, Callable, List
|
||||
"""Running custom embedding models on self-hosted remote hardware."""
|
||||
from typing import Any, Callable, List, Sequence
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms import SelfHostedPipeline
|
||||
|
||||
@@ -71,7 +75,12 @@ class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
@@ -86,7 +95,13 @@ class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings):
|
||||
return embeddings.tolist()
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.self_hosted import SelfHostedEmbeddings
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
@@ -139,7 +142,12 @@ class SelfHostedHuggingFaceInstructEmbeddings(SelfHostedHuggingFaceEmbeddings):
|
||||
load_fn_kwargs["device"] = load_fn_kwargs.get("device", 0)
|
||||
super().__init__(load_fn_kwargs=load_fn_kwargs, **kwargs)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
@@ -154,7 +162,13 @@ class SelfHostedHuggingFaceInstructEmbeddings(SelfHostedHuggingFaceEmbeddings):
|
||||
embeddings = self.client(self.pipeline_ref, instruction_pairs)
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
"""Compute query embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import importlib.util
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Sequence
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForEmbeddingsRun,
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
@@ -63,7 +67,12 @@ class SpacyEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values # Return the validated values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Generates embeddings for a list of documents.
|
||||
|
||||
@@ -75,7 +84,13 @@ class SpacyEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
return [self.nlp(text).vector.tolist() for text in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
"""
|
||||
Generates an embedding for a single piece of text.
|
||||
|
||||
@@ -87,7 +102,12 @@ class SpacyEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
return self.nlp(text).vector.tolist()
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
async def _aembed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[AsyncCallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Asynchronously generates embeddings for a list of documents.
|
||||
This method is not implemented and raises a NotImplementedError.
|
||||
@@ -100,7 +120,12 @@ class SpacyEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
raise NotImplementedError("Asynchronous embedding generation is not supported.")
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def _aembed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""
|
||||
Asynchronously generates an embedding for a single piece of text.
|
||||
This method is not implemented and raises a NotImplementedError.
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import Any, List
|
||||
"""Wrapper around TensorflowHub embedding models."""
|
||||
from typing import Any, List, Sequence
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
DEFAULT_MODEL_URL = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
|
||||
@@ -49,7 +53,12 @@ class TensorflowHubEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a TensorflowHub embedding model.
|
||||
|
||||
Args:
|
||||
@@ -62,7 +71,12 @@ class TensorflowHubEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = self.embed(texts).numpy()
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Compute query embeddings using a TensorflowHub embedding model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Dict, List
|
||||
"""Wrapper around Google VertexAI embedding models."""
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.vertexai import _VertexAICommon
|
||||
from langchain.utilities.vertexai import raise_vertex_import_error
|
||||
@@ -11,6 +13,8 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
|
||||
"""Google Cloud VertexAI embedding models."""
|
||||
|
||||
model_name: str = "textembedding-gecko"
|
||||
batch_size: int = 5
|
||||
"""Number of texts to embed in a single API call."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@@ -23,8 +27,11 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
|
||||
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
|
||||
return values
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], batch_size: int = 5
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of strings. Vertex AI currently
|
||||
sets a max batch size of 5 strings.
|
||||
@@ -37,13 +44,18 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
for batch in range(0, len(texts), batch_size):
|
||||
text_batch = texts[batch : batch + batch_size]
|
||||
for batch in range(0, len(texts), self.batch_size):
|
||||
text_batch = texts[batch : batch + self.batch_size]
|
||||
embeddings_batch = self.client.get_embeddings(text_batch)
|
||||
embeddings.extend([el.values for el in embeddings_batch])
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed a text.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Wrapper around Xinference embedding models."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
@@ -81,7 +82,12 @@ class XinferenceEmbeddings(Embeddings):
|
||||
|
||||
self.client = RESTfulClient(server_url)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of documents using Xinference.
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
@@ -96,7 +102,12 @@ class XinferenceEmbeddings(Embeddings):
|
||||
]
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Embed a query of documents using Xinference.
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
@@ -239,8 +239,12 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
||||
Returns:
|
||||
Dict[str, Any]: The computed score.
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
vectors = np.array(
|
||||
self.embeddings.embed_documents([inputs["prediction"], inputs["reference"]])
|
||||
self.embeddings.embed_documents(
|
||||
[inputs["prediction"], inputs["reference"]],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
)
|
||||
score = self._compute_score(vectors)
|
||||
return {"score": score}
|
||||
@@ -260,8 +264,10 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
||||
Returns:
|
||||
Dict[str, Any]: The computed score.
|
||||
"""
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
embedded = await self.embeddings.aembed_documents(
|
||||
[inputs["prediction"], inputs["reference"]]
|
||||
[inputs["prediction"], inputs["reference"]],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
vectors = np.array(embedded)
|
||||
score = self._compute_score(vectors)
|
||||
@@ -376,9 +382,11 @@ class PairwiseEmbeddingDistanceEvalChain(
|
||||
Returns:
|
||||
Dict[str, Any]: The computed score.
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
vectors = np.array(
|
||||
self.embeddings.embed_documents(
|
||||
[inputs["prediction"], inputs["prediction_b"]]
|
||||
[inputs["prediction"], inputs["prediction_b"]],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
)
|
||||
score = self._compute_score(vectors)
|
||||
@@ -399,8 +407,10 @@ class PairwiseEmbeddingDistanceEvalChain(
|
||||
Returns:
|
||||
Dict[str, Any]: The computed score.
|
||||
"""
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
embedded = await self.embeddings.aembed_documents(
|
||||
[inputs["prediction"], inputs["prediction_b"]]
|
||||
[inputs["prediction"], inputs["prediction_b"]],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
vectors = np.array(embedded)
|
||||
score = self._compute_score(vectors)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Fake Embedding class for testing purposes."""
|
||||
import math
|
||||
from typing import List
|
||||
from typing import List, Sequence
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
fake_texts = ["foo", "bar", "baz"]
|
||||
@@ -10,12 +11,22 @@ fake_texts = ["foo", "bar", "baz"]
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Return simple embeddings.
|
||||
Embeddings encode each text as its index."""
|
||||
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Return constant query embeddings.
|
||||
Embeddings are identical to embed_documents(texts)[0].
|
||||
Distance to each text will be that text's index,
|
||||
@@ -31,7 +42,12 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||
self.known_texts: List[str] = []
|
||||
self.dimensionality = dimensionality
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Return consistent embeddings for each text seen so far."""
|
||||
out_vectors = []
|
||||
for text in texts:
|
||||
@@ -43,7 +59,12 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||
out_vectors.append(vector)
|
||||
return out_vectors
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Return consistent embeddings for the text, if seen before, or a constant
|
||||
one if the text is unknown."""
|
||||
if text not in self.known_texts:
|
||||
@@ -58,13 +79,23 @@ class AngularTwoDimensionalEmbeddings(Embeddings):
|
||||
From angles (as strings in units of pi) to unit embedding vectors on a circle.
|
||||
"""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Make a list of texts into a list of embedding vectors.
|
||||
"""
|
||||
return [self.embed_query(text) for text in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""
|
||||
Convert input text to a 'vector' (list of floats).
|
||||
If the text is a number, use it as the angle for the
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from typing import List
|
||||
from typing import List, Sequence
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
)
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.alibabacloud_opensearch import (
|
||||
AlibabaCloudOpenSearch,
|
||||
@@ -15,14 +18,23 @@ texts = ["foo", "bar", "baz"]
|
||||
class FakeEmbeddingsWithOsDimension(FakeEmbeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Return simple embeddings."""
|
||||
return [
|
||||
[float(1.0)] * (OS_TOKEN_COUNT - 1) + [float(i)]
|
||||
for i in range(len(embedding_texts))
|
||||
[float(1.0)] * (OS_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
|
||||
]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Return simple embeddings."""
|
||||
return [float(1.0)] * (OS_TOKEN_COUNT - 1) + [float(texts.index(text))]
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Test PGVector functionality."""
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Sequence
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.analyticdb import AnalyticDB
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
@@ -22,13 +23,23 @@ ADA_TOKEN_COUNT = 1536
|
||||
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Return simple embeddings."""
|
||||
return [
|
||||
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
|
||||
]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Return simple embeddings."""
|
||||
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Test Hologres functionality."""
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Sequence
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.hologres import Hologres
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
@@ -21,13 +22,23 @@ ADA_TOKEN_COUNT = 1536
|
||||
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Return simple embeddings."""
|
||||
return [
|
||||
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
|
||||
]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Return simple embeddings."""
|
||||
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Test PGVector functionality."""
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Sequence
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.pgvector import PGVector
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
@@ -24,13 +25,23 @@ ADA_TOKEN_COUNT = 1536
|
||||
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Return simple embeddings."""
|
||||
return [
|
||||
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
|
||||
]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Return simple embeddings."""
|
||||
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Test SingleStoreDB functionality."""
|
||||
from typing import List
|
||||
from typing import List, Sequence
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.singlestoredb import SingleStoreDB
|
||||
from langchain.vectorstores.utils import DistanceStrategy
|
||||
@@ -36,10 +37,20 @@ class NormilizedFakeEmbeddings(FakeEmbeddings):
|
||||
"""Normalize vector."""
|
||||
return [float(v / np.linalg.norm(vector)) for v in vector]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
return [self.normalize(v) for v in super().embed_documents(texts)]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
return self.normalize(super().embed_query(text))
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Test HyDE."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
||||
@@ -17,11 +18,21 @@ from langchain.schema import Generation, LLMResult
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embedding class for tests."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
) -> List[List[float]]:
|
||||
"""Return random floats."""
|
||||
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForEmbeddingsRun,
|
||||
) -> List[float]:
|
||||
"""Return random floats."""
|
||||
return list(np.random.uniform(0, 1, 10))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user