mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 13:18:12 +00:00
[simple][test] Added test case for schema.py (#3692)
- added unittest for schema.py covering utility functions and token counting. - fixed a nit. based on huggingface doc, the tokenizer model is gpt-2. [link](https://huggingface.co/transformers/v4.8.2/_modules/transformers/models/gpt2/tokenization_gpt2_fast.html) - make lint && make format, passed on local - screenshot of new test running result <img width="1283" alt="Screenshot 2023-04-27 at 9 51 55 PM" src="https://user-images.githubusercontent.com/62768671/235057441-c0ac3406-9541-453f-ba14-3ebb08656114.png">
This commit is contained in:
parent
15b92d361d
commit
b588446bf9
@ -181,6 +181,28 @@ class PromptValue(BaseModel, ABC):
|
||||
"""Return prompt as messages."""
|
||||
|
||||
|
||||
def _get_num_tokens_default_method(text: str) -> int:
|
||||
"""Get the number of tokens present in the text."""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-2 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
# tokenize the text using the GPT-3 tokenizer
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
||||
# calculate the number of tokens in the tokenized text
|
||||
return len(tokenized_text)
|
||||
|
||||
|
||||
class BaseLanguageModel(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
@ -195,25 +217,7 @@ class BaseLanguageModel(BaseModel, ABC):
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text."""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-3 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
# tokenize the text using the GPT-3 tokenizer
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
||||
# calculate the number of tokens in the tokenized text
|
||||
return len(tokenized_text)
|
||||
return _get_num_tokens_default_method(text)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the message."""
|
||||
|
15
tests/integration_tests/test_schema.py
Normal file
15
tests/integration_tests/test_schema.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Test formatting functionality."""
|
||||
|
||||
from langchain.schema import _get_num_tokens_default_method
|
||||
|
||||
|
||||
class TestTokenCountingWithGPT2Tokenizer:
|
||||
def test_empty_token(self) -> None:
|
||||
assert _get_num_tokens_default_method("") == 0
|
||||
|
||||
def test_multiple_tokens(self) -> None:
|
||||
assert _get_num_tokens_default_method("a b c") == 3
|
||||
|
||||
def test_special_tokens(self) -> None:
|
||||
# test for consistency when the default tokenizer is changed
|
||||
assert _get_num_tokens_default_method("a:b_c d") == 6
|
77
tests/unit_tests/test_schema.py
Normal file
77
tests/unit_tests/test_schema.py
Normal file
@ -0,0 +1,77 @@
|
||||
"""Test formatting functionality."""
|
||||
|
||||
import unittest
|
||||
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
|
||||
|
||||
class TestGetBufferString(unittest.TestCase):
|
||||
human_msg: HumanMessage = HumanMessage(content="human")
|
||||
ai_msg: AIMessage = AIMessage(content="ai")
|
||||
sys_msg: SystemMessage = SystemMessage(content="sys")
|
||||
|
||||
def test_empty_input(self) -> None:
|
||||
self.assertEqual(get_buffer_string([]), "")
|
||||
|
||||
def test_valid_single_message(self) -> None:
|
||||
expected_output = f"Human: {self.human_msg.content}"
|
||||
self.assertEqual(
|
||||
get_buffer_string([self.human_msg]),
|
||||
expected_output,
|
||||
)
|
||||
|
||||
def test_custom_human_prefix(self) -> None:
|
||||
prefix = "H"
|
||||
expected_output = f"{prefix}: {self.human_msg.content}"
|
||||
self.assertEqual(
|
||||
get_buffer_string([self.human_msg], human_prefix="H"),
|
||||
expected_output,
|
||||
)
|
||||
|
||||
def test_custom_ai_prefix(self) -> None:
|
||||
prefix = "A"
|
||||
expected_output = f"{prefix}: {self.ai_msg.content}"
|
||||
self.assertEqual(
|
||||
get_buffer_string([self.ai_msg], ai_prefix="A"),
|
||||
expected_output,
|
||||
)
|
||||
|
||||
def test_multiple_msg(self) -> None:
|
||||
msgs = [self.human_msg, self.ai_msg, self.sys_msg]
|
||||
expected_output = "\n".join(
|
||||
[
|
||||
f"Human: {self.human_msg.content}",
|
||||
f"AI: {self.ai_msg.content}",
|
||||
f"System: {self.sys_msg.content}",
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
get_buffer_string(msgs),
|
||||
expected_output,
|
||||
)
|
||||
|
||||
|
||||
class TestMessageDictConversion(unittest.TestCase):
|
||||
human_msg: HumanMessage = HumanMessage(
|
||||
content="human", additional_kwargs={"key": "value"}
|
||||
)
|
||||
ai_msg: AIMessage = AIMessage(content="ai")
|
||||
sys_msg: SystemMessage = SystemMessage(content="sys")
|
||||
|
||||
def test_multiple_msg(self) -> None:
|
||||
msgs = [
|
||||
self.human_msg,
|
||||
self.ai_msg,
|
||||
self.sys_msg,
|
||||
]
|
||||
self.assertEqual(
|
||||
messages_from_dict(messages_to_dict(msgs)),
|
||||
msgs,
|
||||
)
|
Loading…
Reference in New Issue
Block a user