mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-22 19:09:57 +00:00
Add konko chat model (#10380)
This commit is contained in:
@@ -20,12 +20,12 @@ an interface where "chat messages" are the inputs and outputs.
|
||||
from langchain.chat_models.anthropic import ChatAnthropic
|
||||
from langchain.chat_models.anyscale import ChatAnyscale
|
||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
from langchain.chat_models.bedrock import BedrockChat
|
||||
from langchain.chat_models.ernie import ErnieBotChat
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
from langchain.chat_models.google_palm import ChatGooglePalm
|
||||
from langchain.chat_models.human import HumanInputChatModel
|
||||
from langchain.chat_models.jinachat import JinaChat
|
||||
from langchain.chat_models.konko import ChatKonko
|
||||
from langchain.chat_models.litellm import ChatLiteLLM
|
||||
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
|
||||
from langchain.chat_models.ollama import ChatOllama
|
||||
@@ -36,7 +36,6 @@ from langchain.chat_models.vertexai import ChatVertexAI
|
||||
__all__ = [
|
||||
"ChatOpenAI",
|
||||
"AzureChatOpenAI",
|
||||
"BedrockChat",
|
||||
"FakeListChatModel",
|
||||
"PromptLayerChatOpenAI",
|
||||
"ChatAnthropic",
|
||||
@@ -49,4 +48,5 @@ __all__ = [
|
||||
"ChatAnyscale",
|
||||
"ChatLiteLLM",
|
||||
"ErnieBotChat",
|
||||
"ChatKonko",
|
||||
]
|
||||
|
292
libs/langchain/langchain/chat_models/konko.py
Normal file
292
libs/langchain/langchain/chat_models/konko.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""KonkoAI chat wrapper."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.openai import ChatOpenAI, _convert_delta_to_message_chunk
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
from langchain.schema.messages import AIMessageChunk, BaseMessage
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
DEFAULT_API_BASE = "https://api.konko.ai/v1"
|
||||
DEFAULT_MODEL = "meta-llama/Llama-2-13b-chat-hf"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatKonko(ChatOpenAI):
|
||||
"""`ChatKonko` Chat large language models API.
|
||||
|
||||
To use, you should have the ``konko`` python package installed, and the
|
||||
environment variable ``KONKO_API_KEY`` and ``OPENAI_API_KEY`` set with your API key.
|
||||
|
||||
Any parameters that are valid to be passed to the konko.create call can be passed
|
||||
in, even if not explicitly saved on this class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import ChatKonko
|
||||
llm = ChatKonko(model="meta-llama/Llama-2-13b-chat-hf")
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"konko_api_key": "KONKO_API_KEY", "openai_api_key": "OPENAI_API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
model: str = Field(default=DEFAULT_MODEL, alias="model")
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
openai_api_key: Optional[str] = None
|
||||
konko_api_key: Optional[str] = None
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
"""Timeout for requests to Konko completion API."""
|
||||
max_retries: int = 6
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: int = 20
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["konko_api_key"] = get_from_dict_or_env(
|
||||
values, "konko_api_key", "KONKO_API_KEY"
|
||||
)
|
||||
try:
|
||||
import konko
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import konko python package. "
|
||||
"Please install it with `pip install konko`."
|
||||
)
|
||||
try:
|
||||
values["client"] = konko.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`konko` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the konko package. Try upgrading it "
|
||||
"with `pip install --upgrade konko`."
|
||||
)
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Konko API."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"request_timeout": self.request_timeout,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
"temperature": self.temperature,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_available_models(
|
||||
konko_api_key: Optional[str] = None,
|
||||
openai_api_key: Optional[str] = None,
|
||||
konko_api_base: str = DEFAULT_API_BASE,
|
||||
) -> Set[str]:
|
||||
"""Get available models from Konko API."""
|
||||
|
||||
# Try to retrieve the OpenAI API key if it's not passed as an argument
|
||||
if not openai_api_key:
|
||||
try:
|
||||
openai_api_key = os.environ["OPENAI_API_KEY"]
|
||||
except KeyError:
|
||||
pass # It's okay if it's not set, we just won't use it
|
||||
|
||||
# Try to retrieve the Konko API key if it's not passed as an argument
|
||||
if not konko_api_key:
|
||||
try:
|
||||
konko_api_key = os.environ["KONKO_API_KEY"]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Konko API key must be passed as keyword argument or "
|
||||
"set in environment variable KONKO_API_KEY."
|
||||
)
|
||||
|
||||
models_url = f"{konko_api_base}/models"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {konko_api_key}",
|
||||
}
|
||||
|
||||
if openai_api_key:
|
||||
headers["X-OpenAI-Api-Key"] = openai_api_key
|
||||
|
||||
models_response = requests.get(models_url, headers=headers)
|
||||
|
||||
if models_response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error getting models from {models_url}: "
|
||||
f"{models_response.status_code}"
|
||||
)
|
||||
|
||||
return {model["id"] for model in models_response.json()["data"]}
|
||||
|
||||
def completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return self.client.create(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
for k, v in token_usage.items():
|
||||
if k in overall_token_usage:
|
||||
overall_token_usage[k] += v
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
return {"token_usage": overall_token_usage, "model_name": self.model}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
finish_reason = choice.get("finish_reason")
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=chunk)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if stream if stream is not None else self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
message = convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.get("finish_reason")),
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model_name": self.model}, **self._default_params}
|
||||
|
||||
@property
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used for the konko client."""
|
||||
return {**self._default_params}
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
return {
|
||||
"model": self.model,
|
||||
**super()._get_invocation_params(stop=stop),
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "konko-chat"
|
178
libs/langchain/tests/integration_tests/chat_models/test_konko.py
Normal file
178
libs/langchain/tests/integration_tests/chat_models/test_konko.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Evaluate ChatKonko Interface."""
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.konko import ChatKonko
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_konko_chat_test() -> None:
|
||||
"""Evaluate basic ChatKonko functionality."""
|
||||
chat_instance = ChatKonko(max_tokens=10)
|
||||
msg = HumanMessage(content="Hi")
|
||||
chat_response = chat_instance([msg])
|
||||
assert isinstance(chat_response, BaseMessage)
|
||||
assert isinstance(chat_response.content, str)
|
||||
|
||||
|
||||
def test_konko_chat_test_openai() -> None:
|
||||
"""Evaluate basic ChatKonko functionality."""
|
||||
chat_instance = ChatKonko(max_tokens=10, model="gpt-3.5-turbo")
|
||||
msg = HumanMessage(content="Hi")
|
||||
chat_response = chat_instance([msg])
|
||||
assert isinstance(chat_response, BaseMessage)
|
||||
assert isinstance(chat_response.content, str)
|
||||
|
||||
|
||||
def test_konko_model_test() -> None:
|
||||
"""Check how ChatKonko manages model_name."""
|
||||
chat_instance = ChatKonko(model="alpha")
|
||||
assert chat_instance.model == "alpha"
|
||||
chat_instance = ChatKonko(model="beta")
|
||||
assert chat_instance.model == "beta"
|
||||
|
||||
|
||||
def test_konko_available_model_test() -> None:
|
||||
"""Check how ChatKonko manages model_name."""
|
||||
chat_instance = ChatKonko(max_tokens=10, n=2)
|
||||
res = chat_instance.get_available_models()
|
||||
assert isinstance(res, set)
|
||||
|
||||
|
||||
def test_konko_system_msg_test() -> None:
|
||||
"""Evaluate ChatKonko's handling of system messages."""
|
||||
chat_instance = ChatKonko(max_tokens=10)
|
||||
sys_msg = SystemMessage(content="Initiate user chat.")
|
||||
user_msg = HumanMessage(content="Hi there")
|
||||
chat_response = chat_instance([sys_msg, user_msg])
|
||||
assert isinstance(chat_response, BaseMessage)
|
||||
assert isinstance(chat_response.content, str)
|
||||
|
||||
|
||||
def test_konko_generation_test() -> None:
|
||||
"""Check ChatKonko's generation ability."""
|
||||
chat_instance = ChatKonko(max_tokens=10, n=2)
|
||||
msg = HumanMessage(content="Hi")
|
||||
gen_response = chat_instance.generate([[msg], [msg]])
|
||||
assert isinstance(gen_response, LLMResult)
|
||||
assert len(gen_response.generations) == 2
|
||||
for gen_list in gen_response.generations:
|
||||
assert len(gen_list) == 2
|
||||
for gen in gen_list:
|
||||
assert isinstance(gen, ChatGeneration)
|
||||
assert isinstance(gen.text, str)
|
||||
assert gen.text == gen.message.content
|
||||
|
||||
|
||||
def test_konko_multiple_outputs_test() -> None:
|
||||
"""Test multiple completions with ChatKonko."""
|
||||
chat_instance = ChatKonko(max_tokens=10, n=5)
|
||||
msg = HumanMessage(content="Hi")
|
||||
gen_response = chat_instance._generate([msg])
|
||||
assert isinstance(gen_response, ChatResult)
|
||||
assert len(gen_response.generations) == 5
|
||||
for gen in gen_response.generations:
|
||||
assert isinstance(gen.message, BaseMessage)
|
||||
assert isinstance(gen.message.content, str)
|
||||
|
||||
|
||||
def test_konko_streaming_callback_test() -> None:
|
||||
"""Evaluate streaming's token callback functionality."""
|
||||
callback_instance = FakeCallbackHandler()
|
||||
callback_mgr = CallbackManager([callback_instance])
|
||||
chat_instance = ChatKonko(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_mgr,
|
||||
verbose=True,
|
||||
)
|
||||
msg = HumanMessage(content="Hi")
|
||||
chat_response = chat_instance([msg])
|
||||
assert callback_instance.llm_streams > 0
|
||||
assert isinstance(chat_response, BaseMessage)
|
||||
|
||||
|
||||
def test_konko_streaming_info_test() -> None:
|
||||
"""Ensure generation details are retained during streaming."""
|
||||
|
||||
class TestCallback(FakeCallbackHandler):
|
||||
data_store: dict = {}
|
||||
|
||||
def on_llm_end(self, *args: Any, **kwargs: Any) -> Any:
|
||||
self.data_store["generation"] = args[0]
|
||||
|
||||
callback_instance = TestCallback()
|
||||
callback_mgr = CallbackManager([callback_instance])
|
||||
chat_instance = ChatKonko(
|
||||
max_tokens=2,
|
||||
temperature=0,
|
||||
callback_manager=callback_mgr,
|
||||
)
|
||||
list(chat_instance.stream("hey"))
|
||||
gen_data = callback_instance.data_store["generation"]
|
||||
assert gen_data.generations[0][0].text == " Hey"
|
||||
|
||||
|
||||
def test_konko_llm_model_name_test() -> None:
|
||||
"""Check if llm_output has model info."""
|
||||
chat_instance = ChatKonko(max_tokens=10)
|
||||
msg = HumanMessage(content="Hi")
|
||||
llm_data = chat_instance.generate([[msg]])
|
||||
assert llm_data.llm_output is not None
|
||||
assert llm_data.llm_output["model_name"] == chat_instance.model
|
||||
|
||||
|
||||
def test_konko_streaming_model_name_test() -> None:
|
||||
"""Check model info during streaming."""
|
||||
chat_instance = ChatKonko(max_tokens=10, streaming=True)
|
||||
msg = HumanMessage(content="Hi")
|
||||
llm_data = chat_instance.generate([[msg]])
|
||||
assert llm_data.llm_output is not None
|
||||
assert llm_data.llm_output["model_name"] == chat_instance.model
|
||||
|
||||
|
||||
def test_konko_streaming_param_validation_test() -> None:
|
||||
"""Ensure correct token callback during streaming."""
|
||||
with pytest.raises(ValueError):
|
||||
ChatKonko(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
n=5,
|
||||
)
|
||||
|
||||
|
||||
def test_konko_additional_args_test() -> None:
|
||||
"""Evaluate extra arguments for ChatKonko."""
|
||||
chat_instance = ChatKonko(extra=3, max_tokens=10)
|
||||
assert chat_instance.max_tokens == 10
|
||||
assert chat_instance.model_kwargs == {"extra": 3}
|
||||
|
||||
chat_instance = ChatKonko(extra=3, model_kwargs={"addition": 2})
|
||||
assert chat_instance.model_kwargs == {"extra": 3, "addition": 2}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ChatKonko(extra=3, model_kwargs={"extra": 2})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ChatKonko(model_kwargs={"temperature": 0.2})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ChatKonko(model_kwargs={"model": "text-davinci-003"})
|
||||
|
||||
|
||||
def test_konko_token_streaming_test() -> None:
|
||||
"""Check token streaming for ChatKonko."""
|
||||
chat_instance = ChatKonko(max_tokens=10)
|
||||
|
||||
for token in chat_instance.stream("Just a test"):
|
||||
assert isinstance(token.content, str)
|
Reference in New Issue
Block a user