mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
```python """python scripts/update_mypy_ruff.py""" import glob import tomllib from pathlib import Path import toml import subprocess import re ROOT_DIR = Path(__file__).parents[1] def main(): for path in glob.glob(str(ROOT_DIR / "libs/**/pyproject.toml"), recursive=True): print(path) with open(path, "rb") as f: pyproject = tomllib.load(f) try: pyproject["tool"]["poetry"]["group"]["typing"]["dependencies"]["mypy"] = ( "^1.10" ) pyproject["tool"]["poetry"]["group"]["lint"]["dependencies"]["ruff"] = ( "^0.5" ) except KeyError: continue with open(path, "w") as f: toml.dump(pyproject, f) cwd = "/".join(path.split("/")[:-1]) completed = subprocess.run( "poetry lock --no-update; poetry install --with typing; poetry run mypy . --no-color", cwd=cwd, shell=True, capture_output=True, text=True, ) logs = completed.stdout.split("\n") to_ignore = {} for l in logs: if re.match("^(.*)\:(\d+)\: error:.*\[(.*)\]", l): path, line_no, error_type = re.match( "^(.*)\:(\d+)\: error:.*\[(.*)\]", l ).groups() if (path, line_no) in to_ignore: to_ignore[(path, line_no)].append(error_type) else: to_ignore[(path, line_no)] = [error_type] print(len(to_ignore)) for (error_path, line_no), error_types in to_ignore.items(): all_errors = ", ".join(error_types) full_path = f"{cwd}/{error_path}" try: with open(full_path, "r") as f: file_lines = f.readlines() except FileNotFoundError: continue file_lines[int(line_no) - 1] = ( file_lines[int(line_no) - 1][:-1] + f" # type: ignore[{all_errors}]\n" ) with open(full_path, "w") as f: f.write("".join(file_lines)) subprocess.run( "poetry run ruff format .; poetry run ruff --select I --fix .", cwd=cwd, shell=True, capture_output=True, text=True, ) if __name__ == "__main__": main() ```
143 lines
4.7 KiB
Python
143 lines
4.7 KiB
Python
"""Test Titan Takeoff wrapper."""
|
|
|
|
import json
|
|
from typing import Any, Union
|
|
|
|
import pytest
|
|
|
|
from langchain_community.llms import TitanTakeoff, TitanTakeoffPro
|
|
|
|
|
|
@pytest.mark.requires("takeoff_client")
|
|
@pytest.mark.requires("pytest_httpx")
|
|
@pytest.mark.parametrize("streaming", [True, False])
|
|
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
|
|
def test_titan_takeoff_call(
|
|
httpx_mock: Any,
|
|
streaming: bool,
|
|
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
|
|
) -> None:
|
|
"""Test valid call to Titan Takeoff."""
|
|
from pytest_httpx import IteratorStream
|
|
|
|
port = 2345
|
|
url = (
|
|
f"http://localhost:{port}/generate_stream"
|
|
if streaming
|
|
else f"http://localhost:{port}/generate"
|
|
)
|
|
|
|
if streaming:
|
|
httpx_mock.add_response(
|
|
method="POST",
|
|
url=url,
|
|
stream=IteratorStream([b"data: ask someone else\n\n"]),
|
|
)
|
|
else:
|
|
httpx_mock.add_response(
|
|
method="POST",
|
|
url=url,
|
|
json={"text": "ask someone else"},
|
|
)
|
|
|
|
llm = takeoff_object(port=port, streaming=streaming)
|
|
number_of_calls = 0
|
|
for function_call in [llm, llm.invoke]:
|
|
number_of_calls += 1
|
|
output = function_call("What is 2 + 2?")
|
|
assert isinstance(output, str)
|
|
assert len(httpx_mock.get_requests()) == number_of_calls
|
|
assert httpx_mock.get_requests()[0].url == url
|
|
assert (
|
|
json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
|
|
)
|
|
|
|
if streaming:
|
|
output = llm._stream("What is 2 + 2?")
|
|
for chunk in output:
|
|
assert isinstance(chunk.text, str)
|
|
assert len(httpx_mock.get_requests()) == number_of_calls + 1
|
|
assert httpx_mock.get_requests()[0].url == url
|
|
assert (
|
|
json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("pytest_httpx")
|
|
@pytest.mark.requires("takeoff_client")
|
|
@pytest.mark.parametrize("streaming", [True, False])
|
|
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
|
|
def test_titan_takeoff_bad_call(
|
|
httpx_mock: Any,
|
|
streaming: bool,
|
|
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
|
|
) -> None:
|
|
"""Test valid call to Titan Takeoff."""
|
|
from takeoff_client import TakeoffException
|
|
|
|
url = (
|
|
"http://localhost:3000/generate"
|
|
if not streaming
|
|
else "http://localhost:3000/generate_stream"
|
|
)
|
|
httpx_mock.add_response(
|
|
method="POST", url=url, json={"text": "bad things"}, status_code=400
|
|
)
|
|
|
|
llm = takeoff_object(streaming=streaming)
|
|
with pytest.raises(TakeoffException):
|
|
llm.invoke("What is 2 + 2?")
|
|
assert len(httpx_mock.get_requests()) == 1
|
|
assert httpx_mock.get_requests()[0].url == url
|
|
assert json.loads(httpx_mock.get_requests()[0].content)["text"] == "What is 2 + 2?"
|
|
|
|
|
|
@pytest.mark.requires("pytest_httpx")
|
|
@pytest.mark.requires("takeoff_client")
|
|
@pytest.mark.parametrize("takeoff_object", [TitanTakeoff, TitanTakeoffPro])
|
|
def test_titan_takeoff_model_initialisation(
|
|
httpx_mock: Any,
|
|
takeoff_object: Union[TitanTakeoff, TitanTakeoffPro],
|
|
) -> None:
|
|
"""Test valid call to Titan Takeoff."""
|
|
mgnt_port = 36452
|
|
inf_port = 46253
|
|
mgnt_url = f"http://localhost:{mgnt_port}/reader"
|
|
gen_url = f"http://localhost:{inf_port}/generate"
|
|
reader_1 = {
|
|
"model_name": "test",
|
|
"device": "cpu",
|
|
"consumer_group": "primary",
|
|
"max_sequence_length": 512,
|
|
"max_batch_size": 4,
|
|
"tensor_parallel": 3,
|
|
}
|
|
reader_2 = reader_1.copy()
|
|
reader_2["model_name"] = "test2"
|
|
|
|
httpx_mock.add_response(
|
|
method="POST", url=mgnt_url, json={"key": "value"}, status_code=201
|
|
)
|
|
httpx_mock.add_response(
|
|
method="POST", url=gen_url, json={"text": "value"}, status_code=200
|
|
)
|
|
|
|
llm = takeoff_object(
|
|
port=inf_port, mgmt_port=mgnt_port, models=[reader_1, reader_2]
|
|
)
|
|
output = llm.invoke("What is 2 + 2?")
|
|
|
|
assert isinstance(output, str)
|
|
# Ensure the management api was called to create the reader
|
|
assert len(httpx_mock.get_requests()) == 3
|
|
for key, value in reader_1.items():
|
|
assert json.loads(httpx_mock.get_requests()[0].content)[key] == value
|
|
assert httpx_mock.get_requests()[0].url == mgnt_url
|
|
# Also second call should be made to spin uo reader 2
|
|
for key, value in reader_2.items():
|
|
assert json.loads(httpx_mock.get_requests()[1].content)[key] == value
|
|
assert httpx_mock.get_requests()[1].url == mgnt_url
|
|
# Ensure the third call is to generate endpoint to inference
|
|
assert httpx_mock.get_requests()[2].url == gen_url
|
|
assert json.loads(httpx_mock.get_requests()[2].content)["text"] == "What is 2 + 2?"
|