Compare commits

...

14 Commits

Author SHA1 Message Date
William Fu-Hinthorn
b3d30eaa7c Merge branch 'master' into wfh/embeddings_callbacks_v3 2023-08-04 08:41:48 -07:00
William Fu-Hinthorn
5efe913936 merge 2023-08-04 08:41:35 -07:00
William Fu-Hinthorn
ee902ba7b2 Merge branch 'master' into wfh/embeddings_callbacks_v3 2023-07-27 12:40:23 -07:00
William Fu-Hinthorn
5918c2ffc0 merge 2023-07-26 20:05:24 -07:00
William Fu-Hinthorn
097538882d merge. some things still need upgraded now tho 2023-07-26 19:09:27 -07:00
William Fu-Hinthorn
619d6c0b14 merge 2023-07-24 08:39:56 -07:00
William Fu-Hinthorn
b28610c13a Merge branch 'master' into wfh/embeddings_callbacks_v3 2023-07-18 22:09:05 -07:00
William Fu-Hinthorn
f273c99158 update 2023-07-18 22:05:04 -07:00
William Fu-Hinthorn
32ca9dce3e Merge branch 'wfh/docs_nits2' into wfh/embeddings_callbacks_v3 2023-07-18 21:43:13 -07:00
William Fu-Hinthorn
d5ad0d2421 Docs Nit 2023-07-18 21:41:50 -07:00
William Fu-Hinthorn
ef930eda9a linting 2023-07-18 21:39:16 -07:00
William Fu-Hinthorn
08cf728e57 Update others but too many newlines 2023-07-18 17:40:49 -07:00
William Fu-Hinthorn
43d835dda4 Embeddings Draft 2023-07-18 15:03:55 -07:00
William Fu-Hinthorn
472b434f02 tmp 2023-06-30 14:38:54 -07:00
46 changed files with 1352 additions and 184 deletions

View File

