langchain: Bump ruff version to 0.9 (#29211)

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Christophe Bornet 2025-01-22 01:26:39 +01:00 committed by GitHub
parent 2340b3154d
commit a004dec119
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 307 additions and 250 deletions

View File

@ -84,8 +84,7 @@ def initialize_agent(
pass
else:
raise ValueError(
"Somehow both `agent` and `agent_path` are None, "
"this should never happen."
"Somehow both `agent` and `agent_path` are None, this should never happen."
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,

View File

@ -58,8 +58,7 @@ def load_agent_from_config(
if load_from_tools:
if llm is None:
raise ValueError(
"If `load_from_llm_and_tools` is set to True, "
"then LLM must be provided"
"If `load_from_llm_and_tools` is set to True, then LLM must be provided"
)
if tools is None:
raise ValueError(

View File

@ -41,6 +41,6 @@ class LoggingCallbackHandler(FunctionCallbackHandler):
except TracerException:
crumbs_str = ""
self.function_callback(
f'{get_colored_text("[text]", color="blue")}'
f' {get_bolded_text(f"{crumbs_str}New text:")}\n{text}'
f"{get_colored_text('[text]', color='blue')}"
f" {get_bolded_text(f'{crumbs_str}New text:')}\n{text}"
)

View File

@ -370,6 +370,7 @@ try:
@property
def _chain_type(self) -> str:
return "api_chain"
except ImportError:
class APIChain: # type: ignore[no-redef]

View File

@ -68,7 +68,10 @@ def create_extraction_chain_pydantic(
if not isinstance(pydantic_schemas, list):
pydantic_schemas = [pydantic_schemas]
prompt = ChatPromptTemplate.from_messages(
[("system", system_message), ("user", "{input}")]
[
("system", system_message),
("user", "{input}"),
]
)
functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas]
tools = [{"type": "function", "function": d} for d in functions]

View File

@ -33,7 +33,11 @@ refine_template = (
"If the context isn't useful, return the original answer."
)
CHAT_REFINE_PROMPT = ChatPromptTemplate.from_messages(
[("human", "{question}"), ("ai", "{existing_answer}"), ("human", refine_template)]
[
("human", "{question}"),
("ai", "{existing_answer}"),
("human", refine_template),
]
)
REFINE_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=DEFAULT_REFINE_PROMPT,
@ -60,7 +64,10 @@ chat_qa_prompt_template = (
"answer any questions"
)
CHAT_QUESTION_PROMPT = ChatPromptTemplate.from_messages(
[("system", chat_qa_prompt_template), ("human", "{question}")]
[
("system", chat_qa_prompt_template),
("human", "{question}"),
]
)
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=DEFAULT_TEXT_QA_PROMPT,

View File

@ -178,7 +178,9 @@ class SimpleSequentialChain(Chain):
_input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
for i, chain in enumerate(self.chains):
_input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}"))
_input = chain.run(
_input, callbacks=_run_manager.get_child(f"step_{i + 1}")
)
if self.strip_outputs:
_input = _input.strip()
_run_manager.on_text(

View File

@ -590,7 +590,11 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
queued_declarative_operations = list(self._queued_declarative_operations)
if remaining_config:
queued_declarative_operations.append(
("with_config", (), {"config": remaining_config})
(
"with_config",
(),
{"config": remaining_config},
)
)
return _ConfigurableModel(
default_config={**self._default_config, **model_params},

View File

@ -174,8 +174,7 @@ def init_embeddings(
if not model:
providers = _SUPPORTED_PROVIDERS.keys()
raise ValueError(
"Must specify model name. "
f"Supported providers are: {', '.join(providers)}"
f"Must specify model name. Supported providers are: {', '.join(providers)}"
)
provider, model_name = _infer_model_and_provider(model, provider=provider)

View File

@ -310,7 +310,10 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
Dict[str, Any]: The computed score.
"""
embedded = await self.embeddings.aembed_documents(
[inputs["prediction"], inputs["reference"]]
[
inputs["prediction"],
inputs["reference"],
]
)
vectors = np.array(embedded)
score = self._compute_score(vectors)
@ -427,7 +430,10 @@ class PairwiseEmbeddingDistanceEvalChain(
"""
vectors = np.array(
self.embeddings.embed_documents(
[inputs["prediction"], inputs["prediction_b"]]
[
inputs["prediction"],
inputs["prediction_b"],
]
)
)
score = self._compute_score(vectors)
@ -449,7 +455,10 @@ class PairwiseEmbeddingDistanceEvalChain(
Dict[str, Any]: The computed score.
"""
embedded = await self.embeddings.aembed_documents(
[inputs["prediction"], inputs["prediction_b"]]
[
inputs["prediction"],
inputs["prediction_b"],
]
)
vectors = np.array(embedded)
score = self._compute_score(vectors)

View File

@ -71,7 +71,10 @@ class BaseChatMemory(BaseMemory, ABC):
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_messages(
[HumanMessage(content=input_str), AIMessage(content=output_str)]
[
HumanMessage(content=input_str),
AIMessage(content=output_str),
]
)
async def asave_context(
@ -80,7 +83,10 @@ class BaseChatMemory(BaseMemory, ABC):
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
await self.chat_memory.aadd_messages(
[HumanMessage(content=input_str), AIMessage(content=output_str)]
[
HumanMessage(content=input_str),
AIMessage(content=output_str),
]
)
def clear(self) -> None:

View File

@ -92,7 +92,10 @@ class CohereRerank(BaseDocumentCompressor):
result_dicts = []
for res in results:
result_dicts.append(
{"index": res.index, "relevance_score": res.relevance_score}
{
"index": res.index,
"relevance_score": res.relevance_score,
}
)
return result_dicts

View File

@ -265,7 +265,8 @@ class EnsembleRetriever(BaseRetriever):
retriever.ainvoke(
query,
patch_config(
config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
config,
callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
),
)
for i, retriever in enumerate(self.retrievers)

View File

@ -247,8 +247,7 @@ def _get_prompt(inputs: Dict[str, Any]) -> str:
if "prompt" in inputs:
if not isinstance(inputs["prompt"], str):
raise InputFormatError(
"Expected string for 'prompt', got"
f" {type(inputs['prompt']).__name__}"
f"Expected string for 'prompt', got {type(inputs['prompt']).__name__}"
)
prompts = [inputs["prompt"]]
elif "prompts" in inputs:

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@ -704,13 +704,13 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""}
[[package]]
name = "codespell"
version = "2.3.0"
description = "Codespell"
version = "2.4.0"
description = "Fix common misspellings in text files"
optional = false
python-versions = ">=3.8"
files = [
{file = "codespell-2.3.0-py3-none-any.whl", hash = "sha256:a9c7cef2501c9cfede2110fd6d4e5e62296920efe9abfb84648df866e47f58d1"},
{file = "codespell-2.3.0.tar.gz", hash = "sha256:360c7d10f75e65f67bad720af7007e1060a5d395670ec11a7ed1fed9dd17471f"},
{file = "codespell-2.4.0-py3-none-any.whl", hash = "sha256:b4c5b779f747dd481587aeecb5773301183f52b94b96ed51a28126d0482eec1d"},
{file = "codespell-2.4.0.tar.gz", hash = "sha256:587d45b14707fb8ce51339ba4cce50ae0e98ce228ef61f3c5e160e34f681be58"},
]
[package.extras]
@ -4089,29 +4089,29 @@ files = [
[[package]]
name = "ruff"
version = "0.5.7"
version = "0.9.2"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
{file = "ruff-0.5.7-py3-none-linux_armv6l.whl", hash = "sha256:548992d342fc404ee2e15a242cdbea4f8e39a52f2e7752d0e4cbe88d2d2f416a"},
{file = "ruff-0.5.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:00cc8872331055ee017c4f1071a8a31ca0809ccc0657da1d154a1d2abac5c0be"},
{file = "ruff-0.5.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eaf3d86a1fdac1aec8a3417a63587d93f906c678bb9ed0b796da7b59c1114a1e"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a01c34400097b06cf8a6e61b35d6d456d5bd1ae6961542de18ec81eaf33b4cb8"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcc8054f1a717e2213500edaddcf1dbb0abad40d98e1bd9d0ad364f75c763eea"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f70284e73f36558ef51602254451e50dd6cc479f8b6f8413a95fcb5db4a55fc"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:a78ad870ae3c460394fc95437d43deb5c04b5c29297815a2a1de028903f19692"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ccd078c66a8e419475174bfe60a69adb36ce04f8d4e91b006f1329d5cd44bcf"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e31c9bad4ebf8fdb77b59cae75814440731060a09a0e0077d559a556453acbb"},
{file = "ruff-0.5.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d796327eed8e168164346b769dd9a27a70e0298d667b4ecee6877ce8095ec8e"},
{file = "ruff-0.5.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a09ea2c3f7778cc635e7f6edf57d566a8ee8f485f3c4454db7771efb692c499"},
{file = "ruff-0.5.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a36d8dcf55b3a3bc353270d544fb170d75d2dff41eba5df57b4e0b67a95bb64e"},
{file = "ruff-0.5.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9369c218f789eefbd1b8d82a8cf25017b523ac47d96b2f531eba73770971c9e5"},
{file = "ruff-0.5.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b88ca3db7eb377eb24fb7c82840546fb7acef75af4a74bd36e9ceb37a890257e"},
{file = "ruff-0.5.7-py3-none-win32.whl", hash = "sha256:33d61fc0e902198a3e55719f4be6b375b28f860b09c281e4bdbf783c0566576a"},
{file = "ruff-0.5.7-py3-none-win_amd64.whl", hash = "sha256:083bbcbe6fadb93cd86709037acc510f86eed5a314203079df174c40bbbca6b3"},
{file = "ruff-0.5.7-py3-none-win_arm64.whl", hash = "sha256:2dca26154ff9571995107221d0aeaad0e75a77b5a682d6236cf89a58c70b76f4"},
{file = "ruff-0.5.7.tar.gz", hash = "sha256:8dfc0a458797f5d9fb622dd0efc52d796f23f0a1493a9527f4e49a550ae9a7e5"},
{file = "ruff-0.9.2-py3-none-linux_armv6l.whl", hash = "sha256:80605a039ba1454d002b32139e4970becf84b5fee3a3c3bf1c2af6f61a784347"},
{file = "ruff-0.9.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b9aab82bb20afd5f596527045c01e6ae25a718ff1784cb92947bff1f83068b00"},
{file = "ruff-0.9.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fbd337bac1cfa96be615f6efcd4bc4d077edbc127ef30e2b8ba2a27e18c054d4"},
{file = "ruff-0.9.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82b35259b0cbf8daa22a498018e300b9bb0174c2bbb7bcba593935158a78054d"},
{file = "ruff-0.9.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b6a9701d1e371bf41dca22015c3f89769da7576884d2add7317ec1ec8cb9c3c"},
{file = "ruff-0.9.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9cc53e68b3c5ae41e8faf83a3b89f4a5d7b2cb666dff4b366bb86ed2a85b481f"},
{file = "ruff-0.9.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:8efd9da7a1ee314b910da155ca7e8953094a7c10d0c0a39bfde3fcfd2a015684"},
{file = "ruff-0.9.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3292c5a22ea9a5f9a185e2d131dc7f98f8534a32fb6d2ee7b9944569239c648d"},
{file = "ruff-0.9.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a605fdcf6e8b2d39f9436d343d1f0ff70c365a1e681546de0104bef81ce88df"},
{file = "ruff-0.9.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c547f7f256aa366834829a08375c297fa63386cbe5f1459efaf174086b564247"},
{file = "ruff-0.9.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d18bba3d3353ed916e882521bc3e0af403949dbada344c20c16ea78f47af965e"},
{file = "ruff-0.9.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b338edc4610142355ccf6b87bd356729b62bf1bc152a2fad5b0c7dc04af77bfe"},
{file = "ruff-0.9.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:492a5e44ad9b22a0ea98cf72e40305cbdaf27fac0d927f8bc9e1df316dcc96eb"},
{file = "ruff-0.9.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:af1e9e9fe7b1f767264d26b1075ac4ad831c7db976911fa362d09b2d0356426a"},
{file = "ruff-0.9.2-py3-none-win32.whl", hash = "sha256:71cbe22e178c5da20e1514e1e01029c73dc09288a8028a5d3446e6bba87a5145"},
{file = "ruff-0.9.2-py3-none-win_amd64.whl", hash = "sha256:c5e1d6abc798419cf46eed03f54f2e0c3adb1ad4b801119dedf23fcaf69b55b5"},
{file = "ruff-0.9.2-py3-none-win_arm64.whl", hash = "sha256:a1b63fa24149918f8b37cef2ee6fff81f24f0d74b6f0bdc37bc3e1f2143e41c6"},
{file = "ruff-0.9.2.tar.gz", hash = "sha256:b5eceb334d55fae5f316f783437392642ae18e16dcf4f1858d55d3c2a0f8f5d0"},
]
[[package]]
@ -4689,13 +4689,13 @@ files = [
[[package]]
name = "tzdata"
version = "2024.2"
version = "2025.1"
description = "Provider of IANA time zone data"
optional = false
python-versions = ">=2"
files = [
{file = "tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd"},
{file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"},
{file = "tzdata-2025.1-py2.py3-none-any.whl", hash = "sha256:7e127113816800496f027041c570f50bcd464a020098a3b6b199517772303639"},
{file = "tzdata-2025.1.tar.gz", hash = "sha256:24894909e88cdb28bd1636c6887801df64cb485bd593f2fd83ef29075a81d694"},
]
[[package]]
@ -5076,4 +5076,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<4.0"
content-hash = "7dd55a6e29a188f48f59469ce0ca25ecb2d5a85df82121e181462cddb3f14fbc"
content-hash = "8973c8570d06948300d1030dfae674d33ba8f194746413d6930c0bb0af593b76"

View File

@ -12,6 +12,7 @@ readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
[tool.ruff]
target-version = "py39"
exclude = [ "tests/integration_tests/examples/non-utf8-encoding.py",]
[tool.mypy]
@ -119,7 +120,7 @@ cassio = "^0.1.0"
langchainhub = "^0.1.16"
[tool.poetry.group.lint.dependencies]
ruff = "^0.5"
ruff = "^0.9.2"
[[tool.poetry.group.lint.dependencies.cffi]]
version = "<1.17.1"
python = "<3.10"

View File

@ -451,7 +451,10 @@ async def test_runnable_agent() -> None:
model = GenericFakeChatModel(messages=infinite_cycle)
template = ChatPromptTemplate.from_messages(
[("system", "You are Cat Agent 007"), ("human", "{question}")]
[
("system", "You are Cat Agent 007"),
("human", "{question}"),
]
)
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
@ -539,12 +542,18 @@ async def test_runnable_agent_with_function_calls() -> None:
"""Test agent with intermediate agent actions."""
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle(
[AIMessage(content="looking for pet..."), AIMessage(content="Found Pet")]
[
AIMessage(content="looking for pet..."),
AIMessage(content="Found Pet"),
]
)
model = GenericFakeChatModel(messages=infinite_cycle)
template = ChatPromptTemplate.from_messages(
[("system", "You are Cat Agent 007"), ("human", "{question}")]
[
("system", "You are Cat Agent 007"),
("human", "{question}"),
]
)
parser_responses = cycle(
@ -635,12 +644,18 @@ async def test_runnable_with_multi_action_per_step() -> None:
"""Test an agent that can make multiple function calls at once."""
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle(
[AIMessage(content="looking for pet..."), AIMessage(content="Found Pet")]
[
AIMessage(content="looking for pet..."),
AIMessage(content="Found Pet"),
]
)
model = GenericFakeChatModel(messages=infinite_cycle)
template = ChatPromptTemplate.from_messages(
[("system", "You are Cat Agent 007"), ("human", "{question}")]
[
("system", "You are Cat Agent 007"),
("human", "{question}"),
]
)
parser_responses = cycle(
@ -861,7 +876,7 @@ async def test_openai_agent_with_streaming() -> None:
{
"additional_kwargs": {
"function_call": {
"arguments": '{"pet": ' '"cat"}',
"arguments": '{"pet": "cat"}',
"name": "find_pet",
}
},
@ -880,7 +895,7 @@ async def test_openai_agent_with_streaming() -> None:
{
"additional_kwargs": {
"function_call": {
"arguments": '{"pet": ' '"cat"}',
"arguments": '{"pet": "cat"}',
"name": "find_pet",
}
},
@ -909,10 +924,7 @@ async def test_openai_agent_with_streaming() -> None:
"steps": [
{
"action": {
"log": "\n"
"Invoking: `find_pet` with `{'pet': 'cat'}`\n"
"\n"
"\n",
"log": "\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
"tool": "find_pet",
"tool_input": {"pet": "cat"},
"type": "AgentActionMessageLog",
@ -1062,9 +1074,7 @@ async def test_openai_agent_tools_agent() -> None:
# astream
chunks = [chunk async for chunk in executor.astream({"question": "hello"})]
assert (
chunks
== [
assert chunks == [
{
"actions": [
OpenAIToolAgentAction(
@ -1268,7 +1278,6 @@ async def test_openai_agent_tools_agent() -> None:
"output": "The cat is spying from under the bed.",
},
]
)
# astream_log
log_patches = [

View File

@ -24,9 +24,7 @@ def get_action_and_input(text: str) -> Tuple[str, str]:
def test_get_action_and_input() -> None:
"""Test getting an action from text."""
llm_output = (
"Thought: I need to search for NBA\n" "Action: Search\n" "Action Input: NBA"
)
llm_output = "Thought: I need to search for NBA\nAction: Search\nAction Input: NBA"
action, action_input = get_action_and_input(llm_output)
assert action == "Search"
assert action_input == "NBA"
@ -91,7 +89,7 @@ def test_get_action_and_input_sql_query() -> None:
def test_get_final_answer() -> None:
"""Test getting final answer."""
llm_output = "Thought: I can now answer the question\n" "Final Answer: 1994"
llm_output = "Thought: I can now answer the question\nFinal Answer: 1994"
action, action_input = get_action_and_input(llm_output)
assert action == "Final Answer"
assert action_input == "1994"
@ -99,7 +97,7 @@ def test_get_final_answer() -> None:
def test_get_final_answer_new_line() -> None:
"""Test getting final answer."""
llm_output = "Thought: I can now answer the question\n" "Final Answer:\n1994"
llm_output = "Thought: I can now answer the question\nFinal Answer:\n1994"
action, action_input = get_action_and_input(llm_output)
assert action == "Final Answer"
assert action_input == "1994"
@ -107,7 +105,7 @@ def test_get_final_answer_new_line() -> None:
def test_get_final_answer_multiline() -> None:
"""Test getting final answer that is multiline."""
llm_output = "Thought: I can now answer the question\n" "Final Answer: 1994\n1993"
llm_output = "Thought: I can now answer the question\nFinal Answer: 1994\n1993"
action, action_input = get_action_and_input(llm_output)
assert action == "Final Answer"
assert action_input == "1994\n1993"
@ -115,7 +113,7 @@ def test_get_final_answer_multiline() -> None:
def test_bad_action_input_line() -> None:
"""Test handling when no action input found."""
llm_output = "Thought: I need to search for NBA\n" "Action: Search\n" "Thought: NBA"
llm_output = "Thought: I need to search for NBA\nAction: Search\nThought: NBA"
with pytest.raises(OutputParserException) as e_info:
get_action_and_input(llm_output)
assert e_info.value.observation is not None
@ -123,9 +121,7 @@ def test_bad_action_input_line() -> None:
def test_bad_action_line() -> None:
"""Test handling when no action found."""
llm_output = (
"Thought: I need to search for NBA\n" "Thought: Search\n" "Action Input: NBA"
)
llm_output = "Thought: I need to search for NBA\nThought: Search\nAction Input: NBA"
with pytest.raises(OutputParserException) as e_info:
get_action_and_input(llm_output)
assert e_info.value.observation is not None

View File

@ -22,6 +22,6 @@ def test_critique_parsing() -> None:
for text in [TEXT_ONE, TEXT_TWO, TEXT_THREE]:
critique = ConstitutionalChain._parse_critique(text)
assert (
critique.strip() == "This text is bad."
), f"Failed on {text} with {critique}"
assert critique.strip() == "This text is bad.", (
f"Failed on {text} with {critique}"
)

View File

@ -21,6 +21,9 @@ def test_create() -> None:
expected_output = [Document(page_content="I know the answer!")]
output = chain.invoke(
{"input": "What is the answer?", "chat_history": ["hi", "hi"]}
{
"input": "What is the answer?",
"chat_history": ["hi", "hi"],
}
)
assert output == expected_output

View File

@ -33,7 +33,7 @@ def test_complex_question(fake_llm_math_chain: LLMMathChain) -> None:
"""Test complex question that should need python."""
question = "What is the square root of 2?"
output = fake_llm_math_chain.run(question)
assert output == f"Answer: {2**.5}"
assert output == f"Answer: {2**0.5}"
@pytest.mark.requires("numexpr")

View File

@ -8,7 +8,11 @@ from langchain_core.exceptions import OutputParserException
from langchain.output_parsers.pandas_dataframe import PandasDataFrameOutputParser
df = pd.DataFrame(
{"chicken": [1, 2, 3, 4], "veggies": [5, 4, 3, 2], "steak": [9, 8, 7, 6]}
{
"chicken": [1, 2, 3, 4],
"veggies": [5, 4, 3, 2],
"steak": [9, 8, 7, 6],
}
)
parser = PandasDataFrameOutputParser(dataframe=df)

View File

@ -10,7 +10,10 @@ from langchain.runnables.hub import HubRunnable
@patch("langchain.hub.pull")
def test_hub_runnable(mock_pull: Mock) -> None:
mock_pull.return_value = ChatPromptTemplate.from_messages(
[("system", "a"), ("user", "b")]
[
("system", "a"),
("user", "b"),
]
)
basic: HubRunnable = HubRunnable("efriis/my-prompt")
@ -21,10 +24,16 @@ def test_hub_runnable(mock_pull: Mock) -> None:
repo_dict = {
"efriis/my-prompt-1": ChatPromptTemplate.from_messages(
[("system", "a"), ("user", "1")]
[
("system", "a"),
("user", "1"),
]
),
"efriis/my-prompt-2": ChatPromptTemplate.from_messages(
[("system", "a"), ("user", "2")]
[
("system", "a"),
("user", "2"),
]
),
}

View File

@ -43,9 +43,9 @@ def test_openai_functions_router(
snapshot: SnapshotAssertion, mocker: MockerFixture
) -> None:
revise = mocker.Mock(
side_effect=lambda kw: f'Revised draft: no more {kw["notes"]}!'
side_effect=lambda kw: f"Revised draft: no more {kw['notes']}!"
)
accept = mocker.Mock(side_effect=lambda kw: f'Accepted draft: {kw["draft"]}!')
accept = mocker.Mock(side_effect=lambda kw: f"Accepted draft: {kw['draft']}!")
router = OpenAIFunctionsRouter(
{

View File

@ -316,12 +316,15 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
proj.id = "123"
return proj
with mock.patch.object(
Client, "read_dataset", new=mock_read_dataset
), mock.patch.object(Client, "list_examples", new=mock_list_examples), mock.patch(
with (
mock.patch.object(Client, "read_dataset", new=mock_read_dataset),
mock.patch.object(Client, "list_examples", new=mock_list_examples),
mock.patch(
"langchain.smith.evaluation.runner_utils._arun_llm_or_chain",
new=mock_arun_chain,
), mock.patch.object(Client, "create_project", new=mock_create_project):
),
mock.patch.object(Client, "create_project", new=mock_create_project),
):
client = Client(api_url="http://localhost:1984", api_key="123")
chain = mock.MagicMock()
chain.input_keys = ["foothing"]