mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 02:33:19 +00:00
ollama[patch]: fix model validation, ensure per-call reasoning can be set, tests (#31927)
* update model validation due to change in [Ollama client](https://github.com/ollama/ollama) - ensure you are running the latest version (0.9.6) to use `validate_model_on_init` * add code example and fix formatting for ChatOllama reasoning * ensure that setting `reasoning` in invocation kwargs overrides class-level setting * tests
This commit is contained in:
parent
f33a25773e
commit
0002b1dafa
@ -16,7 +16,9 @@ def validate_model(client: Client, model_name: str) -> None:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = client.list()
|
response = client.list()
|
||||||
model_names: list[str] = [model["name"] for model in response["models"]]
|
|
||||||
|
model_names: list[str] = [model["model"] for model in response["models"]]
|
||||||
|
|
||||||
if not any(
|
if not any(
|
||||||
model_name == m or m.startswith(f"{model_name}:") for m in model_names
|
model_name == m or m.startswith(f"{model_name}:") for m in model_names
|
||||||
):
|
):
|
||||||
@ -27,10 +29,7 @@ def validate_model(client: Client, model_name: str) -> None:
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
except ConnectError as e:
|
except ConnectError as e:
|
||||||
msg = (
|
msg = "Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download" # noqa: E501
|
||||||
"Connection to Ollama failed. Please make sure Ollama is running "
|
|
||||||
f"and accessible at {client._client.base_url}. "
|
|
||||||
)
|
|
||||||
raise ValueError(msg) from e
|
raise ValueError(msg) from e
|
||||||
except ResponseError as e:
|
except ResponseError as e:
|
||||||
msg = (
|
msg = (
|
||||||
|
@ -224,10 +224,10 @@ class ChatOllama(BaseChatModel):
|
|||||||
and the response will not include any reasoning content.
|
and the response will not include any reasoning content.
|
||||||
- ``None`` (Default): The model will use its default reasoning behavior. Note
|
- ``None`` (Default): The model will use its default reasoning behavior. Note
|
||||||
however, if the model's default behavior *is* to perform reasoning, think tags
|
however, if the model's default behavior *is* to perform reasoning, think tags
|
||||||
()``<think>`` and ``</think>``) will be present within the main response content
|
(``<think>`` and ``</think>``) will be present within the main response content
|
||||||
unless you set ``reasoning`` to ``True``.
|
unless you set ``reasoning`` to ``True``.
|
||||||
temperature: float
|
temperature: float
|
||||||
Sampling temperature. Ranges from 0.0 to 1.0.
|
Sampling temperature. Ranges from ``0.0`` to ``1.0``.
|
||||||
num_predict: Optional[int]
|
num_predict: Optional[int]
|
||||||
Max number of tokens to generate.
|
Max number of tokens to generate.
|
||||||
|
|
||||||
@ -343,7 +343,6 @@ class ChatOllama(BaseChatModel):
|
|||||||
'{"location": "Pune, India", "time_of_day": "morning"}'
|
'{"location": "Pune, India", "time_of_day": "morning"}'
|
||||||
|
|
||||||
Tool Calling:
|
Tool Calling:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain_ollama import ChatOllama
|
from langchain_ollama import ChatOllama
|
||||||
@ -362,6 +361,48 @@ class ChatOllama(BaseChatModel):
|
|||||||
'args': {'a': 45, 'b': 67},
|
'args': {'a': 45, 'b': 67},
|
||||||
'id': '420c3f3b-df10-4188-945f-eb3abdb40622',
|
'id': '420c3f3b-df10-4188-945f-eb3abdb40622',
|
||||||
'type': 'tool_call'}]
|
'type': 'tool_call'}]
|
||||||
|
|
||||||
|
Thinking / Reasoning:
|
||||||
|
You can enable reasoning mode for models that support it by setting
|
||||||
|
the ``reasoning`` parameter to ``True`` in either the constructor or
|
||||||
|
the ``invoke``/``stream`` methods. This will enable the model to think
|
||||||
|
through the problem and return the reasoning process separately in the
|
||||||
|
``additional_kwargs`` of the response message, under ``reasoning_content``.
|
||||||
|
|
||||||
|
If ``reasoning`` is set to ``None``, the model will use its default reasoning
|
||||||
|
behavior, and any reasoning content will *not* be captured under the
|
||||||
|
``reasoning_content`` key, but will be present within the main response content
|
||||||
|
as think tags (``<think>`` and ``</think>``).
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This feature is only available for `models that support reasoning <https://ollama.com/search?c=thinking>`__.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_ollama import ChatOllama
|
||||||
|
|
||||||
|
llm = ChatOllama(
|
||||||
|
model = "deepseek-r1:8b",
|
||||||
|
reasoning= True,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_message = HumanMessage(content="how many r in the word strawberry?")
|
||||||
|
messages: List[Any] = [user_message]
|
||||||
|
llm.invoke(messages)
|
||||||
|
|
||||||
|
# or, on an invocation basis:
|
||||||
|
|
||||||
|
llm.invoke(messages, reasoning=True)
|
||||||
|
# or llm.stream(messages, reasoning=True)
|
||||||
|
|
||||||
|
# If not provided, the invocation will default to the ChatOllama reasoning
|
||||||
|
# param provided (None by default).
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
AIMessage(content='The word "strawberry" contains **three \'r\' letters**. Here\'s a breakdown for clarity:\n\n- The spelling of "strawberry" has two parts ... be 3.\n\nTo be thorough, let\'s confirm with an online source or common knowledge.\n\nI can recall that "strawberry" has: s-t-r-a-w-b-e-r-r-y — yes, three r\'s.\n\nPerhaps it\'s misspelled by some, but standard is correct.\n\nSo I think the response should be 3.\n'}, response_metadata={'model': 'deepseek-r1:8b', 'created_at': '2025-07-08T19:33:55.891269Z', 'done': True, 'done_reason': 'stop', 'total_duration': 98232561292, 'load_duration': 28036792, 'prompt_eval_count': 10, 'prompt_eval_duration': 40171834, 'eval_count': 3615, 'eval_duration': 98163832416, 'model_name': 'deepseek-r1:8b'}, id='run--18f8269f-6a35-4a7c-826d-b89d52c753b3-0', usage_metadata={'input_tokens': 10, 'output_tokens': 3615, 'total_tokens': 3625})
|
||||||
|
|
||||||
|
|
||||||
""" # noqa: E501, pylint: disable=line-too-long
|
""" # noqa: E501, pylint: disable=line-too-long
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
@ -777,6 +818,7 @@ class ChatOllama(BaseChatModel):
|
|||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
reasoning = kwargs.get("reasoning", self.reasoning)
|
||||||
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
|
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
|
||||||
if not isinstance(stream_resp, str):
|
if not isinstance(stream_resp, str):
|
||||||
if stream_resp.get("done") is True:
|
if stream_resp.get("done") is True:
|
||||||
@ -795,7 +837,7 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
additional_kwargs = {}
|
additional_kwargs = {}
|
||||||
if (
|
if (
|
||||||
self.reasoning
|
reasoning
|
||||||
and "message" in stream_resp
|
and "message" in stream_resp
|
||||||
and (thinking_content := stream_resp["message"].get("thinking"))
|
and (thinking_content := stream_resp["message"].get("thinking"))
|
||||||
):
|
):
|
||||||
@ -836,6 +878,7 @@ class ChatOllama(BaseChatModel):
|
|||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
reasoning = kwargs.get("reasoning", self.reasoning)
|
||||||
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
|
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
|
||||||
if not isinstance(stream_resp, str):
|
if not isinstance(stream_resp, str):
|
||||||
if stream_resp.get("done") is True:
|
if stream_resp.get("done") is True:
|
||||||
@ -854,7 +897,7 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
additional_kwargs = {}
|
additional_kwargs = {}
|
||||||
if (
|
if (
|
||||||
self.reasoning
|
reasoning
|
||||||
and "message" in stream_resp
|
and "message" in stream_resp
|
||||||
and (thinking_content := stream_resp["message"].get("thinking"))
|
and (thinking_content := stream_resp["message"].get("thinking"))
|
||||||
):
|
):
|
||||||
|
@ -38,7 +38,7 @@ class OllamaLLM(BaseLLM):
|
|||||||
model: str
|
model: str
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
|
|
||||||
reasoning: Optional[bool] = True
|
reasoning: Optional[bool] = None
|
||||||
"""Controls the reasoning/thinking mode for
|
"""Controls the reasoning/thinking mode for
|
||||||
`supported models <https://ollama.com/search?c=thinking>`__.
|
`supported models <https://ollama.com/search?c=thinking>`__.
|
||||||
|
|
||||||
@ -272,8 +272,11 @@ class OllamaLLM(BaseLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> GenerationChunk:
|
) -> GenerationChunk:
|
||||||
final_chunk = None
|
final_chunk = None
|
||||||
|
thinking_content = ""
|
||||||
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
|
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
|
||||||
if not isinstance(stream_resp, str):
|
if not isinstance(stream_resp, str):
|
||||||
|
if stream_resp.get("thinking"):
|
||||||
|
thinking_content += stream_resp["thinking"]
|
||||||
chunk = GenerationChunk(
|
chunk = GenerationChunk(
|
||||||
text=stream_resp.get("response", ""),
|
text=stream_resp.get("response", ""),
|
||||||
generation_info=(
|
generation_info=(
|
||||||
@ -294,6 +297,12 @@ class OllamaLLM(BaseLLM):
|
|||||||
msg = "No data received from Ollama stream."
|
msg = "No data received from Ollama stream."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
if thinking_content:
|
||||||
|
if final_chunk.generation_info:
|
||||||
|
final_chunk.generation_info["thinking"] = thinking_content
|
||||||
|
else:
|
||||||
|
final_chunk.generation_info = {"thinking": thinking_content}
|
||||||
|
|
||||||
return final_chunk
|
return final_chunk
|
||||||
|
|
||||||
def _stream_with_aggregation(
|
def _stream_with_aggregation(
|
||||||
@ -305,8 +314,11 @@ class OllamaLLM(BaseLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> GenerationChunk:
|
) -> GenerationChunk:
|
||||||
final_chunk = None
|
final_chunk = None
|
||||||
|
thinking_content = ""
|
||||||
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
|
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
|
||||||
if not isinstance(stream_resp, str):
|
if not isinstance(stream_resp, str):
|
||||||
|
if stream_resp.get("thinking"):
|
||||||
|
thinking_content += stream_resp["thinking"]
|
||||||
chunk = GenerationChunk(
|
chunk = GenerationChunk(
|
||||||
text=stream_resp.get("response", ""),
|
text=stream_resp.get("response", ""),
|
||||||
generation_info=(
|
generation_info=(
|
||||||
@ -327,6 +339,12 @@ class OllamaLLM(BaseLLM):
|
|||||||
msg = "No data received from Ollama stream."
|
msg = "No data received from Ollama stream."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
if thinking_content:
|
||||||
|
if final_chunk.generation_info:
|
||||||
|
final_chunk.generation_info["thinking"] = thinking_content
|
||||||
|
else:
|
||||||
|
final_chunk.generation_info = {"thinking": thinking_content}
|
||||||
|
|
||||||
return final_chunk
|
return final_chunk
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
@ -374,10 +392,11 @@ class OllamaLLM(BaseLLM):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[GenerationChunk]:
|
) -> Iterator[GenerationChunk]:
|
||||||
|
reasoning = kwargs.get("reasoning", self.reasoning)
|
||||||
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
|
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
|
||||||
if not isinstance(stream_resp, str):
|
if not isinstance(stream_resp, str):
|
||||||
additional_kwargs = {}
|
additional_kwargs = {}
|
||||||
if thinking_content := stream_resp.get("thinking"):
|
if reasoning and (thinking_content := stream_resp.get("thinking")):
|
||||||
additional_kwargs["reasoning_content"] = thinking_content
|
additional_kwargs["reasoning_content"] = thinking_content
|
||||||
|
|
||||||
chunk = GenerationChunk(
|
chunk = GenerationChunk(
|
||||||
@ -404,10 +423,11 @@ class OllamaLLM(BaseLLM):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[GenerationChunk]:
|
) -> AsyncIterator[GenerationChunk]:
|
||||||
|
reasoning = kwargs.get("reasoning", self.reasoning)
|
||||||
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
|
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
|
||||||
if not isinstance(stream_resp, str):
|
if not isinstance(stream_resp, str):
|
||||||
additional_kwargs = {}
|
additional_kwargs = {}
|
||||||
if thinking_content := stream_resp.get("thinking"):
|
if reasoning and (thinking_content := stream_resp.get("thinking")):
|
||||||
additional_kwargs["reasoning_content"] = thinking_content
|
additional_kwargs["reasoning_content"] = thinking_content
|
||||||
|
|
||||||
chunk = GenerationChunk(
|
chunk = GenerationChunk(
|
||||||
|
@ -1,18 +1,17 @@
|
|||||||
"""Test OllamaLLM llm."""
|
"""Test OllamaLLM llm."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
from langchain_core.outputs import GenerationChunk
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
from langchain_ollama.llms import OllamaLLM
|
from langchain_ollama.llms import OllamaLLM
|
||||||
|
|
||||||
MODEL_NAME = "llama3.1"
|
MODEL_NAME = "llama3.1"
|
||||||
|
|
||||||
SAMPLE = "What is 3^3?"
|
SAMPLE = "What is 3^3?"
|
||||||
|
|
||||||
|
|
||||||
def test_stream() -> None:
|
def test_stream_text_tokens() -> None:
|
||||||
"""Test streaming tokens from OpenAI."""
|
"""Test streaming raw string tokens from OllamaLLM."""
|
||||||
llm = OllamaLLM(model=MODEL_NAME)
|
llm = OllamaLLM(model=MODEL_NAME)
|
||||||
|
|
||||||
for token in llm.stream("I'm Pickle Rick"):
|
for token in llm.stream("I'm Pickle Rick"):
|
||||||
@ -20,60 +19,52 @@ def test_stream() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||||
def test_stream_no_reasoning(model: str) -> None:
|
def test__stream_no_reasoning(model: str) -> None:
|
||||||
"""Test streaming with `reasoning=False`"""
|
"""Test low-level chunk streaming of a simple prompt with `reasoning=False`."""
|
||||||
llm = OllamaLLM(model=model, num_ctx=2**12)
|
llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": SAMPLE,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
result = None
|
|
||||||
for chunk in llm.stream(messages):
|
|
||||||
assert isinstance(chunk, BaseMessageChunk)
|
|
||||||
if result is None:
|
|
||||||
result = chunk
|
|
||||||
continue
|
|
||||||
result += chunk
|
|
||||||
assert isinstance(result, AIMessageChunk)
|
|
||||||
assert result.content
|
|
||||||
assert "reasoning_content" not in result.additional_kwargs
|
|
||||||
|
|
||||||
# Sanity check the old behavior isn't present
|
result_chunk = None
|
||||||
assert "<think>" not in result.content and "</think>" not in result.content
|
for chunk in llm._stream(SAMPLE):
|
||||||
|
# Should be a GenerationChunk
|
||||||
|
assert isinstance(chunk, GenerationChunk)
|
||||||
|
if result_chunk is None:
|
||||||
|
result_chunk = chunk
|
||||||
|
else:
|
||||||
|
result_chunk += chunk
|
||||||
|
|
||||||
|
# The final result must be a GenerationChunk with visible content
|
||||||
|
assert isinstance(result_chunk, GenerationChunk)
|
||||||
|
assert result_chunk.text
|
||||||
|
# No separate reasoning_content
|
||||||
|
assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||||
def test_reasoning_stream(model: str) -> None:
|
def test__stream_with_reasoning(model: str) -> None:
|
||||||
"""Test streaming with `reasoning=True`"""
|
"""Test low-level chunk streaming with `reasoning=True`."""
|
||||||
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": SAMPLE,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
result = None
|
|
||||||
for chunk in llm.stream(messages):
|
|
||||||
assert isinstance(chunk, BaseMessageChunk)
|
|
||||||
if result is None:
|
|
||||||
result = chunk
|
|
||||||
continue
|
|
||||||
result += chunk
|
|
||||||
assert isinstance(result, AIMessageChunk)
|
|
||||||
assert result.content
|
|
||||||
assert "reasoning_content" in result.additional_kwargs
|
|
||||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
|
||||||
|
|
||||||
# Sanity check the old behavior isn't present
|
result_chunk = None
|
||||||
assert "<think>" not in result.content and "</think>" not in result.content
|
for chunk in llm._stream(SAMPLE):
|
||||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
assert isinstance(chunk, GenerationChunk)
|
||||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
if result_chunk is None:
|
||||||
|
result_chunk = chunk
|
||||||
|
else:
|
||||||
|
result_chunk += chunk
|
||||||
|
|
||||||
|
assert isinstance(result_chunk, GenerationChunk)
|
||||||
|
assert result_chunk.text
|
||||||
|
# Should have extracted reasoning into generation_info
|
||||||
|
assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator]
|
||||||
|
assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index]
|
||||||
|
# And neither the visible nor the hidden portion contains <think> tags
|
||||||
|
assert "<think>" not in result_chunk.text and "</think>" not in result_chunk.text
|
||||||
|
assert "<think>" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index]
|
||||||
|
assert "</think>" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index]
|
||||||
|
|
||||||
|
|
||||||
async def test_astream() -> None:
|
async def test_astream_text_tokens() -> None:
|
||||||
"""Test streaming tokens from OpenAI."""
|
"""Test async streaming raw string tokens from OllamaLLM."""
|
||||||
llm = OllamaLLM(model=MODEL_NAME)
|
llm = OllamaLLM(model=MODEL_NAME)
|
||||||
|
|
||||||
async for token in llm.astream("I'm Pickle Rick"):
|
async for token in llm.astream("I'm Pickle Rick"):
|
||||||
@ -81,60 +72,44 @@ async def test_astream() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||||
async def test_astream_no_reasoning(model: str) -> None:
|
async def test__astream_no_reasoning(model: str) -> None:
|
||||||
"""Test async streaming with `reasoning=False`"""
|
"""Test low-level async chunk streaming with `reasoning=False`."""
|
||||||
llm = OllamaLLM(model=model, num_ctx=2**12)
|
llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": SAMPLE,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
result = None
|
|
||||||
async for chunk in llm.astream(messages):
|
|
||||||
assert isinstance(chunk, BaseMessageChunk)
|
|
||||||
if result is None:
|
|
||||||
result = chunk
|
|
||||||
continue
|
|
||||||
result += chunk
|
|
||||||
assert isinstance(result, AIMessageChunk)
|
|
||||||
assert result.content
|
|
||||||
assert "reasoning_content" not in result.additional_kwargs
|
|
||||||
|
|
||||||
# Sanity check the old behavior isn't present
|
result_chunk = None
|
||||||
assert "<think>" not in result.content and "</think>" not in result.content
|
async for chunk in llm._astream(SAMPLE):
|
||||||
|
assert isinstance(chunk, GenerationChunk)
|
||||||
|
if result_chunk is None:
|
||||||
|
result_chunk = chunk
|
||||||
|
else:
|
||||||
|
result_chunk += chunk
|
||||||
|
|
||||||
|
assert isinstance(result_chunk, GenerationChunk)
|
||||||
|
assert result_chunk.text
|
||||||
|
assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||||
async def test_reasoning_astream(model: str) -> None:
|
async def test__astream_with_reasoning(model: str) -> None:
|
||||||
"""Test async streaming with `reasoning=True`"""
|
"""Test low-level async chunk streaming with `reasoning=True`."""
|
||||||
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": SAMPLE,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
result = None
|
|
||||||
async for chunk in llm.astream(messages):
|
|
||||||
assert isinstance(chunk, BaseMessageChunk)
|
|
||||||
if result is None:
|
|
||||||
result = chunk
|
|
||||||
continue
|
|
||||||
result += chunk
|
|
||||||
assert isinstance(result, AIMessageChunk)
|
|
||||||
assert result.content
|
|
||||||
assert "reasoning_content" in result.additional_kwargs
|
|
||||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
|
||||||
|
|
||||||
# Sanity check the old behavior isn't present
|
result_chunk = None
|
||||||
assert "<think>" not in result.content and "</think>" not in result.content
|
async for chunk in llm._astream(SAMPLE):
|
||||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
assert isinstance(chunk, GenerationChunk)
|
||||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
if result_chunk is None:
|
||||||
|
result_chunk = chunk
|
||||||
|
else:
|
||||||
|
result_chunk += chunk
|
||||||
|
|
||||||
|
assert isinstance(result_chunk, GenerationChunk)
|
||||||
|
assert result_chunk.text
|
||||||
|
assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator]
|
||||||
|
assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index]
|
||||||
|
|
||||||
|
|
||||||
async def test_abatch() -> None:
|
async def test_abatch() -> None:
|
||||||
"""Test streaming tokens from OllamaLLM."""
|
"""Test batch sync token generation from OllamaLLM."""
|
||||||
llm = OllamaLLM(model=MODEL_NAME)
|
llm = OllamaLLM(model=MODEL_NAME)
|
||||||
|
|
||||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||||
@ -143,7 +118,7 @@ async def test_abatch() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_abatch_tags() -> None:
|
async def test_abatch_tags() -> None:
|
||||||
"""Test batch tokens from OllamaLLM."""
|
"""Test batch sync token generation with tags."""
|
||||||
llm = OllamaLLM(model=MODEL_NAME)
|
llm = OllamaLLM(model=MODEL_NAME)
|
||||||
|
|
||||||
result = await llm.abatch(
|
result = await llm.abatch(
|
||||||
@ -154,7 +129,7 @@ async def test_abatch_tags() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_batch() -> None:
|
def test_batch() -> None:
|
||||||
"""Test batch tokens from OllamaLLM."""
|
"""Test batch token generation from OllamaLLM."""
|
||||||
llm = OllamaLLM(model=MODEL_NAME)
|
llm = OllamaLLM(model=MODEL_NAME)
|
||||||
|
|
||||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||||
@ -163,75 +138,15 @@ def test_batch() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_ainvoke() -> None:
|
async def test_ainvoke() -> None:
|
||||||
"""Test invoke tokens from OllamaLLM."""
|
"""Test async invoke returning a string."""
|
||||||
llm = OllamaLLM(model=MODEL_NAME)
|
llm = OllamaLLM(model=MODEL_NAME)
|
||||||
|
|
||||||
result = await llm.ainvoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
|
result = await llm.ainvoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
|
||||||
assert isinstance(result, str)
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
# TODO
|
|
||||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
|
||||||
# async def test_ainvoke_no_reasoning(model: str) -> None:
|
|
||||||
# """Test using async invoke with `reasoning=False`"""
|
|
||||||
# llm = OllamaLLM(model=model, num_ctx=2**12)
|
|
||||||
# message = SAMPLE
|
|
||||||
# result = await llm.ainvoke(message)
|
|
||||||
# assert result.content
|
|
||||||
# assert "reasoning_content" not in result.additional_kwargs
|
|
||||||
|
|
||||||
# # Sanity check the old behavior isn't present
|
|
||||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
|
||||||
# async def test_reasoning_ainvoke(model: str) -> None:
|
|
||||||
# """Test invoke with `reasoning=True`"""
|
|
||||||
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
|
||||||
# message = SAMPLE
|
|
||||||
# result = await llm.ainvoke(message)
|
|
||||||
# assert result.content
|
|
||||||
# assert "reasoning_content" in result.additional_kwargs
|
|
||||||
# assert len(result.additional_kwargs["reasoning_content"]) > 0
|
|
||||||
|
|
||||||
# # Sanity check the old behavior isn't present
|
|
||||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
|
||||||
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
|
||||||
# assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_invoke() -> None:
|
def test_invoke() -> None:
|
||||||
"""Test invoke tokens from OllamaLLM."""
|
"""Test sync invoke returning a string."""
|
||||||
llm = OllamaLLM(model=MODEL_NAME)
|
llm = OllamaLLM(model=MODEL_NAME)
|
||||||
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
|
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
|
||||||
assert isinstance(result, str)
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
# TODO
|
|
||||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
|
||||||
# def test_invoke_no_reasoning(model: str) -> None:
|
|
||||||
# """Test using invoke with `reasoning=False`"""
|
|
||||||
# llm = OllamaLLM(model=model, num_ctx=2**12)
|
|
||||||
# message = SAMPLE
|
|
||||||
# result = llm.invoke(message)
|
|
||||||
# assert result.content
|
|
||||||
# assert "reasoning_content" not in result.additional_kwargs
|
|
||||||
|
|
||||||
# # Sanity check the old behavior isn't present
|
|
||||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
|
||||||
# def test_reasoning_invoke(model: str) -> None:
|
|
||||||
# """Test invoke with `reasoning=True`"""
|
|
||||||
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
|
||||||
# message = SAMPLE
|
|
||||||
# result = llm.invoke(message)
|
|
||||||
# assert result.content
|
|
||||||
# assert "reasoning_content" in result.additional_kwargs
|
|
||||||
# assert len(result.additional_kwargs["reasoning_content"]) > 0
|
|
||||||
|
|
||||||
# # Sanity check the old behavior isn't present
|
|
||||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
|
||||||
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
|
||||||
# assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
|
||||||
|
@ -48,3 +48,24 @@ def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
|||||||
# Test that validate_model is NOT called by default
|
# Test that validate_model is NOT called by default
|
||||||
OllamaLLM(model=MODEL_NAME)
|
OllamaLLM(model=MODEL_NAME)
|
||||||
mock_validate_model.assert_not_called()
|
mock_validate_model.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_reasoning_aggregation() -> None:
|
||||||
|
"""Test that reasoning chunks are aggregated into final response."""
|
||||||
|
llm = OllamaLLM(model=MODEL_NAME, reasoning=True)
|
||||||
|
prompts = ["some prompt"]
|
||||||
|
mock_stream = [
|
||||||
|
{"thinking": "I am thinking.", "done": False},
|
||||||
|
{"thinking": " Still thinking.", "done": False},
|
||||||
|
{"response": "Final Answer.", "done": True},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(llm, "_create_generate_stream") as mock_stream_method:
|
||||||
|
mock_stream_method.return_value = iter(mock_stream)
|
||||||
|
result = llm.generate(prompts)
|
||||||
|
|
||||||
|
assert result.generations[0][0].generation_info is not None
|
||||||
|
assert (
|
||||||
|
result.generations[0][0].generation_info["thinking"]
|
||||||
|
== "I am thinking. Still thinking."
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user