mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 12:39:32 +00:00
fix ollama streaming
This commit is contained in:
@@ -275,6 +275,13 @@ class ChatOllama(BaseChatModel):
|
|||||||
model: str
|
model: str
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to use streaming for invocation.
|
||||||
|
|
||||||
|
If True, invoke will use streaming internally.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
reasoning: Optional[bool] = None
|
reasoning: Optional[bool] = None
|
||||||
"""Controls the reasoning/thinking mode for supported models.
|
"""Controls the reasoning/thinking mode for supported models.
|
||||||
|
|
||||||
@@ -525,6 +532,8 @@ class ChatOllama(BaseChatModel):
|
|||||||
self,
|
self,
|
||||||
messages: list[MessageV1],
|
messages: list[MessageV1],
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
|
*,
|
||||||
|
stream: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Build parameters for Ollama chat API."""
|
"""Build parameters for Ollama chat API."""
|
||||||
@@ -560,7 +569,7 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
params = {
|
params = {
|
||||||
"messages": ollama_messages,
|
"messages": ollama_messages,
|
||||||
"stream": kwargs.pop("stream", True),
|
"stream": kwargs.pop("stream", stream),
|
||||||
"model": kwargs.pop("model", self.model),
|
"model": kwargs.pop("model", self.model),
|
||||||
"think": kwargs.pop("reasoning", self.reasoning),
|
"think": kwargs.pop("reasoning", self.reasoning),
|
||||||
"format": kwargs.pop("format", self.format),
|
"format": kwargs.pop("format", self.format),
|
||||||
@@ -691,11 +700,16 @@ class ChatOllama(BaseChatModel):
|
|||||||
Complete AI message response.
|
Complete AI message response.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
stream_iter = self._generate_stream(
|
if self.streaming:
|
||||||
|
stream_iter = self._stream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
)
|
)
|
||||||
return generate_from_stream(stream_iter)
|
return generate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
chat_params = self._chat_params(messages, stop, stream=False, **kwargs)
|
||||||
|
response = self._client.chat(**chat_params)
|
||||||
|
return _convert_to_v1_from_ollama_format(response)
|
||||||
|
|
||||||
async def _ainvoke(
|
async def _ainvoke(
|
||||||
self,
|
self,
|
||||||
messages: list[MessageV1],
|
messages: list[MessageV1],
|
||||||
@@ -715,11 +729,17 @@ class ChatOllama(BaseChatModel):
|
|||||||
Complete AI message response.
|
Complete AI message response.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
stream_iter = self._agenerate_stream(
|
if self.streaming:
|
||||||
|
stream_iter = self._astream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
)
|
)
|
||||||
return await agenerate_from_stream(stream_iter)
|
return await agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
# Non-streaming case: direct API call
|
||||||
|
chat_params = self._chat_params(messages, stop, stream=False, **kwargs)
|
||||||
|
response = await self._async_client.chat(**chat_params)
|
||||||
|
return _convert_to_v1_from_ollama_format(response)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: list[MessageV1],
|
messages: list[MessageV1],
|
||||||
|
@@ -508,7 +508,7 @@ def test_load_response_with_empty_content_is_skipped(
|
|||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
mock_client.chat.return_value = iter(load_only_response)
|
mock_client.chat.return_value = iter(load_only_response)
|
||||||
|
|
||||||
llm = ChatOllama(model="test-model")
|
llm = ChatOllama(model="test-model", streaming=True)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
caplog.at_level(logging.WARNING),
|
caplog.at_level(logging.WARNING),
|
||||||
@@ -539,7 +539,7 @@ def test_load_response_with_whitespace_content_is_skipped(
|
|||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
mock_client.chat.return_value = iter(load_whitespace_response)
|
mock_client.chat.return_value = iter(load_whitespace_response)
|
||||||
|
|
||||||
llm = ChatOllama(model="test-model")
|
llm = ChatOllama(model="test-model", streaming=True)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
caplog.at_level(logging.WARNING),
|
caplog.at_level(logging.WARNING),
|
||||||
@@ -579,7 +579,7 @@ def test_load_followed_by_content_response(
|
|||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
mock_client.chat.return_value = iter(load_then_content_response)
|
mock_client.chat.return_value = iter(load_then_content_response)
|
||||||
|
|
||||||
llm = ChatOllama(model="test-model")
|
llm = ChatOllama(model="test-model", streaming=True)
|
||||||
|
|
||||||
with caplog.at_level(logging.WARNING):
|
with caplog.at_level(logging.WARNING):
|
||||||
result = llm.invoke([HumanMessage("Hello")])
|
result = llm.invoke([HumanMessage("Hello")])
|
||||||
@@ -610,7 +610,7 @@ def test_load_response_with_actual_content_is_not_skipped(
|
|||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
mock_client.chat.return_value = iter(load_with_content_response)
|
mock_client.chat.return_value = iter(load_with_content_response)
|
||||||
|
|
||||||
llm = ChatOllama(model="test-model")
|
llm = ChatOllama(model="test-model", streaming=True)
|
||||||
|
|
||||||
with caplog.at_level(logging.WARNING):
|
with caplog.at_level(logging.WARNING):
|
||||||
result = llm.invoke([HumanMessage("Hello")])
|
result = llm.invoke([HumanMessage("Hello")])
|
||||||
|
Reference in New Issue
Block a user