mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 02:50:47 +00:00
Harrison/add huggingface hub (#23)
Add support for huggingface hub I could not find a good way to enforce stop tokens over the huggingface hub api - that needs to hopefully be cleaned up in the future
This commit is contained in:
@@ -1,17 +0,0 @@
|
||||
"""Test helper functions for Cohere API."""
|
||||
|
||||
from langchain.llms.cohere import remove_stop_tokens
|
||||
|
||||
|
||||
def test_remove_stop_tokens() -> None:
|
||||
"""Test removing stop tokens when they occur."""
|
||||
text = "foo bar baz"
|
||||
output = remove_stop_tokens(text, ["moo", "baz"])
|
||||
assert output == "foo bar "
|
||||
|
||||
|
||||
def test_remove_stop_tokens_none() -> None:
|
||||
"""Test removing stop tokens when they do not occur."""
|
||||
text = "foo bar baz"
|
||||
output = remove_stop_tokens(text, ["moo"])
|
||||
assert output == "foo bar baz"
|
19
tests/unit_tests/llms/test_utils.py
Normal file
19
tests/unit_tests/llms/test_utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Test LLM utility functions."""
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
def test_enforce_stop_tokens() -> None:
|
||||
"""Test removing stop tokens when they occur."""
|
||||
text = "foo bar baz"
|
||||
output = enforce_stop_tokens(text, ["moo", "baz"])
|
||||
assert output == "foo bar "
|
||||
text = "foo bar baz"
|
||||
output = enforce_stop_tokens(text, ["moo", "baz", "bar"])
|
||||
assert output == "foo "
|
||||
|
||||
|
||||
def test_enforce_stop_tokens_none() -> None:
|
||||
"""Test removing stop tokens when they do not occur."""
|
||||
text = "foo bar baz"
|
||||
output = enforce_stop_tokens(text, ["moo"])
|
||||
assert output == "foo bar baz"
|
Reference in New Issue
Block a user