community: Async Ollama + ChatOllama (#15169)

**Description:**
Adding async methods to booth OllamaLLM and ChatOllama to enable async
streaming and async .on_llm_new_token callbacks.

**Issue:**
ChatOllama is not working in combination with an AsyncCallbackManager
because the .on_llm_new_token method is not awaited.
This commit is contained in:
shroominic 2023-12-26 21:08:04 +01:00 committed by GitHub
parent 3154c9bc9f
commit e6f0cee896
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 262 additions and 4 deletions

View File

@ -1,8 +1,9 @@
import json import json
from typing import Any, Dict, Iterator, List, Optional, Union from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
@ -156,6 +157,20 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs
) )
async def _acreate_chat_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
payload = {
"messages": self._convert_messages_to_ollama_messages(messages),
}
async for stream_resp in self._acreate_stream(
payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs
):
yield stream_resp
def _chat_stream_with_aggregation( def _chat_stream_with_aggregation(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -182,6 +197,32 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
return final_chunk return final_chunk
async def _achat_stream_with_aggregation(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
) -> ChatGenerationChunk:
final_chunk: Optional[ChatGenerationChunk] = None
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
if stream_resp:
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("No data received from Ollama stream.")
return final_chunk
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -219,6 +260,43 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
) )
return ChatResult(generations=[chat_generation]) return ChatResult(generations=[chat_generation])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Call out to Ollama's generate endpoint.
Args:
messages: The list of base messages to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
Chat generations from the model
Example:
.. code-block:: python
response = ollama([
HumanMessage(content="Tell me about the history of AI")
])
"""
final_chunk = await self._achat_stream_with_aggregation(
messages,
stop=stop,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
)
chat_generation = ChatGeneration(
message=AIMessage(content=final_chunk.text),
generation_info=final_chunk.generation_info,
)
return ChatResult(generations=[chat_generation])
def _stream( def _stream(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -229,7 +307,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
try: try:
for stream_resp in self._create_chat_stream(messages, stop, **kwargs): for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if stream_resp: if stream_resp:
chunk = _stream_response_to_chat_generation_chunk(stream_resp) chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
yield chunk yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token( run_manager.on_llm_new_token(
@ -239,6 +317,29 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
except OllamaEndpointNotFoundError: except OllamaEndpointNotFoundError:
yield from self._legacy_stream(messages, stop, **kwargs) yield from self._legacy_stream(messages, stop, **kwargs)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
try:
async for stream_resp in self._create_async_chat_stream(
messages, stop, **kwargs
):
if stream_resp:
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
except OllamaEndpointNotFoundError:
async for chunk in self._legacy_astream(messages, stop, **kwargs):
yield chunk
@deprecated("0.0.3", alternative="_stream") @deprecated("0.0.3", alternative="_stream")
def _legacy_stream( def _legacy_stream(
self, self,

View File

@ -1,8 +1,12 @@
import json import json
from typing import Any, Dict, Iterator, List, Mapping, Optional from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
import aiohttp
import requests import requests
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.llms import BaseLLM from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import GenerationChunk, LLMResult from langchain_core.outputs import GenerationChunk, LLMResult
@ -148,6 +152,22 @@ class _OllamaCommon(BaseLanguageModel):
**kwargs, **kwargs,
) )
async def _acreate_generate_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
payload = {"prompt": prompt, "images": images}
async for item in self._acreate_stream(
payload=payload,
stop=stop,
api_url=f"{self.base_url}/api/generate/",
**kwargs,
):
yield item
def _create_stream( def _create_stream(
self, self,
api_url: str, api_url: str,
@ -208,6 +228,64 @@ class _OllamaCommon(BaseLanguageModel):
) )
return response.iter_lines(decode_unicode=True) return response.iter_lines(decode_unicode=True)
async def _acreate_stream(
self,
api_url: str,
payload: Any,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
elif stop is None:
stop = []
params = self._default_params
if "model" in kwargs:
params["model"] = kwargs["model"]
if "options" in kwargs:
params["options"] = kwargs["options"]
else:
params["options"] = {
**params["options"],
"stop": stop,
**kwargs,
}
if payload.get("messages"):
request_payload = {"messages": payload.get("messages", []), **params}
else:
request_payload = {
"prompt": payload.get("prompt"),
"images": payload.get("images", []),
**params,
}
async with aiohttp.ClientSession() as session:
async with session.post(
url=api_url,
headers={"Content-Type": "application/json"},
json=request_payload,
timeout=self.timeout,
) as response:
if response.status != 200:
if response.status == 404:
raise OllamaEndpointNotFoundError(
"Ollama call failed with status code 404."
)
else:
optional_detail = await response.json().get("error")
raise ValueError(
f"Ollama call failed with status code {response.status}."
f" Details: {optional_detail}"
)
async for line in response.content:
yield line.decode("utf-8")
def _stream_with_aggregation( def _stream_with_aggregation(
self, self,
prompt: str, prompt: str,
@ -234,6 +312,32 @@ class _OllamaCommon(BaseLanguageModel):
return final_chunk return final_chunk
async def _astream_with_aggregation(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
) -> GenerationChunk:
final_chunk: Optional[GenerationChunk] = None
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("No data received from Ollama stream.")
return final_chunk
class Ollama(BaseLLM, _OllamaCommon): class Ollama(BaseLLM, _OllamaCommon):
"""Ollama locally runs large language models. """Ollama locally runs large language models.
@ -293,6 +397,42 @@ class Ollama(BaseLLM, _OllamaCommon):
generations.append([final_chunk]) generations.append([final_chunk])
return LLMResult(generations=generations) return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Ollama's generate endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = ollama("Tell me a joke.")
"""
# TODO: add caching here.
generations = []
for prompt in prompts:
final_chunk = await super()._astream_with_aggregation(
prompt,
stop=stop,
images=images,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
)
generations.append([final_chunk])
return LLMResult(generations=generations)
def _stream( def _stream(
self, self,
prompt: str, prompt: str,
@ -309,3 +449,20 @@ class Ollama(BaseLLM, _OllamaCommon):
chunk.text, chunk.text,
verbose=self.verbose, verbose=self.verbose,
) )
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
async for stream_resp in self._acreate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)