[simple][test] Added test case for schema.py (#3692)

- added unittest for schema.py covering utility functions and token
counting.
- fixed a nit. based on huggingface doc, the tokenizer model is gpt-2.
[link](https://huggingface.co/transformers/v4.8.2/_modules/transformers/models/gpt2/tokenization_gpt2_fast.html)
- make lint && make format, passed on local
- screenshot of new test running result

<img width="1283" alt="Screenshot 2023-04-27 at 9 51 55 PM"
src="https://user-images.githubusercontent.com/62768671/235057441-c0ac3406-9541-453f-ba14-3ebb08656114.png">
This commit is contained in:
Mike Wang
2023-04-28 20:42:24 -07:00
committed by GitHub
parent 15b92d361d
commit b588446bf9
3 changed files with 115 additions and 19 deletions

View File

@@ -0,0 +1,15 @@
"""Test formatting functionality."""
from langchain.schema import _get_num_tokens_default_method
class TestTokenCountingWithGPT2Tokenizer:
def test_empty_token(self) -> None:
assert _get_num_tokens_default_method("") == 0
def test_multiple_tokens(self) -> None:
assert _get_num_tokens_default_method("a b c") == 3
def test_special_tokens(self) -> None:
# test for consistency when the default tokenizer is changed
assert _get_num_tokens_default_method("a:b_c d") == 6

View File

@@ -0,0 +1,77 @@
"""Test formatting functionality."""
import unittest
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage,
get_buffer_string,
messages_from_dict,
messages_to_dict,
)
class TestGetBufferString(unittest.TestCase):
human_msg: HumanMessage = HumanMessage(content="human")
ai_msg: AIMessage = AIMessage(content="ai")
sys_msg: SystemMessage = SystemMessage(content="sys")
def test_empty_input(self) -> None:
self.assertEqual(get_buffer_string([]), "")
def test_valid_single_message(self) -> None:
expected_output = f"Human: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg]),
expected_output,
)
def test_custom_human_prefix(self) -> None:
prefix = "H"
expected_output = f"{prefix}: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg], human_prefix="H"),
expected_output,
)
def test_custom_ai_prefix(self) -> None:
prefix = "A"
expected_output = f"{prefix}: {self.ai_msg.content}"
self.assertEqual(
get_buffer_string([self.ai_msg], ai_prefix="A"),
expected_output,
)
def test_multiple_msg(self) -> None:
msgs = [self.human_msg, self.ai_msg, self.sys_msg]
expected_output = "\n".join(
[
f"Human: {self.human_msg.content}",
f"AI: {self.ai_msg.content}",
f"System: {self.sys_msg.content}",
]
)
self.assertEqual(
get_buffer_string(msgs),
expected_output,
)
class TestMessageDictConversion(unittest.TestCase):
human_msg: HumanMessage = HumanMessage(
content="human", additional_kwargs={"key": "value"}
)
ai_msg: AIMessage = AIMessage(content="ai")
sys_msg: SystemMessage = SystemMessage(content="sys")
def test_multiple_msg(self) -> None:
msgs = [
self.human_msg,
self.ai_msg,
self.sys_msg,
]
self.assertEqual(
messages_from_dict(messages_to_dict(msgs)),
msgs,
)