diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index 984cc0ccfcc..f9c6ca060d1 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -7,12 +7,14 @@ from typing import ( AsyncIterator, Callable, Dict, + Final, Iterator, List, Literal, Mapping, Optional, Sequence, + Tuple, Type, Union, cast, @@ -30,6 +32,7 @@ from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, + BaseMessageChunk, HumanMessage, SystemMessage, ToolCall, @@ -57,6 +60,9 @@ from pydantic.json_schema import JsonSchemaValue from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self, is_typeddict +DEFAULT_THINK_TOKEN_START: Final[str] = "" +DEFAULT_THINK_TOKEN_END: Final[str] = "" + def _get_usage_metadata_from_generation_info( generation_info: Optional[Mapping[str, Any]], @@ -335,6 +341,13 @@ class ChatOllama(BaseChatModel): model: str """Model name to use.""" + extract_reasoning: Optional[Union[bool, Tuple[str, str]]] = False + """Whether to extract the reasoning tokens in think blocks. + Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`. + If a tuple is supplied, they are assumed to be the (start, end) tokens. + If `extract_reasoning=True`, the tokens will default to (, ). + """ + mirostat: Optional[int] = None """Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)""" @@ -568,6 +581,28 @@ class ChatOllama(BaseChatModel): return ollama_messages + def _extract_reasoning( + self, message_chunk: BaseMessageChunk, is_thinking: bool + ) -> Tuple[BaseMessageChunk, bool]: + """Mutate a message chunk to extract reasoning content.""" + if not self.extract_reasoning: + return message_chunk, is_thinking + elif self.extract_reasoning is True: + start_token = DEFAULT_THINK_TOKEN_START + end_token = DEFAULT_THINK_TOKEN_END + else: + start_token, end_token = cast(tuple, self.extract_reasoning) + if start_token in message_chunk.content: + is_thinking = True + content = message_chunk.content + if is_thinking: + message_chunk.additional_kwargs["reasoning_content"] = content + message_chunk.content = "" + if end_token in content: + is_thinking = False + + return message_chunk, is_thinking + async def _acreate_chat_stream( self, messages: List[BaseMessage], @@ -604,35 +639,17 @@ class ChatOllama(BaseChatModel): **kwargs: Any, ) -> ChatGenerationChunk: final_chunk = None - for stream_resp in self._create_chat_stream(messages, stop, **kwargs): - if not isinstance(stream_resp, str): - chunk = ChatGenerationChunk( - message=AIMessageChunk( - content=( - stream_resp["message"]["content"] - if "message" in stream_resp - and "content" in stream_resp["message"] - else "" - ), - usage_metadata=_get_usage_metadata_from_generation_info( - stream_resp - ), - tool_calls=_get_tool_calls_from_response(stream_resp), - ), - generation_info=( - dict(stream_resp) if stream_resp.get("done") is True else None - ), + for chunk in self._iterate_over_stream(messages, stop, **kwargs): + if final_chunk is None: + final_chunk = chunk + else: + final_chunk += chunk + if run_manager: + run_manager.on_llm_new_token( + chunk.text, + chunk=chunk, + verbose=verbose, ) - if final_chunk is None: - final_chunk = chunk - else: - final_chunk += chunk - if run_manager: - run_manager.on_llm_new_token( - chunk.text, - chunk=chunk, - verbose=verbose, - ) if final_chunk is None: raise ValueError("No data received from Ollama stream.") @@ -647,35 +664,17 @@ class ChatOllama(BaseChatModel): **kwargs: Any, ) -> ChatGenerationChunk: final_chunk = None - async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs): - if not isinstance(stream_resp, str): - chunk = ChatGenerationChunk( - message=AIMessageChunk( - content=( - stream_resp["message"]["content"] - if "message" in stream_resp - and "content" in stream_resp["message"] - else "" - ), - usage_metadata=_get_usage_metadata_from_generation_info( - stream_resp - ), - tool_calls=_get_tool_calls_from_response(stream_resp), - ), - generation_info=( - dict(stream_resp) if stream_resp.get("done") is True else None - ), + async for chunk in self._aiterate_over_stream(messages, stop, **kwargs): + if final_chunk is None: + final_chunk = chunk + else: + final_chunk += chunk + if run_manager: + await run_manager.on_llm_new_token( + chunk.text, + chunk=chunk, + verbose=verbose, ) - if final_chunk is None: - final_chunk = chunk - else: - final_chunk += chunk - if run_manager: - await run_manager.on_llm_new_token( - chunk.text, - chunk=chunk, - verbose=verbose, - ) if final_chunk is None: raise ValueError("No data received from Ollama stream.") @@ -712,18 +711,19 @@ class ChatOllama(BaseChatModel): content=final_chunk.text, usage_metadata=cast(AIMessageChunk, final_chunk.message).usage_metadata, tool_calls=cast(AIMessageChunk, final_chunk.message).tool_calls, + additional_kwargs=final_chunk.message.additional_kwargs, ), generation_info=generation_info, ) return ChatResult(generations=[chat_generation]) - def _stream( + def _iterate_over_stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + is_thinking = False for stream_resp in self._create_chat_stream(messages, stop, **kwargs): if not isinstance(stream_resp, str): chunk = ChatGenerationChunk( @@ -743,20 +743,35 @@ class ChatOllama(BaseChatModel): dict(stream_resp) if stream_resp.get("done") is True else None ), ) - if run_manager: - run_manager.on_llm_new_token( - chunk.text, - verbose=self.verbose, + if self.extract_reasoning: + message, is_thinking = self._extract_reasoning( + chunk.message, is_thinking ) + chunk.message = message yield chunk - async def _astream( + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + for chunk in self._iterate_over_stream(messages, stop, **kwargs): + if run_manager: + run_manager.on_llm_new_token( + chunk.text, + verbose=self.verbose, + ) + yield chunk + + async def _aiterate_over_stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: + is_thinking = False async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs): if not isinstance(stream_resp, str): chunk = ChatGenerationChunk( @@ -776,13 +791,28 @@ class ChatOllama(BaseChatModel): dict(stream_resp) if stream_resp.get("done") is True else None ), ) - if run_manager: - await run_manager.on_llm_new_token( - chunk.text, - verbose=self.verbose, + if self.extract_reasoning: + message, is_thinking = self._extract_reasoning( + chunk.message, is_thinking ) + chunk.message = message yield chunk + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + async for chunk in self._aiterate_over_stream(messages, stop, **kwargs): + if run_manager: + await run_manager.on_llm_new_token( + chunk.text, + verbose=self.verbose, + ) + yield chunk + async def _agenerate( self, messages: List[BaseMessage], @@ -799,6 +829,7 @@ class ChatOllama(BaseChatModel): content=final_chunk.text, usage_metadata=cast(AIMessageChunk, final_chunk.message).usage_metadata, tool_calls=cast(AIMessageChunk, final_chunk.message).tool_calls, + additional_kwargs=final_chunk.message.additional_kwargs, ), generation_info=generation_info, ) @@ -1083,6 +1114,7 @@ class ChatOllama(BaseChatModel): # 'parsing_error': None # } """ # noqa: E501, D301 + _ = kwargs.pop("strict", None) if kwargs: raise ValueError(f"Received unsupported arguments {kwargs}") is_pydantic_schema = _is_pydantic_class(schema) diff --git a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_reasoning.py b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_reasoning.py new file mode 100644 index 00000000000..5e41c424a66 --- /dev/null +++ b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_reasoning.py @@ -0,0 +1,162 @@ +"""Ollama specific chat model integration tests for reasoning models.""" + +import pytest +from langchain_core.messages import AIMessageChunk, BaseMessageChunk, HumanMessage +from pydantic import ValidationError + +from langchain_ollama import ChatOllama + +SAMPLE = "What is 3^3?" + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_deepseek_messages_stream_no_reasoning(model: str) -> None: + """Test deepseek model without parsing.""" + llm = ChatOllama(model=model, num_ctx=2**12) + messages = [ + { + "role": "user", + "content": SAMPLE, + } + ] + result = None + for chunk in llm.stream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk + assert isinstance(result, AIMessageChunk) + assert result.content + assert "" in result.content and "" in result.content + assert "reasoning_content" not in result.additional_kwargs + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_deepseek_messages_stream_bool(model: str) -> None: + """Test deepseek model with reasoning bool=True""" + llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True) + messages = [ + { + "role": "user", + "content": SAMPLE, + } + ] + result = None + for chunk in llm.stream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk + assert isinstance(result, AIMessageChunk) + assert result.content + assert "" not in result.content and "" not in result.content + assert "reasoning_content" in result.additional_kwargs + assert len(result.additional_kwargs["reasoning_content"]) > 0 + assert "" in result.additional_kwargs["reasoning_content"] + assert "" in result.additional_kwargs["reasoning_content"] + clean_content = ( + result.additional_kwargs["reasoning_content"] + .replace("", "") + .replace("", "") + .strip() + ) + assert len(clean_content) > 0 + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_deepseek_messages_stream_tuple(model: str) -> None: + """Test deepseek model with reasoning with tuple=...""" + llm = ChatOllama( + model=model, num_ctx=2**12, extract_reasoning=("", "") + ) + messages = [ + { + "role": "user", + "content": SAMPLE, + } + ] + result = None + for chunk in llm.stream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk + assert isinstance(result, AIMessageChunk) + assert result.content + assert "" not in result.content and "" not in result.content + assert "reasoning_content" in result.additional_kwargs + assert len(result.additional_kwargs["reasoning_content"]) > 0 + assert "" in result.additional_kwargs["reasoning_content"] + assert "" in result.additional_kwargs["reasoning_content"] + clean_content = ( + result.additional_kwargs["reasoning_content"] + .replace("", "") + .replace("", "") + .strip() + ) + assert len(clean_content) > 0 + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_deepseek_messages_invoke_no_reasoning(model: str) -> None: + """Test deepseek model without parsing using invoke.""" + llm = ChatOllama(model=model, num_ctx=2**12) + message = HumanMessage(content=SAMPLE) + result = llm.invoke([message]) + assert result.content + assert "" in result.content and "" in result.content + assert "reasoning_content" not in result.additional_kwargs + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_deepseek_messages_invoke_bool(model: str) -> None: + """Test deepseek model with reasoning bool=True using invoke""" + llm = ChatOllama(model=model, num_ctx=2**12, extract_reasoning=True) + message = HumanMessage(content=SAMPLE) + result = llm.invoke([message]) + assert result.content + assert "" not in result.content and "" not in result.content + assert "reasoning_content" in result.additional_kwargs + assert len(result.additional_kwargs["reasoning_content"]) > 0 + assert "" in result.additional_kwargs["reasoning_content"] + assert "" in result.additional_kwargs["reasoning_content"] + clean_content = ( + result.additional_kwargs["reasoning_content"] + .replace("", "") + .replace("", "") + .strip() + ) + assert len(clean_content) > 0 + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_deepseek_messages_invoke_tuple(model: str) -> None: + """Test deepseek model with reasoning with tuple=... using invoke""" + llm = ChatOllama( + model=model, num_ctx=2**12, extract_reasoning=("", "") + ) + message = HumanMessage(content=SAMPLE) + result = llm.invoke([message]) + assert result.content + assert "" not in result.content and "" not in result.content + assert "reasoning_content" in result.additional_kwargs + assert len(result.additional_kwargs["reasoning_content"]) > 0 + assert "" in result.additional_kwargs["reasoning_content"] + assert "" in result.additional_kwargs["reasoning_content"] + clean_content = ( + result.additional_kwargs["reasoning_content"] + .replace("", "") + .replace("", "") + .strip() + ) + assert len(clean_content) > 0 + + +@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +def test_deepseek_invalid(model: str) -> None: + """Test deepseek model with reasoning raises ValidationError""" + with pytest.raises(ValidationError): + _ = ChatOllama(model=model, extract_reasoning={"invalid": "data"}) # type: ignore[arg-type] diff --git a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py index bc39b1319df..cccda2cc186 100644 --- a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py +++ b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py @@ -23,3 +23,7 @@ class TestChatOllama(ChatModelIntegrationTests): @property def supports_json_mode(self) -> bool: return True + + @property + def has_tool_choice(self) -> bool: + return False