partners[lint]: run pyupgrade to get code in line with 3.9 standards (#30781)

Using `pyupgrade` to get all `partners` code up to 3.9 standards
(mostly, fixing old `typing` imports).
This commit is contained in:
Sydney Runkle
2025-04-11 07:18:44 -04:00
committed by GitHub
parent e72f3c26a0
commit 8c6734325b
123 changed files with 1000 additions and 1109 deletions

View File

@@ -3,7 +3,6 @@
It provides infrastructure for interacting with the Ollama service.
"""
from importlib import metadata
from langchain_ollama.chat_models import ChatOllama

View File

@@ -1,21 +1,14 @@
"""Ollama chat models."""
import json
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Final,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
@@ -153,7 +146,7 @@ def _parse_arguments_from_tool_call(
def _get_tool_calls_from_response(
response: Mapping[str, Any],
) -> List[ToolCall]:
) -> list[ToolCall]:
"""Get tool calls from ollama response."""
tool_calls = []
if "message" in response:
@@ -341,7 +334,7 @@ class ChatOllama(BaseChatModel):
model: str
"""Model name to use."""
extract_reasoning: Optional[Union[bool, Tuple[str, str]]] = False
extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
"""Whether to extract the reasoning tokens in think blocks.
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
If a tuple is supplied, they are assumed to be the (start, end) tokens.
@@ -399,7 +392,7 @@ class ChatOllama(BaseChatModel):
to a specific number will make the model generate the same text for
the same prompt."""
stop: Optional[List[str]] = None
stop: Optional[list[str]] = None
"""Sets the stop tokens to use."""
tfs_z: Optional[float] = None
@@ -443,10 +436,10 @@ class ChatOllama(BaseChatModel):
def _chat_params(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
ollama_messages = self._convert_messages_to_ollama_messages(messages)
if self.stop is not None and stop is not None:
@@ -499,13 +492,13 @@ class ChatOllama(BaseChatModel):
return self
def _convert_messages_to_ollama_messages(
self, messages: List[BaseMessage]
self, messages: list[BaseMessage]
) -> Sequence[Message]:
ollama_messages: List = []
ollama_messages: list = []
for message in messages:
role: Literal["user", "assistant", "system", "tool"]
tool_call_id: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_calls: Optional[list[dict[str, Any]]] = None
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
@@ -531,7 +524,7 @@ class ChatOllama(BaseChatModel):
if isinstance(message.content, str):
content = message.content
else:
for content_part in cast(List[Dict], message.content):
for content_part in cast(list[dict], message.content):
if content_part.get("type") == "text":
content += f"\n{content_part['text']}"
elif content_part.get("type") == "tool_use":
@@ -583,7 +576,7 @@ class ChatOllama(BaseChatModel):
def _extract_reasoning(
self, message_chunk: BaseMessageChunk, is_thinking: bool
) -> Tuple[BaseMessageChunk, bool]:
) -> tuple[BaseMessageChunk, bool]:
"""Mutate a message chunk to extract reasoning content."""
if not self.extract_reasoning:
return message_chunk, is_thinking
@@ -605,8 +598,8 @@ class ChatOllama(BaseChatModel):
async def _acreate_chat_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
chat_params = self._chat_params(messages, stop, **kwargs)
@@ -619,8 +612,8 @@ class ChatOllama(BaseChatModel):
def _create_chat_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]:
chat_params = self._chat_params(messages, stop, **kwargs)
@@ -632,8 +625,8 @@ class ChatOllama(BaseChatModel):
def _chat_stream_with_aggregation(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
@@ -657,8 +650,8 @@ class ChatOllama(BaseChatModel):
async def _achat_stream_with_aggregation(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
@@ -681,7 +674,7 @@ class ChatOllama(BaseChatModel):
return final_chunk
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
@@ -697,8 +690,8 @@ class ChatOllama(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@@ -719,8 +712,8 @@ class ChatOllama(BaseChatModel):
def _iterate_over_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
is_thinking = False
@@ -758,8 +751,8 @@ class ChatOllama(BaseChatModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
@@ -773,8 +766,8 @@ class ChatOllama(BaseChatModel):
async def _aiterate_over_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
is_thinking = False
@@ -812,8 +805,8 @@ class ChatOllama(BaseChatModel):
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
@@ -827,8 +820,8 @@ class ChatOllama(BaseChatModel):
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@@ -854,7 +847,7 @@ class ChatOllama(BaseChatModel):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*,
tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None,
**kwargs: Any,
@@ -877,12 +870,12 @@ class ChatOllama(BaseChatModel):
def with_structured_output(
self,
schema: Union[Dict, type],
schema: Union[dict, type],
*,
method: Literal["function_calling", "json_mode", "json_schema"] = "json_schema",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:

View File

@@ -1,6 +1,6 @@
"""Ollama embeddings models."""
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core.embeddings import Embeddings
from ollama import AsyncClient, Client
@@ -188,7 +188,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
"""The temperature of the model. Increasing the temperature will
make the model answer more creatively. (Default: 0.8)"""
stop: Optional[List[str]] = None
stop: Optional[list[str]] = None
"""Sets the stop tokens to use."""
tfs_z: Optional[float] = None
@@ -211,7 +211,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
)
@property
def _default_params(self) -> Dict[str, Any]:
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling Ollama."""
return {
"mirostat": self.mirostat,
@@ -237,18 +237,18 @@ class OllamaEmbeddings(BaseModel, Embeddings):
self._async_client = AsyncClient(host=self.base_url, **client_kwargs)
return self
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
embedded_docs = self._client.embed(
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
)["embeddings"]
return embedded_docs
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
embedded_docs = (
await self._async_client.embed(
@@ -257,6 +257,6 @@ class OllamaEmbeddings(BaseModel, Embeddings):
)["embeddings"]
return embedded_docs
async def aembed_query(self, text: str) -> List[float]:
async def aembed_query(self, text: str) -> list[float]:
"""Embed query text."""
return (await self.aembed_documents([text]))[0]

View File

@@ -1,13 +1,9 @@
"""Ollama large language models."""
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Union,
)
@@ -89,7 +85,7 @@ class OllamaLLM(BaseLLM):
to a specific number will make the model generate the same text for
the same prompt."""
stop: Optional[List[str]] = None
stop: Optional[list[str]] = None
"""Sets the stop tokens to use."""
tfs_z: Optional[float] = None
@@ -134,9 +130,9 @@ class OllamaLLM(BaseLLM):
def _generate_params(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
@@ -181,7 +177,7 @@ class OllamaLLM(BaseLLM):
return "ollama-llm"
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
@@ -200,7 +196,7 @@ class OllamaLLM(BaseLLM):
async def _acreate_generate_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
async for part in await self._async_client.generate(
@@ -211,7 +207,7 @@ class OllamaLLM(BaseLLM):
def _create_generate_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]:
yield from self._client.generate(
@@ -221,7 +217,7 @@ class OllamaLLM(BaseLLM):
async def _astream_with_aggregation(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
@@ -253,7 +249,7 @@ class OllamaLLM(BaseLLM):
def _stream_with_aggregation(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
@@ -284,8 +280,8 @@ class OllamaLLM(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
@@ -303,8 +299,8 @@ class OllamaLLM(BaseLLM):
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
@@ -323,7 +319,7 @@ class OllamaLLM(BaseLLM):
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
@@ -345,7 +341,7 @@ class OllamaLLM(BaseLLM):
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]: