langchain-ollama (partners) / langchain-core: allow passing ChatMessages to Ollama (including arbitrary roles) (#30411)

Replacement for PR #30191 (@ccurme)

**Description**: currently, ChatOllama [will raise a value error if a
ChatMessage is passed to
it](https://github.com/langchain-ai/langchain/blob/master/libs/partners/ollama/langchain_ollama/chat_models.py#L514),
as described
https://github.com/langchain-ai/langchain/pull/30147#issuecomment-2708932481.

Furthermore, ollama-python is removing the limitations on valid roles
that can be passed through chat messages to a model in ollama -
https://github.com/ollama/ollama-python/pull/462#event-16917810634.

This PR removes the role limitations imposed by langchain and enables
passing langchain ChatMessages with arbitrary 'role' values through the
langchain ChatOllama class to the underlying ollama-python Client.

As this PR relies on [merged but unreleased functionality in
ollama-python](
https://github.com/ollama/ollama-python/pull/462#event-16917810634), I
have temporarily pointed the ollama package source to the main branch of
the ollama-python github repo.

Format, lint, and tests of new functionality passing. Need to resolve
issue with recently added ChatOllama tests. (Now resolved)

**Issue**: resolves #30122 (related to ollama issue
https://github.com/ollama/ollama/issues/8955)

**Dependencies**: no new dependencies

[x] PR title
[x] PR message
[x] Lint and test: format, lint, and test all running successfully and
passing

---------

Co-authored-by: Ryan Stewart <ryanstewart@Ryans-MacBook-Pro.local>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
rylativity 2025-04-18 10:07:07 -04:00 committed by GitHub
parent 0c723af4b0
commit dbf9986d44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 588 additions and 478 deletions

File diff suppressed because one or more lines are too long

View File

@ -26,6 +26,7 @@ from langchain_core.messages import (
AIMessageChunk, AIMessageChunk,
BaseMessage, BaseMessage,
BaseMessageChunk, BaseMessageChunk,
ChatMessage,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
ToolCall, ToolCall,
@ -511,7 +512,7 @@ class ChatOllama(BaseChatModel):
) -> Sequence[Message]: ) -> Sequence[Message]:
ollama_messages: list = [] ollama_messages: list = []
for message in messages: for message in messages:
role: Literal["user", "assistant", "system", "tool"] role: str
tool_call_id: Optional[str] = None tool_call_id: Optional[str] = None
tool_calls: Optional[list[dict[str, Any]]] = None tool_calls: Optional[list[dict[str, Any]]] = None
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
@ -528,6 +529,8 @@ class ChatOllama(BaseChatModel):
) )
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
role = "system" role = "system"
elif isinstance(message, ChatMessage):
role = message.role
elif isinstance(message, ToolMessage): elif isinstance(message, ToolMessage):
role = "tool" role = "tool"
tool_call_id = message.tool_call_id tool_call_id = message.tool_call_id

View File

@ -6,7 +6,10 @@ build-backend = "pdm.backend"
authors = [] authors = []
license = { text = "MIT" } license = { text = "MIT" }
requires-python = "<4.0,>=3.9" requires-python = "<4.0,>=3.9"
dependencies = ["ollama<1,>=0.4.4", "langchain-core<1.0.0,>=0.3.52"] dependencies = [
"ollama>=0.4.8,<1.0.0",
"langchain-core<1.0.0,>=0.3.52",
]
name = "langchain-ollama" name = "langchain-ollama"
version = "0.3.2" version = "0.3.2"
description = "An integration package connecting Ollama and LangChain" description = "An integration package connecting Ollama and LangChain"

View File

