Add image support for Ollama (#14713)

Support [LLaVA](https://ollama.ai/library/llava):
* Upgrade Ollama
* `ollama pull llava`

Ensure compatibility with [image prompt
template](https://github.com/langchain-ai/langchain/pull/14263)

---------

Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
This commit is contained in:
Lance Martin 2023-12-15 16:00:55 -08:00 committed by GitHub
parent 1075e7d6e8
commit 42421860bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 391 additions and 595 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,7 @@
import json import json
from typing import Any, Iterator, List, Optional from typing import Any, Dict, Iterator, List, Optional, Union
from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
@ -15,9 +16,10 @@ from langchain_core.messages import (
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_community.llms.ollama import _OllamaCommon from langchain_community.llms.ollama import OllamaEndpointNotFoundError, _OllamaCommon
@deprecated("0.0.3", alternative="_chat_stream_response_to_chat_generation_chunk")
def _stream_response_to_chat_generation_chunk( def _stream_response_to_chat_generation_chunk(
stream_response: str, stream_response: str,
) -> ChatGenerationChunk: ) -> ChatGenerationChunk:
@ -30,6 +32,20 @@ def _stream_response_to_chat_generation_chunk(
) )
def _chat_stream_response_to_chat_generation_chunk(
stream_response: str,
) -> ChatGenerationChunk:
"""Convert a stream response to a generation chunk."""
parsed_response = json.loads(stream_response)
generation_info = parsed_response if parsed_response.get("done") is True else None
return ChatGenerationChunk(
message=AIMessageChunk(
content=parsed_response.get("message", {}).get("content", "")
),
generation_info=generation_info,
)
class ChatOllama(BaseChatModel, _OllamaCommon): class ChatOllama(BaseChatModel, _OllamaCommon):
"""Ollama locally runs large language models. """Ollama locally runs large language models.
@ -52,11 +68,15 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
"""Return whether this model can be serialized by Langchain.""" """Return whether this model can be serialized by Langchain."""
return False return False
@deprecated("0.0.3", alternative="_convert_messages_to_ollama_messages")
def _format_message_as_text(self, message: BaseMessage) -> str: def _format_message_as_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}" message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage): elif isinstance(message, HumanMessage):
message_text = f"[INST] {message.content} [/INST]" if message.content[0].get("type") == "text":
message_text = f"[INST] {message.content[0]['text']} [/INST]"
elif message.content[0].get("type") == "image_url":
message_text = message.content[0]["image_url"]["url"]
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
message_text = f"{message.content}" message_text = f"{message.content}"
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
@ -70,6 +90,98 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
[self._format_message_as_text(message) for message in messages] [self._format_message_as_text(message) for message in messages]
) )
def _convert_messages_to_ollama_messages(
self, messages: List[BaseMessage]
) -> List[Dict[str, Union[str, List[str]]]]:
ollama_messages = []
for message in messages:
role = ""
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
role = "assistant"
elif isinstance(message, SystemMessage):
role = "system"
else:
raise ValueError("Received unsupported message type for Ollama.")
content = ""
images = []
if isinstance(message.content, str):
content = message.content
else:
for content_part in message.content:
if content_part.get("type") == "text":
content += f"\n{content_part['text']}"
elif content_part.get("type") == "image_url":
if isinstance(content_part.get("image_url"), str):
image_url_components = content_part["image_url"].split(",")
# Support data:image/jpeg;base64,<image> format
# and base64 strings
if len(image_url_components) > 1:
images.append(image_url_components[1])
else:
images.append(image_url_components[0])
else:
raise ValueError(
"Only string image_url " "content parts are supported."
)
else:
raise ValueError(
"Unsupported message content type. "
"Must either have type 'text' or type 'image_url' "
"with a string 'image_url' field."
)
ollama_messages.append(
{
"role": role,
"content": content,
"images": images,
}
)
return ollama_messages
def _create_chat_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
payload = {
"messages": self._convert_messages_to_ollama_messages(messages),
}
yield from self._create_stream(
payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs
)
def _chat_stream_with_aggregation(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
) -> ChatGenerationChunk:
final_chunk: Optional[ChatGenerationChunk] = None
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if stream_resp:
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("No data received from Ollama stream.")
return final_chunk
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -94,9 +206,12 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
]) ])
""" """
prompt = self._format_messages_as_text(messages) final_chunk = self._chat_stream_with_aggregation(
final_chunk = super()._stream_with_aggregation( messages,
prompt, stop=stop, run_manager=run_manager, verbose=self.verbose, **kwargs stop=stop,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
) )
chat_generation = ChatGeneration( chat_generation = ChatGeneration(
message=AIMessage(content=final_chunk.text), message=AIMessage(content=final_chunk.text),
@ -111,8 +226,29 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
prompt = self._format_messages_as_text(messages) try:
for stream_resp in self._create_stream(prompt, stop, **kwargs): for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
except OllamaEndpointNotFoundError:
yield from self._legacy_stream(messages, stop, **kwargs)
@deprecated("0.0.3", alternative="_stream")
def _legacy_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
prompt = self._format_messages_as_text(messages)
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if stream_resp: if stream_resp:
chunk = _stream_response_to_chat_generation_chunk(stream_resp) chunk = _stream_response_to_chat_generation_chunk(stream_resp)
yield chunk yield chunk

View File

@ -20,6 +20,10 @@ def _stream_response_to_generation_chunk(
) )
class OllamaEndpointNotFoundError(Exception):
"""Raised when the Ollama endpoint is not found."""
class _OllamaCommon(BaseLanguageModel): class _OllamaCommon(BaseLanguageModel):
base_url: str = "http://localhost:11434" base_url: str = "http://localhost:11434"
"""Base url the model is hosted under.""" """Base url the model is hosted under."""
@ -129,10 +133,26 @@ class _OllamaCommon(BaseLanguageModel):
"""Get the identifying parameters.""" """Get the identifying parameters."""
return {**{"model": self.model, "format": self.format}, **self._default_params} return {**{"model": self.model, "format": self.format}, **self._default_params}
def _create_stream( def _create_generate_stream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
payload = {"prompt": prompt, "images": images}
yield from self._create_stream(
payload=payload,
stop=stop,
api_url=f"{self.base_url}/api/generate/",
**kwargs,
)
def _create_stream(
self,
api_url: str,
payload: Any,
stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[str]: ) -> Iterator[str]:
if self.stop is not None and stop is not None: if self.stop is not None and stop is not None:
@ -156,15 +176,29 @@ class _OllamaCommon(BaseLanguageModel):
**kwargs, **kwargs,
} }
if payload.get("messages"):
request_payload = {"messages": payload.get("messages", []), **params}
else:
request_payload = {
"prompt": payload.get("prompt"),
"images": payload.get("images", []),
**params,
}
response = requests.post( response = requests.post(
url=f"{self.base_url}/api/generate/", url=api_url,
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
json={"prompt": prompt, **params}, json=request_payload,
stream=True, stream=True,
timeout=self.timeout, timeout=self.timeout,
) )
response.encoding = "utf-8" response.encoding = "utf-8"
if response.status_code != 200: if response.status_code != 200:
if response.status_code == 404:
raise OllamaEndpointNotFoundError(
"Ollama call failed with status code 404."
)
else:
optional_detail = response.json().get("error") optional_detail = response.json().get("error")
raise ValueError( raise ValueError(
f"Ollama call failed with status code {response.status_code}." f"Ollama call failed with status code {response.status_code}."
@ -181,7 +215,7 @@ class _OllamaCommon(BaseLanguageModel):
**kwargs: Any, **kwargs: Any,
) -> GenerationChunk: ) -> GenerationChunk:
final_chunk: Optional[GenerationChunk] = None final_chunk: Optional[GenerationChunk] = None
for stream_resp in self._create_stream(prompt, stop, **kwargs): for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if stream_resp: if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp) chunk = _stream_response_to_generation_chunk(stream_resp)
if final_chunk is None: if final_chunk is None:
@ -225,6 +259,7 @@ class Ollama(BaseLLM, _OllamaCommon):
self, self,
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -248,6 +283,7 @@ class Ollama(BaseLLM, _OllamaCommon):
final_chunk = super()._stream_with_aggregation( final_chunk = super()._stream_with_aggregation(
prompt, prompt,
stop=stop, stop=stop,
images=images,
run_manager=run_manager, run_manager=run_manager,
verbose=self.verbose, verbose=self.verbose,
**kwargs, **kwargs,