mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
Add image support for Ollama (#14713)
Support [LLaVA](https://ollama.ai/library/llava): * Upgrade Ollama * `ollama pull llava` Ensure compatibility with [image prompt template](https://github.com/langchain-ai/langchain/pull/14263) --------- Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
This commit is contained in:
parent
1075e7d6e8
commit
42421860bc
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,6 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Iterator, List, Optional
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
@ -15,9 +16,10 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
from langchain_community.llms.ollama import _OllamaCommon
|
||||
from langchain_community.llms.ollama import OllamaEndpointNotFoundError, _OllamaCommon
|
||||
|
||||
|
||||
@deprecated("0.0.3", alternative="_chat_stream_response_to_chat_generation_chunk")
|
||||
def _stream_response_to_chat_generation_chunk(
|
||||
stream_response: str,
|
||||
) -> ChatGenerationChunk:
|
||||
@ -30,6 +32,20 @@ def _stream_response_to_chat_generation_chunk(
|
||||
)
|
||||
|
||||
|
||||
def _chat_stream_response_to_chat_generation_chunk(
|
||||
stream_response: str,
|
||||
) -> ChatGenerationChunk:
|
||||
"""Convert a stream response to a generation chunk."""
|
||||
parsed_response = json.loads(stream_response)
|
||||
generation_info = parsed_response if parsed_response.get("done") is True else None
|
||||
return ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=parsed_response.get("message", {}).get("content", "")
|
||||
),
|
||||
generation_info=generation_info,
|
||||
)
|
||||
|
||||
|
||||
class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
"""Ollama locally runs large language models.
|
||||
|
||||
@ -52,11 +68,15 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return False
|
||||
|
||||
@deprecated("0.0.3", alternative="_convert_messages_to_ollama_messages")
|
||||
def _format_message_as_text(self, message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"[INST] {message.content} [/INST]"
|
||||
if message.content[0].get("type") == "text":
|
||||
message_text = f"[INST] {message.content[0]['text']} [/INST]"
|
||||
elif message.content[0].get("type") == "image_url":
|
||||
message_text = message.content[0]["image_url"]["url"]
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
@ -70,6 +90,98 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
[self._format_message_as_text(message) for message in messages]
|
||||
)
|
||||
|
||||
def _convert_messages_to_ollama_messages(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||
ollama_messages = []
|
||||
for message in messages:
|
||||
role = ""
|
||||
if isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
role = "assistant"
|
||||
elif isinstance(message, SystemMessage):
|
||||
role = "system"
|
||||
else:
|
||||
raise ValueError("Received unsupported message type for Ollama.")
|
||||
|
||||
content = ""
|
||||
images = []
|
||||
if isinstance(message.content, str):
|
||||
content = message.content
|
||||
else:
|
||||
for content_part in message.content:
|
||||
if content_part.get("type") == "text":
|
||||
content += f"\n{content_part['text']}"
|
||||
elif content_part.get("type") == "image_url":
|
||||
if isinstance(content_part.get("image_url"), str):
|
||||
image_url_components = content_part["image_url"].split(",")
|
||||
# Support data:image/jpeg;base64,<image> format
|
||||
# and base64 strings
|
||||
if len(image_url_components) > 1:
|
||||
images.append(image_url_components[1])
|
||||
else:
|
||||
images.append(image_url_components[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only string image_url " "content parts are supported."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported message content type. "
|
||||
"Must either have type 'text' or type 'image_url' "
|
||||
"with a string 'image_url' field."
|
||||
)
|
||||
|
||||
ollama_messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": content,
|
||||
"images": images,
|
||||
}
|
||||
)
|
||||
|
||||
return ollama_messages
|
||||
|
||||
def _create_chat_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
payload = {
|
||||
"messages": self._convert_messages_to_ollama_messages(messages),
|
||||
}
|
||||
yield from self._create_stream(
|
||||
payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs
|
||||
)
|
||||
|
||||
def _chat_stream_with_aggregation(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> ChatGenerationChunk:
|
||||
final_chunk: Optional[ChatGenerationChunk] = None
|
||||
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
|
||||
if stream_resp:
|
||||
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
|
||||
if final_chunk is None:
|
||||
final_chunk = chunk
|
||||
else:
|
||||
final_chunk += chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=verbose,
|
||||
)
|
||||
if final_chunk is None:
|
||||
raise ValueError("No data received from Ollama stream.")
|
||||
|
||||
return final_chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -94,9 +206,12 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
])
|
||||
"""
|
||||
|
||||
prompt = self._format_messages_as_text(messages)
|
||||
final_chunk = super()._stream_with_aggregation(
|
||||
prompt, stop=stop, run_manager=run_manager, verbose=self.verbose, **kwargs
|
||||
final_chunk = self._chat_stream_with_aggregation(
|
||||
messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
verbose=self.verbose,
|
||||
**kwargs,
|
||||
)
|
||||
chat_generation = ChatGeneration(
|
||||
message=AIMessage(content=final_chunk.text),
|
||||
@ -110,9 +225,30 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
try:
|
||||
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
|
||||
if stream_resp:
|
||||
chunk = _stream_response_to_chat_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
except OllamaEndpointNotFoundError:
|
||||
yield from self._legacy_stream(messages, stop, **kwargs)
|
||||
|
||||
@deprecated("0.0.3", alternative="_stream")
|
||||
def _legacy_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
prompt = self._format_messages_as_text(messages)
|
||||
for stream_resp in self._create_stream(prompt, stop, **kwargs):
|
||||
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
|
||||
if stream_resp:
|
||||
chunk = _stream_response_to_chat_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
|
@ -20,6 +20,10 @@ def _stream_response_to_generation_chunk(
|
||||
)
|
||||
|
||||
|
||||
class OllamaEndpointNotFoundError(Exception):
|
||||
"""Raised when the Ollama endpoint is not found."""
|
||||
|
||||
|
||||
class _OllamaCommon(BaseLanguageModel):
|
||||
base_url: str = "http://localhost:11434"
|
||||
"""Base url the model is hosted under."""
|
||||
@ -129,10 +133,26 @@ class _OllamaCommon(BaseLanguageModel):
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model": self.model, "format": self.format}, **self._default_params}
|
||||
|
||||
def _create_stream(
|
||||
def _create_generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
images: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
payload = {"prompt": prompt, "images": images}
|
||||
yield from self._create_stream(
|
||||
payload=payload,
|
||||
stop=stop,
|
||||
api_url=f"{self.base_url}/api/generate/",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _create_stream(
|
||||
self,
|
||||
api_url: str,
|
||||
payload: Any,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
if self.stop is not None and stop is not None:
|
||||
@ -156,20 +176,34 @@ class _OllamaCommon(BaseLanguageModel):
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if payload.get("messages"):
|
||||
request_payload = {"messages": payload.get("messages", []), **params}
|
||||
else:
|
||||
request_payload = {
|
||||
"prompt": payload.get("prompt"),
|
||||
"images": payload.get("images", []),
|
||||
**params,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{self.base_url}/api/generate/",
|
||||
url=api_url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"prompt": prompt, **params},
|
||||
json=request_payload,
|
||||
stream=True,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.encoding = "utf-8"
|
||||
if response.status_code != 200:
|
||||
optional_detail = response.json().get("error")
|
||||
raise ValueError(
|
||||
f"Ollama call failed with status code {response.status_code}."
|
||||
f" Details: {optional_detail}"
|
||||
)
|
||||
if response.status_code == 404:
|
||||
raise OllamaEndpointNotFoundError(
|
||||
"Ollama call failed with status code 404."
|
||||
)
|
||||
else:
|
||||
optional_detail = response.json().get("error")
|
||||
raise ValueError(
|
||||
f"Ollama call failed with status code {response.status_code}."
|
||||
f" Details: {optional_detail}"
|
||||
)
|
||||
return response.iter_lines(decode_unicode=True)
|
||||
|
||||
def _stream_with_aggregation(
|
||||
@ -181,7 +215,7 @@ class _OllamaCommon(BaseLanguageModel):
|
||||
**kwargs: Any,
|
||||
) -> GenerationChunk:
|
||||
final_chunk: Optional[GenerationChunk] = None
|
||||
for stream_resp in self._create_stream(prompt, stop, **kwargs):
|
||||
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
|
||||
if stream_resp:
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
if final_chunk is None:
|
||||
@ -225,6 +259,7 @@ class Ollama(BaseLLM, _OllamaCommon):
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
images: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -248,6 +283,7 @@ class Ollama(BaseLLM, _OllamaCommon):
|
||||
final_chunk = super()._stream_with_aggregation(
|
||||
prompt,
|
||||
stop=stop,
|
||||
images=images,
|
||||
run_manager=run_manager,
|
||||
verbose=self.verbose,
|
||||
**kwargs,
|
||||
|
Loading…
Reference in New Issue
Block a user