@ -1,7 +1,13 @@
"""Test chat model integration.""" """Test chat model integration."""
import json import json
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
import pytest
from httpx import Client, Request, Response
from langchain_core.messages import ChatMessage
from langchain_tests.unit_tests import ChatModelUnitTests from langchain_tests.unit_tests import ChatModelUnitTests
from langchain_ollama.chat_models import ChatOllama, _parse_arguments_from_tool_call from langchain_ollama.chat_models import ChatOllama, _parse_arguments_from_tool_call
@ -23,3 +29,38 @@ def test__parse_arguments_from_tool_call() -> None:
response = _parse_arguments_from_tool_call(raw_tool_calls[0]) response = _parse_arguments_from_tool_call(raw_tool_calls[0])
assert response is not None assert response is not None
assert isinstance(response["arg_1"], str) assert isinstance(response["arg_1"], str)
@contextmanager
def _mock_httpx_client_stream(
*args: Any, **kwargs: Any
) -> Generator[Response, Any, Any]:
yield Response(
status_code=200,
content='{"message": {"role": "assistant", "content": "The meaning ..."}}',
request=Request(method="POST", url="http://whocares:11434"),
)
def test_arbitrary_roles_accepted_in_chatmessages(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(Client, "stream", _mock_httpx_client_stream)
llm = ChatOllama(
base_url="http://whocares:11434",
model="granite3.2",
verbose=True,
format=None,
)
messages = [
ChatMessage(
role="somerandomrole",
content="I'm ok with you adding any role message now!",
),
ChatMessage(role="control", content="thinking"),
ChatMessage(role="user", content="What is the meaning of life?"),
]
llm.invoke(messages)

View File

@ -288,7 +288,7 @@ wheels = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.3.52" version = "0.3.54"
source = { editable = "../../core" } source = { editable = "../../core" }
dependencies = [ dependencies = [
{ name = "jsonpatch" }, { name = "jsonpatch" },
@ -381,7 +381,7 @@ typing = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "langchain-core", editable = "../../core" }, { name = "langchain-core", editable = "../../core" },
{ name = "ollama", specifier = ">=0.4.4,<1" }, { name = "ollama", specifier = ">=0.4.8,<1.0.0" },
] ]
[package.metadata.requires-dev] [package.metadata.requires-dev]
@ -405,7 +405,7 @@ typing = [
[[package]] [[package]]
name = "langchain-tests" name = "langchain-tests"
version = "0.3.18" version = "0.3.19"
source = { editable = "../../standard-tests" } source = { editable = "../../standard-tests" }
dependencies = [ dependencies = [
{ name = "httpx" }, { name = "httpx" },
@ -625,15 +625,15 @@ wheels = [
[[package]] [[package]]
name = "ollama" name = "ollama"
version = "0.4.7" version = "0.4.8"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "httpx" }, { name = "httpx" },
{ name = "pydantic" }, { name = "pydantic" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/b0/6d/dc77539c735bbed5d0c873fb029fb86aa9f0163df169b34152914331c369/ollama-0.4.7.tar.gz", hash = "sha256:891dcbe54f55397d82d289c459de0ea897e103b86a3f1fad0fdb1895922a75ff", size = 12843 } sdist = { url = "https://files.pythonhosted.org/packages/e2/64/709dc99030f8f46ec552f0a7da73bbdcc2da58666abfec4742ccdb2e800e/ollama-0.4.8.tar.gz", hash = "sha256:1121439d49b96fa8339842965d0616eba5deb9f8c790786cdf4c0b3df4833802", size = 12972 }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/31/83/c3ffac86906c10184c88c2e916460806b072a2cfe34cdcaf3a0c0e836d39/ollama-0.4.7-py3-none-any.whl", hash = "sha256:85505663cca67a83707be5fb3aeff0ea72e67846cea5985529d8eca4366564a1", size = 13210 }, { url = "https://files.pythonhosted.org/packages/33/3f/164de150e983b3a16e8bf3d4355625e51a357e7b3b1deebe9cc1f7cb9af8/ollama-0.4.8-py3-none-any.whl", hash = "sha256:04312af2c5e72449aaebac4a2776f52ef010877c554103419d3f36066fe8af4c", size = 13325 },
] ]
[[package]] [[package]]