Add image support for Ollama (#14713)

Support [LLaVA](https://ollama.ai/library/llava):
* Upgrade Ollama
* `ollama pull llava`

Ensure compatibility with [image prompt
template](https://github.com/langchain-ai/langchain/pull/14263)

---------

Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
This commit is contained in:
Lance Martin
2023-12-15 16:00:55 -08:00
committed by GitHub
parent 1075e7d6e8
commit 42421860bc
4 changed files with 391 additions and 595 deletions

View File

@@ -20,6 +20,10 @@ def _stream_response_to_generation_chunk(
)
class OllamaEndpointNotFoundError(Exception):
"""Raised when the Ollama endpoint is not found."""
class _OllamaCommon(BaseLanguageModel):
base_url: str = "http://localhost:11434"
"""Base url the model is hosted under."""
@@ -129,10 +133,26 @@ class _OllamaCommon(BaseLanguageModel):
"""Get the identifying parameters."""
return {**{"model": self.model, "format": self.format}, **self._default_params}
def _create_stream(
def _create_generate_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
payload = {"prompt": prompt, "images": images}
yield from self._create_stream(
payload=payload,
stop=stop,
api_url=f"{self.base_url}/api/generate/",
**kwargs,
)
def _create_stream(
self,
api_url: str,
payload: Any,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
if self.stop is not None and stop is not None:
@@ -156,20 +176,34 @@ class _OllamaCommon(BaseLanguageModel):
**kwargs,
}
if payload.get("messages"):
request_payload = {"messages": payload.get("messages", []), **params}
else:
request_payload = {
"prompt": payload.get("prompt"),
"images": payload.get("images", []),
**params,
}
response = requests.post(
url=f"{self.base_url}/api/generate/",
url=api_url,
headers={"Content-Type": "application/json"},
json={"prompt": prompt, **params},
json=request_payload,
stream=True,
timeout=self.timeout,
)
response.encoding = "utf-8"
if response.status_code != 200:
optional_detail = response.json().get("error")
raise ValueError(
f"Ollama call failed with status code {response.status_code}."
f" Details: {optional_detail}"
)
if response.status_code == 404:
raise OllamaEndpointNotFoundError(
"Ollama call failed with status code 404."
)
else:
optional_detail = response.json().get("error")
raise ValueError(
f"Ollama call failed with status code {response.status_code}."
f" Details: {optional_detail}"
)
return response.iter_lines(decode_unicode=True)
def _stream_with_aggregation(
@@ -181,7 +215,7 @@ class _OllamaCommon(BaseLanguageModel):
**kwargs: Any,
) -> GenerationChunk:
final_chunk: Optional[GenerationChunk] = None
for stream_resp in self._create_stream(prompt, stop, **kwargs):
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
if final_chunk is None:
@@ -225,6 +259,7 @@ class Ollama(BaseLLM, _OllamaCommon):
self,
prompts: List[str],
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
@@ -248,6 +283,7 @@ class Ollama(BaseLLM, _OllamaCommon):
final_chunk = super()._stream_with_aggregation(
prompt,
stop=stop,
images=images,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,