community[patch]: chat model mypy fixes (#17061)

Related to #17048
This commit is contained in:
Bagatur
2024-02-05 13:42:59 -08:00
committed by GitHub
parent d93de71d08
commit 66e45e8ab7
17 changed files with 101 additions and 89 deletions

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@@ -74,10 +74,15 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
if message.content[0].get("type") == "text": # type: ignore[union-attr]
message_text = f"[INST] {message.content[0]['text']} [/INST]" # type: ignore[index]
elif message.content[0].get("type") == "image_url": # type: ignore[union-attr]
message_text = message.content[0]["image_url"]["url"] # type: ignore[index, index]
if isinstance(message.content, List):
first_content = cast(List[Dict], message.content)[0]
content_type = first_content.get("type")
if content_type == "text":
message_text = f"[INST] {first_content['text']} [/INST]"
elif content_type == "image_url":
message_text = first_content["image_url"]["url"]
else:
message_text = f"[INST] {message.content} [/INST]"
elif isinstance(message, AIMessage):
message_text = f"{message.content}"
elif isinstance(message, SystemMessage):
@@ -94,7 +99,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
def _convert_messages_to_ollama_messages(
self, messages: List[BaseMessage]
) -> List[Dict[str, Union[str, List[str]]]]:
ollama_messages = []
ollama_messages: List = []
for message in messages:
role = ""
if isinstance(message, HumanMessage):
@@ -111,12 +116,12 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if isinstance(message.content, str):
content = message.content
else:
for content_part in message.content:
if content_part.get("type") == "text": # type: ignore[union-attr]
content += f"\n{content_part['text']}" # type: ignore[index]
elif content_part.get("type") == "image_url": # type: ignore[union-attr]
if isinstance(content_part.get("image_url"), str): # type: ignore[union-attr]
image_url_components = content_part["image_url"].split(",") # type: ignore[index]
for content_part in cast(List[Dict], message.content):
if content_part.get("type") == "text":
content += f"\n{content_part['text']}"
elif content_part.get("type") == "image_url":
if isinstance(content_part.get("image_url"), str):
image_url_components = content_part["image_url"].split(",")
# Support data:image/jpeg;base64,<image> format
# and base64 strings
if len(image_url_components) > 1:
@@ -142,7 +147,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
}
)
return ollama_messages # type: ignore[return-value]
return ollama_messages
def _create_chat_stream(
self,
@@ -324,21 +329,15 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
try:
async for stream_resp in self._acreate_chat_stream(
messages, stop, **kwargs
):
if stream_resp:
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
except OllamaEndpointNotFoundError:
async for chunk in self._legacy_astream(messages, stop, **kwargs): # type: ignore[attr-defined]
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
if stream_resp:
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
@deprecated("0.0.3", alternative="_stream")
def _legacy_stream(