community: Async Ollama + ChatOllama (#15169)

**Description:**
Adding async methods to booth OllamaLLM and ChatOllama to enable async
streaming and async .on_llm_new_token callbacks.

**Issue:**
ChatOllama is not working in combination with an AsyncCallbackManager
because the .on_llm_new_token method is not awaited.
This commit is contained in:
shroominic 2023-12-26 21:08:04 +01:00 committed by GitHub
parent 3154c9bc9f
commit e6f0cee896
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 262 additions and 4 deletions

View File

@ -1,8 +1,9 @@
import json
from typing import Any, Dict, Iterator, List, Optional, Union
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
@ -156,6 +157,20 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs
)
async def _acreate_chat_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
payload = {
"messages": self._convert_messages_to_ollama_messages(messages),
}
async for stream_resp in self._acreate_stream(
payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs
):
yield stream_resp
def _chat_stream_with_aggregation(
self,
messages: List[BaseMessage],
@ -182,6 +197,32 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
return final_chunk
async def _achat_stream_with_aggregation(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
) -> ChatGenerationChunk:
final_chunk: Optional[ChatGenerationChunk] = None
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
if stream_resp:
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("No data received from Ollama stream.")
return final_chunk
def _generate(
self,
messages: List[BaseMessage],
@ -219,6 +260,43 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
)
return ChatResult(generations=[chat_generation])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Call out to Ollama's generate endpoint.
Args:
messages: The list of base messages to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
Chat generations from the model
Example:
.. code-block:: python
response = ollama([
HumanMessage(content="Tell me about the history of AI")
])
"""
final_chunk = await self._achat_stream_with_aggregation(
messages,
stop=stop,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
)
chat_generation = ChatGeneration(
message=AIMessage(content=final_chunk.text),
generation_info=final_chunk.generation_info,
)
return ChatResult(generations=[chat_generation])
def _stream(
self,
messages: List[BaseMessage],
@ -229,7 +307,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
try:
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_chat_generation_chunk(stream_resp)
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
@ -239,6 +317,29 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
except OllamaEndpointNotFoundError:
yield from self._legacy_stream(messages, stop, **kwargs)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
try:
async for stream_resp in self._create_async_chat_stream(
messages, stop, **kwargs
):
if stream_resp:
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
except OllamaEndpointNotFoundError:
async for chunk in self._legacy_astream(messages, stop, **kwargs):
yield chunk
@deprecated("0.0.3", alternative="_stream")
def _legacy_stream(
self,

View File

@ -1,8 +1,12 @@
import json
from typing import Any, Dict, Iterator, List, Mapping, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
import aiohttp
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import GenerationChunk, LLMResult
@ -148,6 +152,22 @@ class _OllamaCommon(BaseLanguageModel):
**kwargs,
)
async def _acreate_generate_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
payload = {"prompt": prompt, "images": images}
async for item in self._acreate_stream(
payload=payload,
stop=stop,
api_url=f"{self.base_url}/api/generate/",
**kwargs,
):
yield item
def _create_stream(
self,
api_url: str,
@ -208,6 +228,64 @@ class _OllamaCommon(BaseLanguageModel):
)
return response.iter_lines(decode_unicode=True)
async def _acreate_stream(
self,
api_url: str,
payload: Any,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
elif stop is None:
stop = []
params = self._default_params
if "model" in kwargs:
params["model"] = kwargs["model"]
if "options" in kwargs:
params["options"] = kwargs["options"]
else:
params["options"] = {
**params["options"],
"stop": stop,
**kwargs,
}
if payload.get("messages"):
request_payload = {"messages": payload.get("messages", []), **params}
else:
request_payload = {
"prompt": payload.get("prompt"),
"images": payload.get("images", []),
**params,
}
async with aiohttp.ClientSession() as session:
async with session.post(
url=api_url,
headers={"Content-Type": "application/json"},
json=request_payload,
timeout=self.timeout,
) as response:
if response.status != 200:
if response.status == 404:
raise OllamaEndpointNotFoundError(
"Ollama call failed with status code 404."
)
else:
optional_detail = await response.json().get("error")
raise ValueError(
f"Ollama call failed with status code {response.status}."
f" Details: {optional_detail}"
)
async for line in response.content:
yield line.decode("utf-8")
def _stream_with_aggregation(
self,
prompt: str,
@ -234,6 +312,32 @@ class _OllamaCommon(BaseLanguageModel):
return final_chunk
async def _astream_with_aggregation(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
) -> GenerationChunk:
final_chunk: Optional[GenerationChunk] = None
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("No data received from Ollama stream.")
return final_chunk
class Ollama(BaseLLM, _OllamaCommon):
"""Ollama locally runs large language models.
@ -293,6 +397,42 @@ class Ollama(BaseLLM, _OllamaCommon):
generations.append([final_chunk])
return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Ollama's generate endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = ollama("Tell me a joke.")
"""
# TODO: add caching here.
generations = []
for prompt in prompts:
final_chunk = await super()._astream_with_aggregation(
prompt,
stop=stop,
images=images,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
)
generations.append([final_chunk])
return LLMResult(generations=generations)
def _stream(
self,
prompt: str,
@ -309,3 +449,20 @@ class Ollama(BaseLLM, _OllamaCommon):
chunk.text,
verbose=self.verbose,
)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
async for stream_resp in self._acreate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)