core: upgrade mypy to recent mypy (#18753)

Testing this works per package on CI
This commit is contained in:
Eugene Yurtsev 2024-03-07 15:25:19 -05:00 committed by GitHub
parent e188d4ecb0
commit 8c71f92cb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 172 additions and 155 deletions

View File

@ -241,7 +241,7 @@ class ContextSet(RunnableSerializable):
):
if key is not None:
kwargs[key] = value
super().__init__(
super().__init__( # type: ignore[call-arg]
keys={
k: _coerce_set_value(v) if v is not None else None
for k, v in kwargs.items()

View File

@ -407,12 +407,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_managers[i].on_llm_error(e, response=LLMResult(generations=[]))
raise e
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item]
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
output = LLMResult(generations=generations, llm_output=llm_output) # type: ignore[arg-type]
if run_managers:
run_infos = []
for manager, flattened_output in zip(run_managers, flattened_outputs):
@ -504,7 +504,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
*[
run_manager.on_llm_end(
LLMResult(
generations=[res.generations], llm_output=res.llm_output
generations=[res.generations], # type: ignore[list-item, union-attr]
llm_output=res.llm_output, # type: ignore[list-item, union-attr]
)
)
for run_manager, res in zip(run_managers, results)
@ -513,12 +514,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
raise exceptions[0]
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item, union-attr]
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
llm_output = self._combine_llm_outputs([res.llm_output for res in results]) # type: ignore[union-attr]
generations = [res.generations for res in results] # type: ignore[union-attr]
output = LLMResult(generations=generations, llm_output=llm_output) # type: ignore[arg-type]
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)

View File

@ -932,9 +932,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
)
]
)
run_managers = [r[0] for r in run_managers]
run_managers = [r[0] for r in run_managers] # type: ignore[misc]
output = await self._agenerate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
prompts,
stop,
run_managers, # type: ignore[arg-type]
bool(new_arg_supported),
**kwargs, # type: ignore[arg-type]
)
return output
if len(missing_prompts) > 0:
@ -951,15 +955,19 @@ class BaseLLM(BaseLanguageModel[str], ABC):
for idx in missing_prompt_idxs
]
)
run_managers = [r[0] for r in run_managers]
run_managers = [r[0] for r in run_managers] # type: ignore[misc]
new_results = await self._agenerate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
missing_prompts,
stop,
run_managers, # type: ignore[arg-type]
bool(new_arg_supported),
**kwargs, # type: ignore[arg-type]
)
llm_output = await aupdate_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = (
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers] # type: ignore[attr-defined]
if run_managers
else None
)

View File

