core: Add ruff rules RET (#29384)

See https://docs.astral.sh/ruff/rules/#flake8-return-ret
All auto-fixes
This commit is contained in:
Christophe Bornet 2025-04-02 22:59:56 +02:00 committed by GitHub
parent 9ae792f56c
commit f241fd5c11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
70 changed files with 626 additions and 813 deletions

View File

@ -466,8 +466,7 @@ def warn_deprecated(
f"{removal}"
)
raise NotImplementedError(msg)
else:
removal = f"in {removal}"
removal = f"in {removal}"
if not message:
message = ""

View File

@ -185,8 +185,7 @@ def _convert_agent_action_to_messages(
"""
if isinstance(agent_action, AgentActionMessageLog):
return agent_action.message_log
else:
return [AIMessage(content=agent_action.log)]
return [AIMessage(content=agent_action.log)]
def _convert_agent_observation_to_messages(
@ -205,14 +204,13 @@ def _convert_agent_observation_to_messages(
"""
if isinstance(agent_action, AgentActionMessageLog):
return [_create_function_message(agent_action, observation)]
else:
content = observation
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
return [HumanMessage(content=content)]
content = observation
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
return [HumanMessage(content=content)]
def _create_function_message(

View File

@ -59,11 +59,10 @@ def _key_from_id(id_: str) -> str:
wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1]
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)]
elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)]
else:
msg = f"Invalid context config id {id_}"
raise ValueError(msg)
msg = f"Invalid context config id {id_}"
raise ValueError(msg)
def _config_with_context(
@ -197,8 +196,7 @@ class ContextGet(RunnableSerializable):
configurable = config.get("configurable", {})
if isinstance(self.key, list):
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
else:
return configurable[self.ids[0]]()
return configurable[self.ids[0]]()
@override
async def ainvoke(
@ -209,8 +207,7 @@ class ContextGet(RunnableSerializable):
if isinstance(self.key, list):
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
return dict(zip(self.key, values))
else:
return await configurable[self.ids[0]]()
return await configurable[self.ids[0]]()
SetValue = Union[
@ -447,5 +444,4 @@ class PrefixContext:
def _print_keys(keys: Union[str, Sequence[str]]) -> str:
if isinstance(keys, str):
return f"'{keys}'"
else:
return ", ".join(f"'{k}'" for k in keys)
return ", ".join(f"'{k}'" for k in keys)

View File

@ -128,8 +128,7 @@ class LangSmithLoader(BaseLoader):
def _stringify(x: Union[str, dict]) -> str:
if isinstance(x, str):
return x
else:
try:
return json.dumps(x, indent=2)
except Exception:
return str(x)
try:
return json.dumps(x, indent=2)
except Exception:
return str(x)

View File

@ -54,8 +54,7 @@ class BaseMedia(Serializable):
"""
if id_value is not None:
return str(id_value)
else:
return id_value
return id_value
class Blob(BaseMedia):
@ -159,25 +158,23 @@ class Blob(BaseMedia):
"""Read data as a string."""
if self.data is None and self.path:
return Path(self.path).read_text(encoding=self.encoding)
elif isinstance(self.data, bytes):
if isinstance(self.data, bytes):
return self.data.decode(self.encoding)
elif isinstance(self.data, str):
if isinstance(self.data, str):
return self.data
else:
msg = f"Unable to get string for blob {self}"
raise ValueError(msg)
msg = f"Unable to get string for blob {self}"
raise ValueError(msg)
def as_bytes(self) -> bytes:
"""Read data as bytes."""
if isinstance(self.data, bytes):
return self.data
elif isinstance(self.data, str):
if isinstance(self.data, str):
return self.data.encode(self.encoding)
elif self.data is None and self.path:
if self.data is None and self.path:
return Path(self.path).read_bytes()
else:
msg = f"Unable to get bytes for blob {self}"
raise ValueError(msg)
msg = f"Unable to get bytes for blob {self}"
raise ValueError(msg)
@contextlib.contextmanager
def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]:
@ -316,5 +313,4 @@ class Document(BaseMedia):
# a more general solution of formatting content directly inside the prompts.
if self.metadata:
return f"page_content='{self.page_content}' metadata={self.metadata}"
else:
return f"page_content='{self.page_content}'"
return f"page_content='{self.page_content}'"

View File

@ -79,9 +79,8 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
new_length = remaining_length - self.example_text_lengths[i]
if new_length < 0:
break
else:
examples.append(self.examples[i])
remaining_length = new_length
examples.append(self.examples[i])
remaining_length = new_length
i += 1
return examples

View File

@ -54,8 +54,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
) -> str:
if input_keys:
return " ".join(sorted_values({key: example[key] for key in input_keys}))
else:
return " ".join(sorted_values(example))
return " ".join(sorted_values(example))
def _documents_to_examples(self, documents: list[Document]) -> list[dict]:
# Get the examples from the metadata.

View File

@ -152,16 +152,15 @@ def _get_source_id_assigner(
"""Get the source id from the document."""
if source_id_key is None:
return lambda doc: None
elif isinstance(source_id_key, str):
if isinstance(source_id_key, str):
return lambda doc: doc.metadata[source_id_key]
elif callable(source_id_key):
if callable(source_id_key):
return source_id_key
else:
msg = (
f"source_id_key should be either None, a string or a callable. "
f"Got {source_id_key} of type {type(source_id_key)}."
)
raise ValueError(msg)
msg = (
f"source_id_key should be either None, a string or a callable. "
f"Got {source_id_key} of type {type(source_id_key)}."
)
raise ValueError(msg)
def _deduplicate_in_order(

View File

@ -143,8 +143,7 @@ class BaseLanguageModel(
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
return verbose
@property
@override
@ -351,8 +350,7 @@ class BaseLanguageModel(
"""
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
else:
return _get_token_ids_default_method(text)
return _get_token_ids_default_method(text)
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.

View File

@ -284,16 +284,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
elif isinstance(input, str):
if isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, Sequence):
if isinstance(input, Sequence):
return ChatPromptValue(messages=convert_to_messages(input))
else:
msg = (
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg) # noqa: TRY004
msg = (
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg) # noqa: TRY004
@override
def invoke(
@ -610,10 +609,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
_cleanup_llm_representation(serialized_repr, 1)
llm_string = json.dumps(serialized_repr, sort_keys=True)
return llm_string + "---" + param_string
else:
params = self._get_invocation_params(stop=stop, **kwargs)
params = {**params, **kwargs}
return str(sorted(params.items()))
params = self._get_invocation_params(stop=stop, **kwargs)
params = {**params, **kwargs}
return str(sorted(params.items()))
def generate(
self,
@ -1107,9 +1105,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
).generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
msg = "Unexpected generation type"
raise ValueError(msg) # noqa: TRY004
msg = "Unexpected generation type"
raise ValueError(msg) # noqa: TRY004
async def _call_async(
self,
@ -1124,9 +1121,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
generation = result.generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
msg = "Unexpected generation type"
raise ValueError(msg) # noqa: TRY004
msg = "Unexpected generation type"
raise ValueError(msg) # noqa: TRY004
@deprecated("0.1.7", alternative="invoke", removal="1.0")
def call_as_llm(
@ -1167,9 +1163,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
if isinstance(result.content, str):
return result.content
else:
msg = "Cannot use predict when output is not a string."
raise ValueError(msg) # noqa: TRY004
msg = "Cannot use predict when output is not a string."
raise ValueError(msg) # noqa: TRY004
@deprecated("0.1.7", alternative="invoke", removal="1.0")
@override
@ -1194,9 +1189,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
if isinstance(result.content, str):
return result.content
else:
msg = "Cannot use predict when output is not a string."
raise ValueError(msg) # noqa: TRY004
msg = "Cannot use predict when output is not a string."
raise ValueError(msg) # noqa: TRY004
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@override
@ -1391,8 +1385,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
return llm | output_parser
class SimpleChatModel(BaseChatModel):

View File

@ -251,8 +251,7 @@ def update_cache(
prompt = prompts[missing_prompt_idxs[i]]
if llm_cache is not None:
llm_cache.update(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
return new_results.llm_output
async def aupdate_cache(
@ -285,8 +284,7 @@ async def aupdate_cache(
prompt = prompts[missing_prompt_idxs[i]]
if llm_cache:
await llm_cache.aupdate(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
return new_results.llm_output
class BaseLLM(BaseLanguageModel[str], ABC):
@ -330,16 +328,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
elif isinstance(input, str):
if isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, Sequence):
if isinstance(input, Sequence):
return ChatPromptValue(messages=convert_to_messages(input))
else:
msg = (
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg) # noqa: TRY004
msg = (
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg) # noqa: TRY004
def _get_ls_params(
self,
@ -452,8 +449,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
except Exception as e:
if return_exceptions:
return cast("list[str]", [e for _ in inputs])
else:
raise
raise
else:
batches = [
inputs[i : i + max_concurrency]
@ -499,8 +495,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
except Exception as e:
if return_exceptions:
return cast("list[str]", [e for _ in inputs])
else:
raise
raise
else:
batches = [
inputs[i : i + max_concurrency]
@ -973,14 +968,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callback_managers, prompts, run_name_list, run_ids_list
)
]
output = self._generate_helper(
return self._generate_helper(
prompts,
stop,
run_managers,
new_arg_supported=bool(new_arg_supported),
**kwargs,
)
return output
if len(missing_prompts) > 0:
run_managers = [
callback_managers[idx].on_llm_start(
@ -1232,14 +1226,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
]
)
run_managers = [r[0] for r in run_managers] # type: ignore[misc]
output = await self._agenerate_helper(
return await self._agenerate_helper(
prompts,
stop,
run_managers, # type: ignore[arg-type]
new_arg_supported=bool(new_arg_supported),
**kwargs, # type: ignore[arg-type]
)
return output
if len(missing_prompts) > 0:
run_managers = await asyncio.gather(
*[

View File

@ -19,8 +19,7 @@ def default(obj: Any) -> Any:
"""
if isinstance(obj, Serializable):
return obj.to_json()
else:
return to_json_not_implemented(obj)
return to_json_not_implemented(obj)
def _dump_pydantic_models(obj: Any) -> Any:
@ -36,8 +35,7 @@ def _dump_pydantic_models(obj: Any) -> Any:
obj_copy = obj.model_copy(deep=True)
obj_copy.message.additional_kwargs["parsed"] = parsed.model_dump()
return obj_copy
else:
return obj
return obj
def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
@ -64,14 +62,12 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
if pretty:
indent = kwargs.pop("indent", 2)
return json.dumps(obj, default=default, indent=indent, **kwargs)
else:
return json.dumps(obj, default=default, **kwargs)
return json.dumps(obj, default=default, **kwargs)
except TypeError:
if pretty:
indent = kwargs.pop("indent", 2)
return json.dumps(to_json_not_implemented(obj), indent=indent, **kwargs)
else:
return json.dumps(to_json_not_implemented(obj), **kwargs)
return json.dumps(to_json_not_implemented(obj), **kwargs)
def dumpd(obj: Any) -> Any:

View File

@ -98,10 +98,9 @@ class Reviver:
[key] = value["id"]
if key in self.secrets_map:
return self.secrets_map[key]
else:
if self.secrets_from_env and key in os.environ and os.environ[key]:
return os.environ[key]
return None
if self.secrets_from_env and key in os.environ and os.environ[key]:
return os.environ[key]
return None
if (
value.get("lc") == 1
@ -130,7 +129,7 @@ class Reviver:
msg = f"Invalid namespace: {value}"
raise ValueError(msg)
# Has explicit import path.
elif mapping_key in self.import_mappings:
if mapping_key in self.import_mappings:
import_path = self.import_mappings[mapping_key]
# Split into module and name
import_dir, name = import_path[:-1], import_path[-1]

View File

@ -372,7 +372,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
return add_ai_message_chunks(self, other)
elif isinstance(other, (list, tuple)) and all(
if isinstance(other, (list, tuple)) and all(
isinstance(o, AIMessageChunk) for o in other
):
return add_ai_message_chunks(self, *other)

View File

@ -65,8 +65,7 @@ class BaseMessage(Serializable):
"""Coerce the id field to a string."""
if id_value is not None:
return str(id_value)
else:
return id_value
return id_value
def __init__(
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
@ -225,7 +224,7 @@ class BaseMessageChunk(BaseMessage):
self.response_metadata, other.response_metadata
),
)
elif isinstance(other, list) and all(
if isinstance(other, list) and all(
isinstance(o, BaseMessageChunk) for o in other
):
content = merge_content(self.content, *(o.content for o in other))
@ -241,13 +240,12 @@ class BaseMessageChunk(BaseMessage):
additional_kwargs=additional_kwargs,
response_metadata=response_metadata,
)
else:
msg = (
'unsupported operand type(s) for +: "'
f"{self.__class__.__name__}"
f'" and "{other.__class__.__name__}"'
)
raise TypeError(msg)
msg = (
'unsupported operand type(s) for +: "'
f"{self.__class__.__name__}"
f'" and "{other.__class__.__name__}"'
)
raise TypeError(msg)
def message_to_dict(message: BaseMessage) -> dict:

View File

@ -53,7 +53,7 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
),
id=self.id,
)
elif isinstance(other, BaseMessageChunk):
if isinstance(other, BaseMessageChunk):
return self.__class__(
role=self.role,
content=merge_content(self.content, other.content),
@ -65,5 +65,4 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
),
id=self.id,
)
else:
return super().__add__(other)
return super().__add__(other)

View File

@ -320,25 +320,24 @@ def default_tool_parser(
for raw_tool_call in raw_tool_calls:
if "function" not in raw_tool_call:
continue
else:
function_name = raw_tool_call["function"]["name"]
try:
function_args = json.loads(raw_tool_call["function"]["arguments"])
parsed = tool_call(
name=function_name or "",
args=function_args or {},
function_name = raw_tool_call["function"]["name"]
try:
function_args = json.loads(raw_tool_call["function"]["arguments"])
parsed = tool_call(
name=function_name or "",
args=function_args or {},
id=raw_tool_call.get("id"),
)
tool_calls.append(parsed)
except json.JSONDecodeError:
invalid_tool_calls.append(
invalid_tool_call(
name=function_name,
args=raw_tool_call["function"]["arguments"],
id=raw_tool_call.get("id"),
error=None,
)
tool_calls.append(parsed)
except json.JSONDecodeError:
invalid_tool_calls.append(
invalid_tool_call(
name=function_name,
args=raw_tool_call["function"]["arguments"],
id=raw_tool_call.get("id"),
error=None,
)
)
)
return tool_calls, invalid_tool_calls

View File

@ -51,14 +51,13 @@ def _get_type(v: Any) -> str:
"""Get the type associated with the object for serialization purposes."""
if isinstance(v, dict) and "type" in v:
return v["type"]
elif hasattr(v, "type"):
if hasattr(v, "type"):
return v.type
else:
msg = (
f"Expected either a dictionary with a 'type' key or an object "
f"with a 'type' attribute. Instead got type {type(v)}."
)
raise TypeError(msg)
msg = (
f"Expected either a dictionary with a 'type' key or an object "
f"with a 'type' attribute. Instead got type {type(v)}."
)
raise TypeError(msg)
AnyMessage = Annotated[
@ -138,33 +137,32 @@ def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
if _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
if _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
if _type == "chat":
return ChatMessage(**message["data"])
elif _type == "function":
if _type == "function":
return FunctionMessage(**message["data"])
elif _type == "tool":
if _type == "tool":
return ToolMessage(**message["data"])
elif _type == "remove":
if _type == "remove":
return RemoveMessage(**message["data"])
elif _type == "AIMessageChunk":
if _type == "AIMessageChunk":
return AIMessageChunk(**message["data"])
elif _type == "HumanMessageChunk":
if _type == "HumanMessageChunk":
return HumanMessageChunk(**message["data"])
elif _type == "FunctionMessageChunk":
if _type == "FunctionMessageChunk":
return FunctionMessageChunk(**message["data"])
elif _type == "ToolMessageChunk":
if _type == "ToolMessageChunk":
return ToolMessageChunk(**message["data"])
elif _type == "SystemMessageChunk":
if _type == "SystemMessageChunk":
return SystemMessageChunk(**message["data"])
elif _type == "ChatMessageChunk":
if _type == "ChatMessageChunk":
return ChatMessageChunk(**message["data"])
else:
msg = f"Got unexpected message type: {_type}"
raise ValueError(msg)
msg = f"Got unexpected message type: {_type}"
raise ValueError(msg)
def messages_from_dict(messages: Sequence[dict]) -> list[BaseMessage]:
@ -387,8 +385,7 @@ def _runnable_support(func: Callable) -> Callable:
if messages is not None:
return func(messages, **kwargs)
else:
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
wrapped.__doc__ = func.__doc__
return wrapped
@ -472,8 +469,6 @@ def filter_messages(
or (exclude_ids and msg.id in exclude_ids)
):
continue
else:
pass
if exclude_tool_calls is True and (
(isinstance(msg, AIMessage) and msg.tool_calls)
@ -926,7 +921,7 @@ def trim_messages(
partial_strategy="first" if allow_partial else None,
end_on=end_on,
)
elif strategy == "last":
if strategy == "last":
return _last_max_tokens(
messages,
max_tokens=max_tokens,
@ -937,9 +932,8 @@ def trim_messages(
end_on=end_on,
text_splitter=text_splitter_fn,
)
else:
msg = f"Unrecognized {strategy=}. Supported strategies are 'last' and 'first'."
raise ValueError(msg)
msg = f"Unrecognized {strategy=}. Supported strategies are 'last' and 'first'."
raise ValueError(msg)
def convert_to_openai_messages(
@ -1269,8 +1263,7 @@ def convert_to_openai_messages(
if is_single:
return oai_messages[0]
else:
return oai_messages
return oai_messages
def _first_max_tokens(
@ -1347,7 +1340,7 @@ def _first_max_tokens(
if isinstance(block, str):
text = block
break
elif isinstance(block, dict) and block.get("type") == "text":
if isinstance(block, dict) and block.get("type") == "text":
text = block.get("text")
break
@ -1517,19 +1510,18 @@ def _bytes_to_b64_str(bytes_: bytes) -> str:
def _get_message_openai_role(message: BaseMessage) -> str:
if isinstance(message, AIMessage):
return "assistant"
elif isinstance(message, HumanMessage):
if isinstance(message, HumanMessage):
return "user"
elif isinstance(message, ToolMessage):
if isinstance(message, ToolMessage):
return "tool"
elif isinstance(message, SystemMessage):
if isinstance(message, SystemMessage):
return message.additional_kwargs.get("__openai_role__", "system")
elif isinstance(message, FunctionMessage):
if isinstance(message, FunctionMessage):
return "function"
elif isinstance(message, ChatMessage):
if isinstance(message, ChatMessage):
return message.role
else:
msg = f"Unknown BaseMessage type {message.__class__}."
raise ValueError(msg) # noqa: TRY004
msg = f"Unknown BaseMessage type {message.__class__}."
raise ValueError(msg) # noqa: TRY004
def _convert_to_openai_tool_calls(tool_calls: list[ToolCall]) -> list[dict]:

View File

@ -97,13 +97,12 @@ class BaseGenerationOutputParser(
config,
run_type="parser",
)
else:
return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
@override
async def ainvoke(
@ -121,13 +120,12 @@ class BaseGenerationOutputParser(
config,
run_type="parser",
)
else:
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
class BaseOutputParser(
@ -203,13 +201,12 @@ class BaseOutputParser(
config,
run_type="parser",
)
else:
return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
@override
async def ainvoke(
@ -227,13 +224,12 @@ class BaseOutputParser(
config,
run_type="parser",
)
else:
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.

View File

@ -53,8 +53,9 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
if issubclass(pydantic_object, pydantic.BaseModel):
return pydantic_object.model_json_schema()
elif issubclass(pydantic_object, pydantic.v1.BaseModel):
if issubclass(pydantic_object, pydantic.v1.BaseModel):
return pydantic_object.schema()
return None
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
@ -106,19 +107,18 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
"""
if self.pydantic_object is None:
return "Return a JSON object."
else:
# Copy schema to avoid altering original Pydantic schema.
schema = dict(self._get_schema(self.pydantic_object).items())
# Copy schema to avoid altering original Pydantic schema.
schema = dict(self._get_schema(self.pydantic_object).items())
# Remove extraneous fields.
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema_str = json.dumps(reduced_schema, ensure_ascii=False)
return JSON_FORMAT_INSTRUCTIONS.format(schema=schema_str)
# Remove extraneous fields.
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema_str = json.dumps(reduced_schema, ensure_ascii=False)
return JSON_FORMAT_INSTRUCTIONS.format(schema=schema_str)
@property
def _type(self) -> str:

View File

@ -99,9 +99,8 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
except KeyError as exc:
if partial:
return None
else:
msg = f"Could not parse function call: {exc}"
raise OutputParserException(msg) from exc
msg = f"Could not parse function call: {exc}"
raise OutputParserException(msg) from exc
try:
if partial:
try:
@ -109,13 +108,12 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
return parse_partial_json(
function_call["arguments"], strict=self.strict
)
else:
return {
**function_call,
"arguments": parse_partial_json(
function_call["arguments"], strict=self.strict
),
}
return {
**function_call,
"arguments": parse_partial_json(
function_call["arguments"], strict=self.strict
),
}
except json.JSONDecodeError:
return None
else:

View File

@ -241,10 +241,9 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
)
if self.return_id:
return single_result
elif single_result:
if single_result:
return single_result["args"]
else:
return None
return None
parsed_result = [res for res in parsed_result if res["type"] == self.key_name]
if not self.return_id:
parsed_result = [res["args"] for res in parsed_result]
@ -300,5 +299,4 @@ class PydanticToolsParser(JsonOutputToolsParser):
raise
if self.first_tool_only:
return pydantic_objects[0] if pydantic_objects else None
else:
return pydantic_objects
return pydantic_objects

View File

@ -28,12 +28,11 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
try:
if issubclass(self.pydantic_object, pydantic.BaseModel):
return self.pydantic_object.model_validate(obj)
elif issubclass(self.pydantic_object, pydantic.v1.BaseModel):
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
return self.pydantic_object.parse_obj(obj)
else:
msg = f"Unsupported model version for PydanticOutputParser: \
msg = f"Unsupported model version for PydanticOutputParser: \
{self.pydantic_object.__class__}"
raise OutputParserException(msg)
raise OutputParserException(msg)
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
raise self._parser_exception(e, obj) from e
else: # pydantic v1

View File

@ -282,5 +282,4 @@ def nested_element(path: list[str], elem: ET.Element) -> Any:
"""
if len(path) == 0:
return AddableDict({elem.tag: elem.text})
else:
return AddableDict({path[0]: [nested_element(path[1:], elem)]})
return AddableDict({path[0]: [nested_element(path[1:], elem)]})

View File

@ -60,11 +60,9 @@ class ChatGeneration(Generation):
if isinstance(block, str):
text = block
break
elif isinstance(block, dict) and "text" in block:
if isinstance(block, dict) and "text" in block:
text = block["text"]
break
else:
pass
else:
pass
self.text = text
@ -104,7 +102,7 @@ class ChatGenerationChunk(ChatGeneration):
message=self.message + other.message,
generation_info=generation_info or None,
)
elif isinstance(other, list) and all(
if isinstance(other, list) and all(
isinstance(x, ChatGenerationChunk) for x in other
):
generation_info = merge_dicts(
@ -115,8 +113,5 @@ class ChatGenerationChunk(ChatGeneration):
message=self.message + [chunk.message for chunk in other],
generation_info=generation_info or None,
)
else:
msg = (
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
raise TypeError(msg)
msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
raise TypeError(msg)

View File

@ -64,8 +64,5 @@ class GenerationChunk(Generation):
text=self.text + other.text,
generation_info=generation_info or None,
)
else:
msg = (
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
raise TypeError(msg)
msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
raise TypeError(msg)

View File

@ -513,7 +513,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
partial_variables=partial_variables,
)
return cls(prompt=prompt, **kwargs)
elif isinstance(template, list):
if isinstance(template, list):
if (partial_variables is not None) and len(partial_variables) > 0:
msg = "Partial variables are not supported for list of templates."
raise ValueError(msg)
@ -571,9 +571,8 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
msg = f"Invalid template: {tmpl}"
raise ValueError(msg)
return cls(prompt=prompt, **kwargs)
else:
msg = f"Invalid template: {template}"
raise ValueError(msg) # noqa: TRY004
msg = f"Invalid template: {template}"
raise ValueError(msg) # noqa: TRY004
@classmethod
def from_template_file(
@ -625,8 +624,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
List of input variable names.
"""
prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt]
input_variables = [iv for prompt in prompts for iv in prompt.input_variables]
return input_variables
return [iv for prompt in prompts for iv in prompt.input_variables]
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
@ -642,19 +640,18 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
return self._msg_class(
content=text, additional_kwargs=self.additional_kwargs
)
else:
content: list = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = prompt.format(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = prompt.format(**inputs)
content.append({"type": "image_url", "image_url": formatted})
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
)
content: list = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = prompt.format(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = prompt.format(**inputs)
content.append({"type": "image_url", "image_url": formatted})
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
)
async def aformat(self, **kwargs: Any) -> BaseMessage:
"""Async format the prompt template.
@ -670,19 +667,18 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
return self._msg_class(
content=text, additional_kwargs=self.additional_kwargs
)
else:
content: list = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = await prompt.aformat(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = await prompt.aformat(**inputs)
content.append({"type": "image_url", "image_url": formatted})
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
)
content: list = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = await prompt.aformat(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = await prompt.aformat(**inputs)
content.append({"type": "image_url", "image_url": formatted})
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
)
def pretty_repr(self, html: bool = False) -> str:
"""Human-readable representation.
@ -1034,25 +1030,24 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
return ChatPromptTemplate(messages=self.messages + other.messages).partial(
**partials
) # type: ignore[call-arg]
elif isinstance(
if isinstance(
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
):
return ChatPromptTemplate(messages=self.messages + [other]).partial(
**partials
) # type: ignore[call-arg]
elif isinstance(other, (list, tuple)):
if isinstance(other, (list, tuple)):
_other = ChatPromptTemplate.from_messages(other)
return ChatPromptTemplate(messages=self.messages + _other.messages).partial(
**partials
) # type: ignore[call-arg]
elif isinstance(other, str):
if isinstance(other, str):
prompt = HumanMessagePromptTemplate.from_template(other)
return ChatPromptTemplate(messages=self.messages + [prompt]).partial(
**partials
) # type: ignore[call-arg]
else:
msg = f"Unsupported operand type for +: {type(other)}"
raise NotImplementedError(msg)
msg = f"Unsupported operand type for +: {type(other)}"
raise NotImplementedError(msg)
@model_validator(mode="before")
@classmethod
@ -1322,8 +1317,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
start, stop, step = index.indices(len(self.messages))
messages = self.messages[start:stop:step]
return ChatPromptTemplate.from_messages(messages)
else:
return self.messages[index]
return self.messages[index]
def __len__(self) -> int:
"""Get the length of the chat template."""

View File

@ -88,11 +88,10 @@ class _FewShotPromptTemplateMixin(BaseModel):
"""
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
if self.example_selector is not None:
return self.example_selector.select_examples(kwargs)
else:
msg = "One of 'examples' and 'example_selector' should be provided"
raise ValueError(msg)
msg = "One of 'examples' and 'example_selector' should be provided"
raise ValueError(msg)
async def _aget_examples(self, **kwargs: Any) -> list[dict]:
"""Async get the examples to use for formatting the prompt.
@ -108,11 +107,10 @@ class _FewShotPromptTemplateMixin(BaseModel):
"""
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
if self.example_selector is not None:
return await self.example_selector.aselect_examples(kwargs)
else:
msg = "One of 'examples' and 'example_selector' should be provided"
raise ValueError(msg)
msg = "One of 'examples' and 'example_selector' should be provided"
raise ValueError(msg)
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
@ -394,12 +392,11 @@ class FewShotChatMessagePromptTemplate(
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
]
# Format the examples.
messages = [
return [
message
for example in examples
for message in self.example_prompt.format_messages(**example)
]
return messages
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format kwargs into a list of messages.
@ -416,12 +413,11 @@ class FewShotChatMessagePromptTemplate(
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
]
# Format the examples.
messages = [
return [
message
for example in examples
for message in await self.example_prompt.aformat_messages(**example)
]
return messages
def format(self, **kwargs: Any) -> str:
"""Format the prompt with inputs generating a string.

View File

@ -97,18 +97,16 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
def _get_examples(self, **kwargs: Any) -> list[dict]:
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
if self.example_selector is not None:
return self.example_selector.select_examples(kwargs)
else:
raise ValueError
raise ValueError
async def _aget_examples(self, **kwargs: Any) -> list[dict]:
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
if self.example_selector is not None:
return await self.example_selector.aselect_examples(kwargs)
else:
raise ValueError
raise ValueError
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.

View File

@ -110,14 +110,13 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
if not url:
msg = "Must provide url."
raise ValueError(msg)
elif not isinstance(url, str):
if not isinstance(url, str):
msg = "url must be a string."
raise ValueError(msg)
else:
output: ImageURL = {"url": url}
if detail:
# Don't check literal values here: let the API check them
output["detail"] = detail # type: ignore[typeddict-item]
raise ValueError(msg) # noqa: TRY004
output: ImageURL = {"url": url}
if detail:
# Don't check literal values here: let the API check them
output["detail"] = detail # type: ignore[typeddict-item]
return output
async def aformat(self, **kwargs: Any) -> ImageURL:

View File

@ -154,8 +154,7 @@ class PromptTemplate(StringPromptTemplate):
if k in partial_variables:
msg = "Cannot have same variable partialed twice."
raise ValueError(msg)
else:
partial_variables[k] = v
partial_variables[k] = v
return PromptTemplate(
template=template,
input_variables=input_variables,
@ -163,12 +162,11 @@ class PromptTemplate(StringPromptTemplate):
template_format="f-string",
validate_template=validate_template,
)
elif isinstance(other, str):
if isinstance(other, str):
prompt = PromptTemplate.from_template(other)
return self + prompt
else:
msg = f"Unsupported operand type for +: {type(other)}"
raise NotImplementedError(msg)
msg = f"Unsupported operand type for +: {type(other)}"
raise NotImplementedError(msg)
@property
def _prompt_type(self) -> str:

View File

@ -100,8 +100,7 @@ def _get_jinja2_variables_from_template(template: str) -> set[str]:
# noqa for insecure warning elsewhere
env = Environment() # noqa: S701
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
return meta.find_undeclared_variables(ast)
def mustache_formatter(template: str, /, **kwargs: Any) -> str:

View File

@ -166,6 +166,5 @@ class StructuredPrompt(ChatPromptTemplate):
*others[1:],
name=name,
)
else:
msg = "Structured prompts need to be piped to a language model."
raise NotImplementedError(msg)
msg = "Structured prompts need to be piped to a language model."
raise NotImplementedError(msg)

View File

@ -208,8 +208,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
default_retriever_name = default_retriever_name[:-9]
default_retriever_name = default_retriever_name.lower()
ls_params = LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
return ls_params
return LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any

View File

@ -269,10 +269,8 @@ class Runnable(Generic[Input, Output], ABC):
if suffix:
if name_[0].isupper():
return name_ + suffix.title()
else:
return name_ + "_" + suffix.lower()
else:
return name_
return name_ + "_" + suffix.lower()
return name_
@property
def InputType(self) -> type[Input]: # noqa: N802
@ -513,10 +511,9 @@ class Runnable(Generic[Input, Output], ABC):
if field_name in [i for i in include if i != "configurable"]
},
}
model = create_model_v2( # type: ignore[call-overload]
return create_model_v2( # type: ignore[call-overload]
self.get_name("Config"), field_definitions=all_fields
)
return model
def get_config_jsonschema(
self, *, include: Optional[Sequence[str]] = None
@ -2051,8 +2048,7 @@ class Runnable(Generic[Input, Output], ABC):
run_manager.on_chain_error(e)
if return_exceptions:
return cast("list[Output]", [e for _ in input])
else:
raise
raise
else:
first_exception: Optional[Exception] = None
for run_manager, out in zip(run_managers, output):
@ -2063,8 +2059,7 @@ class Runnable(Generic[Input, Output], ABC):
run_manager.on_chain_end(out)
if return_exceptions or first_exception is None:
return cast("list[Output]", output)
else:
raise first_exception
raise first_exception
async def _abatch_with_config(
self,
@ -2130,8 +2125,7 @@ class Runnable(Generic[Input, Output], ABC):
)
if return_exceptions:
return cast("list[Output]", [e for _ in input])
else:
raise
raise
else:
first_exception: Optional[Exception] = None
coros: list[Awaitable[None]] = []
@ -2144,8 +2138,7 @@ class Runnable(Generic[Input, Output], ABC):
await asyncio.gather(*coros)
if return_exceptions or first_exception is None:
return cast("list[Output]", output)
else:
raise first_exception
raise first_exception
def _transform_stream_with_config(
self,
@ -2615,7 +2608,7 @@ def _seq_input_schema(
first = steps[0]
if len(steps) == 1:
return first.get_input_schema(config)
elif isinstance(first, RunnableAssign):
if isinstance(first, RunnableAssign):
next_input_schema = _seq_input_schema(steps[1:], config)
if not issubclass(next_input_schema, RootModel):
# it's a dict as expected
@ -2641,7 +2634,7 @@ def _seq_output_schema(
last = steps[-1]
if len(steps) == 1:
return last.get_input_schema(config)
elif isinstance(last, RunnableAssign):
if isinstance(last, RunnableAssign):
mapper_output_schema = last.mapper.get_output_schema(config)
prev_output_schema = _seq_output_schema(steps[:-1], config)
if not issubclass(prev_output_schema, RootModel):
@ -2672,11 +2665,10 @@ def _seq_output_schema(
if k in last.keys
},
)
else:
field = prev_output_schema.model_fields[last.keys]
return create_model_v2( # type: ignore[call-overload]
"RunnableSequenceOutput", root=(field.annotation, field.default)
)
field = prev_output_schema.model_fields[last.keys]
return create_model_v2( # type: ignore[call-overload]
"RunnableSequenceOutput", root=(field.annotation, field.default)
)
return last.get_output_schema(config)
@ -2988,14 +2980,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
other.last,
name=self.name or other.name,
)
else:
return RunnableSequence(
self.first,
*self.middle,
self.last,
coerce_to_runnable(other),
name=self.name,
)
return RunnableSequence(
self.first,
*self.middle,
self.last,
coerce_to_runnable(other),
name=self.name,
)
@override
def __ror__(
@ -3017,14 +3008,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
self.last,
name=other.name or self.name,
)
else:
return RunnableSequence(
coerce_to_runnable(other),
self.first,
*self.middle,
self.last,
name=self.name,
)
return RunnableSequence(
coerce_to_runnable(other),
self.first,
*self.middle,
self.last,
name=self.name,
)
@override
def invoke(
@ -3224,8 +3214,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
rm.on_chain_error(e)
if return_exceptions:
return cast("list[Output]", [e for _ in inputs])
else:
raise
raise
else:
first_exception: Optional[Exception] = None
for run_manager, out in zip(run_managers, inputs):
@ -3236,8 +3225,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
run_manager.on_chain_end(out)
if return_exceptions or first_exception is None:
return cast("list[Output]", inputs)
else:
raise first_exception
raise first_exception
@override
async def abatch(
@ -3357,8 +3345,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
if return_exceptions:
return cast("list[Output]", [e for _ in inputs])
else:
raise
raise
else:
first_exception: Optional[Exception] = None
coros: list[Awaitable[None]] = []
@ -3371,8 +3358,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
await asyncio.gather(*coros)
if return_exceptions or first_exception is None:
return cast("list[Output]", inputs)
else:
raise first_exception
raise first_exception
def _transform(
self,
@ -3826,8 +3812,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
return await asyncio.create_task( # type: ignore
step.ainvoke(input, child_config), context=context
)
else:
return await asyncio.create_task(step.ainvoke(input, child_config))
return await asyncio.create_task(step.ainvoke(input, child_config))
# gather results from all steps
try:
@ -4141,10 +4126,9 @@ class RunnableGenerator(Runnable[Input, Output]):
first_param = next(iter(params.values()), None)
if first_param and first_param.annotation != inspect.Parameter.empty:
return getattr(first_param.annotation, "__args__", (Any,))[0]
else:
return Any
except ValueError:
return Any
pass
return Any
@override
def get_input_schema(
@ -4220,12 +4204,10 @@ class RunnableGenerator(Runnable[Input, Output]):
if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"):
return self._transform == other._transform
elif hasattr(self, "_atransform") and hasattr(other, "_atransform"):
if hasattr(self, "_atransform") and hasattr(other, "_atransform"):
return self._atransform == other._atransform
else:
return False
else:
return False
return False
@override
def __repr__(self) -> str:
@ -4443,10 +4425,9 @@ class RunnableLambda(Runnable[Input, Output]):
first_param = next(iter(params.values()), None)
if first_param and first_param.annotation != inspect.Parameter.empty:
return first_param.annotation
else:
return Any
except ValueError:
return Any
pass
return Any
@override
def get_input_schema(
@ -4472,16 +4453,15 @@ class RunnableLambda(Runnable[Input, Output]):
fields = {item[1:-1]: (Any, ...) for item in items}
# It's a dict, lol
return create_model_v2(self.get_name("Input"), field_definitions=fields)
else:
module = getattr(func, "__module__", None)
return create_model_v2(
self.get_name("Input"),
root=list[Any],
# To create the schema, we need to provide the module
# where the underlying function is defined.
# This allows pydantic to resolve type annotations appropriately.
module_name=module,
)
module = getattr(func, "__module__", None)
return create_model_v2(
self.get_name("Input"),
root=list[Any],
# To create the schema, we need to provide the module
# where the underlying function is defined.
# This allows pydantic to resolve type annotations appropriately.
module_name=module,
)
if self.InputType != Any:
return super().get_input_schema(config)
@ -4513,10 +4493,9 @@ class RunnableLambda(Runnable[Input, Output]):
):
return getattr(sig.return_annotation, "__args__", (Any,))[0]
return sig.return_annotation
else:
return Any
except ValueError:
return Any
pass
return Any
@override
def get_output_schema(
@ -4607,12 +4586,10 @@ class RunnableLambda(Runnable[Input, Output]):
if isinstance(other, RunnableLambda):
if hasattr(self, "func") and hasattr(other, "func"):
return self.func == other.func
elif hasattr(self, "afunc") and hasattr(other, "afunc"):
if hasattr(self, "afunc") and hasattr(other, "afunc"):
return self.afunc == other.afunc
else:
return False
else:
return False
return False
def __repr__(self) -> str:
"""A string representation of this Runnable."""
@ -4806,12 +4783,8 @@ class RunnableLambda(Runnable[Input, Output]):
self._config(config, self.func),
**kwargs,
)
else:
msg = (
"Cannot invoke a coroutine function synchronously."
"Use `ainvoke` instead."
)
raise TypeError(msg)
msg = "Cannot invoke a coroutine function synchronously.Use `ainvoke` instead."
raise TypeError(msg)
@override
async def ainvoke(
@ -5886,7 +5859,7 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
)
return wrapper
elif config_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
if config_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
idx = list(inspect.signature(attr).parameters).index("config")
@wraps(attr)
@ -5895,14 +5868,11 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
argsl = list(args)
argsl[idx] = merge_configs(self.config, argsl[idx])
return attr(*argsl, **kwargs)
else:
return attr(
*args,
config=merge_configs(
self.config, kwargs.pop("config", None)
),
**kwargs,
)
return attr(
*args,
config=merge_configs(self.config, kwargs.pop("config", None)),
**kwargs,
)
return wrapper
@ -5957,18 +5927,17 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
"""
if isinstance(thing, Runnable):
return thing
elif is_async_generator(thing) or inspect.isgeneratorfunction(thing):
if is_async_generator(thing) or inspect.isgeneratorfunction(thing):
return RunnableGenerator(thing)
elif callable(thing):
if callable(thing):
return RunnableLambda(cast("Callable[[Input], Output]", thing))
elif isinstance(thing, dict):
if isinstance(thing, dict):
return cast("Runnable[Input, Output]", RunnableParallel(thing))
else:
msg = (
f"Expected a Runnable, callable or dict."
f"Instead got an unsupported type: {type(thing)}"
)
raise TypeError(msg)
msg = (
f"Expected a Runnable, callable or dict."
f"Instead got an unsupported type: {type(thing)}"
)
raise TypeError(msg)
@overload

View File

@ -314,8 +314,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
return wrapper
else:
return attr
return attr
class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
@ -462,8 +461,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
self.default.__class__(**{**init_params, **configurable}),
config,
)
else:
return (self.default, config)
return (self.default, config)
RunnableConfigurableFields.model_rebuild()
@ -638,15 +636,13 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
# return the chosen alternative
if which == self.default_key:
return (self.default, config)
elif which in self.alternatives:
if which in self.alternatives:
alt = self.alternatives[which]
if isinstance(alt, Runnable):
return (alt, config)
else:
return (alt(), config)
else:
msg = f"Unknown alternative: {which}"
raise ValueError(msg)
return (alt(), config)
msg = f"Unknown alternative: {which}"
raise ValueError(msg)
def _strremoveprefix(s: str, prefix: str) -> str:
@ -714,12 +710,11 @@ def make_options_spec(
default=spec.default,
is_shared=spec.is_shared,
)
else:
return ConfigurableFieldSpec(
id=spec.id,
name=spec.name,
description=spec.description or description,
annotation=Sequence[enum], # type: ignore[valid-type]
default=spec.default,
is_shared=spec.is_shared,
)
return ConfigurableFieldSpec(
id=spec.id,
name=spec.name,
description=spec.description or description,
annotation=Sequence[enum], # type: ignore[valid-type]
default=spec.default,
is_shared=spec.is_shared,
)

View File

@ -661,7 +661,6 @@ def _is_runnable_type(type_: Any) -> bool:
origin = getattr(type_, "__origin__", None)
if inspect.isclass(origin):
return issubclass(origin, Runnable)
elif origin is typing.Union:
if origin is typing.Union:
return all(_is_runnable_type(t) for t in type_.__args__)
else:
return False
return False

View File

@ -195,10 +195,7 @@ def node_data_str(id: str, data: Union[type[BaseModel], RunnableType]) -> str:
if not is_uuid(id):
return id
elif isinstance(data, Runnable):
data_str = data.get_name()
else:
data_str = data.__name__
data_str = data.get_name() if isinstance(data, Runnable) else data.__name__
return data_str if not data_str.startswith("Runnable") else data_str[8:]
@ -449,8 +446,7 @@ class Graph:
label = unique_labels[node_id]
if is_uuid(node_id):
return label
else:
return node_id
return node_id
return Graph(
nodes={

View File

@ -407,9 +407,8 @@ def _render_mermaid_using_api(
Path(output_file_path).write_bytes(response.content)
return img_bytes
else:
msg = (
f"Failed to render the graph using the Mermaid.INK API. "
f"Status code: {response.status_code}."
)
raise ValueError(msg)
msg = (
f"Failed to render the graph using the Mermaid.INK API. "
f"Status code: {response.status_code}."
)
raise ValueError(msg)

View File

@ -398,8 +398,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
@property
@override
def OutputType(self) -> type[Output]:
output_type = self._history_chain.OutputType
return output_type
return self._history_chain.OutputType
def get_output_schema(
self, config: Optional[RunnableConfig] = None
@ -460,10 +459,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
return [HumanMessage(content=input_val)]
# If value is a single message, convert to a list
elif isinstance(input_val, BaseMessage):
if isinstance(input_val, BaseMessage):
return [input_val]
# If value is a list or tuple...
elif isinstance(input_val, (list, tuple)):
if isinstance(input_val, (list, tuple)):
# Handle empty case
if len(input_val) == 0:
return list(input_val)
@ -475,12 +474,11 @@ class RunnableWithMessageHistory(RunnableBindingBase):
raise ValueError(msg)
return input_val[0]
return list(input_val)
else:
msg = (
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
f"Got {input_val}."
)
raise ValueError(msg) # noqa: TRY004
msg = (
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
f"Got {input_val}."
)
raise ValueError(msg) # noqa: TRY004
def _get_output_messages(
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
@ -507,16 +505,15 @@ class RunnableWithMessageHistory(RunnableBindingBase):
return [AIMessage(content=output_val)]
# If value is a single message, convert to a list
elif isinstance(output_val, BaseMessage):
if isinstance(output_val, BaseMessage):
return [output_val]
elif isinstance(output_val, (list, tuple)):
if isinstance(output_val, (list, tuple)):
return list(output_val)
else:
msg = (
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
f"Got {output_val}."
)
raise ValueError(msg) # noqa: TRY004
msg = (
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
f"Got {output_val}."
)
raise ValueError(msg) # noqa: TRY004
def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]:
hist: BaseChatMessageHistory = config["configurable"]["message_history"]

View File

@ -459,7 +459,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
return create_model_v2( # type: ignore[call-overload]
"RunnableAssignOutput", field_definitions=fields
)
elif not issubclass(map_output_schema, RootModel):
if not issubclass(map_output_schema, RootModel):
# ie. only map output is a dict
# ie. input type is either unknown or inferred incorrectly
return map_output_schema
@ -741,12 +741,10 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
if isinstance(self.keys, str):
return input.get(self.keys)
else:
picked = {k: input.get(k) for k in self.keys if k in input}
if picked:
return AddableDict(picked)
else:
return None
picked = {k: input.get(k) for k in self.keys if k in input}
if picked:
return AddableDict(picked)
return None
def _invoke(
self,

View File

@ -440,11 +440,10 @@ def get_function_nonlocals(func: Callable) -> list[Any]:
for part in kk.split(".")[1:]:
if vv is None:
break
else:
try:
vv = getattr(vv, part)
except AttributeError:
break
try:
vv = getattr(vv, part)
except AttributeError:
break
else:
values.append(vv)
except (SyntaxError, TypeError, OSError, SystemError):

View File

@ -501,8 +501,7 @@ class ChildTool(BaseTool):
if isinstance(self.args_schema, dict):
return super().get_input_schema(config)
return self.args_schema
else:
return create_schema_from_function(self.name, self._run)
return create_schema_from_function(self.name, self._run)
@override
def invoke(
@ -550,58 +549,54 @@ class ChildTool(BaseTool):
else:
input_args.parse_obj({key_: tool_input})
return tool_input
else:
if input_args is not None:
if isinstance(input_args, dict):
return tool_input
elif issubclass(input_args, BaseModel):
for k, v in get_all_basemodel_annotations(input_args).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
):
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.model_validate(tool_input)
result_dict = result.model_dump()
elif issubclass(input_args, BaseModelV1):
for k, v in get_all_basemodel_annotations(input_args).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
):
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.parse_obj(tool_input)
result_dict = result.dict()
else:
msg = (
"args_schema must be a Pydantic BaseModel, "
f"got {self.args_schema}"
)
raise NotImplementedError(msg)
return {
k: getattr(result, k)
for k, v in result_dict.items()
if k in tool_input
}
return tool_input
if input_args is not None:
if isinstance(input_args, dict):
return tool_input
if issubclass(input_args, BaseModel):
for k, v in get_all_basemodel_annotations(input_args).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
):
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.model_validate(tool_input)
result_dict = result.model_dump()
elif issubclass(input_args, BaseModelV1):
for k, v in get_all_basemodel_annotations(input_args).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
):
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.parse_obj(tool_input)
result_dict = result.dict()
else:
msg = (
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
)
raise NotImplementedError(msg)
return {
k: getattr(result, k) for k, v in result_dict.items() if k in tool_input
}
return tool_input
@model_validator(mode="before")
@classmethod
@ -659,17 +654,16 @@ class ChildTool(BaseTool):
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
elif isinstance(tool_input, dict):
if isinstance(tool_input, dict):
# Make a shallow copy of the input to allow downstream code
# to modify the root level of the input without affecting the
# original input.
# This is used by the tool to inject run time information like
# the callback manager.
return (), tool_input.copy()
else:
# This code path is not expected to be reachable.
msg = f"Invalid tool input type: {type(tool_input)}"
raise TypeError(msg)
# This code path is not expected to be reachable.
msg = f"Invalid tool input type: {type(tool_input)}"
raise TypeError(msg)
def run(
self,
@ -1012,10 +1006,9 @@ def _is_message_content_block(obj: Any) -> bool:
"""Check for OpenAI or Anthropic format tool message content blocks."""
if isinstance(obj, str):
return True
elif isinstance(obj, dict):
if isinstance(obj, dict):
return obj.get("type", None) in ("text", "image_url", "image", "json")
else:
return False
return False
def _stringify(content: Any) -> str:
@ -1153,18 +1146,16 @@ def _replace_type_vars(
if isinstance(type_, TypeVar):
if type_ in generic_map:
return generic_map[type_]
elif default_to_bound:
if default_to_bound:
return type_.__bound__ or Any
else:
return type_
elif (origin := get_origin(type_)) and (args := get_args(type_)):
return type_
if (origin := get_origin(type_)) and (args := get_args(type_)):
new_args = tuple(
_replace_type_vars(arg, generic_map, default_to_bound=default_to_bound)
for arg in args
)
return _py_38_safe_origin(origin)[new_args] # type: ignore[index]
else:
return type_
return type_
class BaseToolkit(BaseModel, ABC):

View File

@ -310,14 +310,14 @@ def tool(
msg = "Name must be a string for tool constructor"
raise ValueError(msg)
return _create_tool_factory(name_or_callable)(runnable)
elif name_or_callable is not None:
if name_or_callable is not None:
if callable(name_or_callable) and hasattr(name_or_callable, "__name__"):
# Used as a decorator without parameters
# @tool
# def my_tool():
# pass
return _create_tool_factory(name_or_callable.__name__)(name_or_callable)
elif isinstance(name_or_callable, str):
if isinstance(name_or_callable, str):
# Used with a new name for the tool
# @tool("search")
# def my_tool():
@ -329,24 +329,23 @@ def tool(
# def my_tool():
# pass
return _create_tool_factory(name_or_callable)
else:
msg = (
f"The first argument must be a string or a callable with a __name__ "
f"for tool decorator. Got {type(name_or_callable)}"
)
raise ValueError(msg)
else:
# Tool is used as a decorator with parameters specified
# @tool(parse_docstring=True)
# def my_tool():
# pass
def _partial(func: Union[Callable, Runnable]) -> BaseTool:
"""Partial function that takes a callable and returns a tool."""
name_ = func.get_name() if isinstance(func, Runnable) else func.__name__
tool_factory = _create_tool_factory(name_)
return tool_factory(func)
msg = (
f"The first argument must be a string or a callable with a __name__ "
f"for tool decorator. Got {type(name_or_callable)}"
)
raise ValueError(msg)
return _partial
# Tool is used as a decorator with parameters specified
# @tool(parse_docstring=True)
# def my_tool():
# pass
def _partial(func: Union[Callable, Runnable]) -> BaseTool:
"""Partial function that takes a callable and returns a tool."""
name_ = func.get_name() if isinstance(func, Runnable) else func.__name__
tool_factory = _create_tool_factory(name_)
return tool_factory(func)
return _partial
def _get_description_from_runnable(runnable: Runnable) -> str:
@ -408,31 +407,30 @@ def convert_runnable_to_tool(
coroutine=runnable.ainvoke,
description=description,
)
async def ainvoke_wrapper(
callbacks: Optional[Callbacks] = None, **kwargs: Any
) -> Any:
return await runnable.ainvoke(kwargs, config={"callbacks": callbacks})
def invoke_wrapper(callbacks: Optional[Callbacks] = None, **kwargs: Any) -> Any:
return runnable.invoke(kwargs, config={"callbacks": callbacks})
if (
arg_types is None
and schema.get("type") == "object"
and schema.get("properties")
):
args_schema = runnable.input_schema
else:
async def ainvoke_wrapper(
callbacks: Optional[Callbacks] = None, **kwargs: Any
) -> Any:
return await runnable.ainvoke(kwargs, config={"callbacks": callbacks})
def invoke_wrapper(callbacks: Optional[Callbacks] = None, **kwargs: Any) -> Any:
return runnable.invoke(kwargs, config={"callbacks": callbacks})
if (
arg_types is None
and schema.get("type") == "object"
and schema.get("properties")
):
args_schema = runnable.input_schema
else:
args_schema = _get_schema_from_runnable_and_arg_types(
runnable, name, arg_types=arg_types
)
return StructuredTool.from_function(
name=name,
func=invoke_wrapper,
coroutine=ainvoke_wrapper,
description=description,
args_schema=args_schema,
args_schema = _get_schema_from_runnable_and_arg_types(
runnable, name, arg_types=arg_types
)
return StructuredTool.from_function(
name=name,
func=invoke_wrapper,
coroutine=ainvoke_wrapper,
description=description,
args_schema=args_schema,
)

View File

@ -183,11 +183,10 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
Returns:
The run.
"""
llm_run = self._llm_run_with_retry_event(
return self._llm_run_with_retry_event(
retry_state=retry_state,
run_id=run_id,
)
return llm_run
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for an LLM run.

View File

@ -335,25 +335,23 @@ class _TracerCore(ABC):
"""Get the inputs for a chain run."""
if self._schema_format in ("original", "original+chat"):
return inputs if isinstance(inputs, dict) else {"input": inputs}
elif self._schema_format == "streaming_events":
if self._schema_format == "streaming_events":
return {
"input": inputs,
}
else:
msg = f"Invalid format: {self._schema_format}"
raise ValueError(msg)
msg = f"Invalid format: {self._schema_format}"
raise ValueError(msg)
def _get_chain_outputs(self, outputs: Any) -> Any:
"""Get the outputs for a chain run."""
if self._schema_format in ("original", "original+chat"):
return outputs if isinstance(outputs, dict) else {"output": outputs}
elif self._schema_format == "streaming_events":
if self._schema_format == "streaming_events":
return {
"output": outputs,
}
else:
msg = f"Invalid format: {self._schema_format}"
raise ValueError(msg)
msg = f"Invalid format: {self._schema_format}"
raise ValueError(msg)
def _complete_chain_run(
self,

View File

@ -78,7 +78,7 @@ def _assign_name(name: Optional[str], serialized: Optional[dict[str, Any]]) -> s
if serialized is not None:
if "name" in serialized:
return serialized["name"]
elif "id" in serialized:
if "id" in serialized:
return serialized["id"][-1]
return "Unnamed"

View File

@ -91,13 +91,12 @@ class FunctionCallbackHandler(BaseTracer):
A string with the breadcrumbs of the run.
"""
parents = self.get_parents(run)[::-1]
string = " > ".join(
return " > ".join(
f"{parent.run_type}:{parent.name}"
if i != len(parents) - 1
else f"{parent.run_type}:{parent.name}"
for i, parent in enumerate(parents + [run])
)
return string
# logging methods
def _on_chain_start(self, run: Run) -> None:

View File

@ -85,7 +85,7 @@ def merge_lists(left: Optional[list], *others: Optional[list]) -> Optional[list]
for other in others:
if other is None:
continue
elif merged is None:
if merged is None:
merged = other.copy()
else:
for e in other:
@ -131,23 +131,22 @@ def merge_obj(left: Any, right: Any) -> Any:
"""
if left is None or right is None:
return left if left is not None else right
elif type(left) is not type(right):
if type(left) is not type(right):
msg = (
f"left and right are of different types. Left type: {type(left)}. Right "
f"type: {type(right)}."
)
raise TypeError(msg)
elif isinstance(left, str):
if isinstance(left, str):
return left + right
elif isinstance(left, dict):
if isinstance(left, dict):
return merge_dicts(left, right)
elif isinstance(left, list):
if isinstance(left, list):
return merge_lists(left, right)
elif left == right:
if left == right:
return left
else:
msg = (
f"Unable to merge {left=} and {right=}. Both must be of type str, dict, or "
f"list, or else be two equal objects."
)
raise ValueError(msg)
msg = (
f"Unable to merge {left=} and {right=}. Both must be of type str, dict, or "
f"list, or else be two equal objects."
)
raise ValueError(msg)

View File

@ -72,12 +72,11 @@ def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
"""
if env_key in os.environ and os.environ[env_key]:
return os.environ[env_key]
elif default is not None:
if default is not None:
return default
else:
msg = (
f"Did not find {key}, please add an environment variable"
f" `{env_key}` which contains it, or pass"
f" `{key}` as a named parameter."
)
raise ValueError(msg)
msg = (
f"Did not find {key}, please add an environment variable"
f" `{env_key}` which contains it, or pass"
f" `{key}` as a named parameter."
)
raise ValueError(msg)

View File

@ -266,9 +266,9 @@ def _convert_any_typed_dicts_to_pydantic(
if type_ in visited:
return visited[type_]
elif depth >= _MAX_TYPED_DICT_RECURSION:
if depth >= _MAX_TYPED_DICT_RECURSION:
return type_
elif is_typeddict(type_):
if is_typeddict(type_):
typed_dict = type_
docstring = inspect.getdoc(typed_dict)
annotations_ = typed_dict.__annotations__
@ -292,7 +292,7 @@ def _convert_any_typed_dicts_to_pydantic(
f"type {type(field_desc)}."
)
raise ValueError(msg)
elif arg_desc := arg_descriptions.get(arg):
if arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc
else:
pass
@ -309,15 +309,14 @@ def _convert_any_typed_dicts_to_pydantic(
model.__doc__ = description
visited[typed_dict] = model
return model
elif (origin := get_origin(type_)) and (type_args := get_args(type_)):
if (origin := get_origin(type_)) and (type_args := get_args(type_)):
subscriptable_origin = _py_38_safe_origin(origin)
type_args = tuple(
_convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited)
for arg in type_args # type: ignore[index]
)
return subscriptable_origin[type_args] # type: ignore[index]
else:
return type_
return type_
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
@ -337,33 +336,31 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
return _convert_json_schema_to_openai_function(
tool.tool_call_schema, name=tool.name, description=tool.description
)
elif issubclass(tool.tool_call_schema, (BaseModel, BaseModelV1)):
if issubclass(tool.tool_call_schema, (BaseModel, BaseModelV1)):
return _convert_pydantic_to_openai_function(
tool.tool_call_schema, name=tool.name, description=tool.description
)
else:
error_msg = (
f"Unsupported tool call schema: {tool.tool_call_schema}. "
"Tool call schema must be a JSON schema dict or a Pydantic model."
)
raise ValueError(error_msg)
else:
return {
"name": tool.name,
"description": tool.description,
"parameters": {
# This is a hack to get around the fact that some tools
# do not expose an args_schema, and expect an argument
# which is a string.
# And Open AI does not support an array type for the
# parameters.
"properties": {
"__arg1": {"title": "__arg1", "type": "string"},
},
"required": ["__arg1"],
"type": "object",
error_msg = (
f"Unsupported tool call schema: {tool.tool_call_schema}. "
"Tool call schema must be a JSON schema dict or a Pydantic model."
)
raise ValueError(error_msg)
return {
"name": tool.name,
"description": tool.description,
"parameters": {
# This is a hack to get around the fact that some tools
# do not expose an args_schema, and expect an argument
# which is a string.
# And Open AI does not support an array type for the
# parameters.
"properties": {
"__arg1": {"title": "__arg1", "type": "string"},
},
}
"required": ["__arg1"],
"type": "object",
},
}
format_tool_to_openai_function = deprecated(
@ -730,7 +727,7 @@ def _parse_google_docstring(
if block.startswith("Args:"):
args_block = block
break
elif block.startswith(("Returns:", "Example:")):
if block.startswith(("Returns:", "Example:")):
# Don't break in case Args come after
past_descriptors = True
elif not past_descriptors:

View File

@ -26,8 +26,7 @@ def get_color_mapping(
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping
return {item: colors[i % len(colors)] for i, item in enumerate(items)}
def get_colored_text(text: str, color: str) -> str:

View File

@ -30,15 +30,13 @@ def _custom_parser(multiline_string: str) -> str:
if isinstance(multiline_string, (bytes, bytearray)):
multiline_string = multiline_string.decode()
multiline_string = re.sub(
return re.sub(
r'("action_input"\:\s*")(.*?)(")',
_replace_new_line,
multiline_string,
flags=re.DOTALL,
)
return multiline_string
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/5b6080fae1f8c68938a1e4fa8667e3744084ee21/interpreter/utils/parse_partial_json.py
# MIT License

View File

@ -60,13 +60,12 @@ def _dereference_refs_helper(
else:
obj_out[k] = v
return obj_out
elif isinstance(obj, list):
if isinstance(obj, list):
return [
_dereference_refs_helper(el, full_schema, skip_keys, processed_refs)
for el in obj
]
else:
return obj
return obj
def _infer_skip_keys(

View File

@ -84,8 +84,7 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
# Then the next tag could be a standalone
# Otherwise it can't be
return padding.isspace() or padding == ""
else:
return False
return False
def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
@ -107,8 +106,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
return on_newline[0].isspace() or not on_newline[0]
# If we're a tag can't be a standalone
else:
return False
return False
def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], str]:

View File

@ -89,7 +89,7 @@ def is_pydantic_v1_subclass(cls: type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
if PYDANTIC_MAJOR_VERSION == 1:
return True
elif PYDANTIC_MAJOR_VERSION == 2:
if PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(cls, BaseModelV1):
@ -335,7 +335,7 @@ def _create_subset_model(
descriptions=descriptions,
fn_description=fn_description,
)
elif PYDANTIC_MAJOR_VERSION == 2:
if PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel as BaseModelV1
if issubclass(model, BaseModelV1):
@ -346,17 +346,15 @@ def _create_subset_model(
descriptions=descriptions,
fn_description=fn_description,
)
else:
return _create_subset_model_v2(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
else:
msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
raise NotImplementedError(msg)
return _create_subset_model_v2(
name,
model,
field_names,
descriptions=descriptions,
fn_description=fn_description,
)
msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
raise NotImplementedError(msg)
if PYDANTIC_MAJOR_VERSION == 2:
@ -387,11 +385,10 @@ if PYDANTIC_MAJOR_VERSION == 2:
if hasattr(model, "model_fields"):
return model.model_fields # type: ignore
elif hasattr(model, "__fields__"):
if hasattr(model, "__fields__"):
return model.__fields__ # type: ignore
else:
msg = f"Expected a Pydantic model. Got {type(model)}"
raise TypeError(msg)
msg = f"Expected a Pydantic model. Got {type(model)}"
raise TypeError(msg)
elif PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1_

View File

@ -14,12 +14,11 @@ def stringify_value(val: Any) -> str:
"""
if isinstance(val, str):
return val
elif isinstance(val, dict):
if isinstance(val, dict):
return "\n" + stringify_dict(val)
elif isinstance(val, list):
if isinstance(val, list):
return "\n".join(stringify_value(v) for v in val)
else:
return str(val)
return str(val)
def stringify_dict(data: dict) -> str:

View File

@ -392,16 +392,14 @@ def from_env(
if isinstance(default, (str, type(None))):
return default
else:
if error_message:
raise ValueError(error_message)
else:
msg = (
f"Did not find {key}, please add an environment variable"
f" `{key}` which contains it, or pass"
f" `{key}` as a named parameter."
)
raise ValueError(msg)
if error_message:
raise ValueError(error_message)
msg = (
f"Did not find {key}, please add an environment variable"
f" `{key}` which contains it, or pass"
f" `{key}` as a named parameter."
)
raise ValueError(msg)
return get_from_env_fn
@ -454,17 +452,15 @@ def secret_from_env(
return SecretStr(os.environ[key])
if isinstance(default, str):
return SecretStr(default)
elif default is None:
if default is None:
return None
else:
if error_message:
raise ValueError(error_message)
else:
msg = (
f"Did not find {key}, please add an environment variable"
f" `{key}` which contains it, or pass"
f" `{key}` as a named parameter."
)
raise ValueError(msg)
if error_message:
raise ValueError(error_message)
msg = (
f"Did not find {key}, please add an environment variable"
f" `{key}` which contains it, or pass"
f" `{key}` as a named parameter."
)
raise ValueError(msg)
return get_secret_from_env

View File

@ -340,20 +340,19 @@ class VectorStore(ABC):
"""
if search_type == "similarity":
return self.similarity_search(query, **kwargs)
elif search_type == "similarity_score_threshold":
if search_type == "similarity_score_threshold":
docs_and_similarities = self.similarity_search_with_relevance_scores(
query, **kwargs
)
return [doc for doc, _ in docs_and_similarities]
elif search_type == "mmr":
if search_type == "mmr":
return self.max_marginal_relevance_search(query, **kwargs)
else:
msg = (
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity', 'similarity_score_threshold'"
" or 'mmr'."
)
raise ValueError(msg)
msg = (
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity', 'similarity_score_threshold'"
" or 'mmr'."
)
raise ValueError(msg)
async def asearch(
self, query: str, search_type: str, **kwargs: Any
@ -375,19 +374,18 @@ class VectorStore(ABC):
"""
if search_type == "similarity":
return await self.asimilarity_search(query, **kwargs)
elif search_type == "similarity_score_threshold":
if search_type == "similarity_score_threshold":
docs_and_similarities = await self.asimilarity_search_with_relevance_scores(
query, **kwargs
)
return [doc for doc, _ in docs_and_similarities]
elif search_type == "mmr":
if search_type == "mmr":
return await self.amax_marginal_relevance_search(query, **kwargs)
else:
msg = (
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity', 'similarity_score_threshold' or 'mmr'."
)
raise ValueError(msg)
msg = (
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity', 'similarity_score_threshold' or 'mmr'."
)
raise ValueError(msg)
@abstractmethod
def similarity_search(

View File

@ -431,24 +431,22 @@ class InMemoryVectorStore(VectorStore):
**kwargs: Any,
) -> list[tuple[Document, float]]:
embedding = self.embedding.embed_query(query)
docs = self.similarity_search_with_score_by_vector(
return self.similarity_search_with_score_by_vector(
embedding,
k,
**kwargs,
)
return docs
@override
async def asimilarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> list[tuple[Document, float]]:
embedding = await self.embedding.aembed_query(query)
docs = self.similarity_search_with_score_by_vector(
return self.similarity_search_with_score_by_vector(
embedding,
k,
**kwargs,
)
return docs
@override
def similarity_search_by_vector(

View File

@ -103,7 +103,6 @@ ignore = [
"PGH",
"PLR",
"PYI",
"RET",
"RUF",
"SLF",
"TD",

View File

@ -17,12 +17,11 @@ EXAMPLES = [
def selector() -> LengthBasedExampleSelector:
"""Get length based selector to use in tests."""
prompts = PromptTemplate(input_variables=["question"], template="{question}")
selector = LengthBasedExampleSelector(
return LengthBasedExampleSelector(
examples=EXAMPLES,
example_prompt=prompts,
max_length=30,
)
return selector
def test_selector_valid(selector: LengthBasedExampleSelector) -> None:

View File

@ -18,9 +18,8 @@ def _fake_runnable(
) -> Union[BaseModel, dict]:
if isclass(schema) and is_basemodel_subclass(schema):
return schema(name="yo", value=value)
else:
params = cast("dict", schema)["parameters"]
return {k: 1 if k != "value" else value for k, v in params.items()}
params = cast("dict", schema)["parameters"]
return {k: 1 if k != "value" else value for k, v in params.items()}
class FakeStructuredChatModel(FakeListChatModel):

View File

@ -219,8 +219,7 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
def conditional_str_parser(input: str) -> Runnable:
if input == "a":
return str_parser
else:
return xml_parser
return xml_parser
sequence: Runnable = (
prompt

View File

@ -2954,11 +2954,10 @@ def test_higher_order_lambda_runnable(
def router(input: dict[str, Any]) -> Runnable:
if input["key"] == "math":
return itemgetter("input") | math_chain
elif input["key"] == "english":
if input["key"] == "english":
return itemgetter("input") | english_chain
else:
msg = f"Unknown key: {input['key']}"
raise ValueError(msg)
msg = f"Unknown key: {input['key']}"
raise ValueError(msg)
chain: Runnable = input_map | router
assert dumps(chain, pretty=True) == snapshot
@ -3011,11 +3010,10 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
def router(input: dict[str, Any]) -> Runnable:
if input["key"] == "math":
return itemgetter("input") | math_chain
elif input["key"] == "english":
if input["key"] == "english":
return itemgetter("input") | english_chain
else:
msg = f"Unknown key: {input['key']}"
raise ValueError(msg)
msg = f"Unknown key: {input['key']}"
raise ValueError(msg)
chain: Runnable = input_map | router
@ -3034,11 +3032,10 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
async def arouter(input: dict[str, Any]) -> Runnable:
if input["key"] == "math":
return itemgetter("input") | math_chain
elif input["key"] == "english":
if input["key"] == "english":
return itemgetter("input") | english_chain
else:
msg = f"Unknown key: {input['key']}"
raise ValueError(msg)
msg = f"Unknown key: {input['key']}"
raise ValueError(msg)
achain: Runnable = input_map | arouter
math_spy = mocker.spy(math_chain.__class__, "ainvoke")
@ -3858,8 +3855,7 @@ def test_recursive_lambda() -> None:
def _simple_recursion(x: int) -> Union[int, Runnable]:
if x < 10:
return RunnableLambda(lambda *args: _simple_recursion(x + 1))
else:
return x
return x
runnable = RunnableLambda(_simple_recursion)
assert runnable.invoke(5) == 10
@ -3873,11 +3869,10 @@ def test_retrying(mocker: MockerFixture) -> None:
if x == 1:
msg = "x is 1"
raise ValueError(msg)
elif x == 2:
if x == 2:
msg = "x is 2"
raise RuntimeError(msg)
else:
return x
return x
_lambda_mock = mocker.Mock(side_effect=_lambda)
runnable = RunnableLambda(_lambda_mock)
@ -3938,11 +3933,10 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
if x == 1:
msg = "x is 1"
raise ValueError(msg)
elif x == 2:
if x == 2:
msg = "x is 2"
raise RuntimeError(msg)
else:
return x
return x
_lambda_mock = mocker.Mock(side_effect=_lambda)
runnable = RunnableLambda(_lambda_mock)

View File

@ -545,8 +545,7 @@ async def test_astream_events_from_model() -> None:
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
if sys.version_info >= (3, 11):
return model.invoke(input)
else:
return model.invoke(input, config)
return model.invoke(input, config)
events = await _collect_events(i_dont_stream.astream_events("hello", version="v1"))
_assert_events_equal_allow_superset_metadata(
@ -670,8 +669,7 @@ async def test_astream_events_from_model() -> None:
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
if sys.version_info >= (3, 11):
return await model.ainvoke(input)
else:
return await model.ainvoke(input, config)
return await model.ainvoke(input, config)
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1"))
_assert_events_equal_allow_superset_metadata(

View File

@ -615,8 +615,7 @@ async def test_astream_with_model_in_chain() -> None:
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
if sys.version_info >= (3, 11):
return model.invoke(input)
else:
return model.invoke(input, config)
return model.invoke(input, config)
events = await _collect_events(i_dont_stream.astream_events("hello", version="v2"))
_assert_events_equal_allow_superset_metadata(
@ -724,8 +723,7 @@ async def test_astream_with_model_in_chain() -> None:
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
if sys.version_info >= (3, 11):
return await model.ainvoke(input)
else:
return await model.ainvoke(input, config)
return await model.ainvoke(input, config)
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2"))
_assert_events_equal_allow_superset_metadata(

View File

@ -334,23 +334,22 @@ class TestRunnableSequenceParallelTraceNesting:
parent_id_map[n] = matching_post.get("parent_run_id")
i += len(name)
continue
else:
assert posts[i]["name"] == name
dotted_order = posts[i]["dotted_order"]
if prev_dotted_order is not None and not str(
expected_parents[name]
).startswith("RunnableParallel"):
assert dotted_order > prev_dotted_order, (
f"{name} not after {name_order[i - 1]}"
)
prev_dotted_order = dotted_order
if name in dotted_order_map:
msg = f"Duplicate name {name}"
raise ValueError(msg)
dotted_order_map[name] = dotted_order
id_map[name] = posts[i]["id"]
parent_id_map[name] = posts[i].get("parent_run_id")
i += 1
assert posts[i]["name"] == name
dotted_order = posts[i]["dotted_order"]
if prev_dotted_order is not None and not str(
expected_parents[name]
).startswith("RunnableParallel"):
assert dotted_order > prev_dotted_order, (
f"{name} not after {name_order[i - 1]}"
)
prev_dotted_order = dotted_order
if name in dotted_order_map:
msg = f"Duplicate name {name}"
raise ValueError(msg)
dotted_order_map[name] = dotted_order
id_map[name] = posts[i]["id"]
parent_id_map[name] = posts[i].get("parent_run_id")
i += 1
# Now check the dotted orders
for name, parent_ in expected_parents.items():

View File

@ -80,8 +80,7 @@ def _get_tool_call_json_schema(tool: BaseTool) -> dict:
if hasattr(tool_schema, "model_json_schema"):
return tool_schema.model_json_schema()
else:
return tool_schema.schema()
return tool_schema.schema()
def test_unnamed_decorator() -> None:

View File

@ -599,8 +599,7 @@ def test_tracer_nested_runs_on_error() -> None:
def _get_mock_client() -> Client:
mock_session = MagicMock()
client = Client(session=mock_session, api_key="test")
return client
return Client(session=mock_session, api_key="test")
def test_traceable_to_tracing() -> None: