From e6f0cee896592b25523f00a726f731538f704802 Mon Sep 17 00:00:00 2001 From: shroominic <34897716+shroominic@users.noreply.github.com> Date: Tue, 26 Dec 2023 21:08:04 +0100 Subject: [PATCH] community: Async Ollama + ChatOllama (#15169) **Description:** Adding async methods to booth OllamaLLM and ChatOllama to enable async streaming and async .on_llm_new_token callbacks. **Issue:** ChatOllama is not working in combination with an AsyncCallbackManager because the .on_llm_new_token method is not awaited. --- .../langchain_community/chat_models/ollama.py | 105 +++++++++++- .../langchain_community/llms/ollama.py | 161 +++++++++++++++++- 2 files changed, 262 insertions(+), 4 deletions(-) diff --git a/libs/community/langchain_community/chat_models/ollama.py b/libs/community/langchain_community/chat_models/ollama.py index 54aa8a8c8cf..d9ad14a26e3 100644 --- a/libs/community/langchain_community/chat_models/ollama.py +++ b/libs/community/langchain_community/chat_models/ollama.py @@ -1,8 +1,9 @@ import json -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union from langchain_core._api import deprecated from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models.chat_models import BaseChatModel @@ -156,6 +157,20 @@ class ChatOllama(BaseChatModel, _OllamaCommon): payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs ) + async def _acreate_chat_stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> AsyncIterator[str]: + payload = { + "messages": self._convert_messages_to_ollama_messages(messages), + } + async for stream_resp in self._acreate_stream( + payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs + ): + yield stream_resp + def _chat_stream_with_aggregation( self, messages: List[BaseMessage], @@ -182,6 +197,32 @@ class ChatOllama(BaseChatModel, _OllamaCommon): return final_chunk + async def _achat_stream_with_aggregation( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + verbose: bool = False, + **kwargs: Any, + ) -> ChatGenerationChunk: + final_chunk: Optional[ChatGenerationChunk] = None + async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs): + if stream_resp: + chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp) + if final_chunk is None: + final_chunk = chunk + else: + final_chunk += chunk + if run_manager: + await run_manager.on_llm_new_token( + chunk.text, + verbose=verbose, + ) + if final_chunk is None: + raise ValueError("No data received from Ollama stream.") + + return final_chunk + def _generate( self, messages: List[BaseMessage], @@ -219,6 +260,43 @@ class ChatOllama(BaseChatModel, _OllamaCommon): ) return ChatResult(generations=[chat_generation]) + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Call out to Ollama's generate endpoint. + + Args: + messages: The list of base messages to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + Chat generations from the model + + Example: + .. code-block:: python + + response = ollama([ + HumanMessage(content="Tell me about the history of AI") + ]) + """ + + final_chunk = await self._achat_stream_with_aggregation( + messages, + stop=stop, + run_manager=run_manager, + verbose=self.verbose, + **kwargs, + ) + chat_generation = ChatGeneration( + message=AIMessage(content=final_chunk.text), + generation_info=final_chunk.generation_info, + ) + return ChatResult(generations=[chat_generation]) + def _stream( self, messages: List[BaseMessage], @@ -229,7 +307,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon): try: for stream_resp in self._create_chat_stream(messages, stop, **kwargs): if stream_resp: - chunk = _stream_response_to_chat_generation_chunk(stream_resp) + chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp) yield chunk if run_manager: run_manager.on_llm_new_token( @@ -239,6 +317,29 @@ class ChatOllama(BaseChatModel, _OllamaCommon): except OllamaEndpointNotFoundError: yield from self._legacy_stream(messages, stop, **kwargs) + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + try: + async for stream_resp in self._create_async_chat_stream( + messages, stop, **kwargs + ): + if stream_resp: + chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp) + yield chunk + if run_manager: + await run_manager.on_llm_new_token( + chunk.text, + verbose=self.verbose, + ) + except OllamaEndpointNotFoundError: + async for chunk in self._legacy_astream(messages, stop, **kwargs): + yield chunk + @deprecated("0.0.3", alternative="_stream") def _legacy_stream( self, diff --git a/libs/community/langchain_community/llms/ollama.py b/libs/community/langchain_community/llms/ollama.py index 85e9d72d3ea..afe0aed5711 100644 --- a/libs/community/langchain_community/llms/ollama.py +++ b/libs/community/langchain_community/llms/ollama.py @@ -1,8 +1,12 @@ import json -from typing import Any, Dict, Iterator, List, Mapping, Optional +from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional +import aiohttp import requests -from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import GenerationChunk, LLMResult @@ -148,6 +152,22 @@ class _OllamaCommon(BaseLanguageModel): **kwargs, ) + async def _acreate_generate_stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + images: Optional[List[str]] = None, + **kwargs: Any, + ) -> AsyncIterator[str]: + payload = {"prompt": prompt, "images": images} + async for item in self._acreate_stream( + payload=payload, + stop=stop, + api_url=f"{self.base_url}/api/generate/", + **kwargs, + ): + yield item + def _create_stream( self, api_url: str, @@ -208,6 +228,64 @@ class _OllamaCommon(BaseLanguageModel): ) return response.iter_lines(decode_unicode=True) + async def _acreate_stream( + self, + api_url: str, + payload: Any, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> AsyncIterator[str]: + if self.stop is not None and stop is not None: + raise ValueError("`stop` found in both the input and default params.") + elif self.stop is not None: + stop = self.stop + elif stop is None: + stop = [] + + params = self._default_params + + if "model" in kwargs: + params["model"] = kwargs["model"] + + if "options" in kwargs: + params["options"] = kwargs["options"] + else: + params["options"] = { + **params["options"], + "stop": stop, + **kwargs, + } + + if payload.get("messages"): + request_payload = {"messages": payload.get("messages", []), **params} + else: + request_payload = { + "prompt": payload.get("prompt"), + "images": payload.get("images", []), + **params, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + url=api_url, + headers={"Content-Type": "application/json"}, + json=request_payload, + timeout=self.timeout, + ) as response: + if response.status != 200: + if response.status == 404: + raise OllamaEndpointNotFoundError( + "Ollama call failed with status code 404." + ) + else: + optional_detail = await response.json().get("error") + raise ValueError( + f"Ollama call failed with status code {response.status}." + f" Details: {optional_detail}" + ) + async for line in response.content: + yield line.decode("utf-8") + def _stream_with_aggregation( self, prompt: str, @@ -234,6 +312,32 @@ class _OllamaCommon(BaseLanguageModel): return final_chunk + async def _astream_with_aggregation( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + verbose: bool = False, + **kwargs: Any, + ) -> GenerationChunk: + final_chunk: Optional[GenerationChunk] = None + async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs): + if stream_resp: + chunk = _stream_response_to_generation_chunk(stream_resp) + if final_chunk is None: + final_chunk = chunk + else: + final_chunk += chunk + if run_manager: + await run_manager.on_llm_new_token( + chunk.text, + verbose=verbose, + ) + if final_chunk is None: + raise ValueError("No data received from Ollama stream.") + + return final_chunk + class Ollama(BaseLLM, _OllamaCommon): """Ollama locally runs large language models. @@ -293,6 +397,42 @@ class Ollama(BaseLLM, _OllamaCommon): generations.append([final_chunk]) return LLMResult(generations=generations) + async def _agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + images: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Call out to Ollama's generate endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = ollama("Tell me a joke.") + """ + # TODO: add caching here. + generations = [] + for prompt in prompts: + final_chunk = await super()._astream_with_aggregation( + prompt, + stop=stop, + images=images, + run_manager=run_manager, + verbose=self.verbose, + **kwargs, + ) + generations.append([final_chunk]) + return LLMResult(generations=generations) + def _stream( self, prompt: str, @@ -309,3 +449,20 @@ class Ollama(BaseLLM, _OllamaCommon): chunk.text, verbose=self.verbose, ) + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + async for stream_resp in self._acreate_stream(prompt, stop, **kwargs): + if stream_resp: + chunk = _stream_response_to_generation_chunk(stream_resp) + yield chunk + if run_manager: + await run_manager.on_llm_new_token( + chunk.text, + verbose=self.verbose, + )