mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-30 14:15:49 +00:00
fix(core): disallow_any_generics (#38156)
Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
committed by
GitHub
parent
b247e572b1
commit
afff89a9f7
@@ -101,7 +101,7 @@ AnyMessage = Annotated[
|
||||
"""A type representing any defined `Message` or `MessageChunk` type."""
|
||||
|
||||
|
||||
def _has_base64_data(block: dict) -> bool:
|
||||
def _has_base64_data(block: dict[str, Any]) -> bool:
|
||||
"""Check if a content block contains base64 encoded data.
|
||||
|
||||
Args:
|
||||
@@ -139,7 +139,7 @@ def _truncate(text: str, max_len: int = _XML_CONTENT_BLOCK_MAX_LEN) -> str:
|
||||
return text[:max_len] + "..."
|
||||
|
||||
|
||||
def _format_content_block_xml(block: dict) -> str | None:
|
||||
def _format_content_block_xml(block: dict[str, Any]) -> str | None:
|
||||
"""Format a content block as XML.
|
||||
|
||||
Args:
|
||||
@@ -581,14 +581,18 @@ def message_chunk_to_message(chunk: BaseMessage) -> BaseMessage:
|
||||
|
||||
|
||||
MessageLikeRepresentation = (
|
||||
BaseMessage | list[str] | tuple[str, str] | str | dict[str, Any]
|
||||
BaseMessage
|
||||
| list[str]
|
||||
| tuple[str, str | list[str | dict[str, Any]]]
|
||||
| str
|
||||
| dict[str, Any]
|
||||
)
|
||||
"""A type representing the various ways a message can be represented."""
|
||||
|
||||
|
||||
def _create_message_from_message_type(
|
||||
message_type: str,
|
||||
content: str,
|
||||
content: str | list[str | dict[str, Any]],
|
||||
name: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
@@ -1534,7 +1538,7 @@ def convert_to_openai_messages(
|
||||
|
||||
@overload
|
||||
def convert_to_openai_messages(
|
||||
messages: _MultipleMessages,
|
||||
messages: _MultipleMessages[Any],
|
||||
*,
|
||||
text_format: Literal["string", "block"] = "string",
|
||||
include_id: bool = False,
|
||||
@@ -1639,12 +1643,13 @@ def convert_to_openai_messages(
|
||||
|
||||
oai_messages: list[dict[str, Any]] = []
|
||||
|
||||
messages_: Sequence[MessageLikeRepresentation]
|
||||
if is_single := isinstance(messages, (BaseMessage, dict, str)):
|
||||
messages = [messages]
|
||||
messages_ = [messages]
|
||||
else:
|
||||
messages_ = cast("Sequence[MessageLikeRepresentation]", messages)
|
||||
|
||||
messages = convert_to_messages(messages)
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
for i, message in enumerate(convert_to_messages(messages_)):
|
||||
oai_msg: dict[str, Any] = {"role": _get_message_openai_role(message)}
|
||||
tool_messages: list[dict[str, Any]] = []
|
||||
content: str | list[dict[str, Any]]
|
||||
|
||||
@@ -58,7 +58,10 @@ class BasePromptTemplate(
|
||||
If not provided, all variables are assumed to be strings.
|
||||
"""
|
||||
|
||||
output_parser: BaseOutputParser | None = None
|
||||
# Ideally we would type output_parser as BaseOutputParser[Any]
|
||||
# but that makes Pydantic fail (Pydantic tries to instantiate BaseOutputParser
|
||||
# instead of using the provided output_parser...)
|
||||
output_parser: BaseOutputParser | None = None # type: ignore[type-arg]
|
||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||
|
||||
partial_variables: Mapping[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@@ -389,7 +389,9 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
```
|
||||
"""
|
||||
|
||||
mapper: RunnableParallel
|
||||
# Ideally we would type mapper as RunnableParallel[dict[str, Any]]
|
||||
# but this fails validation for Pydantic <2.10
|
||||
mapper: RunnableParallel # type: ignore[type-arg]
|
||||
|
||||
def __init__(self, mapper: RunnableParallel[dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Create a `RunnableAssign`.
|
||||
|
||||
@@ -92,9 +92,6 @@ strict = true
|
||||
enable_error_code = "deprecated"
|
||||
warn_unreachable = true
|
||||
|
||||
# TODO: activate for 'strict' checking
|
||||
disallow_any_generics = false
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Test functionality related to length based selector."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.example_selectors import (
|
||||
@@ -64,7 +66,7 @@ def test_selector_empty_example(
|
||||
selector: LengthBasedExampleSelector,
|
||||
) -> None:
|
||||
"""Test Empty Example result empty."""
|
||||
empty_list: list[dict] = []
|
||||
empty_list: list[dict[str, Any]] = []
|
||||
empty_selector = LengthBasedExampleSelector(
|
||||
examples=empty_list,
|
||||
example_prompt=selector.example_prompt,
|
||||
|
||||
@@ -542,7 +542,7 @@ def test_content_blocks_v1_list_content_short_circuits() -> None:
|
||||
returns it verbatim (the same object) without routing through the
|
||||
translator. Covers both `AIMessage` and `AIMessageChunk`.
|
||||
"""
|
||||
content: list = [
|
||||
content: list[str | dict[str, Any]] = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "tool_call", "name": "foo", "args": {"a": 1}, "id": "tc_1"},
|
||||
]
|
||||
|
||||
@@ -765,7 +765,7 @@ class FakeTokenCountingModel(FakeChatModel):
|
||||
|
||||
|
||||
def test_convert_to_messages() -> None:
|
||||
message_like: list = [
|
||||
message_like: list[MessageLikeRepresentation] = [
|
||||
# BaseMessage
|
||||
SystemMessage("1"),
|
||||
SystemMessage("1.1", additional_kwargs={"__openai_role__": "developer"}),
|
||||
|
||||
@@ -460,7 +460,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"title": "CommaSeparatedListOutputParserOutput",
|
||||
}
|
||||
|
||||
router: Runnable = RouterRunnable({})
|
||||
router = RouterRunnable[Any]({})
|
||||
|
||||
assert _schema(router.input_schema) == {
|
||||
"$ref": "#/definitions/RouterInput",
|
||||
@@ -709,7 +709,7 @@ def test_schema_complex_seq() -> None:
|
||||
|
||||
model = FakeListChatModel(responses=[""])
|
||||
|
||||
chain1: Runnable = RunnableSequence(
|
||||
chain1 = RunnableSequence[dict[str, Any], str](
|
||||
prompt1, model, StrOutputParser(), name="city_chain"
|
||||
)
|
||||
|
||||
@@ -3024,7 +3024,7 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
|
||||
input={"question": lambda x: x["question"]},
|
||||
)
|
||||
|
||||
def router(value: dict[str, Any]) -> Runnable:
|
||||
def router(value: dict[str, Any]) -> Runnable[dict[str, Any], str]:
|
||||
if value["key"] == "math":
|
||||
return itemgetter("input") | math_chain
|
||||
if value["key"] == "english":
|
||||
@@ -3046,7 +3046,7 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
|
||||
assert result2 == ["4", "2"]
|
||||
|
||||
# Test ainvoke
|
||||
async def arouter(params: dict[str, Any]) -> Runnable:
|
||||
async def arouter(params: dict[str, Any]) -> Runnable[dict[str, Any], str]:
|
||||
if params["key"] == "math":
|
||||
return itemgetter("input") | math_chain
|
||||
if params["key"] == "english":
|
||||
@@ -3925,10 +3925,10 @@ def test_each(snapshot: SnapshotAssertion) -> None:
|
||||
|
||||
|
||||
def test_recursive_lambda() -> None:
|
||||
def _simple_recursion(x: int) -> int | Runnable:
|
||||
def _simple_recursion(x: int) -> Runnable[Any, int]:
|
||||
if x < 10:
|
||||
return RunnableLambda(lambda *_: _simple_recursion(x + 1))
|
||||
return x
|
||||
return RunnableLambda(lambda *_: x)
|
||||
|
||||
runnable = RunnableLambda(_simple_recursion)
|
||||
assert runnable.invoke(5) == 10
|
||||
|
||||
@@ -309,7 +309,7 @@ class TestRunnableSequenceParallelTraceNesting:
|
||||
other_thing: Callable[
|
||||
[int], Generator[int, None, None] | AsyncGenerator[int, None]
|
||||
],
|
||||
) -> RunnableLambda:
|
||||
) -> RunnableLambda[int, int]:
|
||||
@RunnableLambda
|
||||
def my_child_function(a: int) -> int:
|
||||
return a + 2
|
||||
@@ -611,7 +611,7 @@ def test_traceable_parent_run_map_cleanup_with_sibling_children() -> None:
|
||||
with tracing_context(client=tracer.client, enabled=True):
|
||||
|
||||
@traceable
|
||||
def parent(x: dict) -> Any:
|
||||
def parent(x: dict[str, Any]) -> Any:
|
||||
return chain.invoke(x, config={"callbacks": [tracer]})
|
||||
|
||||
result = parent({"input": "hello"})
|
||||
|
||||
@@ -3919,7 +3919,7 @@ def test_tool_invoke_returns_list_of_mixin() -> None:
|
||||
"""End-to-end: a tool returning a list of ToolOutputMixin via invoke."""
|
||||
|
||||
@tool
|
||||
def multi(x: int) -> list:
|
||||
def multi(x: int) -> list[ToolMessage]:
|
||||
"""Return multiple outputs."""
|
||||
return [
|
||||
ToolMessage(f"result-{i}", tool_call_id=f"sub-{i}", name="multi")
|
||||
|
||||
@@ -440,7 +440,7 @@ def test_generation_chunk_addition_type_error() -> None:
|
||||
],
|
||||
)
|
||||
def test_merge_lists(
|
||||
left: list | None, right: list | None, expected: list | None
|
||||
left: list[Any] | None, right: list[Any] | None, expected: list[Any] | None
|
||||
) -> None:
|
||||
left_copy = deepcopy(left)
|
||||
right_copy = deepcopy(right)
|
||||
|
||||
Reference in New Issue
Block a user