mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
core: Add ruff rules RET (#29384)
See https://docs.astral.sh/ruff/rules/#flake8-return-ret All auto-fixes
This commit is contained in:
parent
9ae792f56c
commit
f241fd5c11
@ -466,7 +466,6 @@ def warn_deprecated(
|
|||||||
f"{removal}"
|
f"{removal}"
|
||||||
)
|
)
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
else:
|
|
||||||
removal = f"in {removal}"
|
removal = f"in {removal}"
|
||||||
|
|
||||||
if not message:
|
if not message:
|
||||||
|
@ -185,7 +185,6 @@ def _convert_agent_action_to_messages(
|
|||||||
"""
|
"""
|
||||||
if isinstance(agent_action, AgentActionMessageLog):
|
if isinstance(agent_action, AgentActionMessageLog):
|
||||||
return agent_action.message_log
|
return agent_action.message_log
|
||||||
else:
|
|
||||||
return [AIMessage(content=agent_action.log)]
|
return [AIMessage(content=agent_action.log)]
|
||||||
|
|
||||||
|
|
||||||
@ -205,7 +204,6 @@ def _convert_agent_observation_to_messages(
|
|||||||
"""
|
"""
|
||||||
if isinstance(agent_action, AgentActionMessageLog):
|
if isinstance(agent_action, AgentActionMessageLog):
|
||||||
return [_create_function_message(agent_action, observation)]
|
return [_create_function_message(agent_action, observation)]
|
||||||
else:
|
|
||||||
content = observation
|
content = observation
|
||||||
if not isinstance(observation, str):
|
if not isinstance(observation, str):
|
||||||
try:
|
try:
|
||||||
|
@ -59,9 +59,8 @@ def _key_from_id(id_: str) -> str:
|
|||||||
wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1]
|
wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1]
|
||||||
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET):
|
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET):
|
||||||
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)]
|
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)]
|
||||||
elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
|
if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET):
|
||||||
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)]
|
return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)]
|
||||||
else:
|
|
||||||
msg = f"Invalid context config id {id_}"
|
msg = f"Invalid context config id {id_}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -197,7 +196,6 @@ class ContextGet(RunnableSerializable):
|
|||||||
configurable = config.get("configurable", {})
|
configurable = config.get("configurable", {})
|
||||||
if isinstance(self.key, list):
|
if isinstance(self.key, list):
|
||||||
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
|
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
|
||||||
else:
|
|
||||||
return configurable[self.ids[0]]()
|
return configurable[self.ids[0]]()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -209,7 +207,6 @@ class ContextGet(RunnableSerializable):
|
|||||||
if isinstance(self.key, list):
|
if isinstance(self.key, list):
|
||||||
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
|
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
|
||||||
return dict(zip(self.key, values))
|
return dict(zip(self.key, values))
|
||||||
else:
|
|
||||||
return await configurable[self.ids[0]]()
|
return await configurable[self.ids[0]]()
|
||||||
|
|
||||||
|
|
||||||
@ -447,5 +444,4 @@ class PrefixContext:
|
|||||||
def _print_keys(keys: Union[str, Sequence[str]]) -> str:
|
def _print_keys(keys: Union[str, Sequence[str]]) -> str:
|
||||||
if isinstance(keys, str):
|
if isinstance(keys, str):
|
||||||
return f"'{keys}'"
|
return f"'{keys}'"
|
||||||
else:
|
|
||||||
return ", ".join(f"'{k}'" for k in keys)
|
return ", ".join(f"'{k}'" for k in keys)
|
||||||
|
@ -128,7 +128,6 @@ class LangSmithLoader(BaseLoader):
|
|||||||
def _stringify(x: Union[str, dict]) -> str:
|
def _stringify(x: Union[str, dict]) -> str:
|
||||||
if isinstance(x, str):
|
if isinstance(x, str):
|
||||||
return x
|
return x
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
return json.dumps(x, indent=2)
|
return json.dumps(x, indent=2)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -54,7 +54,6 @@ class BaseMedia(Serializable):
|
|||||||
"""
|
"""
|
||||||
if id_value is not None:
|
if id_value is not None:
|
||||||
return str(id_value)
|
return str(id_value)
|
||||||
else:
|
|
||||||
return id_value
|
return id_value
|
||||||
|
|
||||||
|
|
||||||
@ -159,11 +158,10 @@ class Blob(BaseMedia):
|
|||||||
"""Read data as a string."""
|
"""Read data as a string."""
|
||||||
if self.data is None and self.path:
|
if self.data is None and self.path:
|
||||||
return Path(self.path).read_text(encoding=self.encoding)
|
return Path(self.path).read_text(encoding=self.encoding)
|
||||||
elif isinstance(self.data, bytes):
|
if isinstance(self.data, bytes):
|
||||||
return self.data.decode(self.encoding)
|
return self.data.decode(self.encoding)
|
||||||
elif isinstance(self.data, str):
|
if isinstance(self.data, str):
|
||||||
return self.data
|
return self.data
|
||||||
else:
|
|
||||||
msg = f"Unable to get string for blob {self}"
|
msg = f"Unable to get string for blob {self}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -171,11 +169,10 @@ class Blob(BaseMedia):
|
|||||||
"""Read data as bytes."""
|
"""Read data as bytes."""
|
||||||
if isinstance(self.data, bytes):
|
if isinstance(self.data, bytes):
|
||||||
return self.data
|
return self.data
|
||||||
elif isinstance(self.data, str):
|
if isinstance(self.data, str):
|
||||||
return self.data.encode(self.encoding)
|
return self.data.encode(self.encoding)
|
||||||
elif self.data is None and self.path:
|
if self.data is None and self.path:
|
||||||
return Path(self.path).read_bytes()
|
return Path(self.path).read_bytes()
|
||||||
else:
|
|
||||||
msg = f"Unable to get bytes for blob {self}"
|
msg = f"Unable to get bytes for blob {self}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -316,5 +313,4 @@ class Document(BaseMedia):
|
|||||||
# a more general solution of formatting content directly inside the prompts.
|
# a more general solution of formatting content directly inside the prompts.
|
||||||
if self.metadata:
|
if self.metadata:
|
||||||
return f"page_content='{self.page_content}' metadata={self.metadata}"
|
return f"page_content='{self.page_content}' metadata={self.metadata}"
|
||||||
else:
|
|
||||||
return f"page_content='{self.page_content}'"
|
return f"page_content='{self.page_content}'"
|
||||||
|
@ -79,7 +79,6 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
|||||||
new_length = remaining_length - self.example_text_lengths[i]
|
new_length = remaining_length - self.example_text_lengths[i]
|
||||||
if new_length < 0:
|
if new_length < 0:
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
examples.append(self.examples[i])
|
examples.append(self.examples[i])
|
||||||
remaining_length = new_length
|
remaining_length = new_length
|
||||||
i += 1
|
i += 1
|
||||||
|
@ -54,7 +54,6 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
|||||||
) -> str:
|
) -> str:
|
||||||
if input_keys:
|
if input_keys:
|
||||||
return " ".join(sorted_values({key: example[key] for key in input_keys}))
|
return " ".join(sorted_values({key: example[key] for key in input_keys}))
|
||||||
else:
|
|
||||||
return " ".join(sorted_values(example))
|
return " ".join(sorted_values(example))
|
||||||
|
|
||||||
def _documents_to_examples(self, documents: list[Document]) -> list[dict]:
|
def _documents_to_examples(self, documents: list[Document]) -> list[dict]:
|
||||||
|
@ -152,11 +152,10 @@ def _get_source_id_assigner(
|
|||||||
"""Get the source id from the document."""
|
"""Get the source id from the document."""
|
||||||
if source_id_key is None:
|
if source_id_key is None:
|
||||||
return lambda doc: None
|
return lambda doc: None
|
||||||
elif isinstance(source_id_key, str):
|
if isinstance(source_id_key, str):
|
||||||
return lambda doc: doc.metadata[source_id_key]
|
return lambda doc: doc.metadata[source_id_key]
|
||||||
elif callable(source_id_key):
|
if callable(source_id_key):
|
||||||
return source_id_key
|
return source_id_key
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"source_id_key should be either None, a string or a callable. "
|
f"source_id_key should be either None, a string or a callable. "
|
||||||
f"Got {source_id_key} of type {type(source_id_key)}."
|
f"Got {source_id_key} of type {type(source_id_key)}."
|
||||||
|
@ -143,7 +143,6 @@ class BaseLanguageModel(
|
|||||||
"""
|
"""
|
||||||
if verbose is None:
|
if verbose is None:
|
||||||
return _get_verbosity()
|
return _get_verbosity()
|
||||||
else:
|
|
||||||
return verbose
|
return verbose
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -351,7 +350,6 @@ class BaseLanguageModel(
|
|||||||
"""
|
"""
|
||||||
if self.custom_get_token_ids is not None:
|
if self.custom_get_token_ids is not None:
|
||||||
return self.custom_get_token_ids(text)
|
return self.custom_get_token_ids(text)
|
||||||
else:
|
|
||||||
return _get_token_ids_default_method(text)
|
return _get_token_ids_default_method(text)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
@ -284,11 +284,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||||
if isinstance(input, PromptValue):
|
if isinstance(input, PromptValue):
|
||||||
return input
|
return input
|
||||||
elif isinstance(input, str):
|
if isinstance(input, str):
|
||||||
return StringPromptValue(text=input)
|
return StringPromptValue(text=input)
|
||||||
elif isinstance(input, Sequence):
|
if isinstance(input, Sequence):
|
||||||
return ChatPromptValue(messages=convert_to_messages(input))
|
return ChatPromptValue(messages=convert_to_messages(input))
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Invalid input type {type(input)}. "
|
f"Invalid input type {type(input)}. "
|
||||||
"Must be a PromptValue, str, or list of BaseMessages."
|
"Must be a PromptValue, str, or list of BaseMessages."
|
||||||
@ -610,7 +609,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
_cleanup_llm_representation(serialized_repr, 1)
|
_cleanup_llm_representation(serialized_repr, 1)
|
||||||
llm_string = json.dumps(serialized_repr, sort_keys=True)
|
llm_string = json.dumps(serialized_repr, sort_keys=True)
|
||||||
return llm_string + "---" + param_string
|
return llm_string + "---" + param_string
|
||||||
else:
|
|
||||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
return str(sorted(params.items()))
|
return str(sorted(params.items()))
|
||||||
@ -1107,7 +1105,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
).generations[0][0]
|
).generations[0][0]
|
||||||
if isinstance(generation, ChatGeneration):
|
if isinstance(generation, ChatGeneration):
|
||||||
return generation.message
|
return generation.message
|
||||||
else:
|
|
||||||
msg = "Unexpected generation type"
|
msg = "Unexpected generation type"
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
@ -1124,7 +1121,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
generation = result.generations[0][0]
|
generation = result.generations[0][0]
|
||||||
if isinstance(generation, ChatGeneration):
|
if isinstance(generation, ChatGeneration):
|
||||||
return generation.message
|
return generation.message
|
||||||
else:
|
|
||||||
msg = "Unexpected generation type"
|
msg = "Unexpected generation type"
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
@ -1167,7 +1163,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
|
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
|
||||||
if isinstance(result.content, str):
|
if isinstance(result.content, str):
|
||||||
return result.content
|
return result.content
|
||||||
else:
|
|
||||||
msg = "Cannot use predict when output is not a string."
|
msg = "Cannot use predict when output is not a string."
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
@ -1194,7 +1189,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
)
|
)
|
||||||
if isinstance(result.content, str):
|
if isinstance(result.content, str):
|
||||||
return result.content
|
return result.content
|
||||||
else:
|
|
||||||
msg = "Cannot use predict when output is not a string."
|
msg = "Cannot use predict when output is not a string."
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
@ -1391,7 +1385,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
[parser_none], exception_key="parsing_error"
|
[parser_none], exception_key="parsing_error"
|
||||||
)
|
)
|
||||||
return RunnableMap(raw=llm) | parser_with_fallback
|
return RunnableMap(raw=llm) | parser_with_fallback
|
||||||
else:
|
|
||||||
return llm | output_parser
|
return llm | output_parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -251,8 +251,7 @@ def update_cache(
|
|||||||
prompt = prompts[missing_prompt_idxs[i]]
|
prompt = prompts[missing_prompt_idxs[i]]
|
||||||
if llm_cache is not None:
|
if llm_cache is not None:
|
||||||
llm_cache.update(prompt, llm_string, result)
|
llm_cache.update(prompt, llm_string, result)
|
||||||
llm_output = new_results.llm_output
|
return new_results.llm_output
|
||||||
return llm_output
|
|
||||||
|
|
||||||
|
|
||||||
async def aupdate_cache(
|
async def aupdate_cache(
|
||||||
@ -285,8 +284,7 @@ async def aupdate_cache(
|
|||||||
prompt = prompts[missing_prompt_idxs[i]]
|
prompt = prompts[missing_prompt_idxs[i]]
|
||||||
if llm_cache:
|
if llm_cache:
|
||||||
await llm_cache.aupdate(prompt, llm_string, result)
|
await llm_cache.aupdate(prompt, llm_string, result)
|
||||||
llm_output = new_results.llm_output
|
return new_results.llm_output
|
||||||
return llm_output
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM(BaseLanguageModel[str], ABC):
|
class BaseLLM(BaseLanguageModel[str], ABC):
|
||||||
@ -330,11 +328,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||||
if isinstance(input, PromptValue):
|
if isinstance(input, PromptValue):
|
||||||
return input
|
return input
|
||||||
elif isinstance(input, str):
|
if isinstance(input, str):
|
||||||
return StringPromptValue(text=input)
|
return StringPromptValue(text=input)
|
||||||
elif isinstance(input, Sequence):
|
if isinstance(input, Sequence):
|
||||||
return ChatPromptValue(messages=convert_to_messages(input))
|
return ChatPromptValue(messages=convert_to_messages(input))
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Invalid input type {type(input)}. "
|
f"Invalid input type {type(input)}. "
|
||||||
"Must be a PromptValue, str, or list of BaseMessages."
|
"Must be a PromptValue, str, or list of BaseMessages."
|
||||||
@ -452,7 +449,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast("list[str]", [e for _ in inputs])
|
return cast("list[str]", [e for _ in inputs])
|
||||||
else:
|
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
batches = [
|
batches = [
|
||||||
@ -499,7 +495,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast("list[str]", [e for _ in inputs])
|
return cast("list[str]", [e for _ in inputs])
|
||||||
else:
|
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
batches = [
|
batches = [
|
||||||
@ -973,14 +968,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
callback_managers, prompts, run_name_list, run_ids_list
|
callback_managers, prompts, run_name_list, run_ids_list
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
output = self._generate_helper(
|
return self._generate_helper(
|
||||||
prompts,
|
prompts,
|
||||||
stop,
|
stop,
|
||||||
run_managers,
|
run_managers,
|
||||||
new_arg_supported=bool(new_arg_supported),
|
new_arg_supported=bool(new_arg_supported),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return output
|
|
||||||
if len(missing_prompts) > 0:
|
if len(missing_prompts) > 0:
|
||||||
run_managers = [
|
run_managers = [
|
||||||
callback_managers[idx].on_llm_start(
|
callback_managers[idx].on_llm_start(
|
||||||
@ -1232,14 +1226,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
run_managers = [r[0] for r in run_managers] # type: ignore[misc]
|
run_managers = [r[0] for r in run_managers] # type: ignore[misc]
|
||||||
output = await self._agenerate_helper(
|
return await self._agenerate_helper(
|
||||||
prompts,
|
prompts,
|
||||||
stop,
|
stop,
|
||||||
run_managers, # type: ignore[arg-type]
|
run_managers, # type: ignore[arg-type]
|
||||||
new_arg_supported=bool(new_arg_supported),
|
new_arg_supported=bool(new_arg_supported),
|
||||||
**kwargs, # type: ignore[arg-type]
|
**kwargs, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
return output
|
|
||||||
if len(missing_prompts) > 0:
|
if len(missing_prompts) > 0:
|
||||||
run_managers = await asyncio.gather(
|
run_managers = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
|
@ -19,7 +19,6 @@ def default(obj: Any) -> Any:
|
|||||||
"""
|
"""
|
||||||
if isinstance(obj, Serializable):
|
if isinstance(obj, Serializable):
|
||||||
return obj.to_json()
|
return obj.to_json()
|
||||||
else:
|
|
||||||
return to_json_not_implemented(obj)
|
return to_json_not_implemented(obj)
|
||||||
|
|
||||||
|
|
||||||
@ -36,7 +35,6 @@ def _dump_pydantic_models(obj: Any) -> Any:
|
|||||||
obj_copy = obj.model_copy(deep=True)
|
obj_copy = obj.model_copy(deep=True)
|
||||||
obj_copy.message.additional_kwargs["parsed"] = parsed.model_dump()
|
obj_copy.message.additional_kwargs["parsed"] = parsed.model_dump()
|
||||||
return obj_copy
|
return obj_copy
|
||||||
else:
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
@ -64,13 +62,11 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
|
|||||||
if pretty:
|
if pretty:
|
||||||
indent = kwargs.pop("indent", 2)
|
indent = kwargs.pop("indent", 2)
|
||||||
return json.dumps(obj, default=default, indent=indent, **kwargs)
|
return json.dumps(obj, default=default, indent=indent, **kwargs)
|
||||||
else:
|
|
||||||
return json.dumps(obj, default=default, **kwargs)
|
return json.dumps(obj, default=default, **kwargs)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
if pretty:
|
if pretty:
|
||||||
indent = kwargs.pop("indent", 2)
|
indent = kwargs.pop("indent", 2)
|
||||||
return json.dumps(to_json_not_implemented(obj), indent=indent, **kwargs)
|
return json.dumps(to_json_not_implemented(obj), indent=indent, **kwargs)
|
||||||
else:
|
|
||||||
return json.dumps(to_json_not_implemented(obj), **kwargs)
|
return json.dumps(to_json_not_implemented(obj), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -98,7 +98,6 @@ class Reviver:
|
|||||||
[key] = value["id"]
|
[key] = value["id"]
|
||||||
if key in self.secrets_map:
|
if key in self.secrets_map:
|
||||||
return self.secrets_map[key]
|
return self.secrets_map[key]
|
||||||
else:
|
|
||||||
if self.secrets_from_env and key in os.environ and os.environ[key]:
|
if self.secrets_from_env and key in os.environ and os.environ[key]:
|
||||||
return os.environ[key]
|
return os.environ[key]
|
||||||
return None
|
return None
|
||||||
@ -130,7 +129,7 @@ class Reviver:
|
|||||||
msg = f"Invalid namespace: {value}"
|
msg = f"Invalid namespace: {value}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
# Has explicit import path.
|
# Has explicit import path.
|
||||||
elif mapping_key in self.import_mappings:
|
if mapping_key in self.import_mappings:
|
||||||
import_path = self.import_mappings[mapping_key]
|
import_path = self.import_mappings[mapping_key]
|
||||||
# Split into module and name
|
# Split into module and name
|
||||||
import_dir, name = import_path[:-1], import_path[-1]
|
import_dir, name = import_path[:-1], import_path[-1]
|
||||||
|
@ -372,7 +372,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
|||||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||||
if isinstance(other, AIMessageChunk):
|
if isinstance(other, AIMessageChunk):
|
||||||
return add_ai_message_chunks(self, other)
|
return add_ai_message_chunks(self, other)
|
||||||
elif isinstance(other, (list, tuple)) and all(
|
if isinstance(other, (list, tuple)) and all(
|
||||||
isinstance(o, AIMessageChunk) for o in other
|
isinstance(o, AIMessageChunk) for o in other
|
||||||
):
|
):
|
||||||
return add_ai_message_chunks(self, *other)
|
return add_ai_message_chunks(self, *other)
|
||||||
|
@ -65,7 +65,6 @@ class BaseMessage(Serializable):
|
|||||||
"""Coerce the id field to a string."""
|
"""Coerce the id field to a string."""
|
||||||
if id_value is not None:
|
if id_value is not None:
|
||||||
return str(id_value)
|
return str(id_value)
|
||||||
else:
|
|
||||||
return id_value
|
return id_value
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -225,7 +224,7 @@ class BaseMessageChunk(BaseMessage):
|
|||||||
self.response_metadata, other.response_metadata
|
self.response_metadata, other.response_metadata
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif isinstance(other, list) and all(
|
if isinstance(other, list) and all(
|
||||||
isinstance(o, BaseMessageChunk) for o in other
|
isinstance(o, BaseMessageChunk) for o in other
|
||||||
):
|
):
|
||||||
content = merge_content(self.content, *(o.content for o in other))
|
content = merge_content(self.content, *(o.content for o in other))
|
||||||
@ -241,7 +240,6 @@ class BaseMessageChunk(BaseMessage):
|
|||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
response_metadata=response_metadata,
|
response_metadata=response_metadata,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
'unsupported operand type(s) for +: "'
|
'unsupported operand type(s) for +: "'
|
||||||
f"{self.__class__.__name__}"
|
f"{self.__class__.__name__}"
|
||||||
|
@ -53,7 +53,7 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
|||||||
),
|
),
|
||||||
id=self.id,
|
id=self.id,
|
||||||
)
|
)
|
||||||
elif isinstance(other, BaseMessageChunk):
|
if isinstance(other, BaseMessageChunk):
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
role=self.role,
|
role=self.role,
|
||||||
content=merge_content(self.content, other.content),
|
content=merge_content(self.content, other.content),
|
||||||
@ -65,5 +65,4 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
|||||||
),
|
),
|
||||||
id=self.id,
|
id=self.id,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return super().__add__(other)
|
return super().__add__(other)
|
||||||
|
@ -320,7 +320,6 @@ def default_tool_parser(
|
|||||||
for raw_tool_call in raw_tool_calls:
|
for raw_tool_call in raw_tool_calls:
|
||||||
if "function" not in raw_tool_call:
|
if "function" not in raw_tool_call:
|
||||||
continue
|
continue
|
||||||
else:
|
|
||||||
function_name = raw_tool_call["function"]["name"]
|
function_name = raw_tool_call["function"]["name"]
|
||||||
try:
|
try:
|
||||||
function_args = json.loads(raw_tool_call["function"]["arguments"])
|
function_args = json.loads(raw_tool_call["function"]["arguments"])
|
||||||
|
@ -51,9 +51,8 @@ def _get_type(v: Any) -> str:
|
|||||||
"""Get the type associated with the object for serialization purposes."""
|
"""Get the type associated with the object for serialization purposes."""
|
||||||
if isinstance(v, dict) and "type" in v:
|
if isinstance(v, dict) and "type" in v:
|
||||||
return v["type"]
|
return v["type"]
|
||||||
elif hasattr(v, "type"):
|
if hasattr(v, "type"):
|
||||||
return v.type
|
return v.type
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Expected either a dictionary with a 'type' key or an object "
|
f"Expected either a dictionary with a 'type' key or an object "
|
||||||
f"with a 'type' attribute. Instead got type {type(v)}."
|
f"with a 'type' attribute. Instead got type {type(v)}."
|
||||||
@ -138,31 +137,30 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
|||||||
_type = message["type"]
|
_type = message["type"]
|
||||||
if _type == "human":
|
if _type == "human":
|
||||||
return HumanMessage(**message["data"])
|
return HumanMessage(**message["data"])
|
||||||
elif _type == "ai":
|
if _type == "ai":
|
||||||
return AIMessage(**message["data"])
|
return AIMessage(**message["data"])
|
||||||
elif _type == "system":
|
if _type == "system":
|
||||||
return SystemMessage(**message["data"])
|
return SystemMessage(**message["data"])
|
||||||
elif _type == "chat":
|
if _type == "chat":
|
||||||
return ChatMessage(**message["data"])
|
return ChatMessage(**message["data"])
|
||||||
elif _type == "function":
|
if _type == "function":
|
||||||
return FunctionMessage(**message["data"])
|
return FunctionMessage(**message["data"])
|
||||||
elif _type == "tool":
|
if _type == "tool":
|
||||||
return ToolMessage(**message["data"])
|
return ToolMessage(**message["data"])
|
||||||
elif _type == "remove":
|
if _type == "remove":
|
||||||
return RemoveMessage(**message["data"])
|
return RemoveMessage(**message["data"])
|
||||||
elif _type == "AIMessageChunk":
|
if _type == "AIMessageChunk":
|
||||||
return AIMessageChunk(**message["data"])
|
return AIMessageChunk(**message["data"])
|
||||||
elif _type == "HumanMessageChunk":
|
if _type == "HumanMessageChunk":
|
||||||
return HumanMessageChunk(**message["data"])
|
return HumanMessageChunk(**message["data"])
|
||||||
elif _type == "FunctionMessageChunk":
|
if _type == "FunctionMessageChunk":
|
||||||
return FunctionMessageChunk(**message["data"])
|
return FunctionMessageChunk(**message["data"])
|
||||||
elif _type == "ToolMessageChunk":
|
if _type == "ToolMessageChunk":
|
||||||
return ToolMessageChunk(**message["data"])
|
return ToolMessageChunk(**message["data"])
|
||||||
elif _type == "SystemMessageChunk":
|
if _type == "SystemMessageChunk":
|
||||||
return SystemMessageChunk(**message["data"])
|
return SystemMessageChunk(**message["data"])
|
||||||
elif _type == "ChatMessageChunk":
|
if _type == "ChatMessageChunk":
|
||||||
return ChatMessageChunk(**message["data"])
|
return ChatMessageChunk(**message["data"])
|
||||||
else:
|
|
||||||
msg = f"Got unexpected message type: {_type}"
|
msg = f"Got unexpected message type: {_type}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -387,7 +385,6 @@ def _runnable_support(func: Callable) -> Callable:
|
|||||||
|
|
||||||
if messages is not None:
|
if messages is not None:
|
||||||
return func(messages, **kwargs)
|
return func(messages, **kwargs)
|
||||||
else:
|
|
||||||
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
|
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
|
||||||
|
|
||||||
wrapped.__doc__ = func.__doc__
|
wrapped.__doc__ = func.__doc__
|
||||||
@ -472,8 +469,6 @@ def filter_messages(
|
|||||||
or (exclude_ids and msg.id in exclude_ids)
|
or (exclude_ids and msg.id in exclude_ids)
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if exclude_tool_calls is True and (
|
if exclude_tool_calls is True and (
|
||||||
(isinstance(msg, AIMessage) and msg.tool_calls)
|
(isinstance(msg, AIMessage) and msg.tool_calls)
|
||||||
@ -926,7 +921,7 @@ def trim_messages(
|
|||||||
partial_strategy="first" if allow_partial else None,
|
partial_strategy="first" if allow_partial else None,
|
||||||
end_on=end_on,
|
end_on=end_on,
|
||||||
)
|
)
|
||||||
elif strategy == "last":
|
if strategy == "last":
|
||||||
return _last_max_tokens(
|
return _last_max_tokens(
|
||||||
messages,
|
messages,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@ -937,7 +932,6 @@ def trim_messages(
|
|||||||
end_on=end_on,
|
end_on=end_on,
|
||||||
text_splitter=text_splitter_fn,
|
text_splitter=text_splitter_fn,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
msg = f"Unrecognized {strategy=}. Supported strategies are 'last' and 'first'."
|
msg = f"Unrecognized {strategy=}. Supported strategies are 'last' and 'first'."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -1269,7 +1263,6 @@ def convert_to_openai_messages(
|
|||||||
|
|
||||||
if is_single:
|
if is_single:
|
||||||
return oai_messages[0]
|
return oai_messages[0]
|
||||||
else:
|
|
||||||
return oai_messages
|
return oai_messages
|
||||||
|
|
||||||
|
|
||||||
@ -1347,7 +1340,7 @@ def _first_max_tokens(
|
|||||||
if isinstance(block, str):
|
if isinstance(block, str):
|
||||||
text = block
|
text = block
|
||||||
break
|
break
|
||||||
elif isinstance(block, dict) and block.get("type") == "text":
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
text = block.get("text")
|
text = block.get("text")
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -1517,17 +1510,16 @@ def _bytes_to_b64_str(bytes_: bytes) -> str:
|
|||||||
def _get_message_openai_role(message: BaseMessage) -> str:
|
def _get_message_openai_role(message: BaseMessage) -> str:
|
||||||
if isinstance(message, AIMessage):
|
if isinstance(message, AIMessage):
|
||||||
return "assistant"
|
return "assistant"
|
||||||
elif isinstance(message, HumanMessage):
|
if isinstance(message, HumanMessage):
|
||||||
return "user"
|
return "user"
|
||||||
elif isinstance(message, ToolMessage):
|
if isinstance(message, ToolMessage):
|
||||||
return "tool"
|
return "tool"
|
||||||
elif isinstance(message, SystemMessage):
|
if isinstance(message, SystemMessage):
|
||||||
return message.additional_kwargs.get("__openai_role__", "system")
|
return message.additional_kwargs.get("__openai_role__", "system")
|
||||||
elif isinstance(message, FunctionMessage):
|
if isinstance(message, FunctionMessage):
|
||||||
return "function"
|
return "function"
|
||||||
elif isinstance(message, ChatMessage):
|
if isinstance(message, ChatMessage):
|
||||||
return message.role
|
return message.role
|
||||||
else:
|
|
||||||
msg = f"Unknown BaseMessage type {message.__class__}."
|
msg = f"Unknown BaseMessage type {message.__class__}."
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
|
@ -97,7 +97,6 @@ class BaseGenerationOutputParser(
|
|||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return self._call_with_config(
|
return self._call_with_config(
|
||||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||||
input,
|
input,
|
||||||
@ -121,7 +120,6 @@ class BaseGenerationOutputParser(
|
|||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return await self._acall_with_config(
|
return await self._acall_with_config(
|
||||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||||
input,
|
input,
|
||||||
@ -203,7 +201,6 @@ class BaseOutputParser(
|
|||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return self._call_with_config(
|
return self._call_with_config(
|
||||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||||
input,
|
input,
|
||||||
@ -227,7 +224,6 @@ class BaseOutputParser(
|
|||||||
config,
|
config,
|
||||||
run_type="parser",
|
run_type="parser",
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return await self._acall_with_config(
|
return await self._acall_with_config(
|
||||||
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
|
||||||
input,
|
input,
|
||||||
|
@ -53,8 +53,9 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
|
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
|
||||||
if issubclass(pydantic_object, pydantic.BaseModel):
|
if issubclass(pydantic_object, pydantic.BaseModel):
|
||||||
return pydantic_object.model_json_schema()
|
return pydantic_object.model_json_schema()
|
||||||
elif issubclass(pydantic_object, pydantic.v1.BaseModel):
|
if issubclass(pydantic_object, pydantic.v1.BaseModel):
|
||||||
return pydantic_object.schema()
|
return pydantic_object.schema()
|
||||||
|
return None
|
||||||
|
|
||||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
@ -106,7 +107,6 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
"""
|
"""
|
||||||
if self.pydantic_object is None:
|
if self.pydantic_object is None:
|
||||||
return "Return a JSON object."
|
return "Return a JSON object."
|
||||||
else:
|
|
||||||
# Copy schema to avoid altering original Pydantic schema.
|
# Copy schema to avoid altering original Pydantic schema.
|
||||||
schema = dict(self._get_schema(self.pydantic_object).items())
|
schema = dict(self._get_schema(self.pydantic_object).items())
|
||||||
|
|
||||||
|
@ -99,7 +99,6 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
if partial:
|
if partial:
|
||||||
return None
|
return None
|
||||||
else:
|
|
||||||
msg = f"Could not parse function call: {exc}"
|
msg = f"Could not parse function call: {exc}"
|
||||||
raise OutputParserException(msg) from exc
|
raise OutputParserException(msg) from exc
|
||||||
try:
|
try:
|
||||||
@ -109,7 +108,6 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
return parse_partial_json(
|
return parse_partial_json(
|
||||||
function_call["arguments"], strict=self.strict
|
function_call["arguments"], strict=self.strict
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
**function_call,
|
**function_call,
|
||||||
"arguments": parse_partial_json(
|
"arguments": parse_partial_json(
|
||||||
|
@ -241,9 +241,8 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
|||||||
)
|
)
|
||||||
if self.return_id:
|
if self.return_id:
|
||||||
return single_result
|
return single_result
|
||||||
elif single_result:
|
if single_result:
|
||||||
return single_result["args"]
|
return single_result["args"]
|
||||||
else:
|
|
||||||
return None
|
return None
|
||||||
parsed_result = [res for res in parsed_result if res["type"] == self.key_name]
|
parsed_result = [res for res in parsed_result if res["type"] == self.key_name]
|
||||||
if not self.return_id:
|
if not self.return_id:
|
||||||
@ -300,5 +299,4 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
|||||||
raise
|
raise
|
||||||
if self.first_tool_only:
|
if self.first_tool_only:
|
||||||
return pydantic_objects[0] if pydantic_objects else None
|
return pydantic_objects[0] if pydantic_objects else None
|
||||||
else:
|
|
||||||
return pydantic_objects
|
return pydantic_objects
|
||||||
|
@ -28,9 +28,8 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|||||||
try:
|
try:
|
||||||
if issubclass(self.pydantic_object, pydantic.BaseModel):
|
if issubclass(self.pydantic_object, pydantic.BaseModel):
|
||||||
return self.pydantic_object.model_validate(obj)
|
return self.pydantic_object.model_validate(obj)
|
||||||
elif issubclass(self.pydantic_object, pydantic.v1.BaseModel):
|
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
|
||||||
return self.pydantic_object.parse_obj(obj)
|
return self.pydantic_object.parse_obj(obj)
|
||||||
else:
|
|
||||||
msg = f"Unsupported model version for PydanticOutputParser: \
|
msg = f"Unsupported model version for PydanticOutputParser: \
|
||||||
{self.pydantic_object.__class__}"
|
{self.pydantic_object.__class__}"
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
|
@ -282,5 +282,4 @@ def nested_element(path: list[str], elem: ET.Element) -> Any:
|
|||||||
"""
|
"""
|
||||||
if len(path) == 0:
|
if len(path) == 0:
|
||||||
return AddableDict({elem.tag: elem.text})
|
return AddableDict({elem.tag: elem.text})
|
||||||
else:
|
|
||||||
return AddableDict({path[0]: [nested_element(path[1:], elem)]})
|
return AddableDict({path[0]: [nested_element(path[1:], elem)]})
|
||||||
|
@ -60,13 +60,11 @@ class ChatGeneration(Generation):
|
|||||||
if isinstance(block, str):
|
if isinstance(block, str):
|
||||||
text = block
|
text = block
|
||||||
break
|
break
|
||||||
elif isinstance(block, dict) and "text" in block:
|
if isinstance(block, dict) and "text" in block:
|
||||||
text = block["text"]
|
text = block["text"]
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
else:
|
|
||||||
pass
|
|
||||||
self.text = text
|
self.text = text
|
||||||
except (KeyError, AttributeError) as e:
|
except (KeyError, AttributeError) as e:
|
||||||
msg = "Error while initializing ChatGeneration"
|
msg = "Error while initializing ChatGeneration"
|
||||||
@ -104,7 +102,7 @@ class ChatGenerationChunk(ChatGeneration):
|
|||||||
message=self.message + other.message,
|
message=self.message + other.message,
|
||||||
generation_info=generation_info or None,
|
generation_info=generation_info or None,
|
||||||
)
|
)
|
||||||
elif isinstance(other, list) and all(
|
if isinstance(other, list) and all(
|
||||||
isinstance(x, ChatGenerationChunk) for x in other
|
isinstance(x, ChatGenerationChunk) for x in other
|
||||||
):
|
):
|
||||||
generation_info = merge_dicts(
|
generation_info = merge_dicts(
|
||||||
@ -115,8 +113,5 @@ class ChatGenerationChunk(ChatGeneration):
|
|||||||
message=self.message + [chunk.message for chunk in other],
|
message=self.message + [chunk.message for chunk in other],
|
||||||
generation_info=generation_info or None,
|
generation_info=generation_info or None,
|
||||||
)
|
)
|
||||||
else:
|
msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||||
msg = (
|
|
||||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
|
||||||
)
|
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
@ -64,8 +64,5 @@ class GenerationChunk(Generation):
|
|||||||
text=self.text + other.text,
|
text=self.text + other.text,
|
||||||
generation_info=generation_info or None,
|
generation_info=generation_info or None,
|
||||||
)
|
)
|
||||||
else:
|
msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||||
msg = (
|
|
||||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
|
||||||
)
|
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
@ -513,7 +513,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
partial_variables=partial_variables,
|
partial_variables=partial_variables,
|
||||||
)
|
)
|
||||||
return cls(prompt=prompt, **kwargs)
|
return cls(prompt=prompt, **kwargs)
|
||||||
elif isinstance(template, list):
|
if isinstance(template, list):
|
||||||
if (partial_variables is not None) and len(partial_variables) > 0:
|
if (partial_variables is not None) and len(partial_variables) > 0:
|
||||||
msg = "Partial variables are not supported for list of templates."
|
msg = "Partial variables are not supported for list of templates."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -571,7 +571,6 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
msg = f"Invalid template: {tmpl}"
|
msg = f"Invalid template: {tmpl}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return cls(prompt=prompt, **kwargs)
|
return cls(prompt=prompt, **kwargs)
|
||||||
else:
|
|
||||||
msg = f"Invalid template: {template}"
|
msg = f"Invalid template: {template}"
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg) # noqa: TRY004
|
||||||
|
|
||||||
@ -625,8 +624,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
List of input variable names.
|
List of input variable names.
|
||||||
"""
|
"""
|
||||||
prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt]
|
prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt]
|
||||||
input_variables = [iv for prompt in prompts for iv in prompt.input_variables]
|
return [iv for prompt in prompts for iv in prompt.input_variables]
|
||||||
return input_variables
|
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> BaseMessage:
|
def format(self, **kwargs: Any) -> BaseMessage:
|
||||||
"""Format the prompt template.
|
"""Format the prompt template.
|
||||||
@ -642,7 +640,6 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
return self._msg_class(
|
return self._msg_class(
|
||||||
content=text, additional_kwargs=self.additional_kwargs
|
content=text, additional_kwargs=self.additional_kwargs
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
content: list = []
|
content: list = []
|
||||||
for prompt in self.prompt:
|
for prompt in self.prompt:
|
||||||
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
||||||
@ -670,7 +667,6 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
return self._msg_class(
|
return self._msg_class(
|
||||||
content=text, additional_kwargs=self.additional_kwargs
|
content=text, additional_kwargs=self.additional_kwargs
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
content: list = []
|
content: list = []
|
||||||
for prompt in self.prompt:
|
for prompt in self.prompt:
|
||||||
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
||||||
@ -1034,23 +1030,22 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
return ChatPromptTemplate(messages=self.messages + other.messages).partial(
|
return ChatPromptTemplate(messages=self.messages + other.messages).partial(
|
||||||
**partials
|
**partials
|
||||||
) # type: ignore[call-arg]
|
) # type: ignore[call-arg]
|
||||||
elif isinstance(
|
if isinstance(
|
||||||
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
|
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
|
||||||
):
|
):
|
||||||
return ChatPromptTemplate(messages=self.messages + [other]).partial(
|
return ChatPromptTemplate(messages=self.messages + [other]).partial(
|
||||||
**partials
|
**partials
|
||||||
) # type: ignore[call-arg]
|
) # type: ignore[call-arg]
|
||||||
elif isinstance(other, (list, tuple)):
|
if isinstance(other, (list, tuple)):
|
||||||
_other = ChatPromptTemplate.from_messages(other)
|
_other = ChatPromptTemplate.from_messages(other)
|
||||||
return ChatPromptTemplate(messages=self.messages + _other.messages).partial(
|
return ChatPromptTemplate(messages=self.messages + _other.messages).partial(
|
||||||
**partials
|
**partials
|
||||||
) # type: ignore[call-arg]
|
) # type: ignore[call-arg]
|
||||||
elif isinstance(other, str):
|
if isinstance(other, str):
|
||||||
prompt = HumanMessagePromptTemplate.from_template(other)
|
prompt = HumanMessagePromptTemplate.from_template(other)
|
||||||
return ChatPromptTemplate(messages=self.messages + [prompt]).partial(
|
return ChatPromptTemplate(messages=self.messages + [prompt]).partial(
|
||||||
**partials
|
**partials
|
||||||
) # type: ignore[call-arg]
|
) # type: ignore[call-arg]
|
||||||
else:
|
|
||||||
msg = f"Unsupported operand type for +: {type(other)}"
|
msg = f"Unsupported operand type for +: {type(other)}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
@ -1322,7 +1317,6 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
start, stop, step = index.indices(len(self.messages))
|
start, stop, step = index.indices(len(self.messages))
|
||||||
messages = self.messages[start:stop:step]
|
messages = self.messages[start:stop:step]
|
||||||
return ChatPromptTemplate.from_messages(messages)
|
return ChatPromptTemplate.from_messages(messages)
|
||||||
else:
|
|
||||||
return self.messages[index]
|
return self.messages[index]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
@ -88,9 +88,8 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
|||||||
"""
|
"""
|
||||||
if self.examples is not None:
|
if self.examples is not None:
|
||||||
return self.examples
|
return self.examples
|
||||||
elif self.example_selector is not None:
|
if self.example_selector is not None:
|
||||||
return self.example_selector.select_examples(kwargs)
|
return self.example_selector.select_examples(kwargs)
|
||||||
else:
|
|
||||||
msg = "One of 'examples' and 'example_selector' should be provided"
|
msg = "One of 'examples' and 'example_selector' should be provided"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -108,9 +107,8 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
|||||||
"""
|
"""
|
||||||
if self.examples is not None:
|
if self.examples is not None:
|
||||||
return self.examples
|
return self.examples
|
||||||
elif self.example_selector is not None:
|
if self.example_selector is not None:
|
||||||
return await self.example_selector.aselect_examples(kwargs)
|
return await self.example_selector.aselect_examples(kwargs)
|
||||||
else:
|
|
||||||
msg = "One of 'examples' and 'example_selector' should be provided"
|
msg = "One of 'examples' and 'example_selector' should be provided"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -394,12 +392,11 @@ class FewShotChatMessagePromptTemplate(
|
|||||||
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
|
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
|
||||||
]
|
]
|
||||||
# Format the examples.
|
# Format the examples.
|
||||||
messages = [
|
return [
|
||||||
message
|
message
|
||||||
for example in examples
|
for example in examples
|
||||||
for message in self.example_prompt.format_messages(**example)
|
for message in self.example_prompt.format_messages(**example)
|
||||||
]
|
]
|
||||||
return messages
|
|
||||||
|
|
||||||
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||||
"""Async format kwargs into a list of messages.
|
"""Async format kwargs into a list of messages.
|
||||||
@ -416,12 +413,11 @@ class FewShotChatMessagePromptTemplate(
|
|||||||
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
|
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
|
||||||
]
|
]
|
||||||
# Format the examples.
|
# Format the examples.
|
||||||
messages = [
|
return [
|
||||||
message
|
message
|
||||||
for example in examples
|
for example in examples
|
||||||
for message in await self.example_prompt.aformat_messages(**example)
|
for message in await self.example_prompt.aformat_messages(**example)
|
||||||
]
|
]
|
||||||
return messages
|
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> str:
|
def format(self, **kwargs: Any) -> str:
|
||||||
"""Format the prompt with inputs generating a string.
|
"""Format the prompt with inputs generating a string.
|
||||||
|
@ -97,17 +97,15 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
def _get_examples(self, **kwargs: Any) -> list[dict]:
|
def _get_examples(self, **kwargs: Any) -> list[dict]:
|
||||||
if self.examples is not None:
|
if self.examples is not None:
|
||||||
return self.examples
|
return self.examples
|
||||||
elif self.example_selector is not None:
|
if self.example_selector is not None:
|
||||||
return self.example_selector.select_examples(kwargs)
|
return self.example_selector.select_examples(kwargs)
|
||||||
else:
|
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
async def _aget_examples(self, **kwargs: Any) -> list[dict]:
|
async def _aget_examples(self, **kwargs: Any) -> list[dict]:
|
||||||
if self.examples is not None:
|
if self.examples is not None:
|
||||||
return self.examples
|
return self.examples
|
||||||
elif self.example_selector is not None:
|
if self.example_selector is not None:
|
||||||
return await self.example_selector.aselect_examples(kwargs)
|
return await self.example_selector.aselect_examples(kwargs)
|
||||||
else:
|
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> str:
|
def format(self, **kwargs: Any) -> str:
|
||||||
|
@ -110,10 +110,9 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
|||||||
if not url:
|
if not url:
|
||||||
msg = "Must provide url."
|
msg = "Must provide url."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif not isinstance(url, str):
|
if not isinstance(url, str):
|
||||||
msg = "url must be a string."
|
msg = "url must be a string."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg) # noqa: TRY004
|
||||||
else:
|
|
||||||
output: ImageURL = {"url": url}
|
output: ImageURL = {"url": url}
|
||||||
if detail:
|
if detail:
|
||||||
# Don't check literal values here: let the API check them
|
# Don't check literal values here: let the API check them
|
||||||
|
@ -154,7 +154,6 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
if k in partial_variables:
|
if k in partial_variables:
|
||||||
msg = "Cannot have same variable partialed twice."
|
msg = "Cannot have same variable partialed twice."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
|
||||||
partial_variables[k] = v
|
partial_variables[k] = v
|
||||||
return PromptTemplate(
|
return PromptTemplate(
|
||||||
template=template,
|
template=template,
|
||||||
@ -163,10 +162,9 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
template_format="f-string",
|
template_format="f-string",
|
||||||
validate_template=validate_template,
|
validate_template=validate_template,
|
||||||
)
|
)
|
||||||
elif isinstance(other, str):
|
if isinstance(other, str):
|
||||||
prompt = PromptTemplate.from_template(other)
|
prompt = PromptTemplate.from_template(other)
|
||||||
return self + prompt
|
return self + prompt
|
||||||
else:
|
|
||||||
msg = f"Unsupported operand type for +: {type(other)}"
|
msg = f"Unsupported operand type for +: {type(other)}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
@ -100,8 +100,7 @@ def _get_jinja2_variables_from_template(template: str) -> set[str]:
|
|||||||
# noqa for insecure warning elsewhere
|
# noqa for insecure warning elsewhere
|
||||||
env = Environment() # noqa: S701
|
env = Environment() # noqa: S701
|
||||||
ast = env.parse(template)
|
ast = env.parse(template)
|
||||||
variables = meta.find_undeclared_variables(ast)
|
return meta.find_undeclared_variables(ast)
|
||||||
return variables
|
|
||||||
|
|
||||||
|
|
||||||
def mustache_formatter(template: str, /, **kwargs: Any) -> str:
|
def mustache_formatter(template: str, /, **kwargs: Any) -> str:
|
||||||
|
@ -166,6 +166,5 @@ class StructuredPrompt(ChatPromptTemplate):
|
|||||||
*others[1:],
|
*others[1:],
|
||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
msg = "Structured prompts need to be piped to a language model."
|
msg = "Structured prompts need to be piped to a language model."
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
@ -208,8 +208,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
default_retriever_name = default_retriever_name[:-9]
|
default_retriever_name = default_retriever_name[:-9]
|
||||||
default_retriever_name = default_retriever_name.lower()
|
default_retriever_name = default_retriever_name.lower()
|
||||||
|
|
||||||
ls_params = LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
|
return LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
|
||||||
return ls_params
|
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
@ -269,9 +269,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
if suffix:
|
if suffix:
|
||||||
if name_[0].isupper():
|
if name_[0].isupper():
|
||||||
return name_ + suffix.title()
|
return name_ + suffix.title()
|
||||||
else:
|
|
||||||
return name_ + "_" + suffix.lower()
|
return name_ + "_" + suffix.lower()
|
||||||
else:
|
|
||||||
return name_
|
return name_
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -513,10 +511,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
if field_name in [i for i in include if i != "configurable"]
|
if field_name in [i for i in include if i != "configurable"]
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
model = create_model_v2( # type: ignore[call-overload]
|
return create_model_v2( # type: ignore[call-overload]
|
||||||
self.get_name("Config"), field_definitions=all_fields
|
self.get_name("Config"), field_definitions=all_fields
|
||||||
)
|
)
|
||||||
return model
|
|
||||||
|
|
||||||
def get_config_jsonschema(
|
def get_config_jsonschema(
|
||||||
self, *, include: Optional[Sequence[str]] = None
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
@ -2051,7 +2048,6 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast("list[Output]", [e for _ in input])
|
return cast("list[Output]", [e for _ in input])
|
||||||
else:
|
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
first_exception: Optional[Exception] = None
|
first_exception: Optional[Exception] = None
|
||||||
@ -2063,7 +2059,6 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
run_manager.on_chain_end(out)
|
run_manager.on_chain_end(out)
|
||||||
if return_exceptions or first_exception is None:
|
if return_exceptions or first_exception is None:
|
||||||
return cast("list[Output]", output)
|
return cast("list[Output]", output)
|
||||||
else:
|
|
||||||
raise first_exception
|
raise first_exception
|
||||||
|
|
||||||
async def _abatch_with_config(
|
async def _abatch_with_config(
|
||||||
@ -2130,7 +2125,6 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
)
|
)
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast("list[Output]", [e for _ in input])
|
return cast("list[Output]", [e for _ in input])
|
||||||
else:
|
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
first_exception: Optional[Exception] = None
|
first_exception: Optional[Exception] = None
|
||||||
@ -2144,7 +2138,6 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
await asyncio.gather(*coros)
|
await asyncio.gather(*coros)
|
||||||
if return_exceptions or first_exception is None:
|
if return_exceptions or first_exception is None:
|
||||||
return cast("list[Output]", output)
|
return cast("list[Output]", output)
|
||||||
else:
|
|
||||||
raise first_exception
|
raise first_exception
|
||||||
|
|
||||||
def _transform_stream_with_config(
|
def _transform_stream_with_config(
|
||||||
@ -2615,7 +2608,7 @@ def _seq_input_schema(
|
|||||||
first = steps[0]
|
first = steps[0]
|
||||||
if len(steps) == 1:
|
if len(steps) == 1:
|
||||||
return first.get_input_schema(config)
|
return first.get_input_schema(config)
|
||||||
elif isinstance(first, RunnableAssign):
|
if isinstance(first, RunnableAssign):
|
||||||
next_input_schema = _seq_input_schema(steps[1:], config)
|
next_input_schema = _seq_input_schema(steps[1:], config)
|
||||||
if not issubclass(next_input_schema, RootModel):
|
if not issubclass(next_input_schema, RootModel):
|
||||||
# it's a dict as expected
|
# it's a dict as expected
|
||||||
@ -2641,7 +2634,7 @@ def _seq_output_schema(
|
|||||||
last = steps[-1]
|
last = steps[-1]
|
||||||
if len(steps) == 1:
|
if len(steps) == 1:
|
||||||
return last.get_input_schema(config)
|
return last.get_input_schema(config)
|
||||||
elif isinstance(last, RunnableAssign):
|
if isinstance(last, RunnableAssign):
|
||||||
mapper_output_schema = last.mapper.get_output_schema(config)
|
mapper_output_schema = last.mapper.get_output_schema(config)
|
||||||
prev_output_schema = _seq_output_schema(steps[:-1], config)
|
prev_output_schema = _seq_output_schema(steps[:-1], config)
|
||||||
if not issubclass(prev_output_schema, RootModel):
|
if not issubclass(prev_output_schema, RootModel):
|
||||||
@ -2672,7 +2665,6 @@ def _seq_output_schema(
|
|||||||
if k in last.keys
|
if k in last.keys
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
field = prev_output_schema.model_fields[last.keys]
|
field = prev_output_schema.model_fields[last.keys]
|
||||||
return create_model_v2( # type: ignore[call-overload]
|
return create_model_v2( # type: ignore[call-overload]
|
||||||
"RunnableSequenceOutput", root=(field.annotation, field.default)
|
"RunnableSequenceOutput", root=(field.annotation, field.default)
|
||||||
@ -2988,7 +2980,6 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
other.last,
|
other.last,
|
||||||
name=self.name or other.name,
|
name=self.name or other.name,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return RunnableSequence(
|
return RunnableSequence(
|
||||||
self.first,
|
self.first,
|
||||||
*self.middle,
|
*self.middle,
|
||||||
@ -3017,7 +3008,6 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
self.last,
|
self.last,
|
||||||
name=other.name or self.name,
|
name=other.name or self.name,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return RunnableSequence(
|
return RunnableSequence(
|
||||||
coerce_to_runnable(other),
|
coerce_to_runnable(other),
|
||||||
self.first,
|
self.first,
|
||||||
@ -3224,7 +3214,6 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
rm.on_chain_error(e)
|
rm.on_chain_error(e)
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast("list[Output]", [e for _ in inputs])
|
return cast("list[Output]", [e for _ in inputs])
|
||||||
else:
|
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
first_exception: Optional[Exception] = None
|
first_exception: Optional[Exception] = None
|
||||||
@ -3236,7 +3225,6 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
run_manager.on_chain_end(out)
|
run_manager.on_chain_end(out)
|
||||||
if return_exceptions or first_exception is None:
|
if return_exceptions or first_exception is None:
|
||||||
return cast("list[Output]", inputs)
|
return cast("list[Output]", inputs)
|
||||||
else:
|
|
||||||
raise first_exception
|
raise first_exception
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -3357,7 +3345,6 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast("list[Output]", [e for _ in inputs])
|
return cast("list[Output]", [e for _ in inputs])
|
||||||
else:
|
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
first_exception: Optional[Exception] = None
|
first_exception: Optional[Exception] = None
|
||||||
@ -3371,7 +3358,6 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
await asyncio.gather(*coros)
|
await asyncio.gather(*coros)
|
||||||
if return_exceptions or first_exception is None:
|
if return_exceptions or first_exception is None:
|
||||||
return cast("list[Output]", inputs)
|
return cast("list[Output]", inputs)
|
||||||
else:
|
|
||||||
raise first_exception
|
raise first_exception
|
||||||
|
|
||||||
def _transform(
|
def _transform(
|
||||||
@ -3826,7 +3812,6 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
|||||||
return await asyncio.create_task( # type: ignore
|
return await asyncio.create_task( # type: ignore
|
||||||
step.ainvoke(input, child_config), context=context
|
step.ainvoke(input, child_config), context=context
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return await asyncio.create_task(step.ainvoke(input, child_config))
|
return await asyncio.create_task(step.ainvoke(input, child_config))
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
@ -4141,9 +4126,8 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
first_param = next(iter(params.values()), None)
|
first_param = next(iter(params.values()), None)
|
||||||
if first_param and first_param.annotation != inspect.Parameter.empty:
|
if first_param and first_param.annotation != inspect.Parameter.empty:
|
||||||
return getattr(first_param.annotation, "__args__", (Any,))[0]
|
return getattr(first_param.annotation, "__args__", (Any,))[0]
|
||||||
else:
|
|
||||||
return Any
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
pass
|
||||||
return Any
|
return Any
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -4220,11 +4204,9 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
if isinstance(other, RunnableGenerator):
|
if isinstance(other, RunnableGenerator):
|
||||||
if hasattr(self, "_transform") and hasattr(other, "_transform"):
|
if hasattr(self, "_transform") and hasattr(other, "_transform"):
|
||||||
return self._transform == other._transform
|
return self._transform == other._transform
|
||||||
elif hasattr(self, "_atransform") and hasattr(other, "_atransform"):
|
if hasattr(self, "_atransform") and hasattr(other, "_atransform"):
|
||||||
return self._atransform == other._atransform
|
return self._atransform == other._atransform
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -4443,9 +4425,8 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
first_param = next(iter(params.values()), None)
|
first_param = next(iter(params.values()), None)
|
||||||
if first_param and first_param.annotation != inspect.Parameter.empty:
|
if first_param and first_param.annotation != inspect.Parameter.empty:
|
||||||
return first_param.annotation
|
return first_param.annotation
|
||||||
else:
|
|
||||||
return Any
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
pass
|
||||||
return Any
|
return Any
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -4472,7 +4453,6 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
fields = {item[1:-1]: (Any, ...) for item in items}
|
fields = {item[1:-1]: (Any, ...) for item in items}
|
||||||
# It's a dict, lol
|
# It's a dict, lol
|
||||||
return create_model_v2(self.get_name("Input"), field_definitions=fields)
|
return create_model_v2(self.get_name("Input"), field_definitions=fields)
|
||||||
else:
|
|
||||||
module = getattr(func, "__module__", None)
|
module = getattr(func, "__module__", None)
|
||||||
return create_model_v2(
|
return create_model_v2(
|
||||||
self.get_name("Input"),
|
self.get_name("Input"),
|
||||||
@ -4513,9 +4493,8 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
):
|
):
|
||||||
return getattr(sig.return_annotation, "__args__", (Any,))[0]
|
return getattr(sig.return_annotation, "__args__", (Any,))[0]
|
||||||
return sig.return_annotation
|
return sig.return_annotation
|
||||||
else:
|
|
||||||
return Any
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
pass
|
||||||
return Any
|
return Any
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -4607,11 +4586,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
if isinstance(other, RunnableLambda):
|
if isinstance(other, RunnableLambda):
|
||||||
if hasattr(self, "func") and hasattr(other, "func"):
|
if hasattr(self, "func") and hasattr(other, "func"):
|
||||||
return self.func == other.func
|
return self.func == other.func
|
||||||
elif hasattr(self, "afunc") and hasattr(other, "afunc"):
|
if hasattr(self, "afunc") and hasattr(other, "afunc"):
|
||||||
return self.afunc == other.afunc
|
return self.afunc == other.afunc
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@ -4806,11 +4783,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
self._config(config, self.func),
|
self._config(config, self.func),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
msg = "Cannot invoke a coroutine function synchronously.Use `ainvoke` instead."
|
||||||
msg = (
|
|
||||||
"Cannot invoke a coroutine function synchronously."
|
|
||||||
"Use `ainvoke` instead."
|
|
||||||
)
|
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -5886,7 +5859,7 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
elif config_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
if config_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
||||||
idx = list(inspect.signature(attr).parameters).index("config")
|
idx = list(inspect.signature(attr).parameters).index("config")
|
||||||
|
|
||||||
@wraps(attr)
|
@wraps(attr)
|
||||||
@ -5895,12 +5868,9 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
|
|||||||
argsl = list(args)
|
argsl = list(args)
|
||||||
argsl[idx] = merge_configs(self.config, argsl[idx])
|
argsl[idx] = merge_configs(self.config, argsl[idx])
|
||||||
return attr(*argsl, **kwargs)
|
return attr(*argsl, **kwargs)
|
||||||
else:
|
|
||||||
return attr(
|
return attr(
|
||||||
*args,
|
*args,
|
||||||
config=merge_configs(
|
config=merge_configs(self.config, kwargs.pop("config", None)),
|
||||||
self.config, kwargs.pop("config", None)
|
|
||||||
),
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -5957,13 +5927,12 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
|
|||||||
"""
|
"""
|
||||||
if isinstance(thing, Runnable):
|
if isinstance(thing, Runnable):
|
||||||
return thing
|
return thing
|
||||||
elif is_async_generator(thing) or inspect.isgeneratorfunction(thing):
|
if is_async_generator(thing) or inspect.isgeneratorfunction(thing):
|
||||||
return RunnableGenerator(thing)
|
return RunnableGenerator(thing)
|
||||||
elif callable(thing):
|
if callable(thing):
|
||||||
return RunnableLambda(cast("Callable[[Input], Output]", thing))
|
return RunnableLambda(cast("Callable[[Input], Output]", thing))
|
||||||
elif isinstance(thing, dict):
|
if isinstance(thing, dict):
|
||||||
return cast("Runnable[Input, Output]", RunnableParallel(thing))
|
return cast("Runnable[Input, Output]", RunnableParallel(thing))
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Expected a Runnable, callable or dict."
|
f"Expected a Runnable, callable or dict."
|
||||||
f"Instead got an unsupported type: {type(thing)}"
|
f"Instead got an unsupported type: {type(thing)}"
|
||||||
|
@ -314,7 +314,6 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
else:
|
|
||||||
return attr
|
return attr
|
||||||
|
|
||||||
|
|
||||||
@ -462,7 +461,6 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|||||||
self.default.__class__(**{**init_params, **configurable}),
|
self.default.__class__(**{**init_params, **configurable}),
|
||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return (self.default, config)
|
return (self.default, config)
|
||||||
|
|
||||||
|
|
||||||
@ -638,13 +636,11 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|||||||
# return the chosen alternative
|
# return the chosen alternative
|
||||||
if which == self.default_key:
|
if which == self.default_key:
|
||||||
return (self.default, config)
|
return (self.default, config)
|
||||||
elif which in self.alternatives:
|
if which in self.alternatives:
|
||||||
alt = self.alternatives[which]
|
alt = self.alternatives[which]
|
||||||
if isinstance(alt, Runnable):
|
if isinstance(alt, Runnable):
|
||||||
return (alt, config)
|
return (alt, config)
|
||||||
else:
|
|
||||||
return (alt(), config)
|
return (alt(), config)
|
||||||
else:
|
|
||||||
msg = f"Unknown alternative: {which}"
|
msg = f"Unknown alternative: {which}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -714,7 +710,6 @@ def make_options_spec(
|
|||||||
default=spec.default,
|
default=spec.default,
|
||||||
is_shared=spec.is_shared,
|
is_shared=spec.is_shared,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return ConfigurableFieldSpec(
|
return ConfigurableFieldSpec(
|
||||||
id=spec.id,
|
id=spec.id,
|
||||||
name=spec.name,
|
name=spec.name,
|
||||||
|
@ -661,7 +661,6 @@ def _is_runnable_type(type_: Any) -> bool:
|
|||||||
origin = getattr(type_, "__origin__", None)
|
origin = getattr(type_, "__origin__", None)
|
||||||
if inspect.isclass(origin):
|
if inspect.isclass(origin):
|
||||||
return issubclass(origin, Runnable)
|
return issubclass(origin, Runnable)
|
||||||
elif origin is typing.Union:
|
if origin is typing.Union:
|
||||||
return all(_is_runnable_type(t) for t in type_.__args__)
|
return all(_is_runnable_type(t) for t in type_.__args__)
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
|
@ -195,10 +195,7 @@ def node_data_str(id: str, data: Union[type[BaseModel], RunnableType]) -> str:
|
|||||||
|
|
||||||
if not is_uuid(id):
|
if not is_uuid(id):
|
||||||
return id
|
return id
|
||||||
elif isinstance(data, Runnable):
|
data_str = data.get_name() if isinstance(data, Runnable) else data.__name__
|
||||||
data_str = data.get_name()
|
|
||||||
else:
|
|
||||||
data_str = data.__name__
|
|
||||||
return data_str if not data_str.startswith("Runnable") else data_str[8:]
|
return data_str if not data_str.startswith("Runnable") else data_str[8:]
|
||||||
|
|
||||||
|
|
||||||
@ -449,7 +446,6 @@ class Graph:
|
|||||||
label = unique_labels[node_id]
|
label = unique_labels[node_id]
|
||||||
if is_uuid(node_id):
|
if is_uuid(node_id):
|
||||||
return label
|
return label
|
||||||
else:
|
|
||||||
return node_id
|
return node_id
|
||||||
|
|
||||||
return Graph(
|
return Graph(
|
||||||
|
@ -407,7 +407,6 @@ def _render_mermaid_using_api(
|
|||||||
Path(output_file_path).write_bytes(response.content)
|
Path(output_file_path).write_bytes(response.content)
|
||||||
|
|
||||||
return img_bytes
|
return img_bytes
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Failed to render the graph using the Mermaid.INK API. "
|
f"Failed to render the graph using the Mermaid.INK API. "
|
||||||
f"Status code: {response.status_code}."
|
f"Status code: {response.status_code}."
|
||||||
|
@ -398,8 +398,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
@property
|
@property
|
||||||
@override
|
@override
|
||||||
def OutputType(self) -> type[Output]:
|
def OutputType(self) -> type[Output]:
|
||||||
output_type = self._history_chain.OutputType
|
return self._history_chain.OutputType
|
||||||
return output_type
|
|
||||||
|
|
||||||
def get_output_schema(
|
def get_output_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
@ -460,10 +459,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
|
|
||||||
return [HumanMessage(content=input_val)]
|
return [HumanMessage(content=input_val)]
|
||||||
# If value is a single message, convert to a list
|
# If value is a single message, convert to a list
|
||||||
elif isinstance(input_val, BaseMessage):
|
if isinstance(input_val, BaseMessage):
|
||||||
return [input_val]
|
return [input_val]
|
||||||
# If value is a list or tuple...
|
# If value is a list or tuple...
|
||||||
elif isinstance(input_val, (list, tuple)):
|
if isinstance(input_val, (list, tuple)):
|
||||||
# Handle empty case
|
# Handle empty case
|
||||||
if len(input_val) == 0:
|
if len(input_val) == 0:
|
||||||
return list(input_val)
|
return list(input_val)
|
||||||
@ -475,7 +474,6 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return input_val[0]
|
return input_val[0]
|
||||||
return list(input_val)
|
return list(input_val)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
|
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
|
||||||
f"Got {input_val}."
|
f"Got {input_val}."
|
||||||
@ -507,11 +505,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
|
|
||||||
return [AIMessage(content=output_val)]
|
return [AIMessage(content=output_val)]
|
||||||
# If value is a single message, convert to a list
|
# If value is a single message, convert to a list
|
||||||
elif isinstance(output_val, BaseMessage):
|
if isinstance(output_val, BaseMessage):
|
||||||
return [output_val]
|
return [output_val]
|
||||||
elif isinstance(output_val, (list, tuple)):
|
if isinstance(output_val, (list, tuple)):
|
||||||
return list(output_val)
|
return list(output_val)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
|
f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. "
|
||||||
f"Got {output_val}."
|
f"Got {output_val}."
|
||||||
|
@ -459,7 +459,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
|||||||
return create_model_v2( # type: ignore[call-overload]
|
return create_model_v2( # type: ignore[call-overload]
|
||||||
"RunnableAssignOutput", field_definitions=fields
|
"RunnableAssignOutput", field_definitions=fields
|
||||||
)
|
)
|
||||||
elif not issubclass(map_output_schema, RootModel):
|
if not issubclass(map_output_schema, RootModel):
|
||||||
# ie. only map output is a dict
|
# ie. only map output is a dict
|
||||||
# ie. input type is either unknown or inferred incorrectly
|
# ie. input type is either unknown or inferred incorrectly
|
||||||
return map_output_schema
|
return map_output_schema
|
||||||
@ -741,11 +741,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
|||||||
|
|
||||||
if isinstance(self.keys, str):
|
if isinstance(self.keys, str):
|
||||||
return input.get(self.keys)
|
return input.get(self.keys)
|
||||||
else:
|
|
||||||
picked = {k: input.get(k) for k in self.keys if k in input}
|
picked = {k: input.get(k) for k in self.keys if k in input}
|
||||||
if picked:
|
if picked:
|
||||||
return AddableDict(picked)
|
return AddableDict(picked)
|
||||||
else:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
|
@ -440,7 +440,6 @@ def get_function_nonlocals(func: Callable) -> list[Any]:
|
|||||||
for part in kk.split(".")[1:]:
|
for part in kk.split(".")[1:]:
|
||||||
if vv is None:
|
if vv is None:
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
vv = getattr(vv, part)
|
vv = getattr(vv, part)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
@ -501,7 +501,6 @@ class ChildTool(BaseTool):
|
|||||||
if isinstance(self.args_schema, dict):
|
if isinstance(self.args_schema, dict):
|
||||||
return super().get_input_schema(config)
|
return super().get_input_schema(config)
|
||||||
return self.args_schema
|
return self.args_schema
|
||||||
else:
|
|
||||||
return create_schema_from_function(self.name, self._run)
|
return create_schema_from_function(self.name, self._run)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -550,11 +549,10 @@ class ChildTool(BaseTool):
|
|||||||
else:
|
else:
|
||||||
input_args.parse_obj({key_: tool_input})
|
input_args.parse_obj({key_: tool_input})
|
||||||
return tool_input
|
return tool_input
|
||||||
else:
|
|
||||||
if input_args is not None:
|
if input_args is not None:
|
||||||
if isinstance(input_args, dict):
|
if isinstance(input_args, dict):
|
||||||
return tool_input
|
return tool_input
|
||||||
elif issubclass(input_args, BaseModel):
|
if issubclass(input_args, BaseModel):
|
||||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||||
if (
|
if (
|
||||||
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
|
||||||
@ -592,14 +590,11 @@ class ChildTool(BaseTool):
|
|||||||
result_dict = result.dict()
|
result_dict = result.dict()
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
"args_schema must be a Pydantic BaseModel, "
|
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
|
||||||
f"got {self.args_schema}"
|
|
||||||
)
|
)
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
return {
|
return {
|
||||||
k: getattr(result, k)
|
k: getattr(result, k) for k, v in result_dict.items() if k in tool_input
|
||||||
for k, v in result_dict.items()
|
|
||||||
if k in tool_input
|
|
||||||
}
|
}
|
||||||
return tool_input
|
return tool_input
|
||||||
|
|
||||||
@ -659,14 +654,13 @@ class ChildTool(BaseTool):
|
|||||||
# pass as a positional argument.
|
# pass as a positional argument.
|
||||||
if isinstance(tool_input, str):
|
if isinstance(tool_input, str):
|
||||||
return (tool_input,), {}
|
return (tool_input,), {}
|
||||||
elif isinstance(tool_input, dict):
|
if isinstance(tool_input, dict):
|
||||||
# Make a shallow copy of the input to allow downstream code
|
# Make a shallow copy of the input to allow downstream code
|
||||||
# to modify the root level of the input without affecting the
|
# to modify the root level of the input without affecting the
|
||||||
# original input.
|
# original input.
|
||||||
# This is used by the tool to inject run time information like
|
# This is used by the tool to inject run time information like
|
||||||
# the callback manager.
|
# the callback manager.
|
||||||
return (), tool_input.copy()
|
return (), tool_input.copy()
|
||||||
else:
|
|
||||||
# This code path is not expected to be reachable.
|
# This code path is not expected to be reachable.
|
||||||
msg = f"Invalid tool input type: {type(tool_input)}"
|
msg = f"Invalid tool input type: {type(tool_input)}"
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
@ -1012,9 +1006,8 @@ def _is_message_content_block(obj: Any) -> bool:
|
|||||||
"""Check for OpenAI or Anthropic format tool message content blocks."""
|
"""Check for OpenAI or Anthropic format tool message content blocks."""
|
||||||
if isinstance(obj, str):
|
if isinstance(obj, str):
|
||||||
return True
|
return True
|
||||||
elif isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
return obj.get("type", None) in ("text", "image_url", "image", "json")
|
return obj.get("type", None) in ("text", "image_url", "image", "json")
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -1153,17 +1146,15 @@ def _replace_type_vars(
|
|||||||
if isinstance(type_, TypeVar):
|
if isinstance(type_, TypeVar):
|
||||||
if type_ in generic_map:
|
if type_ in generic_map:
|
||||||
return generic_map[type_]
|
return generic_map[type_]
|
||||||
elif default_to_bound:
|
if default_to_bound:
|
||||||
return type_.__bound__ or Any
|
return type_.__bound__ or Any
|
||||||
else:
|
|
||||||
return type_
|
return type_
|
||||||
elif (origin := get_origin(type_)) and (args := get_args(type_)):
|
if (origin := get_origin(type_)) and (args := get_args(type_)):
|
||||||
new_args = tuple(
|
new_args = tuple(
|
||||||
_replace_type_vars(arg, generic_map, default_to_bound=default_to_bound)
|
_replace_type_vars(arg, generic_map, default_to_bound=default_to_bound)
|
||||||
for arg in args
|
for arg in args
|
||||||
)
|
)
|
||||||
return _py_38_safe_origin(origin)[new_args] # type: ignore[index]
|
return _py_38_safe_origin(origin)[new_args] # type: ignore[index]
|
||||||
else:
|
|
||||||
return type_
|
return type_
|
||||||
|
|
||||||
|
|
||||||
|
@ -310,14 +310,14 @@ def tool(
|
|||||||
msg = "Name must be a string for tool constructor"
|
msg = "Name must be a string for tool constructor"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return _create_tool_factory(name_or_callable)(runnable)
|
return _create_tool_factory(name_or_callable)(runnable)
|
||||||
elif name_or_callable is not None:
|
if name_or_callable is not None:
|
||||||
if callable(name_or_callable) and hasattr(name_or_callable, "__name__"):
|
if callable(name_or_callable) and hasattr(name_or_callable, "__name__"):
|
||||||
# Used as a decorator without parameters
|
# Used as a decorator without parameters
|
||||||
# @tool
|
# @tool
|
||||||
# def my_tool():
|
# def my_tool():
|
||||||
# pass
|
# pass
|
||||||
return _create_tool_factory(name_or_callable.__name__)(name_or_callable)
|
return _create_tool_factory(name_or_callable.__name__)(name_or_callable)
|
||||||
elif isinstance(name_or_callable, str):
|
if isinstance(name_or_callable, str):
|
||||||
# Used with a new name for the tool
|
# Used with a new name for the tool
|
||||||
# @tool("search")
|
# @tool("search")
|
||||||
# def my_tool():
|
# def my_tool():
|
||||||
@ -329,13 +329,12 @@ def tool(
|
|||||||
# def my_tool():
|
# def my_tool():
|
||||||
# pass
|
# pass
|
||||||
return _create_tool_factory(name_or_callable)
|
return _create_tool_factory(name_or_callable)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"The first argument must be a string or a callable with a __name__ "
|
f"The first argument must be a string or a callable with a __name__ "
|
||||||
f"for tool decorator. Got {type(name_or_callable)}"
|
f"for tool decorator. Got {type(name_or_callable)}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
|
||||||
# Tool is used as a decorator with parameters specified
|
# Tool is used as a decorator with parameters specified
|
||||||
# @tool(parse_docstring=True)
|
# @tool(parse_docstring=True)
|
||||||
# def my_tool():
|
# def my_tool():
|
||||||
@ -408,7 +407,6 @@ def convert_runnable_to_tool(
|
|||||||
coroutine=runnable.ainvoke,
|
coroutine=runnable.ainvoke,
|
||||||
description=description,
|
description=description,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
|
|
||||||
async def ainvoke_wrapper(
|
async def ainvoke_wrapper(
|
||||||
callbacks: Optional[Callbacks] = None, **kwargs: Any
|
callbacks: Optional[Callbacks] = None, **kwargs: Any
|
||||||
|
@ -183,11 +183,10 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
The run.
|
The run.
|
||||||
"""
|
"""
|
||||||
llm_run = self._llm_run_with_retry_event(
|
return self._llm_run_with_retry_event(
|
||||||
retry_state=retry_state,
|
retry_state=retry_state,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
)
|
)
|
||||||
return llm_run
|
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||||
"""End a trace for an LLM run.
|
"""End a trace for an LLM run.
|
||||||
|
@ -335,11 +335,10 @@ class _TracerCore(ABC):
|
|||||||
"""Get the inputs for a chain run."""
|
"""Get the inputs for a chain run."""
|
||||||
if self._schema_format in ("original", "original+chat"):
|
if self._schema_format in ("original", "original+chat"):
|
||||||
return inputs if isinstance(inputs, dict) else {"input": inputs}
|
return inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||||
elif self._schema_format == "streaming_events":
|
if self._schema_format == "streaming_events":
|
||||||
return {
|
return {
|
||||||
"input": inputs,
|
"input": inputs,
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
msg = f"Invalid format: {self._schema_format}"
|
msg = f"Invalid format: {self._schema_format}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -347,11 +346,10 @@ class _TracerCore(ABC):
|
|||||||
"""Get the outputs for a chain run."""
|
"""Get the outputs for a chain run."""
|
||||||
if self._schema_format in ("original", "original+chat"):
|
if self._schema_format in ("original", "original+chat"):
|
||||||
return outputs if isinstance(outputs, dict) else {"output": outputs}
|
return outputs if isinstance(outputs, dict) else {"output": outputs}
|
||||||
elif self._schema_format == "streaming_events":
|
if self._schema_format == "streaming_events":
|
||||||
return {
|
return {
|
||||||
"output": outputs,
|
"output": outputs,
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
msg = f"Invalid format: {self._schema_format}"
|
msg = f"Invalid format: {self._schema_format}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ def _assign_name(name: Optional[str], serialized: Optional[dict[str, Any]]) -> s
|
|||||||
if serialized is not None:
|
if serialized is not None:
|
||||||
if "name" in serialized:
|
if "name" in serialized:
|
||||||
return serialized["name"]
|
return serialized["name"]
|
||||||
elif "id" in serialized:
|
if "id" in serialized:
|
||||||
return serialized["id"][-1]
|
return serialized["id"][-1]
|
||||||
return "Unnamed"
|
return "Unnamed"
|
||||||
|
|
||||||
|
@ -91,13 +91,12 @@ class FunctionCallbackHandler(BaseTracer):
|
|||||||
A string with the breadcrumbs of the run.
|
A string with the breadcrumbs of the run.
|
||||||
"""
|
"""
|
||||||
parents = self.get_parents(run)[::-1]
|
parents = self.get_parents(run)[::-1]
|
||||||
string = " > ".join(
|
return " > ".join(
|
||||||
f"{parent.run_type}:{parent.name}"
|
f"{parent.run_type}:{parent.name}"
|
||||||
if i != len(parents) - 1
|
if i != len(parents) - 1
|
||||||
else f"{parent.run_type}:{parent.name}"
|
else f"{parent.run_type}:{parent.name}"
|
||||||
for i, parent in enumerate(parents + [run])
|
for i, parent in enumerate(parents + [run])
|
||||||
)
|
)
|
||||||
return string
|
|
||||||
|
|
||||||
# logging methods
|
# logging methods
|
||||||
def _on_chain_start(self, run: Run) -> None:
|
def _on_chain_start(self, run: Run) -> None:
|
||||||
|
@ -85,7 +85,7 @@ def merge_lists(left: Optional[list], *others: Optional[list]) -> Optional[list]
|
|||||||
for other in others:
|
for other in others:
|
||||||
if other is None:
|
if other is None:
|
||||||
continue
|
continue
|
||||||
elif merged is None:
|
if merged is None:
|
||||||
merged = other.copy()
|
merged = other.copy()
|
||||||
else:
|
else:
|
||||||
for e in other:
|
for e in other:
|
||||||
@ -131,21 +131,20 @@ def merge_obj(left: Any, right: Any) -> Any:
|
|||||||
"""
|
"""
|
||||||
if left is None or right is None:
|
if left is None or right is None:
|
||||||
return left if left is not None else right
|
return left if left is not None else right
|
||||||
elif type(left) is not type(right):
|
if type(left) is not type(right):
|
||||||
msg = (
|
msg = (
|
||||||
f"left and right are of different types. Left type: {type(left)}. Right "
|
f"left and right are of different types. Left type: {type(left)}. Right "
|
||||||
f"type: {type(right)}."
|
f"type: {type(right)}."
|
||||||
)
|
)
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
elif isinstance(left, str):
|
if isinstance(left, str):
|
||||||
return left + right
|
return left + right
|
||||||
elif isinstance(left, dict):
|
if isinstance(left, dict):
|
||||||
return merge_dicts(left, right)
|
return merge_dicts(left, right)
|
||||||
elif isinstance(left, list):
|
if isinstance(left, list):
|
||||||
return merge_lists(left, right)
|
return merge_lists(left, right)
|
||||||
elif left == right:
|
if left == right:
|
||||||
return left
|
return left
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Unable to merge {left=} and {right=}. Both must be of type str, dict, or "
|
f"Unable to merge {left=} and {right=}. Both must be of type str, dict, or "
|
||||||
f"list, or else be two equal objects."
|
f"list, or else be two equal objects."
|
||||||
|
@ -72,9 +72,8 @@ def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
|||||||
"""
|
"""
|
||||||
if env_key in os.environ and os.environ[env_key]:
|
if env_key in os.environ and os.environ[env_key]:
|
||||||
return os.environ[env_key]
|
return os.environ[env_key]
|
||||||
elif default is not None:
|
if default is not None:
|
||||||
return default
|
return default
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Did not find {key}, please add an environment variable"
|
f"Did not find {key}, please add an environment variable"
|
||||||
f" `{env_key}` which contains it, or pass"
|
f" `{env_key}` which contains it, or pass"
|
||||||
|
@ -266,9 +266,9 @@ def _convert_any_typed_dicts_to_pydantic(
|
|||||||
|
|
||||||
if type_ in visited:
|
if type_ in visited:
|
||||||
return visited[type_]
|
return visited[type_]
|
||||||
elif depth >= _MAX_TYPED_DICT_RECURSION:
|
if depth >= _MAX_TYPED_DICT_RECURSION:
|
||||||
return type_
|
return type_
|
||||||
elif is_typeddict(type_):
|
if is_typeddict(type_):
|
||||||
typed_dict = type_
|
typed_dict = type_
|
||||||
docstring = inspect.getdoc(typed_dict)
|
docstring = inspect.getdoc(typed_dict)
|
||||||
annotations_ = typed_dict.__annotations__
|
annotations_ = typed_dict.__annotations__
|
||||||
@ -292,7 +292,7 @@ def _convert_any_typed_dicts_to_pydantic(
|
|||||||
f"type {type(field_desc)}."
|
f"type {type(field_desc)}."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif arg_desc := arg_descriptions.get(arg):
|
if arg_desc := arg_descriptions.get(arg):
|
||||||
field_kwargs["description"] = arg_desc
|
field_kwargs["description"] = arg_desc
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
@ -309,14 +309,13 @@ def _convert_any_typed_dicts_to_pydantic(
|
|||||||
model.__doc__ = description
|
model.__doc__ = description
|
||||||
visited[typed_dict] = model
|
visited[typed_dict] = model
|
||||||
return model
|
return model
|
||||||
elif (origin := get_origin(type_)) and (type_args := get_args(type_)):
|
if (origin := get_origin(type_)) and (type_args := get_args(type_)):
|
||||||
subscriptable_origin = _py_38_safe_origin(origin)
|
subscriptable_origin = _py_38_safe_origin(origin)
|
||||||
type_args = tuple(
|
type_args = tuple(
|
||||||
_convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited)
|
_convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited)
|
||||||
for arg in type_args # type: ignore[index]
|
for arg in type_args # type: ignore[index]
|
||||||
)
|
)
|
||||||
return subscriptable_origin[type_args] # type: ignore[index]
|
return subscriptable_origin[type_args] # type: ignore[index]
|
||||||
else:
|
|
||||||
return type_
|
return type_
|
||||||
|
|
||||||
|
|
||||||
@ -337,17 +336,15 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
|||||||
return _convert_json_schema_to_openai_function(
|
return _convert_json_schema_to_openai_function(
|
||||||
tool.tool_call_schema, name=tool.name, description=tool.description
|
tool.tool_call_schema, name=tool.name, description=tool.description
|
||||||
)
|
)
|
||||||
elif issubclass(tool.tool_call_schema, (BaseModel, BaseModelV1)):
|
if issubclass(tool.tool_call_schema, (BaseModel, BaseModelV1)):
|
||||||
return _convert_pydantic_to_openai_function(
|
return _convert_pydantic_to_openai_function(
|
||||||
tool.tool_call_schema, name=tool.name, description=tool.description
|
tool.tool_call_schema, name=tool.name, description=tool.description
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"Unsupported tool call schema: {tool.tool_call_schema}. "
|
f"Unsupported tool call schema: {tool.tool_call_schema}. "
|
||||||
"Tool call schema must be a JSON schema dict or a Pydantic model."
|
"Tool call schema must be a JSON schema dict or a Pydantic model."
|
||||||
)
|
)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
else:
|
|
||||||
return {
|
return {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
@ -730,7 +727,7 @@ def _parse_google_docstring(
|
|||||||
if block.startswith("Args:"):
|
if block.startswith("Args:"):
|
||||||
args_block = block
|
args_block = block
|
||||||
break
|
break
|
||||||
elif block.startswith(("Returns:", "Example:")):
|
if block.startswith(("Returns:", "Example:")):
|
||||||
# Don't break in case Args come after
|
# Don't break in case Args come after
|
||||||
past_descriptors = True
|
past_descriptors = True
|
||||||
elif not past_descriptors:
|
elif not past_descriptors:
|
||||||
|
@ -26,8 +26,7 @@ def get_color_mapping(
|
|||||||
colors = list(_TEXT_COLOR_MAPPING.keys())
|
colors = list(_TEXT_COLOR_MAPPING.keys())
|
||||||
if excluded_colors is not None:
|
if excluded_colors is not None:
|
||||||
colors = [c for c in colors if c not in excluded_colors]
|
colors = [c for c in colors if c not in excluded_colors]
|
||||||
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
return {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||||
return color_mapping
|
|
||||||
|
|
||||||
|
|
||||||
def get_colored_text(text: str, color: str) -> str:
|
def get_colored_text(text: str, color: str) -> str:
|
||||||
|
@ -30,15 +30,13 @@ def _custom_parser(multiline_string: str) -> str:
|
|||||||
if isinstance(multiline_string, (bytes, bytearray)):
|
if isinstance(multiline_string, (bytes, bytearray)):
|
||||||
multiline_string = multiline_string.decode()
|
multiline_string = multiline_string.decode()
|
||||||
|
|
||||||
multiline_string = re.sub(
|
return re.sub(
|
||||||
r'("action_input"\:\s*")(.*?)(")',
|
r'("action_input"\:\s*")(.*?)(")',
|
||||||
_replace_new_line,
|
_replace_new_line,
|
||||||
multiline_string,
|
multiline_string,
|
||||||
flags=re.DOTALL,
|
flags=re.DOTALL,
|
||||||
)
|
)
|
||||||
|
|
||||||
return multiline_string
|
|
||||||
|
|
||||||
|
|
||||||
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/5b6080fae1f8c68938a1e4fa8667e3744084ee21/interpreter/utils/parse_partial_json.py
|
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/5b6080fae1f8c68938a1e4fa8667e3744084ee21/interpreter/utils/parse_partial_json.py
|
||||||
# MIT License
|
# MIT License
|
||||||
|
@ -60,12 +60,11 @@ def _dereference_refs_helper(
|
|||||||
else:
|
else:
|
||||||
obj_out[k] = v
|
obj_out[k] = v
|
||||||
return obj_out
|
return obj_out
|
||||||
elif isinstance(obj, list):
|
if isinstance(obj, list):
|
||||||
return [
|
return [
|
||||||
_dereference_refs_helper(el, full_schema, skip_keys, processed_refs)
|
_dereference_refs_helper(el, full_schema, skip_keys, processed_refs)
|
||||||
for el in obj
|
for el in obj
|
||||||
]
|
]
|
||||||
else:
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,7 +84,6 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
|
|||||||
# Then the next tag could be a standalone
|
# Then the next tag could be a standalone
|
||||||
# Otherwise it can't be
|
# Otherwise it can't be
|
||||||
return padding.isspace() or padding == ""
|
return padding.isspace() or padding == ""
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -107,7 +106,6 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
|
|||||||
return on_newline[0].isspace() or not on_newline[0]
|
return on_newline[0].isspace() or not on_newline[0]
|
||||||
|
|
||||||
# If we're a tag can't be a standalone
|
# If we're a tag can't be a standalone
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ def is_pydantic_v1_subclass(cls: type) -> bool:
|
|||||||
"""Check if the installed Pydantic version is 1.x-like."""
|
"""Check if the installed Pydantic version is 1.x-like."""
|
||||||
if PYDANTIC_MAJOR_VERSION == 1:
|
if PYDANTIC_MAJOR_VERSION == 1:
|
||||||
return True
|
return True
|
||||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
if PYDANTIC_MAJOR_VERSION == 2:
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
if issubclass(cls, BaseModelV1):
|
if issubclass(cls, BaseModelV1):
|
||||||
@ -335,7 +335,7 @@ def _create_subset_model(
|
|||||||
descriptions=descriptions,
|
descriptions=descriptions,
|
||||||
fn_description=fn_description,
|
fn_description=fn_description,
|
||||||
)
|
)
|
||||||
elif PYDANTIC_MAJOR_VERSION == 2:
|
if PYDANTIC_MAJOR_VERSION == 2:
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
if issubclass(model, BaseModelV1):
|
if issubclass(model, BaseModelV1):
|
||||||
@ -346,7 +346,6 @@ def _create_subset_model(
|
|||||||
descriptions=descriptions,
|
descriptions=descriptions,
|
||||||
fn_description=fn_description,
|
fn_description=fn_description,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return _create_subset_model_v2(
|
return _create_subset_model_v2(
|
||||||
name,
|
name,
|
||||||
model,
|
model,
|
||||||
@ -354,7 +353,6 @@ def _create_subset_model(
|
|||||||
descriptions=descriptions,
|
descriptions=descriptions,
|
||||||
fn_description=fn_description,
|
fn_description=fn_description,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
msg = f"Unsupported pydantic version: {PYDANTIC_MAJOR_VERSION}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
@ -387,9 +385,8 @@ if PYDANTIC_MAJOR_VERSION == 2:
|
|||||||
if hasattr(model, "model_fields"):
|
if hasattr(model, "model_fields"):
|
||||||
return model.model_fields # type: ignore
|
return model.model_fields # type: ignore
|
||||||
|
|
||||||
elif hasattr(model, "__fields__"):
|
if hasattr(model, "__fields__"):
|
||||||
return model.__fields__ # type: ignore
|
return model.__fields__ # type: ignore
|
||||||
else:
|
|
||||||
msg = f"Expected a Pydantic model. Got {type(model)}"
|
msg = f"Expected a Pydantic model. Got {type(model)}"
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
|
@ -14,11 +14,10 @@ def stringify_value(val: Any) -> str:
|
|||||||
"""
|
"""
|
||||||
if isinstance(val, str):
|
if isinstance(val, str):
|
||||||
return val
|
return val
|
||||||
elif isinstance(val, dict):
|
if isinstance(val, dict):
|
||||||
return "\n" + stringify_dict(val)
|
return "\n" + stringify_dict(val)
|
||||||
elif isinstance(val, list):
|
if isinstance(val, list):
|
||||||
return "\n".join(stringify_value(v) for v in val)
|
return "\n".join(stringify_value(v) for v in val)
|
||||||
else:
|
|
||||||
return str(val)
|
return str(val)
|
||||||
|
|
||||||
|
|
||||||
|
@ -392,10 +392,8 @@ def from_env(
|
|||||||
|
|
||||||
if isinstance(default, (str, type(None))):
|
if isinstance(default, (str, type(None))):
|
||||||
return default
|
return default
|
||||||
else:
|
|
||||||
if error_message:
|
if error_message:
|
||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Did not find {key}, please add an environment variable"
|
f"Did not find {key}, please add an environment variable"
|
||||||
f" `{key}` which contains it, or pass"
|
f" `{key}` which contains it, or pass"
|
||||||
@ -454,12 +452,10 @@ def secret_from_env(
|
|||||||
return SecretStr(os.environ[key])
|
return SecretStr(os.environ[key])
|
||||||
if isinstance(default, str):
|
if isinstance(default, str):
|
||||||
return SecretStr(default)
|
return SecretStr(default)
|
||||||
elif default is None:
|
if default is None:
|
||||||
return None
|
return None
|
||||||
else:
|
|
||||||
if error_message:
|
if error_message:
|
||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Did not find {key}, please add an environment variable"
|
f"Did not find {key}, please add an environment variable"
|
||||||
f" `{key}` which contains it, or pass"
|
f" `{key}` which contains it, or pass"
|
||||||
|
@ -340,14 +340,13 @@ class VectorStore(ABC):
|
|||||||
"""
|
"""
|
||||||
if search_type == "similarity":
|
if search_type == "similarity":
|
||||||
return self.similarity_search(query, **kwargs)
|
return self.similarity_search(query, **kwargs)
|
||||||
elif search_type == "similarity_score_threshold":
|
if search_type == "similarity_score_threshold":
|
||||||
docs_and_similarities = self.similarity_search_with_relevance_scores(
|
docs_and_similarities = self.similarity_search_with_relevance_scores(
|
||||||
query, **kwargs
|
query, **kwargs
|
||||||
)
|
)
|
||||||
return [doc for doc, _ in docs_and_similarities]
|
return [doc for doc, _ in docs_and_similarities]
|
||||||
elif search_type == "mmr":
|
if search_type == "mmr":
|
||||||
return self.max_marginal_relevance_search(query, **kwargs)
|
return self.max_marginal_relevance_search(query, **kwargs)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"search_type of {search_type} not allowed. Expected "
|
f"search_type of {search_type} not allowed. Expected "
|
||||||
"search_type to be 'similarity', 'similarity_score_threshold'"
|
"search_type to be 'similarity', 'similarity_score_threshold'"
|
||||||
@ -375,14 +374,13 @@ class VectorStore(ABC):
|
|||||||
"""
|
"""
|
||||||
if search_type == "similarity":
|
if search_type == "similarity":
|
||||||
return await self.asimilarity_search(query, **kwargs)
|
return await self.asimilarity_search(query, **kwargs)
|
||||||
elif search_type == "similarity_score_threshold":
|
if search_type == "similarity_score_threshold":
|
||||||
docs_and_similarities = await self.asimilarity_search_with_relevance_scores(
|
docs_and_similarities = await self.asimilarity_search_with_relevance_scores(
|
||||||
query, **kwargs
|
query, **kwargs
|
||||||
)
|
)
|
||||||
return [doc for doc, _ in docs_and_similarities]
|
return [doc for doc, _ in docs_and_similarities]
|
||||||
elif search_type == "mmr":
|
if search_type == "mmr":
|
||||||
return await self.amax_marginal_relevance_search(query, **kwargs)
|
return await self.amax_marginal_relevance_search(query, **kwargs)
|
||||||
else:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"search_type of {search_type} not allowed. Expected "
|
f"search_type of {search_type} not allowed. Expected "
|
||||||
"search_type to be 'similarity', 'similarity_score_threshold' or 'mmr'."
|
"search_type to be 'similarity', 'similarity_score_threshold' or 'mmr'."
|
||||||
|
@ -431,24 +431,22 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[tuple[Document, float]]:
|
) -> list[tuple[Document, float]]:
|
||||||
embedding = self.embedding.embed_query(query)
|
embedding = self.embedding.embed_query(query)
|
||||||
docs = self.similarity_search_with_score_by_vector(
|
return self.similarity_search_with_score_by_vector(
|
||||||
embedding,
|
embedding,
|
||||||
k,
|
k,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return docs
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def asimilarity_search_with_score(
|
async def asimilarity_search_with_score(
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
) -> list[tuple[Document, float]]:
|
) -> list[tuple[Document, float]]:
|
||||||
embedding = await self.embedding.aembed_query(query)
|
embedding = await self.embedding.aembed_query(query)
|
||||||
docs = self.similarity_search_with_score_by_vector(
|
return self.similarity_search_with_score_by_vector(
|
||||||
embedding,
|
embedding,
|
||||||
k,
|
k,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return docs
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def similarity_search_by_vector(
|
def similarity_search_by_vector(
|
||||||
|
@ -103,7 +103,6 @@ ignore = [
|
|||||||
"PGH",
|
"PGH",
|
||||||
"PLR",
|
"PLR",
|
||||||
"PYI",
|
"PYI",
|
||||||
"RET",
|
|
||||||
"RUF",
|
"RUF",
|
||||||
"SLF",
|
"SLF",
|
||||||
"TD",
|
"TD",
|
||||||
|
@ -17,12 +17,11 @@ EXAMPLES = [
|
|||||||
def selector() -> LengthBasedExampleSelector:
|
def selector() -> LengthBasedExampleSelector:
|
||||||
"""Get length based selector to use in tests."""
|
"""Get length based selector to use in tests."""
|
||||||
prompts = PromptTemplate(input_variables=["question"], template="{question}")
|
prompts = PromptTemplate(input_variables=["question"], template="{question}")
|
||||||
selector = LengthBasedExampleSelector(
|
return LengthBasedExampleSelector(
|
||||||
examples=EXAMPLES,
|
examples=EXAMPLES,
|
||||||
example_prompt=prompts,
|
example_prompt=prompts,
|
||||||
max_length=30,
|
max_length=30,
|
||||||
)
|
)
|
||||||
return selector
|
|
||||||
|
|
||||||
|
|
||||||
def test_selector_valid(selector: LengthBasedExampleSelector) -> None:
|
def test_selector_valid(selector: LengthBasedExampleSelector) -> None:
|
||||||
|
@ -18,7 +18,6 @@ def _fake_runnable(
|
|||||||
) -> Union[BaseModel, dict]:
|
) -> Union[BaseModel, dict]:
|
||||||
if isclass(schema) and is_basemodel_subclass(schema):
|
if isclass(schema) and is_basemodel_subclass(schema):
|
||||||
return schema(name="yo", value=value)
|
return schema(name="yo", value=value)
|
||||||
else:
|
|
||||||
params = cast("dict", schema)["parameters"]
|
params = cast("dict", schema)["parameters"]
|
||||||
return {k: 1 if k != "value" else value for k, v in params.items()}
|
return {k: 1 if k != "value" else value for k, v in params.items()}
|
||||||
|
|
||||||
|
@ -219,7 +219,6 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
|||||||
def conditional_str_parser(input: str) -> Runnable:
|
def conditional_str_parser(input: str) -> Runnable:
|
||||||
if input == "a":
|
if input == "a":
|
||||||
return str_parser
|
return str_parser
|
||||||
else:
|
|
||||||
return xml_parser
|
return xml_parser
|
||||||
|
|
||||||
sequence: Runnable = (
|
sequence: Runnable = (
|
||||||
|
@ -2954,9 +2954,8 @@ def test_higher_order_lambda_runnable(
|
|||||||
def router(input: dict[str, Any]) -> Runnable:
|
def router(input: dict[str, Any]) -> Runnable:
|
||||||
if input["key"] == "math":
|
if input["key"] == "math":
|
||||||
return itemgetter("input") | math_chain
|
return itemgetter("input") | math_chain
|
||||||
elif input["key"] == "english":
|
if input["key"] == "english":
|
||||||
return itemgetter("input") | english_chain
|
return itemgetter("input") | english_chain
|
||||||
else:
|
|
||||||
msg = f"Unknown key: {input['key']}"
|
msg = f"Unknown key: {input['key']}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -3011,9 +3010,8 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
|
|||||||
def router(input: dict[str, Any]) -> Runnable:
|
def router(input: dict[str, Any]) -> Runnable:
|
||||||
if input["key"] == "math":
|
if input["key"] == "math":
|
||||||
return itemgetter("input") | math_chain
|
return itemgetter("input") | math_chain
|
||||||
elif input["key"] == "english":
|
if input["key"] == "english":
|
||||||
return itemgetter("input") | english_chain
|
return itemgetter("input") | english_chain
|
||||||
else:
|
|
||||||
msg = f"Unknown key: {input['key']}"
|
msg = f"Unknown key: {input['key']}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -3034,9 +3032,8 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
|
|||||||
async def arouter(input: dict[str, Any]) -> Runnable:
|
async def arouter(input: dict[str, Any]) -> Runnable:
|
||||||
if input["key"] == "math":
|
if input["key"] == "math":
|
||||||
return itemgetter("input") | math_chain
|
return itemgetter("input") | math_chain
|
||||||
elif input["key"] == "english":
|
if input["key"] == "english":
|
||||||
return itemgetter("input") | english_chain
|
return itemgetter("input") | english_chain
|
||||||
else:
|
|
||||||
msg = f"Unknown key: {input['key']}"
|
msg = f"Unknown key: {input['key']}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -3858,7 +3855,6 @@ def test_recursive_lambda() -> None:
|
|||||||
def _simple_recursion(x: int) -> Union[int, Runnable]:
|
def _simple_recursion(x: int) -> Union[int, Runnable]:
|
||||||
if x < 10:
|
if x < 10:
|
||||||
return RunnableLambda(lambda *args: _simple_recursion(x + 1))
|
return RunnableLambda(lambda *args: _simple_recursion(x + 1))
|
||||||
else:
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
runnable = RunnableLambda(_simple_recursion)
|
runnable = RunnableLambda(_simple_recursion)
|
||||||
@ -3873,10 +3869,9 @@ def test_retrying(mocker: MockerFixture) -> None:
|
|||||||
if x == 1:
|
if x == 1:
|
||||||
msg = "x is 1"
|
msg = "x is 1"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif x == 2:
|
if x == 2:
|
||||||
msg = "x is 2"
|
msg = "x is 2"
|
||||||
raise RuntimeError(msg)
|
raise RuntimeError(msg)
|
||||||
else:
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
||||||
@ -3938,10 +3933,9 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
|||||||
if x == 1:
|
if x == 1:
|
||||||
msg = "x is 1"
|
msg = "x is 1"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif x == 2:
|
if x == 2:
|
||||||
msg = "x is 2"
|
msg = "x is 2"
|
||||||
raise RuntimeError(msg)
|
raise RuntimeError(msg)
|
||||||
else:
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
||||||
|
@ -545,7 +545,6 @@ async def test_astream_events_from_model() -> None:
|
|||||||
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
return model.invoke(input)
|
return model.invoke(input)
|
||||||
else:
|
|
||||||
return model.invoke(input, config)
|
return model.invoke(input, config)
|
||||||
|
|
||||||
events = await _collect_events(i_dont_stream.astream_events("hello", version="v1"))
|
events = await _collect_events(i_dont_stream.astream_events("hello", version="v1"))
|
||||||
@ -670,7 +669,6 @@ async def test_astream_events_from_model() -> None:
|
|||||||
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
return await model.ainvoke(input)
|
return await model.ainvoke(input)
|
||||||
else:
|
|
||||||
return await model.ainvoke(input, config)
|
return await model.ainvoke(input, config)
|
||||||
|
|
||||||
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1"))
|
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1"))
|
||||||
|
@ -615,7 +615,6 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
return model.invoke(input)
|
return model.invoke(input)
|
||||||
else:
|
|
||||||
return model.invoke(input, config)
|
return model.invoke(input, config)
|
||||||
|
|
||||||
events = await _collect_events(i_dont_stream.astream_events("hello", version="v2"))
|
events = await _collect_events(i_dont_stream.astream_events("hello", version="v2"))
|
||||||
@ -724,7 +723,6 @@ async def test_astream_with_model_in_chain() -> None:
|
|||||||
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
return await model.ainvoke(input)
|
return await model.ainvoke(input)
|
||||||
else:
|
|
||||||
return await model.ainvoke(input, config)
|
return await model.ainvoke(input, config)
|
||||||
|
|
||||||
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2"))
|
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2"))
|
||||||
|
@ -334,7 +334,6 @@ class TestRunnableSequenceParallelTraceNesting:
|
|||||||
parent_id_map[n] = matching_post.get("parent_run_id")
|
parent_id_map[n] = matching_post.get("parent_run_id")
|
||||||
i += len(name)
|
i += len(name)
|
||||||
continue
|
continue
|
||||||
else:
|
|
||||||
assert posts[i]["name"] == name
|
assert posts[i]["name"] == name
|
||||||
dotted_order = posts[i]["dotted_order"]
|
dotted_order = posts[i]["dotted_order"]
|
||||||
if prev_dotted_order is not None and not str(
|
if prev_dotted_order is not None and not str(
|
||||||
|
@ -80,7 +80,6 @@ def _get_tool_call_json_schema(tool: BaseTool) -> dict:
|
|||||||
|
|
||||||
if hasattr(tool_schema, "model_json_schema"):
|
if hasattr(tool_schema, "model_json_schema"):
|
||||||
return tool_schema.model_json_schema()
|
return tool_schema.model_json_schema()
|
||||||
else:
|
|
||||||
return tool_schema.schema()
|
return tool_schema.schema()
|
||||||
|
|
||||||
|
|
||||||
|
@ -599,8 +599,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
|
|
||||||
def _get_mock_client() -> Client:
|
def _get_mock_client() -> Client:
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
client = Client(session=mock_session, api_key="test")
|
return Client(session=mock_session, api_key="test")
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
def test_traceable_to_tracing() -> None:
|
def test_traceable_to_tracing() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user