make image inputs compatible with langchain_ollama (#24619)

This commit is contained in:
Isaac Francisco 2024-07-26 17:39:57 -07:00 committed by GitHub
parent 0535d72927
commit 152427eca1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 8 deletions

View File

@ -346,7 +346,7 @@ class ChatOllama(BaseChatModel):
) -> Sequence[Message]: ) -> Sequence[Message]:
ollama_messages: List = [] ollama_messages: List = []
for message in messages: for message in messages:
role = "" role: Literal["user", "assistant", "system", "tool"]
tool_call_id: Optional[str] = None tool_call_id: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None tool_calls: Optional[List[Dict[str, Any]]] = None
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
@ -383,11 +383,13 @@ class ChatOllama(BaseChatModel):
image_url = None image_url = None
temp_image_url = content_part.get("image_url") temp_image_url = content_part.get("image_url")
if isinstance(temp_image_url, str): if isinstance(temp_image_url, str):
image_url = content_part["image_url"]
elif (
isinstance(temp_image_url, dict) and "url" in temp_image_url
):
image_url = temp_image_url image_url = temp_image_url
elif (
isinstance(temp_image_url, dict)
and "url" in temp_image_url
and isinstance(temp_image_url["url"], str)
):
image_url = temp_image_url["url"]
else: else:
raise ValueError( raise ValueError(
"Only string image_url or dict with string 'url' " "Only string image_url or dict with string 'url' "
@ -408,15 +410,16 @@ class ChatOllama(BaseChatModel):
"Must either have type 'text' or type 'image_url' " "Must either have type 'text' or type 'image_url' "
"with a string 'image_url' field." "with a string 'image_url' field."
) )
msg = { # Should convert to ollama.Message once role includes tool, and tool_call_id is in Message # noqa: E501
msg: dict = {
"role": role, "role": role,
"content": content, "content": content,
"images": images, "images": images,
} }
if tool_calls:
msg["tool_calls"] = tool_calls # type: ignore
if tool_call_id: if tool_call_id:
msg["tool_call_id"] = tool_call_id msg["tool_call_id"] = tool_call_id
if tool_calls:
msg["tool_calls"] = tool_calls
ollama_messages.append(msg) ollama_messages.append(msg)
return ollama_messages return ollama_messages

View File

@ -2,6 +2,8 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
from langchain_ollama.chat_models import ChatOllama from langchain_ollama.chat_models import ChatOllama
@ -15,3 +17,23 @@ class TestChatOllama(ChatModelIntegrationTests):
@property @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return {"model": "llama3-groq-tool-use"} return {"model": "llama3-groq-tool-use"}
@property
def supports_image_inputs(self) -> bool:
return True
@pytest.mark.xfail(
reason=(
"Fails with 'AssertionError'. Ollama does not support 'tool_choice' yet."
)
)
def test_structured_output(self, model: BaseChatModel) -> None:
super().test_structured_output(model)
@pytest.mark.xfail(
reason=(
"Fails with 'AssertionError'. Ollama does not support 'tool_choice' yet."
)
)
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
super().test_structured_output_pydantic_2_v1(model)