Propagate context vars in all classes/methods (#15329)

- Any direct usage of ThreadPoolExecutor or asyncio.run_in_executor
needs manual handling of context vars

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos
2023-12-29 15:59:00 -08:00
committed by GitHub
39 changed files with 395 additions and 377 deletions

View File

@@ -30,7 +30,7 @@ from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from langchain_core.runnables.utils import AddableDict
from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping
@@ -1437,7 +1437,7 @@ class AgentExecutor(Chain):
**kwargs: Any,
) -> Iterator[AddableDict]:
"""Enables streaming over steps taken to reach final output."""
config = config or {}
config = ensure_config(config)
iterator = AgentExecutorIterator(
self,
input,
@@ -1458,7 +1458,7 @@ class AgentExecutor(Chain):
**kwargs: Any,
) -> AsyncIterator[AddableDict]:
"""Enables streaming over steps taken to reach final output."""
config = config or {}
config = ensure_config(config)
iterator = AgentExecutorIterator(
self,
input,

View File

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Un
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.load import dumpd
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
from langchain_core.tools import BaseTool
from langchain.callbacks.manager import CallbackManager
@@ -222,7 +222,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]].
"""
config = config or {}
config = ensure_config(config)
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),

View File

@@ -1,4 +1,3 @@
import asyncio
import json
from json import JSONDecodeError
from typing import List, Union
@@ -85,12 +84,5 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
message = result[0].message
return self._parse_ai_message(message)
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
)
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
raise ValueError("Can only parse messages")

View File

@@ -1,4 +1,3 @@
import asyncio
import json
from json import JSONDecodeError
from typing import List, Union
@@ -92,12 +91,5 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
message = result[0].message
return parse_ai_message_to_openai_tool_action(message)
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]:
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
)
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
raise ValueError("Can only parse messages")

View File

@@ -1,5 +1,4 @@
"""Base interface that all chains should implement."""
import asyncio
import inspect
import json
import logging
@@ -19,7 +18,12 @@ from langchain_core.pydantic_v1 import (
root_validator,
validator,
)
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables import (
RunnableConfig,
RunnableSerializable,
ensure_config,
run_in_executor,
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
@@ -85,7 +89,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
config = config or {}
config = ensure_config(config)
return self(
input,
callbacks=config.get("callbacks"),
@@ -101,7 +105,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
config = config or {}
config = ensure_config(config)
return await self.acall(
input,
callbacks=config.get("callbacks"),
@@ -245,8 +249,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self._call, inputs, run_manager
return await run_in_executor(
None, self._call, inputs, run_manager.get_sync() if run_manager else None
)
def __call__(

View File

@@ -1,16 +1,15 @@
"""Interfaces to be implemented by general evaluators."""
from __future__ import annotations
import asyncio
import logging
from abc import ABC, abstractmethod
from enum import Enum
from functools import partial
from typing import Any, Optional, Sequence, Tuple, Union
from warnings import warn
from langchain_core.agents import AgentAction
from langchain_core.language_models import BaseLanguageModel
from langchain_core.runnables.config import run_in_executor
from langchain.chains.base import Chain
@@ -189,15 +188,13 @@ class StringEvaluator(_EvalArgsMixin, ABC):
- value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable.
""" # noqa: E501
return await asyncio.get_running_loop().run_in_executor(
return await run_in_executor(
None,
partial(
self._evaluate_strings,
prediction=prediction,
reference=reference,
input=input,
**kwargs,
),
self._evaluate_strings,
prediction=prediction,
reference=reference,
input=input,
**kwargs,
)
def evaluate_strings(
@@ -292,16 +289,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
Returns:
dict: A dictionary containing the preference, scores, and/or other information.
""" # noqa: E501
return await asyncio.get_running_loop().run_in_executor(
return await run_in_executor(
None,
partial(
self._evaluate_string_pairs,
prediction=prediction,
prediction_b=prediction_b,
reference=reference,
input=input,
**kwargs,
),
self._evaluate_string_pairs,
prediction=prediction,
prediction_b=prediction_b,
reference=reference,
input=input,
**kwargs,
)
def evaluate_string_pairs(
@@ -415,16 +410,14 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
Returns:
dict: The evaluation result.
"""
return await asyncio.get_running_loop().run_in_executor(
return await run_in_executor(
None,
partial(
self._evaluate_agent_trajectory,
prediction=prediction,
agent_trajectory=agent_trajectory,
reference=reference,
input=input,
**kwargs,
),
self._evaluate_agent_trajectory,
prediction=prediction,
agent_trajectory=agent_trajectory,
reference=reference,
input=input,
**kwargs,
)
def evaluate_agent_trajectory(

View File

@@ -1,10 +1,10 @@
import asyncio
from abc import ABC, abstractmethod
from inspect import signature
from typing import List, Optional, Sequence, Union
from langchain_core.documents import BaseDocumentTransformer, Document
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.config import run_in_executor
from langchain.callbacks.manager import Callbacks
@@ -28,7 +28,7 @@ class BaseDocumentCompressor(BaseModel, ABC):
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Compress retrieved documents given the query context."""
return await asyncio.get_running_loop().run_in_executor(
return await run_in_executor(
None, self.compress_documents, documents, query, callbacks
)

View File

@@ -21,7 +21,6 @@ Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive
from __future__ import annotations
import asyncio
import copy
import logging
import pathlib
@@ -29,7 +28,6 @@ import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from functools import partial
from io import BytesIO, StringIO
from typing import (
AbstractSet,
@@ -283,14 +281,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
"""Transform sequence of documents by splitting them."""
return self.split_documents(list(documents))
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a sequence of documents by splitting them."""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents
)
class CharacterTextSplitter(TextSplitter):
"""Splitting text that looks at characters."""