mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 15:16:21 +00:00
373 lines
12 KiB
Python
373 lines
12 KiB
Python
import json
|
|
import os
|
|
from typing import Any, Dict, List
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
|
from pydantic import ValidationError
|
|
|
|
from langchain_community.chat_models import ChatReka
|
|
from langchain_community.chat_models.reka import (
|
|
convert_to_reka_messages,
|
|
process_content,
|
|
)
|
|
|
|
os.environ["REKA_API_KEY"] = "dummy_key"
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
def test_reka_model_param() -> None:
|
|
llm = ChatReka(model="reka-flash")
|
|
assert llm.model == "reka-flash"
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
def test_reka_model_kwargs() -> None:
|
|
llm = ChatReka(model_kwargs={"foo": "bar"})
|
|
assert llm.model_kwargs == {"foo": "bar"}
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
def test_reka_incorrect_field() -> None:
|
|
"""Test that providing an incorrect field raises ValidationError."""
|
|
with pytest.raises(ValidationError):
|
|
ChatReka(unknown_field="bar") # type: ignore
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
def test_reka_initialization() -> None:
|
|
"""Test Reka initialization."""
|
|
# Verify that ChatReka can be initialized using a secret key provided
|
|
# as a parameter rather than an environment variable.
|
|
ChatReka(model="reka-flash", reka_api_key="test_key")
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
@pytest.mark.parametrize(
|
|
("content", "expected"),
|
|
[
|
|
("Hello", [{"type": "text", "text": "Hello"}]),
|
|
(
|
|
[
|
|
{"type": "text", "text": "Describe this image"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": "https://example.com/image.jpg",
|
|
},
|
|
],
|
|
[
|
|
{"type": "text", "text": "Describe this image"},
|
|
{"type": "image_url", "image_url": "https://example.com/image.jpg"},
|
|
],
|
|
),
|
|
(
|
|
[
|
|
{"type": "text", "text": "Hello"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": "https://example.com/image.jpg"},
|
|
},
|
|
],
|
|
[
|
|
{"type": "text", "text": "Hello"},
|
|
{"type": "image_url", "image_url": "https://example.com/image.jpg"},
|
|
],
|
|
),
|
|
],
|
|
)
|
|
def test_process_content(content: Any, expected: List[Dict[str, Any]]) -> None:
|
|
result = process_content(content)
|
|
assert result == expected
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
@pytest.mark.parametrize(
|
|
("messages", "expected"),
|
|
[
|
|
(
|
|
[HumanMessage(content="Hello")],
|
|
[{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
|
),
|
|
(
|
|
[
|
|
HumanMessage(
|
|
content=[
|
|
{"type": "text", "text": "Describe this image"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": "https://example.com/image.jpg",
|
|
},
|
|
]
|
|
),
|
|
AIMessage(content="It's a beautiful landscape."),
|
|
],
|
|
[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "Describe this image"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": "https://example.com/image.jpg",
|
|
},
|
|
],
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": [
|
|
{"type": "text", "text": "It's a beautiful landscape."}
|
|
],
|
|
},
|
|
],
|
|
),
|
|
],
|
|
)
|
|
def test_convert_to_reka_messages(
|
|
messages: List[BaseMessage], expected: List[Dict[str, Any]]
|
|
) -> None:
|
|
result = convert_to_reka_messages(messages)
|
|
assert result == expected
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
def test_reka_streaming() -> None:
|
|
llm = ChatReka(streaming=True)
|
|
assert llm.streaming is True
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
def test_reka_temperature() -> None:
|
|
llm = ChatReka(temperature=0.5)
|
|
assert llm.temperature == 0.5
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
def test_reka_max_tokens() -> None:
|
|
llm = ChatReka(max_tokens=100)
|
|
assert llm.max_tokens == 100
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
def test_reka_default_params() -> None:
|
|
llm = ChatReka()
|
|
assert llm._default_params == {
|
|
"max_tokens": 256,
|
|
"model": "reka-flash",
|
|
}
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
def test_reka_identifying_params() -> None:
|
|
"""Test that ChatReka identifies its default parameters correctly."""
|
|
chat = ChatReka(model="reka-flash", temperature=0.7, max_tokens=256)
|
|
expected_params = {
|
|
"model": "reka-flash",
|
|
"temperature": 0.7,
|
|
"max_tokens": 256,
|
|
}
|
|
assert chat._default_params == expected_params
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
def test_reka_llm_type() -> None:
|
|
llm = ChatReka()
|
|
assert llm._llm_type == "reka-chat"
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
def test_reka_tool_use_with_mocked_response() -> None:
|
|
with patch("reka.client.Reka") as MockReka:
|
|
# Mock the Reka client
|
|
mock_client = MockReka.return_value
|
|
mock_chat = MagicMock()
|
|
mock_client.chat = mock_chat
|
|
mock_response = MagicMock()
|
|
mock_message = MagicMock()
|
|
mock_tool_call = MagicMock()
|
|
mock_tool_call.id = "tool_call_1"
|
|
mock_tool_call.name = "search_tool"
|
|
mock_tool_call.parameters = {"query": "LangChain"}
|
|
mock_message.tool_calls = [mock_tool_call]
|
|
mock_message.content = None
|
|
mock_response.responses = [MagicMock(message=mock_message)]
|
|
mock_chat.create.return_value = mock_response
|
|
|
|
llm = ChatReka()
|
|
messages: List[BaseMessage] = [HumanMessage(content="Tell me about LangChain")]
|
|
result = llm._generate(messages)
|
|
|
|
assert len(result.generations) == 1
|
|
ai_message = result.generations[0].message
|
|
assert ai_message.content == ""
|
|
assert "tool_calls" in ai_message.additional_kwargs
|
|
tool_calls = ai_message.additional_kwargs["tool_calls"]
|
|
assert len(tool_calls) == 1
|
|
assert tool_calls[0]["id"] == "tool_call_1"
|
|
assert tool_calls[0]["function"]["name"] == "search_tool"
|
|
assert tool_calls[0]["function"]["arguments"] == json.dumps(
|
|
{"query": "LangChain"}
|
|
)
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
@pytest.mark.parametrize(
|
|
("messages", "expected"),
|
|
[
|
|
# Test single system message
|
|
(
|
|
[
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
HumanMessage(content="Hello"),
|
|
],
|
|
[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "You are a helpful assistant.\nHello"}
|
|
],
|
|
}
|
|
],
|
|
),
|
|
# Test system message with multiple messages
|
|
(
|
|
[
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
HumanMessage(content="What is 2+2?"),
|
|
AIMessage(content="4"),
|
|
HumanMessage(content="Thanks!"),
|
|
],
|
|
[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "You are a helpful assistant.\nWhat is 2+2?",
|
|
}
|
|
],
|
|
},
|
|
{"role": "assistant", "content": [{"type": "text", "text": "4"}]},
|
|
{"role": "user", "content": [{"type": "text", "text": "Thanks!"}]},
|
|
],
|
|
),
|
|
# Test system message with media content
|
|
(
|
|
[
|
|
SystemMessage(content="Hi."),
|
|
HumanMessage(
|
|
content=[
|
|
{"type": "text", "text": "What's in this image?"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": "https://example.com/image.jpg",
|
|
},
|
|
]
|
|
),
|
|
],
|
|
[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Hi.\nWhat's in this image?",
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": "https://example.com/image.jpg",
|
|
},
|
|
],
|
|
},
|
|
],
|
|
),
|
|
],
|
|
)
|
|
def test_system_message_handling(
|
|
messages: List[BaseMessage], expected: List[Dict[str, Any]]
|
|
) -> None:
|
|
"""Test that system messages are handled correctly."""
|
|
result = convert_to_reka_messages(messages)
|
|
assert result == expected
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
def test_multiple_system_messages_error() -> None:
|
|
"""Test that multiple system messages raise an error."""
|
|
messages = [
|
|
SystemMessage(content="System message 1"),
|
|
SystemMessage(content="System message 2"),
|
|
HumanMessage(content="Hello"),
|
|
]
|
|
|
|
with pytest.raises(ValueError, match="Multiple system messages are not supported."):
|
|
convert_to_reka_messages(messages)
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Dependency conflict w/ other dependencies for urllib3 versions."
|
|
)
|
|
def test_get_num_tokens() -> None:
|
|
"""Test that token counting works correctly for different input types."""
|
|
llm = ChatReka()
|
|
import tiktoken
|
|
|
|
encoding = tiktoken.get_encoding("cl100k_base")
|
|
|
|
# Test string input
|
|
text = "What is the weather like today?"
|
|
expected_tokens = len(encoding.encode(text))
|
|
assert llm.get_num_tokens(text) == expected_tokens
|
|
|
|
# Test BaseMessage input
|
|
message = HumanMessage(content="What is the weather like today?")
|
|
assert isinstance(message.content, str)
|
|
expected_tokens = len(encoding.encode(message.content))
|
|
assert llm.get_num_tokens(message) == expected_tokens
|
|
|
|
# Test List[BaseMessage] input
|
|
messages = [
|
|
SystemMessage(content="You are a helpful assistant."),
|
|
HumanMessage(content="Hi!"),
|
|
AIMessage(content="Hello! How can I help you today?"),
|
|
]
|
|
expected_tokens = sum(
|
|
len(encoding.encode(msg.content))
|
|
for msg in messages
|
|
if isinstance(msg.content, str)
|
|
)
|
|
assert llm.get_num_tokens(messages) == expected_tokens
|
|
|
|
# Test empty message list
|
|
assert llm.get_num_tokens([]) == 0
|