mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
community: Async Ollama + ChatOllama (#15169)
**Description:** Adding async methods to booth OllamaLLM and ChatOllama to enable async streaming and async .on_llm_new_token callbacks. **Issue:** ChatOllama is not working in combination with an AsyncCallbackManager because the .on_llm_new_token method is not awaited.
This commit is contained in:
parent
3154c9bc9f
commit
e6f0cee896
@ -1,8 +1,9 @@
|
||||
import json
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
@ -156,6 +157,20 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs
|
||||
)
|
||||
|
||||
async def _acreate_chat_stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
payload = {
|
||||
"messages": self._convert_messages_to_ollama_messages(messages),
|
||||
}
|
||||
async for stream_resp in self._acreate_stream(
|
||||
payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs
|
||||
):
|
||||
yield stream_resp
|
||||
|
||||
def _chat_stream_with_aggregation(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -182,6 +197,32 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
|
||||
return final_chunk
|
||||
|
||||
async def _achat_stream_with_aggregation(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> ChatGenerationChunk:
|
||||
final_chunk: Optional[ChatGenerationChunk] = None
|
||||
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
|
||||
if stream_resp:
|
||||
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
|
||||
if final_chunk is None:
|
||||
final_chunk = chunk
|
||||
else:
|
||||
final_chunk += chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=verbose,
|
||||
)
|
||||
if final_chunk is None:
|
||||
raise ValueError("No data received from Ollama stream.")
|
||||
|
||||
return final_chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -219,6 +260,43 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
)
|
||||
return ChatResult(generations=[chat_generation])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Call out to Ollama's generate endpoint.
|
||||
|
||||
Args:
|
||||
messages: The list of base messages to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
Chat generations from the model
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = ollama([
|
||||
HumanMessage(content="Tell me about the history of AI")
|
||||
])
|
||||
"""
|
||||
|
||||
final_chunk = await self._achat_stream_with_aggregation(
|
||||
messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
verbose=self.verbose,
|
||||
**kwargs,
|
||||
)
|
||||
chat_generation = ChatGeneration(
|
||||
message=AIMessage(content=final_chunk.text),
|
||||
generation_info=final_chunk.generation_info,
|
||||
)
|
||||
return ChatResult(generations=[chat_generation])
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -229,7 +307,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
try:
|
||||
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
|
||||
if stream_resp:
|
||||
chunk = _stream_response_to_chat_generation_chunk(stream_resp)
|
||||
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
@ -239,6 +317,29 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
except OllamaEndpointNotFoundError:
|
||||
yield from self._legacy_stream(messages, stop, **kwargs)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
try:
|
||||
async for stream_resp in self._create_async_chat_stream(
|
||||
messages, stop, **kwargs
|
||||
):
|
||||
if stream_resp:
|
||||
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
except OllamaEndpointNotFoundError:
|
||||
async for chunk in self._legacy_astream(messages, stop, **kwargs):
|
||||
yield chunk
|
||||
|
||||
@deprecated("0.0.3", alternative="_stream")
|
||||
def _legacy_stream(
|
||||
self,
|
||||
|
@ -1,8 +1,12 @@
|
||||
import json
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.outputs import GenerationChunk, LLMResult
|
||||
@ -148,6 +152,22 @@ class _OllamaCommon(BaseLanguageModel):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _acreate_generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
images: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
payload = {"prompt": prompt, "images": images}
|
||||
async for item in self._acreate_stream(
|
||||
payload=payload,
|
||||
stop=stop,
|
||||
api_url=f"{self.base_url}/api/generate/",
|
||||
**kwargs,
|
||||
):
|
||||
yield item
|
||||
|
||||
def _create_stream(
|
||||
self,
|
||||
api_url: str,
|
||||
@ -208,6 +228,64 @@ class _OllamaCommon(BaseLanguageModel):
|
||||
)
|
||||
return response.iter_lines(decode_unicode=True)
|
||||
|
||||
async def _acreate_stream(
|
||||
self,
|
||||
api_url: str,
|
||||
payload: Any,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
stop = self.stop
|
||||
elif stop is None:
|
||||
stop = []
|
||||
|
||||
params = self._default_params
|
||||
|
||||
if "model" in kwargs:
|
||||
params["model"] = kwargs["model"]
|
||||
|
||||
if "options" in kwargs:
|
||||
params["options"] = kwargs["options"]
|
||||
else:
|
||||
params["options"] = {
|
||||
**params["options"],
|
||||
"stop": stop,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if payload.get("messages"):
|
||||
request_payload = {"messages": payload.get("messages", []), **params}
|
||||
else:
|
||||
request_payload = {
|
||||
"prompt": payload.get("prompt"),
|
||||
"images": payload.get("images", []),
|
||||
**params,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url=api_url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=request_payload,
|
||||
timeout=self.timeout,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
if response.status == 404:
|
||||
raise OllamaEndpointNotFoundError(
|
||||
"Ollama call failed with status code 404."
|
||||
)
|
||||
else:
|
||||
optional_detail = await response.json().get("error")
|
||||
raise ValueError(
|
||||
f"Ollama call failed with status code {response.status}."
|
||||
f" Details: {optional_detail}"
|
||||
)
|
||||
async for line in response.content:
|
||||
yield line.decode("utf-8")
|
||||
|
||||
def _stream_with_aggregation(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -234,6 +312,32 @@ class _OllamaCommon(BaseLanguageModel):
|
||||
|
||||
return final_chunk
|
||||
|
||||
async def _astream_with_aggregation(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> GenerationChunk:
|
||||
final_chunk: Optional[GenerationChunk] = None
|
||||
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
|
||||
if stream_resp:
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
if final_chunk is None:
|
||||
final_chunk = chunk
|
||||
else:
|
||||
final_chunk += chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=verbose,
|
||||
)
|
||||
if final_chunk is None:
|
||||
raise ValueError("No data received from Ollama stream.")
|
||||
|
||||
return final_chunk
|
||||
|
||||
|
||||
class Ollama(BaseLLM, _OllamaCommon):
|
||||
"""Ollama locally runs large language models.
|
||||
@ -293,6 +397,42 @@ class Ollama(BaseLLM, _OllamaCommon):
|
||||
generations.append([final_chunk])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
images: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out to Ollama's generate endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = ollama("Tell me a joke.")
|
||||
"""
|
||||
# TODO: add caching here.
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
final_chunk = await super()._astream_with_aggregation(
|
||||
prompt,
|
||||
stop=stop,
|
||||
images=images,
|
||||
run_manager=run_manager,
|
||||
verbose=self.verbose,
|
||||
**kwargs,
|
||||
)
|
||||
generations.append([final_chunk])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -309,3 +449,20 @@ class Ollama(BaseLLM, _OllamaCommon):
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
async for stream_resp in self._acreate_stream(prompt, stop, **kwargs):
|
||||
if stream_resp:
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user