mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 15:04:13 +00:00
community[patch]: upgrade to recent version of mypy (#21616)
This PR upgrades community to a recent version of mypy. It inserts type: ignore on all existing failures.
This commit is contained in:
@@ -5,6 +5,6 @@ from langchain_community.llms.aleph_alpha import AlephAlpha
|
||||
|
||||
def test_aleph_alpha_call() -> None:
|
||||
"""Test valid call to cohere."""
|
||||
llm = AlephAlpha(maximum_tokens=10)
|
||||
llm = AlephAlpha(maximum_tokens=10) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
@@ -17,20 +17,20 @@ def test_anthropic_model_name_param() -> None:
|
||||
|
||||
@pytest.mark.requires("anthropic")
|
||||
def test_anthropic_model_param() -> None:
|
||||
llm = Anthropic(model="foo")
|
||||
llm = Anthropic(model="foo") # type: ignore[call-arg]
|
||||
assert llm.model == "foo"
|
||||
|
||||
|
||||
def test_anthropic_call() -> None:
|
||||
"""Test valid call to anthropic."""
|
||||
llm = Anthropic(model="claude-instant-1")
|
||||
llm = Anthropic(model="claude-instant-1") # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_anthropic_streaming() -> None:
|
||||
"""Test streaming tokens from anthropic."""
|
||||
llm = Anthropic(model="claude-instant-1")
|
||||
llm = Anthropic(model="claude-instant-1") # type: ignore[call-arg]
|
||||
generator = llm.stream("I'm Pickle Rick")
|
||||
|
||||
assert isinstance(generator, Generator)
|
||||
|
@@ -22,9 +22,9 @@ from langchain_community.llms.loading import load_llm
|
||||
def test_gpt2_call() -> None:
|
||||
"""Test valid call to GPT2."""
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"), # type: ignore[arg-type]
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"), # type: ignore[arg-type]
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"), # type: ignore[arg-type]
|
||||
content_formatter=OSSContentFormatter(),
|
||||
)
|
||||
output = llm.invoke("Foo")
|
||||
@@ -34,9 +34,9 @@ def test_gpt2_call() -> None:
|
||||
def test_hf_call() -> None:
|
||||
"""Test valid call to HuggingFace Foundation Model."""
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("HF_ENDPOINT_API_KEY"),
|
||||
endpoint_url=os.getenv("HF_ENDPOINT_URL"),
|
||||
deployment_name=os.getenv("HF_DEPLOYMENT_NAME"),
|
||||
endpoint_api_key=os.getenv("HF_ENDPOINT_API_KEY"), # type: ignore[arg-type]
|
||||
endpoint_url=os.getenv("HF_ENDPOINT_URL"), # type: ignore[arg-type]
|
||||
deployment_name=os.getenv("HF_DEPLOYMENT_NAME"), # type: ignore[arg-type]
|
||||
content_formatter=HFContentFormatter(),
|
||||
)
|
||||
output = llm.invoke("Foo")
|
||||
@@ -46,9 +46,9 @@ def test_hf_call() -> None:
|
||||
def test_dolly_call() -> None:
|
||||
"""Test valid call to dolly-v2."""
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("DOLLY_ENDPOINT_API_KEY"),
|
||||
endpoint_url=os.getenv("DOLLY_ENDPOINT_URL"),
|
||||
deployment_name=os.getenv("DOLLY_DEPLOYMENT_NAME"),
|
||||
endpoint_api_key=os.getenv("DOLLY_ENDPOINT_API_KEY"), # type: ignore[arg-type]
|
||||
endpoint_url=os.getenv("DOLLY_ENDPOINT_URL"), # type: ignore[arg-type]
|
||||
deployment_name=os.getenv("DOLLY_DEPLOYMENT_NAME"), # type: ignore[arg-type]
|
||||
content_formatter=DollyContentFormatter(),
|
||||
)
|
||||
output = llm.invoke("Foo")
|
||||
@@ -77,9 +77,9 @@ def test_custom_formatter() -> None:
|
||||
return response_json[0]["summary_text"]
|
||||
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("BART_ENDPOINT_API_KEY"),
|
||||
endpoint_url=os.getenv("BART_ENDPOINT_URL"),
|
||||
deployment_name=os.getenv("BART_DEPLOYMENT_NAME"),
|
||||
endpoint_api_key=os.getenv("BART_ENDPOINT_API_KEY"), # type: ignore[arg-type]
|
||||
endpoint_url=os.getenv("BART_ENDPOINT_URL"), # type: ignore[arg-type]
|
||||
deployment_name=os.getenv("BART_DEPLOYMENT_NAME"), # type: ignore[arg-type]
|
||||
content_formatter=CustomFormatter(),
|
||||
)
|
||||
output = llm.invoke("Foo")
|
||||
@@ -90,9 +90,9 @@ def test_missing_content_formatter() -> None:
|
||||
"""Test AzureML LLM without a content_formatter attribute"""
|
||||
with pytest.raises(AttributeError):
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"), # type: ignore[arg-type]
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"), # type: ignore[arg-type]
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"), # type: ignore[arg-type]
|
||||
)
|
||||
llm.invoke("Foo")
|
||||
|
||||
@@ -119,9 +119,9 @@ def test_invalid_request_format() -> None:
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"), # type: ignore[arg-type]
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"), # type: ignore[arg-type]
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"), # type: ignore[arg-type]
|
||||
content_formatter=CustomContentFormatter(),
|
||||
)
|
||||
llm.invoke("Foo")
|
||||
@@ -131,9 +131,9 @@ def test_incorrect_url() -> None:
|
||||
"""Testing AzureML Endpoint for an incorrect URL"""
|
||||
with pytest.raises(ValidationError):
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"), # type: ignore[arg-type]
|
||||
endpoint_url="https://endpoint.inference.com",
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"), # type: ignore[arg-type]
|
||||
content_formatter=OSSContentFormatter(),
|
||||
)
|
||||
llm.invoke("Foo")
|
||||
@@ -142,10 +142,10 @@ def test_incorrect_url() -> None:
|
||||
def test_incorrect_api_type() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
endpoint_api_type="serverless",
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"), # type: ignore[arg-type]
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"), # type: ignore[arg-type]
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"), # type: ignore[arg-type]
|
||||
endpoint_api_type="serverless", # type: ignore[arg-type]
|
||||
content_formatter=OSSContentFormatter(),
|
||||
)
|
||||
llm.invoke("Foo")
|
||||
@@ -155,9 +155,9 @@ def test_incorrect_key() -> None:
|
||||
"""Testing AzureML Endpoint for incorrect key"""
|
||||
with pytest.raises(HTTPError):
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key="incorrect-key",
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
endpoint_api_key="incorrect-key", # type: ignore[arg-type]
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"), # type: ignore[arg-type]
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"), # type: ignore[arg-type]
|
||||
content_formatter=OSSContentFormatter(),
|
||||
)
|
||||
llm.invoke("Foo")
|
||||
|
@@ -8,6 +8,6 @@ from langchain_community.llms.baseten import Baseten
|
||||
|
||||
def test_baseten_call() -> None:
|
||||
"""Test valid call to Baseten."""
|
||||
llm = Baseten(model=os.environ["BASETEN_MODEL_ID"])
|
||||
llm = Baseten(model=os.environ["BASETEN_MODEL_ID"]) # type: ignore[call-arg]
|
||||
output = llm.invoke("Test prompt, please respond.")
|
||||
assert isinstance(output, str)
|
||||
|
@@ -8,7 +8,7 @@ def test_beam_call() -> None:
|
||||
llm = Beam(
|
||||
model_name="gpt2",
|
||||
name="langchain-gpt2",
|
||||
cpu=8,
|
||||
cpu=8, # type: ignore[arg-type]
|
||||
memory="32Gi",
|
||||
gpu="A10G",
|
||||
python_version="python3.8",
|
||||
|
@@ -5,6 +5,6 @@ from langchain_community.llms.cerebriumai import CerebriumAI
|
||||
|
||||
def test_cerebriumai_call() -> None:
|
||||
"""Test valid call to cerebriumai."""
|
||||
llm = CerebriumAI(max_length=10)
|
||||
llm = CerebriumAI(max_length=10) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
@@ -12,7 +12,7 @@ from tests.integration_tests.llms.utils import assert_llm_equality
|
||||
|
||||
def test_cohere_call() -> None:
|
||||
"""Test valid call to cohere."""
|
||||
llm = Cohere(max_tokens=10)
|
||||
llm = Cohere(max_tokens=10) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
@@ -20,16 +20,16 @@ def test_cohere_call() -> None:
|
||||
def test_cohere_api_key(monkeypatch: MonkeyPatch) -> None:
|
||||
"""Test that cohere api key is a secret key."""
|
||||
# test initialization from init
|
||||
assert isinstance(Cohere(cohere_api_key="1").cohere_api_key, SecretStr)
|
||||
assert isinstance(Cohere(cohere_api_key="1").cohere_api_key, SecretStr) # type: ignore[arg-type, call-arg]
|
||||
|
||||
# test initialization from env variable
|
||||
monkeypatch.setenv("COHERE_API_KEY", "secret-api-key")
|
||||
assert isinstance(Cohere().cohere_api_key, SecretStr)
|
||||
assert isinstance(Cohere().cohere_api_key, SecretStr) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an Cohere LLM."""
|
||||
llm = Cohere(max_tokens=10)
|
||||
llm = Cohere(max_tokens=10) # type: ignore[call-arg]
|
||||
llm.save(file_path=tmp_path / "cohere.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "cohere.yaml")
|
||||
assert_llm_equality(llm, loaded_llm)
|
||||
|
@@ -5,6 +5,6 @@ from langchain_community.llms.forefrontai import ForefrontAI
|
||||
|
||||
def test_forefrontai_call() -> None:
|
||||
"""Test valid call to forefrontai."""
|
||||
llm = ForefrontAI(length=10)
|
||||
llm = ForefrontAI(length=10) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
@@ -22,9 +22,9 @@ model_names = [None, "models/text-bison-001", "gemini-pro"]
|
||||
def test_google_generativeai_call(model_name: str) -> None:
|
||||
"""Test valid call to Google GenerativeAI text API."""
|
||||
if model_name:
|
||||
llm = GooglePalm(max_output_tokens=10, model_name=model_name)
|
||||
llm = GooglePalm(max_output_tokens=10, model_name=model_name) # type: ignore[call-arg]
|
||||
else:
|
||||
llm = GooglePalm(max_output_tokens=10)
|
||||
llm = GooglePalm(max_output_tokens=10) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
assert llm._llm_type == "google_palm"
|
||||
@@ -41,9 +41,9 @@ def test_google_generativeai_call(model_name: str) -> None:
|
||||
def test_google_generativeai_generate(model_name: str) -> None:
|
||||
n = 1 if model_name == "gemini-pro" else 2
|
||||
if model_name:
|
||||
llm = GooglePalm(temperature=0.3, n=n, model_name=model_name)
|
||||
llm = GooglePalm(temperature=0.3, n=n, model_name=model_name) # type: ignore[call-arg]
|
||||
else:
|
||||
llm = GooglePalm(temperature=0.3, n=n)
|
||||
llm = GooglePalm(temperature=0.3, n=n) # type: ignore[call-arg]
|
||||
output = llm.generate(["Say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
@@ -51,26 +51,26 @@ def test_google_generativeai_generate(model_name: str) -> None:
|
||||
|
||||
|
||||
def test_google_generativeai_get_num_tokens() -> None:
|
||||
llm = GooglePalm()
|
||||
llm = GooglePalm() # type: ignore[call-arg]
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
||||
|
||||
|
||||
async def test_google_generativeai_agenerate() -> None:
|
||||
llm = GooglePalm(temperature=0, model_name="gemini-pro")
|
||||
llm = GooglePalm(temperature=0, model_name="gemini-pro") # type: ignore[call-arg]
|
||||
output = await llm.agenerate(["Please say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
def test_generativeai_stream() -> None:
|
||||
llm = GooglePalm(temperature=0, model_name="gemini-pro")
|
||||
llm = GooglePalm(temperature=0, model_name="gemini-pro") # type: ignore[call-arg]
|
||||
outputs = list(llm.stream("Please say foo:"))
|
||||
assert isinstance(outputs[0], str)
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading a Google PaLM LLM."""
|
||||
llm = GooglePalm(max_output_tokens=10)
|
||||
llm = GooglePalm(max_output_tokens=10) # type: ignore[call-arg]
|
||||
llm.save(file_path=tmp_path / "google_palm.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "google_palm.yaml")
|
||||
assert loaded_llm == llm
|
||||
|
@@ -5,14 +5,14 @@ from langchain_community.llms.gooseai import GooseAI
|
||||
|
||||
def test_gooseai_call() -> None:
|
||||
"""Test valid call to gooseai."""
|
||||
llm = GooseAI(max_tokens=10)
|
||||
llm = GooseAI(max_tokens=10) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_gooseai_call_fairseq() -> None:
|
||||
"""Test valid call to gooseai with fairseq model."""
|
||||
llm = GooseAI(model_name="fairseq-1-3b", max_tokens=10)
|
||||
llm = GooseAI(model_name="fairseq-1-3b", max_tokens=10) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
@@ -20,9 +20,9 @@ def test_gooseai_call_fairseq() -> None:
|
||||
def test_gooseai_stop_valid() -> None:
|
||||
"""Test gooseai stop logic on valid configuration."""
|
||||
query = "write an ordered list of five items"
|
||||
first_llm = GooseAI(stop="3", temperature=0)
|
||||
first_llm = GooseAI(stop="3", temperature=0) # type: ignore[call-arg]
|
||||
first_output = first_llm.invoke(query)
|
||||
second_llm = GooseAI(temperature=0)
|
||||
second_llm = GooseAI(temperature=0) # type: ignore[call-arg]
|
||||
second_output = second_llm.invoke(query, stop=["3"])
|
||||
# Because it stops on new lines, shouldn't return anything
|
||||
assert first_output == second_output
|
||||
|
@@ -11,14 +11,14 @@ from tests.integration_tests.llms.utils import assert_llm_equality
|
||||
|
||||
def test_huggingface_endpoint_call_error() -> None:
|
||||
"""Test valid call to HuggingFace that errors."""
|
||||
llm = HuggingFaceEndpoint(endpoint_url="", model_kwargs={"max_new_tokens": -1})
|
||||
llm = HuggingFaceEndpoint(endpoint_url="", model_kwargs={"max_new_tokens": -1}) # type: ignore[call-arg]
|
||||
with pytest.raises(ValueError):
|
||||
llm.invoke("Say foo:")
|
||||
|
||||
|
||||
def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an HuggingFaceHub LLM."""
|
||||
llm = HuggingFaceEndpoint(
|
||||
llm = HuggingFaceEndpoint( # type: ignore[call-arg]
|
||||
endpoint_url="", task="text-generation", model_kwargs={"max_new_tokens": 10}
|
||||
)
|
||||
llm.save(file_path=tmp_path / "hf.yaml")
|
||||
@@ -28,7 +28,7 @@ def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
|
||||
|
||||
def test_huggingface_text_generation() -> None:
|
||||
"""Test valid call to HuggingFace text generation model."""
|
||||
llm = HuggingFaceEndpoint(repo_id="gpt2", model_kwargs={"max_new_tokens": 10})
|
||||
llm = HuggingFaceEndpoint(repo_id="gpt2", model_kwargs={"max_new_tokens": 10}) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
print(output) # noqa: T201
|
||||
assert isinstance(output, str)
|
||||
@@ -36,35 +36,35 @@ def test_huggingface_text_generation() -> None:
|
||||
|
||||
def test_huggingface_text2text_generation() -> None:
|
||||
"""Test valid call to HuggingFace text2text model."""
|
||||
llm = HuggingFaceEndpoint(repo_id="google/flan-t5-xl")
|
||||
llm = HuggingFaceEndpoint(repo_id="google/flan-t5-xl") # type: ignore[call-arg]
|
||||
output = llm.invoke("The capital of New York is")
|
||||
assert output == "Albany"
|
||||
|
||||
|
||||
def test_huggingface_summarization() -> None:
|
||||
"""Test valid call to HuggingFace summarization model."""
|
||||
llm = HuggingFaceEndpoint(repo_id="facebook/bart-large-cnn")
|
||||
llm = HuggingFaceEndpoint(repo_id="facebook/bart-large-cnn") # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_huggingface_call_error() -> None:
|
||||
"""Test valid call to HuggingFace that errors."""
|
||||
llm = HuggingFaceEndpoint(repo_id="gpt2", model_kwargs={"max_new_tokens": -1})
|
||||
llm = HuggingFaceEndpoint(repo_id="gpt2", model_kwargs={"max_new_tokens": -1}) # type: ignore[call-arg]
|
||||
with pytest.raises(ValueError):
|
||||
llm.invoke("Say foo:")
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an HuggingFaceEndpoint LLM."""
|
||||
llm = HuggingFaceEndpoint(repo_id="gpt2", model_kwargs={"max_new_tokens": 10})
|
||||
llm = HuggingFaceEndpoint(repo_id="gpt2", model_kwargs={"max_new_tokens": 10}) # type: ignore[call-arg]
|
||||
llm.save(file_path=tmp_path / "hf.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "hf.yaml")
|
||||
assert_llm_equality(llm, loaded_llm)
|
||||
|
||||
|
||||
def test_invocation_params_stop_sequences() -> None:
|
||||
llm = HuggingFaceEndpoint()
|
||||
llm = HuggingFaceEndpoint() # type: ignore[call-arg]
|
||||
assert llm._default_params["stop_sequences"] == []
|
||||
|
||||
runtime_stop = None
|
||||
@@ -75,7 +75,7 @@ def test_invocation_params_stop_sequences() -> None:
|
||||
assert llm._invocation_params(runtime_stop)["stop_sequences"] == ["stop"]
|
||||
assert llm._default_params["stop_sequences"] == []
|
||||
|
||||
llm = HuggingFaceEndpoint(stop_sequences=["."])
|
||||
llm = HuggingFaceEndpoint(stop_sequences=["."]) # type: ignore[call-arg]
|
||||
runtime_stop = ["stop"]
|
||||
assert llm._invocation_params(runtime_stop)["stop_sequences"] == [".", "stop"]
|
||||
assert llm._default_params["stop_sequences"] == ["."]
|
||||
|
@@ -11,35 +11,35 @@ from tests.integration_tests.llms.utils import assert_llm_equality
|
||||
|
||||
def test_huggingface_text_generation() -> None:
|
||||
"""Test valid call to HuggingFace text generation model."""
|
||||
llm = HuggingFaceHub(repo_id="gpt2", model_kwargs={"max_new_tokens": 10})
|
||||
llm = HuggingFaceHub(repo_id="gpt2", model_kwargs={"max_new_tokens": 10}) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_huggingface_text2text_generation() -> None:
|
||||
"""Test valid call to HuggingFace text2text model."""
|
||||
llm = HuggingFaceHub(repo_id="google/flan-t5-xl")
|
||||
llm = HuggingFaceHub(repo_id="google/flan-t5-xl") # type: ignore[call-arg]
|
||||
output = llm.invoke("The capital of New York is")
|
||||
assert output == "Albany"
|
||||
|
||||
|
||||
def test_huggingface_summarization() -> None:
|
||||
"""Test valid call to HuggingFace summarization model."""
|
||||
llm = HuggingFaceHub(repo_id="facebook/bart-large-cnn")
|
||||
llm = HuggingFaceHub(repo_id="facebook/bart-large-cnn") # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_huggingface_call_error() -> None:
|
||||
"""Test valid call to HuggingFace that errors."""
|
||||
llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1})
|
||||
llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1}) # type: ignore[call-arg]
|
||||
with pytest.raises(ValueError):
|
||||
llm.invoke("Say foo:")
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an HuggingFaceHub LLM."""
|
||||
llm = HuggingFaceHub(repo_id="gpt2", model_kwargs={"max_new_tokens": 10})
|
||||
llm = HuggingFaceHub(repo_id="gpt2", model_kwargs={"max_new_tokens": 10}) # type: ignore[call-arg]
|
||||
llm.save(file_path=tmp_path / "hf.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "hf.yaml")
|
||||
assert_llm_equality(llm, loaded_llm)
|
||||
|
@@ -24,7 +24,7 @@ class MockLLM(LLM):
|
||||
|
||||
def test_layerup_security_with_invalid_api_key() -> None:
|
||||
mock_llm = MockLLM()
|
||||
layerup_security = LayerupSecurity(
|
||||
layerup_security = LayerupSecurity( # type: ignore[call-arg]
|
||||
llm=mock_llm,
|
||||
layerup_api_key="-- invalid API key --",
|
||||
layerup_api_base_url="https://api.uselayerup.com/v1",
|
||||
|
@@ -4,14 +4,14 @@ from langchain_community.llms.minimax import Minimax
|
||||
|
||||
def test_minimax_call() -> None:
|
||||
"""Test valid call to minimax."""
|
||||
llm = Minimax(max_tokens=10)
|
||||
llm = Minimax(max_tokens=10) # type: ignore[call-arg]
|
||||
output = llm.invoke("Hello world!")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_minimax_call_successful() -> None:
|
||||
"""Test valid call to minimax."""
|
||||
llm = Minimax()
|
||||
llm = Minimax() # type: ignore[call-arg]
|
||||
output = llm.invoke(
|
||||
"A chain is a serial assembly of connected pieces, called links, \
|
||||
typically made of metal, with an overall character similar to that\
|
||||
|
@@ -13,14 +13,14 @@ from tests.integration_tests.llms.utils import assert_llm_equality
|
||||
|
||||
def test_nlpcloud_call() -> None:
|
||||
"""Test valid call to nlpcloud."""
|
||||
llm = NLPCloud(max_length=10)
|
||||
llm = NLPCloud(max_length=10) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an NLPCloud LLM."""
|
||||
llm = NLPCloud(max_length=10)
|
||||
llm = NLPCloud(max_length=10) # type: ignore[call-arg]
|
||||
llm.save(file_path=tmp_path / "nlpcloud.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "nlpcloud.yaml")
|
||||
assert_llm_equality(llm, loaded_llm)
|
||||
@@ -29,10 +29,10 @@ def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
def test_nlpcloud_api_key(monkeypatch: MonkeyPatch, capsys: CaptureFixture) -> None:
|
||||
"""Test that nlpcloud api key is a secret key."""
|
||||
# test initialization from init
|
||||
assert isinstance(NLPCloud(nlpcloud_api_key="1").nlpcloud_api_key, SecretStr)
|
||||
assert isinstance(NLPCloud(nlpcloud_api_key="1").nlpcloud_api_key, SecretStr) # type: ignore[arg-type, call-arg]
|
||||
|
||||
monkeypatch.setenv("NLPCLOUD_API_KEY", "secret-api-key")
|
||||
llm = NLPCloud()
|
||||
llm = NLPCloud() # type: ignore[call-arg]
|
||||
assert isinstance(llm.nlpcloud_api_key, SecretStr)
|
||||
|
||||
assert cast(SecretStr, llm.nlpcloud_api_key).get_secret_value() == "secret-api-key"
|
||||
|
@@ -43,7 +43,7 @@ Question: ```{question}```
|
||||
|
||||
|
||||
def test_opaqueprompts() -> None:
|
||||
chain = PromptTemplate.from_template(prompt_template) | OpaquePrompts(llm=OpenAI())
|
||||
chain = PromptTemplate.from_template(prompt_template) | OpaquePrompts(llm=OpenAI()) # type: ignore[call-arg]
|
||||
output = chain.invoke(
|
||||
{
|
||||
"question": "Write a text message to remind John to do password reset \
|
||||
|
@@ -33,7 +33,7 @@ def test_openai_llm_output_contains_model_name() -> None:
|
||||
def test_openai_stop_valid() -> None:
|
||||
"""Test openai stop logic on valid configuration."""
|
||||
query = "write an ordered list of five items"
|
||||
first_llm = OpenAI(stop="3", temperature=0)
|
||||
first_llm = OpenAI(stop="3", temperature=0) # type: ignore[call-arg]
|
||||
first_output = first_llm.invoke(query)
|
||||
second_llm = OpenAI(temperature=0)
|
||||
second_output = second_llm.invoke(query, stop=["3"])
|
||||
@@ -43,7 +43,7 @@ def test_openai_stop_valid() -> None:
|
||||
|
||||
def test_openai_stop_error() -> None:
|
||||
"""Test openai stop logic on bad configuration."""
|
||||
llm = OpenAI(stop="3", temperature=0)
|
||||
llm = OpenAI(stop="3", temperature=0) # type: ignore[call-arg]
|
||||
with pytest.raises(ValueError):
|
||||
llm.invoke("write an ordered list of five items", stop=["\n"])
|
||||
|
||||
|
@@ -3,6 +3,6 @@ from langchain_community.llms.openlm import OpenLM
|
||||
|
||||
def test_openlm_call() -> None:
|
||||
"""Test valid call to openlm."""
|
||||
llm = OpenLM(model_name="dolly-v2-7b", max_tokens=10)
|
||||
llm = OpenLM(model_name="dolly-v2-7b", max_tokens=10) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
@@ -8,8 +8,8 @@ from langchain_community.llms.pai_eas_endpoint import PaiEasEndpoint
|
||||
def test_pai_eas_v1_call() -> None:
|
||||
"""Test valid call to PAI-EAS Service."""
|
||||
llm = PaiEasEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"), # type: ignore[arg-type]
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"), # type: ignore[arg-type]
|
||||
version="1.0",
|
||||
)
|
||||
output = llm.invoke("Say foo:")
|
||||
@@ -18,8 +18,8 @@ def test_pai_eas_v1_call() -> None:
|
||||
|
||||
def test_pai_eas_v2_call() -> None:
|
||||
llm = PaiEasEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"), # type: ignore[arg-type]
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"), # type: ignore[arg-type]
|
||||
version="2.0",
|
||||
)
|
||||
output = llm.invoke("Say foo:")
|
||||
@@ -29,8 +29,8 @@ def test_pai_eas_v2_call() -> None:
|
||||
def test_pai_eas_v1_streaming() -> None:
|
||||
"""Test streaming call to PAI-EAS Service."""
|
||||
llm = PaiEasEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"), # type: ignore[arg-type]
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"), # type: ignore[arg-type]
|
||||
version="1.0",
|
||||
)
|
||||
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
|
||||
@@ -45,8 +45,8 @@ def test_pai_eas_v1_streaming() -> None:
|
||||
|
||||
def test_pai_eas_v2_streaming() -> None:
|
||||
llm = PaiEasEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"), # type: ignore[arg-type]
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"), # type: ignore[arg-type]
|
||||
version="2.0",
|
||||
)
|
||||
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
|
||||
|
@@ -7,14 +7,14 @@ from langchain_community.llms.petals import Petals
|
||||
|
||||
|
||||
def test_api_key_is_string() -> None:
|
||||
llm = Petals(huggingface_api_key="secret-api-key")
|
||||
llm = Petals(huggingface_api_key="secret-api-key") # type: ignore[arg-type, call-arg]
|
||||
assert isinstance(llm.huggingface_api_key, SecretStr)
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
llm = Petals(huggingface_api_key="secret-api-key")
|
||||
llm = Petals(huggingface_api_key="secret-api-key") # type: ignore[arg-type, call-arg]
|
||||
print(llm.huggingface_api_key, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
|
||||
@@ -23,6 +23,6 @@ def test_api_key_masked_when_passed_via_constructor(
|
||||
|
||||
def test_gooseai_call() -> None:
|
||||
"""Test valid call to gooseai."""
|
||||
llm = Petals(max_new_tokens=10)
|
||||
llm = Petals(max_new_tokens=10) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
@@ -5,6 +5,6 @@ from langchain_community.llms.predictionguard import PredictionGuard
|
||||
|
||||
def test_predictionguard_call() -> None:
|
||||
"""Test valid call to prediction guard."""
|
||||
llm = PredictionGuard(model="OpenAI-text-davinci-003")
|
||||
llm = PredictionGuard(model="OpenAI-text-davinci-003") # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
@@ -11,7 +11,7 @@ from langchain_community.llms.promptlayer_openai import PromptLayerOpenAI
|
||||
|
||||
def test_promptlayer_openai_call() -> None:
|
||||
"""Test valid call to promptlayer openai."""
|
||||
llm = PromptLayerOpenAI(max_tokens=10)
|
||||
llm = PromptLayerOpenAI(max_tokens=10) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
@@ -19,25 +19,25 @@ def test_promptlayer_openai_call() -> None:
|
||||
def test_promptlayer_openai_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to promptlayer openai."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
llm = PromptLayerOpenAI(foo=3, max_tokens=10)
|
||||
llm = PromptLayerOpenAI(foo=3, max_tokens=10) # type: ignore[call-arg]
|
||||
assert llm.max_tokens == 10
|
||||
assert llm.model_kwargs == {"foo": 3}
|
||||
|
||||
# Test that if extra_kwargs are provided, they are added to it.
|
||||
llm = PromptLayerOpenAI(foo=3, model_kwargs={"bar": 2})
|
||||
llm = PromptLayerOpenAI(foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||
|
||||
# Test that if provided twice it errors
|
||||
with pytest.raises(ValueError):
|
||||
PromptLayerOpenAI(foo=3, model_kwargs={"foo": 2})
|
||||
PromptLayerOpenAI(foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_promptlayer_openai_stop_valid() -> None:
|
||||
"""Test promptlayer openai stop logic on valid configuration."""
|
||||
query = "write an ordered list of five items"
|
||||
first_llm = PromptLayerOpenAI(stop="3", temperature=0)
|
||||
first_llm = PromptLayerOpenAI(stop="3", temperature=0) # type: ignore[call-arg]
|
||||
first_output = first_llm.invoke(query)
|
||||
second_llm = PromptLayerOpenAI(temperature=0)
|
||||
second_llm = PromptLayerOpenAI(temperature=0) # type: ignore[call-arg]
|
||||
second_output = second_llm.invoke(query, stop=["3"])
|
||||
# Because it stops on new lines, shouldn't return anything
|
||||
assert first_output == second_output
|
||||
@@ -45,14 +45,14 @@ def test_promptlayer_openai_stop_valid() -> None:
|
||||
|
||||
def test_promptlayer_openai_stop_error() -> None:
|
||||
"""Test promptlayer openai stop logic on bad configuration."""
|
||||
llm = PromptLayerOpenAI(stop="3", temperature=0)
|
||||
llm = PromptLayerOpenAI(stop="3", temperature=0) # type: ignore[call-arg]
|
||||
with pytest.raises(ValueError):
|
||||
llm.invoke("write an ordered list of five items", stop=["\n"])
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an promptlayer OpenAPI LLM."""
|
||||
llm = PromptLayerOpenAI(max_tokens=10)
|
||||
llm = PromptLayerOpenAI(max_tokens=10) # type: ignore[call-arg]
|
||||
llm.save(file_path=tmp_path / "openai.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "openai.yaml")
|
||||
assert loaded_llm == llm
|
||||
@@ -60,7 +60,7 @@ def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
|
||||
def test_promptlayer_openai_streaming() -> None:
|
||||
"""Test streaming tokens from promptalyer OpenAI."""
|
||||
llm = PromptLayerOpenAI(max_tokens=10)
|
||||
llm = PromptLayerOpenAI(max_tokens=10) # type: ignore[call-arg]
|
||||
generator = llm.stream("I'm Pickle Rick")
|
||||
|
||||
assert isinstance(generator, Generator)
|
||||
|
@@ -10,7 +10,7 @@ from langchain_community.llms.promptlayer_openai import PromptLayerOpenAIChat
|
||||
|
||||
def test_promptlayer_openai_chat_call() -> None:
|
||||
"""Test valid call to promptlayer openai."""
|
||||
llm = PromptLayerOpenAIChat(max_tokens=10)
|
||||
llm = PromptLayerOpenAIChat(max_tokens=10) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
@@ -18,9 +18,9 @@ def test_promptlayer_openai_chat_call() -> None:
|
||||
def test_promptlayer_openai_chat_stop_valid() -> None:
|
||||
"""Test promptlayer openai stop logic on valid configuration."""
|
||||
query = "write an ordered list of five items"
|
||||
first_llm = PromptLayerOpenAIChat(stop="3", temperature=0)
|
||||
first_llm = PromptLayerOpenAIChat(stop="3", temperature=0) # type: ignore[call-arg]
|
||||
first_output = first_llm.invoke(query)
|
||||
second_llm = PromptLayerOpenAIChat(temperature=0)
|
||||
second_llm = PromptLayerOpenAIChat(temperature=0) # type: ignore[call-arg]
|
||||
second_output = second_llm.invoke(query, stop=["3"])
|
||||
# Because it stops on new lines, shouldn't return anything
|
||||
assert first_output == second_output
|
||||
@@ -28,14 +28,14 @@ def test_promptlayer_openai_chat_stop_valid() -> None:
|
||||
|
||||
def test_promptlayer_openai_chat_stop_error() -> None:
|
||||
"""Test promptlayer openai stop logic on bad configuration."""
|
||||
llm = PromptLayerOpenAIChat(stop="3", temperature=0)
|
||||
llm = PromptLayerOpenAIChat(stop="3", temperature=0) # type: ignore[call-arg]
|
||||
with pytest.raises(ValueError):
|
||||
llm.invoke("write an ordered list of five items", stop=["\n"])
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an promptlayer OpenAPI LLM."""
|
||||
llm = PromptLayerOpenAIChat(max_tokens=10)
|
||||
llm = PromptLayerOpenAIChat(max_tokens=10) # type: ignore[call-arg]
|
||||
llm.save(file_path=tmp_path / "openai.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "openai.yaml")
|
||||
assert loaded_llm == llm
|
||||
|
@@ -8,14 +8,14 @@ from langchain_community.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint
|
||||
|
||||
def test_call() -> None:
|
||||
"""Test valid call to qianfan."""
|
||||
llm = QianfanLLMEndpoint()
|
||||
llm = QianfanLLMEndpoint() # type: ignore[call-arg]
|
||||
output = llm.invoke("write a joke")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_generate() -> None:
|
||||
"""Test valid call to qianfan."""
|
||||
llm = QianfanLLMEndpoint()
|
||||
llm = QianfanLLMEndpoint() # type: ignore[call-arg]
|
||||
output = llm.generate(["write a joke"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
@@ -23,20 +23,20 @@ def test_generate() -> None:
|
||||
|
||||
def test_generate_stream() -> None:
|
||||
"""Test valid call to qianfan."""
|
||||
llm = QianfanLLMEndpoint()
|
||||
llm = QianfanLLMEndpoint() # type: ignore[call-arg]
|
||||
output = llm.stream("write a joke")
|
||||
assert isinstance(output, Generator)
|
||||
|
||||
|
||||
async def test_qianfan_aio() -> None:
|
||||
llm = QianfanLLMEndpoint(streaming=True)
|
||||
llm = QianfanLLMEndpoint(streaming=True) # type: ignore[call-arg]
|
||||
|
||||
async for token in llm.astream("hi qianfan."):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_rate_limit() -> None:
|
||||
llm = QianfanLLMEndpoint(model="ERNIE-Bot", init_kwargs={"query_per_second": 2})
|
||||
llm = QianfanLLMEndpoint(model="ERNIE-Bot", init_kwargs={"query_per_second": 2}) # type: ignore[call-arg]
|
||||
assert llm.client._client._rate_limiter._sync_limiter._query_per_second == 2
|
||||
output = llm.generate(["write a joke"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
@@ -29,11 +29,11 @@ def test_replicate_streaming_call() -> None:
|
||||
|
||||
def test_replicate_model_kwargs() -> None:
|
||||
"""Test simple non-streaming call to Replicate."""
|
||||
llm = Replicate(
|
||||
llm = Replicate( # type: ignore[call-arg]
|
||||
model=TEST_MODEL, model_kwargs={"max_length": 100, "temperature": 0.01}
|
||||
)
|
||||
long_output = llm.invoke("What is LangChain")
|
||||
llm = Replicate(
|
||||
llm = Replicate( # type: ignore[call-arg]
|
||||
model=TEST_MODEL, model_kwargs={"max_length": 10, "temperature": 0.01}
|
||||
)
|
||||
short_output = llm.invoke("What is LangChain")
|
||||
|
@@ -34,7 +34,7 @@ Sam: Perfect! Let's keep the momentum going. Reach out if there are any
|
||||
sudden issues or support needed. Have a productive day!
|
||||
Alex: You too.
|
||||
Rhea: Thanks, bye!"""
|
||||
llm = Nebula(nebula_api_key="<your_api_key>")
|
||||
llm = Nebula(nebula_api_key="<your_api_key>") # type: ignore[arg-type]
|
||||
|
||||
instruction = """Identify the main objectives mentioned in this
|
||||
conversation."""
|
||||
|
@@ -7,14 +7,14 @@ from langchain_community.llms.tongyi import Tongyi
|
||||
|
||||
def test_tongyi_call() -> None:
|
||||
"""Test valid call to tongyi."""
|
||||
llm = Tongyi()
|
||||
llm = Tongyi() # type: ignore[call-arg]
|
||||
output = llm.invoke("who are you")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_tongyi_generate() -> None:
|
||||
"""Test valid call to tongyi."""
|
||||
llm = Tongyi()
|
||||
llm = Tongyi() # type: ignore[call-arg]
|
||||
output = llm.generate(["who are you"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
@@ -22,7 +22,7 @@ def test_tongyi_generate() -> None:
|
||||
|
||||
def test_tongyi_generate_stream() -> None:
|
||||
"""Test valid call to tongyi."""
|
||||
llm = Tongyi(streaming=True)
|
||||
llm = Tongyi(streaming=True) # type: ignore[call-arg]
|
||||
output = llm.generate(["who are you"])
|
||||
print(output) # noqa: T201
|
||||
assert isinstance(output, LLMResult)
|
||||
|
@@ -13,9 +13,9 @@ from langchain_community.llms.volcengine_maas import (
|
||||
|
||||
|
||||
def test_api_key_is_string() -> None:
|
||||
llm = VolcEngineMaasBase(
|
||||
volc_engine_maas_ak="secret-volc-ak",
|
||||
volc_engine_maas_sk="secret-volc-sk",
|
||||
llm = VolcEngineMaasBase( # type: ignore[call-arg]
|
||||
volc_engine_maas_ak="secret-volc-ak", # type: ignore[arg-type]
|
||||
volc_engine_maas_sk="secret-volc-sk", # type: ignore[arg-type]
|
||||
)
|
||||
assert isinstance(llm.volc_engine_maas_ak, SecretStr)
|
||||
assert isinstance(llm.volc_engine_maas_sk, SecretStr)
|
||||
@@ -24,9 +24,9 @@ def test_api_key_is_string() -> None:
|
||||
def test_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
llm = VolcEngineMaasBase(
|
||||
volc_engine_maas_ak="secret-volc-ak",
|
||||
volc_engine_maas_sk="secret-volc-sk",
|
||||
llm = VolcEngineMaasBase( # type: ignore[call-arg]
|
||||
volc_engine_maas_ak="secret-volc-ak", # type: ignore[arg-type]
|
||||
volc_engine_maas_sk="secret-volc-sk", # type: ignore[arg-type]
|
||||
)
|
||||
print(llm.volc_engine_maas_ak, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
@@ -36,14 +36,14 @@ def test_api_key_masked_when_passed_via_constructor(
|
||||
|
||||
def test_default_call() -> None:
|
||||
"""Test valid call to volc engine."""
|
||||
llm = VolcEngineMaasLLM()
|
||||
llm = VolcEngineMaasLLM() # type: ignore[call-arg]
|
||||
output = llm.invoke("tell me a joke")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_generate() -> None:
|
||||
"""Test valid call to volc engine."""
|
||||
llm = VolcEngineMaasLLM()
|
||||
llm = VolcEngineMaasLLM() # type: ignore[call-arg]
|
||||
output = llm.generate(["tell me a joke"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
@@ -51,6 +51,6 @@ def test_generate() -> None:
|
||||
|
||||
def test_generate_stream() -> None:
|
||||
"""Test valid call to volc engine."""
|
||||
llm = VolcEngineMaasLLM(streaming=True)
|
||||
llm = VolcEngineMaasLLM(streaming=True) # type: ignore[call-arg]
|
||||
output = llm.stream("tell me a joke")
|
||||
assert isinstance(output, Generator)
|
||||
|
Reference in New Issue
Block a user