mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
core[patch]: Add ruff rules for flake8-simplify (SIM) (#26848)
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
de0b48c41a
commit
7809b31b95
@ -127,10 +127,9 @@ def beta(
|
||||
|
||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
||||
"""Finalize the annotation of a class."""
|
||||
try:
|
||||
# Can't set new_doc on some extension objects.
|
||||
with contextlib.suppress(AttributeError):
|
||||
obj.__doc__ = new_doc
|
||||
except AttributeError: # Can't set on some extension objects.
|
||||
pass
|
||||
|
||||
def warn_if_direct_instance(
|
||||
self: Any, *args: Any, **kwargs: Any
|
||||
|
@ -198,10 +198,9 @@ def deprecated(
|
||||
|
||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
||||
"""Finalize the deprecation of a class."""
|
||||
try:
|
||||
# Can't set new_doc on some extension objects.
|
||||
with contextlib.suppress(AttributeError):
|
||||
obj.__doc__ = new_doc
|
||||
except AttributeError: # Can't set on some extension objects.
|
||||
pass
|
||||
|
||||
def warn_if_direct_instance(
|
||||
self: Any, *args: Any, **kwargs: Any
|
||||
|
@ -27,7 +27,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
mode: The mode to open the file in. Defaults to "a".
|
||||
color: The color to use for the text. Defaults to None.
|
||||
"""
|
||||
self.file = cast(TextIO, open(filename, mode, encoding="utf-8"))
|
||||
self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) # noqa: SIM115
|
||||
self.color = color
|
||||
|
||||
def __del__(self) -> None:
|
||||
|
@ -2252,14 +2252,14 @@ def _configure(
|
||||
else:
|
||||
parent_run_id_ = inheritable_callbacks.parent_run_id
|
||||
# Break ties between the external tracing context and inherited context
|
||||
if parent_run_id is not None:
|
||||
if parent_run_id_ is None:
|
||||
parent_run_id_ = parent_run_id
|
||||
if parent_run_id is not None and (
|
||||
parent_run_id_ is None
|
||||
# If the LC parent has already been reflected
|
||||
# in the run tree, we know the run_tree is either the
|
||||
# same parent or a child of the parent.
|
||||
elif run_tree and str(parent_run_id_) in run_tree.dotted_order:
|
||||
parent_run_id_ = parent_run_id
|
||||
or (run_tree and str(parent_run_id_) in run_tree.dotted_order)
|
||||
):
|
||||
parent_run_id_ = parent_run_id
|
||||
# Otherwise, we assume the LC context has progressed
|
||||
# beyond the run tree and we should not inherit the parent.
|
||||
callback_manager = callback_manager_cls(
|
||||
|
@ -40,12 +40,11 @@ class OutputParserException(ValueError, LangChainException): # noqa: N818
|
||||
send_to_llm: bool = False,
|
||||
):
|
||||
super().__init__(error)
|
||||
if send_to_llm:
|
||||
if observation is None or llm_output is None:
|
||||
raise ValueError(
|
||||
"Arguments 'observation' & 'llm_output'"
|
||||
" are required if 'send_to_llm' is True"
|
||||
)
|
||||
if send_to_llm and (observation is None or llm_output is None):
|
||||
raise ValueError(
|
||||
"Arguments 'observation' & 'llm_output'"
|
||||
" are required if 'send_to_llm' is True"
|
||||
)
|
||||
self.observation = observation
|
||||
self.llm_output = llm_output
|
||||
self.send_to_llm = send_to_llm
|
||||
|
@ -92,7 +92,7 @@ class _HashedDocument(Document):
|
||||
values["metadata_hash"] = metadata_hash
|
||||
values["hash_"] = str(_hash_string_to_uuid(content_hash + metadata_hash))
|
||||
|
||||
_uid = values.get("uid", None)
|
||||
_uid = values.get("uid")
|
||||
|
||||
if _uid is None:
|
||||
values["uid"] = values["hash_"]
|
||||
|
@ -802,10 +802,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if isinstance(self.cache, BaseCache):
|
||||
llm_cache = self.cache
|
||||
else:
|
||||
llm_cache = get_llm_cache()
|
||||
llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache()
|
||||
# We should check the cache unless it's explicitly set to False
|
||||
# A None cache means we should use the default global cache
|
||||
# if it's configured.
|
||||
@ -879,10 +876,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if isinstance(self.cache, BaseCache):
|
||||
llm_cache = self.cache
|
||||
else:
|
||||
llm_cache = get_llm_cache()
|
||||
llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache()
|
||||
# We should check the cache unless it's explicitly set to False
|
||||
# A None cache means we should use the default global cache
|
||||
# if it's configured.
|
||||
@ -1054,10 +1048,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
_stop = None if stop is None else list(stop)
|
||||
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
|
||||
if isinstance(result.content, str):
|
||||
return result.content
|
||||
@ -1072,20 +1063,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
_stop = None if stop is None else list(stop)
|
||||
return self(messages, stop=_stop, **kwargs)
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
_stop = None if stop is None else list(stop)
|
||||
result = await self._call_async(
|
||||
[HumanMessage(content=text)], stop=_stop, **kwargs
|
||||
)
|
||||
@ -1102,10 +1087,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
_stop = None if stop is None else list(stop)
|
||||
return await self._call_async(messages, stop=_stop, **kwargs)
|
||||
|
||||
@property
|
||||
@ -1333,9 +1315,12 @@ def _cleanup_llm_representation(serialized: Any, depth: int) -> None:
|
||||
if not isinstance(serialized, dict):
|
||||
return
|
||||
|
||||
if "type" in serialized and serialized["type"] == "not_implemented":
|
||||
if "repr" in serialized:
|
||||
del serialized["repr"]
|
||||
if (
|
||||
"type" in serialized
|
||||
and serialized["type"] == "not_implemented"
|
||||
and "repr" in serialized
|
||||
):
|
||||
del serialized["repr"]
|
||||
|
||||
if "graph" in serialized:
|
||||
del serialized["graph"]
|
||||
|
@ -194,10 +194,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
message = next(self.messages)
|
||||
if isinstance(message, str):
|
||||
message_ = AIMessage(content=message)
|
||||
else:
|
||||
message_ = message
|
||||
message_ = AIMessage(content=message) if isinstance(message, str) else message
|
||||
generation = ChatGeneration(message=message_)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
|
@ -1305,10 +1305,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
_stop = None if stop is None else list(stop)
|
||||
return self(text, stop=_stop, **kwargs)
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
@ -1320,10 +1317,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
text = get_buffer_string(messages)
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
_stop = None if stop is None else list(stop)
|
||||
content = self(text, stop=_stop, **kwargs)
|
||||
return AIMessage(content=content)
|
||||
|
||||
@ -1331,10 +1325,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
_stop = None if stop is None else list(stop)
|
||||
return await self._call_async(text, stop=_stop, **kwargs)
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
@ -1346,10 +1337,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
text = get_buffer_string(messages)
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
_stop = None if stop is None else list(stop)
|
||||
content = await self._call_async(text, stop=_stop, **kwargs)
|
||||
return AIMessage(content=content)
|
||||
|
||||
@ -1384,10 +1372,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
llm.save(file_path="path/llm.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
save_path = Path(file_path) if isinstance(file_path, str) else file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -86,9 +86,9 @@ class Reviver:
|
||||
|
||||
def __call__(self, value: dict[str, Any]) -> Any:
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "secret"
|
||||
and value.get("id", None) is not None
|
||||
value.get("lc") == 1
|
||||
and value.get("type") == "secret"
|
||||
and value.get("id") is not None
|
||||
):
|
||||
[key] = value["id"]
|
||||
if key in self.secrets_map:
|
||||
@ -99,9 +99,9 @@ class Reviver:
|
||||
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
|
||||
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "not_implemented"
|
||||
and value.get("id", None) is not None
|
||||
value.get("lc") == 1
|
||||
and value.get("type") == "not_implemented"
|
||||
and value.get("id") is not None
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Trying to load an object that doesn't implement "
|
||||
@ -109,17 +109,18 @@ class Reviver:
|
||||
)
|
||||
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "constructor"
|
||||
and value.get("id", None) is not None
|
||||
value.get("lc") == 1
|
||||
and value.get("type") == "constructor"
|
||||
and value.get("id") is not None
|
||||
):
|
||||
[*namespace, name] = value["id"]
|
||||
mapping_key = tuple(value["id"])
|
||||
|
||||
if namespace[0] not in self.valid_namespaces:
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
# The root namespace ["langchain"] is not a valid identifier.
|
||||
elif namespace == ["langchain"]:
|
||||
if (
|
||||
namespace[0] not in self.valid_namespaces
|
||||
# The root namespace ["langchain"] is not a valid identifier.
|
||||
or namespace == ["langchain"]
|
||||
):
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
# Has explicit import path.
|
||||
elif mapping_key in self.import_mappings:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
from abc import ABC
|
||||
from typing import (
|
||||
Any,
|
||||
@ -238,7 +239,7 @@ class Serializable(BaseModel, ABC):
|
||||
|
||||
# include all secrets, even if not specified in kwargs
|
||||
# as these secrets may be passed as an environment variable instead
|
||||
for key in secrets.keys():
|
||||
for key in secrets:
|
||||
secret_value = getattr(self, key, None) or lc_kwargs.get(key)
|
||||
if secret_value is not None:
|
||||
lc_kwargs.update({key: secret_value})
|
||||
@ -357,8 +358,6 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
||||
"id": _id,
|
||||
"repr": None,
|
||||
}
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
result["repr"] = repr(obj)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
@ -435,23 +435,22 @@ def filter_messages(
|
||||
messages = convert_to_messages(messages)
|
||||
filtered: list[BaseMessage] = []
|
||||
for msg in messages:
|
||||
if exclude_names and msg.name in exclude_names:
|
||||
continue
|
||||
elif exclude_types and _is_message_type(msg, exclude_types):
|
||||
continue
|
||||
elif exclude_ids and msg.id in exclude_ids:
|
||||
if (
|
||||
(exclude_names and msg.name in exclude_names)
|
||||
or (exclude_types and _is_message_type(msg, exclude_types))
|
||||
or (exclude_ids and msg.id in exclude_ids)
|
||||
):
|
||||
continue
|
||||
else:
|
||||
pass
|
||||
|
||||
# default to inclusion when no inclusion criteria given.
|
||||
if not (include_types or include_ids or include_names):
|
||||
filtered.append(msg)
|
||||
elif include_names and msg.name in include_names:
|
||||
filtered.append(msg)
|
||||
elif include_types and _is_message_type(msg, include_types):
|
||||
filtered.append(msg)
|
||||
elif include_ids and msg.id in include_ids:
|
||||
if (
|
||||
not (include_types or include_ids or include_names)
|
||||
or (include_names and msg.name in include_names)
|
||||
or (include_types and _is_message_type(msg, include_types))
|
||||
or (include_ids and msg.id in include_ids)
|
||||
):
|
||||
filtered.append(msg)
|
||||
else:
|
||||
pass
|
||||
@ -961,10 +960,7 @@ def _last_max_tokens(
|
||||
while messages and not _is_message_type(messages[-1], end_on):
|
||||
messages.pop()
|
||||
swapped_system = include_system and isinstance(messages[0], SystemMessage)
|
||||
if swapped_system:
|
||||
reversed_ = messages[:1] + messages[1:][::-1]
|
||||
else:
|
||||
reversed_ = messages[::-1]
|
||||
reversed_ = messages[:1] + messages[1:][::-1] if swapped_system else messages[::-1]
|
||||
|
||||
reversed_ = _first_max_tokens(
|
||||
reversed_,
|
||||
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -311,8 +312,6 @@ class BaseOutputParser(
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
output_parser_dict = super().dict(**kwargs)
|
||||
try:
|
||||
with contextlib.suppress(NotImplementedError):
|
||||
output_parser_dict["_type"] = self._type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return output_parser_dict
|
||||
|
@ -49,12 +49,10 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
|
||||
if PYDANTIC_MAJOR_VERSION == 2:
|
||||
if issubclass(pydantic_object, pydantic.BaseModel):
|
||||
return pydantic_object.model_json_schema()
|
||||
elif issubclass(pydantic_object, pydantic.v1.BaseModel):
|
||||
return pydantic_object.model_json_schema()
|
||||
return pydantic_object.model_json_schema()
|
||||
if issubclass(pydantic_object, pydantic.BaseModel):
|
||||
return pydantic_object.model_json_schema()
|
||||
elif issubclass(pydantic_object, pydantic.v1.BaseModel):
|
||||
return pydantic_object.schema()
|
||||
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
@ -110,10 +110,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||
prev_parsed = None
|
||||
acc_gen = None
|
||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
|
||||
for chunk in input:
|
||||
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
@ -121,10 +122,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
||||
if acc_gen is None:
|
||||
acc_gen = chunk_gen
|
||||
else:
|
||||
acc_gen = acc_gen + chunk_gen
|
||||
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
||||
|
||||
parsed = self.parse_result([acc_gen], partial=True)
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
@ -138,10 +136,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[T]:
|
||||
prev_parsed = None
|
||||
acc_gen = None
|
||||
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
|
||||
async for chunk in input:
|
||||
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
|
||||
if isinstance(chunk, BaseMessageChunk):
|
||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
||||
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||
elif isinstance(chunk, BaseMessage):
|
||||
chunk_gen = ChatGenerationChunk(
|
||||
message=BaseMessageChunk(**chunk.dict())
|
||||
@ -149,10 +148,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
else:
|
||||
chunk_gen = GenerationChunk(text=chunk)
|
||||
|
||||
if acc_gen is None:
|
||||
acc_gen = chunk_gen
|
||||
else:
|
||||
acc_gen = acc_gen + chunk_gen
|
||||
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
||||
|
||||
parsed = await self.aparse_result([acc_gen], partial=True)
|
||||
if parsed is not None and parsed != prev_parsed:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import re
|
||||
import xml
|
||||
import xml.etree.ElementTree as ET # noqa: N817
|
||||
@ -131,11 +132,9 @@ class _StreamingParser:
|
||||
Raises:
|
||||
xml.etree.ElementTree.ParseError: If the XML is not well-formed.
|
||||
"""
|
||||
try:
|
||||
# Ignore ParseError. This will ignore any incomplete XML at the end of the input
|
||||
with contextlib.suppress(xml.etree.ElementTree.ParseError):
|
||||
self.pull_parser.close()
|
||||
except xml.etree.ElementTree.ParseError:
|
||||
# Ignore. This will ignore any incomplete XML at the end of the input
|
||||
pass
|
||||
|
||||
|
||||
class XMLOutputParser(BaseTransformOutputParser):
|
||||
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
@ -319,10 +320,8 @@ class BasePromptTemplate(
|
||||
NotImplementedError: If the prompt type is not implemented.
|
||||
"""
|
||||
prompt_dict = super().model_dump(**kwargs)
|
||||
try:
|
||||
with contextlib.suppress(NotImplementedError):
|
||||
prompt_dict["_type"] = self._prompt_type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return prompt_dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
@ -350,10 +349,7 @@ class BasePromptTemplate(
|
||||
raise NotImplementedError(f"Prompt {self} does not support saving.")
|
||||
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
save_path = Path(file_path) if isinstance(file_path, str) else file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -59,8 +59,8 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
||||
ValueError: If neither or both examples and example_selector are provided.
|
||||
ValueError: If both examples and example_selector are provided.
|
||||
"""
|
||||
examples = values.get("examples", None)
|
||||
example_selector = values.get("example_selector", None)
|
||||
examples = values.get("examples")
|
||||
example_selector = values.get("example_selector")
|
||||
if examples and example_selector:
|
||||
raise ValueError(
|
||||
"Only one of 'examples' and 'example_selector' should be provided"
|
||||
|
@ -51,8 +51,8 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
@classmethod
|
||||
def check_examples_and_selector(cls, values: dict) -> Any:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
examples = values.get("examples", None)
|
||||
example_selector = values.get("example_selector", None)
|
||||
examples = values.get("examples")
|
||||
example_selector = values.get("example_selector")
|
||||
if examples and example_selector:
|
||||
raise ValueError(
|
||||
"Only one of 'examples' and 'example_selector' should be provided"
|
||||
@ -138,7 +138,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
prefix_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
|
||||
}
|
||||
for k in prefix_kwargs.keys():
|
||||
for k in prefix_kwargs:
|
||||
kwargs.pop(k)
|
||||
prefix = self.prefix.format(**prefix_kwargs)
|
||||
|
||||
@ -146,7 +146,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
suffix_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
|
||||
}
|
||||
for k in suffix_kwargs.keys():
|
||||
for k in suffix_kwargs:
|
||||
kwargs.pop(k)
|
||||
suffix = self.suffix.format(
|
||||
**suffix_kwargs,
|
||||
@ -182,7 +182,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
prefix_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
|
||||
}
|
||||
for k in prefix_kwargs.keys():
|
||||
for k in prefix_kwargs:
|
||||
kwargs.pop(k)
|
||||
prefix = await self.prefix.aformat(**prefix_kwargs)
|
||||
|
||||
@ -190,7 +190,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
suffix_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
|
||||
}
|
||||
for k in suffix_kwargs.keys():
|
||||
for k in suffix_kwargs:
|
||||
kwargs.pop(k)
|
||||
suffix = await self.suffix.aformat(
|
||||
**suffix_kwargs,
|
||||
|
@ -164,10 +164,7 @@ def _load_prompt_from_file(
|
||||
) -> BasePromptTemplate:
|
||||
"""Load prompt from file."""
|
||||
# Convert file to a Path object.
|
||||
if isinstance(file, str):
|
||||
file_path = Path(file)
|
||||
else:
|
||||
file_path = file
|
||||
file_path = Path(file) if isinstance(file, str) else file
|
||||
# Load from either json or yaml.
|
||||
if file_path.suffix == ".json":
|
||||
with open(file_path, encoding=encoding) as f:
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import threading
|
||||
@ -2466,10 +2467,8 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
A JSON-serializable representation of the Runnable.
|
||||
"""
|
||||
dumped = super().to_json()
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
dumped["name"] = self.get_name()
|
||||
except Exception:
|
||||
pass
|
||||
return dumped
|
||||
|
||||
def configurable_fields(
|
||||
@ -2763,9 +2762,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
ValueError: If the sequence has less than 2 steps.
|
||||
"""
|
||||
steps_flat: list[Runnable] = []
|
||||
if not steps:
|
||||
if first is not None and last is not None:
|
||||
steps_flat = [first] + (middle or []) + [last]
|
||||
if not steps and first is not None and last is not None:
|
||||
steps_flat = [first] + (middle or []) + [last]
|
||||
for step in steps:
|
||||
if isinstance(step, RunnableSequence):
|
||||
steps_flat.extend(step.steps)
|
||||
@ -4180,12 +4178,9 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
final = None
|
||||
final: Optional[Output] = None
|
||||
for output in self.stream(input, config, **kwargs):
|
||||
if final is None:
|
||||
final = output
|
||||
else:
|
||||
final = final + output
|
||||
final = output if final is None else final + output # type: ignore[operator]
|
||||
return cast(Output, final)
|
||||
|
||||
def atransform(
|
||||
@ -4215,12 +4210,9 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
final = None
|
||||
final: Optional[Output] = None
|
||||
async for output in self.astream(input, config, **kwargs):
|
||||
if final is None:
|
||||
final = output
|
||||
else:
|
||||
final = final + output
|
||||
final = output if final is None else final + output # type: ignore[operator]
|
||||
return cast(Output, final)
|
||||
|
||||
|
||||
|
@ -139,11 +139,11 @@ def _set_config_context(config: RunnableConfig) -> None:
|
||||
None,
|
||||
)
|
||||
)
|
||||
and (run := tracer.run_map.get(str(parent_run_id)))
|
||||
):
|
||||
if run := tracer.run_map.get(str(parent_run_id)):
|
||||
from langsmith.run_helpers import _set_tracing_context
|
||||
from langsmith.run_helpers import _set_tracing_context
|
||||
|
||||
_set_tracing_context({"parent": run})
|
||||
_set_tracing_context({"parent": run})
|
||||
|
||||
|
||||
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
|
@ -99,24 +99,15 @@ class AsciiCanvas:
|
||||
self.point(x0, y0, char)
|
||||
elif abs(dx) >= abs(dy):
|
||||
for x in range(x0, x1 + 1):
|
||||
if dx == 0:
|
||||
y = y0
|
||||
else:
|
||||
y = y0 + int(round((x - x0) * dy / float(dx)))
|
||||
y = y0 if dx == 0 else y0 + int(round((x - x0) * dy / float(dx)))
|
||||
self.point(x, y, char)
|
||||
elif y0 < y1:
|
||||
for y in range(y0, y1 + 1):
|
||||
if dy == 0:
|
||||
x = x0
|
||||
else:
|
||||
x = x0 + int(round((y - y0) * dx / float(dy)))
|
||||
x = x0 if dy == 0 else x0 + int(round((y - y0) * dx / float(dy)))
|
||||
self.point(x, y, char)
|
||||
else:
|
||||
for y in range(y1, y0 + 1):
|
||||
if dy == 0:
|
||||
x = x0
|
||||
else:
|
||||
x = x1 + int(round((y - y1) * dx / float(dy)))
|
||||
x = x0 if dy == 0 else x1 + int(round((y - y1) * dx / float(dy)))
|
||||
self.point(x, y, char)
|
||||
|
||||
def text(self, x: int, y: int, text: str) -> None:
|
||||
|
@ -131,10 +131,7 @@ def draw_mermaid(
|
||||
else:
|
||||
edge_label = f" -- {edge_data} --> "
|
||||
else:
|
||||
if edge.conditional:
|
||||
edge_label = " -.-> "
|
||||
else:
|
||||
edge_label = " --> "
|
||||
edge_label = " -.-> " if edge.conditional else " --> "
|
||||
|
||||
mermaid_graph += (
|
||||
f"\t{_escape_node_label(source)}{edge_label}"
|
||||
@ -142,7 +139,7 @@ def draw_mermaid(
|
||||
)
|
||||
|
||||
# Recursively add nested subgraphs
|
||||
for nested_prefix in edge_groups.keys():
|
||||
for nested_prefix in edge_groups:
|
||||
if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix:
|
||||
continue
|
||||
add_subgraph(edge_groups[nested_prefix], nested_prefix)
|
||||
@ -154,7 +151,7 @@ def draw_mermaid(
|
||||
add_subgraph(edge_groups.get("", []), "")
|
||||
|
||||
# Add remaining subgraphs
|
||||
for prefix in edge_groups.keys():
|
||||
for prefix in edge_groups:
|
||||
if ":" in prefix or prefix == "":
|
||||
continue
|
||||
add_subgraph(edge_groups[prefix], prefix)
|
||||
|
@ -496,12 +496,9 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
||||
Returns:
|
||||
Optional[Addable]: The result of adding the addable objects.
|
||||
"""
|
||||
final = None
|
||||
final: Optional[Addable] = None
|
||||
for chunk in addables:
|
||||
if final is None:
|
||||
final = chunk
|
||||
else:
|
||||
final = final + chunk
|
||||
final = chunk if final is None else final + chunk
|
||||
return final
|
||||
|
||||
|
||||
@ -514,12 +511,9 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
||||
Returns:
|
||||
Optional[Addable]: The result of adding the addable objects.
|
||||
"""
|
||||
final = None
|
||||
final: Optional[Addable] = None
|
||||
async for chunk in addables:
|
||||
if final is None:
|
||||
final = chunk
|
||||
else:
|
||||
final = final + chunk
|
||||
final = chunk if final is None else final + chunk
|
||||
return final
|
||||
|
||||
|
||||
@ -642,9 +636,7 @@ def get_unique_config_specs(
|
||||
for id, dupes in grouped:
|
||||
first = next(dupes)
|
||||
others = list(dupes)
|
||||
if len(others) == 0:
|
||||
unique.append(first)
|
||||
elif all(o == first for o in others):
|
||||
if len(others) == 0 or all(o == first for o in others):
|
||||
unique.append(first)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -258,7 +258,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
||||
if prefix is None:
|
||||
yield from self.store.keys()
|
||||
else:
|
||||
for key in self.store.keys():
|
||||
for key in self.store:
|
||||
if key.startswith(prefix):
|
||||
yield key
|
||||
|
||||
@ -272,10 +272,10 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
||||
AsyncIterator[str]: An async iterator over keys that match the given prefix.
|
||||
"""
|
||||
if prefix is None:
|
||||
for key in self.store.keys():
|
||||
for key in self.store:
|
||||
yield key
|
||||
else:
|
||||
for key in self.store.keys():
|
||||
for key in self.store:
|
||||
if key.startswith(prefix):
|
||||
yield key
|
||||
|
||||
|
@ -19,18 +19,24 @@ class Visitor(ABC):
|
||||
"""Allowed operators for the visitor."""
|
||||
|
||||
def _validate_func(self, func: Union[Operator, Comparator]) -> None:
|
||||
if isinstance(func, Operator) and self.allowed_operators is not None:
|
||||
if func not in self.allowed_operators:
|
||||
raise ValueError(
|
||||
f"Received disallowed operator {func}. Allowed "
|
||||
f"comparators are {self.allowed_operators}"
|
||||
)
|
||||
if isinstance(func, Comparator) and self.allowed_comparators is not None:
|
||||
if func not in self.allowed_comparators:
|
||||
raise ValueError(
|
||||
f"Received disallowed comparator {func}. Allowed "
|
||||
f"comparators are {self.allowed_comparators}"
|
||||
)
|
||||
if (
|
||||
isinstance(func, Operator)
|
||||
and self.allowed_operators is not None
|
||||
and func not in self.allowed_operators
|
||||
):
|
||||
raise ValueError(
|
||||
f"Received disallowed operator {func}. Allowed "
|
||||
f"comparators are {self.allowed_operators}"
|
||||
)
|
||||
if (
|
||||
isinstance(func, Comparator)
|
||||
and self.allowed_comparators is not None
|
||||
and func not in self.allowed_comparators
|
||||
):
|
||||
raise ValueError(
|
||||
f"Received disallowed comparator {func}. Allowed "
|
||||
f"comparators are {self.allowed_comparators}"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def visit_operation(self, operation: Operation) -> Any:
|
||||
|
@ -248,11 +248,8 @@ def create_schema_from_function(
|
||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
||||
|
||||
# Let's ignore `self` and `cls` arguments for class and instance methods
|
||||
if func.__qualname__ and "." in func.__qualname__:
|
||||
# Then it likely belongs in a class namespace
|
||||
in_class = True
|
||||
else:
|
||||
in_class = False
|
||||
# If qualified name has a ".", then it likely belongs in a class namespace
|
||||
in_class = bool(func.__qualname__ and "." in func.__qualname__)
|
||||
|
||||
has_args = False
|
||||
has_kwargs = False
|
||||
@ -289,12 +286,10 @@ def create_schema_from_function(
|
||||
# Pydantic adds placeholder virtual fields we need to strip
|
||||
valid_properties = []
|
||||
for field in get_fields(inferred_model):
|
||||
if not has_args:
|
||||
if field == "args":
|
||||
continue
|
||||
if not has_kwargs:
|
||||
if field == "kwargs":
|
||||
continue
|
||||
if not has_args and field == "args":
|
||||
continue
|
||||
if not has_kwargs and field == "kwargs":
|
||||
continue
|
||||
|
||||
if field == "v__duplicate_kwargs": # Internal pydantic field
|
||||
continue
|
||||
@ -422,12 +417,15 @@ class ChildTool(BaseTool):
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the tool."""
|
||||
if "args_schema" in kwargs and kwargs["args_schema"] is not None:
|
||||
if not is_basemodel_subclass(kwargs["args_schema"]):
|
||||
raise TypeError(
|
||||
f"args_schema must be a subclass of pydantic BaseModel. "
|
||||
f"Got: {kwargs['args_schema']}."
|
||||
)
|
||||
if (
|
||||
"args_schema" in kwargs
|
||||
and kwargs["args_schema"] is not None
|
||||
and not is_basemodel_subclass(kwargs["args_schema"])
|
||||
):
|
||||
raise TypeError(
|
||||
f"args_schema must be a subclass of pydantic BaseModel. "
|
||||
f"Got: {kwargs['args_schema']}."
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
model_config = ConfigDict(
|
||||
@ -840,10 +838,7 @@ def _handle_tool_error(
|
||||
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
||||
) -> str:
|
||||
if isinstance(flag, bool):
|
||||
if e.args:
|
||||
content = e.args[0]
|
||||
else:
|
||||
content = "Tool execution error"
|
||||
content = e.args[0] if e.args else "Tool execution error"
|
||||
elif isinstance(flag, str):
|
||||
content = flag
|
||||
elif callable(flag):
|
||||
@ -902,12 +897,11 @@ def _format_output(
|
||||
|
||||
def _is_message_content_type(obj: Any) -> bool:
|
||||
"""Check for OpenAI or Anthropic format tool message content."""
|
||||
if isinstance(obj, str):
|
||||
return True
|
||||
elif isinstance(obj, list) and all(_is_message_content_block(e) for e in obj):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return (
|
||||
isinstance(obj, str)
|
||||
or isinstance(obj, list)
|
||||
and all(_is_message_content_block(e) for e in obj)
|
||||
)
|
||||
|
||||
|
||||
def _is_message_content_block(obj: Any) -> bool:
|
||||
|
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from typing import (
|
||||
@ -814,10 +815,7 @@ async def _astream_events_implementation_v1(
|
||||
data: EventData = {}
|
||||
log_entry: LogEntry = run_log.state["logs"][path]
|
||||
if log_entry["end_time"] is None:
|
||||
if log_entry["streamed_output"]:
|
||||
event_type = "stream"
|
||||
else:
|
||||
event_type = "start"
|
||||
event_type = "stream" if log_entry["streamed_output"] else "start"
|
||||
else:
|
||||
event_type = "end"
|
||||
|
||||
@ -983,14 +981,15 @@ async def _astream_events_implementation_v2(
|
||||
yield event
|
||||
continue
|
||||
|
||||
if event["run_id"] == first_event_run_id and event["event"].endswith(
|
||||
"_end"
|
||||
# If it's the end event corresponding to the root runnable
|
||||
# we dont include the input in the event since it's guaranteed
|
||||
# to be included in the first event.
|
||||
if (
|
||||
event["run_id"] == first_event_run_id
|
||||
and event["event"].endswith("_end")
|
||||
and "input" in event["data"]
|
||||
):
|
||||
# If it's the end event corresponding to the root runnable
|
||||
# we dont include the input in the event since it's guaranteed
|
||||
# to be included in the first event.
|
||||
if "input" in event["data"]:
|
||||
del event["data"]["input"]
|
||||
del event["data"]["input"]
|
||||
|
||||
yield event
|
||||
except asyncio.CancelledError as exc:
|
||||
@ -1001,7 +1000,5 @@ async def _astream_events_implementation_v2(
|
||||
# Cancel the task if it's still running
|
||||
task.cancel()
|
||||
# Await it anyway, to run any cleanup code, and propagate any exceptions
|
||||
try:
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import copy
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
@ -253,18 +254,22 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
"""
|
||||
async for chunk in output:
|
||||
# root run is handled in .astream_log()
|
||||
if run_id != self.root_id:
|
||||
# if we can't find the run silently ignore
|
||||
# eg. because this run wasn't included in the log
|
||||
if key := self._key_map_by_run_id.get(run_id):
|
||||
if not self.send(
|
||||
# if we can't find the run silently ignore
|
||||
# eg. because this run wasn't included in the log
|
||||
if (
|
||||
run_id != self.root_id
|
||||
and (key := self._key_map_by_run_id.get(run_id))
|
||||
and (
|
||||
not self.send(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{key}/streamed_output/-",
|
||||
"value": chunk,
|
||||
}
|
||||
):
|
||||
break
|
||||
)
|
||||
)
|
||||
):
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
@ -280,18 +285,22 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
"""
|
||||
for chunk in output:
|
||||
# root run is handled in .astream_log()
|
||||
if run_id != self.root_id:
|
||||
# if we can't find the run silently ignore
|
||||
# eg. because this run wasn't included in the log
|
||||
if key := self._key_map_by_run_id.get(run_id):
|
||||
if not self.send(
|
||||
# if we can't find the run silently ignore
|
||||
# eg. because this run wasn't included in the log
|
||||
if (
|
||||
run_id != self.root_id
|
||||
and (key := self._key_map_by_run_id.get(run_id))
|
||||
and (
|
||||
not self.send(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{key}/streamed_output/-",
|
||||
"value": chunk,
|
||||
}
|
||||
):
|
||||
break
|
||||
)
|
||||
)
|
||||
):
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
@ -439,9 +448,8 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
|
||||
self.send(*ops)
|
||||
finally:
|
||||
if run.id == self.root_id:
|
||||
if self.auto_close:
|
||||
self.send_stream.close()
|
||||
if run.id == self.root_id and self.auto_close:
|
||||
self.send_stream.close()
|
||||
|
||||
def _on_llm_new_token(
|
||||
self,
|
||||
@ -662,7 +670,5 @@ async def _astream_log_implementation(
|
||||
yield state
|
||||
finally:
|
||||
# Wait for the runnable to finish, if not cancelled (eg. by break)
|
||||
try:
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
@ -29,9 +29,7 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]
|
||||
merged = left.copy()
|
||||
for right in others:
|
||||
for right_k, right_v in right.items():
|
||||
if right_k not in merged:
|
||||
merged[right_k] = right_v
|
||||
elif right_v is not None and merged[right_k] is None:
|
||||
if right_k not in merged or right_v is not None and merged[right_k] is None:
|
||||
merged[right_k] = right_v
|
||||
elif right_v is None:
|
||||
continue
|
||||
|
@ -43,14 +43,10 @@ def get_from_dict_or_env(
|
||||
if k in data and data[k]:
|
||||
return data[k]
|
||||
|
||||
if isinstance(key, str):
|
||||
if key in data and data[key]:
|
||||
return data[key]
|
||||
if isinstance(key, str) and key in data and data[key]:
|
||||
return data[key]
|
||||
|
||||
if isinstance(key, (list, tuple)):
|
||||
key_for_err = key[0]
|
||||
else:
|
||||
key_for_err = key
|
||||
key_for_err = key[0] if isinstance(key, (list, tuple)) else key
|
||||
|
||||
return get_from_env(key_for_err, env_key, default=default)
|
||||
|
||||
|
@ -64,7 +64,7 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
|
||||
new_kv = {}
|
||||
for k, v in kv.items():
|
||||
if k == "title":
|
||||
if isinstance(v, dict) and prev_key == "properties" and "title" in v.keys():
|
||||
if isinstance(v, dict) and prev_key == "properties" and "title" in v:
|
||||
new_kv[k] = _rm_titles(v, k)
|
||||
else:
|
||||
continue
|
||||
|
@ -139,11 +139,8 @@ def parse_json_markdown(
|
||||
match = _json_markdown_re.search(json_string)
|
||||
|
||||
# If no match found, assume the entire string is a JSON string
|
||||
if match is None:
|
||||
json_str = json_string
|
||||
else:
|
||||
# If match found, use the content within the backticks
|
||||
json_str = match.group(2)
|
||||
# Else, use the content within the backticks
|
||||
json_str = json_string if match is None else match.group(2)
|
||||
return _parse_json(json_str, parser=parser)
|
||||
|
||||
|
||||
|
@ -80,12 +80,9 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
|
||||
padding = literal.split("\n")[-1]
|
||||
|
||||
# If all the characters since the last newline are spaces
|
||||
if padding.isspace() or padding == "":
|
||||
# Then the next tag could be a standalone
|
||||
return True
|
||||
else:
|
||||
# Otherwise it can't be
|
||||
return False
|
||||
# Then the next tag could be a standalone
|
||||
# Otherwise it can't be
|
||||
return padding.isspace() or padding == ""
|
||||
else:
|
||||
return False
|
||||
|
||||
@ -107,10 +104,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
|
||||
on_newline = template.split("\n", 1)
|
||||
|
||||
# If the stuff to the right of us are spaces we're a standalone
|
||||
if on_newline[0].isspace() or not on_newline[0]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return on_newline[0].isspace() or not on_newline[0]
|
||||
|
||||
# If we're a tag can't be a standalone
|
||||
else:
|
||||
@ -174,14 +168,18 @@ def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], s
|
||||
"unclosed set delimiter tag\n" f"at line {_CURRENT_LINE}"
|
||||
)
|
||||
|
||||
# If we might be a no html escape tag
|
||||
elif tag_type == "no escape?":
|
||||
elif (
|
||||
# If we might be a no html escape tag
|
||||
tag_type == "no escape?"
|
||||
# And we have a third curly brace
|
||||
# (And are using curly braces as delimiters)
|
||||
if l_del == "{{" and r_del == "}}" and template.startswith("}"):
|
||||
# Then we are a no html escape tag
|
||||
template = template[1:]
|
||||
tag_type = "no escape"
|
||||
and l_del == "{{"
|
||||
and r_del == "}}"
|
||||
and template.startswith("}")
|
||||
):
|
||||
# Then we are a no html escape tag
|
||||
template = template[1:]
|
||||
tag_type = "no escape"
|
||||
|
||||
# Strip the whitespace off the key and return
|
||||
return ((tag_type, tag.strip()), template)
|
||||
|
@ -179,22 +179,27 @@ def pre_init(func: Callable) -> Any:
|
||||
for name, field_info in fields.items():
|
||||
# Check if allow_population_by_field_name is enabled
|
||||
# If yes, then set the field name to the alias
|
||||
if hasattr(cls, "Config"):
|
||||
if hasattr(cls.Config, "allow_population_by_field_name"):
|
||||
if cls.Config.allow_population_by_field_name:
|
||||
if field_info.alias in values:
|
||||
values[name] = values.pop(field_info.alias)
|
||||
if hasattr(cls, "model_config"):
|
||||
if cls.model_config.get("populate_by_name"):
|
||||
if field_info.alias in values:
|
||||
values[name] = values.pop(field_info.alias)
|
||||
if (
|
||||
hasattr(cls, "Config")
|
||||
and hasattr(cls.Config, "allow_population_by_field_name")
|
||||
and cls.Config.allow_population_by_field_name
|
||||
and field_info.alias in values
|
||||
):
|
||||
values[name] = values.pop(field_info.alias)
|
||||
if (
|
||||
hasattr(cls, "model_config")
|
||||
and cls.model_config.get("populate_by_name")
|
||||
and field_info.alias in values
|
||||
):
|
||||
values[name] = values.pop(field_info.alias)
|
||||
|
||||
if name not in values or values[name] is None:
|
||||
if not field_info.is_required():
|
||||
if field_info.default_factory is not None:
|
||||
values[name] = field_info.default_factory()
|
||||
else:
|
||||
values[name] = field_info.default
|
||||
if (
|
||||
name not in values or values[name] is None
|
||||
) and not field_info.is_required():
|
||||
if field_info.default_factory is not None:
|
||||
values[name] = field_info.default_factory()
|
||||
else:
|
||||
values[name] = field_info.default
|
||||
|
||||
# Call the decorated function
|
||||
return func(cls, values)
|
||||
|
@ -332,9 +332,8 @@ def from_env(
|
||||
for k in key:
|
||||
if k in os.environ:
|
||||
return os.environ[k]
|
||||
if isinstance(key, str):
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
if isinstance(key, str) and key in os.environ:
|
||||
return os.environ[key]
|
||||
|
||||
if isinstance(default, (str, type(None))):
|
||||
return default
|
||||
@ -395,9 +394,8 @@ def secret_from_env(
|
||||
for k in key:
|
||||
if k in os.environ:
|
||||
return SecretStr(os.environ[k])
|
||||
if isinstance(key, str):
|
||||
if key in os.environ:
|
||||
return SecretStr(os.environ[key])
|
||||
if isinstance(key, str) and key in os.environ:
|
||||
return SecretStr(os.environ[key])
|
||||
if isinstance(default, str):
|
||||
return SecretStr(default)
|
||||
elif isinstance(default, type(None)):
|
||||
|
@ -44,7 +44,7 @@ python = ">=3.12.4"
|
||||
[tool.poetry.extras]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [ "B", "C4", "E", "F", "I", "N", "PIE", "T201", "UP",]
|
||||
select = [ "B", "C4", "E", "F", "I", "N", "PIE", "SIM", "T201", "UP",]
|
||||
ignore = [ "UP007",]
|
||||
|
||||
[tool.coverage.run]
|
||||
|
@ -148,7 +148,7 @@ def test_pydantic_output_parser() -> None:
|
||||
|
||||
result = pydantic_parser.parse(DEF_RESULT)
|
||||
print("parse_result:", result) # noqa: T201
|
||||
assert DEF_EXPECTED_RESULT == result
|
||||
assert result == DEF_EXPECTED_RESULT
|
||||
assert pydantic_parser.OutputType is TestModel
|
||||
|
||||
|
||||
|
@ -115,10 +115,7 @@ def _normalize_schema(obj: Any) -> dict[str, Any]:
|
||||
Args:
|
||||
obj: The object to generate the schema for
|
||||
"""
|
||||
if isinstance(obj, BaseModel):
|
||||
data = obj.model_json_schema()
|
||||
else:
|
||||
data = obj
|
||||
data = obj.model_json_schema() if isinstance(obj, BaseModel) else obj
|
||||
remove_all_none_default(data)
|
||||
replace_all_of_with_ref(data)
|
||||
_remove_enum(data)
|
||||
|
@ -3345,9 +3345,9 @@ def test_bind_with_lambda() -> None:
|
||||
return 3 + kwargs.get("n", 0)
|
||||
|
||||
runnable = RunnableLambda(my_function).bind(n=1)
|
||||
assert 4 == runnable.invoke({})
|
||||
assert runnable.invoke({}) == 4
|
||||
chunks = list(runnable.stream({}))
|
||||
assert [4] == chunks
|
||||
assert chunks == [4]
|
||||
|
||||
|
||||
async def test_bind_with_lambda_async() -> None:
|
||||
@ -3355,9 +3355,9 @@ async def test_bind_with_lambda_async() -> None:
|
||||
return 3 + kwargs.get("n", 0)
|
||||
|
||||
runnable = RunnableLambda(my_function).bind(n=1)
|
||||
assert 4 == await runnable.ainvoke({})
|
||||
assert await runnable.ainvoke({}) == 4
|
||||
chunks = [item async for item in runnable.astream({})]
|
||||
assert [4] == chunks
|
||||
assert chunks == [4]
|
||||
|
||||
|
||||
def test_deep_stream() -> None:
|
||||
@ -5140,13 +5140,10 @@ async def test_astream_log_deep_copies() -> None:
|
||||
|
||||
chain = RunnableLambda(add_one)
|
||||
chunks = []
|
||||
final_output = None
|
||||
final_output: Optional[RunLogPatch] = None
|
||||
async for chunk in chain.astream_log(1):
|
||||
chunks.append(chunk)
|
||||
if final_output is None:
|
||||
final_output = chunk
|
||||
else:
|
||||
final_output = final_output + chunk
|
||||
final_output = chunk if final_output is None else final_output + chunk
|
||||
|
||||
run_log = _get_run_log(chunks)
|
||||
state = run_log.state.copy()
|
||||
|
@ -209,9 +209,11 @@ def test_tracing_enable_disable(
|
||||
|
||||
get_env_var.cache_clear()
|
||||
env_on = env == "true"
|
||||
with patch.dict("os.environ", {"LANGSMITH_TRACING": env}):
|
||||
with tracing_context(enabled=enabled):
|
||||
RunnableLambda(my_func).invoke(1)
|
||||
with (
|
||||
patch.dict("os.environ", {"LANGSMITH_TRACING": env}),
|
||||
tracing_context(enabled=enabled),
|
||||
):
|
||||
RunnableLambda(my_func).invoke(1)
|
||||
|
||||
mock_posts = _get_posts(mock_client_)
|
||||
if enabled is True:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@ -10,6 +10,7 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
@ -630,14 +631,11 @@ def test_tool_calls_merge() -> None:
|
||||
{"content": ""},
|
||||
]
|
||||
|
||||
final = None
|
||||
final: Optional[BaseMessageChunk] = None
|
||||
|
||||
for chunk in chunks:
|
||||
msg = AIMessageChunk(**chunk)
|
||||
if final is None:
|
||||
final = msg
|
||||
else:
|
||||
final = final + msg
|
||||
final = msg if final is None else final + msg
|
||||
|
||||
assert final == AIMessageChunk(
|
||||
content="",
|
||||
|
@ -133,25 +133,24 @@ class LangChainProjectNameTest(unittest.TestCase):
|
||||
for case in cases:
|
||||
get_env_var.cache_clear()
|
||||
get_tracer_project.cache_clear()
|
||||
with self.subTest(msg=case.test_name):
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
for k, v in case.envvars.items():
|
||||
mp.setenv(k, v)
|
||||
with self.subTest(msg=case.test_name), pytest.MonkeyPatch.context() as mp:
|
||||
for k, v in case.envvars.items():
|
||||
mp.setenv(k, v)
|
||||
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
tracer = LangChainTracer(client=client)
|
||||
projects = []
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
tracer = LangChainTracer(client=client)
|
||||
projects = []
|
||||
|
||||
def mock_create_run(**kwargs: Any) -> Any:
|
||||
projects.append(kwargs.get("project_name")) # noqa: B023
|
||||
return unittest.mock.MagicMock()
|
||||
def mock_create_run(**kwargs: Any) -> Any:
|
||||
projects.append(kwargs.get("project_name")) # noqa: B023
|
||||
return unittest.mock.MagicMock()
|
||||
|
||||
client.create_run = mock_create_run
|
||||
client.create_run = mock_create_run
|
||||
|
||||
tracer.on_llm_start(
|
||||
{"name": "example_1"},
|
||||
["foo"],
|
||||
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
|
||||
)
|
||||
tracer.wait_for_futures()
|
||||
assert projects == [case.expected_project_name]
|
||||
tracer.on_llm_start(
|
||||
{"name": "example_1"},
|
||||
["foo"],
|
||||
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
|
||||
)
|
||||
tracer.wait_for_futures()
|
||||
assert projects == [case.expected_project_name]
|
||||
|
@ -513,14 +513,8 @@ def test_tool_outputs() -> None:
|
||||
def test__convert_typed_dict_to_openai_function(
|
||||
use_extension_typed_dict: bool, use_extension_annotated: bool
|
||||
) -> None:
|
||||
if use_extension_typed_dict:
|
||||
typed_dict = ExtensionsTypedDict
|
||||
else:
|
||||
typed_dict = TypingTypedDict
|
||||
if use_extension_annotated:
|
||||
annotated = TypingAnnotated
|
||||
else:
|
||||
annotated = TypingAnnotated
|
||||
typed_dict = ExtensionsTypedDict if use_extension_typed_dict else TypingTypedDict
|
||||
annotated = TypingAnnotated if use_extension_annotated else TypingAnnotated
|
||||
|
||||
class SubTool(typed_dict):
|
||||
"""Subtool docstring"""
|
||||
|
@ -116,10 +116,7 @@ def test_check_package_version(
|
||||
def test_merge_dicts(
|
||||
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
|
||||
) -> None:
|
||||
if isinstance(expected, AbstractContextManager):
|
||||
err = expected
|
||||
else:
|
||||
err = nullcontext()
|
||||
err = expected if isinstance(expected, AbstractContextManager) else nullcontext()
|
||||
|
||||
left_copy = deepcopy(left)
|
||||
right_copy = deepcopy(right)
|
||||
@ -147,10 +144,7 @@ def test_merge_dicts(
|
||||
def test_merge_dicts_0_3(
|
||||
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
|
||||
) -> None:
|
||||
if isinstance(expected, AbstractContextManager):
|
||||
err = expected
|
||||
else:
|
||||
err = nullcontext()
|
||||
err = expected if isinstance(expected, AbstractContextManager) else nullcontext()
|
||||
|
||||
left_copy = deepcopy(left)
|
||||
right_copy = deepcopy(right)
|
||||
|
Loading…
Reference in New Issue
Block a user