langchain: Bump ruff version to 0.9 (#29211)

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Christophe Bornet 2025-01-22 01:26:39 +01:00 committed by GitHub
parent 2340b3154d
commit a004dec119
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 307 additions and 250 deletions

View File

@ -84,8 +84,7 @@ def initialize_agent(
pass pass
else: else:
raise ValueError( raise ValueError(
"Somehow both `agent` and `agent_path` are None, " "Somehow both `agent` and `agent_path` are None, this should never happen."
"this should never happen."
) )
return AgentExecutor.from_agent_and_tools( return AgentExecutor.from_agent_and_tools(
agent=agent_obj, agent=agent_obj,

View File

@ -58,8 +58,7 @@ def load_agent_from_config(
if load_from_tools: if load_from_tools:
if llm is None: if llm is None:
raise ValueError( raise ValueError(
"If `load_from_llm_and_tools` is set to True, " "If `load_from_llm_and_tools` is set to True, then LLM must be provided"
"then LLM must be provided"
) )
if tools is None: if tools is None:
raise ValueError( raise ValueError(

View File

@ -41,6 +41,6 @@ class LoggingCallbackHandler(FunctionCallbackHandler):
except TracerException: except TracerException:
crumbs_str = "" crumbs_str = ""
self.function_callback( self.function_callback(
f'{get_colored_text("[text]", color="blue")}' f"{get_colored_text('[text]', color='blue')}"
f' {get_bolded_text(f"{crumbs_str}New text:")}\n{text}' f" {get_bolded_text(f'{crumbs_str}New text:')}\n{text}"
) )

View File

@ -370,6 +370,7 @@ try:
@property @property
def _chain_type(self) -> str: def _chain_type(self) -> str:
return "api_chain" return "api_chain"
except ImportError: except ImportError:
class APIChain: # type: ignore[no-redef] class APIChain: # type: ignore[no-redef]

View File

@ -233,7 +233,7 @@ class SimpleRequestChain(Chain):
response = ( response = (
f"{api_response.status_code}: {api_response.reason}" f"{api_response.status_code}: {api_response.reason}"
+ f"\nFor {name} " + f"\nFor {name} "
+ f"Called with args: {args.get('params','')}" + f"Called with args: {args.get('params', '')}"
) )
else: else:
try: try:

View File

@ -68,7 +68,10 @@ def create_extraction_chain_pydantic(
if not isinstance(pydantic_schemas, list): if not isinstance(pydantic_schemas, list):
pydantic_schemas = [pydantic_schemas] pydantic_schemas = [pydantic_schemas]
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(
[("system", system_message), ("user", "{input}")] [
("system", system_message),
("user", "{input}"),
]
) )
functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas] functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas]
tools = [{"type": "function", "function": d} for d in functions] tools = [{"type": "function", "function": d} for d in functions]

View File

@ -33,7 +33,11 @@ refine_template = (
"If the context isn't useful, return the original answer." "If the context isn't useful, return the original answer."
) )
CHAT_REFINE_PROMPT = ChatPromptTemplate.from_messages( CHAT_REFINE_PROMPT = ChatPromptTemplate.from_messages(
[("human", "{question}"), ("ai", "{existing_answer}"), ("human", refine_template)] [
("human", "{question}"),
("ai", "{existing_answer}"),
("human", refine_template),
]
) )
REFINE_PROMPT_SELECTOR = ConditionalPromptSelector( REFINE_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=DEFAULT_REFINE_PROMPT, default_prompt=DEFAULT_REFINE_PROMPT,
@ -60,7 +64,10 @@ chat_qa_prompt_template = (
"answer any questions" "answer any questions"
) )
CHAT_QUESTION_PROMPT = ChatPromptTemplate.from_messages( CHAT_QUESTION_PROMPT = ChatPromptTemplate.from_messages(
[("system", chat_qa_prompt_template), ("human", "{question}")] [
("system", chat_qa_prompt_template),
("human", "{question}"),
]
) )
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector( QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=DEFAULT_TEXT_QA_PROMPT, default_prompt=DEFAULT_TEXT_QA_PROMPT,

View File

@ -178,7 +178,9 @@ class SimpleSequentialChain(Chain):
_input = inputs[self.input_key] _input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
for i, chain in enumerate(self.chains): for i, chain in enumerate(self.chains):
_input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}")) _input = chain.run(
_input, callbacks=_run_manager.get_child(f"step_{i + 1}")
)
if self.strip_outputs: if self.strip_outputs:
_input = _input.strip() _input = _input.strip()
_run_manager.on_text( _run_manager.on_text(
@ -196,7 +198,7 @@ class SimpleSequentialChain(Chain):
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
for i, chain in enumerate(self.chains): for i, chain in enumerate(self.chains):
_input = await chain.arun( _input = await chain.arun(
_input, callbacks=_run_manager.get_child(f"step_{i+1}") _input, callbacks=_run_manager.get_child(f"step_{i + 1}")
) )
if self.strip_outputs: if self.strip_outputs:
_input = _input.strip() _input = _input.strip()

View File

@ -590,7 +590,11 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
queued_declarative_operations = list(self._queued_declarative_operations) queued_declarative_operations = list(self._queued_declarative_operations)
if remaining_config: if remaining_config:
queued_declarative_operations.append( queued_declarative_operations.append(
("with_config", (), {"config": remaining_config}) (
"with_config",
(),
{"config": remaining_config},
)
) )
return _ConfigurableModel( return _ConfigurableModel(
default_config={**self._default_config, **model_params}, default_config={**self._default_config, **model_params},

View File

@ -174,8 +174,7 @@ def init_embeddings(
if not model: if not model:
providers = _SUPPORTED_PROVIDERS.keys() providers = _SUPPORTED_PROVIDERS.keys()
raise ValueError( raise ValueError(
"Must specify model name. " f"Must specify model name. Supported providers are: {', '.join(providers)}"
f"Supported providers are: {', '.join(providers)}"
) )
provider, model_name = _infer_model_and_provider(model, provider=provider) provider, model_name = _infer_model_and_provider(model, provider=provider)

View File

@ -310,7 +310,10 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
Dict[str, Any]: The computed score. Dict[str, Any]: The computed score.
""" """
embedded = await self.embeddings.aembed_documents( embedded = await self.embeddings.aembed_documents(
[inputs["prediction"], inputs["reference"]] [
inputs["prediction"],
inputs["reference"],
]
) )
vectors = np.array(embedded) vectors = np.array(embedded)
score = self._compute_score(vectors) score = self._compute_score(vectors)
@ -427,7 +430,10 @@ class PairwiseEmbeddingDistanceEvalChain(
""" """
vectors = np.array( vectors = np.array(
self.embeddings.embed_documents( self.embeddings.embed_documents(
[inputs["prediction"], inputs["prediction_b"]] [
inputs["prediction"],
inputs["prediction_b"],
]
) )
) )
score = self._compute_score(vectors) score = self._compute_score(vectors)
@ -449,7 +455,10 @@ class PairwiseEmbeddingDistanceEvalChain(
Dict[str, Any]: The computed score. Dict[str, Any]: The computed score.
""" """
embedded = await self.embeddings.aembed_documents( embedded = await self.embeddings.aembed_documents(
[inputs["prediction"], inputs["prediction_b"]] [
inputs["prediction"],
inputs["prediction_b"],
]
) )
vectors = np.array(embedded) vectors = np.array(embedded)
score = self._compute_score(vectors) score = self._compute_score(vectors)

View File

@ -71,7 +71,10 @@ class BaseChatMemory(BaseMemory, ABC):
"""Save context from this conversation to buffer.""" """Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs) input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_messages( self.chat_memory.add_messages(
[HumanMessage(content=input_str), AIMessage(content=output_str)] [
HumanMessage(content=input_str),
AIMessage(content=output_str),
]
) )
async def asave_context( async def asave_context(
@ -80,7 +83,10 @@ class BaseChatMemory(BaseMemory, ABC):
"""Save context from this conversation to buffer.""" """Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs) input_str, output_str = self._get_input_output(inputs, outputs)
await self.chat_memory.aadd_messages( await self.chat_memory.aadd_messages(
[HumanMessage(content=input_str), AIMessage(content=output_str)] [
HumanMessage(content=input_str),
AIMessage(content=output_str),
]
) )
def clear(self) -> None: def clear(self) -> None:

View File

@ -92,7 +92,10 @@ class CohereRerank(BaseDocumentCompressor):
result_dicts = [] result_dicts = []
for res in results: for res in results:
result_dicts.append( result_dicts.append(
{"index": res.index, "relevance_score": res.relevance_score} {
"index": res.index,
"relevance_score": res.relevance_score,
}
) )
return result_dicts return result_dicts

View File

@ -223,7 +223,7 @@ class EnsembleRetriever(BaseRetriever):
retriever.invoke( retriever.invoke(
query, query,
patch_config( patch_config(
config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}") config, callbacks=run_manager.get_child(tag=f"retriever_{i + 1}")
), ),
) )
for i, retriever in enumerate(self.retrievers) for i, retriever in enumerate(self.retrievers)
@ -265,7 +265,8 @@ class EnsembleRetriever(BaseRetriever):
retriever.ainvoke( retriever.ainvoke(
query, query,
patch_config( patch_config(
config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}") config,
callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
), ),
) )
for i, retriever in enumerate(self.retrievers) for i, retriever in enumerate(self.retrievers)

View File

@ -247,8 +247,7 @@ def _get_prompt(inputs: Dict[str, Any]) -> str:
if "prompt" in inputs: if "prompt" in inputs:
if not isinstance(inputs["prompt"], str): if not isinstance(inputs["prompt"], str):
raise InputFormatError( raise InputFormatError(
"Expected string for 'prompt', got" f"Expected string for 'prompt', got {type(inputs['prompt']).__name__}"
f" {type(inputs['prompt']).__name__}"
) )
prompts = [inputs["prompt"]] prompts = [inputs["prompt"]]
elif "prompts" in inputs: elif "prompts" in inputs:

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "aiohappyeyeballs" name = "aiohappyeyeballs"
@ -704,13 +704,13 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""}
[[package]] [[package]]
name = "codespell" name = "codespell"
version = "2.3.0" version = "2.4.0"
description = "Codespell" description = "Fix common misspellings in text files"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "codespell-2.3.0-py3-none-any.whl", hash = "sha256:a9c7cef2501c9cfede2110fd6d4e5e62296920efe9abfb84648df866e47f58d1"}, {file = "codespell-2.4.0-py3-none-any.whl", hash = "sha256:b4c5b779f747dd481587aeecb5773301183f52b94b96ed51a28126d0482eec1d"},
{file = "codespell-2.3.0.tar.gz", hash = "sha256:360c7d10f75e65f67bad720af7007e1060a5d395670ec11a7ed1fed9dd17471f"}, {file = "codespell-2.4.0.tar.gz", hash = "sha256:587d45b14707fb8ce51339ba4cce50ae0e98ce228ef61f3c5e160e34f681be58"},
] ]
[package.extras] [package.extras]
@ -4089,29 +4089,29 @@ files = [
[[package]] [[package]]
name = "ruff" name = "ruff"
version = "0.5.7" version = "0.9.2"
description = "An extremely fast Python linter and code formatter, written in Rust." description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "ruff-0.5.7-py3-none-linux_armv6l.whl", hash = "sha256:548992d342fc404ee2e15a242cdbea4f8e39a52f2e7752d0e4cbe88d2d2f416a"}, {file = "ruff-0.9.2-py3-none-linux_armv6l.whl", hash = "sha256:80605a039ba1454d002b32139e4970becf84b5fee3a3c3bf1c2af6f61a784347"},
{file = "ruff-0.5.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:00cc8872331055ee017c4f1071a8a31ca0809ccc0657da1d154a1d2abac5c0be"}, {file = "ruff-0.9.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b9aab82bb20afd5f596527045c01e6ae25a718ff1784cb92947bff1f83068b00"},
{file = "ruff-0.5.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eaf3d86a1fdac1aec8a3417a63587d93f906c678bb9ed0b796da7b59c1114a1e"}, {file = "ruff-0.9.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fbd337bac1cfa96be615f6efcd4bc4d077edbc127ef30e2b8ba2a27e18c054d4"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a01c34400097b06cf8a6e61b35d6d456d5bd1ae6961542de18ec81eaf33b4cb8"}, {file = "ruff-0.9.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82b35259b0cbf8daa22a498018e300b9bb0174c2bbb7bcba593935158a78054d"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcc8054f1a717e2213500edaddcf1dbb0abad40d98e1bd9d0ad364f75c763eea"}, {file = "ruff-0.9.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b6a9701d1e371bf41dca22015c3f89769da7576884d2add7317ec1ec8cb9c3c"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f70284e73f36558ef51602254451e50dd6cc479f8b6f8413a95fcb5db4a55fc"}, {file = "ruff-0.9.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9cc53e68b3c5ae41e8faf83a3b89f4a5d7b2cb666dff4b366bb86ed2a85b481f"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:a78ad870ae3c460394fc95437d43deb5c04b5c29297815a2a1de028903f19692"}, {file = "ruff-0.9.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:8efd9da7a1ee314b910da155ca7e8953094a7c10d0c0a39bfde3fcfd2a015684"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ccd078c66a8e419475174bfe60a69adb36ce04f8d4e91b006f1329d5cd44bcf"}, {file = "ruff-0.9.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3292c5a22ea9a5f9a185e2d131dc7f98f8534a32fb6d2ee7b9944569239c648d"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e31c9bad4ebf8fdb77b59cae75814440731060a09a0e0077d559a556453acbb"}, {file = "ruff-0.9.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a605fdcf6e8b2d39f9436d343d1f0ff70c365a1e681546de0104bef81ce88df"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d796327eed8e168164346b769dd9a27a70e0298d667b4ecee6877ce8095ec8e"}, {file = "ruff-0.9.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c547f7f256aa366834829a08375c297fa63386cbe5f1459efaf174086b564247"},
{file = "ruff-0.5.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a09ea2c3f7778cc635e7f6edf57d566a8ee8f485f3c4454db7771efb692c499"}, {file = "ruff-0.9.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d18bba3d3353ed916e882521bc3e0af403949dbada344c20c16ea78f47af965e"},
{file = "ruff-0.5.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a36d8dcf55b3a3bc353270d544fb170d75d2dff41eba5df57b4e0b67a95bb64e"}, {file = "ruff-0.9.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b338edc4610142355ccf6b87bd356729b62bf1bc152a2fad5b0c7dc04af77bfe"},
{file = "ruff-0.5.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9369c218f789eefbd1b8d82a8cf25017b523ac47d96b2f531eba73770971c9e5"}, {file = "ruff-0.9.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:492a5e44ad9b22a0ea98cf72e40305cbdaf27fac0d927f8bc9e1df316dcc96eb"},
{file = "ruff-0.5.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b88ca3db7eb377eb24fb7c82840546fb7acef75af4a74bd36e9ceb37a890257e"}, {file = "ruff-0.9.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:af1e9e9fe7b1f767264d26b1075ac4ad831c7db976911fa362d09b2d0356426a"},
{file = "ruff-0.5.7-py3-none-win32.whl", hash = "sha256:33d61fc0e902198a3e55719f4be6b375b28f860b09c281e4bdbf783c0566576a"}, {file = "ruff-0.9.2-py3-none-win32.whl", hash = "sha256:71cbe22e178c5da20e1514e1e01029c73dc09288a8028a5d3446e6bba87a5145"},
{file = "ruff-0.5.7-py3-none-win_amd64.whl", hash = "sha256:083bbcbe6fadb93cd86709037acc510f86eed5a314203079df174c40bbbca6b3"}, {file = "ruff-0.9.2-py3-none-win_amd64.whl", hash = "sha256:c5e1d6abc798419cf46eed03f54f2e0c3adb1ad4b801119dedf23fcaf69b55b5"},
{file = "ruff-0.5.7-py3-none-win_arm64.whl", hash = "sha256:2dca26154ff9571995107221d0aeaad0e75a77b5a682d6236cf89a58c70b76f4"}, {file = "ruff-0.9.2-py3-none-win_arm64.whl", hash = "sha256:a1b63fa24149918f8b37cef2ee6fff81f24f0d74b6f0bdc37bc3e1f2143e41c6"},
{file = "ruff-0.5.7.tar.gz", hash = "sha256:8dfc0a458797f5d9fb622dd0efc52d796f23f0a1493a9527f4e49a550ae9a7e5"}, {file = "ruff-0.9.2.tar.gz", hash = "sha256:b5eceb334d55fae5f316f783437392642ae18e16dcf4f1858d55d3c2a0f8f5d0"},
] ]
[[package]] [[package]]
@ -4689,13 +4689,13 @@ files = [
[[package]] [[package]]
name = "tzdata" name = "tzdata"
version = "2024.2" version = "2025.1"
description = "Provider of IANA time zone data" description = "Provider of IANA time zone data"
optional = false optional = false
python-versions = ">=2" python-versions = ">=2"
files = [ files = [
{file = "tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd"}, {file = "tzdata-2025.1-py2.py3-none-any.whl", hash = "sha256:7e127113816800496f027041c570f50bcd464a020098a3b6b199517772303639"},
{file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"}, {file = "tzdata-2025.1.tar.gz", hash = "sha256:24894909e88cdb28bd1636c6887801df64cb485bd593f2fd83ef29075a81d694"},
] ]
[[package]] [[package]]
@ -5076,4 +5076,4 @@ type = ["pytest-mypy"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<4.0" python-versions = ">=3.9,<4.0"
content-hash = "7dd55a6e29a188f48f59469ce0ca25ecb2d5a85df82121e181462cddb3f14fbc" content-hash = "8973c8570d06948300d1030dfae674d33ba8f194746413d6930c0bb0af593b76"

View File

@ -12,6 +12,7 @@ readme = "README.md"
repository = "https://github.com/langchain-ai/langchain" repository = "https://github.com/langchain-ai/langchain"
[tool.ruff] [tool.ruff]
target-version = "py39"
exclude = [ "tests/integration_tests/examples/non-utf8-encoding.py",] exclude = [ "tests/integration_tests/examples/non-utf8-encoding.py",]
[tool.mypy] [tool.mypy]
@ -119,7 +120,7 @@ cassio = "^0.1.0"
langchainhub = "^0.1.16" langchainhub = "^0.1.16"
[tool.poetry.group.lint.dependencies] [tool.poetry.group.lint.dependencies]
ruff = "^0.5" ruff = "^0.9.2"
[[tool.poetry.group.lint.dependencies.cffi]] [[tool.poetry.group.lint.dependencies.cffi]]
version = "<1.17.1" version = "<1.17.1"
python = "<3.10" python = "<3.10"

View File

@ -451,7 +451,10 @@ async def test_runnable_agent() -> None:
model = GenericFakeChatModel(messages=infinite_cycle) model = GenericFakeChatModel(messages=infinite_cycle)
template = ChatPromptTemplate.from_messages( template = ChatPromptTemplate.from_messages(
[("system", "You are Cat Agent 007"), ("human", "{question}")] [
("system", "You are Cat Agent 007"),
("human", "{question}"),
]
) )
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
@ -539,12 +542,18 @@ async def test_runnable_agent_with_function_calls() -> None:
"""Test agent with intermediate agent actions.""" """Test agent with intermediate agent actions."""
# Will alternate between responding with hello and goodbye # Will alternate between responding with hello and goodbye
infinite_cycle = cycle( infinite_cycle = cycle(
[AIMessage(content="looking for pet..."), AIMessage(content="Found Pet")] [
AIMessage(content="looking for pet..."),
AIMessage(content="Found Pet"),
]
) )
model = GenericFakeChatModel(messages=infinite_cycle) model = GenericFakeChatModel(messages=infinite_cycle)
template = ChatPromptTemplate.from_messages( template = ChatPromptTemplate.from_messages(
[("system", "You are Cat Agent 007"), ("human", "{question}")] [
("system", "You are Cat Agent 007"),
("human", "{question}"),
]
) )
parser_responses = cycle( parser_responses = cycle(
@ -635,12 +644,18 @@ async def test_runnable_with_multi_action_per_step() -> None:
"""Test an agent that can make multiple function calls at once.""" """Test an agent that can make multiple function calls at once."""
# Will alternate between responding with hello and goodbye # Will alternate between responding with hello and goodbye
infinite_cycle = cycle( infinite_cycle = cycle(
[AIMessage(content="looking for pet..."), AIMessage(content="Found Pet")] [
AIMessage(content="looking for pet..."),
AIMessage(content="Found Pet"),
]
) )
model = GenericFakeChatModel(messages=infinite_cycle) model = GenericFakeChatModel(messages=infinite_cycle)
template = ChatPromptTemplate.from_messages( template = ChatPromptTemplate.from_messages(
[("system", "You are Cat Agent 007"), ("human", "{question}")] [
("system", "You are Cat Agent 007"),
("human", "{question}"),
]
) )
parser_responses = cycle( parser_responses = cycle(
@ -861,7 +876,7 @@ async def test_openai_agent_with_streaming() -> None:
{ {
"additional_kwargs": { "additional_kwargs": {
"function_call": { "function_call": {
"arguments": '{"pet": ' '"cat"}', "arguments": '{"pet": "cat"}',
"name": "find_pet", "name": "find_pet",
} }
}, },
@ -880,7 +895,7 @@ async def test_openai_agent_with_streaming() -> None:
{ {
"additional_kwargs": { "additional_kwargs": {
"function_call": { "function_call": {
"arguments": '{"pet": ' '"cat"}', "arguments": '{"pet": "cat"}',
"name": "find_pet", "name": "find_pet",
} }
}, },
@ -909,10 +924,7 @@ async def test_openai_agent_with_streaming() -> None:
"steps": [ "steps": [
{ {
"action": { "action": {
"log": "\n" "log": "\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
"Invoking: `find_pet` with `{'pet': 'cat'}`\n"
"\n"
"\n",
"tool": "find_pet", "tool": "find_pet",
"tool_input": {"pet": "cat"}, "tool_input": {"pet": "cat"},
"type": "AgentActionMessageLog", "type": "AgentActionMessageLog",
@ -1062,15 +1074,131 @@ async def test_openai_agent_tools_agent() -> None:
# astream # astream
chunks = [chunk async for chunk in executor.astream({"question": "hello"})] chunks = [chunk async for chunk in executor.astream({"question": "hello"})]
assert ( assert chunks == [
chunks {
== [ "actions": [
{ OpenAIToolAgentAction(
"actions": [ tool="find_pet",
OpenAIToolAgentAction( tool_input={"pet": "cat"},
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
message_log=[
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"function": {
"name": "find_pet",
"arguments": '{"pet": "cat"}',
},
"id": "0",
},
{
"function": {
"name": "check_time",
"arguments": "{}",
},
"id": "1",
},
]
},
)
],
tool_call_id="0",
)
],
"messages": [
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"function": {
"name": "find_pet",
"arguments": '{"pet": "cat"}',
},
"id": "0",
},
{
"function": {
"name": "check_time",
"arguments": "{}",
},
"id": "1",
},
]
},
)
],
},
{
"actions": [
OpenAIToolAgentAction(
tool="check_time",
tool_input={},
log="\nInvoking: `check_time` with `{}`\n\n\n",
message_log=[
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"function": {
"name": "find_pet",
"arguments": '{"pet": "cat"}',
},
"id": "0",
},
{
"function": {
"name": "check_time",
"arguments": "{}",
},
"id": "1",
},
]
},
)
],
tool_call_id="1",
)
],
"messages": [
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"function": {
"name": "find_pet",
"arguments": '{"pet": "cat"}',
},
"id": "0",
},
{
"function": {
"name": "check_time",
"arguments": "{}",
},
"id": "1",
},
]
},
)
],
},
{
"messages": [
FunctionMessage(
content="Spying from under the bed.", name="find_pet"
)
],
"steps": [
AgentStep(
action=OpenAIToolAgentAction(
tool="find_pet", tool="find_pet",
tool_input={"pet": "cat"}, tool_input={"pet": "cat"},
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", # noqa: E501
message_log=[ message_log=[
_AnyIdAIMessageChunk( _AnyIdAIMessageChunk(
content="", content="",
@ -1095,35 +1223,21 @@ async def test_openai_agent_tools_agent() -> None:
) )
], ],
tool_call_id="0", tool_call_id="0",
) ),
], observation="Spying from under the bed.",
"messages": [ )
_AnyIdAIMessageChunk( ],
content="", },
additional_kwargs={ {
"tool_calls": [ "messages": [
{ FunctionMessage(
"function": { content="check_time is not a valid tool, try one of [find_pet].", # noqa: E501
"name": "find_pet", name="check_time",
"arguments": '{"pet": "cat"}', )
}, ],
"id": "0", "steps": [
}, AgentStep(
{ action=OpenAIToolAgentAction(
"function": {
"name": "check_time",
"arguments": "{}",
},
"id": "1",
},
]
},
)
],
},
{
"actions": [
OpenAIToolAgentAction(
tool="check_time", tool="check_time",
tool_input={}, tool_input={},
log="\nInvoking: `check_time` with `{}`\n\n\n", log="\nInvoking: `check_time` with `{}`\n\n\n",
@ -1151,124 +1265,19 @@ async def test_openai_agent_tools_agent() -> None:
) )
], ],
tool_call_id="1", tool_call_id="1",
) ),
], observation="check_time is not a valid tool, "
"messages": [ "try one of [find_pet].",
_AnyIdAIMessageChunk( )
content="", ],
additional_kwargs={ },
"tool_calls": [ {
{ "messages": [
"function": { AIMessage(content="The cat is spying from under the bed.")
"name": "find_pet", ],
"arguments": '{"pet": "cat"}', "output": "The cat is spying from under the bed.",
}, },
"id": "0", ]
},
{
"function": {
"name": "check_time",
"arguments": "{}",
},
"id": "1",
},
]
},
)
],
},
{
"messages": [
FunctionMessage(
content="Spying from under the bed.", name="find_pet"
)
],
"steps": [
AgentStep(
action=OpenAIToolAgentAction(
tool="find_pet",
tool_input={"pet": "cat"},
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", # noqa: E501
message_log=[
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"function": {
"name": "find_pet",
"arguments": '{"pet": "cat"}',
},
"id": "0",
},
{
"function": {
"name": "check_time",
"arguments": "{}",
},
"id": "1",
},
]
},
)
],
tool_call_id="0",
),
observation="Spying from under the bed.",
)
],
},
{
"messages": [
FunctionMessage(
content="check_time is not a valid tool, try one of [find_pet].", # noqa: E501
name="check_time",
)
],
"steps": [
AgentStep(
action=OpenAIToolAgentAction(
tool="check_time",
tool_input={},
log="\nInvoking: `check_time` with `{}`\n\n\n",
message_log=[
_AnyIdAIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"function": {
"name": "find_pet",
"arguments": '{"pet": "cat"}',
},
"id": "0",
},
{
"function": {
"name": "check_time",
"arguments": "{}",
},
"id": "1",
},
]
},
)
],
tool_call_id="1",
),
observation="check_time is not a valid tool, "
"try one of [find_pet].",
)
],
},
{
"messages": [
AIMessage(content="The cat is spying from under the bed.")
],
"output": "The cat is spying from under the bed.",
},
]
)
# astream_log # astream_log
log_patches = [ log_patches = [

View File

@ -24,9 +24,7 @@ def get_action_and_input(text: str) -> Tuple[str, str]:
def test_get_action_and_input() -> None: def test_get_action_and_input() -> None:
"""Test getting an action from text.""" """Test getting an action from text."""
llm_output = ( llm_output = "Thought: I need to search for NBA\nAction: Search\nAction Input: NBA"
"Thought: I need to search for NBA\n" "Action: Search\n" "Action Input: NBA"
)
action, action_input = get_action_and_input(llm_output) action, action_input = get_action_and_input(llm_output)
assert action == "Search" assert action == "Search"
assert action_input == "NBA" assert action_input == "NBA"
@ -91,7 +89,7 @@ def test_get_action_and_input_sql_query() -> None:
def test_get_final_answer() -> None: def test_get_final_answer() -> None:
"""Test getting final answer.""" """Test getting final answer."""
llm_output = "Thought: I can now answer the question\n" "Final Answer: 1994" llm_output = "Thought: I can now answer the question\nFinal Answer: 1994"
action, action_input = get_action_and_input(llm_output) action, action_input = get_action_and_input(llm_output)
assert action == "Final Answer" assert action == "Final Answer"
assert action_input == "1994" assert action_input == "1994"
@ -99,7 +97,7 @@ def test_get_final_answer() -> None:
def test_get_final_answer_new_line() -> None: def test_get_final_answer_new_line() -> None:
"""Test getting final answer.""" """Test getting final answer."""
llm_output = "Thought: I can now answer the question\n" "Final Answer:\n1994" llm_output = "Thought: I can now answer the question\nFinal Answer:\n1994"
action, action_input = get_action_and_input(llm_output) action, action_input = get_action_and_input(llm_output)
assert action == "Final Answer" assert action == "Final Answer"
assert action_input == "1994" assert action_input == "1994"
@ -107,7 +105,7 @@ def test_get_final_answer_new_line() -> None:
def test_get_final_answer_multiline() -> None: def test_get_final_answer_multiline() -> None:
"""Test getting final answer that is multiline.""" """Test getting final answer that is multiline."""
llm_output = "Thought: I can now answer the question\n" "Final Answer: 1994\n1993" llm_output = "Thought: I can now answer the question\nFinal Answer: 1994\n1993"
action, action_input = get_action_and_input(llm_output) action, action_input = get_action_and_input(llm_output)
assert action == "Final Answer" assert action == "Final Answer"
assert action_input == "1994\n1993" assert action_input == "1994\n1993"
@ -115,7 +113,7 @@ def test_get_final_answer_multiline() -> None:
def test_bad_action_input_line() -> None: def test_bad_action_input_line() -> None:
"""Test handling when no action input found.""" """Test handling when no action input found."""
llm_output = "Thought: I need to search for NBA\n" "Action: Search\n" "Thought: NBA" llm_output = "Thought: I need to search for NBA\nAction: Search\nThought: NBA"
with pytest.raises(OutputParserException) as e_info: with pytest.raises(OutputParserException) as e_info:
get_action_and_input(llm_output) get_action_and_input(llm_output)
assert e_info.value.observation is not None assert e_info.value.observation is not None
@ -123,9 +121,7 @@ def test_bad_action_input_line() -> None:
def test_bad_action_line() -> None: def test_bad_action_line() -> None:
"""Test handling when no action found.""" """Test handling when no action found."""
llm_output = ( llm_output = "Thought: I need to search for NBA\nThought: Search\nAction Input: NBA"
"Thought: I need to search for NBA\n" "Thought: Search\n" "Action Input: NBA"
)
with pytest.raises(OutputParserException) as e_info: with pytest.raises(OutputParserException) as e_info:
get_action_and_input(llm_output) get_action_and_input(llm_output)
assert e_info.value.observation is not None assert e_info.value.observation is not None

View File

@ -22,6 +22,6 @@ def test_critique_parsing() -> None:
for text in [TEXT_ONE, TEXT_TWO, TEXT_THREE]: for text in [TEXT_ONE, TEXT_TWO, TEXT_THREE]:
critique = ConstitutionalChain._parse_critique(text) critique = ConstitutionalChain._parse_critique(text)
assert ( assert critique.strip() == "This text is bad.", (
critique.strip() == "This text is bad." f"Failed on {text} with {critique}"
), f"Failed on {text} with {critique}" )

View File

@ -21,6 +21,9 @@ def test_create() -> None:
expected_output = [Document(page_content="I know the answer!")] expected_output = [Document(page_content="I know the answer!")]
output = chain.invoke( output = chain.invoke(
{"input": "What is the answer?", "chat_history": ["hi", "hi"]} {
"input": "What is the answer?",
"chat_history": ["hi", "hi"],
}
) )
assert output == expected_output assert output == expected_output

View File

@ -33,7 +33,7 @@ def test_complex_question(fake_llm_math_chain: LLMMathChain) -> None:
"""Test complex question that should need python.""" """Test complex question that should need python."""
question = "What is the square root of 2?" question = "What is the square root of 2?"
output = fake_llm_math_chain.run(question) output = fake_llm_math_chain.run(question)
assert output == f"Answer: {2**.5}" assert output == f"Answer: {2**0.5}"
@pytest.mark.requires("numexpr") @pytest.mark.requires("numexpr")

View File

@ -151,21 +151,21 @@ def test_json_equality_evaluator_evaluate_lists_permutation_invariant() -> None:
# Limit tests # Limit tests
prediction = ( prediction = (
"[" + ",".join([f'{{"a": {i}, "b": {i+1}}}' for i in range(1000)]) + "]" "[" + ",".join([f'{{"a": {i}, "b": {i + 1}}}' for i in range(1000)]) + "]"
) )
rlist = [f'{{"a": {i}, "b": {i+1}}}' for i in range(1000)] rlist = [f'{{"a": {i}, "b": {i + 1}}}' for i in range(1000)]
random.shuffle(rlist) random.shuffle(rlist)
reference = "[" + ",".join(rlist) + "]" reference = "[" + ",".join(rlist) + "]"
result = evaluator.evaluate_strings(prediction=prediction, reference=reference) result = evaluator.evaluate_strings(prediction=prediction, reference=reference)
assert result == {"score": True} assert result == {"score": True}
prediction = ( prediction = (
"[" + ",".join([f'{{"b": {i+1}, "a": {i}}}' for i in range(1000)]) + "]" "[" + ",".join([f'{{"b": {i + 1}, "a": {i}}}' for i in range(1000)]) + "]"
) )
reference = ( reference = (
"[" "["
+ ",".join( + ",".join(
[f'{{"a": {i+1}, "b": {i+2}}}' for i in range(999)] [f'{{"a": {i + 1}, "b": {i + 2}}}' for i in range(999)]
+ ['{"a": 1000, "b": 1001}'] + ['{"a": 1000, "b": 1001}']
) )
+ "]" + "]"

View File

@ -8,7 +8,11 @@ from langchain_core.exceptions import OutputParserException
from langchain.output_parsers.pandas_dataframe import PandasDataFrameOutputParser from langchain.output_parsers.pandas_dataframe import PandasDataFrameOutputParser
df = pd.DataFrame( df = pd.DataFrame(
{"chicken": [1, 2, 3, 4], "veggies": [5, 4, 3, 2], "steak": [9, 8, 7, 6]} {
"chicken": [1, 2, 3, 4],
"veggies": [5, 4, 3, 2],
"steak": [9, 8, 7, 6],
}
) )
parser = PandasDataFrameOutputParser(dataframe=df) parser = PandasDataFrameOutputParser(dataframe=df)

View File

@ -10,7 +10,10 @@ from langchain.runnables.hub import HubRunnable
@patch("langchain.hub.pull") @patch("langchain.hub.pull")
def test_hub_runnable(mock_pull: Mock) -> None: def test_hub_runnable(mock_pull: Mock) -> None:
mock_pull.return_value = ChatPromptTemplate.from_messages( mock_pull.return_value = ChatPromptTemplate.from_messages(
[("system", "a"), ("user", "b")] [
("system", "a"),
("user", "b"),
]
) )
basic: HubRunnable = HubRunnable("efriis/my-prompt") basic: HubRunnable = HubRunnable("efriis/my-prompt")
@ -21,10 +24,16 @@ def test_hub_runnable(mock_pull: Mock) -> None:
repo_dict = { repo_dict = {
"efriis/my-prompt-1": ChatPromptTemplate.from_messages( "efriis/my-prompt-1": ChatPromptTemplate.from_messages(
[("system", "a"), ("user", "1")] [
("system", "a"),
("user", "1"),
]
), ),
"efriis/my-prompt-2": ChatPromptTemplate.from_messages( "efriis/my-prompt-2": ChatPromptTemplate.from_messages(
[("system", "a"), ("user", "2")] [
("system", "a"),
("user", "2"),
]
), ),
} }

View File

@ -43,9 +43,9 @@ def test_openai_functions_router(
snapshot: SnapshotAssertion, mocker: MockerFixture snapshot: SnapshotAssertion, mocker: MockerFixture
) -> None: ) -> None:
revise = mocker.Mock( revise = mocker.Mock(
side_effect=lambda kw: f'Revised draft: no more {kw["notes"]}!' side_effect=lambda kw: f"Revised draft: no more {kw['notes']}!"
) )
accept = mocker.Mock(side_effect=lambda kw: f'Accepted draft: {kw["draft"]}!') accept = mocker.Mock(side_effect=lambda kw: f"Accepted draft: {kw['draft']}!")
router = OpenAIFunctionsRouter( router = OpenAIFunctionsRouter(
{ {

View File

@ -316,12 +316,15 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
proj.id = "123" proj.id = "123"
return proj return proj
with mock.patch.object( with (
Client, "read_dataset", new=mock_read_dataset mock.patch.object(Client, "read_dataset", new=mock_read_dataset),
), mock.patch.object(Client, "list_examples", new=mock_list_examples), mock.patch( mock.patch.object(Client, "list_examples", new=mock_list_examples),
"langchain.smith.evaluation.runner_utils._arun_llm_or_chain", mock.patch(
new=mock_arun_chain, "langchain.smith.evaluation.runner_utils._arun_llm_or_chain",
), mock.patch.object(Client, "create_project", new=mock_create_project): new=mock_arun_chain,
),
mock.patch.object(Client, "create_project", new=mock_create_project),
):
client = Client(api_url="http://localhost:1984", api_key="123") client = Client(api_url="http://localhost:1984", api_key="123")
chain = mock.MagicMock() chain = mock.MagicMock()
chain.input_keys = ["foothing"] chain.input_keys = ["foothing"]