@@ -69,6 +69,35 @@ class LLMManagerMixin:
"""Run when LLM errors."""
class EmbeddingsManagerMixin:
"""Mixin for Embeddings callbacks."""
def on_embedding_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Called when an embedding model throws an error."""
def on_embedding_end(
self,
vector: List[float],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Called when embeddings model finishes generating embeddings.
Args:
vector (List[float]): The generated embeddings.
Returns:
Any: The result of the callback.
"""
class ChainManagerMixin:
"""Mixin for chain callbacks."""
@@ -182,6 +211,19 @@ class CallbackManagerMixin:
) -> Any:
"""Run when Retriever starts running."""
def on_embedding_start(
self,
serialized: Dict[str, Any],
texts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when Embeddings starts running."""
def on_chain_start(
self,
serialized: Dict[str, Any],
@@ -225,6 +267,7 @@ class RunManagerMixin:
class BaseCallbackHandler(
LLMManagerMixin,
EmbeddingsManagerMixin,
ChainManagerMixin,
ToolManagerMixin,
RetrieverManagerMixin,
@@ -267,6 +310,11 @@ class BaseCallbackHandler(
"""Whether to ignore chat model callbacks."""
return False
@property
def ignore_embeddings(self) -> bool:
"""Whether to ignore embeddings callbacks."""
return False
class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""
@@ -294,7 +342,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
) -> None:
"""Run when a chat model starts running."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
@@ -333,6 +381,38 @@ class AsyncCallbackHandler(BaseCallbackHandler):
) -> None:
"""Run when LLM errors."""
async def on_embedding_start(
self,
serialized: Dict[str, Any],
texts: List[str],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: List[str] | None = None,
**kwargs: Any,
) -> None:
"""Run when embeddings call starts running."""
async def on_embedding_end(
self,
vector: List[float],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
"""Run when embeddings call ends running."""
async def on_embedding_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
"""Run when embeddings call errors."""
async def on_chain_start(
self,
serialized: Dict[str, Any],

View File

@@ -31,6 +31,7 @@ from langchain.callbacks.base import (
BaseCallbackManager,
Callbacks,
ChainManagerMixin,
EmbeddingsManagerMixin,
LLMManagerMixin,
RetrieverManagerMixin,
RunManagerMixin,
@@ -645,6 +646,50 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
)
class CallbackManagerForEmbeddingsRun(RunManager, EmbeddingsManagerMixin):
"""Callback manager for embeddings run."""
def on_embedding_end(
self,
vector: List[float],
**kwargs: Any,
) -> Any:
"""Run when embeddings are generated.
Args:
embeddings (List[List[float]]): The generated embeddings.
Returns:
Any: The result of the callback.
"""
_handle_event(
self.handlers,
"on_embedding_end",
"ignore_embeddings",
vector,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_embedding_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when embeddings errors.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
_handle_event(
self.handlers,
"on_embedding_error",
"ignore_embeddings",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
"""Async callback manager for LLM run."""
@@ -872,6 +917,50 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
)
class AsyncCallbackManagerForEmbeddingsRun(ParentRunManager, EmbeddingsManagerMixin):
"""Callback manager for embeddings run."""
async def on_embedding_end(
self,
vector: List[float],
**kwargs: Any,
) -> Any:
"""Run when embeddings are generated.
Args:
vector (List[float]): The generated embedding vector.
Returns:
Any: The result of the callback.
"""
await _ahandle_event(
self.handlers,
"on_embedding_end",
"ignore_embeddings",
vector,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_embedding_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when an embeddings run raises an error.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
await _ahandle_event(
self.handlers,
"on_embedding_error",
"ignore_embeddings",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
"""Callback manager for tool run."""
@@ -1137,6 +1226,44 @@ class CallbackManager(BaseCallbackManager):
return managers
def on_embedding_start(
self,
serialized: Dict[str, Any],
texts: List[str],
**kwargs: Any,
) -> List[CallbackManagerForEmbeddingsRun]:
"""Run when embeddings model starts running.
Args:
serialized (Dict[str, Any]): The serialized embeddings model.
texts (List[str]): The list of texts.
Returns:
List[CallbackManagerForEmbeddingsRun]: A callback manager for each
text as an embeddings run.
"""
managers = []
for text in texts:
run_id_ = uuid.uuid4()
_handle_event(
self.handlers,
"on_embedding_start",
"ignore_embeddings",
serialized,
[text],
run_id=run_id_,
parent_run_id=self.parent_run_id,
**kwargs,
)
managers.append(
CallbackManagerForEmbeddingsRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
inheritable_tags=self.inheritable_tags,
)
)
return managers
def on_chain_start(
self,
serialized: Dict[str, Any],
@@ -1422,6 +1549,48 @@ class AsyncCallbackManager(BaseCallbackManager):
inheritable_metadata=self.inheritable_metadata,
)
)
await asyncio.gather(*tasks)
return managers
async def on_embedding_start(
self,
serialized: Dict[str, Any],
texts: List[str],
**kwargs: Any,
) -> List[AsyncCallbackManagerForEmbeddingsRun]:
"""Run when embeddings model starts running.
Args:
serialized (Dict[str, Any]): The serialized embeddings model.
texts (List[str]): The list of texts.
Returns:
List[CallbackManagerForEmbeddingsRun]: A callback manager for each
text as an embeddings run.
"""
tasks = []
managers: List[AsyncCallbackManagerForEmbeddingsRun] = []
for text in texts:
run_id_ = uuid.uuid4()
tasks.append(
_ahandle_event(
self.handlers,
"on_embedding_start",
"ignore_embeddings",
serialized,
[text],
run_id=run_id_,
parent_run_id=self.parent_run_id,
**kwargs,
)
)
managers.append(
AsyncCallbackManagerForEmbeddingsRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
inheritable_tags=self.inheritable_tags,
)
)
await asyncio.gather(*tasks)
return managers

View File

@@ -427,6 +427,82 @@ class BaseTracer(BaseCallbackHandler, ABC):
self._end_trace(retrieval_run)
self._on_retriever_end(retrieval_run)
def on_embedding_start(
self,
serialized: Dict[str, Any],
texts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when embeddings model starts running.
Args:
serialized (Dict[str, Any]): The serialized embeddings model.
texts (List[str]): The list of texts.
Returns:
None
"""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
start_time = datetime.utcnow()
if metadata:
kwargs.update({"metadata": metadata})
embeddings_run = Run(
id=run_id,
name="Embeddings", # TODO: Derive from serialized model
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"texts": texts},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
tags=tags,
run_type="embedding",
)
self._start_trace(embeddings_run)
self._on_embedding_start(embeddings_run)
def on_embedding_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
**kwargs: Any,
) -> None:
"""Run when embeddings model errors."""
if not run_id:
raise TracerException("No run_id provided for on_embedding_error callback.")
embeddings_run = self.run_map.get(str(run_id))
if embeddings_run is None or embeddings_run.run_type != "embedding":
raise TracerException("No embeddings Run found to be traced")
embeddings_run.error = repr(error)
embeddings_run.end_time = datetime.utcnow()
embeddings_run.events.append({"name": "error", "time": embeddings_run.end_time})
self._end_trace(embeddings_run)
self._on_embedding_error(embeddings_run)
def on_embedding_end(
self, vector: List[float], *, run_id: UUID, **kwargs: Any
) -> None:
"""Run when embeddings model ends running."""
if not run_id:
raise TracerException("No run_id provided for on_embedding_end callback.")
embeddings_run = self.run_map.get(str(run_id))
if embeddings_run is None or embeddings_run.run_type != "embedding":
raise TracerException("No embeddings Run found to be traced")
embeddings_run.outputs = {"vector": vector}
embeddings_run.end_time = datetime.utcnow()
embeddings_run.events.append({"name": "end", "time": embeddings_run.end_time})
self._end_trace(embeddings_run)
self._on_embedding_end(embeddings_run)
def __deepcopy__(self, memo: dict) -> BaseTracer:
"""Deepcopy the tracer."""
return self
@@ -473,3 +549,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
def _on_retriever_error(self, run: Run) -> None:
"""Process the Retriever Run upon error."""
def _on_embedding_start(self, run: Run) -> None:
"""Process the Embeddings Run upon start."""
def _on_embedding_end(self, run: Run) -> None:
"""Process the Embeddings Run."""
def _on_embedding_error(self, run: Run) -> None:
"""Process the Embeddings Run upon error."""

View File

@@ -216,6 +216,20 @@ class LangChainTracer(BaseTracer):
"""Process the Retriever Run upon error."""
self._submit(self._update_run_single, run.copy(deep=True))
def _on_embedding_start(self, run: Run) -> None:
"""Process the Embeddings Run upon start."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._submit(self._persist_run_single, run.copy(deep=True))
def _on_embedding_end(self, run: Run) -> None:
"""Process the Embeddings Run."""
self._submit(self._update_run_single, run.copy(deep=True))
def _on_embedding_error(self, run: Run) -> None:
"""Process the Embeddings Run upon error."""
self._submit(self._update_run_single, run.copy(deep=True))
def wait_for_futures(self) -> None:
"""Wait for the given futures to complete."""
futures = list(self._futures)

View File

@@ -4,12 +4,15 @@ https://arxiv.org/abs/2212.10496
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence
import numpy as np
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.callbacks.manager import (
CallbackManagerForChainRun,
CallbackManagerForEmbeddingsRun,
)
from langchain.chains.base import Chain
from langchain.chains.hyde.prompts import PROMPT_MAP
from langchain.chains.llm import LLMChain
@@ -42,7 +45,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
"""Output keys for Hyde's LLM chain."""
return self.llm_chain.output_keys
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Call the base embeddings."""
return self.base_embeddings.embed_documents(texts)
@@ -50,7 +58,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
"""Combine embeddings into final embeddings."""
return list(np.array(embeddings).mean(axis=0))
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Generate a hypothetical document and embedded it."""
var_name = self.llm_chain.input_keys[0]
result = self.llm_chain.generate([{var_name: text}])

View File

@@ -1,7 +1,10 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -106,7 +109,12 @@ class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings):
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Call out to Aleph Alpha's asymmetric Document endpoint.
Args:
@@ -147,7 +155,13 @@ class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings):
return document_embeddings
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed query text."""
"""Call out to Aleph Alpha's asymmetric, query embedding endpoint
Args:
text: The text to embed.
@@ -230,7 +244,12 @@ class AlephAlphaSymmetricSemanticEmbedding(AlephAlphaAsymmetricSemanticEmbedding
return query_response.embedding
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Call out to Aleph Alpha's Document endpoint.
Args:
@@ -245,7 +264,13 @@ class AlephAlphaSymmetricSemanticEmbedding(AlephAlphaAsymmetricSemanticEmbedding
document_embeddings.append(self._embed(text))
return document_embeddings
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed query text."""
"""Call out to Aleph Alpha's asymmetric, query embedding endpoint
Args:
text: The text to embed.

View File

@@ -1,7 +1,8 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Sequence
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.embeddings.base import Embeddings
@@ -33,7 +34,12 @@ class AwaEmbeddings(BaseModel, Embeddings):
self.model = model_name
self.client.model_name = model_name
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Embed a list of documents using AwaEmbedding.
Args:
@@ -44,7 +50,12 @@ class AwaEmbeddings(BaseModel, Embeddings):
"""
return self.client.EmbeddingBatch(texts)
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Compute query embeddings using AwaEmbedding.
Args:

View File

@@ -1,22 +1,228 @@
"""Interface for embedding models."""
import asyncio
import warnings
from abc import ABC, abstractmethod
from typing import List
from inspect import signature
from typing import Any, Dict, List, Optional, Sequence
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForEmbeddingsRun,
CallbackManager,
CallbackManagerForEmbeddingsRun,
Callbacks,
)
class Embeddings(ABC):
"""Interface for embedding models."""
_new_arg_supported: bool = False
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__()
if cls.embed_documents != Embeddings.embed_documents:
warnings.warn(
"Embedding models must implement abstract `_embed_documents` method"
" instead of `embed_documents`",
DeprecationWarning,
)
swap = cls.embed_documents
cls.embed_documents = Embeddings.embed_documents # type: ignore[assignment]
cls._embed_documents = swap # type: ignore[assignment]
if (
hasattr(cls, "aembed_documents")
and cls.aembed_documents != Embeddings.aembed_documents
):
warnings.warn(
"Embedding models must implement abstract `_aembed_documents` method"
" instead of `aembed_documents`",
DeprecationWarning,
)
aswap = cls.aembed_documents
cls.aembed_documents = ( # type: ignore[assignment]
Embeddings.aembed_documents
)
cls._aembed_documents = aswap # type: ignore[assignment]
parameters = signature(cls._embed_documents).parameters
cls._new_arg_supported = parameters.get("run_managers") is not None
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed query text."""
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
raise NotImplementedError
async def _aembed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[AsyncCallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Embed search docs."""
raise NotImplementedError(f"{self.__class__.__name__} does not support async")
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
raise NotImplementedError
async def _aembed_query(
self,
text: str,
*,
run_manager: AsyncCallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed query text."""
raise NotImplementedError(f"{self.__class__.__name__} does not support async")
def embed_documents(
self,
texts: List[str],
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> List[List[float]]:
"""Embed search docs."""
callback_manager = CallbackManager.configure(
callbacks, None, inheritable_tags=tags, inheritable_metadata=metadata
)
run_managers: List[
CallbackManagerForEmbeddingsRun
] = callback_manager.on_embedding_start(
{}, # TODO: make embeddings serializable
texts,
)
try:
if self._new_arg_supported:
result = self._embed_documents(
texts,
run_managers=run_managers,
)
else:
result = self._embed_documents(texts) # type: ignore[call-arg]
except Exception as e:
for run_manager in run_managers:
run_manager.on_embedding_error(e)
raise e
else:
for single_result, run_manager in zip(result, run_managers):
run_manager.on_embedding_end(
single_result,
)
return result
def embed_query(
self,
text: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> List[float]:
"""Embed query text."""
callback_manager = CallbackManager.configure(
callbacks, None, inheritable_tags=tags, inheritable_metadata=metadata
)
run_managers: List[
CallbackManagerForEmbeddingsRun
] = callback_manager.on_embedding_start(
{}, # TODO: make embeddings serializable
[text],
)
try:
if self._new_arg_supported:
result = self._embed_query(text, run_manager=run_managers[0])
else:
result = self._embed_query(text) # type: ignore[call-arg]
except Exception as e:
run_managers[0].on_embedding_error(e)
raise e
else:
run_managers[0].on_embedding_end(
result,
)
return result
async def aembed_documents(
self,
texts: List[str],
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> List[List[float]]:
"""Asynchronously embed search docs."""
callback_manager = AsyncCallbackManager.configure(
callbacks, None, inheritable_tags=tags, inheritable_metadata=metadata
)
run_managers: List[
AsyncCallbackManagerForEmbeddingsRun
] = await callback_manager.on_embedding_start(
{}, # TODO: make embeddings serializable
texts,
)
try:
if self._new_arg_supported:
result = await self._aembed_documents(
texts,
run_managers=run_managers,
)
else:
result = await self._aembed_documents(texts) # type: ignore[call-arg]
except Exception as e:
tasks = [run_manager.on_embedding_error(e) for run_manager in run_managers]
await asyncio.gather(*tasks)
raise e
else:
tasks = [
run_manager.on_embedding_end(
single_result,
)
for run_manager, single_result in zip(run_managers, result)
]
await asyncio.gather(*tasks)
return result
async def aembed_query(
self,
text: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> List[float]:
"""Asynchronously embed query text."""
callback_manager = AsyncCallbackManager.configure(
callbacks, None, inheritable_tags=tags, inheritable_metadata=metadata
)
run_managers: List[
AsyncCallbackManagerForEmbeddingsRun
] = await callback_manager.on_embedding_start(
{}, # TODO: make embeddings serializable
[text],
)
try:
if self._new_arg_supported:
result = await self._aembed_query(
text,
run_manager=run_managers[0],
)
else:
result = await self._aembed_query(text) # type: ignore[call-arg]
except Exception as e:
await run_managers[0].on_embedding_error(e)
raise e
else:
await run_managers[0].on_embedding_end(
result,
)
return result

View File

@@ -1,9 +1,10 @@
import json
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.embeddings.base import Embeddings
@@ -128,8 +129,11 @@ class BedrockEmbeddings(BaseModel, Embeddings):
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
def embed_documents(
self, texts: List[str], chunk_size: int = 1
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Compute doc embeddings using a Bedrock model.
@@ -149,7 +153,13 @@ class BedrockEmbeddings(BaseModel, Embeddings):
results.append(response)
return results
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed query text."""
"""Compute query embeddings using a Bedrock model.
Args:

View File

@@ -1,8 +1,11 @@
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -82,7 +85,12 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Call out to Clarifai's embedding models.
Args:
@@ -137,7 +145,12 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
]
return embeddings
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Call out to Clarifai's embedding models.
Args:

View File

@@ -1,7 +1,12 @@
from typing import Any, Dict, List, Optional
"""Wrapper around Cohere embedding models."""
from typing import Any, Dict, List, Optional, Sequence
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForEmbeddingsRun,
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -57,7 +62,12 @@ class CohereEmbeddings(BaseModel, Embeddings):
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Call out to Cohere's embedding endpoint.
Args:
@@ -71,7 +81,12 @@ class CohereEmbeddings(BaseModel, Embeddings):
).embeddings
return [list(map(float, e)) for e in embeddings]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
async def _aembed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[AsyncCallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Async call out to Cohere's embedding endpoint.
Args:
@@ -85,7 +100,12 @@ class CohereEmbeddings(BaseModel, Embeddings):
)
return [list(map(float, e)) for e in embeddings.embeddings]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Call out to Cohere's embedding endpoint.
Args:
@@ -94,9 +114,11 @@ class CohereEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
return self._embed_documents([text], run_managers=[run_manager])[0]
async def aembed_query(self, text: str) -> List[float]:
async def _aembed_query(
self, text: str, *, run_manager: AsyncCallbackManagerForEmbeddingsRun
) -> List[float]:
"""Async call out to Cohere's embedding endpoint.
Args:
@@ -105,5 +127,5 @@ class CohereEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
embeddings = await self.aembed_documents([text])
embeddings = await self._aembed_documents([text], run_managers=[run_manager])
return embeddings[0]

View File

@@ -7,6 +7,7 @@ from typing import (
Dict,
List,
Optional,
Sequence,
)
from pydantic import BaseModel, Extra, root_validator
@@ -19,6 +20,9 @@ from tenacity import (
wait_exponential,
)
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -123,7 +127,12 @@ class DashScopeEmbeddings(BaseModel, Embeddings):
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Call out to DashScope's embedding endpoint for embedding search docs.
Args:
@@ -140,7 +149,12 @@ class DashScopeEmbeddings(BaseModel, Embeddings):
embedding_list = [item["embedding"] for item in embeddings]
return embedding_list
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Call out to DashScope's embedding endpoint for embedding query text.
Args:

View File

@@ -1,8 +1,11 @@
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Mapping, Optional, Sequence
import requests
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -102,7 +105,12 @@ class DeepInfraEmbeddings(BaseModel, Embeddings):
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Embed documents using a Deep Infra deployed embedding model.
Args:
@@ -115,7 +123,12 @@ class DeepInfraEmbeddings(BaseModel, Embeddings):
embeddings = self._embed(instruction_pairs)
return embeddings
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed a query using a Deep Infra deployed embedding model.
Args:

View File

@@ -1,7 +1,8 @@
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Sequence
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.embeddings.base import Embeddings
from langchain.requests import Requests
from langchain.utils import get_from_dict_or_env
@@ -64,7 +65,12 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Embed a list of documents using EdenAI.
Args:
@@ -76,7 +82,12 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
return self._generate_embeddings(texts)
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed a query using EdenAI.
Args:

View File

@@ -1,7 +1,10 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Sequence
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.utils import get_from_env
if TYPE_CHECKING:
@@ -195,7 +198,12 @@ class ElasticsearchEmbeddings(Embeddings):
embeddings = [doc["predicted_value"] for doc in response["inference_results"]]
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""
Generate embeddings for a list of documents.
@@ -209,7 +217,13 @@ class ElasticsearchEmbeddings(Embeddings):
"""
return self._embedding_func(texts)
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed query text."""
"""
Generate an embedding for a single query text.

View File

@@ -1,9 +1,13 @@
from typing import Any, Dict, List, Mapping, Optional
"""Wrapper around embaas embeddings API."""
from typing import Any, Dict, List, Mapping, Optional, Sequence
import requests
from pydantic import BaseModel, Extra, root_validator
from typing_extensions import NotRequired, TypedDict
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -110,7 +114,12 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
)
raise
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Get embeddings for a list of texts.
Args:
@@ -126,7 +135,13 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
# flatten the list of lists into a single list
return [embedding for batch in embeddings for embedding in batch]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed query text."""
"""Get embeddings for a single text.
Args:

View File

@@ -1,9 +1,10 @@
import hashlib
from typing import List
from typing import List, Sequence
import numpy as np
from pydantic import BaseModel
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.embeddings.base import Embeddings
@@ -16,10 +17,21 @@ class FakeEmbeddings(Embeddings, BaseModel):
def _get_embedding(self) -> List[float]:
return list(np.random.normal(size=self.size))
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
return [self._get_embedding() for _ in texts]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed query text."""
return self._get_embedding()
@@ -43,8 +55,15 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
"""
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self, text: str, *, run_manager: CallbackManagerForEmbeddingsRun
) -> List[float]:
return self._get_embedding(seed=self._get_seed(text))

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Sequence
from pydantic import BaseModel, root_validator
from tenacity import (
@@ -12,6 +12,9 @@ from tenacity import (
wait_exponential,
)
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -78,10 +81,21 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Embed search docs."""
return [self.embed_query(text) for text in texts]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed query text."""
embedding = embed_with_retry(self, self.model_name, text)
return embedding["embedding"]

View File

@@ -1,7 +1,11 @@
from typing import Any, Dict, List
"""Wrapper around GPT4All embedding models."""
from typing import Any, Dict, List, Sequence
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
@@ -36,7 +40,12 @@ class GPT4AllEmbeddings(BaseModel, Embeddings):
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Embed a list of documents using GPT4All.
Args:
@@ -49,7 +58,12 @@ class GPT4AllEmbeddings(BaseModel, Embeddings):
embeddings = [self.client.embed(text) for text in texts]
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed a query using GPT4All.
Args:

View File

@@ -1,7 +1,11 @@
from typing import Any, Dict, List, Optional
"""Wrapper around HuggingFace embedding models."""
from typing import Any, Dict, List, Optional, Sequence
from pydantic import BaseModel, Extra, Field
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
@@ -64,7 +68,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
@@ -77,7 +86,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
embeddings = self.client.encode(texts, **self.encode_kwargs)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
@@ -144,7 +158,12 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace instruct model.
Args:
@@ -157,7 +176,12 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Compute query embeddings using a HuggingFace instruct model.
Args:

View File

@@ -1,7 +1,11 @@
from typing import Any, Dict, List, Optional
"""Wrapper around HuggingFace Hub embedding models."""
from typing import Any, Dict, List, Optional, Sequence
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -76,7 +80,12 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Call out to HuggingFaceHub's embedding endpoint for embedding search docs.
Args:
@@ -91,7 +100,12 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
responses = self.client(inputs=texts, params=_model_kwargs)
return responses
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Call out to HuggingFaceHub's embedding endpoint for embedding query text.
Args:

View File

@@ -1,9 +1,12 @@
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence
import requests
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -71,7 +74,12 @@ class JinaEmbeddings(BaseModel, Embeddings):
payload = dict(inputs=docs, metadata=self.request_headers, **kwargs)
return self.client.post(on="/encode", **payload)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Call out to Jina's embedding endpoint.
Args:
texts: The list of texts to embed.
@@ -85,7 +93,12 @@ class JinaEmbeddings(BaseModel, Embeddings):
).embeddings
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Call out to Jina's embedding endpoint.
Args:
text: The text to embed.

View File

@@ -1,7 +1,11 @@
from typing import Any, Dict, List, Optional
"""Wrapper around llama.cpp embedding models."""
from typing import Any, Dict, List, Optional, Sequence
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
@@ -98,7 +102,12 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Embed a list of documents using the Llama model.
Args:
@@ -110,7 +119,12 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
embeddings = [self.client.embed(text) for text in texts]
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed a query using the Llama model.
Args:

View File

@@ -25,6 +25,10 @@ from tenacity import (
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForEmbeddingsRun,
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
@@ -285,8 +289,11 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
)
)["data"][0]["embedding"]
def embed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Call out to LocalAI's embedding endpoint for embedding search docs.
@@ -301,8 +308,11 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
# call _embedding_func for each text
return [self._embedding_func(text, engine=self.deployment) for text in texts]
async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
async def _aembed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[AsyncCallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Call out to LocalAI's embedding endpoint async for embedding search docs.
@@ -320,7 +330,9 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
embeddings.append(response)
return embeddings
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self, text: str, *, run_manager: CallbackManagerForEmbeddingsRun
) -> List[float]:
"""Call out to LocalAI's embedding endpoint for embedding query text.
Args:
@@ -332,7 +344,9 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
embedding = self._embedding_func(text, engine=self.deployment)
return embedding
async def aembed_query(self, text: str) -> List[float]:
async def _aembed_query(
self, text: str, *, run_manager: AsyncCallbackManagerForEmbeddingsRun
) -> List[float]:
"""Call out to LocalAI's embedding endpoint async for embedding query text.
Args:

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Sequence
import requests
from pydantic import BaseModel, Extra, root_validator
@@ -12,6 +12,9 @@ from tenacity import (
wait_exponential,
)
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -135,7 +138,12 @@ class MiniMaxEmbeddings(BaseModel, Embeddings):
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Embed documents using a MiniMax embedding endpoint.
Args:
@@ -147,7 +155,12 @@ class MiniMaxEmbeddings(BaseModel, Embeddings):
embeddings = embed_with_retry(self, texts=texts, embed_type=self.embed_type_db)
return embeddings
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed a query using a MiniMax embedding endpoint.
Args:

View File

@@ -1,9 +1,10 @@
from __future__ import annotations
from typing import Any, Iterator, List, Optional
from typing import Any, Iterator, List, Optional, Sequence
from pydantic import BaseModel
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.embeddings.base import Embeddings
@@ -63,8 +64,18 @@ class MlflowAIGatewayEmbeddings(Embeddings, BaseModel):
embeddings.append(resp["embeddings"])
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
return self._query(texts)
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
return self._query([text])[0]

View File

@@ -1,7 +1,9 @@
from typing import Any, List, Optional
"""Wrapper around ModelScopeHub embedding models."""
from typing import Any, List, Optional, Sequence
from pydantic import BaseModel, Extra
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.embeddings.base import Embeddings
@@ -45,7 +47,12 @@ class ModelScopeEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Compute doc embeddings using a modelscope embedding model.
Args:
@@ -59,7 +66,12 @@ class ModelScopeEmbeddings(BaseModel, Embeddings):
embeddings = self.embed(input=inputs)["text_embedding"]
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Compute query embeddings using a modelscope embedding model.
Args:

View File

@@ -1,8 +1,12 @@
from typing import Any, Dict, List, Mapping, Optional, Tuple
"""Wrapper around MosaicML APIs."""
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
import requests
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -136,7 +140,12 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Embed documents using a MosaicML deployed instructor embedding model.
Args:
@@ -149,7 +158,12 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
embeddings = self._embed(instruction_pairs)
return embeddings
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed a query using a MosaicML deployed instructor embedding model.
Args:

View File

@@ -1,7 +1,8 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Sequence
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -27,7 +28,7 @@ class NLPCloudEmbeddings(BaseModel, Embeddings):
self,
model_name: str = "paraphrase-multilingual-mpnet-base-v2",
gpu: bool = False,
**kwargs: Any
**kwargs: Any,
) -> None:
super().__init__(model_name=model_name, gpu=gpu, **kwargs)
@@ -50,7 +51,12 @@ class NLPCloudEmbeddings(BaseModel, Embeddings):
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Embed a list of documents using NLP Cloud.
Args:
@@ -62,7 +68,12 @@ class NLPCloudEmbeddings(BaseModel, Embeddings):
return self.client.embeddings(texts)["embeddings"]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed a query using NLP Cloud.
Args:

View File

@@ -1,7 +1,12 @@
from typing import Any, Dict, List, Mapping, Optional
"""Module providing a wrapper around OctoAI Compute Service embedding models."""
from typing import Any, Dict, List, Mapping, Optional, Sequence
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@@ -80,12 +85,22 @@ class OctoAIEmbeddings(BaseModel, Embeddings):
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Compute document embeddings using an OctoAI instruct model."""
texts = list(map(lambda x: x.replace("\n", " "), texts))
return self._compute_embeddings(texts, self.embed_instruction)
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Compute query embedding using an OctoAI instruct model."""
text = text.replace("\n", " ")
return self._compute_embeddings([text], self.embed_instruction)[0]

View File

@@ -26,6 +26,10 @@ from tenacity import (
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForEmbeddingsRun,
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
@@ -124,29 +128,23 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
To use, you should have the ``openai`` python package installed, and the
environment variable ``OPENAI_API_KEY`` set with your API key or pass it
as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.embeddings import OpenAIEmbeddings
openai = OpenAIEmbeddings(openai_api_key="my-api-key")
In order to use the library with Microsoft Azure endpoints, you need to set
the OPENAI_API_TYPE, OPENAI_API_BASE, OPENAI_API_KEY and OPENAI_API_VERSION.
The OPENAI_API_TYPE must be set to 'azure' and the others correspond to
the properties of your endpoint.
In addition, the deployment name must be passed as the model parameter.
Example:
.. code-block:: python
import os
os.environ["OPENAI_API_TYPE"] = "azure"
os.environ["OPENAI_API_BASE"] = "https://<your-endpoint.openai.azure.com/"
os.environ["OPENAI_API_KEY"] = "your AzureOpenAI key"
os.environ["OPENAI_API_VERSION"] = "2023-05-15"
os.environ["OPENAI_PROXY"] = "http://your-corporate-proxy:8080"
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(
deployment="your-embeddings-deployment-name",
@@ -156,7 +154,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
)
text = "This is a test query."
query_result = embeddings.embed_query(text)
"""
client: Any #: :meta private:
@@ -454,16 +451,51 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return embeddings
def embed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint."""
# handle large input text
if len(text) > self.embedding_ctx_length:
return self._get_len_safe_embeddings([text], engine=engine)[0]
else:
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return embed_with_retry(
self,
input=[text],
**self._invocation_params,
)[
"data"
][0]["embedding"]
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint."""
# handle large input text
if len(text) > self.embedding_ctx_length:
return (await self._aget_len_safe_embeddings([text], engine=engine))[0]
else:
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (
await async_embed_with_retry(
self,
input=[text],
**self._invocation_params,
)
)["data"][0]["embedding"]
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs.
Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.
Returns:
List of embeddings, one for each text.
"""
@@ -471,16 +503,15 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# than the maximum context and use length-safe embedding function.
return self._get_len_safe_embeddings(texts, engine=self.deployment)
async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
async def _aembed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[AsyncCallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.
Returns:
List of embeddings, one for each text.
"""
@@ -488,23 +519,29 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# than the maximum context and use length-safe embedding function.
return await self._aget_len_safe_embeddings(texts, engine=self.deployment)
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
return self.embed_documents([text])[0]
async def aembed_query(self, text: str) -> List[float]:
async def _aembed_query(
self,
text: str,
*,
run_manager: AsyncCallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Call out to OpenAI's embedding endpoint async for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""

View File

@@ -1,7 +1,11 @@
from typing import Any, Dict, List, Optional
"""Wrapper around Sagemaker InvokeEndpoint API."""
from typing import Any, Dict, List, Optional, Sequence
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
@@ -98,6 +102,8 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
function. See `boto3`_. docs for more info.
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>
"""
chunk_size: int = 64
"""The number of documents to send to the endpoint at a time."""
class Config:
"""Configuration for this pydantic object."""
@@ -163,8 +169,11 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
return self.content_handler.transform_output(response["Body"])
def embed_documents(
self, texts: List[str], chunk_size: int = 64
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Compute doc embeddings using a SageMaker Inference Endpoint.
@@ -179,13 +188,19 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
List of embeddings, one for each text.
"""
results = []
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
_chunk_size = len(texts) if self.chunk_size > len(texts) else self.chunk_size
for i in range(0, len(texts), _chunk_size):
response = self._embedding_func(texts[i : i + _chunk_size])
results.extend(response)
return results
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed query text."""
"""Compute query embeddings using a SageMaker inference endpoint.
Args:

View File

@@ -1,7 +1,11 @@
from typing import Any, Callable, List
"""Running custom embedding models on self-hosted remote hardware."""
from typing import Any, Callable, List, Sequence
from pydantic import Extra
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
from langchain.llms import SelfHostedPipeline
@@ -71,7 +75,12 @@ class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings):
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
@@ -86,7 +95,13 @@ class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings):
return embeddings.tolist()
return embeddings
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed query text."""
"""Compute query embeddings using a HuggingFace transformer model.
Args:

View File

@@ -1,7 +1,10 @@
import importlib
import logging
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Sequence
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.self_hosted import SelfHostedEmbeddings
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
@@ -139,7 +142,12 @@ class SelfHostedHuggingFaceInstructEmbeddings(SelfHostedHuggingFaceEmbeddings):
load_fn_kwargs["device"] = load_fn_kwargs.get("device", 0)
super().__init__(load_fn_kwargs=load_fn_kwargs, **kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace instruct model.
Args:
@@ -154,7 +162,13 @@ class SelfHostedHuggingFaceInstructEmbeddings(SelfHostedHuggingFaceEmbeddings):
embeddings = self.client(self.pipeline_ref, instruction_pairs)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed query text."""
"""Compute query embeddings using a HuggingFace instruct model.
Args:

View File

@@ -1,8 +1,12 @@
import importlib.util
from typing import Any, Dict, List
from typing import Any, Dict, List, Sequence
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForEmbeddingsRun,
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
@@ -63,7 +67,12 @@ class SpacyEmbeddings(BaseModel, Embeddings):
)
return values # Return the validated values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""
Generates embeddings for a list of documents.
@@ -75,7 +84,13 @@ class SpacyEmbeddings(BaseModel, Embeddings):
"""
return [self.nlp(text).vector.tolist() for text in texts]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed query text."""
"""
Generates an embedding for a single piece of text.
@@ -87,7 +102,12 @@ class SpacyEmbeddings(BaseModel, Embeddings):
"""
return self.nlp(text).vector.tolist()
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
async def _aembed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[AsyncCallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""
Asynchronously generates embeddings for a list of documents.
This method is not implemented and raises a NotImplementedError.
@@ -100,7 +120,12 @@ class SpacyEmbeddings(BaseModel, Embeddings):
"""
raise NotImplementedError("Asynchronous embedding generation is not supported.")
async def aembed_query(self, text: str) -> List[float]:
async def _aembed_query(
self,
text: str,
*,
run_manager: AsyncCallbackManagerForEmbeddingsRun,
) -> List[float]:
"""
Asynchronously generates an embedding for a single piece of text.
This method is not implemented and raises a NotImplementedError.

View File

@@ -1,7 +1,11 @@
from typing import Any, List
"""Wrapper around TensorflowHub embedding models."""
from typing import Any, List, Sequence
from pydantic import BaseModel, Extra
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.embeddings.base import Embeddings
DEFAULT_MODEL_URL = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
@@ -49,7 +53,12 @@ class TensorflowHubEmbeddings(BaseModel, Embeddings):
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Compute doc embeddings using a TensorflowHub embedding model.
Args:
@@ -62,7 +71,12 @@ class TensorflowHubEmbeddings(BaseModel, Embeddings):
embeddings = self.embed(texts).numpy()
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Compute query embeddings using a TensorflowHub embedding model.
Args:

View File

@@ -1,7 +1,9 @@
from typing import Dict, List
"""Wrapper around Google VertexAI embedding models."""
from typing import Dict, List, Sequence
from pydantic import root_validator
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.embeddings.base import Embeddings
from langchain.llms.vertexai import _VertexAICommon
from langchain.utilities.vertexai import raise_vertex_import_error
@@ -11,6 +13,8 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
"""Google Cloud VertexAI embedding models."""
model_name: str = "textembedding-gecko"
batch_size: int = 5
"""Number of texts to embed in a single API call."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@@ -23,8 +27,11 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values
def embed_documents(
self, texts: List[str], batch_size: int = 5
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Embed a list of strings. Vertex AI currently
sets a max batch size of 5 strings.
@@ -37,13 +44,18 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
List of embeddings, one for each text.
"""
embeddings = []
for batch in range(0, len(texts), batch_size):
text_batch = texts[batch : batch + batch_size]
for batch in range(0, len(texts), self.batch_size):
text_batch = texts[batch : batch + self.batch_size]
embeddings_batch = self.client.get_embeddings(text_batch)
embeddings.extend([el.values for el in embeddings_batch])
return embeddings
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed a text.
Args:

View File

@@ -1,6 +1,7 @@
"""Wrapper around Xinference embedding models."""
from typing import Any, List, Optional
from typing import Any, List, Optional, Sequence
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.embeddings.base import Embeddings
@@ -81,7 +82,12 @@ class XinferenceEmbeddings(Embeddings):
self.client = RESTfulClient(server_url)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Embed a list of documents using Xinference.
Args:
texts: The list of texts to embed.
@@ -96,7 +102,12 @@ class XinferenceEmbeddings(Embeddings):
]
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Embed a query of documents using Xinference.
Args:
text: The text to embed.

View File

@@ -239,8 +239,12 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
Returns:
Dict[str, Any]: The computed score.
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
vectors = np.array(
self.embeddings.embed_documents([inputs["prediction"], inputs["reference"]])
self.embeddings.embed_documents(
[inputs["prediction"], inputs["reference"]],
callbacks=_run_manager.get_child(),
)
)
score = self._compute_score(vectors)
return {"score": score}
@@ -260,8 +264,10 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
Returns:
Dict[str, Any]: The computed score.
"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
embedded = await self.embeddings.aembed_documents(
[inputs["prediction"], inputs["reference"]]
[inputs["prediction"], inputs["reference"]],
callbacks=_run_manager.get_child(),
)
vectors = np.array(embedded)
score = self._compute_score(vectors)
@@ -376,9 +382,11 @@ class PairwiseEmbeddingDistanceEvalChain(
Returns:
Dict[str, Any]: The computed score.
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
vectors = np.array(
self.embeddings.embed_documents(
[inputs["prediction"], inputs["prediction_b"]]
[inputs["prediction"], inputs["prediction_b"]],
callbacks=_run_manager.get_child(),
)
)
score = self._compute_score(vectors)
@@ -399,8 +407,10 @@ class PairwiseEmbeddingDistanceEvalChain(
Returns:
Dict[str, Any]: The computed score.
"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
embedded = await self.embeddings.aembed_documents(
[inputs["prediction"], inputs["prediction_b"]]
[inputs["prediction"], inputs["prediction_b"]],
callbacks=_run_manager.get_child(),
)
vectors = np.array(embedded)
score = self._compute_score(vectors)

View File

@@ -1,7 +1,8 @@
"""Fake Embedding class for testing purposes."""
import math
from typing import List
from typing import List, Sequence
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.embeddings.base import Embeddings
fake_texts = ["foo", "bar", "baz"]
@@ -10,12 +11,22 @@ fake_texts = ["foo", "bar", "baz"]
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Return simple embeddings.
Embeddings encode each text as its index."""
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index,
@@ -31,7 +42,12 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
self.known_texts: List[str] = []
self.dimensionality = dimensionality
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
@@ -43,7 +59,12 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
out_vectors.append(vector)
return out_vectors
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
if text not in self.known_texts:
@@ -58,13 +79,23 @@ class AngularTwoDimensionalEmbeddings(Embeddings):
From angles (as strings in units of pi) to unit embedding vectors on a circle.
"""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""
Make a list of texts into a list of embedding vectors.
"""
return [self.embed_query(text) for text in texts]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""
Convert input text to a 'vector' (list of floats).
If the text is a number, use it as the angle for the

View File

@@ -1,5 +1,8 @@
from typing import List
from typing import List, Sequence
from langchain.callbacks.manager import (
CallbackManagerForEmbeddingsRun,
)
from langchain.schema import Document
from langchain.vectorstores.alibabacloud_opensearch import (
AlibabaCloudOpenSearch,
@@ -15,14 +18,23 @@ texts = ["foo", "bar", "baz"]
class FakeEmbeddingsWithOsDimension(FakeEmbeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Return simple embeddings."""
return [
[float(1.0)] * (OS_TOKEN_COUNT - 1) + [float(i)]
for i in range(len(embedding_texts))
[float(1.0)] * (OS_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Return simple embeddings."""
return [float(1.0)] * (OS_TOKEN_COUNT - 1) + [float(texts.index(text))]

View File

@@ -1,7 +1,8 @@
"""Test PGVector functionality."""
import os
from typing import List
from typing import List, Sequence
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.docstore.document import Document
from langchain.vectorstores.analyticdb import AnalyticDB
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
@@ -22,13 +23,23 @@ ADA_TOKEN_COUNT = 1536
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Return simple embeddings."""
return [
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Return simple embeddings."""
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]

View File

@@ -1,7 +1,8 @@
"""Test Hologres functionality."""
import os
from typing import List
from typing import List, Sequence
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.docstore.document import Document
from langchain.vectorstores.hologres import Hologres
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
@@ -21,13 +22,23 @@ ADA_TOKEN_COUNT = 1536
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Return simple embeddings."""
return [
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Return simple embeddings."""
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]

View File

@@ -1,9 +1,10 @@
"""Test PGVector functionality."""
import os
from typing import List
from typing import List, Sequence
from sqlalchemy.orm import Session
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.docstore.document import Document
from langchain.vectorstores.pgvector import PGVector
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
@@ -24,13 +25,23 @@ ADA_TOKEN_COUNT = 1536
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Return simple embeddings."""
return [
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Return simple embeddings."""
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]

View File

@@ -1,9 +1,10 @@
"""Test SingleStoreDB functionality."""
from typing import List
from typing import List, Sequence
import numpy as np
import pytest
from langchain.callbacks.manager import CallbackManagerForEmbeddingsRun
from langchain.docstore.document import Document
from langchain.vectorstores.singlestoredb import SingleStoreDB
from langchain.vectorstores.utils import DistanceStrategy
@@ -36,10 +37,20 @@ class NormilizedFakeEmbeddings(FakeEmbeddings):
"""Normalize vector."""
return [float(v / np.linalg.norm(vector)) for v in vector]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
return [self.normalize(v) for v in super().embed_documents(texts)]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
return self.normalize(super().embed_query(text))

View File

@@ -1,10 +1,11 @@
"""Test HyDE."""
from typing import Any, List, Optional
from typing import Any, List, Optional, Sequence
import numpy as np
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForEmbeddingsRun,
CallbackManagerForLLMRun,
)
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
@@ -17,11 +18,21 @@ from langchain.schema import Generation, LLMResult
class FakeEmbeddings(Embeddings):
"""Fake embedding class for tests."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
) -> List[List[float]]:
"""Return random floats."""
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self,
text: str,
*,
run_manager: CallbackManagerForEmbeddingsRun,
) -> List[float]:
"""Return random floats."""
return list(np.random.uniform(0, 1, 10))