Compare commits

...

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
472b434f02 tmp 2023-06-30 14:38:54 -07:00
3 changed files with 461 additions and 8 deletions

View File

@@ -46,6 +46,36 @@ class LLMManagerMixin:
"""Run when LLM errors."""
class EmbeddingsManagerMixin:
"""Mixin for Embeddings callbacks."""
def on_embeddings_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when an embedding model throws an error."""
def on_embeddings_end(
self,
embeddings: List[List[float]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when embeddings model finishes generating embeddings.
Args:
embeddings (List[List[float]]): The generated embeddings.
Returns:
Any: The result of the callback.
"""
class ChainManagerMixin:
"""Mixin for chain callbacks."""
@@ -144,6 +174,18 @@ class CallbackManagerMixin:
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)
def on_embeddings_start(
self,
serialized: Dict[str, Any],
texts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
"""Run when Embeddings starts running."""
def on_chain_start(
self,
serialized: Dict[str, Any],
@@ -185,6 +227,7 @@ class RunManagerMixin:
class BaseCallbackHandler(
LLMManagerMixin,
EmbeddingsManagerMixin,
ChainManagerMixin,
ToolManagerMixin,
CallbackManagerMixin,
@@ -216,6 +259,11 @@ class BaseCallbackHandler(
"""Whether to ignore chat model callbacks."""
return False
@property
def ignore_embeddings(self) -> bool:
"""Whether to ignore embeddings callbacks."""
return False
class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""
@@ -241,7 +289,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
) -> None:
"""Run when a chat model starts running."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
@@ -277,6 +325,38 @@ class AsyncCallbackHandler(BaseCallbackHandler):
) -> None:
"""Run when LLM errors."""
async def on_embeddings_start(
self,
serialized: Dict[str, Any],
texts: List[str],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: List[str] | None = None,
**kwargs: Any,
) -> None:
"""Run when embeddings call starts running."""
async def on_embeddings_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
"""Run when embeddings call ends running."""
async def on_embeddings_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
"""Run when embeddings call errors."""
async def on_chain_start(
self,
serialized: Dict[str, Any],

View File

@@ -26,6 +26,7 @@ from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
ChainManagerMixin,
EmbeddingsManagerMixin,
LLMManagerMixin,
RunManagerMixin,
ToolManagerMixin,
@@ -531,6 +532,52 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
)
class CallbackManagerForEmbeddingsRun(RunManager, EmbeddingsManagerMixin):
"""Callback manager for embeddings run."""
def on_embeddings_end(
self,
embeddings: List[List[float]],
**kwargs: Any,
) -> Any:
"""Run when embeddings are generated.
Args:
embeddings (List[List[float]]): The generated embeddings.
Returns:
Any: The result of the callback.
"""
_handle_event(
self.handlers,
"on_embeddings_end",
"ignore_embeddings",
embeddings,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_embeddings_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when embeddings errors.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
_handle_event(
self.handlers,
"on_embeddings_error",
"ignore_embeddings",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
"""Async callback manager for LLM run."""
@@ -781,6 +828,52 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
)
class AsyncCallbackManagerForEmbeddingsRun(RunManager, EmbeddingsManagerMixin):
"""Callback manager for embeddings run."""
async def on_embeddings_end(
self,
embeddings: List[List[float]],
**kwargs: Any,
) -> Any:
"""Run when embeddings are generated.
Args:
embeddings (List[List[float]]): The generated embeddings.
Returns:
Any: The result of the callback.
"""
_ahandle_event(
self.handlers,
"on_embeddings_end",
"ignore_embeddings",
embeddings,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_embeddings_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when embeddings errors.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
_ahandle_event(
self.handlers,
"on_embeddings_error",
"ignore_embeddings",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
"""Callback manager for tool run."""
@@ -993,6 +1086,47 @@ class CallbackManager(BaseCallbackManager):
return managers
def on_embeddings_start(
self,
serialized: Dict[str, Any],
texts: List[str],
**kwargs: Any,
) -> List[CallbackManagerForEmbeddingsRun]:
"""Run when embeddings model starts running.
Args:
serialized (Dict[str, Any]): The serialized embeddings model.
texts (List[str]): The list of texts.
Returns:
List[CallbackManagerForEmbeddingsRun]: A callback manager for each
text as an embeddings run.
"""
managers = []
for text in texts:
run_id_ = uuid4()
_handle_event(
self.handlers,
"on_embeddings_start",
"ignore_embeddings",
serialized,
[text],
run_id=run_id_,
parent_run_id=self.parent_run_id,
**kwargs,
)
managers.append(
CallbackManagerForEmbeddingsRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
inheritable_tags=self.inheritable_tags,
)
)
return managers
def on_chain_start(
self,
serialized: Dict[str, Any],
@@ -1226,6 +1360,51 @@ class AsyncCallbackManager(BaseCallbackManager):
await asyncio.gather(*tasks)
return managers
async def on_embeddings_start(
self,
serialized: Dict[str, Any],
texts: List[str],
**kwargs: Any,
) -> List[AsyncCallbackManagerForEmbeddingsRun]:
"""Run when embeddings model starts running.
Args:
serialized (Dict[str, Any]): The serialized embeddings model.
texts (List[str]): The list of texts.
Returns:
List[CallbackManagerForEmbeddingsRun]: A callback manager for each
text as an embeddings run.
"""
tasks = []
managers = []
for text in texts:
run_id_ = uuid4()
tasks.append(
_ahandle_event(
self.handlers,
"on_embeddings_start",
"ignore_embeddings",
serialized,
[text],
run_id=run_id_,
parent_run_id=self.parent_run_id,
**kwargs,
)
)
managers.append(
AsyncCallbackManagerForEmbeddingsRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
inheritable_tags=self.inheritable_tags,
)
)
await asyncio.gather(*tasks)
return managers
async def on_chain_start(
self,
serialized: Dict[str, Any],

View File

@@ -1,23 +1,217 @@
"""Interface for embedding models."""
import asyncio
import warnings
from abc import ABC, abstractmethod
from typing import List
from inspect import signature
from typing import Any, List, Sequence
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForEmbeddingsRun,
CallbackManager,
CallbackManagerForEmbeddingsRun,
Callbacks,
)
class Embeddings(ABC):
"""Interface for embedding models."""
_new_arg_supported: bool = False
_expects_other_args: bool = False
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if cls.embed_documents != Embeddings.embed_documents:
warnings.warn(
"Embedding models must implement abstract `_embed_documents` method"
" instead of `embed_documents`",
DeprecationWarning,
)
swap = cls.embed_documents
cls.embed_documents = Embeddings.embed_documents # type: ignore[assignment]
cls._embed_documents = swap # type: ignore[assignment]
if (
hasattr(cls, "aembed_documents")
and cls.aembed_documents != Embeddings.aembed_documents
):
warnings.warn(
"Embedding models must implement abstract `_aembed_documents` method"
" instead of `aembed_documents`",
DeprecationWarning,
)
aswap = cls.aembed_documents
cls.aembed_documents = ( # type: ignore[assignment]
Embeddings.aembed_documents
)
cls._aembed_documents = aswap # type: ignore[assignment]
parameters = signature(cls._embed_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None
cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 1
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def _embed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[CallbackManagerForEmbeddingsRun],
**kwargs: Any
) -> List[List[float]]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> List[float]:
def _embed_query(
self, text: str, *, run_manager: CallbackManagerForEmbeddingsRun, **kwargs: Any
) -> List[float]:
"""Embed query text."""
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
@abstractmethod
async def _aembed_documents(
self,
texts: List[str],
*,
run_managers: Sequence[AsyncCallbackManagerForEmbeddingsRun],
**kwargs: Any
) -> List[List[float]]:
"""Embed search docs."""
raise NotImplementedError
async def aembed_query(self, text: str) -> List[float]:
@abstractmethod
async def _aembed_query(
self,
text: str,
*,
run_manager: AsyncCallbackManagerForEmbeddingsRun,
**kwargs: Any
) -> List[float]:
"""Embed query text."""
raise NotImplementedError
def embed_documents(
self, texts: List[str], *, callbacks: Callbacks = None, **kwargs: Any
) -> List[List[float]]:
"""Embed search docs."""
callback_manager = CallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_managers = callback_manager.on_embeddings_start(
texts,
**kwargs,
)
try:
if self._new_arg_supported:
result = self._embed_documents(
texts, run_managers=run_managers, **kwargs
)
elif self._expects_other_args:
result = self._embed_documents(texts, **kwargs)
else:
result = self._embed_documents(texts) # type: ignore[call-arg]
except Exception as e:
for run_manager in run_managers:
run_manager.on_embeddings_error(e)
raise e
else:
for run_manager in run_managers:
run_manager.on_embeddings_end(
result,
**kwargs,
)
return result
def embed_query(
self, text: str, *, callbacks: Callbacks = None, **kwargs: Any
) -> List[float]:
"""Embed query text."""
from langchain.callbacks.manager import CallbackManager
callback_manager = CallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_managers = callback_manager.on_embeddings_start(
[text],
**kwargs,
)
try:
if self._new_arg_supported:
result = self._embed_query(text, run_manager=run_managers[0], **kwargs)
elif self._expects_other_args:
result = self._embed_query(text, **kwargs)
else:
result = self._embed_query(text) # type: ignore[call-arg]
except Exception as e:
run_managers[0].on_embeddings_error(e)
raise e
else:
run_managers[0].on_embeddings_end(
result,
**kwargs,
)
return result
async def aembed_documents(
self, texts: List[str], *, callbacks: Callbacks = None, **kwargs: Any
) -> List[List[float]]:
"""Asynchronously embed search docs."""
callback_manager = AsyncCallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_managers = await callback_manager.on_embeddings_start(
texts,
**kwargs,
)
try:
if self._new_arg_supported:
result = await self._aembed_documents(
texts, run_managers=run_managers, **kwargs
)
elif self._expects_other_args:
result = await self._aembed_documents(texts, **kwargs)
else:
result = await self._aembed_documents(texts) # type: ignore[call-arg]
except Exception as e:
tasks = [run_manager.on_embeddings_error(e) for run_manager in run_managers]
await asyncio.gather(*tasks)
raise e
else:
tasks = [
run_manager.on_embeddings_end(
results,
**kwargs,
)
for run_manager, results in zip(run_managers, result)
]
await asyncio.gather(*tasks)
return result
async def aembed_query(
self, text: str, *, callbacks: Callbacks = None, **kwargs: Any
) -> List[float]:
"""Asynchronously embed query text."""
from langchain.callbacks.manager import AsyncCallbackManager
callback_manager = AsyncCallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_managers = await callback_manager.on_embeddings_start(
[text],
**kwargs,
)
try:
if self._new_arg_supported:
result = await self._aembed_query(
text, run_manager=run_managers[0], **kwargs
)
elif self._expects_other_args:
result = await self._aembed_query(text, **kwargs)
else:
result = await self._aembed_query(text) # type: ignore[call-arg]
except Exception as e:
await run_managers[0].on_embeddings_error(e)
raise e
else:
await run_managers[0].on_embeddings_end(
result,
**kwargs,
)
return result