mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 20:58:25 +00:00
Propagate context vars in all classes/methods (#15329)
- Any direct usage of ThreadPoolExecutor or asyncio.run_in_executor needs manual handling of context vars <!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
commit
99000c612e
@ -1,12 +1,9 @@
|
||||
"""ChatModel wrapper which returns user input as the response.."""
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from io import StringIO
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||
|
||||
import yaml
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
@ -111,15 +108,3 @@ class HumanInputChatModel(BaseChatModel):
|
||||
self.message_func(messages, **self.message_kwargs)
|
||||
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
|
||||
return ChatResult(generations=[ChatGeneration(message=user_input)])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
func = partial(
|
||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
@ -1,11 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
@ -125,18 +122,6 @@ class ChatMlflow(BaseChatModel):
|
||||
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
|
||||
return ChatMlflow._create_chat_result(resp)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
func = partial(
|
||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return self._default_params
|
||||
|
@ -1,11 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
@ -116,18 +113,6 @@ class ChatMLflowAIGateway(BaseChatModel):
|
||||
resp = mlflow.gateway.query(self.route, data=data)
|
||||
return ChatMLflowAIGateway._create_chat_result(resp)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
func = partial(
|
||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return self._default_params
|
||||
|
@ -1,7 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, cast
|
||||
|
||||
import requests
|
||||
@ -300,25 +298,3 @@ class PaiEasChatEndpoint(BaseChatModel):
|
||||
# break if stop sequence found
|
||||
if stop_seq_found:
|
||||
break
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if stream if stream is not None else self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
generation = chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
func = partial(
|
||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
|
||||
class BedrockEmbeddings(BaseModel, Embeddings):
|
||||
@ -181,9 +181,7 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
||||
Embeddings for the text.
|
||||
"""
|
||||
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.embed_query, text)
|
||||
)
|
||||
return await run_in_executor(None, self.embed_query, text)
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous compute doc embeddings using a Bedrock model.
|
||||
|
@ -1,12 +1,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -134,9 +134,7 @@ class ErnieEmbeddings(BaseModel, Embeddings):
|
||||
List[float]: Embeddings for the text.
|
||||
"""
|
||||
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.embed_query, text)
|
||||
)
|
||||
return await run_in_executor(None, self.embed_query, text)
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs.
|
||||
|
@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
@ -57,11 +55,3 @@ Note: SessionId must be received from previous Browser window creation."""
|
||||
print(f"{e}, retrying...")
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred: {e}")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
sessionId: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._run, sessionId)
|
||||
|
@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
@ -67,14 +65,3 @@ class MultionCreateSession(BaseTool):
|
||||
}
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred: {e}")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
url: Optional[str] = "https://www.google.com/",
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> dict:
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, self._run, query, url)
|
||||
|
||||
return result
|
||||
|
@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
@ -74,15 +72,3 @@ Note: sessionId must be received from previous Browser window creation."""
|
||||
return {"error": f"{e}", "Response": "retrying..."}
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred: {e}")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
sessionId: str,
|
||||
query: str,
|
||||
url: Optional[str] = "https://www.google.com/",
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> dict:
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, self._run, sessionId, query, url)
|
||||
|
||||
return result
|
||||
|
@ -1,10 +1,8 @@
|
||||
import asyncio
|
||||
import platform
|
||||
import warnings
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
@ -77,13 +75,3 @@ class ShellTool(BaseTool):
|
||||
) -> str:
|
||||
"""Run commands and return final output."""
|
||||
return self.process.run(commands)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
commands: Union[str, List[str]],
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Run commands asynchronously and return final output."""
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.process.run, commands
|
||||
)
|
||||
|
@ -1,13 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
import pickle
|
||||
import uuid
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
@ -24,6 +22,7 @@ from typing import (
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.docstore.base import AddableMixin, Docstore
|
||||
@ -359,7 +358,8 @@ class FAISS(VectorStore):
|
||||
"""
|
||||
|
||||
# This is a temporary workaround to make the similarity search asynchronous.
|
||||
func = partial(
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self.similarity_search_with_score_by_vector,
|
||||
embedding,
|
||||
k=k,
|
||||
@ -367,7 +367,6 @@ class FAISS(VectorStore):
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
@ -640,7 +639,8 @@ class FAISS(VectorStore):
|
||||
relevance and score for each.
|
||||
"""
|
||||
# This is a temporary workaround to make the similarity search asynchronous.
|
||||
func = partial(
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self.max_marginal_relevance_search_with_score_by_vector,
|
||||
embedding,
|
||||
k=k,
|
||||
@ -648,7 +648,6 @@ class FAISS(VectorStore):
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
|
@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import enum
|
||||
import logging
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -31,6 +29,7 @@ except ImportError:
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
@ -941,7 +940,8 @@ class PGVector(VectorStore):
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self.max_marginal_relevance_search_by_vector,
|
||||
embedding,
|
||||
k=k,
|
||||
@ -950,4 +950,3 @@ class PGVector(VectorStore):
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import uuid
|
||||
import warnings
|
||||
@ -25,6 +24,7 @@ from typing import (
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
@ -58,10 +58,9 @@ def sync_call_fallback(method: Callable) -> Callable:
|
||||
# by removing the first letter from the method name. For example,
|
||||
# if the async method is called ``aaad_texts``, the synchronous method
|
||||
# will be called ``aad_texts``.
|
||||
sync_method = functools.partial(
|
||||
getattr(self, method.__name__[1:]), *args, **kwargs
|
||||
return await run_in_executor(
|
||||
None, getattr(self, method.__name__[1:]), *args, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, sync_method)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
@ -23,7 +23,7 @@ from langchain_core.runnables.base import (
|
||||
RunnableSerializable,
|
||||
coerce_to_runnable,
|
||||
)
|
||||
from langchain_core.runnables.config import RunnableConfig, patch_config
|
||||
from langchain_core.runnables.config import RunnableConfig, ensure_config, patch_config
|
||||
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -186,7 +186,7 @@ class ContextGet(RunnableSerializable):
|
||||
]
|
||||
|
||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
if isinstance(self.key, list):
|
||||
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
|
||||
@ -196,7 +196,7 @@ class ContextGet(RunnableSerializable):
|
||||
async def ainvoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
if isinstance(self.key, list):
|
||||
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
|
||||
@ -281,7 +281,7 @@ class ContextSet(RunnableSerializable):
|
||||
]
|
||||
|
||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
for id_, mapper in zip(self.ids, self.keys.values()):
|
||||
if mapper is not None:
|
||||
@ -293,7 +293,7 @@ class ContextSet(RunnableSerializable):
|
||||
async def ainvoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
for id_, mapper in zip(self.ids, self.keys.values()):
|
||||
if mapper is not None:
|
||||
|
@ -4,13 +4,15 @@ import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import Context, copy_context
|
||||
from contextvars import copy_context
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generator,
|
||||
@ -272,25 +274,14 @@ def handle_event(
|
||||
# we end up in a deadlock, as we'd have gotten here from a
|
||||
# running coroutine, which we cannot interrupt to run this one.
|
||||
# The solution is to create a new loop in a new thread.
|
||||
with _executor_w_context(1) as executor:
|
||||
executor.submit(_run_coros, coros).result()
|
||||
with ThreadPoolExecutor(1) as executor:
|
||||
executor.submit(
|
||||
cast(Callable, copy_context().run), _run_coros, coros
|
||||
).result()
|
||||
else:
|
||||
_run_coros(coros)
|
||||
|
||||
|
||||
def _set_context(context: Context) -> None:
|
||||
for var, value in context.items():
|
||||
var.set(value)
|
||||
|
||||
|
||||
def _executor_w_context(max_workers: Optional[int] = None) -> ThreadPoolExecutor:
|
||||
return ThreadPoolExecutor(
|
||||
max_workers=max_workers,
|
||||
initializer=_set_context,
|
||||
initargs=(copy_context(),),
|
||||
)
|
||||
|
||||
|
||||
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
||||
if hasattr(asyncio, "Runner"):
|
||||
# Python 3.11+
|
||||
@ -315,7 +306,6 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
||||
|
||||
|
||||
async def _ahandle_event_for_handler(
|
||||
executor: ThreadPoolExecutor,
|
||||
handler: BaseCallbackHandler,
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
@ -332,13 +322,18 @@ async def _ahandle_event_for_handler(
|
||||
event(*args, **kwargs)
|
||||
else:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
executor, functools.partial(event, *args, **kwargs)
|
||||
None,
|
||||
cast(
|
||||
Callable,
|
||||
functools.partial(
|
||||
copy_context().run, event, *args, **kwargs
|
||||
),
|
||||
),
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
if event_name == "on_chat_model_start":
|
||||
message_strings = [get_buffer_string(m) for m in args[1]]
|
||||
await _ahandle_event_for_handler(
|
||||
executor,
|
||||
handler,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
@ -380,25 +375,23 @@ async def ahandle_event(
|
||||
*args: The arguments to pass to the event handler
|
||||
**kwargs: The keyword arguments to pass to the event handler
|
||||
"""
|
||||
with _executor_w_context() as executor:
|
||||
for handler in [h for h in handlers if h.run_inline]:
|
||||
await _ahandle_event_for_handler(
|
||||
executor, handler, event_name, ignore_condition_name, *args, **kwargs
|
||||
)
|
||||
await asyncio.gather(
|
||||
*(
|
||||
_ahandle_event_for_handler(
|
||||
executor,
|
||||
handler,
|
||||
event_name,
|
||||
ignore_condition_name,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
for handler in handlers
|
||||
if not handler.run_inline
|
||||
)
|
||||
for handler in [h for h in handlers if h.run_inline]:
|
||||
await _ahandle_event_for_handler(
|
||||
handler, event_name, ignore_condition_name, *args, **kwargs
|
||||
)
|
||||
await asyncio.gather(
|
||||
*(
|
||||
_ahandle_event_for_handler(
|
||||
handler,
|
||||
event_name,
|
||||
ignore_condition_name,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
for handler in handlers
|
||||
if not handler.run_inline
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
BRM = TypeVar("BRM", bound="BaseRunManager")
|
||||
@ -526,9 +519,17 @@ class ParentRunManager(RunManager):
|
||||
return manager
|
||||
|
||||
|
||||
class AsyncRunManager(BaseRunManager):
|
||||
class AsyncRunManager(BaseRunManager, ABC):
|
||||
"""Async Run Manager."""
|
||||
|
||||
@abstractmethod
|
||||
def get_sync(self) -> RunManager:
|
||||
"""Get the equivalent sync RunManager.
|
||||
|
||||
Returns:
|
||||
RunManager: The sync RunManager.
|
||||
"""
|
||||
|
||||
async def on_text(
|
||||
self,
|
||||
text: str,
|
||||
@ -664,6 +665,23 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
"""Async callback manager for LLM run."""
|
||||
|
||||
def get_sync(self) -> CallbackManagerForLLMRun:
|
||||
"""Get the equivalent sync RunManager.
|
||||
|
||||
Returns:
|
||||
CallbackManagerForLLMRun: The sync RunManager.
|
||||
"""
|
||||
return CallbackManagerForLLMRun(
|
||||
run_id=self.run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
@ -818,6 +836,23 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
"""Async callback manager for chain run."""
|
||||
|
||||
def get_sync(self) -> CallbackManagerForChainRun:
|
||||
"""Get the equivalent sync RunManager.
|
||||
|
||||
Returns:
|
||||
CallbackManagerForChainRun: The sync RunManager.
|
||||
"""
|
||||
return CallbackManagerForChainRun(
|
||||
run_id=self.run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||
) -> None:
|
||||
@ -948,6 +983,23 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
|
||||
class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
||||
"""Async callback manager for tool run."""
|
||||
|
||||
def get_sync(self) -> CallbackManagerForToolRun:
|
||||
"""Get the equivalent sync RunManager.
|
||||
|
||||
Returns:
|
||||
CallbackManagerForToolRun: The sync RunManager.
|
||||
"""
|
||||
return CallbackManagerForToolRun(
|
||||
run_id=self.run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running.
|
||||
|
||||
@ -1031,6 +1083,23 @@ class AsyncCallbackManagerForRetrieverRun(
|
||||
):
|
||||
"""Async callback manager for retriever run."""
|
||||
|
||||
def get_sync(self) -> CallbackManagerForRetrieverRun:
|
||||
"""Get the equivalent sync RunManager.
|
||||
|
||||
Returns:
|
||||
CallbackManagerForRetrieverRun: The sync RunManager.
|
||||
"""
|
||||
return CallbackManagerForRetrieverRun(
|
||||
run_id=self.run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
async def on_retriever_end(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> None:
|
||||
|
@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Sequence
|
||||
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@ -69,6 +69,6 @@ class BaseDocumentTransformer(ABC):
|
||||
Returns:
|
||||
A list of transformed Documents.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.transform_documents, **kwargs), documents
|
||||
return await run_in_executor(
|
||||
None, self.transform_documents, documents, **kwargs
|
||||
)
|
||||
|
@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models."""
|
||||
@ -16,12 +17,8 @@ class Embeddings(ABC):
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_documents, texts
|
||||
)
|
||||
return await run_in_executor(None, self.embed_documents, texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_query, text
|
||||
)
|
||||
return await run_in_executor(None, self.embed_query, text)
|
||||
|
@ -4,7 +4,6 @@ import asyncio
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -45,6 +44,7 @@ from langchain_core.outputs import (
|
||||
)
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
@ -158,7 +158,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
return cast(
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
@ -180,7 +180,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
@ -206,7 +206,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
)
|
||||
else:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input).to_messages()
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
@ -264,7 +264,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
||||
)
|
||||
else:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input).to_messages()
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
@ -605,8 +605,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._generate, **kwargs), messages, stop, run_manager
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self._generate,
|
||||
messages,
|
||||
stop,
|
||||
run_manager.get_sync() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _stream(
|
||||
@ -766,7 +771,11 @@ class SimpleChatModel(BaseChatModel):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
func = partial(
|
||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self._generate,
|
||||
messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager.get_sync() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
@ -8,7 +8,6 @@ import json
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
@ -52,7 +51,8 @@ from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
||||
from langchain_core.runnables import RunnableConfig, get_config_list
|
||||
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -221,7 +221,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
return (
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)],
|
||||
@ -244,7 +244,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
@ -362,7 +362,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
@ -419,7 +419,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
@ -483,8 +483,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._generate, **kwargs), prompts, stop, run_manager
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self._generate,
|
||||
prompts,
|
||||
stop,
|
||||
run_manager.get_sync() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _stream(
|
||||
@ -1049,8 +1054,13 @@ class LLM(BaseLLM):
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._call, **kwargs), prompt, stop, run_manager
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self._call,
|
||||
prompt,
|
||||
stop,
|
||||
run_manager.get_sync() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
|
@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -20,6 +18,7 @@ from typing_extensions import get_args
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
@ -54,9 +53,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.parse_result, result
|
||||
)
|
||||
return await run_in_executor(None, self.parse_result, result)
|
||||
|
||||
|
||||
class BaseGenerationOutputParser(
|
||||
@ -247,9 +244,7 @@ class BaseOutputParser(
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, functools.partial(self.parse_result, partial=partial), result
|
||||
)
|
||||
return await run_in_executor(None, self.parse_result, result, partial=partial)
|
||||
|
||||
async def aparse(self, text: str) -> T:
|
||||
"""Parse a single string model output into some structure.
|
||||
@ -260,7 +255,7 @@ class BaseOutputParser(
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(None, self.parse, text)
|
||||
return await run_in_executor(None, self.parse, text)
|
||||
|
||||
# TODO: rename 'completion' -> 'text'.
|
||||
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
||||
|
@ -1,15 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableSerializable,
|
||||
ensure_config,
|
||||
)
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import (
|
||||
@ -113,7 +117,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None
|
||||
) -> List[Document]:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
return self.get_relevant_documents(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
@ -128,7 +132,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Document]:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
return await self.aget_relevant_documents(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
@ -159,8 +163,11 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._get_relevant_documents, run_manager=run_manager), query
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self._get_relevant_documents,
|
||||
query,
|
||||
run_manager=run_manager.get_sync(),
|
||||
)
|
||||
|
||||
def get_relevant_documents(
|
||||
|
@ -27,8 +27,10 @@ from langchain_core.runnables.base import (
|
||||
from langchain_core.runnables.branch import RunnableBranch
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
ensure_config,
|
||||
get_config_list,
|
||||
patch_config,
|
||||
run_in_executor,
|
||||
)
|
||||
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
|
||||
from langchain_core.runnables.passthrough import (
|
||||
@ -42,6 +44,7 @@ from langchain_core.runnables.utils import (
|
||||
ConfigurableField,
|
||||
ConfigurableFieldMultiOption,
|
||||
ConfigurableFieldSingleOption,
|
||||
ConfigurableFieldSpec,
|
||||
aadd,
|
||||
add,
|
||||
)
|
||||
@ -51,6 +54,9 @@ __all__ = [
|
||||
"ConfigurableField",
|
||||
"ConfigurableFieldSingleOption",
|
||||
"ConfigurableFieldMultiOption",
|
||||
"ConfigurableFieldSpec",
|
||||
"ensure_config",
|
||||
"run_in_executor",
|
||||
"patch_config",
|
||||
"RouterInput",
|
||||
"RouterRunnable",
|
||||
|
@ -6,7 +6,7 @@ import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import FIRST_COMPLETED, wait
|
||||
from copy import deepcopy
|
||||
from functools import partial, wraps
|
||||
from functools import wraps
|
||||
from itertools import groupby, tee
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
@ -47,6 +47,7 @@ from langchain_core.runnables.config import (
|
||||
get_executor_for_config,
|
||||
merge_configs,
|
||||
patch_config,
|
||||
run_in_executor,
|
||||
)
|
||||
from langchain_core.runnables.graph import Graph
|
||||
from langchain_core.runnables.utils import (
|
||||
@ -472,10 +473,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
Subclasses should override this method if they can run asynchronously.
|
||||
"""
|
||||
with get_executor_for_config(config) as executor:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
executor, partial(self.invoke, **kwargs), input, config
|
||||
)
|
||||
return await run_in_executor(config, self.invoke, input, config, **kwargs)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
@ -665,7 +663,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
|
||||
# Assign the stream handler to the config
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks is None:
|
||||
config["callbacks"] = [stream]
|
||||
@ -2883,10 +2881,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
@wraps(self.func)
|
||||
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||
with get_executor_for_config(config) as executor:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
executor, partial(self.func, **kwargs), *args
|
||||
)
|
||||
return await run_in_executor(config, self.func, *args, **kwargs)
|
||||
|
||||
afunc = f
|
||||
|
||||
@ -2913,7 +2908,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
def _config(
|
||||
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
|
||||
) -> RunnableConfig:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
|
||||
if config.get("run_name") is None:
|
||||
try:
|
||||
@ -3052,9 +3047,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
@wraps(self.func)
|
||||
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.func, **kwargs), *args
|
||||
)
|
||||
return await run_in_executor(config, self.func, *args, **kwargs)
|
||||
|
||||
afunc = f
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
import asyncio
|
||||
from concurrent.futures import Executor, Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from contextvars import Context, copy_context
|
||||
from contextvars import ContextVar, copy_context
|
||||
from functools import partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -10,13 +12,16 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import ParamSpec, TypedDict
|
||||
|
||||
from langchain_core.runnables.utils import (
|
||||
Input,
|
||||
@ -91,6 +96,11 @@ class RunnableConfig(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
|
||||
var_child_runnable_config = ContextVar(
|
||||
"child_runnable_config", default=RunnableConfig()
|
||||
)
|
||||
|
||||
|
||||
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
"""Ensure that a config is a dict with all keys present.
|
||||
|
||||
@ -107,6 +117,10 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
callbacks=None,
|
||||
recursion_limit=25,
|
||||
)
|
||||
if var_config := var_child_runnable_config.get():
|
||||
empty.update(
|
||||
cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None})
|
||||
)
|
||||
if config is not None:
|
||||
empty.update(
|
||||
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
|
||||
@ -388,9 +402,51 @@ def get_async_callback_manager_for_config(
|
||||
)
|
||||
|
||||
|
||||
def _set_context(context: Context) -> None:
|
||||
for var, value in context.items():
|
||||
var.set(value)
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ContextThreadPoolExecutor(ThreadPoolExecutor):
|
||||
"""ThreadPoolExecutor that copies the context to the child thread."""
|
||||
|
||||
def submit( # type: ignore[override]
|
||||
self,
|
||||
func: Callable[P, T],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Future[T]:
|
||||
"""Submit a function to the executor.
|
||||
|
||||
Args:
|
||||
func (Callable[..., T]): The function to submit.
|
||||
*args (Any): The positional arguments to the function.
|
||||
**kwargs (Any): The keyword arguments to the function.
|
||||
|
||||
Returns:
|
||||
Future[T]: The future for the function.
|
||||
"""
|
||||
return super().submit(
|
||||
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs))
|
||||
)
|
||||
|
||||
def map(
|
||||
self,
|
||||
fn: Callable[..., T],
|
||||
*iterables: Iterable[Any],
|
||||
timeout: float | None = None,
|
||||
chunksize: int = 1,
|
||||
) -> Iterator[T]:
|
||||
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
|
||||
|
||||
def _wrapped_fn(*args: Any) -> T:
|
||||
return contexts.pop().run(fn, *args)
|
||||
|
||||
return super().map(
|
||||
_wrapped_fn,
|
||||
*iterables,
|
||||
timeout=timeout,
|
||||
chunksize=chunksize,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -406,9 +462,36 @@ def get_executor_for_config(
|
||||
Generator[Executor, None, None]: The executor.
|
||||
"""
|
||||
config = config or {}
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=config.get("max_concurrency"),
|
||||
initializer=_set_context,
|
||||
initargs=(copy_context(),),
|
||||
with ContextThreadPoolExecutor(
|
||||
max_workers=config.get("max_concurrency")
|
||||
) as executor:
|
||||
yield executor
|
||||
|
||||
|
||||
async def run_in_executor(
|
||||
executor_or_config: Optional[Union[Executor, RunnableConfig]],
|
||||
func: Callable[P, T],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> T:
|
||||
"""Run a function in an executor.
|
||||
|
||||
Args:
|
||||
executor (Executor): The executor.
|
||||
func (Callable[P, Output]): The function.
|
||||
*args (Any): The positional arguments to the function.
|
||||
**kwargs (Any): The keyword arguments to the function.
|
||||
|
||||
Returns:
|
||||
Output: The output of the function.
|
||||
"""
|
||||
if executor_or_config is None or isinstance(executor_or_config, dict):
|
||||
# Use default executor with context copied from current context
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs)),
|
||||
)
|
||||
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
executor_or_config, partial(func, **kwargs), *args
|
||||
)
|
||||
|
@ -23,6 +23,7 @@ from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
ensure_config,
|
||||
get_config_list,
|
||||
get_executor_for_config,
|
||||
)
|
||||
@ -259,7 +260,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
|
||||
configurable_fields = {
|
||||
specs_by_id[k][0]: v
|
||||
@ -392,7 +393,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
which = config.get("configurable", {}).get(self.which.id, self.default_key)
|
||||
# remap configurable keys for the chosen alternative
|
||||
if self.prefix_keys:
|
||||
|
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -18,6 +17,7 @@ from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.load import load
|
||||
from langchain_core.pydantic_v1 import BaseModel, create_model
|
||||
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.runnables.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
@ -331,9 +331,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
async def _aenter_history(
|
||||
self, input: Dict[str, Any], config: RunnableConfig
|
||||
) -> List[BaseMessage]:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self._enter_history, input, config
|
||||
)
|
||||
return await run_in_executor(config, self._enter_history, input, config)
|
||||
|
||||
def _exit_history(self, run: Run, config: RunnableConfig) -> None:
|
||||
hist = config["configurable"]["message_history"]
|
||||
|
@ -31,6 +31,7 @@ from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
acall_func_with_variable_args,
|
||||
call_func_with_variable_args,
|
||||
ensure_config,
|
||||
get_executor_for_config,
|
||||
patch_config,
|
||||
)
|
||||
@ -206,7 +207,9 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Other:
|
||||
if self.func is not None:
|
||||
call_func_with_variable_args(self.func, input, config or {}, **kwargs)
|
||||
call_func_with_variable_args(
|
||||
self.func, input, ensure_config(config), **kwargs
|
||||
)
|
||||
return self._call_with_config(identity, input, config)
|
||||
|
||||
async def ainvoke(
|
||||
@ -217,10 +220,12 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
) -> Other:
|
||||
if self.afunc is not None:
|
||||
await acall_func_with_variable_args(
|
||||
self.afunc, input, config or {}, **kwargs
|
||||
self.afunc, input, ensure_config(config), **kwargs
|
||||
)
|
||||
elif self.func is not None:
|
||||
call_func_with_variable_args(self.func, input, config or {}, **kwargs)
|
||||
call_func_with_variable_args(
|
||||
self.func, input, ensure_config(config), **kwargs
|
||||
)
|
||||
return await self._acall_with_config(aidentity, input, config)
|
||||
|
||||
def transform(
|
||||
@ -243,7 +248,9 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
final = final + chunk
|
||||
|
||||
if final is not None:
|
||||
call_func_with_variable_args(self.func, final, config or {}, **kwargs)
|
||||
call_func_with_variable_args(
|
||||
self.func, final, ensure_config(config), **kwargs
|
||||
)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
@ -269,7 +276,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
final = final + chunk
|
||||
|
||||
if final is not None:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
if self.afunc is not None:
|
||||
await acall_func_with_variable_args(
|
||||
self.afunc, final, config, **kwargs
|
||||
@ -458,7 +465,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
)
|
||||
|
||||
# get executor to start map output stream in background
|
||||
with get_executor_for_config(config or {}) as executor:
|
||||
with get_executor_for_config(config) as executor:
|
||||
# start map output stream
|
||||
first_map_chunk_future = executor.submit(
|
||||
next,
|
||||
|
@ -1,11 +1,9 @@
|
||||
"""Base implementation for tools or skills."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
@ -26,7 +24,13 @@ from langchain_core.pydantic_v1 import (
|
||||
root_validator,
|
||||
validate_arguments,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableSerializable,
|
||||
ensure_config,
|
||||
)
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
|
||||
class SchemaAnnotationError(TypeError):
|
||||
@ -202,7 +206,7 @@ class ChildTool(BaseTool):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
return self.run(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
@ -218,7 +222,7 @@ class ChildTool(BaseTool):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
return await self.arun(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
@ -280,11 +284,7 @@ class ChildTool(BaseTool):
|
||||
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
|
||||
to child implementations to enable tracing,
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
partial(self._run, **kwargs),
|
||||
*args,
|
||||
)
|
||||
return await run_in_executor(None, self._run, *args, **kwargs)
|
||||
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||
# For backwards compatibility, if run_input is a string,
|
||||
@ -468,9 +468,7 @@ class Tool(BaseTool):
|
||||
) -> Any:
|
||||
if not self.coroutine:
|
||||
# If the tool does not implement async, fall back to default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.invoke, input, config, **kwargs)
|
||||
)
|
||||
return await run_in_executor(config, self.invoke, input, config, **kwargs)
|
||||
|
||||
return await super().ainvoke(input, config, **kwargs)
|
||||
|
||||
@ -538,8 +536,12 @@ class Tool(BaseTool):
|
||||
else await self.coroutine(*args, **kwargs)
|
||||
)
|
||||
else:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._run, run_manager=run_manager, **kwargs), *args
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self._run,
|
||||
run_manager=run_manager.get_sync() if run_manager else None,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# TODO: this is for backwards compatibility, remove in future
|
||||
@ -599,9 +601,7 @@ class StructuredTool(BaseTool):
|
||||
) -> Any:
|
||||
if not self.coroutine:
|
||||
# If the tool does not implement async, fall back to default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.invoke, input, config, **kwargs)
|
||||
)
|
||||
return await run_in_executor(config, self.invoke, input, config, **kwargs)
|
||||
|
||||
return await super().ainvoke(input, config, **kwargs)
|
||||
|
||||
@ -652,10 +652,12 @@ class StructuredTool(BaseTool):
|
||||
if new_argument_supported
|
||||
else await self.coroutine(*args, **kwargs)
|
||||
)
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
return await run_in_executor(
|
||||
None,
|
||||
partial(self._run, run_manager=run_manager, **kwargs),
|
||||
self._run,
|
||||
run_manager=run_manager.get_sync() if run_manager else None,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -24,6 +22,7 @@ from typing import (
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import (
|
||||
@ -103,9 +102,7 @@ class VectorStore(ABC):
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.add_texts, **kwargs), texts, metadatas
|
||||
)
|
||||
return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs)
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
@ -224,8 +221,9 @@ class VectorStore(ABC):
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search_with_score, *args, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
return await run_in_executor(
|
||||
None, self.similarity_search_with_score, *args, **kwargs
|
||||
)
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
@ -383,8 +381,7 @@ class VectorStore(ABC):
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search, query, k=k, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
return await run_in_executor(None, self.similarity_search, query, k=k, **kwargs)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
@ -408,8 +405,9 @@ class VectorStore(ABC):
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
return await run_in_executor(
|
||||
None, self.similarity_search_by_vector, embedding, k=k, **kwargs
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
@ -450,7 +448,8 @@ class VectorStore(ABC):
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self.max_marginal_relevance_search,
|
||||
query,
|
||||
k=k,
|
||||
@ -458,7 +457,6 @@ class VectorStore(ABC):
|
||||
lambda_mult=lambda_mult,
|
||||
**kwargs,
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
@ -541,8 +539,8 @@ class VectorStore(ABC):
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
|
||||
return await run_in_executor(
|
||||
None, cls.from_texts, texts, embedding, metadatas, **kwargs
|
||||
)
|
||||
|
||||
def _get_retriever_tags(self) -> List[str]:
|
||||
|
@ -5,6 +5,9 @@ EXPECTED_ALL = [
|
||||
"ConfigurableField",
|
||||
"ConfigurableFieldSingleOption",
|
||||
"ConfigurableFieldMultiOption",
|
||||
"ConfigurableFieldSpec",
|
||||
"ensure_config",
|
||||
"run_in_executor",
|
||||
"patch_config",
|
||||
"RouterInput",
|
||||
"RouterRunnable",
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""A tool for running python code in a REPL."""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import re
|
||||
import sys
|
||||
from contextlib import redirect_stdout
|
||||
@ -14,6 +13,7 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
from langchain_experimental.utilities.python import PythonREPL
|
||||
|
||||
@ -72,10 +72,7 @@ class PythonREPLTool(BaseTool):
|
||||
if self.sanitize_input:
|
||||
query = sanitize_input(query)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, self.run, query)
|
||||
|
||||
return result
|
||||
return await run_in_executor(None, self.run, query)
|
||||
|
||||
|
||||
class PythonInputs(BaseModel):
|
||||
@ -144,7 +141,4 @@ class PythonAstREPLTool(BaseTool):
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, self._run, query)
|
||||
|
||||
return result
|
||||
return await run_in_executor(None, self._run, query)
|
||||
|
@ -30,7 +30,7 @@ from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
||||
from langchain_core.runnables.utils import AddableDict
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.input import get_color_mapping
|
||||
@ -1437,7 +1437,7 @@ class AgentExecutor(Chain):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[AddableDict]:
|
||||
"""Enables streaming over steps taken to reach final output."""
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
iterator = AgentExecutorIterator(
|
||||
self,
|
||||
input,
|
||||
@ -1458,7 +1458,7 @@ class AgentExecutor(Chain):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[AddableDict]:
|
||||
"""Enables streaming over steps taken to reach final output."""
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
iterator = AgentExecutorIterator(
|
||||
self,
|
||||
input,
|
||||
|
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Un
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
@ -222,7 +222,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]].
|
||||
"""
|
||||
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
|
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Union
|
||||
@ -85,12 +84,5 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
|
||||
message = result[0].message
|
||||
return self._parse_ai_message(message)
|
||||
|
||||
async def aparse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.parse_result, result
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
raise ValueError("Can only parse messages")
|
||||
|
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Union
|
||||
@ -92,12 +91,5 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
|
||||
message = result[0].message
|
||||
return parse_ai_message_to_openai_tool_action(message)
|
||||
|
||||
async def aparse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.parse_result, result
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
||||
raise ValueError("Can only parse messages")
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
@ -19,7 +18,12 @@ from langchain_core.pydantic_v1 import (
|
||||
root_validator,
|
||||
validator,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
RunnableSerializable,
|
||||
ensure_config,
|
||||
run_in_executor,
|
||||
)
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
@ -85,7 +89,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
return self(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
@ -101,7 +105,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
return await self.acall(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
@ -245,8 +249,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self._call, inputs, run_manager
|
||||
return await run_in_executor(
|
||||
None, self._call, inputs, run_manager.get_sync() if run_manager else None
|
||||
)
|
||||
|
||||
def __call__(
|
||||
|
@ -1,16 +1,15 @@
|
||||
"""Interfaces to be implemented by general evaluators."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
from warnings import warn
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
@ -189,15 +188,13 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
- value: the string value of the evaluation, if applicable.
|
||||
- reasoning: the reasoning for the evaluation, if applicable.
|
||||
""" # noqa: E501
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
return await run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self._evaluate_strings,
|
||||
prediction=prediction,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
),
|
||||
self._evaluate_strings,
|
||||
prediction=prediction,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def evaluate_strings(
|
||||
@ -292,16 +289,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
Returns:
|
||||
dict: A dictionary containing the preference, scores, and/or other information.
|
||||
""" # noqa: E501
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
return await run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self._evaluate_string_pairs,
|
||||
prediction=prediction,
|
||||
prediction_b=prediction_b,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
),
|
||||
self._evaluate_string_pairs,
|
||||
prediction=prediction,
|
||||
prediction_b=prediction_b,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def evaluate_string_pairs(
|
||||
@ -415,16 +410,14 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
|
||||
Returns:
|
||||
dict: The evaluation result.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
return await run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self._evaluate_agent_trajectory,
|
||||
prediction=prediction,
|
||||
agent_trajectory=agent_trajectory,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
),
|
||||
self._evaluate_agent_trajectory,
|
||||
prediction=prediction,
|
||||
agent_trajectory=agent_trajectory,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def evaluate_agent_trajectory(
|
||||
|
@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.documents import BaseDocumentTransformer, Document
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
|
||||
@ -28,7 +28,7 @@ class BaseDocumentCompressor(BaseModel, ABC):
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Compress retrieved documents given the query context."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
return await run_in_executor(
|
||||
None, self.compress_documents, documents, query, callbacks
|
||||
)
|
||||
|
||||
|
@ -21,7 +21,6 @@ Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import pathlib
|
||||
@ -29,7 +28,6 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from io import BytesIO, StringIO
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
@ -283,14 +281,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
"""Transform sequence of documents by splitting them."""
|
||||
return self.split_documents(list(documents))
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Asynchronously transform a sequence of documents by splitting them."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.transform_documents, **kwargs), documents
|
||||
)
|
||||
|
||||
|
||||
class CharacterTextSplitter(TextSplitter):
|
||||
"""Splitting text that looks at characters."""
|
||||
|
Loading…
Reference in New Issue
Block a user