mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-30 13:50:11 +00:00
Compare commits
1 Commits
langchain-
...
wfh/embedd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
472b434f02 |
@@ -46,6 +46,36 @@ class LLMManagerMixin:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
|
||||
class EmbeddingsManagerMixin:
|
||||
"""Mixin for Embeddings callbacks."""
|
||||
|
||||
def on_embeddings_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when an embedding model throws an error."""
|
||||
|
||||
def on_embeddings_end(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when embeddings model finishes generating embeddings.
|
||||
|
||||
Args:
|
||||
embeddings (List[List[float]]): The generated embeddings.
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
"""
|
||||
|
||||
|
||||
class ChainManagerMixin:
|
||||
"""Mixin for chain callbacks."""
|
||||
|
||||
@@ -144,6 +174,18 @@ class CallbackManagerMixin:
|
||||
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||
)
|
||||
|
||||
def on_embeddings_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
texts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Embeddings starts running."""
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
@@ -185,6 +227,7 @@ class RunManagerMixin:
|
||||
|
||||
class BaseCallbackHandler(
|
||||
LLMManagerMixin,
|
||||
EmbeddingsManagerMixin,
|
||||
ChainManagerMixin,
|
||||
ToolManagerMixin,
|
||||
CallbackManagerMixin,
|
||||
@@ -216,6 +259,11 @@ class BaseCallbackHandler(
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_embeddings(self) -> bool:
|
||||
"""Whether to ignore embeddings callbacks."""
|
||||
return False
|
||||
|
||||
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
@@ -241,7 +289,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
) -> None:
|
||||
"""Run when a chat model starts running."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||
@@ -277,6 +325,38 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
async def on_embeddings_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
texts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: List[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when embeddings call starts running."""
|
||||
|
||||
async def on_embeddings_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when embeddings call ends running."""
|
||||
|
||||
async def on_embeddings_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when embeddings call errors."""
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
|
||||
@@ -26,6 +26,7 @@ from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
ChainManagerMixin,
|
||||
EmbeddingsManagerMixin,
|
||||
LLMManagerMixin,
|
||||
RunManagerMixin,
|
||||
ToolManagerMixin,
|
||||
@@ -531,6 +532,52 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForEmbeddingsRun(RunManager, EmbeddingsManagerMixin):
|
||||
"""Callback manager for embeddings run."""
|
||||
|
||||
def on_embeddings_end(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when embeddings are generated.
|
||||
|
||||
Args:
|
||||
embeddings (List[List[float]]): The generated embeddings.
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
"""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_embeddings_end",
|
||||
"ignore_embeddings",
|
||||
embeddings,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_embeddings_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when embeddings errors.
|
||||
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
"""
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_embeddings_error",
|
||||
"ignore_embeddings",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
"""Async callback manager for LLM run."""
|
||||
|
||||
@@ -781,6 +828,52 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForEmbeddingsRun(RunManager, EmbeddingsManagerMixin):
|
||||
"""Callback manager for embeddings run."""
|
||||
|
||||
async def on_embeddings_end(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when embeddings are generated.
|
||||
|
||||
Args:
|
||||
embeddings (List[List[float]]): The generated embeddings.
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
"""
|
||||
_ahandle_event(
|
||||
self.handlers,
|
||||
"on_embeddings_end",
|
||||
"ignore_embeddings",
|
||||
embeddings,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def on_embeddings_error(
|
||||
self,
|
||||
error: Union[Exception, KeyboardInterrupt],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when embeddings errors.
|
||||
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
"""
|
||||
_ahandle_event(
|
||||
self.handlers,
|
||||
"on_embeddings_error",
|
||||
"ignore_embeddings",
|
||||
error,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
|
||||
"""Callback manager for tool run."""
|
||||
|
||||
@@ -993,6 +1086,47 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
return managers
|
||||
|
||||
def on_embeddings_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
texts: List[str],
|
||||
**kwargs: Any,
|
||||
) -> List[CallbackManagerForEmbeddingsRun]:
|
||||
"""Run when embeddings model starts running.
|
||||
|
||||
Args:
|
||||
serialized (Dict[str, Any]): The serialized embeddings model.
|
||||
texts (List[str]): The list of texts.
|
||||
Returns:
|
||||
List[CallbackManagerForEmbeddingsRun]: A callback manager for each
|
||||
text as an embeddings run.
|
||||
"""
|
||||
managers = []
|
||||
for text in texts:
|
||||
run_id_ = uuid4()
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_embeddings_start",
|
||||
"ignore_embeddings",
|
||||
serialized,
|
||||
[text],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
managers.append(
|
||||
CallbackManagerForEmbeddingsRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
)
|
||||
|
||||
return managers
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
@@ -1226,6 +1360,51 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
await asyncio.gather(*tasks)
|
||||
return managers
|
||||
|
||||
async def on_embeddings_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
texts: List[str],
|
||||
**kwargs: Any,
|
||||
) -> List[AsyncCallbackManagerForEmbeddingsRun]:
|
||||
"""Run when embeddings model starts running.
|
||||
|
||||
Args:
|
||||
serialized (Dict[str, Any]): The serialized embeddings model.
|
||||
texts (List[str]): The list of texts.
|
||||
Returns:
|
||||
List[CallbackManagerForEmbeddingsRun]: A callback manager for each
|
||||
text as an embeddings run.
|
||||
"""
|
||||
tasks = []
|
||||
managers = []
|
||||
for text in texts:
|
||||
run_id_ = uuid4()
|
||||
tasks.append(
|
||||
_ahandle_event(
|
||||
self.handlers,
|
||||
"on_embeddings_start",
|
||||
"ignore_embeddings",
|
||||
serialized,
|
||||
[text],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
managers.append(
|
||||
AsyncCallbackManagerForEmbeddingsRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
return managers
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
|
||||
@@ -1,23 +1,217 @@
|
||||
"""Interface for embedding models."""
|
||||
import asyncio
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from inspect import signature
|
||||
from typing import Any, List, Sequence
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForEmbeddingsRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForEmbeddingsRun,
|
||||
Callbacks,
|
||||
)
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models."""
|
||||
|
||||
_new_arg_supported: bool = False
|
||||
_expects_other_args: bool = False
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if cls.embed_documents != Embeddings.embed_documents:
|
||||
warnings.warn(
|
||||
"Embedding models must implement abstract `_embed_documents` method"
|
||||
" instead of `embed_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
swap = cls.embed_documents
|
||||
cls.embed_documents = Embeddings.embed_documents # type: ignore[assignment]
|
||||
cls._embed_documents = swap # type: ignore[assignment]
|
||||
if (
|
||||
hasattr(cls, "aembed_documents")
|
||||
and cls.aembed_documents != Embeddings.aembed_documents
|
||||
):
|
||||
warnings.warn(
|
||||
"Embedding models must implement abstract `_aembed_documents` method"
|
||||
" instead of `aembed_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
aswap = cls.aembed_documents
|
||||
cls.aembed_documents = ( # type: ignore[assignment]
|
||||
Embeddings.aembed_documents
|
||||
)
|
||||
cls._aembed_documents = aswap # type: ignore[assignment]
|
||||
parameters = signature(cls._embed_documents).parameters
|
||||
cls._new_arg_supported = parameters.get("run_manager") is not None
|
||||
cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 1
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def _embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
|
||||
**kwargs: Any
|
||||
) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def _embed_query(
|
||||
self, text: str, *, run_manager: CallbackManagerForEmbeddingsRun, **kwargs: Any
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
@abstractmethod
|
||||
async def _aembed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
run_managers: Sequence[AsyncCallbackManagerForEmbeddingsRun],
|
||||
**kwargs: Any
|
||||
) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
@abstractmethod
|
||||
async def _aembed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForEmbeddingsRun,
|
||||
**kwargs: Any
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
raise NotImplementedError
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
)
|
||||
run_managers = callback_manager.on_embeddings_start(
|
||||
texts,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = self._embed_documents(
|
||||
texts, run_managers=run_managers, **kwargs
|
||||
)
|
||||
elif self._expects_other_args:
|
||||
result = self._embed_documents(texts, **kwargs)
|
||||
else:
|
||||
result = self._embed_documents(texts) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_embeddings_error(e)
|
||||
raise e
|
||||
else:
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_embeddings_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
def embed_query(
|
||||
self, text: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
)
|
||||
run_managers = callback_manager.on_embeddings_start(
|
||||
[text],
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = self._embed_query(text, run_manager=run_managers[0], **kwargs)
|
||||
elif self._expects_other_args:
|
||||
result = self._embed_query(text, **kwargs)
|
||||
else:
|
||||
result = self._embed_query(text) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
run_managers[0].on_embeddings_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_managers[0].on_embeddings_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
async def aembed_documents(
|
||||
self, texts: List[str], *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[List[float]]:
|
||||
"""Asynchronously embed search docs."""
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
)
|
||||
run_managers = await callback_manager.on_embeddings_start(
|
||||
texts,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = await self._aembed_documents(
|
||||
texts, run_managers=run_managers, **kwargs
|
||||
)
|
||||
elif self._expects_other_args:
|
||||
result = await self._aembed_documents(texts, **kwargs)
|
||||
else:
|
||||
result = await self._aembed_documents(texts) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
tasks = [run_manager.on_embeddings_error(e) for run_manager in run_managers]
|
||||
await asyncio.gather(*tasks)
|
||||
raise e
|
||||
else:
|
||||
tasks = [
|
||||
run_manager.on_embeddings_end(
|
||||
results,
|
||||
**kwargs,
|
||||
)
|
||||
for run_manager, results in zip(run_managers, result)
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
return result
|
||||
|
||||
async def aembed_query(
|
||||
self, text: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[float]:
|
||||
"""Asynchronously embed query text."""
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
)
|
||||
run_managers = await callback_manager.on_embeddings_start(
|
||||
[text],
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = await self._aembed_query(
|
||||
text, run_manager=run_managers[0], **kwargs
|
||||
)
|
||||
elif self._expects_other_args:
|
||||
result = await self._aembed_query(text, **kwargs)
|
||||
else:
|
||||
result = await self._aembed_query(text) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
await run_managers[0].on_embeddings_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_managers[0].on_embeddings_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user