mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 17:33:53 +00:00
ollama[patch]: fix model validation, ensure per-call reasoning can be set, tests (#31927)
* update model validation due to change in [Ollama client](https://github.com/ollama/ollama) - ensure you are running the latest version (0.9.6) to use `validate_model_on_init` * add code example and fix formatting for ChatOllama reasoning * ensure that setting `reasoning` in invocation kwargs overrides class-level setting * tests
This commit is contained in:
parent
f33a25773e
commit
0002b1dafa
@ -16,7 +16,9 @@ def validate_model(client: Client, model_name: str) -> None:
|
||||
"""
|
||||
try:
|
||||
response = client.list()
|
||||
model_names: list[str] = [model["name"] for model in response["models"]]
|
||||
|
||||
model_names: list[str] = [model["model"] for model in response["models"]]
|
||||
|
||||
if not any(
|
||||
model_name == m or m.startswith(f"{model_name}:") for m in model_names
|
||||
):
|
||||
@ -27,10 +29,7 @@ def validate_model(client: Client, model_name: str) -> None:
|
||||
)
|
||||
raise ValueError(msg)
|
||||
except ConnectError as e:
|
||||
msg = (
|
||||
"Connection to Ollama failed. Please make sure Ollama is running "
|
||||
f"and accessible at {client._client.base_url}. "
|
||||
)
|
||||
msg = "Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download" # noqa: E501
|
||||
raise ValueError(msg) from e
|
||||
except ResponseError as e:
|
||||
msg = (
|
||||
|
@ -217,17 +217,17 @@ class ChatOllama(BaseChatModel):
|
||||
`supported models <https://ollama.com/search?c=thinking>`__.
|
||||
|
||||
- ``True``: Enables reasoning mode. The model's reasoning process will be
|
||||
captured and returned separately in the ``additional_kwargs`` of the
|
||||
response message, under ``reasoning_content``. The main response
|
||||
content will not include the reasoning tags.
|
||||
captured and returned separately in the ``additional_kwargs`` of the
|
||||
response message, under ``reasoning_content``. The main response
|
||||
content will not include the reasoning tags.
|
||||
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
|
||||
and the response will not include any reasoning content.
|
||||
and the response will not include any reasoning content.
|
||||
- ``None`` (Default): The model will use its default reasoning behavior. Note
|
||||
however, if the model's default behavior *is* to perform reasoning, think tags
|
||||
()``<think>`` and ``</think>``) will be present within the main response content
|
||||
unless you set ``reasoning`` to ``True``.
|
||||
however, if the model's default behavior *is* to perform reasoning, think tags
|
||||
(``<think>`` and ``</think>``) will be present within the main response content
|
||||
unless you set ``reasoning`` to ``True``.
|
||||
temperature: float
|
||||
Sampling temperature. Ranges from 0.0 to 1.0.
|
||||
Sampling temperature. Ranges from ``0.0`` to ``1.0``.
|
||||
num_predict: Optional[int]
|
||||
Max number of tokens to generate.
|
||||
|
||||
@ -343,7 +343,6 @@ class ChatOllama(BaseChatModel):
|
||||
'{"location": "Pune, India", "time_of_day": "morning"}'
|
||||
|
||||
Tool Calling:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
@ -362,6 +361,48 @@ class ChatOllama(BaseChatModel):
|
||||
'args': {'a': 45, 'b': 67},
|
||||
'id': '420c3f3b-df10-4188-945f-eb3abdb40622',
|
||||
'type': 'tool_call'}]
|
||||
|
||||
Thinking / Reasoning:
|
||||
You can enable reasoning mode for models that support it by setting
|
||||
the ``reasoning`` parameter to ``True`` in either the constructor or
|
||||
the ``invoke``/``stream`` methods. This will enable the model to think
|
||||
through the problem and return the reasoning process separately in the
|
||||
``additional_kwargs`` of the response message, under ``reasoning_content``.
|
||||
|
||||
If ``reasoning`` is set to ``None``, the model will use its default reasoning
|
||||
behavior, and any reasoning content will *not* be captured under the
|
||||
``reasoning_content`` key, but will be present within the main response content
|
||||
as think tags (``<think>`` and ``</think>``).
|
||||
|
||||
.. note::
|
||||
This feature is only available for `models that support reasoning <https://ollama.com/search?c=thinking>`__.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
llm = ChatOllama(
|
||||
model = "deepseek-r1:8b",
|
||||
reasoning= True,
|
||||
)
|
||||
|
||||
user_message = HumanMessage(content="how many r in the word strawberry?")
|
||||
messages: List[Any] = [user_message]
|
||||
llm.invoke(messages)
|
||||
|
||||
# or, on an invocation basis:
|
||||
|
||||
llm.invoke(messages, reasoning=True)
|
||||
# or llm.stream(messages, reasoning=True)
|
||||
|
||||
# If not provided, the invocation will default to the ChatOllama reasoning
|
||||
# param provided (None by default).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AIMessage(content='The word "strawberry" contains **three \'r\' letters**. Here\'s a breakdown for clarity:\n\n- The spelling of "strawberry" has two parts ... be 3.\n\nTo be thorough, let\'s confirm with an online source or common knowledge.\n\nI can recall that "strawberry" has: s-t-r-a-w-b-e-r-r-y — yes, three r\'s.\n\nPerhaps it\'s misspelled by some, but standard is correct.\n\nSo I think the response should be 3.\n'}, response_metadata={'model': 'deepseek-r1:8b', 'created_at': '2025-07-08T19:33:55.891269Z', 'done': True, 'done_reason': 'stop', 'total_duration': 98232561292, 'load_duration': 28036792, 'prompt_eval_count': 10, 'prompt_eval_duration': 40171834, 'eval_count': 3615, 'eval_duration': 98163832416, 'model_name': 'deepseek-r1:8b'}, id='run--18f8269f-6a35-4a7c-826d-b89d52c753b3-0', usage_metadata={'input_tokens': 10, 'output_tokens': 3615, 'total_tokens': 3625})
|
||||
|
||||
|
||||
""" # noqa: E501, pylint: disable=line-too-long
|
||||
|
||||
model: str
|
||||
@ -777,6 +818,7 @@ class ChatOllama(BaseChatModel):
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
reasoning = kwargs.get("reasoning", self.reasoning)
|
||||
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
if stream_resp.get("done") is True:
|
||||
@ -795,7 +837,7 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
additional_kwargs = {}
|
||||
if (
|
||||
self.reasoning
|
||||
reasoning
|
||||
and "message" in stream_resp
|
||||
and (thinking_content := stream_resp["message"].get("thinking"))
|
||||
):
|
||||
@ -836,6 +878,7 @@ class ChatOllama(BaseChatModel):
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
reasoning = kwargs.get("reasoning", self.reasoning)
|
||||
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
if stream_resp.get("done") is True:
|
||||
@ -854,7 +897,7 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
additional_kwargs = {}
|
||||
if (
|
||||
self.reasoning
|
||||
reasoning
|
||||
and "message" in stream_resp
|
||||
and (thinking_content := stream_resp["message"].get("thinking"))
|
||||
):
|
||||
|
@ -38,7 +38,7 @@ class OllamaLLM(BaseLLM):
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
reasoning: Optional[bool] = True
|
||||
reasoning: Optional[bool] = None
|
||||
"""Controls the reasoning/thinking mode for
|
||||
`supported models <https://ollama.com/search?c=thinking>`__.
|
||||
|
||||
@ -272,8 +272,11 @@ class OllamaLLM(BaseLLM):
|
||||
**kwargs: Any,
|
||||
) -> GenerationChunk:
|
||||
final_chunk = None
|
||||
thinking_content = ""
|
||||
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
if stream_resp.get("thinking"):
|
||||
thinking_content += stream_resp["thinking"]
|
||||
chunk = GenerationChunk(
|
||||
text=stream_resp.get("response", ""),
|
||||
generation_info=(
|
||||
@ -294,6 +297,12 @@ class OllamaLLM(BaseLLM):
|
||||
msg = "No data received from Ollama stream."
|
||||
raise ValueError(msg)
|
||||
|
||||
if thinking_content:
|
||||
if final_chunk.generation_info:
|
||||
final_chunk.generation_info["thinking"] = thinking_content
|
||||
else:
|
||||
final_chunk.generation_info = {"thinking": thinking_content}
|
||||
|
||||
return final_chunk
|
||||
|
||||
def _stream_with_aggregation(
|
||||
@ -305,8 +314,11 @@ class OllamaLLM(BaseLLM):
|
||||
**kwargs: Any,
|
||||
) -> GenerationChunk:
|
||||
final_chunk = None
|
||||
thinking_content = ""
|
||||
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
if stream_resp.get("thinking"):
|
||||
thinking_content += stream_resp["thinking"]
|
||||
chunk = GenerationChunk(
|
||||
text=stream_resp.get("response", ""),
|
||||
generation_info=(
|
||||
@ -327,6 +339,12 @@ class OllamaLLM(BaseLLM):
|
||||
msg = "No data received from Ollama stream."
|
||||
raise ValueError(msg)
|
||||
|
||||
if thinking_content:
|
||||
if final_chunk.generation_info:
|
||||
final_chunk.generation_info["thinking"] = thinking_content
|
||||
else:
|
||||
final_chunk.generation_info = {"thinking": thinking_content}
|
||||
|
||||
return final_chunk
|
||||
|
||||
def _generate(
|
||||
@ -374,10 +392,11 @@ class OllamaLLM(BaseLLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
reasoning = kwargs.get("reasoning", self.reasoning)
|
||||
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
additional_kwargs = {}
|
||||
if thinking_content := stream_resp.get("thinking"):
|
||||
if reasoning and (thinking_content := stream_resp.get("thinking")):
|
||||
additional_kwargs["reasoning_content"] = thinking_content
|
||||
|
||||
chunk = GenerationChunk(
|
||||
@ -404,10 +423,11 @@ class OllamaLLM(BaseLLM):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
reasoning = kwargs.get("reasoning", self.reasoning)
|
||||
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
|
||||
if not isinstance(stream_resp, str):
|
||||
additional_kwargs = {}
|
||||
if thinking_content := stream_resp.get("thinking"):
|
||||
if reasoning and (thinking_content := stream_resp.get("thinking")):
|
||||
additional_kwargs["reasoning_content"] = thinking_content
|
||||
|
||||
chunk = GenerationChunk(
|
||||
|
@ -1,18 +1,17 @@
|
||||
"""Test OllamaLLM llm."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from langchain_ollama.llms import OllamaLLM
|
||||
|
||||
MODEL_NAME = "llama3.1"
|
||||
|
||||
SAMPLE = "What is 3^3?"
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
def test_stream_text_tokens() -> None:
|
||||
"""Test streaming raw string tokens from OllamaLLM."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
@ -20,60 +19,52 @@ def test_stream() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_stream_no_reasoning(model: str) -> None:
|
||||
"""Test streaming with `reasoning=False`"""
|
||||
def test__stream_no_reasoning(model: str) -> None:
|
||||
"""Test low-level chunk streaming of a simple prompt with `reasoning=False`."""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
# Sanity check the old behavior isn't present
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
result_chunk = None
|
||||
for chunk in llm._stream(SAMPLE):
|
||||
# Should be a GenerationChunk
|
||||
assert isinstance(chunk, GenerationChunk)
|
||||
if result_chunk is None:
|
||||
result_chunk = chunk
|
||||
else:
|
||||
result_chunk += chunk
|
||||
|
||||
# The final result must be a GenerationChunk with visible content
|
||||
assert isinstance(result_chunk, GenerationChunk)
|
||||
assert result_chunk.text
|
||||
# No separate reasoning_content
|
||||
assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_reasoning_stream(model: str) -> None:
|
||||
"""Test streaming with `reasoning=True`"""
|
||||
def test__stream_with_reasoning(model: str) -> None:
|
||||
"""Test low-level chunk streaming with `reasoning=True`."""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
for chunk in llm.stream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
|
||||
# Sanity check the old behavior isn't present
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
result_chunk = None
|
||||
for chunk in llm._stream(SAMPLE):
|
||||
assert isinstance(chunk, GenerationChunk)
|
||||
if result_chunk is None:
|
||||
result_chunk = chunk
|
||||
else:
|
||||
result_chunk += chunk
|
||||
|
||||
assert isinstance(result_chunk, GenerationChunk)
|
||||
assert result_chunk.text
|
||||
# Should have extracted reasoning into generation_info
|
||||
assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator]
|
||||
assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index]
|
||||
# And neither the visible nor the hidden portion contains <think> tags
|
||||
assert "<think>" not in result_chunk.text and "</think>" not in result_chunk.text
|
||||
assert "<think>" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index]
|
||||
assert "</think>" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index]
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
async def test_astream_text_tokens() -> None:
|
||||
"""Test async streaming raw string tokens from OllamaLLM."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
@ -81,60 +72,44 @@ async def test_astream() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_astream_no_reasoning(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=False`"""
|
||||
async def test__astream_no_reasoning(model: str) -> None:
|
||||
"""Test low-level async chunk streaming with `reasoning=False`."""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
# Sanity check the old behavior isn't present
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
result_chunk = None
|
||||
async for chunk in llm._astream(SAMPLE):
|
||||
assert isinstance(chunk, GenerationChunk)
|
||||
if result_chunk is None:
|
||||
result_chunk = chunk
|
||||
else:
|
||||
result_chunk += chunk
|
||||
|
||||
assert isinstance(result_chunk, GenerationChunk)
|
||||
assert result_chunk.text
|
||||
assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_reasoning_astream(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=True`"""
|
||||
async def test__astream_with_reasoning(model: str) -> None:
|
||||
"""Test low-level async chunk streaming with `reasoning=True`."""
|
||||
llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SAMPLE,
|
||||
}
|
||||
]
|
||||
result = None
|
||||
async for chunk in llm.astream(messages):
|
||||
assert isinstance(chunk, BaseMessageChunk)
|
||||
if result is None:
|
||||
result = chunk
|
||||
continue
|
||||
result += chunk
|
||||
assert isinstance(result, AIMessageChunk)
|
||||
assert result.content
|
||||
assert "reasoning_content" in result.additional_kwargs
|
||||
assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
|
||||
# Sanity check the old behavior isn't present
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
result_chunk = None
|
||||
async for chunk in llm._astream(SAMPLE):
|
||||
assert isinstance(chunk, GenerationChunk)
|
||||
if result_chunk is None:
|
||||
result_chunk = chunk
|
||||
else:
|
||||
result_chunk += chunk
|
||||
|
||||
assert isinstance(result_chunk, GenerationChunk)
|
||||
assert result_chunk.text
|
||||
assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator]
|
||||
assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index]
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from OllamaLLM."""
|
||||
"""Test batch sync token generation from OllamaLLM."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
@ -143,7 +118,7 @@ async def test_abatch() -> None:
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from OllamaLLM."""
|
||||
"""Test batch sync token generation with tags."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
result = await llm.abatch(
|
||||
@ -154,7 +129,7 @@ async def test_abatch_tags() -> None:
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from OllamaLLM."""
|
||||
"""Test batch token generation from OllamaLLM."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
@ -163,75 +138,15 @@ def test_batch() -> None:
|
||||
|
||||
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test invoke tokens from OllamaLLM."""
|
||||
"""Test async invoke returning a string."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# TODO
|
||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
# async def test_ainvoke_no_reasoning(model: str) -> None:
|
||||
# """Test using async invoke with `reasoning=False`"""
|
||||
# llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
# message = SAMPLE
|
||||
# result = await llm.ainvoke(message)
|
||||
# assert result.content
|
||||
# assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
# # Sanity check the old behavior isn't present
|
||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
# async def test_reasoning_ainvoke(model: str) -> None:
|
||||
# """Test invoke with `reasoning=True`"""
|
||||
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||
# message = SAMPLE
|
||||
# result = await llm.ainvoke(message)
|
||||
# assert result.content
|
||||
# assert "reasoning_content" in result.additional_kwargs
|
||||
# assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
|
||||
# # Sanity check the old behavior isn't present
|
||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
||||
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
# assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from OllamaLLM."""
|
||||
"""Test sync invoke returning a string."""
|
||||
llm = OllamaLLM(model=MODEL_NAME)
|
||||
result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# TODO
|
||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
# def test_invoke_no_reasoning(model: str) -> None:
|
||||
# """Test using invoke with `reasoning=False`"""
|
||||
# llm = OllamaLLM(model=model, num_ctx=2**12)
|
||||
# message = SAMPLE
|
||||
# result = llm.invoke(message)
|
||||
# assert result.content
|
||||
# assert "reasoning_content" not in result.additional_kwargs
|
||||
|
||||
# # Sanity check the old behavior isn't present
|
||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
# def test_reasoning_invoke(model: str) -> None:
|
||||
# """Test invoke with `reasoning=True`"""
|
||||
# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True)
|
||||
# message = SAMPLE
|
||||
# result = llm.invoke(message)
|
||||
# assert result.content
|
||||
# assert "reasoning_content" in result.additional_kwargs
|
||||
# assert len(result.additional_kwargs["reasoning_content"]) > 0
|
||||
|
||||
# # Sanity check the old behavior isn't present
|
||||
# assert "<think>" not in result.content and "</think>" not in result.content
|
||||
# assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
# assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
@ -48,3 +48,24 @@ def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
# Test that validate_model is NOT called by default
|
||||
OllamaLLM(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
|
||||
def test_reasoning_aggregation() -> None:
|
||||
"""Test that reasoning chunks are aggregated into final response."""
|
||||
llm = OllamaLLM(model=MODEL_NAME, reasoning=True)
|
||||
prompts = ["some prompt"]
|
||||
mock_stream = [
|
||||
{"thinking": "I am thinking.", "done": False},
|
||||
{"thinking": " Still thinking.", "done": False},
|
||||
{"response": "Final Answer.", "done": True},
|
||||
]
|
||||
|
||||
with patch.object(llm, "_create_generate_stream") as mock_stream_method:
|
||||
mock_stream_method.return_value = iter(mock_stream)
|
||||
result = llm.generate(prompts)
|
||||
|
||||
assert result.generations[0][0].generation_info is not None
|
||||
assert (
|
||||
result.generations[0][0].generation_info["thinking"]
|
||||
== "I am thinking. Still thinking."
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user