mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-23 11:30:37 +00:00
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@@ -74,10 +74,15 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
if message.content[0].get("type") == "text": # type: ignore[union-attr]
|
||||
message_text = f"[INST] {message.content[0]['text']} [/INST]" # type: ignore[index]
|
||||
elif message.content[0].get("type") == "image_url": # type: ignore[union-attr]
|
||||
message_text = message.content[0]["image_url"]["url"] # type: ignore[index, index]
|
||||
if isinstance(message.content, List):
|
||||
first_content = cast(List[Dict], message.content)[0]
|
||||
content_type = first_content.get("type")
|
||||
if content_type == "text":
|
||||
message_text = f"[INST] {first_content['text']} [/INST]"
|
||||
elif content_type == "image_url":
|
||||
message_text = first_content["image_url"]["url"]
|
||||
else:
|
||||
message_text = f"[INST] {message.content} [/INST]"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
@@ -94,7 +99,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
def _convert_messages_to_ollama_messages(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||
ollama_messages = []
|
||||
ollama_messages: List = []
|
||||
for message in messages:
|
||||
role = ""
|
||||
if isinstance(message, HumanMessage):
|
||||
@@ -111,12 +116,12 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
if isinstance(message.content, str):
|
||||
content = message.content
|
||||
else:
|
||||
for content_part in message.content:
|
||||
if content_part.get("type") == "text": # type: ignore[union-attr]
|
||||
content += f"\n{content_part['text']}" # type: ignore[index]
|
||||
elif content_part.get("type") == "image_url": # type: ignore[union-attr]
|
||||
if isinstance(content_part.get("image_url"), str): # type: ignore[union-attr]
|
||||
image_url_components = content_part["image_url"].split(",") # type: ignore[index]
|
||||
for content_part in cast(List[Dict], message.content):
|
||||
if content_part.get("type") == "text":
|
||||
content += f"\n{content_part['text']}"
|
||||
elif content_part.get("type") == "image_url":
|
||||
if isinstance(content_part.get("image_url"), str):
|
||||
image_url_components = content_part["image_url"].split(",")
|
||||
# Support data:image/jpeg;base64,<image> format
|
||||
# and base64 strings
|
||||
if len(image_url_components) > 1:
|
||||
@@ -142,7 +147,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
}
|
||||
)
|
||||
|
||||
return ollama_messages # type: ignore[return-value]
|
||||
return ollama_messages
|
||||
|
||||
def _create_chat_stream(
|
||||
self,
|
||||
@@ -324,21 +329,15 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
try:
|
||||
async for stream_resp in self._acreate_chat_stream(
|
||||
messages, stop, **kwargs
|
||||
):
|
||||
if stream_resp:
|
||||
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
except OllamaEndpointNotFoundError:
|
||||
async for chunk in self._legacy_astream(messages, stop, **kwargs): # type: ignore[attr-defined]
|
||||
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
|
||||
if stream_resp:
|
||||
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
@deprecated("0.0.3", alternative="_stream")
|
||||
def _legacy_stream(
|
||||
|
Reference in New Issue
Block a user