mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 06:53:59 +00:00
(rfc) chat models (#1424)
Co-authored-by: Ankush Gola <ankush.gola@gmail.com>
This commit is contained in:
0
tests/integration_tests/chat_models/__init__.py
Normal file
0
tests/integration_tests/chat_models/__init__.py
Normal file
89
tests/integration_tests/chat_models/test_openai.py
Normal file
89
tests/integration_tests/chat_models/test_openai.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Test ChatOpenAI wrapper."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.schema import (
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
SystemMessage,
|
||||
)
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_chat_openai() -> None:
|
||||
"""Test ChatOpenAI wrapper."""
|
||||
chat = ChatOpenAI(max_tokens=10)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_openai_system_message() -> None:
|
||||
"""Test ChatOpenAI wrapper with system message."""
|
||||
chat = ChatOpenAI(max_tokens=10)
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat([system_message, human_message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_openai_generate() -> None:
|
||||
"""Test ChatOpenAI wrapper with generate."""
|
||||
chat = ChatOpenAI(max_tokens=10, n=2)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.generate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
def test_chat_openai_multiple_completions() -> None:
|
||||
"""Test ChatOpenAI wrapper with multiple completions."""
|
||||
chat = ChatOpenAI(max_tokens=10, n=5)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat._generate([message])
|
||||
assert isinstance(response, ChatResult)
|
||||
assert len(response.generations) == 5
|
||||
for generation in response.generations:
|
||||
assert isinstance(generation.message, BaseMessage)
|
||||
assert isinstance(generation.message.content, str)
|
||||
|
||||
|
||||
def test_chat_openai_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response, BaseMessage)
|
||||
|
||||
|
||||
def test_chat_openai_invalid_streaming_params() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
n=5,
|
||||
)
|
@@ -60,7 +60,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
execution_order=3,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult([[]]),
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
)
|
||||
],
|
||||
@@ -74,7 +74,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
execution_order=4,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult([[]]),
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
),
|
||||
],
|
||||
@@ -86,10 +86,10 @@ def _perform_nested_run(tracer: BaseTracer) -> None:
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_tool_end("test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_chain_end(outputs={})
|
||||
|
||||
|
||||
@@ -209,7 +209,7 @@ def test_tracer_llm_run() -> None:
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult([[]]),
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
@@ -217,7 +217,7 @@ def test_tracer_llm_run() -> None:
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@@ -237,7 +237,7 @@ def test_tracer_llm_run_errors_no_start() -> None:
|
||||
|
||||
tracer.new_session()
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@@ -251,7 +251,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult([[]]),
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
@@ -261,7 +261,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
|
||||
@@ -409,9 +409,9 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
for _ in range(3):
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_error(exception)
|
||||
|
@@ -31,7 +31,7 @@ def test_caching() -> None:
|
||||
[Generation(text="fizz")],
|
||||
]
|
||||
expected_output = LLMResult(
|
||||
expected_generations,
|
||||
generations=expected_generations,
|
||||
llm_output=None,
|
||||
)
|
||||
assert output == expected_output
|
||||
@@ -69,7 +69,7 @@ def test_custom_caching() -> None:
|
||||
[Generation(text="fizz")],
|
||||
]
|
||||
expected_output = LLMResult(
|
||||
expected_generations,
|
||||
generations=expected_generations,
|
||||
llm_output=None,
|
||||
)
|
||||
assert output == expected_output
|
||||
|
91
tests/unit_tests/prompts/test_chat.py
Normal file
91
tests/unit_tests/prompts/test_chat.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.chat import (
|
||||
AIMessagePromptTemplate,
|
||||
BaseMessagePromptTemplate,
|
||||
ChatMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
ChatPromptValue,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
|
||||
|
||||
def create_messages() -> List[BaseMessagePromptTemplate]:
|
||||
"""Create messages."""
|
||||
system_message_prompt = SystemMessagePromptTemplate(
|
||||
prompt=PromptTemplate(
|
||||
template="Here's some context: {context}",
|
||||
input_variables=["context"],
|
||||
)
|
||||
)
|
||||
human_message_prompt = HumanMessagePromptTemplate(
|
||||
prompt=PromptTemplate(
|
||||
template="Hello {foo}, I'm {bar}. Thanks for the {context}",
|
||||
input_variables=["foo", "bar", "context"],
|
||||
)
|
||||
)
|
||||
ai_message_prompt = AIMessagePromptTemplate(
|
||||
prompt=PromptTemplate(
|
||||
template="I'm an AI. I'm {foo}. I'm {bar}.",
|
||||
input_variables=["foo", "bar"],
|
||||
)
|
||||
)
|
||||
chat_message_prompt = ChatMessagePromptTemplate(
|
||||
role="test",
|
||||
prompt=PromptTemplate(
|
||||
template="I'm a generic message. I'm {foo}. I'm {bar}.",
|
||||
input_variables=["foo", "bar"],
|
||||
),
|
||||
)
|
||||
return [
|
||||
system_message_prompt,
|
||||
human_message_prompt,
|
||||
ai_message_prompt,
|
||||
chat_message_prompt,
|
||||
]
|
||||
|
||||
|
||||
def create_chat_prompt_template() -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template."""
|
||||
return ChatPromptTemplate(
|
||||
input_variables=["foo", "bar", "context"],
|
||||
messages=create_messages(),
|
||||
)
|
||||
|
||||
|
||||
def test_chat_prompt_template() -> None:
|
||||
"""Test chat prompt template."""
|
||||
prompt_template = create_chat_prompt_template()
|
||||
prompt = prompt_template.format_prompt(foo="foo", bar="bar", context="context")
|
||||
assert isinstance(prompt, ChatPromptValue)
|
||||
messages = prompt.to_messages()
|
||||
assert len(messages) == 4
|
||||
assert messages[0].content == "Here's some context: context"
|
||||
assert messages[1].content == "Hello foo, I'm bar. Thanks for the context"
|
||||
assert messages[2].content == "I'm an AI. I'm foo. I'm bar."
|
||||
assert messages[3].content == "I'm a generic message. I'm foo. I'm bar."
|
||||
|
||||
string = prompt.to_string()
|
||||
expected = (
|
||||
'[SystemMessage(content="Here\'s some context: context", '
|
||||
'additional_kwargs={}), HumanMessage(content="Hello foo, '
|
||||
"I'm bar. Thanks for the context\", additional_kwargs={}), "
|
||||
"AIMessage(content=\"I'm an AI. I'm foo. I'm bar.\", additional_kwargs={}), "
|
||||
"ChatMessage(content=\"I'm a generic message. I'm foo. I'm bar.\","
|
||||
" additional_kwargs={}, role='test')]"
|
||||
)
|
||||
assert string == expected
|
||||
|
||||
string = prompt_template.format(foo="foo", bar="bar", context="context")
|
||||
assert string == expected
|
||||
|
||||
|
||||
def test_chat_prompt_template_from_messages() -> None:
|
||||
"""Test creating a chat prompt template from messages."""
|
||||
chat_prompt_template = ChatPromptTemplate.from_messages(create_messages())
|
||||
assert sorted(chat_prompt_template.input_variables) == sorted(
|
||||
["context", "foo", "bar"]
|
||||
)
|
||||
assert len(chat_prompt_template.messages) == 4
|
Reference in New Issue
Block a user