fix ollama streaming

This commit is contained in:
Mason Daugherty
2025-08-06 13:32:41 -04:00
parent 4d261089c6
commit 8a3e049a9e
2 changed files with 33 additions and 13 deletions

View File

@@ -275,6 +275,13 @@ class ChatOllama(BaseChatModel):
model: str
"""Model name to use."""
streaming: bool = False
"""Whether to use streaming for invocation.
If True, invoke will use streaming internally.
"""
reasoning: Optional[bool] = None
"""Controls the reasoning/thinking mode for supported models.
@@ -525,6 +532,8 @@ class ChatOllama(BaseChatModel):
self,
messages: list[MessageV1],
stop: Optional[list[str]] = None,
*,
stream: bool = True,
**kwargs: Any,
) -> dict[str, Any]:
"""Build parameters for Ollama chat API."""
@@ -560,7 +569,7 @@ class ChatOllama(BaseChatModel):
params = {
"messages": ollama_messages,
"stream": kwargs.pop("stream", True),
"stream": kwargs.pop("stream", stream),
"model": kwargs.pop("model", self.model),
"think": kwargs.pop("reasoning", self.reasoning),
"format": kwargs.pop("format", self.format),
@@ -691,10 +700,15 @@ class ChatOllama(BaseChatModel):
Complete AI message response.
"""
stream_iter = self._generate_stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
chat_params = self._chat_params(messages, stop, stream=False, **kwargs)
response = self._client.chat(**chat_params)
return _convert_to_v1_from_ollama_format(response)
async def _ainvoke(
self,
@@ -715,10 +729,16 @@ class ChatOllama(BaseChatModel):
Complete AI message response.
"""
stream_iter = self._agenerate_stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
# Non-streaming case: direct API call
chat_params = self._chat_params(messages, stop, stream=False, **kwargs)
response = await self._async_client.chat(**chat_params)
return _convert_to_v1_from_ollama_format(response)
def _stream(
self,

View File

@@ -508,7 +508,7 @@ def test_load_response_with_empty_content_is_skipped(
mock_client_class.return_value = mock_client
mock_client.chat.return_value = iter(load_only_response)
llm = ChatOllama(model="test-model")
llm = ChatOllama(model="test-model", streaming=True)
with (
caplog.at_level(logging.WARNING),
@@ -539,7 +539,7 @@ def test_load_response_with_whitespace_content_is_skipped(
mock_client_class.return_value = mock_client
mock_client.chat.return_value = iter(load_whitespace_response)
llm = ChatOllama(model="test-model")
llm = ChatOllama(model="test-model", streaming=True)
with (
caplog.at_level(logging.WARNING),
@@ -579,7 +579,7 @@ def test_load_followed_by_content_response(
mock_client_class.return_value = mock_client
mock_client.chat.return_value = iter(load_then_content_response)
llm = ChatOllama(model="test-model")
llm = ChatOllama(model="test-model", streaming=True)
with caplog.at_level(logging.WARNING):
result = llm.invoke([HumanMessage("Hello")])
@@ -610,7 +610,7 @@ def test_load_response_with_actual_content_is_not_skipped(
mock_client_class.return_value = mock_client
mock_client.chat.return_value = iter(load_with_content_response)
llm = ChatOllama(model="test-model")
llm = ChatOllama(model="test-model", streaming=True)
with caplog.at_level(logging.WARNING):
result = llm.invoke([HumanMessage("Hello")])