@ -51,7 +51,7 @@ class BaseMessage(Serializable):
def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain_core.prompts.chat import ChatPromptTemplate
prompt = ChatPromptTemplate(messages=[self])
prompt = ChatPromptTemplate(messages=[self]) # type: ignore[call-arg]
return prompt + other
def pretty_repr(self, html: bool = False) -> str:
@ -162,7 +162,7 @@ class BaseMessageChunk(BaseMessage):
# If both are (subclasses of) BaseMessageChunk,
# concat into a single BaseMessageChunk
return self.__class__(
return self.__class__( # type: ignore[call-arg]
id=self.id,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(

View File

@ -91,7 +91,7 @@ class BaseMessagePromptTemplate(Serializable, ABC):
Returns:
Combined prompt template.
"""
prompt = ChatPromptTemplate(messages=[self])
prompt = ChatPromptTemplate(messages=[self]) # type: ignore[call-arg]
return prompt + other
@ -603,17 +603,17 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
"""
# Allow for easy combining
if isinstance(other, ChatPromptTemplate):
return ChatPromptTemplate(messages=self.messages + other.messages)
return ChatPromptTemplate(messages=self.messages + other.messages) # type: ignore[call-arg]
elif isinstance(
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
):
return ChatPromptTemplate(messages=self.messages + [other])
return ChatPromptTemplate(messages=self.messages + [other]) # type: ignore[call-arg]
elif isinstance(other, (list, tuple)):
_other = ChatPromptTemplate.from_messages(other)
return ChatPromptTemplate(messages=self.messages + _other.messages)
return ChatPromptTemplate(messages=self.messages + _other.messages) # type: ignore[call-arg]
elif isinstance(other, str):
prompt = HumanMessagePromptTemplate.from_template(other)
return ChatPromptTemplate(messages=self.messages + [prompt])
return ChatPromptTemplate(messages=self.messages + [prompt]) # type: ignore[call-arg]
else:
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
@ -684,7 +684,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
Returns:
a chat prompt template
"""
return cls(
return cls( # type: ignore[call-arg]
messages=[
ChatMessagePromptTemplate.from_template(template, role=role)
for role, template in string_messages

View File

@ -255,7 +255,7 @@ class PromptTemplate(StringPromptTemplate):
return cls(
input_variables=input_variables,
template=template,
template_format=template_format,
template_format=template_format, # type: ignore[arg-type]
partial_variables=_partial_variables,
**kwargs,
)

View File

@ -1260,7 +1260,7 @@ class Runnable(Generic[Input, Output], ABC):
output = cast(
Output,
context.run(
call_func_with_variable_args,
call_func_with_variable_args, # type: ignore[arg-type]
func, # type: ignore[arg-type]
input, # type: ignore[arg-type]
config,
@ -1888,7 +1888,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
raise ValueError(
f"RunnableSequence must have at least 2 steps, got {len(steps_flat)}"
)
super().__init__(
super().__init__( # type: ignore[call-arg]
first=steps_flat[0],
middle=list(steps_flat[1:-1]),
last=steps_flat[-1],
@ -2574,7 +2574,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
) -> None:
merged = {**__steps} if __steps is not None else {}
merged.update(kwargs)
super().__init__(
super().__init__( # type: ignore[call-arg]
steps={key: coerce_to_runnable(r) for key, r in merged.items()}
)
@ -3001,7 +3001,7 @@ class RunnableGenerator(Runnable[Input, Output]):
func_for_name: Callable = atransform
if inspect.isasyncgenfunction(transform):
self._atransform = transform
self._atransform = transform # type: ignore[assignment]
func_for_name = transform
elif inspect.isgeneratorfunction(transform):
self._transform = transform
@ -3066,7 +3066,10 @@ class RunnableGenerator(Runnable[Input, Output]):
if not hasattr(self, "_transform"):
raise NotImplementedError(f"{repr(self)} only supports async methods.")
return self._transform_stream_with_config(
input, self._transform, config, **kwargs
input,
self._transform, # type: ignore[arg-type]
config,
**kwargs, # type: ignore[arg-type]
)
def stream(
@ -3995,7 +3998,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
runnable with a custom type.
**other_kwargs: Unpacked into the base class.
"""
super().__init__(
super().__init__( # type: ignore[call-arg]
bound=bound,
kwargs=kwargs or {},
config=config or {},

View File

@ -119,7 +119,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
runnable = coerce_to_runnable(runnable)
_branches.append((condition, runnable))
super().__init__(branches=_branches, default=default_)
super().__init__(branches=_branches, default=default_) # type: ignore[call-arg]
class Config:
arbitrary_types_allowed = True

View File

@ -165,7 +165,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
afunc = func
func = None
super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs)
super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs) # type: ignore[call-arg]
@classmethod
def is_lc_serializable(cls) -> bool:
@ -320,7 +320,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
mapper: RunnableParallel[Dict[str, Any]]
def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None:
super().__init__(mapper=mapper, **kwargs)
super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg]
@classmethod
def is_lc_serializable(cls) -> bool:
@ -582,7 +582,7 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
keys: Union[str, List[str]]
def __init__(self, keys: Union[str, List[str]], **kwargs: Any) -> None:
super().__init__(keys=keys, **kwargs)
super().__init__(keys=keys, **kwargs) # type: ignore[call-arg]
@classmethod
def is_lc_serializable(cls) -> bool:

View File

@ -63,7 +63,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
self,
runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]],
) -> None:
super().__init__(
super().__init__( # type: ignore[call-arg]
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
)

View File

@ -119,7 +119,7 @@ class BaseStore(Generic[K, V], ABC):
item = await run_in_executor(None, lambda it: next(it, done), iterator)
if item is done:
break
yield item
yield item # type: ignore[misc]
ByteStore = BaseStore[str, bytes]

View File

@ -619,7 +619,7 @@ class Tool(BaseTool):
self, name: str, func: Optional[Callable], description: str, **kwargs: Any
) -> None:
"""Initialize tool."""
super(Tool, self).__init__(
super(Tool, self).__init__( # type: ignore[call-arg]
name=name, func=func, description=description, **kwargs
)
@ -795,7 +795,7 @@ class StructuredTool(BaseTool):
name=name,
func=func,
coroutine=coroutine,
args_schema=_args_schema,
args_schema=_args_schema, # type: ignore[arg-type]
description=description,
return_direct=return_direct,
**kwargs,

View File

@ -222,7 +222,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
# Changing this to "chat_model" may break triggering on_llm_start
run_type="chat_model",
tags=tags,
name=name,
name=name, # type: ignore[arg-type]
)
self._start_trace(chat_model_run)
self._on_chat_model_start(chat_model_run)
@ -259,7 +259,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_execution_order=execution_order,
run_type="llm",
tags=tags or [],
name=name,
name=name, # type: ignore[arg-type]
)
self._start_trace(llm_run)
self._on_llm_start(llm_run)
@ -390,7 +390,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_execution_order=execution_order,
child_runs=[],
run_type=run_type or "chain",
name=name,
name=name, # type: ignore[arg-type]
tags=tags or [],
)
self._start_trace(chain_run)
@ -498,7 +498,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_runs=[],
run_type="tool",
tags=tags or [],
name=name,
name=name, # type: ignore[arg-type]
)
self._start_trace(tool_run)
self._on_tool_start(tool_run)

