mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 03:26:17 +00:00
Harrison/add huggingface hub (#23)
Add support for huggingface hub I could not find a good way to enforce stop tokens over the huggingface hub api - that needs to hopefully be cleaned up in the future
This commit is contained in:
19
tests/integration_tests/llms/test_huggingface_hub.py
Normal file
19
tests/integration_tests/llms/test_huggingface_hub.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Test HuggingFace API wrapper."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||
|
||||
|
||||
def test_huggingface_call() -> None:
|
||||
"""Test valid call to HuggingFace."""
|
||||
llm = HuggingFaceHub(max_new_tokens=10)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_huggingface_call_error() -> None:
|
||||
"""Test valid call to HuggingFace that errors."""
|
||||
llm = HuggingFaceHub(max_new_tokens=-1)
|
||||
with pytest.raises(ValueError):
|
||||
llm("Say foo:")
|
@@ -1,17 +0,0 @@
|
||||
"""Test helper functions for Cohere API."""
|
||||
|
||||
from langchain.llms.cohere import remove_stop_tokens
|
||||
|
||||
|
||||
def test_remove_stop_tokens() -> None:
|
||||
"""Test removing stop tokens when they occur."""
|
||||
text = "foo bar baz"
|
||||
output = remove_stop_tokens(text, ["moo", "baz"])
|
||||
assert output == "foo bar "
|
||||
|
||||
|
||||
def test_remove_stop_tokens_none() -> None:
|
||||
"""Test removing stop tokens when they do not occur."""
|
||||
text = "foo bar baz"
|
||||
output = remove_stop_tokens(text, ["moo"])
|
||||
assert output == "foo bar baz"
|
19
tests/unit_tests/llms/test_utils.py
Normal file
19
tests/unit_tests/llms/test_utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Test LLM utility functions."""
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
def test_enforce_stop_tokens() -> None:
|
||||
"""Test removing stop tokens when they occur."""
|
||||
text = "foo bar baz"
|
||||
output = enforce_stop_tokens(text, ["moo", "baz"])
|
||||
assert output == "foo bar "
|
||||
text = "foo bar baz"
|
||||
output = enforce_stop_tokens(text, ["moo", "baz", "bar"])
|
||||
assert output == "foo "
|
||||
|
||||
|
||||
def test_enforce_stop_tokens_none() -> None:
|
||||
"""Test removing stop tokens when they do not occur."""
|
||||
text = "foo bar baz"
|
||||
output = enforce_stop_tokens(text, ["moo"])
|
||||
assert output == "foo bar baz"
|
Reference in New Issue
Block a user