core: upgrade mypy to recent mypy (#18753)

Testing this works per package on CI
This commit is contained in:
Eugene Yurtsev
2024-03-07 15:25:19 -05:00
committed by GitHub
parent e188d4ecb0
commit 8c71f92cb2
23 changed files with 172 additions and 155 deletions

View File

@@ -67,7 +67,7 @@ def create_chat_prompt_template() -> ChatPromptTemplate:
"""Create a chat prompt template."""
return ChatPromptTemplate(
input_variables=["foo", "bar", "context"],
messages=create_messages(),
messages=create_messages(), # type: ignore[arg-type]
)
@@ -191,10 +191,12 @@ def test_chat_invalid_input_variables_extra() -> None:
messages = [HumanMessage(content="foo")]
with pytest.raises(ValueError):
ChatPromptTemplate(
messages=messages, input_variables=["foo"], validate_template=True
messages=messages, # type: ignore[arg-type]
input_variables=["foo"],
validate_template=True, # type: ignore[arg-type]
)
assert (
ChatPromptTemplate(messages=messages, input_variables=["foo"]).input_variables
ChatPromptTemplate(messages=messages, input_variables=["foo"]).input_variables # type: ignore[arg-type]
== []
)
@@ -203,16 +205,19 @@ def test_chat_invalid_input_variables_missing() -> None:
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
with pytest.raises(ValueError):
ChatPromptTemplate(
messages=messages, input_variables=[], validate_template=True
messages=messages, # type: ignore[arg-type]
input_variables=[],
validate_template=True, # type: ignore[arg-type]
)
assert ChatPromptTemplate(
messages=messages, input_variables=[]
messages=messages, # type: ignore[arg-type]
input_variables=[], # type: ignore[arg-type]
).input_variables == ["foo"]
def test_infer_variables() -> None:
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
prompt = ChatPromptTemplate(messages=messages)
prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type, call-arg]
assert prompt.input_variables == ["foo"]
@@ -223,7 +228,7 @@ def test_chat_valid_with_partial_variables() -> None:
)
]
prompt = ChatPromptTemplate(
messages=messages,
messages=messages, # type: ignore[arg-type]
input_variables=["question", "context"],
partial_variables={"formatins": "some structure"},
)
@@ -237,8 +242,9 @@ def test_chat_valid_infer_variables() -> None:
"Do something with {question} using {context} giving it like {formatins}"
)
]
prompt = ChatPromptTemplate(
messages=messages, partial_variables={"formatins": "some structure"}
prompt = ChatPromptTemplate( # type: ignore[call-arg]
messages=messages, # type: ignore[arg-type]
partial_variables={"formatins": "some structure"}, # type: ignore[arg-type]
)
assert set(prompt.input_variables) == {"question", "context"}
assert prompt.partial_variables == {"formatins": "some structure"}

View File

@@ -6,7 +6,7 @@ from langchain_core.prompts.prompt import PromptTemplate
def test_get_input_variables() -> None:
prompt_a = PromptTemplate.from_template("{foo}")
prompt_b = PromptTemplate.from_template("{bar}")
pipeline_prompt = PipelinePromptTemplate(
pipeline_prompt = PipelinePromptTemplate( # type: ignore[call-arg]
final_prompt=prompt_b, pipeline_prompts=[("bar", prompt_a)]
)
assert pipeline_prompt.input_variables == ["foo"]
@@ -15,7 +15,7 @@ def test_get_input_variables() -> None:
def test_simple_pipeline() -> None:
prompt_a = PromptTemplate.from_template("{foo}")
prompt_b = PromptTemplate.from_template("{bar}")
pipeline_prompt = PipelinePromptTemplate(
pipeline_prompt = PipelinePromptTemplate( # type: ignore[call-arg]
final_prompt=prompt_b, pipeline_prompts=[("bar", prompt_a)]
)
output = pipeline_prompt.format(foo="jim")
@@ -25,7 +25,7 @@ def test_simple_pipeline() -> None:
def test_multi_variable_pipeline() -> None:
prompt_a = PromptTemplate.from_template("{foo}")
prompt_b = PromptTemplate.from_template("okay {bar} {baz}")
pipeline_prompt = PipelinePromptTemplate(
pipeline_prompt = PipelinePromptTemplate( # type: ignore[call-arg]
final_prompt=prompt_b, pipeline_prompts=[("bar", prompt_a)]
)
output = pipeline_prompt.format(foo="jim", baz="deep")
@@ -37,7 +37,7 @@ def test_partial_with_chat_prompts() -> None:
input_variables=["foo"], messages=[MessagesPlaceholder(variable_name="foo")]
)
prompt_b = ChatPromptTemplate.from_template("jim {bar}")
pipeline_prompt = PipelinePromptTemplate(
pipeline_prompt = PipelinePromptTemplate( # type: ignore[call-arg]
final_prompt=prompt_a, pipeline_prompts=[("foo", prompt_b)]
)
assert pipeline_prompt.input_variables == ["bar"]

View File

@@ -126,7 +126,9 @@ def test_prompt_invalid_template_format() -> None:
input_variables = ["foo"]
with pytest.raises(ValueError):
PromptTemplate(
input_variables=input_variables, template=template, template_format="bar"
input_variables=input_variables,
template=template,
template_format="bar", # type: ignore[arg-type]
)

View File

@@ -758,7 +758,7 @@ def test_validation_error_handling_non_validation_error(
async def _arun(self) -> str:
return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
with pytest.raises(NotImplementedError):
_tool.run({})
@@ -820,7 +820,7 @@ async def test_async_validation_error_handling_non_validation_error(
async def _arun(self) -> str:
return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
with pytest.raises(NotImplementedError):
await _tool.arun({})

View File

@@ -53,7 +53,7 @@ def _compare_run_with_error(run: Any, expected_run: Any) -> None:
def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run."""
uuid = uuid4()
compare_run = Run(
compare_run = Run( # type: ignore[call-arg]
id=uuid,
parent_run_id=None,
start_time=datetime.now(timezone.utc),
@@ -67,7 +67,7 @@ def test_tracer_llm_run() -> None:
child_execution_order=1,
serialized=SERIALIZED,
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
trace_id=uuid,
@@ -89,7 +89,7 @@ def test_tracer_chat_model_run() -> None:
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
)
compare_run = Run(
id=str(run_managers[0].run_id),
id=str(run_managers[0].run_id), # type: ignore[arg-type]
name="chat_model",
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
@@ -102,7 +102,7 @@ def test_tracer_chat_model_run() -> None:
child_execution_order=1,
serialized=SERIALIZED_CHAT,
inputs=dict(prompts=["Human: "]),
outputs=LLMResult(generations=[[]]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
trace_id=run_managers[0].run_id,
@@ -140,7 +140,7 @@ def test_tracer_multiple_llm_runs() -> None:
child_execution_order=1,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
trace_id=uuid,
@@ -160,8 +160,8 @@ def test_tracer_multiple_llm_runs() -> None:
def test_tracer_chain_run() -> None:
"""Test tracer on a Chain run."""
uuid = uuid4()
compare_run = Run(
id=str(uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@@ -190,8 +190,8 @@ def test_tracer_chain_run() -> None:
def test_tracer_tool_run() -> None:
"""Test tracer on a Tool run."""
uuid = uuid4()
compare_run = Run(
id=str(uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@@ -251,8 +251,8 @@ def test_tracer_nested_run() -> None:
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
compare_run = Run(
id=str(chain_uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(chain_uuid), # type: ignore[arg-type]
error=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
@@ -270,7 +270,7 @@ def test_tracer_nested_run() -> None:
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}",
child_runs=[
Run(
Run( # type: ignore[call-arg]
id=tool_uuid,
parent_run_id=chain_uuid,
start_time=datetime.now(timezone.utc),
@@ -290,9 +290,9 @@ def test_tracer_nested_run() -> None:
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}",
child_runs=[
Run(
id=str(llm_uuid1),
parent_run_id=str(tool_uuid),
Run( # type: ignore[call-arg]
id=str(llm_uuid1), # type: ignore[arg-type]
parent_run_id=str(tool_uuid), # type: ignore[arg-type]
error=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
@@ -305,16 +305,16 @@ def test_tracer_nested_run() -> None:
child_execution_order=3,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}.20230101T000000000000Z{llm_uuid1}",
)
],
),
Run(
id=str(llm_uuid2),
parent_run_id=str(chain_uuid),
Run( # type: ignore[call-arg]
id=str(llm_uuid2), # type: ignore[arg-type]
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
error=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
@@ -327,7 +327,7 @@ def test_tracer_nested_run() -> None:
child_execution_order=4,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{llm_uuid2}",
@@ -344,8 +344,8 @@ def test_tracer_llm_run_on_error() -> None:
exception = Exception("test")
uuid = uuid4()
compare_run = Run(
id=str(uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@@ -377,8 +377,8 @@ def test_tracer_llm_run_on_error_callback() -> None:
exception = Exception("test")
uuid = uuid4()
compare_run = Run(
id=str(uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@@ -415,8 +415,8 @@ def test_tracer_chain_run_on_error() -> None:
exception = Exception("test")
uuid = uuid4()
compare_run = Run(
id=str(uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@@ -447,8 +447,8 @@ def test_tracer_tool_run_on_error() -> None:
exception = Exception("test")
uuid = uuid4()
compare_run = Run(
id=str(uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@@ -520,8 +520,8 @@ def test_tracer_nested_runs_on_error() -> None:
tracer.on_tool_error(exception, run_id=tool_uuid)
tracer.on_chain_error(exception, run_id=chain_uuid)
compare_run = Run(
id=str(chain_uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(chain_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@@ -539,9 +539,9 @@ def test_tracer_nested_runs_on_error() -> None:
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}",
child_runs=[
Run(
id=str(llm_uuid1),
parent_run_id=str(chain_uuid),
Run( # type: ignore[call-arg]
id=str(llm_uuid1), # type: ignore[arg-type]
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@@ -554,14 +554,14 @@ def test_tracer_nested_runs_on_error() -> None:
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None),
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{llm_uuid1}",
),
Run(
id=str(llm_uuid2),
parent_run_id=str(chain_uuid),
Run( # type: ignore[call-arg]
id=str(llm_uuid2), # type: ignore[arg-type]
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@@ -574,14 +574,14 @@ def test_tracer_nested_runs_on_error() -> None:
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None),
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{llm_uuid2}",
),
Run(
id=str(tool_uuid),
parent_run_id=str(chain_uuid),
Run( # type: ignore[call-arg]
id=str(tool_uuid), # type: ignore[arg-type]
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@@ -599,9 +599,9 @@ def test_tracer_nested_runs_on_error() -> None:
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}",
child_runs=[
Run(
id=str(llm_uuid3),
parent_run_id=str(tool_uuid),
Run( # type: ignore[call-arg]
id=str(llm_uuid3), # type: ignore[arg-type]
parent_run_id=str(tool_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[

View File

@@ -459,8 +459,8 @@ def test_convert_run(
sample_tracer_session_v1: TracerSessionV1,
) -> None:
"""Test converting a run to a V1 run."""
llm_run = Run(
id="57a08cc4-73d2-4236-8370-549099d07fad",
llm_run = Run( # type: ignore[call-arg]
id="57a08cc4-73d2-4236-8370-549099d07fad", # type: ignore[arg-type]
name="llm_run",
execution_order=1,
child_execution_order=1,
@@ -474,7 +474,7 @@ def test_convert_run(
run_type="llm",
)
chain_run = Run(
id="57a08cc4-73d2-4236-8371-549099d07fad",
id="57a08cc4-73d2-4236-8371-549099d07fad", # type: ignore[arg-type]
name="chain_run",
execution_order=1,
start_time=datetime.now(timezone.utc),
@@ -489,7 +489,7 @@ def test_convert_run(
)
tool_run = Run(
id="57a08cc4-73d2-4236-8372-549099d07fad",
id="57a08cc4-73d2-4236-8372-549099d07fad", # type: ignore[arg-type]
name="tool_run",
execution_order=1,
child_execution_order=1,
@@ -503,7 +503,7 @@ def test_convert_run(
run_type="tool",
)
expected_llm_run = LLMRun(
expected_llm_run = LLMRun( # type: ignore[call-arg]
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
@@ -517,7 +517,7 @@ def test_convert_run(
extra={},
)
expected_chain_run = ChainRun(
expected_chain_run = ChainRun( # type: ignore[call-arg]
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
@@ -533,7 +533,7 @@ def test_convert_run(
child_tool_runs=[],
extra={},
)
expected_tool_run = ToolRun(
expected_tool_run = ToolRun( # type: ignore[call-arg]
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,