mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 15:46:47 +00:00
core[patch]: Add ruff rules for flake8-simplify (SIM) (#26848)
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
de0b48c41a
commit
7809b31b95
@ -127,10 +127,9 @@ def beta(
|
|||||||
|
|
||||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
||||||
"""Finalize the annotation of a class."""
|
"""Finalize the annotation of a class."""
|
||||||
try:
|
# Can't set new_doc on some extension objects.
|
||||||
|
with contextlib.suppress(AttributeError):
|
||||||
obj.__doc__ = new_doc
|
obj.__doc__ = new_doc
|
||||||
except AttributeError: # Can't set on some extension objects.
|
|
||||||
pass
|
|
||||||
|
|
||||||
def warn_if_direct_instance(
|
def warn_if_direct_instance(
|
||||||
self: Any, *args: Any, **kwargs: Any
|
self: Any, *args: Any, **kwargs: Any
|
||||||
|
@ -198,10 +198,9 @@ def deprecated(
|
|||||||
|
|
||||||
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
|
||||||
"""Finalize the deprecation of a class."""
|
"""Finalize the deprecation of a class."""
|
||||||
try:
|
# Can't set new_doc on some extension objects.
|
||||||
|
with contextlib.suppress(AttributeError):
|
||||||
obj.__doc__ = new_doc
|
obj.__doc__ = new_doc
|
||||||
except AttributeError: # Can't set on some extension objects.
|
|
||||||
pass
|
|
||||||
|
|
||||||
def warn_if_direct_instance(
|
def warn_if_direct_instance(
|
||||||
self: Any, *args: Any, **kwargs: Any
|
self: Any, *args: Any, **kwargs: Any
|
||||||
|
@ -27,7 +27,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
|||||||
mode: The mode to open the file in. Defaults to "a".
|
mode: The mode to open the file in. Defaults to "a".
|
||||||
color: The color to use for the text. Defaults to None.
|
color: The color to use for the text. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.file = cast(TextIO, open(filename, mode, encoding="utf-8"))
|
self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) # noqa: SIM115
|
||||||
self.color = color
|
self.color = color
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
|
@ -2252,14 +2252,14 @@ def _configure(
|
|||||||
else:
|
else:
|
||||||
parent_run_id_ = inheritable_callbacks.parent_run_id
|
parent_run_id_ = inheritable_callbacks.parent_run_id
|
||||||
# Break ties between the external tracing context and inherited context
|
# Break ties between the external tracing context and inherited context
|
||||||
if parent_run_id is not None:
|
if parent_run_id is not None and (
|
||||||
if parent_run_id_ is None:
|
parent_run_id_ is None
|
||||||
parent_run_id_ = parent_run_id
|
|
||||||
# If the LC parent has already been reflected
|
# If the LC parent has already been reflected
|
||||||
# in the run tree, we know the run_tree is either the
|
# in the run tree, we know the run_tree is either the
|
||||||
# same parent or a child of the parent.
|
# same parent or a child of the parent.
|
||||||
elif run_tree and str(parent_run_id_) in run_tree.dotted_order:
|
or (run_tree and str(parent_run_id_) in run_tree.dotted_order)
|
||||||
parent_run_id_ = parent_run_id
|
):
|
||||||
|
parent_run_id_ = parent_run_id
|
||||||
# Otherwise, we assume the LC context has progressed
|
# Otherwise, we assume the LC context has progressed
|
||||||
# beyond the run tree and we should not inherit the parent.
|
# beyond the run tree and we should not inherit the parent.
|
||||||
callback_manager = callback_manager_cls(
|
callback_manager = callback_manager_cls(
|
||||||
|
@ -40,12 +40,11 @@ class OutputParserException(ValueError, LangChainException): # noqa: N818
|
|||||||
send_to_llm: bool = False,
|
send_to_llm: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(error)
|
super().__init__(error)
|
||||||
if send_to_llm:
|
if send_to_llm and (observation is None or llm_output is None):
|
||||||
if observation is None or llm_output is None:
|
raise ValueError(
|
||||||
raise ValueError(
|
"Arguments 'observation' & 'llm_output'"
|
||||||
"Arguments 'observation' & 'llm_output'"
|
" are required if 'send_to_llm' is True"
|
||||||
" are required if 'send_to_llm' is True"
|
)
|
||||||
)
|
|
||||||
self.observation = observation
|
self.observation = observation
|
||||||
self.llm_output = llm_output
|
self.llm_output = llm_output
|
||||||
self.send_to_llm = send_to_llm
|
self.send_to_llm = send_to_llm
|
||||||
|
@ -92,7 +92,7 @@ class _HashedDocument(Document):
|
|||||||
values["metadata_hash"] = metadata_hash
|
values["metadata_hash"] = metadata_hash
|
||||||
values["hash_"] = str(_hash_string_to_uuid(content_hash + metadata_hash))
|
values["hash_"] = str(_hash_string_to_uuid(content_hash + metadata_hash))
|
||||||
|
|
||||||
_uid = values.get("uid", None)
|
_uid = values.get("uid")
|
||||||
|
|
||||||
if _uid is None:
|
if _uid is None:
|
||||||
values["uid"] = values["hash_"]
|
values["uid"] = values["hash_"]
|
||||||
|
@ -802,10 +802,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
if isinstance(self.cache, BaseCache):
|
llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache()
|
||||||
llm_cache = self.cache
|
|
||||||
else:
|
|
||||||
llm_cache = get_llm_cache()
|
|
||||||
# We should check the cache unless it's explicitly set to False
|
# We should check the cache unless it's explicitly set to False
|
||||||
# A None cache means we should use the default global cache
|
# A None cache means we should use the default global cache
|
||||||
# if it's configured.
|
# if it's configured.
|
||||||
@ -879,10 +876,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
if isinstance(self.cache, BaseCache):
|
llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache()
|
||||||
llm_cache = self.cache
|
|
||||||
else:
|
|
||||||
llm_cache = get_llm_cache()
|
|
||||||
# We should check the cache unless it's explicitly set to False
|
# We should check the cache unless it's explicitly set to False
|
||||||
# A None cache means we should use the default global cache
|
# A None cache means we should use the default global cache
|
||||||
# if it's configured.
|
# if it's configured.
|
||||||
@ -1054,10 +1048,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
def predict(
|
def predict(
|
||||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
if stop is None:
|
_stop = None if stop is None else list(stop)
|
||||||
_stop = None
|
|
||||||
else:
|
|
||||||
_stop = list(stop)
|
|
||||||
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
|
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
|
||||||
if isinstance(result.content, str):
|
if isinstance(result.content, str):
|
||||||
return result.content
|
return result.content
|
||||||
@ -1072,20 +1063,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
if stop is None:
|
_stop = None if stop is None else list(stop)
|
||||||
_stop = None
|
|
||||||
else:
|
|
||||||
_stop = list(stop)
|
|
||||||
return self(messages, stop=_stop, **kwargs)
|
return self(messages, stop=_stop, **kwargs)
|
||||||
|
|
||||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||||
async def apredict(
|
async def apredict(
|
||||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
if stop is None:
|
_stop = None if stop is None else list(stop)
|
||||||
_stop = None
|
|
||||||
else:
|
|
||||||
_stop = list(stop)
|
|
||||||
result = await self._call_async(
|
result = await self._call_async(
|
||||||
[HumanMessage(content=text)], stop=_stop, **kwargs
|
[HumanMessage(content=text)], stop=_stop, **kwargs
|
||||||
)
|
)
|
||||||
@ -1102,10 +1087,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
if stop is None:
|
_stop = None if stop is None else list(stop)
|
||||||
_stop = None
|
|
||||||
else:
|
|
||||||
_stop = list(stop)
|
|
||||||
return await self._call_async(messages, stop=_stop, **kwargs)
|
return await self._call_async(messages, stop=_stop, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1333,9 +1315,12 @@ def _cleanup_llm_representation(serialized: Any, depth: int) -> None:
|
|||||||
if not isinstance(serialized, dict):
|
if not isinstance(serialized, dict):
|
||||||
return
|
return
|
||||||
|
|
||||||
if "type" in serialized and serialized["type"] == "not_implemented":
|
if (
|
||||||
if "repr" in serialized:
|
"type" in serialized
|
||||||
del serialized["repr"]
|
and serialized["type"] == "not_implemented"
|
||||||
|
and "repr" in serialized
|
||||||
|
):
|
||||||
|
del serialized["repr"]
|
||||||
|
|
||||||
if "graph" in serialized:
|
if "graph" in serialized:
|
||||||
del serialized["graph"]
|
del serialized["graph"]
|
||||||
|
@ -194,10 +194,7 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
message = next(self.messages)
|
message = next(self.messages)
|
||||||
if isinstance(message, str):
|
message_ = AIMessage(content=message) if isinstance(message, str) else message
|
||||||
message_ = AIMessage(content=message)
|
|
||||||
else:
|
|
||||||
message_ = message
|
|
||||||
generation = ChatGeneration(message=message_)
|
generation = ChatGeneration(message=message_)
|
||||||
return ChatResult(generations=[generation])
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
|
@ -1305,10 +1305,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
def predict(
|
def predict(
|
||||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
if stop is None:
|
_stop = None if stop is None else list(stop)
|
||||||
_stop = None
|
|
||||||
else:
|
|
||||||
_stop = list(stop)
|
|
||||||
return self(text, stop=_stop, **kwargs)
|
return self(text, stop=_stop, **kwargs)
|
||||||
|
|
||||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||||
@ -1320,10 +1317,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
text = get_buffer_string(messages)
|
text = get_buffer_string(messages)
|
||||||
if stop is None:
|
_stop = None if stop is None else list(stop)
|
||||||
_stop = None
|
|
||||||
else:
|
|
||||||
_stop = list(stop)
|
|
||||||
content = self(text, stop=_stop, **kwargs)
|
content = self(text, stop=_stop, **kwargs)
|
||||||
return AIMessage(content=content)
|
return AIMessage(content=content)
|
||||||
|
|
||||||
@ -1331,10 +1325,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
async def apredict(
|
async def apredict(
|
||||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
if stop is None:
|
_stop = None if stop is None else list(stop)
|
||||||
_stop = None
|
|
||||||
else:
|
|
||||||
_stop = list(stop)
|
|
||||||
return await self._call_async(text, stop=_stop, **kwargs)
|
return await self._call_async(text, stop=_stop, **kwargs)
|
||||||
|
|
||||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||||
@ -1346,10 +1337,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
text = get_buffer_string(messages)
|
text = get_buffer_string(messages)
|
||||||
if stop is None:
|
_stop = None if stop is None else list(stop)
|
||||||
_stop = None
|
|
||||||
else:
|
|
||||||
_stop = list(stop)
|
|
||||||
content = await self._call_async(text, stop=_stop, **kwargs)
|
content = await self._call_async(text, stop=_stop, **kwargs)
|
||||||
return AIMessage(content=content)
|
return AIMessage(content=content)
|
||||||
|
|
||||||
@ -1384,10 +1372,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
llm.save(file_path="path/llm.yaml")
|
llm.save(file_path="path/llm.yaml")
|
||||||
"""
|
"""
|
||||||
# Convert file to Path object.
|
# Convert file to Path object.
|
||||||
if isinstance(file_path, str):
|
save_path = Path(file_path) if isinstance(file_path, str) else file_path
|
||||||
save_path = Path(file_path)
|
|
||||||
else:
|
|
||||||
save_path = file_path
|
|
||||||
|
|
||||||
directory_path = save_path.parent
|
directory_path = save_path.parent
|
||||||
directory_path.mkdir(parents=True, exist_ok=True)
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -86,9 +86,9 @@ class Reviver:
|
|||||||
|
|
||||||
def __call__(self, value: dict[str, Any]) -> Any:
|
def __call__(self, value: dict[str, Any]) -> Any:
|
||||||
if (
|
if (
|
||||||
value.get("lc", None) == 1
|
value.get("lc") == 1
|
||||||
and value.get("type", None) == "secret"
|
and value.get("type") == "secret"
|
||||||
and value.get("id", None) is not None
|
and value.get("id") is not None
|
||||||
):
|
):
|
||||||
[key] = value["id"]
|
[key] = value["id"]
|
||||||
if key in self.secrets_map:
|
if key in self.secrets_map:
|
||||||
@ -99,9 +99,9 @@ class Reviver:
|
|||||||
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
|
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
|
||||||
|
|
||||||
if (
|
if (
|
||||||
value.get("lc", None) == 1
|
value.get("lc") == 1
|
||||||
and value.get("type", None) == "not_implemented"
|
and value.get("type") == "not_implemented"
|
||||||
and value.get("id", None) is not None
|
and value.get("id") is not None
|
||||||
):
|
):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Trying to load an object that doesn't implement "
|
"Trying to load an object that doesn't implement "
|
||||||
@ -109,17 +109,18 @@ class Reviver:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
value.get("lc", None) == 1
|
value.get("lc") == 1
|
||||||
and value.get("type", None) == "constructor"
|
and value.get("type") == "constructor"
|
||||||
and value.get("id", None) is not None
|
and value.get("id") is not None
|
||||||
):
|
):
|
||||||
[*namespace, name] = value["id"]
|
[*namespace, name] = value["id"]
|
||||||
mapping_key = tuple(value["id"])
|
mapping_key = tuple(value["id"])
|
||||||
|
|
||||||
if namespace[0] not in self.valid_namespaces:
|
if (
|
||||||
raise ValueError(f"Invalid namespace: {value}")
|
namespace[0] not in self.valid_namespaces
|
||||||
# The root namespace ["langchain"] is not a valid identifier.
|
# The root namespace ["langchain"] is not a valid identifier.
|
||||||
elif namespace == ["langchain"]:
|
or namespace == ["langchain"]
|
||||||
|
):
|
||||||
raise ValueError(f"Invalid namespace: {value}")
|
raise ValueError(f"Invalid namespace: {value}")
|
||||||
# Has explicit import path.
|
# Has explicit import path.
|
||||||
elif mapping_key in self.import_mappings:
|
elif mapping_key in self.import_mappings:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -238,7 +239,7 @@ class Serializable(BaseModel, ABC):
|
|||||||
|
|
||||||
# include all secrets, even if not specified in kwargs
|
# include all secrets, even if not specified in kwargs
|
||||||
# as these secrets may be passed as an environment variable instead
|
# as these secrets may be passed as an environment variable instead
|
||||||
for key in secrets.keys():
|
for key in secrets:
|
||||||
secret_value = getattr(self, key, None) or lc_kwargs.get(key)
|
secret_value = getattr(self, key, None) or lc_kwargs.get(key)
|
||||||
if secret_value is not None:
|
if secret_value is not None:
|
||||||
lc_kwargs.update({key: secret_value})
|
lc_kwargs.update({key: secret_value})
|
||||||
@ -357,8 +358,6 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
|||||||
"id": _id,
|
"id": _id,
|
||||||
"repr": None,
|
"repr": None,
|
||||||
}
|
}
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
result["repr"] = repr(obj)
|
result["repr"] = repr(obj)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return result
|
return result
|
||||||
|
@ -435,23 +435,22 @@ def filter_messages(
|
|||||||
messages = convert_to_messages(messages)
|
messages = convert_to_messages(messages)
|
||||||
filtered: list[BaseMessage] = []
|
filtered: list[BaseMessage] = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if exclude_names and msg.name in exclude_names:
|
if (
|
||||||
continue
|
(exclude_names and msg.name in exclude_names)
|
||||||
elif exclude_types and _is_message_type(msg, exclude_types):
|
or (exclude_types and _is_message_type(msg, exclude_types))
|
||||||
continue
|
or (exclude_ids and msg.id in exclude_ids)
|
||||||
elif exclude_ids and msg.id in exclude_ids:
|
):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# default to inclusion when no inclusion criteria given.
|
# default to inclusion when no inclusion criteria given.
|
||||||
if not (include_types or include_ids or include_names):
|
if (
|
||||||
filtered.append(msg)
|
not (include_types or include_ids or include_names)
|
||||||
elif include_names and msg.name in include_names:
|
or (include_names and msg.name in include_names)
|
||||||
filtered.append(msg)
|
or (include_types and _is_message_type(msg, include_types))
|
||||||
elif include_types and _is_message_type(msg, include_types):
|
or (include_ids and msg.id in include_ids)
|
||||||
filtered.append(msg)
|
):
|
||||||
elif include_ids and msg.id in include_ids:
|
|
||||||
filtered.append(msg)
|
filtered.append(msg)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
@ -961,10 +960,7 @@ def _last_max_tokens(
|
|||||||
while messages and not _is_message_type(messages[-1], end_on):
|
while messages and not _is_message_type(messages[-1], end_on):
|
||||||
messages.pop()
|
messages.pop()
|
||||||
swapped_system = include_system and isinstance(messages[0], SystemMessage)
|
swapped_system = include_system and isinstance(messages[0], SystemMessage)
|
||||||
if swapped_system:
|
reversed_ = messages[:1] + messages[1:][::-1] if swapped_system else messages[::-1]
|
||||||
reversed_ = messages[:1] + messages[1:][::-1]
|
|
||||||
else:
|
|
||||||
reversed_ = messages[::-1]
|
|
||||||
|
|
||||||
reversed_ = _first_max_tokens(
|
reversed_ = _first_max_tokens(
|
||||||
reversed_,
|
reversed_,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@ -311,8 +312,6 @@ class BaseOutputParser(
|
|||||||
def dict(self, **kwargs: Any) -> dict:
|
def dict(self, **kwargs: Any) -> dict:
|
||||||
"""Return dictionary representation of output parser."""
|
"""Return dictionary representation of output parser."""
|
||||||
output_parser_dict = super().dict(**kwargs)
|
output_parser_dict = super().dict(**kwargs)
|
||||||
try:
|
with contextlib.suppress(NotImplementedError):
|
||||||
output_parser_dict["_type"] = self._type
|
output_parser_dict["_type"] = self._type
|
||||||
except NotImplementedError:
|
|
||||||
pass
|
|
||||||
return output_parser_dict
|
return output_parser_dict
|
||||||
|
@ -49,12 +49,10 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
return jsonpatch.make_patch(prev, next).patch
|
return jsonpatch.make_patch(prev, next).patch
|
||||||
|
|
||||||
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
|
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
|
||||||
if PYDANTIC_MAJOR_VERSION == 2:
|
if issubclass(pydantic_object, pydantic.BaseModel):
|
||||||
if issubclass(pydantic_object, pydantic.BaseModel):
|
return pydantic_object.model_json_schema()
|
||||||
return pydantic_object.model_json_schema()
|
elif issubclass(pydantic_object, pydantic.v1.BaseModel):
|
||||||
elif issubclass(pydantic_object, pydantic.v1.BaseModel):
|
return pydantic_object.schema()
|
||||||
return pydantic_object.model_json_schema()
|
|
||||||
return pydantic_object.model_json_schema()
|
|
||||||
|
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
@ -110,10 +110,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
|
|
||||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||||
prev_parsed = None
|
prev_parsed = None
|
||||||
acc_gen = None
|
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
|
||||||
for chunk in input:
|
for chunk in input:
|
||||||
|
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
|
||||||
if isinstance(chunk, BaseMessageChunk):
|
if isinstance(chunk, BaseMessageChunk):
|
||||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||||
elif isinstance(chunk, BaseMessage):
|
elif isinstance(chunk, BaseMessage):
|
||||||
chunk_gen = ChatGenerationChunk(
|
chunk_gen = ChatGenerationChunk(
|
||||||
message=BaseMessageChunk(**chunk.dict())
|
message=BaseMessageChunk(**chunk.dict())
|
||||||
@ -121,10 +122,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
else:
|
else:
|
||||||
chunk_gen = GenerationChunk(text=chunk)
|
chunk_gen = GenerationChunk(text=chunk)
|
||||||
|
|
||||||
if acc_gen is None:
|
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
||||||
acc_gen = chunk_gen
|
|
||||||
else:
|
|
||||||
acc_gen = acc_gen + chunk_gen
|
|
||||||
|
|
||||||
parsed = self.parse_result([acc_gen], partial=True)
|
parsed = self.parse_result([acc_gen], partial=True)
|
||||||
if parsed is not None and parsed != prev_parsed:
|
if parsed is not None and parsed != prev_parsed:
|
||||||
@ -138,10 +136,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||||
) -> AsyncIterator[T]:
|
) -> AsyncIterator[T]:
|
||||||
prev_parsed = None
|
prev_parsed = None
|
||||||
acc_gen = None
|
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
|
||||||
async for chunk in input:
|
async for chunk in input:
|
||||||
|
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
|
||||||
if isinstance(chunk, BaseMessageChunk):
|
if isinstance(chunk, BaseMessageChunk):
|
||||||
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
||||||
elif isinstance(chunk, BaseMessage):
|
elif isinstance(chunk, BaseMessage):
|
||||||
chunk_gen = ChatGenerationChunk(
|
chunk_gen = ChatGenerationChunk(
|
||||||
message=BaseMessageChunk(**chunk.dict())
|
message=BaseMessageChunk(**chunk.dict())
|
||||||
@ -149,10 +148,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|||||||
else:
|
else:
|
||||||
chunk_gen = GenerationChunk(text=chunk)
|
chunk_gen = GenerationChunk(text=chunk)
|
||||||
|
|
||||||
if acc_gen is None:
|
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
|
||||||
acc_gen = chunk_gen
|
|
||||||
else:
|
|
||||||
acc_gen = acc_gen + chunk_gen
|
|
||||||
|
|
||||||
parsed = await self.aparse_result([acc_gen], partial=True)
|
parsed = await self.aparse_result([acc_gen], partial=True)
|
||||||
if parsed is not None and parsed != prev_parsed:
|
if parsed is not None and parsed != prev_parsed:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
import re
|
import re
|
||||||
import xml
|
import xml
|
||||||
import xml.etree.ElementTree as ET # noqa: N817
|
import xml.etree.ElementTree as ET # noqa: N817
|
||||||
@ -131,11 +132,9 @@ class _StreamingParser:
|
|||||||
Raises:
|
Raises:
|
||||||
xml.etree.ElementTree.ParseError: If the XML is not well-formed.
|
xml.etree.ElementTree.ParseError: If the XML is not well-formed.
|
||||||
"""
|
"""
|
||||||
try:
|
# Ignore ParseError. This will ignore any incomplete XML at the end of the input
|
||||||
|
with contextlib.suppress(xml.etree.ElementTree.ParseError):
|
||||||
self.pull_parser.close()
|
self.pull_parser.close()
|
||||||
except xml.etree.ElementTree.ParseError:
|
|
||||||
# Ignore. This will ignore any incomplete XML at the end of the input
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class XMLOutputParser(BaseTransformOutputParser):
|
class XMLOutputParser(BaseTransformOutputParser):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import typing
|
import typing
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -319,10 +320,8 @@ class BasePromptTemplate(
|
|||||||
NotImplementedError: If the prompt type is not implemented.
|
NotImplementedError: If the prompt type is not implemented.
|
||||||
"""
|
"""
|
||||||
prompt_dict = super().model_dump(**kwargs)
|
prompt_dict = super().model_dump(**kwargs)
|
||||||
try:
|
with contextlib.suppress(NotImplementedError):
|
||||||
prompt_dict["_type"] = self._prompt_type
|
prompt_dict["_type"] = self._prompt_type
|
||||||
except NotImplementedError:
|
|
||||||
pass
|
|
||||||
return prompt_dict
|
return prompt_dict
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
@ -350,10 +349,7 @@ class BasePromptTemplate(
|
|||||||
raise NotImplementedError(f"Prompt {self} does not support saving.")
|
raise NotImplementedError(f"Prompt {self} does not support saving.")
|
||||||
|
|
||||||
# Convert file to Path object.
|
# Convert file to Path object.
|
||||||
if isinstance(file_path, str):
|
save_path = Path(file_path) if isinstance(file_path, str) else file_path
|
||||||
save_path = Path(file_path)
|
|
||||||
else:
|
|
||||||
save_path = file_path
|
|
||||||
|
|
||||||
directory_path = save_path.parent
|
directory_path = save_path.parent
|
||||||
directory_path.mkdir(parents=True, exist_ok=True)
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -59,8 +59,8 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
|||||||
ValueError: If neither or both examples and example_selector are provided.
|
ValueError: If neither or both examples and example_selector are provided.
|
||||||
ValueError: If both examples and example_selector are provided.
|
ValueError: If both examples and example_selector are provided.
|
||||||
"""
|
"""
|
||||||
examples = values.get("examples", None)
|
examples = values.get("examples")
|
||||||
example_selector = values.get("example_selector", None)
|
example_selector = values.get("example_selector")
|
||||||
if examples and example_selector:
|
if examples and example_selector:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only one of 'examples' and 'example_selector' should be provided"
|
"Only one of 'examples' and 'example_selector' should be provided"
|
||||||
|
@ -51,8 +51,8 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def check_examples_and_selector(cls, values: dict) -> Any:
|
def check_examples_and_selector(cls, values: dict) -> Any:
|
||||||
"""Check that one and only one of examples/example_selector are provided."""
|
"""Check that one and only one of examples/example_selector are provided."""
|
||||||
examples = values.get("examples", None)
|
examples = values.get("examples")
|
||||||
example_selector = values.get("example_selector", None)
|
example_selector = values.get("example_selector")
|
||||||
if examples and example_selector:
|
if examples and example_selector:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only one of 'examples' and 'example_selector' should be provided"
|
"Only one of 'examples' and 'example_selector' should be provided"
|
||||||
@ -138,7 +138,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
prefix_kwargs = {
|
prefix_kwargs = {
|
||||||
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
|
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
|
||||||
}
|
}
|
||||||
for k in prefix_kwargs.keys():
|
for k in prefix_kwargs:
|
||||||
kwargs.pop(k)
|
kwargs.pop(k)
|
||||||
prefix = self.prefix.format(**prefix_kwargs)
|
prefix = self.prefix.format(**prefix_kwargs)
|
||||||
|
|
||||||
@ -146,7 +146,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
suffix_kwargs = {
|
suffix_kwargs = {
|
||||||
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
|
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
|
||||||
}
|
}
|
||||||
for k in suffix_kwargs.keys():
|
for k in suffix_kwargs:
|
||||||
kwargs.pop(k)
|
kwargs.pop(k)
|
||||||
suffix = self.suffix.format(
|
suffix = self.suffix.format(
|
||||||
**suffix_kwargs,
|
**suffix_kwargs,
|
||||||
@ -182,7 +182,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
prefix_kwargs = {
|
prefix_kwargs = {
|
||||||
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
|
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
|
||||||
}
|
}
|
||||||
for k in prefix_kwargs.keys():
|
for k in prefix_kwargs:
|
||||||
kwargs.pop(k)
|
kwargs.pop(k)
|
||||||
prefix = await self.prefix.aformat(**prefix_kwargs)
|
prefix = await self.prefix.aformat(**prefix_kwargs)
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
suffix_kwargs = {
|
suffix_kwargs = {
|
||||||
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
|
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
|
||||||
}
|
}
|
||||||
for k in suffix_kwargs.keys():
|
for k in suffix_kwargs:
|
||||||
kwargs.pop(k)
|
kwargs.pop(k)
|
||||||
suffix = await self.suffix.aformat(
|
suffix = await self.suffix.aformat(
|
||||||
**suffix_kwargs,
|
**suffix_kwargs,
|
||||||
|
@ -164,10 +164,7 @@ def _load_prompt_from_file(
|
|||||||
) -> BasePromptTemplate:
|
) -> BasePromptTemplate:
|
||||||
"""Load prompt from file."""
|
"""Load prompt from file."""
|
||||||
# Convert file to a Path object.
|
# Convert file to a Path object.
|
||||||
if isinstance(file, str):
|
file_path = Path(file) if isinstance(file, str) else file
|
||||||
file_path = Path(file)
|
|
||||||
else:
|
|
||||||
file_path = file
|
|
||||||
# Load from either json or yaml.
|
# Load from either json or yaml.
|
||||||
if file_path.suffix == ".json":
|
if file_path.suffix == ".json":
|
||||||
with open(file_path, encoding=encoding) as f:
|
with open(file_path, encoding=encoding) as f:
|
||||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import threading
|
import threading
|
||||||
@ -2466,10 +2467,8 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
|||||||
A JSON-serializable representation of the Runnable.
|
A JSON-serializable representation of the Runnable.
|
||||||
"""
|
"""
|
||||||
dumped = super().to_json()
|
dumped = super().to_json()
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
dumped["name"] = self.get_name()
|
dumped["name"] = self.get_name()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return dumped
|
return dumped
|
||||||
|
|
||||||
def configurable_fields(
|
def configurable_fields(
|
||||||
@ -2763,9 +2762,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
ValueError: If the sequence has less than 2 steps.
|
ValueError: If the sequence has less than 2 steps.
|
||||||
"""
|
"""
|
||||||
steps_flat: list[Runnable] = []
|
steps_flat: list[Runnable] = []
|
||||||
if not steps:
|
if not steps and first is not None and last is not None:
|
||||||
if first is not None and last is not None:
|
steps_flat = [first] + (middle or []) + [last]
|
||||||
steps_flat = [first] + (middle or []) + [last]
|
|
||||||
for step in steps:
|
for step in steps:
|
||||||
if isinstance(step, RunnableSequence):
|
if isinstance(step, RunnableSequence):
|
||||||
steps_flat.extend(step.steps)
|
steps_flat.extend(step.steps)
|
||||||
@ -4180,12 +4178,9 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
final = None
|
final: Optional[Output] = None
|
||||||
for output in self.stream(input, config, **kwargs):
|
for output in self.stream(input, config, **kwargs):
|
||||||
if final is None:
|
final = output if final is None else final + output # type: ignore[operator]
|
||||||
final = output
|
|
||||||
else:
|
|
||||||
final = final + output
|
|
||||||
return cast(Output, final)
|
return cast(Output, final)
|
||||||
|
|
||||||
def atransform(
|
def atransform(
|
||||||
@ -4215,12 +4210,9 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
final = None
|
final: Optional[Output] = None
|
||||||
async for output in self.astream(input, config, **kwargs):
|
async for output in self.astream(input, config, **kwargs):
|
||||||
if final is None:
|
final = output if final is None else final + output # type: ignore[operator]
|
||||||
final = output
|
|
||||||
else:
|
|
||||||
final = final + output
|
|
||||||
return cast(Output, final)
|
return cast(Output, final)
|
||||||
|
|
||||||
|
|
||||||
|
@ -139,11 +139,11 @@ def _set_config_context(config: RunnableConfig) -> None:
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
and (run := tracer.run_map.get(str(parent_run_id)))
|
||||||
):
|
):
|
||||||
if run := tracer.run_map.get(str(parent_run_id)):
|
from langsmith.run_helpers import _set_tracing_context
|
||||||
from langsmith.run_helpers import _set_tracing_context
|
|
||||||
|
|
||||||
_set_tracing_context({"parent": run})
|
_set_tracing_context({"parent": run})
|
||||||
|
|
||||||
|
|
||||||
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||||
|
@ -99,24 +99,15 @@ class AsciiCanvas:
|
|||||||
self.point(x0, y0, char)
|
self.point(x0, y0, char)
|
||||||
elif abs(dx) >= abs(dy):
|
elif abs(dx) >= abs(dy):
|
||||||
for x in range(x0, x1 + 1):
|
for x in range(x0, x1 + 1):
|
||||||
if dx == 0:
|
y = y0 if dx == 0 else y0 + int(round((x - x0) * dy / float(dx)))
|
||||||
y = y0
|
|
||||||
else:
|
|
||||||
y = y0 + int(round((x - x0) * dy / float(dx)))
|
|
||||||
self.point(x, y, char)
|
self.point(x, y, char)
|
||||||
elif y0 < y1:
|
elif y0 < y1:
|
||||||
for y in range(y0, y1 + 1):
|
for y in range(y0, y1 + 1):
|
||||||
if dy == 0:
|
x = x0 if dy == 0 else x0 + int(round((y - y0) * dx / float(dy)))
|
||||||
x = x0
|
|
||||||
else:
|
|
||||||
x = x0 + int(round((y - y0) * dx / float(dy)))
|
|
||||||
self.point(x, y, char)
|
self.point(x, y, char)
|
||||||
else:
|
else:
|
||||||
for y in range(y1, y0 + 1):
|
for y in range(y1, y0 + 1):
|
||||||
if dy == 0:
|
x = x0 if dy == 0 else x1 + int(round((y - y1) * dx / float(dy)))
|
||||||
x = x0
|
|
||||||
else:
|
|
||||||
x = x1 + int(round((y - y1) * dx / float(dy)))
|
|
||||||
self.point(x, y, char)
|
self.point(x, y, char)
|
||||||
|
|
||||||
def text(self, x: int, y: int, text: str) -> None:
|
def text(self, x: int, y: int, text: str) -> None:
|
||||||
|
@ -131,10 +131,7 @@ def draw_mermaid(
|
|||||||
else:
|
else:
|
||||||
edge_label = f" -- {edge_data} --> "
|
edge_label = f" -- {edge_data} --> "
|
||||||
else:
|
else:
|
||||||
if edge.conditional:
|
edge_label = " -.-> " if edge.conditional else " --> "
|
||||||
edge_label = " -.-> "
|
|
||||||
else:
|
|
||||||
edge_label = " --> "
|
|
||||||
|
|
||||||
mermaid_graph += (
|
mermaid_graph += (
|
||||||
f"\t{_escape_node_label(source)}{edge_label}"
|
f"\t{_escape_node_label(source)}{edge_label}"
|
||||||
@ -142,7 +139,7 @@ def draw_mermaid(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Recursively add nested subgraphs
|
# Recursively add nested subgraphs
|
||||||
for nested_prefix in edge_groups.keys():
|
for nested_prefix in edge_groups:
|
||||||
if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix:
|
if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix:
|
||||||
continue
|
continue
|
||||||
add_subgraph(edge_groups[nested_prefix], nested_prefix)
|
add_subgraph(edge_groups[nested_prefix], nested_prefix)
|
||||||
@ -154,7 +151,7 @@ def draw_mermaid(
|
|||||||
add_subgraph(edge_groups.get("", []), "")
|
add_subgraph(edge_groups.get("", []), "")
|
||||||
|
|
||||||
# Add remaining subgraphs
|
# Add remaining subgraphs
|
||||||
for prefix in edge_groups.keys():
|
for prefix in edge_groups:
|
||||||
if ":" in prefix or prefix == "":
|
if ":" in prefix or prefix == "":
|
||||||
continue
|
continue
|
||||||
add_subgraph(edge_groups[prefix], prefix)
|
add_subgraph(edge_groups[prefix], prefix)
|
||||||
|
@ -496,12 +496,9 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[Addable]: The result of adding the addable objects.
|
Optional[Addable]: The result of adding the addable objects.
|
||||||
"""
|
"""
|
||||||
final = None
|
final: Optional[Addable] = None
|
||||||
for chunk in addables:
|
for chunk in addables:
|
||||||
if final is None:
|
final = chunk if final is None else final + chunk
|
||||||
final = chunk
|
|
||||||
else:
|
|
||||||
final = final + chunk
|
|
||||||
return final
|
return final
|
||||||
|
|
||||||
|
|
||||||
@ -514,12 +511,9 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[Addable]: The result of adding the addable objects.
|
Optional[Addable]: The result of adding the addable objects.
|
||||||
"""
|
"""
|
||||||
final = None
|
final: Optional[Addable] = None
|
||||||
async for chunk in addables:
|
async for chunk in addables:
|
||||||
if final is None:
|
final = chunk if final is None else final + chunk
|
||||||
final = chunk
|
|
||||||
else:
|
|
||||||
final = final + chunk
|
|
||||||
return final
|
return final
|
||||||
|
|
||||||
|
|
||||||
@ -642,9 +636,7 @@ def get_unique_config_specs(
|
|||||||
for id, dupes in grouped:
|
for id, dupes in grouped:
|
||||||
first = next(dupes)
|
first = next(dupes)
|
||||||
others = list(dupes)
|
others = list(dupes)
|
||||||
if len(others) == 0:
|
if len(others) == 0 or all(o == first for o in others):
|
||||||
unique.append(first)
|
|
||||||
elif all(o == first for o in others):
|
|
||||||
unique.append(first)
|
unique.append(first)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -258,7 +258,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
|||||||
if prefix is None:
|
if prefix is None:
|
||||||
yield from self.store.keys()
|
yield from self.store.keys()
|
||||||
else:
|
else:
|
||||||
for key in self.store.keys():
|
for key in self.store:
|
||||||
if key.startswith(prefix):
|
if key.startswith(prefix):
|
||||||
yield key
|
yield key
|
||||||
|
|
||||||
@ -272,10 +272,10 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
|||||||
AsyncIterator[str]: An async iterator over keys that match the given prefix.
|
AsyncIterator[str]: An async iterator over keys that match the given prefix.
|
||||||
"""
|
"""
|
||||||
if prefix is None:
|
if prefix is None:
|
||||||
for key in self.store.keys():
|
for key in self.store:
|
||||||
yield key
|
yield key
|
||||||
else:
|
else:
|
||||||
for key in self.store.keys():
|
for key in self.store:
|
||||||
if key.startswith(prefix):
|
if key.startswith(prefix):
|
||||||
yield key
|
yield key
|
||||||
|
|
||||||
|
@ -19,18 +19,24 @@ class Visitor(ABC):
|
|||||||
"""Allowed operators for the visitor."""
|
"""Allowed operators for the visitor."""
|
||||||
|
|
||||||
def _validate_func(self, func: Union[Operator, Comparator]) -> None:
|
def _validate_func(self, func: Union[Operator, Comparator]) -> None:
|
||||||
if isinstance(func, Operator) and self.allowed_operators is not None:
|
if (
|
||||||
if func not in self.allowed_operators:
|
isinstance(func, Operator)
|
||||||
raise ValueError(
|
and self.allowed_operators is not None
|
||||||
f"Received disallowed operator {func}. Allowed "
|
and func not in self.allowed_operators
|
||||||
f"comparators are {self.allowed_operators}"
|
):
|
||||||
)
|
raise ValueError(
|
||||||
if isinstance(func, Comparator) and self.allowed_comparators is not None:
|
f"Received disallowed operator {func}. Allowed "
|
||||||
if func not in self.allowed_comparators:
|
f"comparators are {self.allowed_operators}"
|
||||||
raise ValueError(
|
)
|
||||||
f"Received disallowed comparator {func}. Allowed "
|
if (
|
||||||
f"comparators are {self.allowed_comparators}"
|
isinstance(func, Comparator)
|
||||||
)
|
and self.allowed_comparators is not None
|
||||||
|
and func not in self.allowed_comparators
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Received disallowed comparator {func}. Allowed "
|
||||||
|
f"comparators are {self.allowed_comparators}"
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visit_operation(self, operation: Operation) -> Any:
|
def visit_operation(self, operation: Operation) -> Any:
|
||||||
|
@ -248,11 +248,8 @@ def create_schema_from_function(
|
|||||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
|
||||||
|
|
||||||
# Let's ignore `self` and `cls` arguments for class and instance methods
|
# Let's ignore `self` and `cls` arguments for class and instance methods
|
||||||
if func.__qualname__ and "." in func.__qualname__:
|
# If qualified name has a ".", then it likely belongs in a class namespace
|
||||||
# Then it likely belongs in a class namespace
|
in_class = bool(func.__qualname__ and "." in func.__qualname__)
|
||||||
in_class = True
|
|
||||||
else:
|
|
||||||
in_class = False
|
|
||||||
|
|
||||||
has_args = False
|
has_args = False
|
||||||
has_kwargs = False
|
has_kwargs = False
|
||||||
@ -289,12 +286,10 @@ def create_schema_from_function(
|
|||||||
# Pydantic adds placeholder virtual fields we need to strip
|
# Pydantic adds placeholder virtual fields we need to strip
|
||||||
valid_properties = []
|
valid_properties = []
|
||||||
for field in get_fields(inferred_model):
|
for field in get_fields(inferred_model):
|
||||||
if not has_args:
|
if not has_args and field == "args":
|
||||||
if field == "args":
|
continue
|
||||||
continue
|
if not has_kwargs and field == "kwargs":
|
||||||
if not has_kwargs:
|
continue
|
||||||
if field == "kwargs":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if field == "v__duplicate_kwargs": # Internal pydantic field
|
if field == "v__duplicate_kwargs": # Internal pydantic field
|
||||||
continue
|
continue
|
||||||
@ -422,12 +417,15 @@ class ChildTool(BaseTool):
|
|||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
"""Initialize the tool."""
|
"""Initialize the tool."""
|
||||||
if "args_schema" in kwargs and kwargs["args_schema"] is not None:
|
if (
|
||||||
if not is_basemodel_subclass(kwargs["args_schema"]):
|
"args_schema" in kwargs
|
||||||
raise TypeError(
|
and kwargs["args_schema"] is not None
|
||||||
f"args_schema must be a subclass of pydantic BaseModel. "
|
and not is_basemodel_subclass(kwargs["args_schema"])
|
||||||
f"Got: {kwargs['args_schema']}."
|
):
|
||||||
)
|
raise TypeError(
|
||||||
|
f"args_schema must be a subclass of pydantic BaseModel. "
|
||||||
|
f"Got: {kwargs['args_schema']}."
|
||||||
|
)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
@ -840,10 +838,7 @@ def _handle_tool_error(
|
|||||||
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
||||||
) -> str:
|
) -> str:
|
||||||
if isinstance(flag, bool):
|
if isinstance(flag, bool):
|
||||||
if e.args:
|
content = e.args[0] if e.args else "Tool execution error"
|
||||||
content = e.args[0]
|
|
||||||
else:
|
|
||||||
content = "Tool execution error"
|
|
||||||
elif isinstance(flag, str):
|
elif isinstance(flag, str):
|
||||||
content = flag
|
content = flag
|
||||||
elif callable(flag):
|
elif callable(flag):
|
||||||
@ -902,12 +897,11 @@ def _format_output(
|
|||||||
|
|
||||||
def _is_message_content_type(obj: Any) -> bool:
|
def _is_message_content_type(obj: Any) -> bool:
|
||||||
"""Check for OpenAI or Anthropic format tool message content."""
|
"""Check for OpenAI or Anthropic format tool message content."""
|
||||||
if isinstance(obj, str):
|
return (
|
||||||
return True
|
isinstance(obj, str)
|
||||||
elif isinstance(obj, list) and all(_is_message_content_block(e) for e in obj):
|
or isinstance(obj, list)
|
||||||
return True
|
and all(_is_message_content_block(e) for e in obj)
|
||||||
else:
|
)
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_message_content_block(obj: Any) -> bool:
|
def _is_message_content_block(obj: Any) -> bool:
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -814,10 +815,7 @@ async def _astream_events_implementation_v1(
|
|||||||
data: EventData = {}
|
data: EventData = {}
|
||||||
log_entry: LogEntry = run_log.state["logs"][path]
|
log_entry: LogEntry = run_log.state["logs"][path]
|
||||||
if log_entry["end_time"] is None:
|
if log_entry["end_time"] is None:
|
||||||
if log_entry["streamed_output"]:
|
event_type = "stream" if log_entry["streamed_output"] else "start"
|
||||||
event_type = "stream"
|
|
||||||
else:
|
|
||||||
event_type = "start"
|
|
||||||
else:
|
else:
|
||||||
event_type = "end"
|
event_type = "end"
|
||||||
|
|
||||||
@ -983,14 +981,15 @@ async def _astream_events_implementation_v2(
|
|||||||
yield event
|
yield event
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if event["run_id"] == first_event_run_id and event["event"].endswith(
|
# If it's the end event corresponding to the root runnable
|
||||||
"_end"
|
# we dont include the input in the event since it's guaranteed
|
||||||
|
# to be included in the first event.
|
||||||
|
if (
|
||||||
|
event["run_id"] == first_event_run_id
|
||||||
|
and event["event"].endswith("_end")
|
||||||
|
and "input" in event["data"]
|
||||||
):
|
):
|
||||||
# If it's the end event corresponding to the root runnable
|
del event["data"]["input"]
|
||||||
# we dont include the input in the event since it's guaranteed
|
|
||||||
# to be included in the first event.
|
|
||||||
if "input" in event["data"]:
|
|
||||||
del event["data"]["input"]
|
|
||||||
|
|
||||||
yield event
|
yield event
|
||||||
except asyncio.CancelledError as exc:
|
except asyncio.CancelledError as exc:
|
||||||
@ -1001,7 +1000,5 @@ async def _astream_events_implementation_v2(
|
|||||||
# Cancel the task if it's still running
|
# Cancel the task if it's still running
|
||||||
task.cancel()
|
task.cancel()
|
||||||
# Await it anyway, to run any cleanup code, and propagate any exceptions
|
# Await it anyway, to run any cleanup code, and propagate any exceptions
|
||||||
try:
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
await task
|
await task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import threading
|
import threading
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@ -253,18 +254,22 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
|||||||
"""
|
"""
|
||||||
async for chunk in output:
|
async for chunk in output:
|
||||||
# root run is handled in .astream_log()
|
# root run is handled in .astream_log()
|
||||||
if run_id != self.root_id:
|
# if we can't find the run silently ignore
|
||||||
# if we can't find the run silently ignore
|
# eg. because this run wasn't included in the log
|
||||||
# eg. because this run wasn't included in the log
|
if (
|
||||||
if key := self._key_map_by_run_id.get(run_id):
|
run_id != self.root_id
|
||||||
if not self.send(
|
and (key := self._key_map_by_run_id.get(run_id))
|
||||||
|
and (
|
||||||
|
not self.send(
|
||||||
{
|
{
|
||||||
"op": "add",
|
"op": "add",
|
||||||
"path": f"/logs/{key}/streamed_output/-",
|
"path": f"/logs/{key}/streamed_output/-",
|
||||||
"value": chunk,
|
"value": chunk,
|
||||||
}
|
}
|
||||||
):
|
)
|
||||||
break
|
)
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@ -280,18 +285,22 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
|||||||
"""
|
"""
|
||||||
for chunk in output:
|
for chunk in output:
|
||||||
# root run is handled in .astream_log()
|
# root run is handled in .astream_log()
|
||||||
if run_id != self.root_id:
|
# if we can't find the run silently ignore
|
||||||
# if we can't find the run silently ignore
|
# eg. because this run wasn't included in the log
|
||||||
# eg. because this run wasn't included in the log
|
if (
|
||||||
if key := self._key_map_by_run_id.get(run_id):
|
run_id != self.root_id
|
||||||
if not self.send(
|
and (key := self._key_map_by_run_id.get(run_id))
|
||||||
|
and (
|
||||||
|
not self.send(
|
||||||
{
|
{
|
||||||
"op": "add",
|
"op": "add",
|
||||||
"path": f"/logs/{key}/streamed_output/-",
|
"path": f"/logs/{key}/streamed_output/-",
|
||||||
"value": chunk,
|
"value": chunk,
|
||||||
}
|
}
|
||||||
):
|
)
|
||||||
break
|
)
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@ -439,9 +448,8 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
|||||||
|
|
||||||
self.send(*ops)
|
self.send(*ops)
|
||||||
finally:
|
finally:
|
||||||
if run.id == self.root_id:
|
if run.id == self.root_id and self.auto_close:
|
||||||
if self.auto_close:
|
self.send_stream.close()
|
||||||
self.send_stream.close()
|
|
||||||
|
|
||||||
def _on_llm_new_token(
|
def _on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
@ -662,7 +670,5 @@ async def _astream_log_implementation(
|
|||||||
yield state
|
yield state
|
||||||
finally:
|
finally:
|
||||||
# Wait for the runnable to finish, if not cancelled (eg. by break)
|
# Wait for the runnable to finish, if not cancelled (eg. by break)
|
||||||
try:
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
await task
|
await task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
@ -29,9 +29,7 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]
|
|||||||
merged = left.copy()
|
merged = left.copy()
|
||||||
for right in others:
|
for right in others:
|
||||||
for right_k, right_v in right.items():
|
for right_k, right_v in right.items():
|
||||||
if right_k not in merged:
|
if right_k not in merged or right_v is not None and merged[right_k] is None:
|
||||||
merged[right_k] = right_v
|
|
||||||
elif right_v is not None and merged[right_k] is None:
|
|
||||||
merged[right_k] = right_v
|
merged[right_k] = right_v
|
||||||
elif right_v is None:
|
elif right_v is None:
|
||||||
continue
|
continue
|
||||||
|
@ -43,14 +43,10 @@ def get_from_dict_or_env(
|
|||||||
if k in data and data[k]:
|
if k in data and data[k]:
|
||||||
return data[k]
|
return data[k]
|
||||||
|
|
||||||
if isinstance(key, str):
|
if isinstance(key, str) and key in data and data[key]:
|
||||||
if key in data and data[key]:
|
return data[key]
|
||||||
return data[key]
|
|
||||||
|
|
||||||
if isinstance(key, (list, tuple)):
|
key_for_err = key[0] if isinstance(key, (list, tuple)) else key
|
||||||
key_for_err = key[0]
|
|
||||||
else:
|
|
||||||
key_for_err = key
|
|
||||||
|
|
||||||
return get_from_env(key_for_err, env_key, default=default)
|
return get_from_env(key_for_err, env_key, default=default)
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
|
|||||||
new_kv = {}
|
new_kv = {}
|
||||||
for k, v in kv.items():
|
for k, v in kv.items():
|
||||||
if k == "title":
|
if k == "title":
|
||||||
if isinstance(v, dict) and prev_key == "properties" and "title" in v.keys():
|
if isinstance(v, dict) and prev_key == "properties" and "title" in v:
|
||||||
new_kv[k] = _rm_titles(v, k)
|
new_kv[k] = _rm_titles(v, k)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
@ -139,11 +139,8 @@ def parse_json_markdown(
|
|||||||
match = _json_markdown_re.search(json_string)
|
match = _json_markdown_re.search(json_string)
|
||||||
|
|
||||||
# If no match found, assume the entire string is a JSON string
|
# If no match found, assume the entire string is a JSON string
|
||||||
if match is None:
|
# Else, use the content within the backticks
|
||||||
json_str = json_string
|
json_str = json_string if match is None else match.group(2)
|
||||||
else:
|
|
||||||
# If match found, use the content within the backticks
|
|
||||||
json_str = match.group(2)
|
|
||||||
return _parse_json(json_str, parser=parser)
|
return _parse_json(json_str, parser=parser)
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,12 +80,9 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
|
|||||||
padding = literal.split("\n")[-1]
|
padding = literal.split("\n")[-1]
|
||||||
|
|
||||||
# If all the characters since the last newline are spaces
|
# If all the characters since the last newline are spaces
|
||||||
if padding.isspace() or padding == "":
|
# Then the next tag could be a standalone
|
||||||
# Then the next tag could be a standalone
|
# Otherwise it can't be
|
||||||
return True
|
return padding.isspace() or padding == ""
|
||||||
else:
|
|
||||||
# Otherwise it can't be
|
|
||||||
return False
|
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -107,10 +104,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
|
|||||||
on_newline = template.split("\n", 1)
|
on_newline = template.split("\n", 1)
|
||||||
|
|
||||||
# If the stuff to the right of us are spaces we're a standalone
|
# If the stuff to the right of us are spaces we're a standalone
|
||||||
if on_newline[0].isspace() or not on_newline[0]:
|
return on_newline[0].isspace() or not on_newline[0]
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# If we're a tag can't be a standalone
|
# If we're a tag can't be a standalone
|
||||||
else:
|
else:
|
||||||
@ -174,14 +168,18 @@ def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], s
|
|||||||
"unclosed set delimiter tag\n" f"at line {_CURRENT_LINE}"
|
"unclosed set delimiter tag\n" f"at line {_CURRENT_LINE}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we might be a no html escape tag
|
elif (
|
||||||
elif tag_type == "no escape?":
|
# If we might be a no html escape tag
|
||||||
|
tag_type == "no escape?"
|
||||||
# And we have a third curly brace
|
# And we have a third curly brace
|
||||||
# (And are using curly braces as delimiters)
|
# (And are using curly braces as delimiters)
|
||||||
if l_del == "{{" and r_del == "}}" and template.startswith("}"):
|
and l_del == "{{"
|
||||||
# Then we are a no html escape tag
|
and r_del == "}}"
|
||||||
template = template[1:]
|
and template.startswith("}")
|
||||||
tag_type = "no escape"
|
):
|
||||||
|
# Then we are a no html escape tag
|
||||||
|
template = template[1:]
|
||||||
|
tag_type = "no escape"
|
||||||
|
|
||||||
# Strip the whitespace off the key and return
|
# Strip the whitespace off the key and return
|
||||||
return ((tag_type, tag.strip()), template)
|
return ((tag_type, tag.strip()), template)
|
||||||
|
@ -179,22 +179,27 @@ def pre_init(func: Callable) -> Any:
|
|||||||
for name, field_info in fields.items():
|
for name, field_info in fields.items():
|
||||||
# Check if allow_population_by_field_name is enabled
|
# Check if allow_population_by_field_name is enabled
|
||||||
# If yes, then set the field name to the alias
|
# If yes, then set the field name to the alias
|
||||||
if hasattr(cls, "Config"):
|
if (
|
||||||
if hasattr(cls.Config, "allow_population_by_field_name"):
|
hasattr(cls, "Config")
|
||||||
if cls.Config.allow_population_by_field_name:
|
and hasattr(cls.Config, "allow_population_by_field_name")
|
||||||
if field_info.alias in values:
|
and cls.Config.allow_population_by_field_name
|
||||||
values[name] = values.pop(field_info.alias)
|
and field_info.alias in values
|
||||||
if hasattr(cls, "model_config"):
|
):
|
||||||
if cls.model_config.get("populate_by_name"):
|
values[name] = values.pop(field_info.alias)
|
||||||
if field_info.alias in values:
|
if (
|
||||||
values[name] = values.pop(field_info.alias)
|
hasattr(cls, "model_config")
|
||||||
|
and cls.model_config.get("populate_by_name")
|
||||||
|
and field_info.alias in values
|
||||||
|
):
|
||||||
|
values[name] = values.pop(field_info.alias)
|
||||||
|
|
||||||
if name not in values or values[name] is None:
|
if (
|
||||||
if not field_info.is_required():
|
name not in values or values[name] is None
|
||||||
if field_info.default_factory is not None:
|
) and not field_info.is_required():
|
||||||
values[name] = field_info.default_factory()
|
if field_info.default_factory is not None:
|
||||||
else:
|
values[name] = field_info.default_factory()
|
||||||
values[name] = field_info.default
|
else:
|
||||||
|
values[name] = field_info.default
|
||||||
|
|
||||||
# Call the decorated function
|
# Call the decorated function
|
||||||
return func(cls, values)
|
return func(cls, values)
|
||||||
|
@ -332,9 +332,8 @@ def from_env(
|
|||||||
for k in key:
|
for k in key:
|
||||||
if k in os.environ:
|
if k in os.environ:
|
||||||
return os.environ[k]
|
return os.environ[k]
|
||||||
if isinstance(key, str):
|
if isinstance(key, str) and key in os.environ:
|
||||||
if key in os.environ:
|
return os.environ[key]
|
||||||
return os.environ[key]
|
|
||||||
|
|
||||||
if isinstance(default, (str, type(None))):
|
if isinstance(default, (str, type(None))):
|
||||||
return default
|
return default
|
||||||
@ -395,9 +394,8 @@ def secret_from_env(
|
|||||||
for k in key:
|
for k in key:
|
||||||
if k in os.environ:
|
if k in os.environ:
|
||||||
return SecretStr(os.environ[k])
|
return SecretStr(os.environ[k])
|
||||||
if isinstance(key, str):
|
if isinstance(key, str) and key in os.environ:
|
||||||
if key in os.environ:
|
return SecretStr(os.environ[key])
|
||||||
return SecretStr(os.environ[key])
|
|
||||||
if isinstance(default, str):
|
if isinstance(default, str):
|
||||||
return SecretStr(default)
|
return SecretStr(default)
|
||||||
elif isinstance(default, type(None)):
|
elif isinstance(default, type(None)):
|
||||||
|
@ -44,7 +44,7 @@ python = ">=3.12.4"
|
|||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [ "B", "C4", "E", "F", "I", "N", "PIE", "T201", "UP",]
|
select = [ "B", "C4", "E", "F", "I", "N", "PIE", "SIM", "T201", "UP",]
|
||||||
ignore = [ "UP007",]
|
ignore = [ "UP007",]
|
||||||
|
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
|
@ -148,7 +148,7 @@ def test_pydantic_output_parser() -> None:
|
|||||||
|
|
||||||
result = pydantic_parser.parse(DEF_RESULT)
|
result = pydantic_parser.parse(DEF_RESULT)
|
||||||
print("parse_result:", result) # noqa: T201
|
print("parse_result:", result) # noqa: T201
|
||||||
assert DEF_EXPECTED_RESULT == result
|
assert result == DEF_EXPECTED_RESULT
|
||||||
assert pydantic_parser.OutputType is TestModel
|
assert pydantic_parser.OutputType is TestModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -115,10 +115,7 @@ def _normalize_schema(obj: Any) -> dict[str, Any]:
|
|||||||
Args:
|
Args:
|
||||||
obj: The object to generate the schema for
|
obj: The object to generate the schema for
|
||||||
"""
|
"""
|
||||||
if isinstance(obj, BaseModel):
|
data = obj.model_json_schema() if isinstance(obj, BaseModel) else obj
|
||||||
data = obj.model_json_schema()
|
|
||||||
else:
|
|
||||||
data = obj
|
|
||||||
remove_all_none_default(data)
|
remove_all_none_default(data)
|
||||||
replace_all_of_with_ref(data)
|
replace_all_of_with_ref(data)
|
||||||
_remove_enum(data)
|
_remove_enum(data)
|
||||||
|
@ -3345,9 +3345,9 @@ def test_bind_with_lambda() -> None:
|
|||||||
return 3 + kwargs.get("n", 0)
|
return 3 + kwargs.get("n", 0)
|
||||||
|
|
||||||
runnable = RunnableLambda(my_function).bind(n=1)
|
runnable = RunnableLambda(my_function).bind(n=1)
|
||||||
assert 4 == runnable.invoke({})
|
assert runnable.invoke({}) == 4
|
||||||
chunks = list(runnable.stream({}))
|
chunks = list(runnable.stream({}))
|
||||||
assert [4] == chunks
|
assert chunks == [4]
|
||||||
|
|
||||||
|
|
||||||
async def test_bind_with_lambda_async() -> None:
|
async def test_bind_with_lambda_async() -> None:
|
||||||
@ -3355,9 +3355,9 @@ async def test_bind_with_lambda_async() -> None:
|
|||||||
return 3 + kwargs.get("n", 0)
|
return 3 + kwargs.get("n", 0)
|
||||||
|
|
||||||
runnable = RunnableLambda(my_function).bind(n=1)
|
runnable = RunnableLambda(my_function).bind(n=1)
|
||||||
assert 4 == await runnable.ainvoke({})
|
assert await runnable.ainvoke({}) == 4
|
||||||
chunks = [item async for item in runnable.astream({})]
|
chunks = [item async for item in runnable.astream({})]
|
||||||
assert [4] == chunks
|
assert chunks == [4]
|
||||||
|
|
||||||
|
|
||||||
def test_deep_stream() -> None:
|
def test_deep_stream() -> None:
|
||||||
@ -5140,13 +5140,10 @@ async def test_astream_log_deep_copies() -> None:
|
|||||||
|
|
||||||
chain = RunnableLambda(add_one)
|
chain = RunnableLambda(add_one)
|
||||||
chunks = []
|
chunks = []
|
||||||
final_output = None
|
final_output: Optional[RunLogPatch] = None
|
||||||
async for chunk in chain.astream_log(1):
|
async for chunk in chain.astream_log(1):
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
if final_output is None:
|
final_output = chunk if final_output is None else final_output + chunk
|
||||||
final_output = chunk
|
|
||||||
else:
|
|
||||||
final_output = final_output + chunk
|
|
||||||
|
|
||||||
run_log = _get_run_log(chunks)
|
run_log = _get_run_log(chunks)
|
||||||
state = run_log.state.copy()
|
state = run_log.state.copy()
|
||||||
|
@ -209,9 +209,11 @@ def test_tracing_enable_disable(
|
|||||||
|
|
||||||
get_env_var.cache_clear()
|
get_env_var.cache_clear()
|
||||||
env_on = env == "true"
|
env_on = env == "true"
|
||||||
with patch.dict("os.environ", {"LANGSMITH_TRACING": env}):
|
with (
|
||||||
with tracing_context(enabled=enabled):
|
patch.dict("os.environ", {"LANGSMITH_TRACING": env}),
|
||||||
RunnableLambda(my_func).invoke(1)
|
tracing_context(enabled=enabled),
|
||||||
|
):
|
||||||
|
RunnableLambda(my_func).invoke(1)
|
||||||
|
|
||||||
mock_posts = _get_posts(mock_client_)
|
mock_posts = _get_posts(mock_client_)
|
||||||
if enabled is True:
|
if enabled is True:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -10,6 +10,7 @@ from langchain_core.messages import (
|
|||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatMessageChunk,
|
ChatMessageChunk,
|
||||||
FunctionMessage,
|
FunctionMessage,
|
||||||
@ -630,14 +631,11 @@ def test_tool_calls_merge() -> None:
|
|||||||
{"content": ""},
|
{"content": ""},
|
||||||
]
|
]
|
||||||
|
|
||||||
final = None
|
final: Optional[BaseMessageChunk] = None
|
||||||
|
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
msg = AIMessageChunk(**chunk)
|
msg = AIMessageChunk(**chunk)
|
||||||
if final is None:
|
final = msg if final is None else final + msg
|
||||||
final = msg
|
|
||||||
else:
|
|
||||||
final = final + msg
|
|
||||||
|
|
||||||
assert final == AIMessageChunk(
|
assert final == AIMessageChunk(
|
||||||
content="",
|
content="",
|
||||||
|
@ -133,25 +133,24 @@ class LangChainProjectNameTest(unittest.TestCase):
|
|||||||
for case in cases:
|
for case in cases:
|
||||||
get_env_var.cache_clear()
|
get_env_var.cache_clear()
|
||||||
get_tracer_project.cache_clear()
|
get_tracer_project.cache_clear()
|
||||||
with self.subTest(msg=case.test_name):
|
with self.subTest(msg=case.test_name), pytest.MonkeyPatch.context() as mp:
|
||||||
with pytest.MonkeyPatch.context() as mp:
|
for k, v in case.envvars.items():
|
||||||
for k, v in case.envvars.items():
|
mp.setenv(k, v)
|
||||||
mp.setenv(k, v)
|
|
||||||
|
|
||||||
client = unittest.mock.MagicMock(spec=Client)
|
client = unittest.mock.MagicMock(spec=Client)
|
||||||
tracer = LangChainTracer(client=client)
|
tracer = LangChainTracer(client=client)
|
||||||
projects = []
|
projects = []
|
||||||
|
|
||||||
def mock_create_run(**kwargs: Any) -> Any:
|
def mock_create_run(**kwargs: Any) -> Any:
|
||||||
projects.append(kwargs.get("project_name")) # noqa: B023
|
projects.append(kwargs.get("project_name")) # noqa: B023
|
||||||
return unittest.mock.MagicMock()
|
return unittest.mock.MagicMock()
|
||||||
|
|
||||||
client.create_run = mock_create_run
|
client.create_run = mock_create_run
|
||||||
|
|
||||||
tracer.on_llm_start(
|
tracer.on_llm_start(
|
||||||
{"name": "example_1"},
|
{"name": "example_1"},
|
||||||
["foo"],
|
["foo"],
|
||||||
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
|
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
|
||||||
)
|
)
|
||||||
tracer.wait_for_futures()
|
tracer.wait_for_futures()
|
||||||
assert projects == [case.expected_project_name]
|
assert projects == [case.expected_project_name]
|
||||||
|
@ -513,14 +513,8 @@ def test_tool_outputs() -> None:
|
|||||||
def test__convert_typed_dict_to_openai_function(
|
def test__convert_typed_dict_to_openai_function(
|
||||||
use_extension_typed_dict: bool, use_extension_annotated: bool
|
use_extension_typed_dict: bool, use_extension_annotated: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
if use_extension_typed_dict:
|
typed_dict = ExtensionsTypedDict if use_extension_typed_dict else TypingTypedDict
|
||||||
typed_dict = ExtensionsTypedDict
|
annotated = TypingAnnotated if use_extension_annotated else TypingAnnotated
|
||||||
else:
|
|
||||||
typed_dict = TypingTypedDict
|
|
||||||
if use_extension_annotated:
|
|
||||||
annotated = TypingAnnotated
|
|
||||||
else:
|
|
||||||
annotated = TypingAnnotated
|
|
||||||
|
|
||||||
class SubTool(typed_dict):
|
class SubTool(typed_dict):
|
||||||
"""Subtool docstring"""
|
"""Subtool docstring"""
|
||||||
|
@ -116,10 +116,7 @@ def test_check_package_version(
|
|||||||
def test_merge_dicts(
|
def test_merge_dicts(
|
||||||
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
|
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(expected, AbstractContextManager):
|
err = expected if isinstance(expected, AbstractContextManager) else nullcontext()
|
||||||
err = expected
|
|
||||||
else:
|
|
||||||
err = nullcontext()
|
|
||||||
|
|
||||||
left_copy = deepcopy(left)
|
left_copy = deepcopy(left)
|
||||||
right_copy = deepcopy(right)
|
right_copy = deepcopy(right)
|
||||||
@ -147,10 +144,7 @@ def test_merge_dicts(
|
|||||||
def test_merge_dicts_0_3(
|
def test_merge_dicts_0_3(
|
||||||
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
|
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(expected, AbstractContextManager):
|
err = expected if isinstance(expected, AbstractContextManager) else nullcontext()
|
||||||
err = expected
|
|
||||||
else:
|
|
||||||
err = nullcontext()
|
|
||||||
|
|
||||||
left_copy = deepcopy(left)
|
left_copy = deepcopy(left)
|
||||||
right_copy = deepcopy(right)
|
right_copy = deepcopy(right)
|
||||||
|
Loading…
Reference in New Issue
Block a user