mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 07:07:34 +00:00
core[patch]: Add B(bugbear) ruff rules (#25520)
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
d5ddaac1fc
commit
ff0df5ea15
@ -270,7 +270,7 @@ def warn_beta(
|
|||||||
message += f" {addendum}"
|
message += f" {addendum}"
|
||||||
|
|
||||||
warning = LangChainBetaWarning(message)
|
warning = LangChainBetaWarning(message)
|
||||||
warnings.warn(warning, category=LangChainBetaWarning, stacklevel=2)
|
warnings.warn(warning, category=LangChainBetaWarning, stacklevel=4)
|
||||||
|
|
||||||
|
|
||||||
def surface_langchain_beta_warnings() -> None:
|
def surface_langchain_beta_warnings() -> None:
|
||||||
|
@ -444,7 +444,7 @@ def warn_deprecated(
|
|||||||
LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning
|
LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning
|
||||||
)
|
)
|
||||||
warning = warning_cls(message)
|
warning = warning_cls(message)
|
||||||
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=2)
|
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=4)
|
||||||
|
|
||||||
|
|
||||||
def surface_langchain_deprecation_warnings() -> None:
|
def surface_langchain_deprecation_warnings() -> None:
|
||||||
|
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.documents.base import Blob
|
from langchain_core.documents.base import Blob
|
||||||
|
|
||||||
|
|
||||||
class BaseLoader(ABC):
|
class BaseLoader(ABC): # noqa: B024
|
||||||
"""Interface for Document Loader.
|
"""Interface for Document Loader.
|
||||||
|
|
||||||
Implementations should implement the lazy-loading method using generators
|
Implementations should implement the lazy-loading method using generators
|
||||||
|
@ -80,12 +80,12 @@ def _texts_to_nodes(
|
|||||||
for text in texts:
|
for text in texts:
|
||||||
try:
|
try:
|
||||||
_metadata = next(metadatas_it).copy() if metadatas_it else {}
|
_metadata = next(metadatas_it).copy() if metadatas_it else {}
|
||||||
except StopIteration:
|
except StopIteration as e:
|
||||||
raise ValueError("texts iterable longer than metadatas")
|
raise ValueError("texts iterable longer than metadatas") from e
|
||||||
try:
|
try:
|
||||||
_id = next(ids_it) if ids_it else None
|
_id = next(ids_it) if ids_it else None
|
||||||
except StopIteration:
|
except StopIteration as e:
|
||||||
raise ValueError("texts iterable longer than ids")
|
raise ValueError("texts iterable longer than ids") from e
|
||||||
|
|
||||||
links = _metadata.pop(METADATA_LINKS_KEY, [])
|
links = _metadata.pop(METADATA_LINKS_KEY, [])
|
||||||
if not isinstance(links, list):
|
if not isinstance(links, list):
|
||||||
|
@ -91,7 +91,7 @@ class _HashedDocument(Document):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed to hash metadata: {e}. "
|
f"Failed to hash metadata: {e}. "
|
||||||
f"Please use a dict that can be serialized using json."
|
f"Please use a dict that can be serialized using json."
|
||||||
)
|
) from e
|
||||||
|
|
||||||
values["content_hash"] = content_hash
|
values["content_hash"] = content_hash
|
||||||
values["metadata_hash"] = metadata_hash
|
values["metadata_hash"] = metadata_hash
|
||||||
|
@ -64,12 +64,12 @@ def get_tokenizer() -> Any:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from transformers import GPT2TokenizerFast # type: ignore[import]
|
from transformers import GPT2TokenizerFast # type: ignore[import]
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import transformers python package. "
|
"Could not import transformers python package. "
|
||||||
"This is needed in order to calculate get_token_ids. "
|
"This is needed in order to calculate get_token_ids. "
|
||||||
"Please install it with `pip install transformers`."
|
"Please install it with `pip install transformers`."
|
||||||
)
|
) from e
|
||||||
# create a GPT-2 tokenizer instance
|
# create a GPT-2 tokenizer instance
|
||||||
return GPT2TokenizerFast.from_pretrained("gpt2")
|
return GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
@ -235,6 +235,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
"callback_manager is deprecated. Please use callbacks instead.",
|
"callback_manager is deprecated. Please use callbacks instead.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=5,
|
||||||
)
|
)
|
||||||
values["callbacks"] = values.pop("callback_manager", None)
|
values["callbacks"] = values.pop("callback_manager", None)
|
||||||
return values
|
return values
|
||||||
|
@ -69,6 +69,10 @@ class FakeListLLM(LLM):
|
|||||||
return {"responses": self.responses}
|
return {"responses": self.responses}
|
||||||
|
|
||||||
|
|
||||||
|
class FakeListLLMError(Exception):
|
||||||
|
"""Fake error for testing purposes."""
|
||||||
|
|
||||||
|
|
||||||
class FakeStreamingListLLM(FakeListLLM):
|
class FakeStreamingListLLM(FakeListLLM):
|
||||||
"""Fake streaming list LLM for testing purposes.
|
"""Fake streaming list LLM for testing purposes.
|
||||||
|
|
||||||
@ -98,7 +102,7 @@ class FakeStreamingListLLM(FakeListLLM):
|
|||||||
self.error_on_chunk_number is not None
|
self.error_on_chunk_number is not None
|
||||||
and i_c == self.error_on_chunk_number
|
and i_c == self.error_on_chunk_number
|
||||||
):
|
):
|
||||||
raise Exception("Fake error")
|
raise FakeListLLMError
|
||||||
yield c
|
yield c
|
||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
@ -118,5 +122,5 @@ class FakeStreamingListLLM(FakeListLLM):
|
|||||||
self.error_on_chunk_number is not None
|
self.error_on_chunk_number is not None
|
||||||
and i_c == self.error_on_chunk_number
|
and i_c == self.error_on_chunk_number
|
||||||
):
|
):
|
||||||
raise Exception("Fake error")
|
raise FakeListLLMError
|
||||||
yield c
|
yield c
|
||||||
|
@ -44,6 +44,10 @@ class FakeMessagesListChatModel(BaseChatModel):
|
|||||||
return "fake-messages-list-chat-model"
|
return "fake-messages-list-chat-model"
|
||||||
|
|
||||||
|
|
||||||
|
class FakeListChatModelError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FakeListChatModel(SimpleChatModel):
|
class FakeListChatModel(SimpleChatModel):
|
||||||
"""Fake ChatModel for testing purposes."""
|
"""Fake ChatModel for testing purposes."""
|
||||||
|
|
||||||
@ -93,7 +97,7 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
self.error_on_chunk_number is not None
|
self.error_on_chunk_number is not None
|
||||||
and i_c == self.error_on_chunk_number
|
and i_c == self.error_on_chunk_number
|
||||||
):
|
):
|
||||||
raise Exception("Fake error")
|
raise FakeListChatModelError
|
||||||
|
|
||||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||||
|
|
||||||
@ -116,7 +120,7 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
self.error_on_chunk_number is not None
|
self.error_on_chunk_number is not None
|
||||||
and i_c == self.error_on_chunk_number
|
and i_c == self.error_on_chunk_number
|
||||||
):
|
):
|
||||||
raise Exception("Fake error")
|
raise FakeListChatModelError
|
||||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -114,7 +114,6 @@ def create_base_retry_decorator(
|
|||||||
_log_error_once(f"Error in on_retry: {e}")
|
_log_error_once(f"Error in on_retry: {e}")
|
||||||
else:
|
else:
|
||||||
run_manager.on_retry(retry_state)
|
run_manager.on_retry(retry_state)
|
||||||
return None
|
|
||||||
|
|
||||||
min_seconds = 4
|
min_seconds = 4
|
||||||
max_seconds = 10
|
max_seconds = 10
|
||||||
@ -311,6 +310,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
"callback_manager is deprecated. Please use callbacks instead.",
|
"callback_manager is deprecated. Please use callbacks instead.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=5,
|
||||||
)
|
)
|
||||||
values["callbacks"] = values.pop("callback_manager", None)
|
values["callbacks"] = values.pop("callback_manager", None)
|
||||||
return values
|
return values
|
||||||
|
@ -290,10 +290,10 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
|
|||||||
msg_type = msg_kwargs.pop("type")
|
msg_type = msg_kwargs.pop("type")
|
||||||
# None msg content is not allowed
|
# None msg content is not allowed
|
||||||
msg_content = msg_kwargs.pop("content") or ""
|
msg_content = msg_kwargs.pop("content") or ""
|
||||||
except KeyError:
|
except KeyError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Message dict must contain 'role' and 'content' keys, got {message}"
|
f"Message dict must contain 'role' and 'content' keys, got {message}"
|
||||||
)
|
) from e
|
||||||
_message = _create_message_from_message_type(
|
_message = _create_message_from_message_type(
|
||||||
msg_type, msg_content, **msg_kwargs
|
msg_type, msg_content, **msg_kwargs
|
||||||
)
|
)
|
||||||
@ -344,9 +344,7 @@ def _runnable_support(func: Callable) -> Callable:
|
|||||||
if messages is not None:
|
if messages is not None:
|
||||||
return func(messages, **kwargs)
|
return func(messages, **kwargs)
|
||||||
else:
|
else:
|
||||||
return RunnableLambda(
|
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
|
||||||
partial(func, **kwargs), name=getattr(func, "__name__")
|
|
||||||
)
|
|
||||||
|
|
||||||
wrapped.__doc__ = func.__doc__
|
wrapped.__doc__ = func.__doc__
|
||||||
return wrapped
|
return wrapped
|
||||||
@ -791,7 +789,7 @@ def trim_messages(
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
messages = convert_to_messages(messages)
|
messages = convert_to_messages(messages)
|
||||||
if hasattr(token_counter, "get_num_tokens_from_messages"):
|
if hasattr(token_counter, "get_num_tokens_from_messages"):
|
||||||
list_token_counter = getattr(token_counter, "get_num_tokens_from_messages")
|
list_token_counter = token_counter.get_num_tokens_from_messages
|
||||||
elif callable(token_counter):
|
elif callable(token_counter):
|
||||||
if (
|
if (
|
||||||
list(inspect.signature(token_counter).parameters.values())[0].annotation
|
list(inspect.signature(token_counter).parameters.values())[0].annotation
|
||||||
|
@ -42,7 +42,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
|||||||
try:
|
try:
|
||||||
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
|
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise OutputParserException(f"Could not parse function call: {exc}")
|
raise OutputParserException(
|
||||||
|
f"Could not parse function call: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
if self.args_only:
|
if self.args_only:
|
||||||
return func_call["arguments"]
|
return func_call["arguments"]
|
||||||
@ -100,7 +102,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
if partial:
|
if partial:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
raise OutputParserException(f"Could not parse function call: {exc}")
|
raise OutputParserException(
|
||||||
|
f"Could not parse function call: {exc}"
|
||||||
|
) from exc
|
||||||
try:
|
try:
|
||||||
if partial:
|
if partial:
|
||||||
try:
|
try:
|
||||||
@ -126,7 +130,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
except (json.JSONDecodeError, TypeError) as exc:
|
except (json.JSONDecodeError, TypeError) as exc:
|
||||||
raise OutputParserException(
|
raise OutputParserException(
|
||||||
f"Could not parse function call data: {exc}"
|
f"Could not parse function call data: {exc}"
|
||||||
)
|
) from exc
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
@ -138,7 +142,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
except (json.JSONDecodeError, TypeError) as exc:
|
except (json.JSONDecodeError, TypeError) as exc:
|
||||||
raise OutputParserException(
|
raise OutputParserException(
|
||||||
f"Could not parse function call data: {exc}"
|
f"Could not parse function call data: {exc}"
|
||||||
)
|
) from exc
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ def parse_tool_call(
|
|||||||
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
|
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
|
||||||
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
|
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
|
||||||
f"Received JSONDecodeError {e}"
|
f"Received JSONDecodeError {e}"
|
||||||
)
|
) from e
|
||||||
parsed = {
|
parsed = {
|
||||||
"name": raw_tool_call["function"]["name"] or "",
|
"name": raw_tool_call["function"]["name"] or "",
|
||||||
"args": function_args or {},
|
"args": function_args or {},
|
||||||
|
@ -32,12 +32,12 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|||||||
{self.pydantic_object.__class__}"
|
{self.pydantic_object.__class__}"
|
||||||
)
|
)
|
||||||
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
|
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
|
||||||
raise self._parser_exception(e, obj)
|
raise self._parser_exception(e, obj) from e
|
||||||
else: # pydantic v1
|
else: # pydantic v1
|
||||||
try:
|
try:
|
||||||
return self.pydantic_object.parse_obj(obj)
|
return self.pydantic_object.parse_obj(obj)
|
||||||
except pydantic.ValidationError as e:
|
except pydantic.ValidationError as e:
|
||||||
raise self._parser_exception(e, obj)
|
raise self._parser_exception(e, obj) from e
|
||||||
|
|
||||||
def _parser_exception(
|
def _parser_exception(
|
||||||
self, e: Exception, json_object: dict
|
self, e: Exception, json_object: dict
|
||||||
|
@ -46,12 +46,12 @@ class _StreamingParser:
|
|||||||
if parser == "defusedxml":
|
if parser == "defusedxml":
|
||||||
try:
|
try:
|
||||||
from defusedxml import ElementTree as DET # type: ignore
|
from defusedxml import ElementTree as DET # type: ignore
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"defusedxml is not installed. "
|
"defusedxml is not installed. "
|
||||||
"Please install it to use the defusedxml parser."
|
"Please install it to use the defusedxml parser."
|
||||||
"You can install it with `pip install defusedxml` "
|
"You can install it with `pip install defusedxml` "
|
||||||
)
|
) from e
|
||||||
_parser = DET.DefusedXMLParser(target=TreeBuilder())
|
_parser = DET.DefusedXMLParser(target=TreeBuilder())
|
||||||
else:
|
else:
|
||||||
_parser = None
|
_parser = None
|
||||||
@ -189,13 +189,13 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|||||||
if self.parser == "defusedxml":
|
if self.parser == "defusedxml":
|
||||||
try:
|
try:
|
||||||
from defusedxml import ElementTree as DET # type: ignore
|
from defusedxml import ElementTree as DET # type: ignore
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"defusedxml is not installed. "
|
"defusedxml is not installed. "
|
||||||
"Please install it to use the defusedxml parser."
|
"Please install it to use the defusedxml parser."
|
||||||
"You can install it with `pip install defusedxml`"
|
"You can install it with `pip install defusedxml`"
|
||||||
"See https://github.com/tiran/defusedxml for more details"
|
"See https://github.com/tiran/defusedxml for more details"
|
||||||
)
|
) from e
|
||||||
_ET = DET # Use the defusedxml parser
|
_ET = DET # Use the defusedxml parser
|
||||||
else:
|
else:
|
||||||
_ET = ET # Use the standard library parser
|
_ET = ET # Use the standard library parser
|
||||||
|
@ -235,7 +235,9 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
template = f.read()
|
template = f.read()
|
||||||
if input_variables:
|
if input_variables:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"`input_variables' is deprecated and ignored.", DeprecationWarning
|
"`input_variables' is deprecated and ignored.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
return cls.from_template(template=template, **kwargs)
|
return cls.from_template(template=template, **kwargs)
|
||||||
|
|
||||||
|
@ -40,14 +40,14 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from jinja2.sandbox import SandboxedEnvironment
|
from jinja2.sandbox import SandboxedEnvironment
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||||
"Please install it with `pip install jinja2`."
|
"Please install it with `pip install jinja2`."
|
||||||
"Please be cautious when using jinja2 templates. "
|
"Please be cautious when using jinja2 templates. "
|
||||||
"Do not expand jinja2 templates using unverified or user-controlled "
|
"Do not expand jinja2 templates using unverified or user-controlled "
|
||||||
"inputs as that can result in arbitrary Python code execution."
|
"inputs as that can result in arbitrary Python code execution."
|
||||||
)
|
) from e
|
||||||
|
|
||||||
# This uses a sandboxed environment to prevent arbitrary code execution.
|
# This uses a sandboxed environment to prevent arbitrary code execution.
|
||||||
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
|
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
|
||||||
@ -81,17 +81,17 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
|||||||
warning_message += f"Extra variables: {extra_variables}"
|
warning_message += f"Extra variables: {extra_variables}"
|
||||||
|
|
||||||
if warning_message:
|
if warning_message:
|
||||||
warnings.warn(warning_message.strip())
|
warnings.warn(warning_message.strip(), stacklevel=7)
|
||||||
|
|
||||||
|
|
||||||
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||||
try:
|
try:
|
||||||
from jinja2 import Environment, meta
|
from jinja2 import Environment, meta
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||||
"Please install it with `pip install jinja2`."
|
"Please install it with `pip install jinja2`."
|
||||||
)
|
) from e
|
||||||
env = Environment()
|
env = Environment()
|
||||||
ast = env.parse(template)
|
ast = env.parse(template)
|
||||||
variables = meta.find_undeclared_variables(ast)
|
variables = meta.find_undeclared_variables(ast)
|
||||||
|
@ -155,6 +155,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
"Retrievers must implement abstract `_get_relevant_documents` method"
|
"Retrievers must implement abstract `_get_relevant_documents` method"
|
||||||
" instead of `get_relevant_documents`",
|
" instead of `get_relevant_documents`",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=4,
|
||||||
)
|
)
|
||||||
swap = cls.get_relevant_documents
|
swap = cls.get_relevant_documents
|
||||||
cls.get_relevant_documents = ( # type: ignore[assignment]
|
cls.get_relevant_documents = ( # type: ignore[assignment]
|
||||||
@ -169,6 +170,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
"Retrievers must implement abstract `_aget_relevant_documents` method"
|
"Retrievers must implement abstract `_aget_relevant_documents` method"
|
||||||
" instead of `aget_relevant_documents`",
|
" instead of `aget_relevant_documents`",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=4,
|
||||||
)
|
)
|
||||||
aswap = cls.aget_relevant_documents
|
aswap = cls.aget_relevant_documents
|
||||||
cls.aget_relevant_documents = ( # type: ignore[assignment]
|
cls.aget_relevant_documents = ( # type: ignore[assignment]
|
||||||
|
@ -3915,7 +3915,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
|
func = getattr(self, "_transform", None) or self._atransform
|
||||||
try:
|
try:
|
||||||
params = inspect.signature(func).parameters
|
params = inspect.signature(func).parameters
|
||||||
first_param = next(iter(params.values()), None)
|
first_param = next(iter(params.values()), None)
|
||||||
@ -3928,7 +3928,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def OutputType(self) -> Any:
|
def OutputType(self) -> Any:
|
||||||
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
|
func = getattr(self, "_transform", None) or self._atransform
|
||||||
try:
|
try:
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
return (
|
return (
|
||||||
@ -4152,7 +4152,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
@property
|
@property
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
"""The type of the input to this Runnable."""
|
"""The type of the input to this Runnable."""
|
||||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
func = getattr(self, "func", None) or self.afunc
|
||||||
try:
|
try:
|
||||||
params = inspect.signature(func).parameters
|
params = inspect.signature(func).parameters
|
||||||
first_param = next(iter(params.values()), None)
|
first_param = next(iter(params.values()), None)
|
||||||
@ -4174,7 +4174,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
Returns:
|
Returns:
|
||||||
The input schema for this Runnable.
|
The input schema for this Runnable.
|
||||||
"""
|
"""
|
||||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
func = getattr(self, "func", None) or self.afunc
|
||||||
|
|
||||||
if isinstance(func, itemgetter):
|
if isinstance(func, itemgetter):
|
||||||
# This is terrible, but afaict it's not possible to access _items
|
# This is terrible, but afaict it's not possible to access _items
|
||||||
@ -4212,7 +4212,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
Returns:
|
Returns:
|
||||||
The type of the output of this Runnable.
|
The type of the output of this Runnable.
|
||||||
"""
|
"""
|
||||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
func = getattr(self, "func", None) or self.afunc
|
||||||
try:
|
try:
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
if sig.return_annotation != inspect.Signature.empty:
|
if sig.return_annotation != inspect.Signature.empty:
|
||||||
|
@ -236,6 +236,7 @@ def get_config_list(
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Provided run_id be used only for the first element of the batch.",
|
"Provided run_id be used only for the first element of the batch.",
|
||||||
category=RuntimeWarning,
|
category=RuntimeWarning,
|
||||||
|
stacklevel=3,
|
||||||
)
|
)
|
||||||
subsequent = cast(
|
subsequent = cast(
|
||||||
RunnableConfig, {k: v for k, v in config.items() if k != "run_id"}
|
RunnableConfig, {k: v for k, v in config.items() if k != "run_id"}
|
||||||
|
@ -537,7 +537,7 @@ class Graph:
|
|||||||
*,
|
*,
|
||||||
with_styles: bool = True,
|
with_styles: bool = True,
|
||||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||||
node_colors: NodeStyles = NodeStyles(),
|
node_colors: Optional[NodeStyles] = None,
|
||||||
wrap_label_n_words: int = 9,
|
wrap_label_n_words: int = 9,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Draw the graph as a Mermaid syntax string.
|
"""Draw the graph as a Mermaid syntax string.
|
||||||
@ -573,7 +573,7 @@ class Graph:
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||||
node_colors: NodeStyles = NodeStyles(),
|
node_colors: Optional[NodeStyles] = None,
|
||||||
wrap_label_n_words: int = 9,
|
wrap_label_n_words: int = 9,
|
||||||
output_file_path: Optional[str] = None,
|
output_file_path: Optional[str] = None,
|
||||||
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
|
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
|
||||||
|
@ -20,7 +20,7 @@ def draw_mermaid(
|
|||||||
last_node: Optional[str] = None,
|
last_node: Optional[str] = None,
|
||||||
with_styles: bool = True,
|
with_styles: bool = True,
|
||||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||||
node_styles: NodeStyles = NodeStyles(),
|
node_styles: Optional[NodeStyles] = None,
|
||||||
wrap_label_n_words: int = 9,
|
wrap_label_n_words: int = 9,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Draws a Mermaid graph using the provided graph data.
|
"""Draws a Mermaid graph using the provided graph data.
|
||||||
@ -153,7 +153,7 @@ def draw_mermaid(
|
|||||||
|
|
||||||
# Add custom styles for nodes
|
# Add custom styles for nodes
|
||||||
if with_styles:
|
if with_styles:
|
||||||
mermaid_graph += _generate_mermaid_graph_styles(node_styles)
|
mermaid_graph += _generate_mermaid_graph_styles(node_styles or NodeStyles())
|
||||||
return mermaid_graph
|
return mermaid_graph
|
||||||
|
|
||||||
|
|
||||||
|
@ -218,6 +218,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
def pending(iterable: List[U]) -> List[U]:
|
def pending(iterable: List[U]) -> List[U]:
|
||||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||||
|
|
||||||
|
not_set: List[Output] = []
|
||||||
|
result = not_set
|
||||||
try:
|
try:
|
||||||
for attempt in self._sync_retrying():
|
for attempt in self._sync_retrying():
|
||||||
with attempt:
|
with attempt:
|
||||||
@ -247,9 +249,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
):
|
):
|
||||||
attempt.retry_state.set_result(result)
|
attempt.retry_state.set_result(result)
|
||||||
except RetryError as e:
|
except RetryError as e:
|
||||||
try:
|
if result is not_set:
|
||||||
result
|
|
||||||
except UnboundLocalError:
|
|
||||||
result = cast(List[Output], [e] * len(inputs))
|
result = cast(List[Output], [e] * len(inputs))
|
||||||
|
|
||||||
outputs: List[Union[Output, Exception]] = []
|
outputs: List[Union[Output, Exception]] = []
|
||||||
@ -284,6 +284,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
def pending(iterable: List[U]) -> List[U]:
|
def pending(iterable: List[U]) -> List[U]:
|
||||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||||
|
|
||||||
|
not_set: List[Output] = []
|
||||||
|
result = not_set
|
||||||
try:
|
try:
|
||||||
async for attempt in self._async_retrying():
|
async for attempt in self._async_retrying():
|
||||||
with attempt:
|
with attempt:
|
||||||
@ -313,9 +315,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
):
|
):
|
||||||
attempt.retry_state.set_result(result)
|
attempt.retry_state.set_result(result)
|
||||||
except RetryError as e:
|
except RetryError as e:
|
||||||
try:
|
if result is not_set:
|
||||||
result
|
|
||||||
except UnboundLocalError:
|
|
||||||
result = cast(List[Output], [e] * len(inputs))
|
result = cast(List[Output], [e] * len(inputs))
|
||||||
|
|
||||||
outputs: List[Union[Output, Exception]] = []
|
outputs: List[Union[Output, Exception]] = []
|
||||||
|
@ -748,7 +748,7 @@ def is_async_generator(
|
|||||||
"""
|
"""
|
||||||
return (
|
return (
|
||||||
inspect.isasyncgenfunction(func)
|
inspect.isasyncgenfunction(func)
|
||||||
or hasattr(func, "__call__")
|
or hasattr(func, "__call__") # noqa: B004
|
||||||
and inspect.isasyncgenfunction(func.__call__)
|
and inspect.isasyncgenfunction(func.__call__)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -767,6 +767,6 @@ def is_async_callable(
|
|||||||
"""
|
"""
|
||||||
return (
|
return (
|
||||||
asyncio.iscoroutinefunction(func)
|
asyncio.iscoroutinefunction(func)
|
||||||
or hasattr(func, "__call__")
|
or hasattr(func, "__call__") # noqa: B004
|
||||||
and asyncio.iscoroutinefunction(func.__call__)
|
and asyncio.iscoroutinefunction(func.__call__)
|
||||||
)
|
)
|
||||||
|
@ -443,6 +443,7 @@ class ChildTool(BaseTool):
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
"callback_manager is deprecated. Please use callbacks instead.",
|
"callback_manager is deprecated. Please use callbacks instead.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=6,
|
||||||
)
|
)
|
||||||
values["callbacks"] = values.pop("callback_manager", None)
|
values["callbacks"] = values.pop("callback_manager", None)
|
||||||
return values
|
return values
|
||||||
|
@ -244,7 +244,7 @@ def _get_schema_from_runnable_and_arg_types(
|
|||||||
"Tool input must be str or dict. If dict, dict arguments must be "
|
"Tool input must be str or dict. If dict, dict arguments must be "
|
||||||
"typed. Either annotate types (e.g., with TypedDict) or pass "
|
"typed. Either annotate types (e.g., with TypedDict) or pass "
|
||||||
f"arg_types into `.as_tool` to specify. {str(e)}"
|
f"arg_types into `.as_tool` to specify. {str(e)}"
|
||||||
)
|
) from e
|
||||||
fields = {key: (key_type, Field(...)) for key, key_type in arg_types.items()}
|
fields = {key: (key_type, Field(...)) for key, key_type in arg_types.items()}
|
||||||
return create_model(name, **fields) # type: ignore
|
return create_model(name, **fields) # type: ignore
|
||||||
|
|
||||||
|
@ -516,15 +516,19 @@ class _TracerCore(ABC):
|
|||||||
|
|
||||||
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""End a trace for a run."""
|
"""End a trace for a run."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process a run upon creation."""
|
"""Process a run upon creation."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process a run upon update."""
|
"""Process a run upon update."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process the LLM Run upon start."""
|
"""Process the LLM Run upon start."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_llm_new_token(
|
def _on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
@ -533,39 +537,52 @@ class _TracerCore(ABC):
|
|||||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||||
) -> Union[None, Coroutine[Any, Any, None]]:
|
) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process new LLM token."""
|
"""Process new LLM token."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process the LLM Run."""
|
"""Process the LLM Run."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process the LLM Run upon error."""
|
"""Process the LLM Run upon error."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process the Chain Run upon start."""
|
"""Process the Chain Run upon start."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process the Chain Run."""
|
"""Process the Chain Run."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process the Chain Run upon error."""
|
"""Process the Chain Run upon error."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process the Tool Run upon start."""
|
"""Process the Tool Run upon start."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process the Tool Run."""
|
"""Process the Tool Run."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process the Tool Run upon error."""
|
"""Process the Tool Run upon error."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process the Chat Model Run upon start."""
|
"""Process the Chat Model Run upon start."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process the Retriever Run upon start."""
|
"""Process the Retriever Run upon start."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process the Retriever Run."""
|
"""Process the Retriever Run."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||||
"""Process the Retriever Run upon error."""
|
"""Process the Retriever Run upon error."""
|
||||||
|
return None
|
||||||
|
@ -144,11 +144,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
example_id = str(run.reference_example_id)
|
example_id = str(run.reference_example_id)
|
||||||
with self.lock:
|
with self.lock:
|
||||||
for res in eval_results:
|
for res in eval_results:
|
||||||
run_id = (
|
run_id = str(getattr(res, "target_run_id", run.id))
|
||||||
str(getattr(res, "target_run_id"))
|
|
||||||
if hasattr(res, "target_run_id")
|
|
||||||
else str(run.id)
|
|
||||||
)
|
|
||||||
self.logged_eval_results.setdefault((run_id, example_id), []).append(
|
self.logged_eval_results.setdefault((run_id, example_id), []).append(
|
||||||
res
|
res
|
||||||
)
|
)
|
||||||
@ -179,11 +175,9 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
source_info_: Dict[str, Any] = {}
|
source_info_: Dict[str, Any] = {}
|
||||||
if res.evaluator_info:
|
if res.evaluator_info:
|
||||||
source_info_ = {**res.evaluator_info, **source_info_}
|
source_info_ = {**res.evaluator_info, **source_info_}
|
||||||
run_id_ = (
|
run_id_ = getattr(res, "target_run_id", None)
|
||||||
getattr(res, "target_run_id")
|
if run_id_ is None:
|
||||||
if hasattr(res, "target_run_id") and res.target_run_id is not None
|
run_id_ = run.id
|
||||||
else run.id
|
|
||||||
)
|
|
||||||
self.client.create_feedback(
|
self.client.create_feedback(
|
||||||
run_id_,
|
run_id_,
|
||||||
res.key,
|
res.key,
|
||||||
|
@ -22,6 +22,7 @@ def RunTypeEnum() -> Type[RunTypeEnumDep]:
|
|||||||
"RunTypeEnum is deprecated. Please directly use a string instead"
|
"RunTypeEnum is deprecated. Please directly use a string instead"
|
||||||
" (e.g. 'llm', 'chain', 'tool').",
|
" (e.g. 'llm', 'chain', 'tool').",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
return RunTypeEnumDep
|
return RunTypeEnumDep
|
||||||
|
|
||||||
|
@ -62,8 +62,8 @@ def py_anext(
|
|||||||
__anext__ = cast(
|
__anext__ = cast(
|
||||||
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
|
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
|
||||||
)
|
)
|
||||||
except AttributeError:
|
except AttributeError as e:
|
||||||
raise TypeError(f"{iterator!r} is not an async iterator")
|
raise TypeError(f"{iterator!r} is not an async iterator") from e
|
||||||
|
|
||||||
if default is _no_default:
|
if default is _no_default:
|
||||||
return __anext__(iterator)
|
return __anext__(iterator)
|
||||||
|
@ -182,7 +182,7 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
|||||||
try:
|
try:
|
||||||
json_obj = parse_json_markdown(text)
|
json_obj = parse_json_markdown(text)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
raise OutputParserException(f"Got invalid JSON object. Error: {e}") from e
|
||||||
for key in expected_keys:
|
for key in expected_keys:
|
||||||
if key not in json_obj:
|
if key not in json_obj:
|
||||||
raise OutputParserException(
|
raise OutputParserException(
|
||||||
|
@ -21,7 +21,9 @@ def try_load_from_hub(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Loading from the deprecated github-based Hub is no longer supported. "
|
"Loading from the deprecated github-based Hub is no longer supported. "
|
||||||
"Please use the new LangChain Hub at https://smith.langchain.com/hub instead."
|
"Please use the new LangChain Hub at https://smith.langchain.com/hub instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
# return None, which indicates that we shouldn't load from old hub
|
# return None, which indicates that we shouldn't load from old hub
|
||||||
# and might just be a filepath for e.g. load_chain
|
# and might just be a filepath for e.g. load_chain
|
||||||
|
@ -3,13 +3,17 @@ Adapted from https://github.com/noahmorrison/chevron
|
|||||||
MIT License
|
MIT License
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from types import MappingProxyType
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
@ -22,7 +26,7 @@ from typing_extensions import TypeAlias
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
Scopes: TypeAlias = List[Union[Literal[False, 0], Dict[str, Any]]]
|
Scopes: TypeAlias = List[Union[Literal[False, 0], Mapping[str, Any]]]
|
||||||
|
|
||||||
|
|
||||||
# Globals
|
# Globals
|
||||||
@ -152,8 +156,8 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s
|
|||||||
# Get the tag
|
# Get the tag
|
||||||
try:
|
try:
|
||||||
tag, template = template.split(r_del, 1)
|
tag, template = template.split(r_del, 1)
|
||||||
except ValueError:
|
except ValueError as e:
|
||||||
raise ChevronError("unclosed tag " f"at line {_CURRENT_LINE}")
|
raise ChevronError("unclosed tag " f"at line {_CURRENT_LINE}") from e
|
||||||
|
|
||||||
# Find the type meaning of the first character
|
# Find the type meaning of the first character
|
||||||
tag_type = tag_types.get(tag[0], "variable")
|
tag_type = tag_types.get(tag[0], "variable")
|
||||||
@ -279,12 +283,12 @@ def tokenize(
|
|||||||
# is the same as us
|
# is the same as us
|
||||||
try:
|
try:
|
||||||
last_section = open_sections.pop()
|
last_section = open_sections.pop()
|
||||||
except IndexError:
|
except IndexError as e:
|
||||||
raise ChevronError(
|
raise ChevronError(
|
||||||
f'Trying to close tag "{tag_key}"\n'
|
f'Trying to close tag "{tag_key}"\n'
|
||||||
"Looks like it was not opened.\n"
|
"Looks like it was not opened.\n"
|
||||||
f"line {_CURRENT_LINE + 1}"
|
f"line {_CURRENT_LINE + 1}"
|
||||||
)
|
) from e
|
||||||
if tag_key != last_section:
|
if tag_key != last_section:
|
||||||
# Otherwise we need to complain
|
# Otherwise we need to complain
|
||||||
raise ChevronError(
|
raise ChevronError(
|
||||||
@ -411,7 +415,7 @@ def _get_key(
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def _get_partial(name: str, partials_dict: Dict[str, str]) -> str:
|
def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str:
|
||||||
"""Load a partial"""
|
"""Load a partial"""
|
||||||
try:
|
try:
|
||||||
# Maybe the partial is in the dictionary
|
# Maybe the partial is in the dictionary
|
||||||
@ -425,11 +429,13 @@ def _get_partial(name: str, partials_dict: Dict[str, str]) -> str:
|
|||||||
#
|
#
|
||||||
g_token_cache: Dict[str, List[Tuple[str, str]]] = {}
|
g_token_cache: Dict[str, List[Tuple[str, str]]] = {}
|
||||||
|
|
||||||
|
EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({})
|
||||||
|
|
||||||
|
|
||||||
def render(
|
def render(
|
||||||
template: Union[str, List[Tuple[str, str]]] = "",
|
template: Union[str, List[Tuple[str, str]]] = "",
|
||||||
data: Dict[str, Any] = {},
|
data: Mapping[str, Any] = EMPTY_DICT,
|
||||||
partials_dict: Dict[str, str] = {},
|
partials_dict: Mapping[str, str] = EMPTY_DICT,
|
||||||
padding: str = "",
|
padding: str = "",
|
||||||
def_ldel: str = "{{",
|
def_ldel: str = "{{",
|
||||||
def_rdel: str = "}}",
|
def_rdel: str = "}}",
|
||||||
|
@ -131,12 +131,12 @@ def guard_import(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(module_name, package)
|
module = importlib.import_module(module_name, package)
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
|
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Could not import {module_name} python package. "
|
f"Could not import {module_name} python package. "
|
||||||
f"Please install it with `pip install {pip_name}`."
|
f"Please install it with `pip install {pip_name}`."
|
||||||
)
|
) from e
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
@ -235,7 +235,8 @@ def build_extra_kwargs(
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"""WARNING! {field_name} is not default parameter.
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
{field_name} was transferred to model_kwargs.
|
{field_name} was transferred to model_kwargs.
|
||||||
Please confirm that {field_name} is what you intended."""
|
Please confirm that {field_name} is what you intended.""",
|
||||||
|
stacklevel=7,
|
||||||
)
|
)
|
||||||
extra_kwargs[field_name] = values.pop(field_name)
|
extra_kwargs[field_name] = values.pop(field_name)
|
||||||
|
|
||||||
|
@ -558,7 +558,8 @@ class VectorStore(ABC):
|
|||||||
):
|
):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Relevance scores must be between"
|
"Relevance scores must be between"
|
||||||
f" 0 and 1, got {docs_and_similarities}"
|
f" 0 and 1, got {docs_and_similarities}",
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if score_threshold is not None:
|
if score_threshold is not None:
|
||||||
@ -568,7 +569,7 @@ class VectorStore(ABC):
|
|||||||
if similarity >= score_threshold
|
if similarity >= score_threshold
|
||||||
]
|
]
|
||||||
if len(docs_and_similarities) == 0:
|
if len(docs_and_similarities) == 0:
|
||||||
warnings.warn(
|
logger.warning(
|
||||||
"No relevant docs were retrieved using the relevance score"
|
"No relevant docs were retrieved using the relevance score"
|
||||||
f" threshold {score_threshold}"
|
f" threshold {score_threshold}"
|
||||||
)
|
)
|
||||||
@ -605,7 +606,8 @@ class VectorStore(ABC):
|
|||||||
):
|
):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Relevance scores must be between"
|
"Relevance scores must be between"
|
||||||
f" 0 and 1, got {docs_and_similarities}"
|
f" 0 and 1, got {docs_and_similarities}",
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if score_threshold is not None:
|
if score_threshold is not None:
|
||||||
@ -615,7 +617,7 @@ class VectorStore(ABC):
|
|||||||
if similarity >= score_threshold
|
if similarity >= score_threshold
|
||||||
]
|
]
|
||||||
if len(docs_and_similarities) == 0:
|
if len(docs_and_similarities) == 0:
|
||||||
warnings.warn(
|
logger.warning(
|
||||||
"No relevant docs were retrieved using the relevance score"
|
"No relevant docs were retrieved using the relevance score"
|
||||||
f" threshold {score_threshold}"
|
f" threshold {score_threshold}"
|
||||||
)
|
)
|
||||||
|
@ -435,11 +435,11 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"numpy must be installed to use max_marginal_relevance_search "
|
"numpy must be installed to use max_marginal_relevance_search "
|
||||||
"pip install numpy"
|
"pip install numpy"
|
||||||
)
|
) from e
|
||||||
|
|
||||||
mmr_chosen_indices = maximal_marginal_relevance(
|
mmr_chosen_indices = maximal_marginal_relevance(
|
||||||
np.array(embedding, dtype=np.float32),
|
np.array(embedding, dtype=np.float32),
|
||||||
|
@ -34,11 +34,11 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"cosine_similarity requires numpy to be installed. "
|
"cosine_similarity requires numpy to be installed. "
|
||||||
"Please install numpy with `pip install numpy`."
|
"Please install numpy with `pip install numpy`."
|
||||||
)
|
) from e
|
||||||
|
|
||||||
if len(X) == 0 or len(Y) == 0:
|
if len(X) == 0 or len(Y) == 0:
|
||||||
return np.array([])
|
return np.array([])
|
||||||
@ -93,11 +93,11 @@ def maximal_marginal_relevance(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"maximal_marginal_relevance requires numpy to be installed. "
|
"maximal_marginal_relevance requires numpy to be installed. "
|
||||||
"Please install numpy with `pip install numpy`."
|
"Please install numpy with `pip install numpy`."
|
||||||
)
|
) from e
|
||||||
|
|
||||||
if min(k, len(embedding_list)) <= 0:
|
if min(k, len(embedding_list)) <= 0:
|
||||||
return []
|
return []
|
||||||
|
@ -41,7 +41,7 @@ python = ">=3.12.4"
|
|||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [ "E", "F", "I", "T201", "UP",]
|
select = [ "B", "E", "F", "I", "T201", "UP",]
|
||||||
ignore = [ "UP006", "UP007",]
|
ignore = [ "UP006", "UP007",]
|
||||||
|
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
|
@ -7,6 +7,7 @@ import pytest
|
|||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models import BaseChatModel, FakeListChatModel
|
from langchain_core.language_models import BaseChatModel, FakeListChatModel
|
||||||
|
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
@ -110,7 +111,7 @@ async def test_stream_error_callback() -> None:
|
|||||||
responses=[message],
|
responses=[message],
|
||||||
error_on_chunk_number=i,
|
error_on_chunk_number=i,
|
||||||
)
|
)
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(FakeListChatModelError):
|
||||||
cb_async = FakeAsyncCallbackHandler()
|
cb_async = FakeAsyncCallbackHandler()
|
||||||
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
||||||
pass
|
pass
|
||||||
|
@ -7,6 +7,7 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models import BaseLLM, FakeListLLM, FakeStreamingListLLM
|
from langchain_core.language_models import BaseLLM, FakeListLLM, FakeStreamingListLLM
|
||||||
|
from langchain_core.language_models.fake import FakeListLLMError
|
||||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||||
from langchain_core.tracers.context import collect_runs
|
from langchain_core.tracers.context import collect_runs
|
||||||
from tests.unit_tests.fake.callbacks import (
|
from tests.unit_tests.fake.callbacks import (
|
||||||
@ -108,7 +109,7 @@ async def test_stream_error_callback() -> None:
|
|||||||
responses=[message],
|
responses=[message],
|
||||||
error_on_chunk_number=i,
|
error_on_chunk_number=i,
|
||||||
)
|
)
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(FakeListLLMError):
|
||||||
cb_async = FakeAsyncCallbackHandler()
|
cb_async = FakeAsyncCallbackHandler()
|
||||||
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
||||||
pass
|
pass
|
||||||
|
@ -3,6 +3,7 @@ from typing import Any, AsyncIterator, Iterator, Tuple
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.output_parsers.json import (
|
from langchain_core.output_parsers.json import (
|
||||||
SimpleJsonOutputParser,
|
SimpleJsonOutputParser,
|
||||||
)
|
)
|
||||||
@ -531,7 +532,7 @@ async def test_partial_text_json_output_parser_diff_async() -> None:
|
|||||||
|
|
||||||
def test_raises_error() -> None:
|
def test_raises_error() -> None:
|
||||||
parser = SimpleJsonOutputParser()
|
parser = SimpleJsonOutputParser()
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(OutputParserException):
|
||||||
parser.invoke("hi")
|
parser.invoke("hi")
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,13 +164,9 @@ def test_pydantic_output_parser_fail() -> None:
|
|||||||
pydantic_object=TestModel
|
pydantic_object=TestModel
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
with pytest.raises(OutputParserException) as e:
|
||||||
pydantic_parser.parse(DEF_RESULT_FAIL)
|
pydantic_parser.parse(DEF_RESULT_FAIL)
|
||||||
except OutputParserException as e:
|
|
||||||
print("parse_result:", e) # noqa: T201
|
|
||||||
assert "Failed to parse TestModel from completion" in str(e)
|
assert "Failed to parse TestModel from completion" in str(e)
|
||||||
else:
|
|
||||||
assert False, "Expected OutputParserException"
|
|
||||||
|
|
||||||
|
|
||||||
def test_pydantic_output_parser_type_inference() -> None:
|
def test_pydantic_output_parser_type_inference() -> None:
|
||||||
|
@ -28,7 +28,7 @@ def _replace_all_of_with_ref(schema: Any) -> None:
|
|||||||
del schema["default"]
|
del schema["default"]
|
||||||
else:
|
else:
|
||||||
# Recursively process nested schemas
|
# Recursively process nested schemas
|
||||||
for key, value in schema.items():
|
for value in schema.values():
|
||||||
if isinstance(value, (dict, list)):
|
if isinstance(value, (dict, list)):
|
||||||
_replace_all_of_with_ref(value)
|
_replace_all_of_with_ref(value)
|
||||||
elif isinstance(schema, list):
|
elif isinstance(schema, list):
|
||||||
@ -47,7 +47,7 @@ def _remove_bad_none_defaults(schema: Any) -> None:
|
|||||||
See difference between Optional and NotRequired types in python.
|
See difference between Optional and NotRequired types in python.
|
||||||
"""
|
"""
|
||||||
if isinstance(schema, dict):
|
if isinstance(schema, dict):
|
||||||
for key, value in schema.items():
|
for value in schema.values():
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
if "default" in value and value["default"] is None:
|
if "default" in value and value["default"] is None:
|
||||||
any_of = value.get("anyOf", [])
|
any_of = value.get("anyOf", [])
|
||||||
|
@ -307,7 +307,7 @@ async def test_fallbacks_astream() -> None:
|
|||||||
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
|
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
|
||||||
[RunnableGenerator(_agenerate)]
|
[RunnableGenerator(_agenerate)]
|
||||||
)
|
)
|
||||||
async for c in runnable.astream({}):
|
async for _ in runnable.astream({}):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -373,7 +373,7 @@ def test_fallbacks_getattr() -> None:
|
|||||||
assert llm_with_fallbacks.foo == 3
|
assert llm_with_fallbacks.foo == 3
|
||||||
|
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(AttributeError):
|
||||||
llm_with_fallbacks.bar
|
assert llm_with_fallbacks.bar == 4
|
||||||
|
|
||||||
|
|
||||||
def test_fallbacks_getattr_runnable_output() -> None:
|
def test_fallbacks_getattr_runnable_output() -> None:
|
||||||
|
@ -1516,7 +1516,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
) == [5, 7]
|
) == [5, 7]
|
||||||
|
|
||||||
assert len(spy.call_args_list) == 2
|
assert len(spy.call_args_list) == 2
|
||||||
for i, call in enumerate(spy.call_args_list):
|
for call in spy.call_args_list:
|
||||||
call_arg = call.args[0]
|
call_arg = call.args[0]
|
||||||
|
|
||||||
if call_arg == "hello":
|
if call_arg == "hello":
|
||||||
@ -1533,7 +1533,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
||||||
assert len(spy.call_args_list) == 2
|
assert len(spy.call_args_list) == 2
|
||||||
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"}
|
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"}
|
||||||
for i, call in enumerate(spy.call_args_list):
|
for call in spy.call_args_list:
|
||||||
assert call.args[1].get("tags") == ["a-tag"]
|
assert call.args[1].get("tags") == ["a-tag"]
|
||||||
assert call.args[1].get("metadata") == {}
|
assert call.args[1].get("metadata") == {}
|
||||||
spy.reset_mock()
|
spy.reset_mock()
|
||||||
@ -5205,7 +5205,7 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
|
|||||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
||||||
|
|
||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
for _ in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
||||||
@ -5225,7 +5225,7 @@ async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
|||||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
||||||
|
|
||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
async for _ in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
||||||
|
@ -140,7 +140,7 @@ class LangChainProjectNameTest(unittest.TestCase):
|
|||||||
projects = []
|
projects = []
|
||||||
|
|
||||||
def mock_create_run(**kwargs: Any) -> Any:
|
def mock_create_run(**kwargs: Any) -> Any:
|
||||||
projects.append(kwargs.get("project_name"))
|
projects.append(kwargs.get("project_name")) # noqa: B023
|
||||||
return unittest.mock.MagicMock()
|
return unittest.mock.MagicMock()
|
||||||
|
|
||||||
client.create_run = mock_create_run
|
client.create_run = mock_create_run
|
||||||
@ -151,6 +151,4 @@ class LangChainProjectNameTest(unittest.TestCase):
|
|||||||
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
|
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
|
||||||
)
|
)
|
||||||
tracer.wait_for_futures()
|
tracer.wait_for_futures()
|
||||||
assert (
|
assert projects == [case.expected_project_name]
|
||||||
len(projects) == 1 and projects[0] == case.expected_project_name
|
|
||||||
)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user