diff --git a/libs/partners/ollama/langchain_ollama/_utils.py b/libs/partners/ollama/langchain_ollama/_utils.py index 89eb25ee6b4..f3cd6fe9a4d 100644 --- a/libs/partners/ollama/langchain_ollama/_utils.py +++ b/libs/partners/ollama/langchain_ollama/_utils.py @@ -16,7 +16,9 @@ def validate_model(client: Client, model_name: str) -> None: """ try: response = client.list() - model_names: list[str] = [model["name"] for model in response["models"]] + + model_names: list[str] = [model["model"] for model in response["models"]] + if not any( model_name == m or m.startswith(f"{model_name}:") for m in model_names ): @@ -27,10 +29,7 @@ def validate_model(client: Client, model_name: str) -> None: ) raise ValueError(msg) except ConnectError as e: - msg = ( - "Connection to Ollama failed. Please make sure Ollama is running " - f"and accessible at {client._client.base_url}. " - ) + msg = "Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download" # noqa: E501 raise ValueError(msg) from e except ResponseError as e: msg = ( diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index cb2a1a9e84e..b987e8492d4 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -217,17 +217,17 @@ class ChatOllama(BaseChatModel): `supported models `__. - ``True``: Enables reasoning mode. The model's reasoning process will be - captured and returned separately in the ``additional_kwargs`` of the - response message, under ``reasoning_content``. The main response - content will not include the reasoning tags. + captured and returned separately in the ``additional_kwargs`` of the + response message, under ``reasoning_content``. The main response + content will not include the reasoning tags. - ``False``: Disables reasoning mode. The model will not perform any reasoning, - and the response will not include any reasoning content. + and the response will not include any reasoning content. - ``None`` (Default): The model will use its default reasoning behavior. Note - however, if the model's default behavior *is* to perform reasoning, think tags - ()```` and ````) will be present within the main response content - unless you set ``reasoning`` to ``True``. + however, if the model's default behavior *is* to perform reasoning, think tags + (```` and ````) will be present within the main response content + unless you set ``reasoning`` to ``True``. temperature: float - Sampling temperature. Ranges from 0.0 to 1.0. + Sampling temperature. Ranges from ``0.0`` to ``1.0``. num_predict: Optional[int] Max number of tokens to generate. @@ -343,7 +343,6 @@ class ChatOllama(BaseChatModel): '{"location": "Pune, India", "time_of_day": "morning"}' Tool Calling: - .. code-block:: python from langchain_ollama import ChatOllama @@ -362,6 +361,48 @@ class ChatOllama(BaseChatModel): 'args': {'a': 45, 'b': 67}, 'id': '420c3f3b-df10-4188-945f-eb3abdb40622', 'type': 'tool_call'}] + + Thinking / Reasoning: + You can enable reasoning mode for models that support it by setting + the ``reasoning`` parameter to ``True`` in either the constructor or + the ``invoke``/``stream`` methods. This will enable the model to think + through the problem and return the reasoning process separately in the + ``additional_kwargs`` of the response message, under ``reasoning_content``. + + If ``reasoning`` is set to ``None``, the model will use its default reasoning + behavior, and any reasoning content will *not* be captured under the + ``reasoning_content`` key, but will be present within the main response content + as think tags (```` and ````). + + .. note:: + This feature is only available for `models that support reasoning `__. + + .. code-block:: python + + from langchain_ollama import ChatOllama + + llm = ChatOllama( + model = "deepseek-r1:8b", + reasoning= True, + ) + + user_message = HumanMessage(content="how many r in the word strawberry?") + messages: List[Any] = [user_message] + llm.invoke(messages) + + # or, on an invocation basis: + + llm.invoke(messages, reasoning=True) + # or llm.stream(messages, reasoning=True) + + # If not provided, the invocation will default to the ChatOllama reasoning + # param provided (None by default). + + .. code-block:: python + + AIMessage(content='The word "strawberry" contains **three \'r\' letters**. Here\'s a breakdown for clarity:\n\n- The spelling of "strawberry" has two parts ... be 3.\n\nTo be thorough, let\'s confirm with an online source or common knowledge.\n\nI can recall that "strawberry" has: s-t-r-a-w-b-e-r-r-y — yes, three r\'s.\n\nPerhaps it\'s misspelled by some, but standard is correct.\n\nSo I think the response should be 3.\n'}, response_metadata={'model': 'deepseek-r1:8b', 'created_at': '2025-07-08T19:33:55.891269Z', 'done': True, 'done_reason': 'stop', 'total_duration': 98232561292, 'load_duration': 28036792, 'prompt_eval_count': 10, 'prompt_eval_duration': 40171834, 'eval_count': 3615, 'eval_duration': 98163832416, 'model_name': 'deepseek-r1:8b'}, id='run--18f8269f-6a35-4a7c-826d-b89d52c753b3-0', usage_metadata={'input_tokens': 10, 'output_tokens': 3615, 'total_tokens': 3625}) + + """ # noqa: E501, pylint: disable=line-too-long model: str @@ -777,6 +818,7 @@ class ChatOllama(BaseChatModel): stop: Optional[list[str]] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + reasoning = kwargs.get("reasoning", self.reasoning) for stream_resp in self._create_chat_stream(messages, stop, **kwargs): if not isinstance(stream_resp, str): if stream_resp.get("done") is True: @@ -795,7 +837,7 @@ class ChatOllama(BaseChatModel): additional_kwargs = {} if ( - self.reasoning + reasoning and "message" in stream_resp and (thinking_content := stream_resp["message"].get("thinking")) ): @@ -836,6 +878,7 @@ class ChatOllama(BaseChatModel): stop: Optional[list[str]] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: + reasoning = kwargs.get("reasoning", self.reasoning) async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs): if not isinstance(stream_resp, str): if stream_resp.get("done") is True: @@ -854,7 +897,7 @@ class ChatOllama(BaseChatModel): additional_kwargs = {} if ( - self.reasoning + reasoning and "message" in stream_resp and (thinking_content := stream_resp["message"].get("thinking")) ): diff --git a/libs/partners/ollama/langchain_ollama/llms.py b/libs/partners/ollama/langchain_ollama/llms.py index 8e17061103e..89d5040b2f9 100644 --- a/libs/partners/ollama/langchain_ollama/llms.py +++ b/libs/partners/ollama/langchain_ollama/llms.py @@ -38,7 +38,7 @@ class OllamaLLM(BaseLLM): model: str """Model name to use.""" - reasoning: Optional[bool] = True + reasoning: Optional[bool] = None """Controls the reasoning/thinking mode for `supported models `__. @@ -272,8 +272,11 @@ class OllamaLLM(BaseLLM): **kwargs: Any, ) -> GenerationChunk: final_chunk = None + thinking_content = "" async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs): if not isinstance(stream_resp, str): + if stream_resp.get("thinking"): + thinking_content += stream_resp["thinking"] chunk = GenerationChunk( text=stream_resp.get("response", ""), generation_info=( @@ -294,6 +297,12 @@ class OllamaLLM(BaseLLM): msg = "No data received from Ollama stream." raise ValueError(msg) + if thinking_content: + if final_chunk.generation_info: + final_chunk.generation_info["thinking"] = thinking_content + else: + final_chunk.generation_info = {"thinking": thinking_content} + return final_chunk def _stream_with_aggregation( @@ -305,8 +314,11 @@ class OllamaLLM(BaseLLM): **kwargs: Any, ) -> GenerationChunk: final_chunk = None + thinking_content = "" for stream_resp in self._create_generate_stream(prompt, stop, **kwargs): if not isinstance(stream_resp, str): + if stream_resp.get("thinking"): + thinking_content += stream_resp["thinking"] chunk = GenerationChunk( text=stream_resp.get("response", ""), generation_info=( @@ -327,6 +339,12 @@ class OllamaLLM(BaseLLM): msg = "No data received from Ollama stream." raise ValueError(msg) + if thinking_content: + if final_chunk.generation_info: + final_chunk.generation_info["thinking"] = thinking_content + else: + final_chunk.generation_info = {"thinking": thinking_content} + return final_chunk def _generate( @@ -374,10 +392,11 @@ class OllamaLLM(BaseLLM): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: + reasoning = kwargs.get("reasoning", self.reasoning) for stream_resp in self._create_generate_stream(prompt, stop, **kwargs): if not isinstance(stream_resp, str): additional_kwargs = {} - if thinking_content := stream_resp.get("thinking"): + if reasoning and (thinking_content := stream_resp.get("thinking")): additional_kwargs["reasoning_content"] = thinking_content chunk = GenerationChunk( @@ -404,10 +423,11 @@ class OllamaLLM(BaseLLM): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: + reasoning = kwargs.get("reasoning", self.reasoning) async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs): if not isinstance(stream_resp, str): additional_kwargs = {} - if thinking_content := stream_resp.get("thinking"): + if reasoning and (thinking_content := stream_resp.get("thinking")): additional_kwargs["reasoning_content"] = thinking_content chunk = GenerationChunk( diff --git a/libs/partners/ollama/tests/integration_tests/test_llms.py b/libs/partners/ollama/tests/integration_tests/test_llms.py index afc67ef7118..0ee236e62e5 100644 --- a/libs/partners/ollama/tests/integration_tests/test_llms.py +++ b/libs/partners/ollama/tests/integration_tests/test_llms.py @@ -1,18 +1,17 @@ """Test OllamaLLM llm.""" import pytest -from langchain_core.messages import AIMessageChunk, BaseMessageChunk +from langchain_core.outputs import GenerationChunk from langchain_core.runnables import RunnableConfig from langchain_ollama.llms import OllamaLLM MODEL_NAME = "llama3.1" - SAMPLE = "What is 3^3?" -def test_stream() -> None: - """Test streaming tokens from OpenAI.""" +def test_stream_text_tokens() -> None: + """Test streaming raw string tokens from OllamaLLM.""" llm = OllamaLLM(model=MODEL_NAME) for token in llm.stream("I'm Pickle Rick"): @@ -20,60 +19,52 @@ def test_stream() -> None: @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_stream_no_reasoning(model: str) -> None: - """Test streaming with `reasoning=False`""" +def test__stream_no_reasoning(model: str) -> None: + """Test low-level chunk streaming of a simple prompt with `reasoning=False`.""" llm = OllamaLLM(model=model, num_ctx=2**12) - messages = [ - { - "role": "user", - "content": SAMPLE, - } - ] - result = None - for chunk in llm.stream(messages): - assert isinstance(chunk, BaseMessageChunk) - if result is None: - result = chunk - continue - result += chunk - assert isinstance(result, AIMessageChunk) - assert result.content - assert "reasoning_content" not in result.additional_kwargs - # Sanity check the old behavior isn't present - assert "" not in result.content and "" not in result.content + result_chunk = None + for chunk in llm._stream(SAMPLE): + # Should be a GenerationChunk + assert isinstance(chunk, GenerationChunk) + if result_chunk is None: + result_chunk = chunk + else: + result_chunk += chunk + + # The final result must be a GenerationChunk with visible content + assert isinstance(result_chunk, GenerationChunk) + assert result_chunk.text + # No separate reasoning_content + assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator] @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_reasoning_stream(model: str) -> None: - """Test streaming with `reasoning=True`""" +def test__stream_with_reasoning(model: str) -> None: + """Test low-level chunk streaming with `reasoning=True`.""" llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True) - messages = [ - { - "role": "user", - "content": SAMPLE, - } - ] - result = None - for chunk in llm.stream(messages): - assert isinstance(chunk, BaseMessageChunk) - if result is None: - result = chunk - continue - result += chunk - assert isinstance(result, AIMessageChunk) - assert result.content - assert "reasoning_content" in result.additional_kwargs - assert len(result.additional_kwargs["reasoning_content"]) > 0 - # Sanity check the old behavior isn't present - assert "" not in result.content and "" not in result.content - assert "" not in result.additional_kwargs["reasoning_content"] - assert "" not in result.additional_kwargs["reasoning_content"] + result_chunk = None + for chunk in llm._stream(SAMPLE): + assert isinstance(chunk, GenerationChunk) + if result_chunk is None: + result_chunk = chunk + else: + result_chunk += chunk + + assert isinstance(result_chunk, GenerationChunk) + assert result_chunk.text + # Should have extracted reasoning into generation_info + assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator] + assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index] + # And neither the visible nor the hidden portion contains tags + assert "" not in result_chunk.text and "" not in result_chunk.text + assert "" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index] + assert "" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index] -async def test_astream() -> None: - """Test streaming tokens from OpenAI.""" +async def test_astream_text_tokens() -> None: + """Test async streaming raw string tokens from OllamaLLM.""" llm = OllamaLLM(model=MODEL_NAME) async for token in llm.astream("I'm Pickle Rick"): @@ -81,60 +72,44 @@ async def test_astream() -> None: @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -async def test_astream_no_reasoning(model: str) -> None: - """Test async streaming with `reasoning=False`""" +async def test__astream_no_reasoning(model: str) -> None: + """Test low-level async chunk streaming with `reasoning=False`.""" llm = OllamaLLM(model=model, num_ctx=2**12) - messages = [ - { - "role": "user", - "content": SAMPLE, - } - ] - result = None - async for chunk in llm.astream(messages): - assert isinstance(chunk, BaseMessageChunk) - if result is None: - result = chunk - continue - result += chunk - assert isinstance(result, AIMessageChunk) - assert result.content - assert "reasoning_content" not in result.additional_kwargs - # Sanity check the old behavior isn't present - assert "" not in result.content and "" not in result.content + result_chunk = None + async for chunk in llm._astream(SAMPLE): + assert isinstance(chunk, GenerationChunk) + if result_chunk is None: + result_chunk = chunk + else: + result_chunk += chunk + + assert isinstance(result_chunk, GenerationChunk) + assert result_chunk.text + assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator] @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -async def test_reasoning_astream(model: str) -> None: - """Test async streaming with `reasoning=True`""" +async def test__astream_with_reasoning(model: str) -> None: + """Test low-level async chunk streaming with `reasoning=True`.""" llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True) - messages = [ - { - "role": "user", - "content": SAMPLE, - } - ] - result = None - async for chunk in llm.astream(messages): - assert isinstance(chunk, BaseMessageChunk) - if result is None: - result = chunk - continue - result += chunk - assert isinstance(result, AIMessageChunk) - assert result.content - assert "reasoning_content" in result.additional_kwargs - assert len(result.additional_kwargs["reasoning_content"]) > 0 - # Sanity check the old behavior isn't present - assert "" not in result.content and "" not in result.content - assert "" not in result.additional_kwargs["reasoning_content"] - assert "" not in result.additional_kwargs["reasoning_content"] + result_chunk = None + async for chunk in llm._astream(SAMPLE): + assert isinstance(chunk, GenerationChunk) + if result_chunk is None: + result_chunk = chunk + else: + result_chunk += chunk + + assert isinstance(result_chunk, GenerationChunk) + assert result_chunk.text + assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator] + assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index] async def test_abatch() -> None: - """Test streaming tokens from OllamaLLM.""" + """Test batch sync token generation from OllamaLLM.""" llm = OllamaLLM(model=MODEL_NAME) result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) @@ -143,7 +118,7 @@ async def test_abatch() -> None: async def test_abatch_tags() -> None: - """Test batch tokens from OllamaLLM.""" + """Test batch sync token generation with tags.""" llm = OllamaLLM(model=MODEL_NAME) result = await llm.abatch( @@ -154,7 +129,7 @@ async def test_abatch_tags() -> None: def test_batch() -> None: - """Test batch tokens from OllamaLLM.""" + """Test batch token generation from OllamaLLM.""" llm = OllamaLLM(model=MODEL_NAME) result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) @@ -163,75 +138,15 @@ def test_batch() -> None: async def test_ainvoke() -> None: - """Test invoke tokens from OllamaLLM.""" + """Test async invoke returning a string.""" llm = OllamaLLM(model=MODEL_NAME) result = await llm.ainvoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"])) assert isinstance(result, str) -# TODO -# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -# async def test_ainvoke_no_reasoning(model: str) -> None: -# """Test using async invoke with `reasoning=False`""" -# llm = OllamaLLM(model=model, num_ctx=2**12) -# message = SAMPLE -# result = await llm.ainvoke(message) -# assert result.content -# assert "reasoning_content" not in result.additional_kwargs - -# # Sanity check the old behavior isn't present -# assert "" not in result.content and "" not in result.content - - -# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -# async def test_reasoning_ainvoke(model: str) -> None: -# """Test invoke with `reasoning=True`""" -# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True) -# message = SAMPLE -# result = await llm.ainvoke(message) -# assert result.content -# assert "reasoning_content" in result.additional_kwargs -# assert len(result.additional_kwargs["reasoning_content"]) > 0 - -# # Sanity check the old behavior isn't present -# assert "" not in result.content and "" not in result.content -# assert "" not in result.additional_kwargs["reasoning_content"] -# assert "" not in result.additional_kwargs["reasoning_content"] - - def test_invoke() -> None: - """Test invoke tokens from OllamaLLM.""" + """Test sync invoke returning a string.""" llm = OllamaLLM(model=MODEL_NAME) result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"])) assert isinstance(result, str) - - -# TODO -# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -# def test_invoke_no_reasoning(model: str) -> None: -# """Test using invoke with `reasoning=False`""" -# llm = OllamaLLM(model=model, num_ctx=2**12) -# message = SAMPLE -# result = llm.invoke(message) -# assert result.content -# assert "reasoning_content" not in result.additional_kwargs - -# # Sanity check the old behavior isn't present -# assert "" not in result.content and "" not in result.content - - -# @pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -# def test_reasoning_invoke(model: str) -> None: -# """Test invoke with `reasoning=True`""" -# llm = OllamaLLM(model=model, num_ctx=2**12, reasoning=True) -# message = SAMPLE -# result = llm.invoke(message) -# assert result.content -# assert "reasoning_content" in result.additional_kwargs -# assert len(result.additional_kwargs["reasoning_content"]) > 0 - -# # Sanity check the old behavior isn't present -# assert "" not in result.content and "" not in result.content -# assert "" not in result.additional_kwargs["reasoning_content"] -# assert "" not in result.additional_kwargs["reasoning_content"] diff --git a/libs/partners/ollama/tests/unit_tests/test_llms.py b/libs/partners/ollama/tests/unit_tests/test_llms.py index 2a0ad896f90..55116688af3 100644 --- a/libs/partners/ollama/tests/unit_tests/test_llms.py +++ b/libs/partners/ollama/tests/unit_tests/test_llms.py @@ -48,3 +48,24 @@ def test_validate_model_on_init(mock_validate_model: Any) -> None: # Test that validate_model is NOT called by default OllamaLLM(model=MODEL_NAME) mock_validate_model.assert_not_called() + + +def test_reasoning_aggregation() -> None: + """Test that reasoning chunks are aggregated into final response.""" + llm = OllamaLLM(model=MODEL_NAME, reasoning=True) + prompts = ["some prompt"] + mock_stream = [ + {"thinking": "I am thinking.", "done": False}, + {"thinking": " Still thinking.", "done": False}, + {"response": "Final Answer.", "done": True}, + ] + + with patch.object(llm, "_create_generate_stream") as mock_stream_method: + mock_stream_method.return_value = iter(mock_stream) + result = llm.generate(prompts) + + assert result.generations[0][0].generation_info is not None + assert ( + result.generations[0][0].generation_info["thinking"] + == "I am thinking. Still thinking." + )