mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-18 09:01:03 +00:00
Adding compatibility for OllamaFunctions with ImagePromptTemplate (#24499)
- [ ] **PR title**: "experimental: Adding compatibility for OllamaFunctions with ImagePromptTemplate" - [ ] **PR message**: - **Description:** Removes the outdated `_convert_messages_to_ollama_messages` method override in the `OllamaFunctions` class to ensure that ollama multimodal models can be invoked with an image. - **Issue:** #24174 --------- Co-authored-by: Joel Akeret <joel.akeret@ti&m.com> Co-authored-by: Isaac Francisco <78627776+isahers1@users.noreply.github.com> Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
This commit is contained in:
parent
8f3c052db1
commit
acfce30017
@ -12,7 +12,6 @@ from typing import (
|
|||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_community.chat_models.ollama import ChatOllama
|
from langchain_community.chat_models.ollama import ChatOllama
|
||||||
@ -24,10 +23,7 @@ from langchain_core.language_models import LanguageModelInput
|
|||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolMessage,
|
|
||||||
)
|
)
|
||||||
from langchain_core.output_parsers.base import OutputParserLike
|
from langchain_core.output_parsers.base import OutputParserLike
|
||||||
from langchain_core.output_parsers.json import JsonOutputParser
|
from langchain_core.output_parsers.json import JsonOutputParser
|
||||||
@ -282,59 +278,6 @@ class OllamaFunctions(ChatOllama):
|
|||||||
else:
|
else:
|
||||||
return llm | parser_chain
|
return llm | parser_chain
|
||||||
|
|
||||||
def _convert_messages_to_ollama_messages(
|
|
||||||
self, messages: List[BaseMessage]
|
|
||||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
|
||||||
ollama_messages: List = []
|
|
||||||
for message in messages:
|
|
||||||
role = ""
|
|
||||||
if isinstance(message, HumanMessage):
|
|
||||||
role = "user"
|
|
||||||
elif isinstance(message, AIMessage) or isinstance(message, ToolMessage):
|
|
||||||
role = "assistant"
|
|
||||||
elif isinstance(message, SystemMessage):
|
|
||||||
role = "system"
|
|
||||||
else:
|
|
||||||
raise ValueError("Received unsupported message type for Ollama.")
|
|
||||||
|
|
||||||
content = ""
|
|
||||||
images = []
|
|
||||||
if isinstance(message.content, str):
|
|
||||||
content = message.content
|
|
||||||
else:
|
|
||||||
for content_part in cast(List[Dict], message.content):
|
|
||||||
if content_part.get("type") == "text":
|
|
||||||
content += f"\n{content_part['text']}"
|
|
||||||
elif content_part.get("type") == "image_url":
|
|
||||||
if isinstance(content_part.get("image_url"), str):
|
|
||||||
image_url_components = content_part["image_url"].split(",")
|
|
||||||
# Support data:image/jpeg;base64,<image> format
|
|
||||||
# and base64 strings
|
|
||||||
if len(image_url_components) > 1:
|
|
||||||
images.append(image_url_components[1])
|
|
||||||
else:
|
|
||||||
images.append(image_url_components[0])
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Only string image_url content parts are supported."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Unsupported message content type. "
|
|
||||||
"Must either have type 'text' or type 'image_url' "
|
|
||||||
"with a string 'image_url' field."
|
|
||||||
)
|
|
||||||
|
|
||||||
ollama_messages.append(
|
|
||||||
{
|
|
||||||
"role": role,
|
|
||||||
"content": content,
|
|
||||||
"images": images,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return ollama_messages
|
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
|
30
libs/experimental/tests/unit_tests/test_ollama_functions.py
Normal file
30
libs/experimental/tests/unit_tests/test_ollama_functions.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
|
from langchain_experimental.llms.ollama_functions import OllamaFunctions
|
||||||
|
|
||||||
|
|
||||||
|
class Schema(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@patch.object(OllamaFunctions, "_create_stream")
|
||||||
|
def test_convert_image_prompt(
|
||||||
|
_create_stream_mock: Any,
|
||||||
|
) -> None:
|
||||||
|
response = {"message": {"content": '{"tool": "Schema", "tool_input": {}}'}}
|
||||||
|
_create_stream_mock.return_value = [json.dumps(response)]
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[("human", [{"image_url": "data:image/jpeg;base64,{image_url}"}])]
|
||||||
|
)
|
||||||
|
|
||||||
|
lmm = prompt | OllamaFunctions().with_structured_output(schema=Schema)
|
||||||
|
|
||||||
|
schema_instance = lmm.invoke(dict(image_url=""))
|
||||||
|
|
||||||
|
assert schema_instance is not None
|
Loading…
Reference in New Issue
Block a user