View File

@ -122,7 +122,7 @@ class LangChainTracer(BaseTracer):
child_execution_order=execution_order,
run_type="llm",
tags=tags,
name=name,
name=name, # type: ignore[arg-type]
)
self._start_trace(chat_model_run)
self._on_chat_model_start(chat_model_run)

View File

@ -62,29 +62,29 @@ class LangChainTracerV1(BaseTracer):
else:
raise ValueError("No prompts found in LLM run inputs")
return LLMRun(
uuid=str(run.id) if run.id else None,
uuid=str(run.id) if run.id else None, # type: ignore[arg-type]
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
end_time=run.end_time, # type: ignore[arg-type]
extra=run.extra,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
serialized=run.serialized, # type: ignore[arg-type]
session_id=session.id,
error=run.error,
prompts=prompts,
response=run.outputs if run.outputs else None,
response=run.outputs if run.outputs else None, # type: ignore[arg-type]
)
if run.run_type == "chain":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ChainRun(
uuid=str(run.id) if run.id else None,
uuid=str(run.id) if run.id else None, # type: ignore[arg-type]
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
end_time=run.end_time, # type: ignore[arg-type]
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
serialized=run.serialized, # type: ignore[arg-type]
session_id=session.id,
inputs=run.inputs,
outputs=run.outputs,
@ -99,13 +99,13 @@ class LangChainTracerV1(BaseTracer):
if run.run_type == "tool":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ToolRun(
uuid=str(run.id) if run.id else None,
uuid=str(run.id) if run.id else None, # type: ignore[arg-type]
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
end_time=run.end_time, # type: ignore[arg-type]
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
serialized=run.serialized, # type: ignore[arg-type]
session_id=session.id,
action=str(run.serialized),
tool_input=run.inputs.get("input", ""),

69
libs/core/poetry.lock generated
View File

@ -1329,52 +1329,49 @@ files = [
[[package]]
name = "mypy"
version = "0.991"
version = "1.8.0"
description = "Optional static typing for Python"
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "mypy-0.991-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7d17e0a9707d0772f4a7b878f04b4fd11f6f5bcb9b3813975a9b13c9332153ab"},
{file = "mypy-0.991-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0714258640194d75677e86c786e80ccf294972cc76885d3ebbb560f11db0003d"},
{file = "mypy-0.991-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0c8f3be99e8a8bd403caa8c03be619544bc2c77a7093685dcf308c6b109426c6"},
{file = "mypy-0.991-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9ec663ed6c8f15f4ae9d3c04c989b744436c16d26580eaa760ae9dd5d662eb"},
{file = "mypy-0.991-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4307270436fd7694b41f913eb09210faff27ea4979ecbcd849e57d2da2f65305"},
{file = "mypy-0.991-cp310-cp310-win_amd64.whl", hash = "sha256:901c2c269c616e6cb0998b33d4adbb4a6af0ac4ce5cd078afd7bc95830e62c1c"},
{file = "mypy-0.991-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d13674f3fb73805ba0c45eb6c0c3053d218aa1f7abead6e446d474529aafc372"},
{file = "mypy-0.991-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1c8cd4fb70e8584ca1ed5805cbc7c017a3d1a29fb450621089ffed3e99d1857f"},
{file = "mypy-0.991-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:209ee89fbb0deed518605edddd234af80506aec932ad28d73c08f1400ef80a33"},
{file = "mypy-0.991-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37bd02ebf9d10e05b00d71302d2c2e6ca333e6c2a8584a98c00e038db8121f05"},
{file = "mypy-0.991-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:26efb2fcc6b67e4d5a55561f39176821d2adf88f2745ddc72751b7890f3194ad"},
{file = "mypy-0.991-cp311-cp311-win_amd64.whl", hash = "sha256:3a700330b567114b673cf8ee7388e949f843b356a73b5ab22dd7cff4742a5297"},
{file = "mypy-0.991-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1f7d1a520373e2272b10796c3ff721ea1a0712288cafaa95931e66aa15798813"},
{file = "mypy-0.991-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:641411733b127c3e0dab94c45af15fea99e4468f99ac88b39efb1ad677da5711"},
{file = "mypy-0.991-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3d80e36b7d7a9259b740be6d8d906221789b0d836201af4234093cae89ced0cd"},
{file = "mypy-0.991-cp37-cp37m-win_amd64.whl", hash = "sha256:e62ebaad93be3ad1a828a11e90f0e76f15449371ffeecca4a0a0b9adc99abcef"},
{file = "mypy-0.991-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:b86ce2c1866a748c0f6faca5232059f881cda6dda2a893b9a8373353cfe3715a"},
{file = "mypy-0.991-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ac6e503823143464538efda0e8e356d871557ef60ccd38f8824a4257acc18d93"},
{file = "mypy-0.991-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cca5adf694af539aeaa6ac633a7afe9bbd760df9d31be55ab780b77ab5ae8bf"},
{file = "mypy-0.991-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a12c56bf73cdab116df96e4ff39610b92a348cc99a1307e1da3c3768bbb5b135"},
{file = "mypy-0.991-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:652b651d42f155033a1967739788c436491b577b6a44e4c39fb340d0ee7f0d70"},
{file = "mypy-0.991-cp38-cp38-win_amd64.whl", hash = "sha256:4175593dc25d9da12f7de8de873a33f9b2b8bdb4e827a7cae952e5b1a342e243"},
{file = "mypy-0.991-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:98e781cd35c0acf33eb0295e8b9c55cdbef64fcb35f6d3aa2186f289bed6e80d"},
{file = "mypy-0.991-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6d7464bac72a85cb3491c7e92b5b62f3dcccb8af26826257760a552a5e244aa5"},
{file = "mypy-0.991-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c9166b3f81a10cdf9b49f2d594b21b31adadb3d5e9db9b834866c3258b695be3"},
{file = "mypy-0.991-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8472f736a5bfb159a5e36740847808f6f5b659960115ff29c7cecec1741c648"},
{file = "mypy-0.991-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5e80e758243b97b618cdf22004beb09e8a2de1af481382e4d84bc52152d1c476"},
{file = "mypy-0.991-cp39-cp39-win_amd64.whl", hash = "sha256:74e259b5c19f70d35fcc1ad3d56499065c601dfe94ff67ae48b85596b9ec1461"},
{file = "mypy-0.991-py3-none-any.whl", hash = "sha256:de32edc9b0a7e67c2775e574cb061a537660e51210fbf6006b0b36ea695ae9bb"},
{file = "mypy-0.991.tar.gz", hash = "sha256:3c0165ba8f354a6d9881809ef29f1a9318a236a6d81c690094c5df32107bde06"},
{file = "mypy-1.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485a8942f671120f76afffff70f259e1cd0f0cfe08f81c05d8816d958d4577d3"},
{file = "mypy-1.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:df9824ac11deaf007443e7ed2a4a26bebff98d2bc43c6da21b2b64185da011c4"},
{file = "mypy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2afecd6354bbfb6e0160f4e4ad9ba6e4e003b767dd80d85516e71f2e955ab50d"},
{file = "mypy-1.8.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8963b83d53ee733a6e4196954502b33567ad07dfd74851f32be18eb932fb1cb9"},
{file = "mypy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e46f44b54ebddbeedbd3d5b289a893219065ef805d95094d16a0af6630f5d410"},
{file = "mypy-1.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:855fe27b80375e5c5878492f0729540db47b186509c98dae341254c8f45f42ae"},
{file = "mypy-1.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4c886c6cce2d070bd7df4ec4a05a13ee20c0aa60cb587e8d1265b6c03cf91da3"},
{file = "mypy-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d19c413b3c07cbecf1f991e2221746b0d2a9410b59cb3f4fb9557f0365a1a817"},
{file = "mypy-1.8.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9261ed810972061388918c83c3f5cd46079d875026ba97380f3e3978a72f503d"},
{file = "mypy-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:51720c776d148bad2372ca21ca29256ed483aa9a4cdefefcef49006dff2a6835"},
{file = "mypy-1.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:52825b01f5c4c1c4eb0db253ec09c7aa17e1a7304d247c48b6f3599ef40db8bd"},
{file = "mypy-1.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f5ac9a4eeb1ec0f1ccdc6f326bcdb464de5f80eb07fb38b5ddd7b0de6bc61e55"},
{file = "mypy-1.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afe3fe972c645b4632c563d3f3eff1cdca2fa058f730df2b93a35e3b0c538218"},
{file = "mypy-1.8.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:42c6680d256ab35637ef88891c6bd02514ccb7e1122133ac96055ff458f93fc3"},
{file = "mypy-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:720a5ca70e136b675af3af63db533c1c8c9181314d207568bbe79051f122669e"},
{file = "mypy-1.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:028cf9f2cae89e202d7b6593cd98db6759379f17a319b5faf4f9978d7084cdc6"},
{file = "mypy-1.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4e6d97288757e1ddba10dd9549ac27982e3e74a49d8d0179fc14d4365c7add66"},
{file = "mypy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f1478736fcebb90f97e40aff11a5f253af890c845ee0c850fe80aa060a267c6"},
{file = "mypy-1.8.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42419861b43e6962a649068a61f4a4839205a3ef525b858377a960b9e2de6e0d"},
{file = "mypy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:2b5b6c721bd4aabaadead3a5e6fa85c11c6c795e0c81a7215776ef8afc66de02"},
{file = "mypy-1.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5c1538c38584029352878a0466f03a8ee7547d7bd9f641f57a0f3017a7c905b8"},
{file = "mypy-1.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ef4be7baf08a203170f29e89d79064463b7fc7a0908b9d0d5114e8009c3a259"},
{file = "mypy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7178def594014aa6c35a8ff411cf37d682f428b3b5617ca79029d8ae72f5402b"},
{file = "mypy-1.8.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ab3c84fa13c04aeeeabb2a7f67a25ef5d77ac9d6486ff33ded762ef353aa5592"},
{file = "mypy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:99b00bc72855812a60d253420d8a2eae839b0afa4938f09f4d2aa9bb4654263a"},
{file = "mypy-1.8.0-py3-none-any.whl", hash = "sha256:538fd81bb5e430cc1381a443971c0475582ff9f434c16cd46d2c66763ce85d9d"},
{file = "mypy-1.8.0.tar.gz", hash = "sha256:6ff8b244d7085a0b425b56d327b480c3b29cafbd2eff27316a004f9a7391ae07"},
]
[package.dependencies]
mypy-extensions = ">=0.4.3"
mypy-extensions = ">=1.0.0"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
typing-extensions = ">=3.10"
typing-extensions = ">=4.1.0"
[package.extras]
dmypy = ["psutil (>=4.0)"]
install-types = ["pip"]
python2 = ["typed-ast (>=1.4.0,<2)"]
mypyc = ["setuptools (>=50)"]
reports = ["lxml"]
[[package]]
@ -2936,4 +2933,4 @@ extended-testing = ["jinja2"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "092a56ee5733650e75cdacb0480d6a7fea1ff40a4a7f33500f77990a6e590ea4"
content-hash = "9d6e9c9613b31dbbe35772bf8d8d5aaba637228de7abbf4a7b271971c2a81ba9"

View File

@ -30,7 +30,7 @@ ruff = "^0.1.5"
optional = true
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
mypy = "^1"
types-pyyaml = "^6.0.12.2"
types-requests = "^2.28.11.5"
types-jinja2 = "^2.11.9"

View File

@ -67,7 +67,7 @@ def create_chat_prompt_template() -> ChatPromptTemplate:
"""Create a chat prompt template."""
return ChatPromptTemplate(
input_variables=["foo", "bar", "context"],
messages=create_messages(),
messages=create_messages(), # type: ignore[arg-type]
)
@ -191,10 +191,12 @@ def test_chat_invalid_input_variables_extra() -> None:
messages = [HumanMessage(content="foo")]
with pytest.raises(ValueError):
ChatPromptTemplate(
messages=messages, input_variables=["foo"], validate_template=True
messages=messages, # type: ignore[arg-type]
input_variables=["foo"],
validate_template=True, # type: ignore[arg-type]
)
assert (
ChatPromptTemplate(messages=messages, input_variables=["foo"]).input_variables
ChatPromptTemplate(messages=messages, input_variables=["foo"]).input_variables # type: ignore[arg-type]
== []
)
@ -203,16 +205,19 @@ def test_chat_invalid_input_variables_missing() -> None:
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
with pytest.raises(ValueError):
ChatPromptTemplate(
messages=messages, input_variables=[], validate_template=True
messages=messages, # type: ignore[arg-type]
input_variables=[],
validate_template=True, # type: ignore[arg-type]
)
assert ChatPromptTemplate(
messages=messages, input_variables=[]
messages=messages, # type: ignore[arg-type]
input_variables=[], # type: ignore[arg-type]
).input_variables == ["foo"]
def test_infer_variables() -> None:
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
prompt = ChatPromptTemplate(messages=messages)
prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type, call-arg]
assert prompt.input_variables == ["foo"]
@ -223,7 +228,7 @@ def test_chat_valid_with_partial_variables() -> None:
)
]
prompt = ChatPromptTemplate(
messages=messages,
messages=messages, # type: ignore[arg-type]
input_variables=["question", "context"],
partial_variables={"formatins": "some structure"},
)
@ -237,8 +242,9 @@ def test_chat_valid_infer_variables() -> None:
"Do something with {question} using {context} giving it like {formatins}"
)
]
prompt = ChatPromptTemplate(
messages=messages, partial_variables={"formatins": "some structure"}
prompt = ChatPromptTemplate( # type: ignore[call-arg]
messages=messages, # type: ignore[arg-type]
partial_variables={"formatins": "some structure"}, # type: ignore[arg-type]
)
assert set(prompt.input_variables) == {"question", "context"}
assert prompt.partial_variables == {"formatins": "some structure"}

View File

@ -6,7 +6,7 @@ from langchain_core.prompts.prompt import PromptTemplate
def test_get_input_variables() -> None:
prompt_a = PromptTemplate.from_template("{foo}")
prompt_b = PromptTemplate.from_template("{bar}")
pipeline_prompt = PipelinePromptTemplate(
pipeline_prompt = PipelinePromptTemplate( # type: ignore[call-arg]
final_prompt=prompt_b, pipeline_prompts=[("bar", prompt_a)]
)
assert pipeline_prompt.input_variables == ["foo"]
@ -15,7 +15,7 @@ def test_get_input_variables() -> None:
def test_simple_pipeline() -> None:
prompt_a = PromptTemplate.from_template("{foo}")
prompt_b = PromptTemplate.from_template("{bar}")
pipeline_prompt = PipelinePromptTemplate(
pipeline_prompt = PipelinePromptTemplate( # type: ignore[call-arg]
final_prompt=prompt_b, pipeline_prompts=[("bar", prompt_a)]
)
output = pipeline_prompt.format(foo="jim")
@ -25,7 +25,7 @@ def test_simple_pipeline() -> None:
def test_multi_variable_pipeline() -> None:
prompt_a = PromptTemplate.from_template("{foo}")
prompt_b = PromptTemplate.from_template("okay {bar} {baz}")
pipeline_prompt = PipelinePromptTemplate(
pipeline_prompt = PipelinePromptTemplate( # type: ignore[call-arg]
final_prompt=prompt_b, pipeline_prompts=[("bar", prompt_a)]
)
output = pipeline_prompt.format(foo="jim", baz="deep")
@ -37,7 +37,7 @@ def test_partial_with_chat_prompts() -> None:
input_variables=["foo"], messages=[MessagesPlaceholder(variable_name="foo")]
)
prompt_b = ChatPromptTemplate.from_template("jim {bar}")
pipeline_prompt = PipelinePromptTemplate(
pipeline_prompt = PipelinePromptTemplate( # type: ignore[call-arg]
final_prompt=prompt_a, pipeline_prompts=[("foo", prompt_b)]
)
assert pipeline_prompt.input_variables == ["bar"]

View File

@ -126,7 +126,9 @@ def test_prompt_invalid_template_format() -> None:
input_variables = ["foo"]
with pytest.raises(ValueError):
PromptTemplate(
input_variables=input_variables, template=template, template_format="bar"
input_variables=input_variables,
template=template,
template_format="bar", # type: ignore[arg-type]
)

View File

@ -758,7 +758,7 @@ def test_validation_error_handling_non_validation_error(
async def _arun(self) -> str:
return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
with pytest.raises(NotImplementedError):
_tool.run({})
@ -820,7 +820,7 @@ async def test_async_validation_error_handling_non_validation_error(
async def _arun(self) -> str:
return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
with pytest.raises(NotImplementedError):
await _tool.arun({})

View File

@ -53,7 +53,7 @@ def _compare_run_with_error(run: Any, expected_run: Any) -> None:
def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run."""
uuid = uuid4()
compare_run = Run(
compare_run = Run( # type: ignore[call-arg]
id=uuid,
parent_run_id=None,
start_time=datetime.now(timezone.utc),
@ -67,7 +67,7 @@ def test_tracer_llm_run() -> None:
child_execution_order=1,
serialized=SERIALIZED,
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
trace_id=uuid,
@ -89,7 +89,7 @@ def test_tracer_chat_model_run() -> None:
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
)
compare_run = Run(
id=str(run_managers[0].run_id),
id=str(run_managers[0].run_id), # type: ignore[arg-type]
name="chat_model",
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
@ -102,7 +102,7 @@ def test_tracer_chat_model_run() -> None:
child_execution_order=1,
serialized=SERIALIZED_CHAT,
inputs=dict(prompts=["Human: "]),
outputs=LLMResult(generations=[[]]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
trace_id=run_managers[0].run_id,
@ -140,7 +140,7 @@ def test_tracer_multiple_llm_runs() -> None:
child_execution_order=1,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
trace_id=uuid,
@ -160,8 +160,8 @@ def test_tracer_multiple_llm_runs() -> None:
def test_tracer_chain_run() -> None:
"""Test tracer on a Chain run."""
uuid = uuid4()
compare_run = Run(
id=str(uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@ -190,8 +190,8 @@ def test_tracer_chain_run() -> None:
def test_tracer_tool_run() -> None:
"""Test tracer on a Tool run."""
uuid = uuid4()
compare_run = Run(
id=str(uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@ -251,8 +251,8 @@ def test_tracer_nested_run() -> None:
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
compare_run = Run(
id=str(chain_uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(chain_uuid), # type: ignore[arg-type]
error=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
@ -270,7 +270,7 @@ def test_tracer_nested_run() -> None:
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}",
child_runs=[
Run(
Run( # type: ignore[call-arg]
id=tool_uuid,
parent_run_id=chain_uuid,
start_time=datetime.now(timezone.utc),
@ -290,9 +290,9 @@ def test_tracer_nested_run() -> None:
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}",
child_runs=[
Run(
id=str(llm_uuid1),
parent_run_id=str(tool_uuid),
Run( # type: ignore[call-arg]
id=str(llm_uuid1), # type: ignore[arg-type]
parent_run_id=str(tool_uuid), # type: ignore[arg-type]
error=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
@ -305,16 +305,16 @@ def test_tracer_nested_run() -> None:
child_execution_order=3,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}.20230101T000000000000Z{llm_uuid1}",
)
],
),
Run(
id=str(llm_uuid2),
parent_run_id=str(chain_uuid),
Run( # type: ignore[call-arg]
id=str(llm_uuid2), # type: ignore[arg-type]
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
error=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
@ -327,7 +327,7 @@ def test_tracer_nested_run() -> None:
child_execution_order=4,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{llm_uuid2}",
@ -344,8 +344,8 @@ def test_tracer_llm_run_on_error() -> None:
exception = Exception("test")
uuid = uuid4()
compare_run = Run(
id=str(uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@ -377,8 +377,8 @@ def test_tracer_llm_run_on_error_callback() -> None:
exception = Exception("test")
uuid = uuid4()
compare_run = Run(
id=str(uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@ -415,8 +415,8 @@ def test_tracer_chain_run_on_error() -> None:
exception = Exception("test")
uuid = uuid4()
compare_run = Run(
id=str(uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@ -447,8 +447,8 @@ def test_tracer_tool_run_on_error() -> None:
exception = Exception("test")
uuid = uuid4()
compare_run = Run(
id=str(uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@ -520,8 +520,8 @@ def test_tracer_nested_runs_on_error() -> None:
tracer.on_tool_error(exception, run_id=tool_uuid)
tracer.on_chain_error(exception, run_id=chain_uuid)
compare_run = Run(
id=str(chain_uuid),
compare_run = Run( # type: ignore[call-arg]
id=str(chain_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@ -539,9 +539,9 @@ def test_tracer_nested_runs_on_error() -> None:
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}",
child_runs=[
Run(
id=str(llm_uuid1),
parent_run_id=str(chain_uuid),
Run( # type: ignore[call-arg]
id=str(llm_uuid1), # type: ignore[arg-type]
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@ -554,14 +554,14 @@ def test_tracer_nested_runs_on_error() -> None:
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None),
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{llm_uuid1}",
),
Run(
id=str(llm_uuid2),
parent_run_id=str(chain_uuid),
Run( # type: ignore[call-arg]
id=str(llm_uuid2), # type: ignore[arg-type]
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@ -574,14 +574,14 @@ def test_tracer_nested_runs_on_error() -> None:
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None),
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{llm_uuid2}",
),
Run(
id=str(tool_uuid),
parent_run_id=str(chain_uuid),
Run( # type: ignore[call-arg]
id=str(tool_uuid), # type: ignore[arg-type]
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[
@ -599,9 +599,9 @@ def test_tracer_nested_runs_on_error() -> None:
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}",
child_runs=[
Run(
id=str(llm_uuid3),
parent_run_id=str(tool_uuid),
Run( # type: ignore[call-arg]
id=str(llm_uuid3), # type: ignore[arg-type]
parent_run_id=str(tool_uuid), # type: ignore[arg-type]
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
events=[

View File

@ -459,8 +459,8 @@ def test_convert_run(
sample_tracer_session_v1: TracerSessionV1,
) -> None:
"""Test converting a run to a V1 run."""
llm_run = Run(
id="57a08cc4-73d2-4236-8370-549099d07fad",
llm_run = Run( # type: ignore[call-arg]
id="57a08cc4-73d2-4236-8370-549099d07fad", # type: ignore[arg-type]
name="llm_run",
execution_order=1,
child_execution_order=1,
@ -474,7 +474,7 @@ def test_convert_run(
run_type="llm",
)
chain_run = Run(
id="57a08cc4-73d2-4236-8371-549099d07fad",
id="57a08cc4-73d2-4236-8371-549099d07fad", # type: ignore[arg-type]
name="chain_run",
execution_order=1,
start_time=datetime.now(timezone.utc),
@ -489,7 +489,7 @@ def test_convert_run(
)
tool_run = Run(
id="57a08cc4-73d2-4236-8372-549099d07fad",
id="57a08cc4-73d2-4236-8372-549099d07fad", # type: ignore[arg-type]
name="tool_run",
execution_order=1,
child_execution_order=1,
@ -503,7 +503,7 @@ def test_convert_run(
run_type="tool",
)
expected_llm_run = LLMRun(
expected_llm_run = LLMRun( # type: ignore[call-arg]
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
@ -517,7 +517,7 @@ def test_convert_run(
extra={},
)
expected_chain_run = ChainRun(
expected_chain_run = ChainRun( # type: ignore[call-arg]
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
@ -533,7 +533,7 @@ def test_convert_run(
child_tool_runs=[],
extra={},
)
expected_tool_run = ToolRun(
expected_tool_run = ToolRun( # type: ignore[call-arg]
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,