Add image support for Ollama (#14713)

Support [LLaVA](https://ollama.ai/library/llava):
* Upgrade Ollama
* `ollama pull llava`

Ensure compatibility with [image prompt
template](https://github.com/langchain-ai/langchain/pull/14263)

---------

Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
This commit is contained in:
Lance Martin 2023-12-15 16:00:55 -08:00 committed by GitHub
parent 1075e7d6e8
commit 42421860bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 391 additions and 595 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,7 @@
import json
from typing import Any, Iterator, List, Optional
from typing import Any, Dict, Iterator, List, Optional, Union
from langchain_core._api import deprecated
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
@ -15,9 +16,10 @@ from langchain_core.messages import (
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_community.llms.ollama import _OllamaCommon
from langchain_community.llms.ollama import OllamaEndpointNotFoundError, _OllamaCommon
@deprecated("0.0.3", alternative="_chat_stream_response_to_chat_generation_chunk")
def _stream_response_to_chat_generation_chunk(
stream_response: str,
) -> ChatGenerationChunk:
@ -30,6 +32,20 @@ def _stream_response_to_chat_generation_chunk(
)
def _chat_stream_response_to_chat_generation_chunk(
stream_response: str,
) -> ChatGenerationChunk:
"""Convert a stream response to a generation chunk."""
parsed_response = json.loads(stream_response)
generation_info = parsed_response if parsed_response.get("done") is True else None
return ChatGenerationChunk(
message=AIMessageChunk(
content=parsed_response.get("message", {}).get("content", "")
),
generation_info=generation_info,
)
class ChatOllama(BaseChatModel, _OllamaCommon):
"""Ollama locally runs large language models.
@ -52,11 +68,15 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
"""Return whether this model can be serialized by Langchain."""
return False
@deprecated("0.0.3", alternative="_convert_messages_to_ollama_messages")
def _format_message_as_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"[INST] {message.content} [/INST]"
if message.content[0].get("type") == "text":
message_text = f"[INST] {message.content[0]['text']} [/INST]"
elif message.content[0].get("type") == "image_url":
message_text = message.content[0]["image_url"]["url"]
elif isinstance(message, AIMessage):
message_text = f"{message.content}"
elif isinstance(message, SystemMessage):
@ -70,6 +90,98 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
[self._format_message_as_text(message) for message in messages]
)
def _convert_messages_to_ollama_messages(
self, messages: List[BaseMessage]
) -> List[Dict[str, Union[str, List[str]]]]:
ollama_messages = []
for message in messages:
role = ""
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
role = "assistant"
elif isinstance(message, SystemMessage):
role = "system"
else:
raise ValueError("Received unsupported message type for Ollama.")
content = ""
images = []
if isinstance(message.content, str):
content = message.content
else:
for content_part in message.content:
if content_part.get("type") == "text":
content += f"\n{content_part['text']}"
elif content_part.get("type") == "image_url":
if isinstance(content_part.get("image_url"), str):
image_url_components = content_part["image_url"].split(",")
# Support data:image/jpeg;base64,<image> format
# and base64 strings
if len(image_url_components) > 1:
images.append(image_url_components[1])
else:
images.append(image_url_components[0])
else:
raise ValueError(
"Only string image_url " "content parts are supported."
)
else:
raise ValueError(
"Unsupported message content type. "
"Must either have type 'text' or type 'image_url' "
"with a string 'image_url' field."
)
ollama_messages.append(
{
"role": role,
"content": content,
"images": images,
}
)
return ollama_messages
def _create_chat_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
payload = {
"messages": self._convert_messages_to_ollama_messages(messages),
}
yield from self._create_stream(
payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs
)
def _chat_stream_with_aggregation(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
) -> ChatGenerationChunk:
final_chunk: Optional[ChatGenerationChunk] = None
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if stream_resp:
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("No data received from Ollama stream.")
return final_chunk
def _generate(
self,
messages: List[BaseMessage],
@ -94,9 +206,12 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
])
"""
prompt = self._format_messages_as_text(messages)
final_chunk = super()._stream_with_aggregation(
prompt, stop=stop, run_manager=run_manager, verbose=self.verbose, **kwargs
final_chunk = self._chat_stream_with_aggregation(
messages,
stop=stop,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
)
chat_generation = ChatGeneration(
message=AIMessage(content=final_chunk.text),
@ -110,9 +225,30 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
try:
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
except OllamaEndpointNotFoundError:
yield from self._legacy_stream(messages, stop, **kwargs)
@deprecated("0.0.3", alternative="_stream")
def _legacy_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
prompt = self._format_messages_as_text(messages)
for stream_resp in self._create_stream(prompt, stop, **kwargs):
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_chat_generation_chunk(stream_resp)
yield chunk

View File

@ -20,6 +20,10 @@ def _stream_response_to_generation_chunk(
)
class OllamaEndpointNotFoundError(Exception):
"""Raised when the Ollama endpoint is not found."""
class _OllamaCommon(BaseLanguageModel):
base_url: str = "http://localhost:11434"
"""Base url the model is hosted under."""
@ -129,10 +133,26 @@ class _OllamaCommon(BaseLanguageModel):
"""Get the identifying parameters."""
return {**{"model": self.model, "format": self.format}, **self._default_params}
def _create_stream(
def _create_generate_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
payload = {"prompt": prompt, "images": images}
yield from self._create_stream(
payload=payload,
stop=stop,
api_url=f"{self.base_url}/api/generate/",
**kwargs,
)
def _create_stream(
self,
api_url: str,
payload: Any,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
if self.stop is not None and stop is not None:
@ -156,20 +176,34 @@ class _OllamaCommon(BaseLanguageModel):
**kwargs,
}
if payload.get("messages"):
request_payload = {"messages": payload.get("messages", []), **params}
else:
request_payload = {
"prompt": payload.get("prompt"),
"images": payload.get("images", []),
**params,
}
response = requests.post(
url=f"{self.base_url}/api/generate/",
url=api_url,
headers={"Content-Type": "application/json"},
json={"prompt": prompt, **params},
json=request_payload,
stream=True,
timeout=self.timeout,
)
response.encoding = "utf-8"
if response.status_code != 200:
optional_detail = response.json().get("error")
raise ValueError(
f"Ollama call failed with status code {response.status_code}."
f" Details: {optional_detail}"
)
if response.status_code == 404:
raise OllamaEndpointNotFoundError(
"Ollama call failed with status code 404."
)
else:
optional_detail = response.json().get("error")
raise ValueError(
f"Ollama call failed with status code {response.status_code}."
f" Details: {optional_detail}"
)
return response.iter_lines(decode_unicode=True)
def _stream_with_aggregation(
@ -181,7 +215,7 @@ class _OllamaCommon(BaseLanguageModel):
**kwargs: Any,
) -> GenerationChunk:
final_chunk: Optional[GenerationChunk] = None
for stream_resp in self._create_stream(prompt, stop, **kwargs):
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
if final_chunk is None:
@ -225,6 +259,7 @@ class Ollama(BaseLLM, _OllamaCommon):
self,
prompts: List[str],
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
@ -248,6 +283,7 @@ class Ollama(BaseLLM, _OllamaCommon):
final_chunk = super()._stream_with_aggregation(
prompt,
stop=stop,
images=images,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,