mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 12:01:54 +00:00
Propagate context vars in all classes/methods (#15329)
- Any direct usage of ThreadPoolExecutor or asyncio.run_in_executor needs manual handling of context vars <!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
@@ -30,7 +30,7 @@ from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
||||
from langchain_core.runnables.utils import AddableDict
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.input import get_color_mapping
|
||||
@@ -1437,7 +1437,7 @@ class AgentExecutor(Chain):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[AddableDict]:
|
||||
"""Enables streaming over steps taken to reach final output."""
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
iterator = AgentExecutorIterator(
|
||||
self,
|
||||
input,
|
||||
@@ -1458,7 +1458,7 @@ class AgentExecutor(Chain):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[AddableDict]:
|
||||
"""Enables streaming over steps taken to reach final output."""
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
iterator = AgentExecutorIterator(
|
||||
self,
|
||||
input,
|
||||
|
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Un
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
@@ -222,7 +222,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]].
|
||||
"""
|
||||
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Union
|
||||
@@ -85,12 +84,5 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
|
||||
message = result[0].message
|
||||
return self._parse_ai_message(message)
|
||||
|
||||
async def aparse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.parse_result, result
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
raise ValueError("Can only parse messages")
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Union
|
||||
@@ -92,12 +91,5 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
|
||||
message = result[0].message
|
||||
return parse_ai_message_to_openai_tool_action(message)
|
||||
|
||||
async def aparse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.parse_result, result
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
||||
raise ValueError("Can only parse messages")
|
||||
|
@@ -1,5 +1,4 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
@@ -19,7 +18,12 @@ from langchain_core.pydantic_v1 import (
|
||||
root_validator,
|
||||
validator,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
RunnableSerializable,
|
||||
ensure_config,
|
||||
run_in_executor,
|
||||
)
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
@@ -85,7 +89,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
return self(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
@@ -101,7 +105,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
config = config or {}
|
||||
config = ensure_config(config)
|
||||
return await self.acall(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
@@ -245,8 +249,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self._call, inputs, run_manager
|
||||
return await run_in_executor(
|
||||
None, self._call, inputs, run_manager.get_sync() if run_manager else None
|
||||
)
|
||||
|
||||
def __call__(
|
||||
|
@@ -1,16 +1,15 @@
|
||||
"""Interfaces to be implemented by general evaluators."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
from warnings import warn
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
@@ -189,15 +188,13 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
- value: the string value of the evaluation, if applicable.
|
||||
- reasoning: the reasoning for the evaluation, if applicable.
|
||||
""" # noqa: E501
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
return await run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self._evaluate_strings,
|
||||
prediction=prediction,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
),
|
||||
self._evaluate_strings,
|
||||
prediction=prediction,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def evaluate_strings(
|
||||
@@ -292,16 +289,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
Returns:
|
||||
dict: A dictionary containing the preference, scores, and/or other information.
|
||||
""" # noqa: E501
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
return await run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self._evaluate_string_pairs,
|
||||
prediction=prediction,
|
||||
prediction_b=prediction_b,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
),
|
||||
self._evaluate_string_pairs,
|
||||
prediction=prediction,
|
||||
prediction_b=prediction_b,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def evaluate_string_pairs(
|
||||
@@ -415,16 +410,14 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
|
||||
Returns:
|
||||
dict: The evaluation result.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
return await run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self._evaluate_agent_trajectory,
|
||||
prediction=prediction,
|
||||
agent_trajectory=agent_trajectory,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
),
|
||||
self._evaluate_agent_trajectory,
|
||||
prediction=prediction,
|
||||
agent_trajectory=agent_trajectory,
|
||||
reference=reference,
|
||||
input=input,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def evaluate_agent_trajectory(
|
||||
|
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.documents import BaseDocumentTransformer, Document
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
|
||||
@@ -28,7 +28,7 @@ class BaseDocumentCompressor(BaseModel, ABC):
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Compress retrieved documents given the query context."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
return await run_in_executor(
|
||||
None, self.compress_documents, documents, query, callbacks
|
||||
)
|
||||
|
||||
|
@@ -21,7 +21,6 @@ Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import pathlib
|
||||
@@ -29,7 +28,6 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from io import BytesIO, StringIO
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
@@ -283,14 +281,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
"""Transform sequence of documents by splitting them."""
|
||||
return self.split_documents(list(documents))
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Asynchronously transform a sequence of documents by splitting them."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.transform_documents, **kwargs), documents
|
||||
)
|
||||
|
||||
|
||||
class CharacterTextSplitter(TextSplitter):
|
||||
"""Splitting text that looks at characters."""
|
||||
|
Reference in New Issue
Block a user