mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
partners[lint]: run pyupgrade to get code in line with 3.9 standards (#30781)
Using `pyupgrade` to get all `partners` code up to 3.9 standards (mostly, fixing old `typing` imports).
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
It provides infrastructure for interacting with the Ollama service.
|
||||
"""
|
||||
|
||||
|
||||
from importlib import metadata
|
||||
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
|
||||
@@ -1,21 +1,14 @@
|
||||
"""Ollama chat models."""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Final,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@@ -153,7 +146,7 @@ def _parse_arguments_from_tool_call(
|
||||
|
||||
def _get_tool_calls_from_response(
|
||||
response: Mapping[str, Any],
|
||||
) -> List[ToolCall]:
|
||||
) -> list[ToolCall]:
|
||||
"""Get tool calls from ollama response."""
|
||||
tool_calls = []
|
||||
if "message" in response:
|
||||
@@ -341,7 +334,7 @@ class ChatOllama(BaseChatModel):
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
extract_reasoning: Optional[Union[bool, Tuple[str, str]]] = False
|
||||
extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
|
||||
"""Whether to extract the reasoning tokens in think blocks.
|
||||
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
|
||||
If a tuple is supplied, they are assumed to be the (start, end) tokens.
|
||||
@@ -399,7 +392,7 @@ class ChatOllama(BaseChatModel):
|
||||
to a specific number will make the model generate the same text for
|
||||
the same prompt."""
|
||||
|
||||
stop: Optional[List[str]] = None
|
||||
stop: Optional[list[str]] = None
|
||||
"""Sets the stop tokens to use."""
|
||||
|
||||
tfs_z: Optional[float] = None
|
||||
@@ -443,10 +436,10 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _chat_params(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
ollama_messages = self._convert_messages_to_ollama_messages(messages)
|
||||
|
||||
if self.stop is not None and stop is not None:
|
||||
@@ -499,13 +492,13 @@ class ChatOllama(BaseChatModel):
|
||||
return self
|
||||
|
||||
def _convert_messages_to_ollama_messages(
|
||||
self, messages: List[BaseMessage]
|
||||
self, messages: list[BaseMessage]
|
||||
) -> Sequence[Message]:
|
||||
ollama_messages: List = []
|
||||
ollama_messages: list = []
|
||||
for message in messages:
|
||||
role: Literal["user", "assistant", "system", "tool"]
|
||||
tool_call_id: Optional[str] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
tool_calls: Optional[list[dict[str, Any]]] = None
|
||||
if isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
@@ -531,7 +524,7 @@ class ChatOllama(BaseChatModel):
|
||||
if isinstance(message.content, str):
|
||||
content = message.content
|
||||
else:
|
||||
for content_part in cast(List[Dict], message.content):
|
||||
for content_part in cast(list[dict], message.content):
|
||||
if content_part.get("type") == "text":
|
||||
content += f"\n{content_part['text']}"
|
||||
elif content_part.get("type") == "tool_use":
|
||||
@@ -583,7 +576,7 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _extract_reasoning(
|
||||
self, message_chunk: BaseMessageChunk, is_thinking: bool
|
||||
) -> Tuple[BaseMessageChunk, bool]:
|
||||
) -> tuple[BaseMessageChunk, bool]:
|
||||
"""Mutate a message chunk to extract reasoning content."""
|
||||
if not self.extract_reasoning:
|
||||
return message_chunk, is_thinking
|
||||
@@ -605,8 +598,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
async def _acreate_chat_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
|
||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||
@@ -619,8 +612,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _create_chat_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Union[Mapping[str, Any], str]]:
|
||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||
@@ -632,8 +625,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _chat_stream_with_aggregation(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -657,8 +650,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
async def _achat_stream_with_aggregation(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -681,7 +674,7 @@ class ChatOllama(BaseChatModel):
|
||||
return final_chunk
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
@@ -697,8 +690,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -719,8 +712,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _iterate_over_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
is_thinking = False
|
||||
@@ -758,8 +751,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
@@ -773,8 +766,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
async def _aiterate_over_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
is_thinking = False
|
||||
@@ -812,8 +805,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
@@ -827,8 +820,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@@ -854,7 +847,7 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None,
|
||||
**kwargs: Any,
|
||||
@@ -877,12 +870,12 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict, type],
|
||||
schema: Union[dict, type],
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode", "json_schema"] = "json_schema",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Ollama embeddings models."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from ollama import AsyncClient, Client
|
||||
@@ -188,7 +188,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
"""The temperature of the model. Increasing the temperature will
|
||||
make the model answer more creatively. (Default: 0.8)"""
|
||||
|
||||
stop: Optional[List[str]] = None
|
||||
stop: Optional[list[str]] = None
|
||||
"""Sets the stop tokens to use."""
|
||||
|
||||
tfs_z: Optional[float] = None
|
||||
@@ -211,7 +211,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling Ollama."""
|
||||
return {
|
||||
"mirostat": self.mirostat,
|
||||
@@ -237,18 +237,18 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
self._async_client = AsyncClient(host=self.base_url, **client_kwargs)
|
||||
return self
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs."""
|
||||
embedded_docs = self._client.embed(
|
||||
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
|
||||
)["embeddings"]
|
||||
return embedded_docs
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs."""
|
||||
embedded_docs = (
|
||||
await self._async_client.embed(
|
||||
@@ -257,6 +257,6 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
)["embeddings"]
|
||||
return embedded_docs
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
return (await self.aembed_documents([text]))[0]
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
"""Ollama large language models."""
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
@@ -89,7 +85,7 @@ class OllamaLLM(BaseLLM):
|
||||
to a specific number will make the model generate the same text for
|
||||
the same prompt."""
|
||||
|
||||
stop: Optional[List[str]] = None
|
||||
stop: Optional[list[str]] = None
|
||||
"""Sets the stop tokens to use."""
|
||||
|
||||
tfs_z: Optional[float] = None
|
||||
@@ -134,9 +130,9 @@ class OllamaLLM(BaseLLM):
|
||||
def _generate_params(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
@@ -181,7 +177,7 @@ class OllamaLLM(BaseLLM):
|
||||
return "ollama-llm"
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
@@ -200,7 +196,7 @@ class OllamaLLM(BaseLLM):
|
||||
async def _acreate_generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
|
||||
async for part in await self._async_client.generate(
|
||||
@@ -211,7 +207,7 @@ class OllamaLLM(BaseLLM):
|
||||
def _create_generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Union[Mapping[str, Any], str]]:
|
||||
yield from self._client.generate(
|
||||
@@ -221,7 +217,7 @@ class OllamaLLM(BaseLLM):
|
||||
async def _astream_with_aggregation(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -253,7 +249,7 @@ class OllamaLLM(BaseLLM):
|
||||
def _stream_with_aggregation(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -284,8 +280,8 @@ class OllamaLLM(BaseLLM):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@@ -303,8 +299,8 @@ class OllamaLLM(BaseLLM):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@@ -323,7 +319,7 @@ class OllamaLLM(BaseLLM):
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
@@ -345,7 +341,7 @@ class OllamaLLM(BaseLLM):
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
|
||||
Reference in New Issue
Block a user