mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-29 04:16:02 +00:00
```python """python scripts/update_mypy_ruff.py""" import glob import tomllib from pathlib import Path import toml import subprocess import re ROOT_DIR = Path(__file__).parents[1] def main(): for path in glob.glob(str(ROOT_DIR / "libs/**/pyproject.toml"), recursive=True): print(path) with open(path, "rb") as f: pyproject = tomllib.load(f) try: pyproject["tool"]["poetry"]["group"]["typing"]["dependencies"]["mypy"] = ( "^1.10" ) pyproject["tool"]["poetry"]["group"]["lint"]["dependencies"]["ruff"] = ( "^0.5" ) except KeyError: continue with open(path, "w") as f: toml.dump(pyproject, f) cwd = "/".join(path.split("/")[:-1]) completed = subprocess.run( "poetry lock --no-update; poetry install --with typing; poetry run mypy . --no-color", cwd=cwd, shell=True, capture_output=True, text=True, ) logs = completed.stdout.split("\n") to_ignore = {} for l in logs: if re.match("^(.*)\:(\d+)\: error:.*\[(.*)\]", l): path, line_no, error_type = re.match( "^(.*)\:(\d+)\: error:.*\[(.*)\]", l ).groups() if (path, line_no) in to_ignore: to_ignore[(path, line_no)].append(error_type) else: to_ignore[(path, line_no)] = [error_type] print(len(to_ignore)) for (error_path, line_no), error_types in to_ignore.items(): all_errors = ", ".join(error_types) full_path = f"{cwd}/{error_path}" try: with open(full_path, "r") as f: file_lines = f.readlines() except FileNotFoundError: continue file_lines[int(line_no) - 1] = ( file_lines[int(line_no) - 1][:-1] + f" # type: ignore[{all_errors}]\n" ) with open(full_path, "w") as f: f.write("".join(file_lines)) subprocess.run( "poetry run ruff format .; poetry run ruff --select I --fix .", cwd=cwd, shell=True, capture_output=True, text=True, ) if __name__ == "__main__": main() ```
176 lines
6.1 KiB
Python
176 lines
6.1 KiB
Python
"""Evaluate ChatKonko Interface."""
|
|
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from langchain_core.callbacks import CallbackManager
|
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
|
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
|
|
|
|
from langchain_community.chat_models.konko import ChatKonko
|
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
|
|
|
|
|
def test_konko_chat_test() -> None:
|
|
"""Evaluate basic ChatKonko functionality."""
|
|
chat_instance = ChatKonko(max_tokens=10)
|
|
msg = HumanMessage(content="Hi")
|
|
chat_response = chat_instance.invoke([msg])
|
|
assert isinstance(chat_response, BaseMessage)
|
|
assert isinstance(chat_response.content, str)
|
|
|
|
|
|
def test_konko_chat_test_openai() -> None:
|
|
"""Evaluate basic ChatKonko functionality."""
|
|
chat_instance = ChatKonko(max_tokens=10, model="meta-llama/llama-2-70b-chat")
|
|
msg = HumanMessage(content="Hi")
|
|
chat_response = chat_instance.invoke([msg])
|
|
assert isinstance(chat_response, BaseMessage)
|
|
assert isinstance(chat_response.content, str)
|
|
|
|
|
|
def test_konko_model_test() -> None:
|
|
"""Check how ChatKonko manages model_name."""
|
|
chat_instance = ChatKonko(model="alpha")
|
|
assert chat_instance.model == "alpha"
|
|
chat_instance = ChatKonko(model="beta")
|
|
assert chat_instance.model == "beta"
|
|
|
|
|
|
def test_konko_available_model_test() -> None:
|
|
"""Check how ChatKonko manages model_name."""
|
|
chat_instance = ChatKonko(max_tokens=10, n=2)
|
|
res = chat_instance.get_available_models()
|
|
assert isinstance(res, set)
|
|
|
|
|
|
def test_konko_system_msg_test() -> None:
|
|
"""Evaluate ChatKonko's handling of system messages."""
|
|
chat_instance = ChatKonko(max_tokens=10)
|
|
sys_msg = SystemMessage(content="Initiate user chat.")
|
|
user_msg = HumanMessage(content="Hi there")
|
|
chat_response = chat_instance.invoke([sys_msg, user_msg])
|
|
assert isinstance(chat_response, BaseMessage)
|
|
assert isinstance(chat_response.content, str)
|
|
|
|
|
|
def test_konko_generation_test() -> None:
|
|
"""Check ChatKonko's generation ability."""
|
|
chat_instance = ChatKonko(max_tokens=10, n=2)
|
|
msg = HumanMessage(content="Hi")
|
|
gen_response = chat_instance.generate([[msg], [msg]])
|
|
assert isinstance(gen_response, LLMResult)
|
|
assert len(gen_response.generations) == 2
|
|
for gen_list in gen_response.generations:
|
|
assert len(gen_list) == 2
|
|
for gen in gen_list:
|
|
assert isinstance(gen, ChatGeneration)
|
|
assert isinstance(gen.text, str)
|
|
assert gen.text == gen.message.content
|
|
|
|
|
|
def test_konko_multiple_outputs_test() -> None:
|
|
"""Test multiple completions with ChatKonko."""
|
|
chat_instance = ChatKonko(max_tokens=10, n=5)
|
|
msg = HumanMessage(content="Hi")
|
|
gen_response = chat_instance._generate([msg])
|
|
assert isinstance(gen_response, ChatResult)
|
|
assert len(gen_response.generations) == 5
|
|
for gen in gen_response.generations:
|
|
assert isinstance(gen.message, BaseMessage)
|
|
assert isinstance(gen.message.content, str)
|
|
|
|
|
|
def test_konko_streaming_callback_test() -> None:
|
|
"""Evaluate streaming's token callback functionality."""
|
|
callback_instance = FakeCallbackHandler()
|
|
callback_mgr = CallbackManager([callback_instance])
|
|
chat_instance = ChatKonko(
|
|
max_tokens=10,
|
|
streaming=True,
|
|
temperature=0,
|
|
callback_manager=callback_mgr,
|
|
verbose=True,
|
|
)
|
|
msg = HumanMessage(content="Hi")
|
|
chat_response = chat_instance.invoke([msg])
|
|
assert callback_instance.llm_streams > 0
|
|
assert isinstance(chat_response, BaseMessage)
|
|
|
|
|
|
def test_konko_streaming_info_test() -> None:
|
|
"""Ensure generation details are retained during streaming."""
|
|
|
|
class TestCallback(FakeCallbackHandler):
|
|
data_store: dict = {}
|
|
|
|
def on_llm_end(self, *args: Any, **kwargs: Any) -> Any:
|
|
self.data_store["generation"] = args[0]
|
|
|
|
callback_instance = TestCallback()
|
|
callback_mgr = CallbackManager([callback_instance])
|
|
chat_instance = ChatKonko(
|
|
max_tokens=2,
|
|
temperature=0,
|
|
callback_manager=callback_mgr,
|
|
)
|
|
list(chat_instance.stream("hey"))
|
|
gen_data = callback_instance.data_store["generation"]
|
|
assert gen_data.generations[0][0].text == " Hey"
|
|
|
|
|
|
def test_konko_llm_model_name_test() -> None:
|
|
"""Check if llm_output has model info."""
|
|
chat_instance = ChatKonko(max_tokens=10)
|
|
msg = HumanMessage(content="Hi")
|
|
llm_data = chat_instance.generate([[msg]])
|
|
assert llm_data.llm_output is not None
|
|
assert llm_data.llm_output["model_name"] == chat_instance.model
|
|
|
|
|
|
def test_konko_streaming_model_name_test() -> None:
|
|
"""Check model info during streaming."""
|
|
chat_instance = ChatKonko(max_tokens=10, streaming=True)
|
|
msg = HumanMessage(content="Hi")
|
|
llm_data = chat_instance.generate([[msg]])
|
|
assert llm_data.llm_output is not None
|
|
assert llm_data.llm_output["model_name"] == chat_instance.model
|
|
|
|
|
|
def test_konko_streaming_param_validation_test() -> None:
|
|
"""Ensure correct token callback during streaming."""
|
|
with pytest.raises(ValueError):
|
|
ChatKonko(
|
|
max_tokens=10,
|
|
streaming=True,
|
|
temperature=0,
|
|
n=5,
|
|
)
|
|
|
|
|
|
def test_konko_additional_args_test() -> None:
|
|
"""Evaluate extra arguments for ChatKonko."""
|
|
chat_instance = ChatKonko(extra=3, max_tokens=10) # type: ignore[call-arg]
|
|
assert chat_instance.max_tokens == 10
|
|
assert chat_instance.model_kwargs == {"extra": 3}
|
|
|
|
chat_instance = ChatKonko(extra=3, model_kwargs={"addition": 2}) # type: ignore[call-arg]
|
|
assert chat_instance.model_kwargs == {"extra": 3, "addition": 2}
|
|
|
|
with pytest.raises(ValueError):
|
|
ChatKonko(extra=3, model_kwargs={"extra": 2}) # type: ignore[call-arg]
|
|
|
|
with pytest.raises(ValueError):
|
|
ChatKonko(model_kwargs={"temperature": 0.2})
|
|
|
|
with pytest.raises(ValueError):
|
|
ChatKonko(model_kwargs={"model": "text-davinci-003"})
|
|
|
|
|
|
def test_konko_token_streaming_test() -> None:
|
|
"""Check token streaming for ChatKonko."""
|
|
chat_instance = ChatKonko(max_tokens=10)
|
|
|
|
for token in chat_instance.stream("Just a test"):
|
|
assert isinstance(token.content, str)
|