This commit is contained in:
Mason Daugherty 2025-07-25 17:01:17 -04:00
parent 10349b019a
commit fd5c29a268
No known key found for this signature in database
3 changed files with 343 additions and 5 deletions

View File

@ -68,6 +68,29 @@ from langchain_ollama._utils import validate_model
log = logging.getLogger(__name__)
def _strip_think_tags(content: str) -> str:
"""Strip ``<think>`` tags from content.
This is needed because some models have reasoning/thinking as their default
behavior and will include ``<think>`` tags even when ``reasoning=False`` is set.
Since Ollama doesn't provide a way to completely disable thinking for models
that do it by default, we must post-process the response to remove the tags
when the user has explicitly disabled reasoning.
Args:
content: The content that may contain think tags.
Returns:
Content with think tags and their contents removed.
"""
import re
# Remove everything between <think> and </think> tags, including the tags
pattern = r"<think>.*?</think>"
return re.sub(pattern, "", content, flags=re.DOTALL).strip()
def _get_usage_metadata_from_generation_info(
generation_info: Optional[Mapping[str, Any]],
) -> Optional[UsageMetadata]:
@ -615,6 +638,72 @@ class ChatOllama(BaseChatModel):
The async client to use for making requests.
"""
def _chat_params_v1(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Generate chat parameters with native v1 message support.
This method uses the v1-native message conversion and is preferred for handling
v1 format messages.
Args:
messages: List of messages to convert.
stop: Optional stop sequences.
**kwargs: Additional parameters.
Returns:
Dictionary of parameters for Ollama API.
"""
# TODO make this just part of _chat_params ?
# Depends on longrun decision and message formatting probably
ollama_messages = self._convert_messages_to_ollama_messages_v1(messages)
if self.stop is not None and stop is not None:
msg = "`stop` found in both the input and default params."
raise ValueError(msg)
if self.stop is not None:
stop = self.stop
options_dict = kwargs.pop(
"options",
{
"mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau,
"num_ctx": self.num_ctx,
"num_gpu": self.num_gpu,
"num_thread": self.num_thread,
"num_predict": self.num_predict,
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"temperature": self.temperature,
"seed": self.seed,
"stop": self.stop if stop is None else stop,
"tfs_z": self.tfs_z,
"top_k": self.top_k,
"top_p": self.top_p,
},
)
params = {
"messages": ollama_messages,
"stream": kwargs.pop("stream", True),
"model": kwargs.pop("model", self.model),
"think": kwargs.pop("reasoning", self.reasoning),
"format": kwargs.pop("format", self.format),
"options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
**kwargs,
}
if tools := kwargs.get("tools"):
params["tools"] = tools
return params
def _chat_params(
self,
messages: list[BaseMessage],
@ -666,6 +755,34 @@ class ChatOllama(BaseChatModel):
return params
def _get_chat_params(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Get chat parameters, choosing between v0 and v1 methods.
This method automatically chooses the appropriate parameter generation method
based on whether messages contain v1 format content.
Args:
messages: List of messages to convert.
stop: Optional stop sequences.
**kwargs: Additional parameters.
Returns:
Dictionary of parameters for Ollama API.
"""
# Check if any message has v1 format content (list of content blocks)
has_v1_messages = any(isinstance(msg.content, list) for msg in messages)
if has_v1_messages:
# Use v1-native method for better handling
return self._chat_params_v1(messages, stop, **kwargs)
# Use legacy v0 method for backward compatibility
return self._chat_params(messages, stop, **kwargs)
@model_validator(mode="after")
def _set_clients(self) -> Self:
"""Set clients to use for ollama."""
@ -685,6 +802,179 @@ class ChatOllama(BaseChatModel):
validate_model(self._client, self.model)
return self
def _convert_messages_to_ollama_messages_v1(
self, messages: list[BaseMessage]
) -> Sequence[Message]:
"""Convert messages to Ollama format with native v1 support.
This method handles v1 format messages natively without converting to v0 first.
This is the preferred method for v1 message handling.
Args:
messages: List of messages to convert, may include v1 format.
Returns:
Sequence of Ollama Message objects.
"""
ollama_messages: list = []
for message in messages:
# Handle v1 format messages natively (don't convert to v0)
role: str
tool_call_id: Optional[str] = None
tool_calls: Optional[list[dict[str, Any]]] = None
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
role = "assistant"
tool_calls = (
[
_lc_tool_call_to_openai_tool_call(tool_call)
for tool_call in message.tool_calls
]
if message.tool_calls
else None
)
elif isinstance(message, SystemMessage):
role = "system"
elif isinstance(message, ChatMessage):
role = message.role
elif isinstance(message, ToolMessage):
role = "tool"
tool_call_id = message.tool_call_id
else:
msg = "Received unsupported message type for Ollama."
raise ValueError(msg)
content = ""
images = []
reasoning_content = None
# Handle v1 format content (list of content blocks)
if isinstance(message.content, list):
for content_part in message.content:
if isinstance(content_part, dict):
block_type = content_part.get("type")
if block_type == "text":
content += content_part.get("text", "")
elif block_type == "reasoning":
# Extract reasoning content for separate handling
reasoning_content = content_part.get("reasoning", "")
elif block_type == "tool_call":
# Skip - handled by tool_calls property
continue
elif block_type == "image_url":
image_url = None
temp_image_url = content_part.get("image_url")
if isinstance(temp_image_url, str):
image_url = temp_image_url
elif (
isinstance(temp_image_url, dict)
and "url" in temp_image_url
and isinstance(temp_image_url["url"], str)
):
image_url = temp_image_url["url"]
else:
msg = (
"Only string image_url or dict with string 'url' "
"inside content parts are supported."
)
raise ValueError(msg)
image_url_components = image_url.split(",")
# Support data:image/jpeg;base64,<image> format
# and base64 strings
if len(image_url_components) > 1:
images.append(image_url_components[1])
else:
images.append(image_url_components[0])
elif is_data_content_block(content_part):
image = _get_image_from_data_content_block(content_part)
images.append(image)
else:
# Convert unknown content blocks to NonStandardContentBlock
# TODO what to do with these?
_convert_unknown_content_block_to_non_standard(content_part)
continue
else:
# Handle content blocks that are not dicts
# (e.g., TextContentBlock objects)
if hasattr(content_part, "type"):
if content_part.type == "text":
content += getattr(content_part, "text", "")
elif content_part.type == "reasoning":
reasoning_content = getattr(
content_part, "reasoning", ""
)
# Add other content block types as needed
# Handle v0 format content (string)
elif isinstance(message.content, str):
content = message.content
else:
# Handle other content formats if needed
for content_part in cast(list[dict], message.content):
if content_part.get("type") == "text":
content += f"\n{content_part['text']}"
elif content_part.get("type") == "tool_use":
continue
elif content_part.get("type") == "tool_call":
# Skip - handled by tool_calls property
continue
elif content_part.get("type") == "reasoning":
# Skip - handled by reasoning parameter
continue
elif content_part.get("type") == "image_url":
image_url = None
temp_image_url = content_part.get("image_url")
if isinstance(temp_image_url, str):
image_url = temp_image_url
elif (
isinstance(temp_image_url, dict)
and "url" in temp_image_url
and isinstance(temp_image_url["url"], str)
):
image_url = temp_image_url["url"]
else:
msg = (
"Only string image_url or dict with string 'url' "
"inside content parts are supported."
)
raise ValueError(msg)
image_url_components = image_url.split(",")
# Support data:image/jpeg;base64,<image> format
# and base64 strings
if len(image_url_components) > 1:
images.append(image_url_components[1])
else:
images.append(image_url_components[0])
elif is_data_content_block(content_part):
image = _get_image_from_data_content_block(content_part)
images.append(image)
else:
# Convert unknown content blocks to NonStandardContentBlock
# TODO what to do with these?
_convert_unknown_content_block_to_non_standard(content_part)
continue
# Should convert to ollama.Message once role includes tool,
# and tool_call_id is in Message
msg_: dict = {
"role": role,
"content": content,
"images": images,
}
if tool_calls:
msg_["tool_calls"] = tool_calls
if tool_call_id:
msg_["tool_call_id"] = tool_call_id
# Store reasoning content for later use if present
if reasoning_content:
msg_["_reasoning_content"] = reasoning_content
ollama_messages.append(msg_)
return ollama_messages
def _convert_messages_to_ollama_messages(
self, messages: list[BaseMessage]
) -> Sequence[Message]:
@ -922,6 +1212,10 @@ class ChatOllama(BaseChatModel):
else ""
)
# Strip think tags if reasoning is explicitly disabled
if reasoning is False:
content = _strip_think_tags(content)
# Warn and skip responses with done_reason: 'load' and empty content
# These indicate the model was loaded but no actual generation occurred
is_load_response_with_empty_content = (
@ -1003,6 +1297,10 @@ class ChatOllama(BaseChatModel):
else ""
)
# Strip think tags if reasoning is explicitly disabled
if reasoning is False:
content = _strip_think_tags(content)
# Warn and skip responses with done_reason: 'load' and empty content
# These indicate the model was loaded but no actual generation occurred
is_load_response_with_empty_content = (

View File

@ -23,7 +23,7 @@ class MathAnswer(BaseModel):
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_stream_no_reasoning(model: str) -> None:
"""Test streaming with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12)
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
messages = [
{
"role": "user",
@ -46,7 +46,7 @@ def test_stream_no_reasoning(model: str) -> None:
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_astream_no_reasoning(model: str) -> None:
"""Test async streaming with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12)
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
messages = [
{
"role": "user",
@ -175,7 +175,7 @@ async def test_reasoning_astream(model: str) -> None:
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_invoke_no_reasoning(model: str) -> None:
"""Test using invoke with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12)
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
message = HumanMessage(content=SAMPLE)
result = llm.invoke([message])
assert result.content
@ -189,7 +189,7 @@ def test_invoke_no_reasoning(model: str) -> None:
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
async def test_ainvoke_no_reasoning(model: str) -> None:
"""Test using async invoke with `reasoning=False`"""
llm = ChatOllama(model=model, num_ctx=2**12)
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
message = HumanMessage(content=SAMPLE)
result = await llm.ainvoke([message])
assert result.content
@ -256,3 +256,43 @@ async def test_reasoning_ainvoke(model: str) -> None:
assert "<think>" not in result.content and "</think>" not in result.content
assert "<think>" not in result.additional_kwargs["reasoning_content"]
assert "</think>" not in result.additional_kwargs["reasoning_content"]
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
def test_think_tag_stripping_necessity(model: str) -> None:
"""Test that demonstrates why ``_strip_think_tags`` is necessary.
DeepSeek R1 models include reasoning/thinking as their default behavior.
When ``reasoning=False`` is set, the user explicitly wants no reasoning content,
but Ollama cannot disable thinking at the API level for these models.
Therefore, post-processing is required to strip the ``<think>`` tags.
This test documents the specific behavior that necessitates the
``_strip_think_tags`` function in the chat_models.py implementation.
"""
# Test with reasoning=None (default behavior - should include think tags)
llm_default = ChatOllama(model=model, reasoning=None, num_ctx=2**12)
message = HumanMessage(content=SAMPLE)
result_default = llm_default.invoke([message])
# With reasoning=None, the model's default behavior includes <think> tags
# This demonstrates why we need the stripping logic
assert "<think>" in result_default.content
assert "</think>" in result_default.content
assert "reasoning_content" not in result_default.additional_kwargs
# Test with reasoning=False (explicit disable - should NOT include think tags)
llm_disabled = ChatOllama(model=model, reasoning=False, num_ctx=2**12)
result_disabled = llm_disabled.invoke([message])
# With reasoning=False, think tags should be stripped from content
# This verifies that _strip_think_tags is working correctly
assert "<think>" not in result_disabled.content
assert "</think>" not in result_disabled.content
assert "reasoning_content" not in result_disabled.additional_kwargs
# Verify the difference: same model, different reasoning settings
# Default includes tags, disabled strips them
assert result_default.content != result_disabled.content

View File

@ -168,7 +168,7 @@ class TestChatOllama(ChatModelIntegrationTests):
with pytest.raises(ValidationError) as excinfo:
ChatOllama(model="any-model", validate_model_on_init=True)
assert "not found in Ollama" in str(excinfo.value)
assert "Failed to connect to Ollama" in str(excinfo.value)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_response_error(self, mock_list: MagicMock) -> None: