mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 10:23:18 +00:00
fix ollama streaming
This commit is contained in:
@@ -275,6 +275,13 @@ class ChatOllama(BaseChatModel):
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to use streaming for invocation.
|
||||
|
||||
If True, invoke will use streaming internally.
|
||||
|
||||
"""
|
||||
|
||||
reasoning: Optional[bool] = None
|
||||
"""Controls the reasoning/thinking mode for supported models.
|
||||
|
||||
@@ -525,6 +532,8 @@ class ChatOllama(BaseChatModel):
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
stop: Optional[list[str]] = None,
|
||||
*,
|
||||
stream: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Build parameters for Ollama chat API."""
|
||||
@@ -560,7 +569,7 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
params = {
|
||||
"messages": ollama_messages,
|
||||
"stream": kwargs.pop("stream", True),
|
||||
"stream": kwargs.pop("stream", stream),
|
||||
"model": kwargs.pop("model", self.model),
|
||||
"think": kwargs.pop("reasoning", self.reasoning),
|
||||
"format": kwargs.pop("format", self.format),
|
||||
@@ -691,10 +700,15 @@ class ChatOllama(BaseChatModel):
|
||||
Complete AI message response.
|
||||
|
||||
"""
|
||||
stream_iter = self._generate_stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
chat_params = self._chat_params(messages, stop, stream=False, **kwargs)
|
||||
response = self._client.chat(**chat_params)
|
||||
return _convert_to_v1_from_ollama_format(response)
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
@@ -715,10 +729,16 @@ class ChatOllama(BaseChatModel):
|
||||
Complete AI message response.
|
||||
|
||||
"""
|
||||
stream_iter = self._agenerate_stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
# Non-streaming case: direct API call
|
||||
chat_params = self._chat_params(messages, stop, stream=False, **kwargs)
|
||||
response = await self._async_client.chat(**chat_params)
|
||||
return _convert_to_v1_from_ollama_format(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
@@ -508,7 +508,7 @@ def test_load_response_with_empty_content_is_skipped(
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.chat.return_value = iter(load_only_response)
|
||||
|
||||
llm = ChatOllama(model="test-model")
|
||||
llm = ChatOllama(model="test-model", streaming=True)
|
||||
|
||||
with (
|
||||
caplog.at_level(logging.WARNING),
|
||||
@@ -539,7 +539,7 @@ def test_load_response_with_whitespace_content_is_skipped(
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.chat.return_value = iter(load_whitespace_response)
|
||||
|
||||
llm = ChatOllama(model="test-model")
|
||||
llm = ChatOllama(model="test-model", streaming=True)
|
||||
|
||||
with (
|
||||
caplog.at_level(logging.WARNING),
|
||||
@@ -579,7 +579,7 @@ def test_load_followed_by_content_response(
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.chat.return_value = iter(load_then_content_response)
|
||||
|
||||
llm = ChatOllama(model="test-model")
|
||||
llm = ChatOllama(model="test-model", streaming=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
result = llm.invoke([HumanMessage("Hello")])
|
||||
@@ -610,7 +610,7 @@ def test_load_response_with_actual_content_is_not_skipped(
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.chat.return_value = iter(load_with_content_response)
|
||||
|
||||
llm = ChatOllama(model="test-model")
|
||||
llm = ChatOllama(model="test-model", streaming=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
result = llm.invoke([HumanMessage("Hello")])
|
||||
|
Reference in New Issue
Block a user