mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-26 21:11:25 +00:00
more
This commit is contained in:
parent
10349b019a
commit
fd5c29a268
@ -68,6 +68,29 @@ from langchain_ollama._utils import validate_model
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _strip_think_tags(content: str) -> str:
|
||||
"""Strip ``<think>`` tags from content.
|
||||
|
||||
This is needed because some models have reasoning/thinking as their default
|
||||
behavior and will include ``<think>`` tags even when ``reasoning=False`` is set.
|
||||
|
||||
Since Ollama doesn't provide a way to completely disable thinking for models
|
||||
that do it by default, we must post-process the response to remove the tags
|
||||
when the user has explicitly disabled reasoning.
|
||||
|
||||
Args:
|
||||
content: The content that may contain think tags.
|
||||
|
||||
Returns:
|
||||
Content with think tags and their contents removed.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Remove everything between <think> and </think> tags, including the tags
|
||||
pattern = r"<think>.*?</think>"
|
||||
return re.sub(pattern, "", content, flags=re.DOTALL).strip()
|
||||
|
||||
|
||||
def _get_usage_metadata_from_generation_info(
|
||||
generation_info: Optional[Mapping[str, Any]],
|
||||
) -> Optional[UsageMetadata]:
|
||||
@ -615,6 +638,72 @@ class ChatOllama(BaseChatModel):
|
||||
The async client to use for making requests.
|
||||
"""
|
||||
|
||||
def _chat_params_v1(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate chat parameters with native v1 message support.
|
||||
|
||||
This method uses the v1-native message conversion and is preferred for handling
|
||||
v1 format messages.
|
||||
|
||||
Args:
|
||||
messages: List of messages to convert.
|
||||
stop: Optional stop sequences.
|
||||
**kwargs: Additional parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for Ollama API.
|
||||
"""
|
||||
# TODO make this just part of _chat_params ?
|
||||
# Depends on longrun decision and message formatting probably
|
||||
ollama_messages = self._convert_messages_to_ollama_messages_v1(messages)
|
||||
|
||||
if self.stop is not None and stop is not None:
|
||||
msg = "`stop` found in both the input and default params."
|
||||
raise ValueError(msg)
|
||||
if self.stop is not None:
|
||||
stop = self.stop
|
||||
|
||||
options_dict = kwargs.pop(
|
||||
"options",
|
||||
{
|
||||
"mirostat": self.mirostat,
|
||||
"mirostat_eta": self.mirostat_eta,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
"num_ctx": self.num_ctx,
|
||||
"num_gpu": self.num_gpu,
|
||||
"num_thread": self.num_thread,
|
||||
"num_predict": self.num_predict,
|
||||
"repeat_last_n": self.repeat_last_n,
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
"temperature": self.temperature,
|
||||
"seed": self.seed,
|
||||
"stop": self.stop if stop is None else stop,
|
||||
"tfs_z": self.tfs_z,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
},
|
||||
)
|
||||
|
||||
params = {
|
||||
"messages": ollama_messages,
|
||||
"stream": kwargs.pop("stream", True),
|
||||
"model": kwargs.pop("model", self.model),
|
||||
"think": kwargs.pop("reasoning", self.reasoning),
|
||||
"format": kwargs.pop("format", self.format),
|
||||
"options": Options(**options_dict),
|
||||
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if tools := kwargs.get("tools"):
|
||||
params["tools"] = tools
|
||||
|
||||
return params
|
||||
|
||||
def _chat_params(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@ -666,6 +755,34 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
return params
|
||||
|
||||
def _get_chat_params(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Get chat parameters, choosing between v0 and v1 methods.
|
||||
|
||||
This method automatically chooses the appropriate parameter generation method
|
||||
based on whether messages contain v1 format content.
|
||||
|
||||
Args:
|
||||
messages: List of messages to convert.
|
||||
stop: Optional stop sequences.
|
||||
**kwargs: Additional parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for Ollama API.
|
||||
"""
|
||||
# Check if any message has v1 format content (list of content blocks)
|
||||
has_v1_messages = any(isinstance(msg.content, list) for msg in messages)
|
||||
|
||||
if has_v1_messages:
|
||||
# Use v1-native method for better handling
|
||||
return self._chat_params_v1(messages, stop, **kwargs)
|
||||
# Use legacy v0 method for backward compatibility
|
||||
return self._chat_params(messages, stop, **kwargs)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_clients(self) -> Self:
|
||||
"""Set clients to use for ollama."""
|
||||
@ -685,6 +802,179 @@ class ChatOllama(BaseChatModel):
|
||||
validate_model(self._client, self.model)
|
||||
return self
|
||||
|
||||
def _convert_messages_to_ollama_messages_v1(
|
||||
self, messages: list[BaseMessage]
|
||||
) -> Sequence[Message]:
|
||||
"""Convert messages to Ollama format with native v1 support.
|
||||
|
||||
This method handles v1 format messages natively without converting to v0 first.
|
||||
This is the preferred method for v1 message handling.
|
||||
|
||||
Args:
|
||||
messages: List of messages to convert, may include v1 format.
|
||||
|
||||
Returns:
|
||||
Sequence of Ollama Message objects.
|
||||
"""
|
||||
ollama_messages: list = []
|
||||
for message in messages:
|
||||
# Handle v1 format messages natively (don't convert to v0)
|
||||
role: str
|
||||
tool_call_id: Optional[str] = None
|
||||
tool_calls: Optional[list[dict[str, Any]]] = None
|
||||
if isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
role = "assistant"
|
||||
tool_calls = (
|
||||
[
|
||||
_lc_tool_call_to_openai_tool_call(tool_call)
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
if message.tool_calls
|
||||
else None
|
||||
)
|
||||
elif isinstance(message, SystemMessage):
|
||||
role = "system"
|
||||
elif isinstance(message, ChatMessage):
|
||||
role = message.role
|
||||
elif isinstance(message, ToolMessage):
|
||||
role = "tool"
|
||||
tool_call_id = message.tool_call_id
|
||||
else:
|
||||
msg = "Received unsupported message type for Ollama."
|
||||
raise ValueError(msg)
|
||||
|
||||
content = ""
|
||||
images = []
|
||||
reasoning_content = None
|
||||
|
||||
# Handle v1 format content (list of content blocks)
|
||||
if isinstance(message.content, list):
|
||||
for content_part in message.content:
|
||||
if isinstance(content_part, dict):
|
||||
block_type = content_part.get("type")
|
||||
if block_type == "text":
|
||||
content += content_part.get("text", "")
|
||||
elif block_type == "reasoning":
|
||||
# Extract reasoning content for separate handling
|
||||
reasoning_content = content_part.get("reasoning", "")
|
||||
elif block_type == "tool_call":
|
||||
# Skip - handled by tool_calls property
|
||||
continue
|
||||
elif block_type == "image_url":
|
||||
image_url = None
|
||||
temp_image_url = content_part.get("image_url")
|
||||
if isinstance(temp_image_url, str):
|
||||
image_url = temp_image_url
|
||||
elif (
|
||||
isinstance(temp_image_url, dict)
|
||||
and "url" in temp_image_url
|
||||
and isinstance(temp_image_url["url"], str)
|
||||
):
|
||||
image_url = temp_image_url["url"]
|
||||
else:
|
||||
msg = (
|
||||
"Only string image_url or dict with string 'url' "
|
||||
"inside content parts are supported."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
image_url_components = image_url.split(",")
|
||||
# Support data:image/jpeg;base64,<image> format
|
||||
# and base64 strings
|
||||
if len(image_url_components) > 1:
|
||||
images.append(image_url_components[1])
|
||||
else:
|
||||
images.append(image_url_components[0])
|
||||
elif is_data_content_block(content_part):
|
||||
image = _get_image_from_data_content_block(content_part)
|
||||
images.append(image)
|
||||
else:
|
||||
# Convert unknown content blocks to NonStandardContentBlock
|
||||
# TODO what to do with these?
|
||||
_convert_unknown_content_block_to_non_standard(content_part)
|
||||
continue
|
||||
else:
|
||||
# Handle content blocks that are not dicts
|
||||
# (e.g., TextContentBlock objects)
|
||||
if hasattr(content_part, "type"):
|
||||
if content_part.type == "text":
|
||||
content += getattr(content_part, "text", "")
|
||||
elif content_part.type == "reasoning":
|
||||
reasoning_content = getattr(
|
||||
content_part, "reasoning", ""
|
||||
)
|
||||
# Add other content block types as needed
|
||||
|
||||
# Handle v0 format content (string)
|
||||
elif isinstance(message.content, str):
|
||||
content = message.content
|
||||
else:
|
||||
# Handle other content formats if needed
|
||||
for content_part in cast(list[dict], message.content):
|
||||
if content_part.get("type") == "text":
|
||||
content += f"\n{content_part['text']}"
|
||||
elif content_part.get("type") == "tool_use":
|
||||
continue
|
||||
elif content_part.get("type") == "tool_call":
|
||||
# Skip - handled by tool_calls property
|
||||
continue
|
||||
elif content_part.get("type") == "reasoning":
|
||||
# Skip - handled by reasoning parameter
|
||||
continue
|
||||
elif content_part.get("type") == "image_url":
|
||||
image_url = None
|
||||
temp_image_url = content_part.get("image_url")
|
||||
if isinstance(temp_image_url, str):
|
||||
image_url = temp_image_url
|
||||
elif (
|
||||
isinstance(temp_image_url, dict)
|
||||
and "url" in temp_image_url
|
||||
and isinstance(temp_image_url["url"], str)
|
||||
):
|
||||
image_url = temp_image_url["url"]
|
||||
else:
|
||||
msg = (
|
||||
"Only string image_url or dict with string 'url' "
|
||||
"inside content parts are supported."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
image_url_components = image_url.split(",")
|
||||
# Support data:image/jpeg;base64,<image> format
|
||||
# and base64 strings
|
||||
if len(image_url_components) > 1:
|
||||
images.append(image_url_components[1])
|
||||
else:
|
||||
images.append(image_url_components[0])
|
||||
elif is_data_content_block(content_part):
|
||||
image = _get_image_from_data_content_block(content_part)
|
||||
images.append(image)
|
||||
else:
|
||||
# Convert unknown content blocks to NonStandardContentBlock
|
||||
# TODO what to do with these?
|
||||
_convert_unknown_content_block_to_non_standard(content_part)
|
||||
continue
|
||||
|
||||
# Should convert to ollama.Message once role includes tool,
|
||||
# and tool_call_id is in Message
|
||||
msg_: dict = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"images": images,
|
||||
}
|
||||
if tool_calls:
|
||||
msg_["tool_calls"] = tool_calls
|
||||
if tool_call_id:
|
||||
msg_["tool_call_id"] = tool_call_id
|
||||
# Store reasoning content for later use if present
|
||||
if reasoning_content:
|
||||
msg_["_reasoning_content"] = reasoning_content
|
||||
ollama_messages.append(msg_)
|
||||
|
||||
return ollama_messages
|
||||
|
||||
def _convert_messages_to_ollama_messages(
|
||||
self, messages: list[BaseMessage]
|
||||
) -> Sequence[Message]:
|
||||
@ -922,6 +1212,10 @@ class ChatOllama(BaseChatModel):
|
||||
else ""
|
||||
)
|
||||
|
||||
# Strip think tags if reasoning is explicitly disabled
|
||||
if reasoning is False:
|
||||
content = _strip_think_tags(content)
|
||||
|
||||
# Warn and skip responses with done_reason: 'load' and empty content
|
||||
# These indicate the model was loaded but no actual generation occurred
|
||||
is_load_response_with_empty_content = (
|
||||
@ -1003,6 +1297,10 @@ class ChatOllama(BaseChatModel):
|
||||
else ""
|
||||
)
|
||||
|
||||
# Strip think tags if reasoning is explicitly disabled
|
||||
if reasoning is False:
|
||||
content = _strip_think_tags(content)
|
||||
|
||||
# Warn and skip responses with done_reason: 'load' and empty content
|
||||
# These indicate the model was loaded but no actual generation occurred
|
||||
is_load_response_with_empty_content = (
|
||||
|
@ -23,7 +23,7 @@ class MathAnswer(BaseModel):
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_stream_no_reasoning(model: str) -> None:
|
||||
"""Test streaming with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@ -46,7 +46,7 @@ def test_stream_no_reasoning(model: str) -> None:
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_astream_no_reasoning(model: str) -> None:
|
||||
"""Test async streaming with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@ -175,7 +175,7 @@ async def test_reasoning_astream(model: str) -> None:
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_invoke_no_reasoning(model: str) -> None:
|
||||
"""Test using invoke with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = llm.invoke([message])
|
||||
assert result.content
|
||||
@ -189,7 +189,7 @@ def test_invoke_no_reasoning(model: str) -> None:
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
async def test_ainvoke_no_reasoning(model: str) -> None:
|
||||
"""Test using async invoke with `reasoning=False`"""
|
||||
llm = ChatOllama(model=model, num_ctx=2**12)
|
||||
llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
result = await llm.ainvoke([message])
|
||||
assert result.content
|
||||
@ -256,3 +256,43 @@ async def test_reasoning_ainvoke(model: str) -> None:
|
||||
assert "<think>" not in result.content and "</think>" not in result.content
|
||||
assert "<think>" not in result.additional_kwargs["reasoning_content"]
|
||||
assert "</think>" not in result.additional_kwargs["reasoning_content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")])
|
||||
def test_think_tag_stripping_necessity(model: str) -> None:
|
||||
"""Test that demonstrates why ``_strip_think_tags`` is necessary.
|
||||
|
||||
DeepSeek R1 models include reasoning/thinking as their default behavior.
|
||||
When ``reasoning=False`` is set, the user explicitly wants no reasoning content,
|
||||
but Ollama cannot disable thinking at the API level for these models.
|
||||
Therefore, post-processing is required to strip the ``<think>`` tags.
|
||||
|
||||
This test documents the specific behavior that necessitates the
|
||||
``_strip_think_tags`` function in the chat_models.py implementation.
|
||||
"""
|
||||
# Test with reasoning=None (default behavior - should include think tags)
|
||||
llm_default = ChatOllama(model=model, reasoning=None, num_ctx=2**12)
|
||||
message = HumanMessage(content=SAMPLE)
|
||||
|
||||
result_default = llm_default.invoke([message])
|
||||
|
||||
# With reasoning=None, the model's default behavior includes <think> tags
|
||||
# This demonstrates why we need the stripping logic
|
||||
assert "<think>" in result_default.content
|
||||
assert "</think>" in result_default.content
|
||||
assert "reasoning_content" not in result_default.additional_kwargs
|
||||
|
||||
# Test with reasoning=False (explicit disable - should NOT include think tags)
|
||||
llm_disabled = ChatOllama(model=model, reasoning=False, num_ctx=2**12)
|
||||
|
||||
result_disabled = llm_disabled.invoke([message])
|
||||
|
||||
# With reasoning=False, think tags should be stripped from content
|
||||
# This verifies that _strip_think_tags is working correctly
|
||||
assert "<think>" not in result_disabled.content
|
||||
assert "</think>" not in result_disabled.content
|
||||
assert "reasoning_content" not in result_disabled.additional_kwargs
|
||||
|
||||
# Verify the difference: same model, different reasoning settings
|
||||
# Default includes tags, disabled strips them
|
||||
assert result_default.content != result_disabled.content
|
||||
|
@ -168,7 +168,7 @@ class TestChatOllama(ChatModelIntegrationTests):
|
||||
|
||||
with pytest.raises(ValidationError) as excinfo:
|
||||
ChatOllama(model="any-model", validate_model_on_init=True)
|
||||
assert "not found in Ollama" in str(excinfo.value)
|
||||
assert "Failed to connect to Ollama" in str(excinfo.value)
|
||||
|
||||
@patch("langchain_ollama.chat_models.Client.list")
|
||||
def test_init_response_error(self, mock_list: MagicMock) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user