community[minor]: Adding Konko Completion endpoint (#15570)

This PR introduces update to Konko Integration with LangChain.

1. **New Endpoint Addition**: Integration of a new endpoint to utilize
completion models hosted on Konko.

2. **Chat Model Updates for Backward Compatibility**: We have updated
the chat models to ensure backward compatibility with previous OpenAI
versions.

4. **Updated Documentation**: Comprehensive documentation has been
updated to reflect these new changes, providing clear guidance on
utilizing the new features and ensuring seamless integration.

Thank you to the LangChain team for their exceptional work and for
considering this PR. Please let me know if any additional information is
needed.

---------

Co-authored-by: Shivani Modi <shivanimodi@Shivanis-MacBook-Pro.local>
Co-authored-by: Shivani Modi <shivanimodi@Shivanis-MBP.lan>
This commit is contained in:
Shivani Modi
2024-01-23 18:22:32 -08:00
committed by GitHub
parent c69f599594
commit 4e160540ff
11 changed files with 622 additions and 74 deletions

View File

@@ -3,12 +3,12 @@ from __future__ import annotations
import logging
import os
import warnings
from typing import (
Any,
Dict,
Iterator,
List,
Mapping,
Optional,
Set,
Tuple,
@@ -19,20 +19,20 @@ import requests
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_community.adapters.openai import (
convert_dict_to_message,
convert_message_to_dict,
)
from langchain_community.chat_models.openai import _convert_delta_to_message_chunk
from langchain_community.chat_models.openai import (
ChatOpenAI,
_convert_delta_to_message_chunk,
generate_from_stream,
)
from langchain_community.utils.openai import is_openai_v1
DEFAULT_API_BASE = "https://api.konko.ai/v1"
DEFAULT_MODEL = "meta-llama/Llama-2-13b-chat-hf"
@@ -40,7 +40,7 @@ DEFAULT_MODEL = "meta-llama/Llama-2-13b-chat-hf"
logger = logging.getLogger(__name__)
class ChatKonko(BaseChatModel):
class ChatKonko(ChatOpenAI):
"""`ChatKonko` Chat large language models API.
To use, you should have the ``konko`` python package installed, and the
@@ -72,10 +72,8 @@ class ChatKonko(BaseChatModel):
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[SecretStr] = None
konko_api_key: Optional[SecretStr] = None
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to Konko completion API."""
openai_api_key: Optional[str] = None
konko_api_key: Optional[str] = None
max_retries: int = 6
"""Maximum number of retries to make when generating."""
streaming: bool = False
@@ -100,13 +98,23 @@ class ChatKonko(BaseChatModel):
"Please install it with `pip install konko`."
)
try:
values["client"] = konko.ChatCompletion
if is_openai_v1():
values["client"] = konko.chat.completions
else:
values["client"] = konko.ChatCompletion
except AttributeError:
raise ValueError(
"`konko` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the konko package. Try upgrading it "
"with `pip install --upgrade konko`."
)
if not hasattr(konko, "_is_legacy_openai"):
warnings.warn(
"You are using an older version of the 'konko' package. "
"Please consider upgrading to access new features."
)
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
@@ -118,7 +126,6 @@ class ChatKonko(BaseChatModel):
"""Get the default parameters for calling Konko API."""
return {
"model": self.model,
"request_timeout": self.request_timeout,
"max_tokens": self.max_tokens,
"stream": self.streaming,
"n": self.n,
@@ -182,20 +189,6 @@ class ChatKonko(BaseChatModel):
return _completion_with_retry(**kwargs)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model}
def _stream(
self,
messages: List[BaseMessage],
@@ -259,19 +252,6 @@ class ChatKonko(BaseChatModel):
message_dicts = [convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = convert_dict_to_message(res["message"])
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.get("finish_reason")),
)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""

View File

@@ -270,6 +270,12 @@ def _import_koboldai() -> Any:
return KoboldApiLLM
def _import_konko() -> Any:
from langchain_community.llms.konko import Konko
return Konko
def _import_llamacpp() -> Any:
from langchain_community.llms.llamacpp import LlamaCpp
@@ -639,6 +645,8 @@ def __getattr__(name: str) -> Any:
return _import_javelin_ai_gateway()
elif name == "KoboldApiLLM":
return _import_koboldai()
elif name == "Konko":
return _import_konko()
elif name == "LlamaCpp":
return _import_llamacpp()
elif name == "ManifestWrapper":
@@ -780,6 +788,7 @@ __all__ = [
"HuggingFaceTextGenInference",
"HumanInputLLM",
"KoboldApiLLM",
"Konko",
"LlamaCpp",
"TextGen",
"ManifestWrapper",
@@ -868,6 +877,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"huggingface_textgen_inference": _import_huggingface_text_gen_inference,
"human-input": _import_human,
"koboldai": _import_koboldai,
"konko": _import_konko,
"llamacpp": _import_llamacpp,
"textgen": _import_textgen,
"minimax": _import_minimax,

View File

@@ -0,0 +1,200 @@
"""Wrapper around Konko AI's Completion API."""
import logging
import warnings
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
from langchain_community.utils.openai import is_openai_v1
logger = logging.getLogger(__name__)
class Konko(LLM):
"""Wrapper around Konko AI models.
To use, you'll need an API key. This can be passed in as init param
``konko_api_key`` or set as environment variable ``KONKO_API_KEY``.
Konko AI API reference: https://docs.konko.ai/reference/
"""
base_url: str = "https://api.konko.ai/v1/completions"
"""Base inference API URL."""
konko_api_key: SecretStr
"""Konko AI API key."""
model: str
"""Model name. Available models listed here:
https://docs.konko.ai/reference/get_models
"""
temperature: Optional[float] = None
"""Model temperature."""
top_p: Optional[float] = None
"""Used to dynamically adjust the number of choices for each predicted token based
on the cumulative probabilities. A value of 1 will always yield the same
output. A temperature less than 1 favors more correctness and is appropriate
for question answering or summarization. A value greater than 1 introduces more
randomness in the output.
"""
top_k: Optional[int] = None
"""Used to limit the number of choices for the next predicted word or token. It
specifies the maximum number of tokens to consider at each step, based on their
probability of occurrence. This technique helps to speed up the generation
process and can improve the quality of the generated text by focusing on the
most likely options.
"""
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
repetition_penalty: Optional[float] = None
"""A number that controls the diversity of generated text by reducing the
likelihood of repeated sequences. Higher values decrease repetition.
"""
logprobs: Optional[int] = None
"""An integer that specifies how many top token log probabilities are included in
the response for each token generation step.
"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(pre=True)
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that python package exists in environment."""
try:
import konko
except ImportError:
raise ValueError(
"Could not import konko python package. "
"Please install it with `pip install konko`."
)
if not hasattr(konko, "_is_legacy_openai"):
warnings.warn(
"You are using an older version of the 'konko' package. "
"Please consider upgrading to access new features"
"including the completion endpoint."
)
return values
def construct_payload(
self,
prompt: str,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
payload: Dict[str, Any] = {
**self.default_params,
"prompt": prompt,
"stop": stop_to_use,
**kwargs,
}
return {k: v for k, v in payload.items() if v is not None}
@property
def _llm_type(self) -> str:
"""Return type of model."""
return "konko"
@staticmethod
def get_user_agent() -> str:
from langchain_community import __version__
return f"langchain/{__version__}"
@property
def default_params(self) -> Dict[str, Any]:
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_tokens": self.max_tokens,
"repetition_penalty": self.repetition_penalty,
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Konko's text generation endpoint.
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model..
"""
import konko
payload = self.construct_payload(prompt, stop, **kwargs)
try:
if is_openai_v1():
response = konko.completions.create(**payload)
else:
response = konko.Completion.create(**payload)
except AttributeError:
raise ValueError(
"`konko` has no `Completion` attribute, this is likely "
"due to an old version of the konko package. Try upgrading it "
"with `pip install --upgrade konko`."
)
if is_openai_v1():
output = response.choices[0].text
else:
output = response["choices"][0]["text"]
return output
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Asynchronously call out to Konko's text generation endpoint.
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
"""
import konko
payload = self.construct_payload(prompt, stop, **kwargs)
try:
if is_openai_v1():
client = konko.AsyncKonko()
response = await client.completions.create(**payload)
else:
response = await konko.Completion.acreate(**payload)
except AttributeError:
raise ValueError(
"`konko` has no `Completion` attribute, this is likely "
"due to an old version of the konko package. Try upgrading it "
"with `pip install --upgrade konko`."
)
if is_openai_v1():
output = response.choices[0].text
else:
output = response["choices"][0]["text"]
return output

View File

@@ -63,7 +63,7 @@ def test_konko_chat_test() -> None:
def test_konko_chat_test_openai() -> None:
"""Evaluate basic ChatKonko functionality."""
chat_instance = ChatKonko(max_tokens=10, model="gpt-3.5-turbo")
chat_instance = ChatKonko(max_tokens=10, model="meta-llama/llama-2-70b-chat")
msg = HumanMessage(content="Hi")
chat_response = chat_instance([msg])
assert isinstance(chat_response, BaseMessage)

View File

@@ -0,0 +1,36 @@
"""Test Konko API wrapper.
In order to run this test, you need to have an Konko api key.
You'll then need to set KONKO_API_KEY environment variable to your api key.
"""
import pytest as pytest
from langchain_community.llms import Konko
def test_konko_call() -> None:
"""Test simple call to konko."""
llm = Konko(
model="mistralai/mistral-7b-v0.1",
temperature=0.2,
max_tokens=250,
)
output = llm("Say foo:")
assert llm._llm_type == "konko"
assert isinstance(output, str)
async def test_konko_acall() -> None:
"""Test simple call to konko."""
llm = Konko(
model="mistralai/mistral-7b-v0.1",
temperature=0.2,
max_tokens=250,
)
output = await llm.agenerate(["Say foo:"], stop=["bar"])
assert llm._llm_type == "konko"
output_text = output.generations[0][0].text
assert isinstance(output_text, str)
assert output_text.count("bar") <= 1

View File

@@ -0,0 +1,174 @@
"""Evaluate ChatKonko Interface."""
from typing import Any
import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_community.chat_models.konko import ChatKonko
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_konko_chat_test() -> None:
"""Evaluate basic ChatKonko functionality."""
chat_instance = ChatKonko(max_tokens=10)
msg = HumanMessage(content="Hi")
chat_response = chat_instance([msg])
assert isinstance(chat_response, BaseMessage)
assert isinstance(chat_response.content, str)
def test_konko_chat_test_openai() -> None:
"""Evaluate basic ChatKonko functionality."""
chat_instance = ChatKonko(max_tokens=10, model="meta-llama/llama-2-70b-chat")
msg = HumanMessage(content="Hi")
chat_response = chat_instance([msg])
assert isinstance(chat_response, BaseMessage)
assert isinstance(chat_response.content, str)
def test_konko_model_test() -> None:
"""Check how ChatKonko manages model_name."""
chat_instance = ChatKonko(model="alpha")
assert chat_instance.model == "alpha"
chat_instance = ChatKonko(model="beta")
assert chat_instance.model == "beta"
def test_konko_available_model_test() -> None:
"""Check how ChatKonko manages model_name."""
chat_instance = ChatKonko(max_tokens=10, n=2)
res = chat_instance.get_available_models()
assert isinstance(res, set)
def test_konko_system_msg_test() -> None:
"""Evaluate ChatKonko's handling of system messages."""
chat_instance = ChatKonko(max_tokens=10)
sys_msg = SystemMessage(content="Initiate user chat.")
user_msg = HumanMessage(content="Hi there")
chat_response = chat_instance([sys_msg, user_msg])
assert isinstance(chat_response, BaseMessage)
assert isinstance(chat_response.content, str)
def test_konko_generation_test() -> None:
"""Check ChatKonko's generation ability."""
chat_instance = ChatKonko(max_tokens=10, n=2)
msg = HumanMessage(content="Hi")
gen_response = chat_instance.generate([[msg], [msg]])
assert isinstance(gen_response, LLMResult)
assert len(gen_response.generations) == 2
for gen_list in gen_response.generations:
assert len(gen_list) == 2
for gen in gen_list:
assert isinstance(gen, ChatGeneration)
assert isinstance(gen.text, str)
assert gen.text == gen.message.content
def test_konko_multiple_outputs_test() -> None:
"""Test multiple completions with ChatKonko."""
chat_instance = ChatKonko(max_tokens=10, n=5)
msg = HumanMessage(content="Hi")
gen_response = chat_instance._generate([msg])
assert isinstance(gen_response, ChatResult)
assert len(gen_response.generations) == 5
for gen in gen_response.generations:
assert isinstance(gen.message, BaseMessage)
assert isinstance(gen.message.content, str)
def test_konko_streaming_callback_test() -> None:
"""Evaluate streaming's token callback functionality."""
callback_instance = FakeCallbackHandler()
callback_mgr = CallbackManager([callback_instance])
chat_instance = ChatKonko(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_mgr,
verbose=True,
)
msg = HumanMessage(content="Hi")
chat_response = chat_instance([msg])
assert callback_instance.llm_streams > 0
assert isinstance(chat_response, BaseMessage)
def test_konko_streaming_info_test() -> None:
"""Ensure generation details are retained during streaming."""
class TestCallback(FakeCallbackHandler):
data_store: dict = {}
def on_llm_end(self, *args: Any, **kwargs: Any) -> Any:
self.data_store["generation"] = args[0]
callback_instance = TestCallback()
callback_mgr = CallbackManager([callback_instance])
chat_instance = ChatKonko(
max_tokens=2,
temperature=0,
callback_manager=callback_mgr,
)
list(chat_instance.stream("hey"))
gen_data = callback_instance.data_store["generation"]
assert gen_data.generations[0][0].text == " Hey"
def test_konko_llm_model_name_test() -> None:
"""Check if llm_output has model info."""
chat_instance = ChatKonko(max_tokens=10)
msg = HumanMessage(content="Hi")
llm_data = chat_instance.generate([[msg]])
assert llm_data.llm_output is not None
assert llm_data.llm_output["model_name"] == chat_instance.model
def test_konko_streaming_model_name_test() -> None:
"""Check model info during streaming."""
chat_instance = ChatKonko(max_tokens=10, streaming=True)
msg = HumanMessage(content="Hi")
llm_data = chat_instance.generate([[msg]])
assert llm_data.llm_output is not None
assert llm_data.llm_output["model_name"] == chat_instance.model
def test_konko_streaming_param_validation_test() -> None:
"""Ensure correct token callback during streaming."""
with pytest.raises(ValueError):
ChatKonko(
max_tokens=10,
streaming=True,
temperature=0,
n=5,
)
def test_konko_additional_args_test() -> None:
"""Evaluate extra arguments for ChatKonko."""
chat_instance = ChatKonko(extra=3, max_tokens=10)
assert chat_instance.max_tokens == 10
assert chat_instance.model_kwargs == {"extra": 3}
chat_instance = ChatKonko(extra=3, model_kwargs={"addition": 2})
assert chat_instance.model_kwargs == {"extra": 3, "addition": 2}
with pytest.raises(ValueError):
ChatKonko(extra=3, model_kwargs={"extra": 2})
with pytest.raises(ValueError):
ChatKonko(model_kwargs={"temperature": 0.2})
with pytest.raises(ValueError):
ChatKonko(model_kwargs={"model": "text-davinci-003"})
def test_konko_token_streaming_test() -> None:
"""Check token streaming for ChatKonko."""
chat_instance = ChatKonko(max_tokens=10)
for token in chat_instance.stream("Just a test"):
assert isinstance(token.content, str)

View File

@@ -0,0 +1,36 @@
"""Test Konko API wrapper.
In order to run this test, you need to have an Konko api key.
You'll then need to set KONKO_API_KEY environment variable to your api key.
"""
import pytest as pytest
from langchain_community.llms import Konko
def test_konko_call() -> None:
"""Test simple call to konko."""
llm = Konko(
model="mistralai/mistral-7b-v0.1",
temperature=0.2,
max_tokens=250,
)
output = llm("Say foo:")
assert llm._llm_type == "konko"
assert isinstance(output, str)
async def test_konko_acall() -> None:
"""Test simple call to konko."""
llm = Konko(
model="mistralai/mistral-7b-v0.1",
temperature=0.2,
max_tokens=250,
)
output = await llm.agenerate(["Say foo:"], stop=["bar"])
assert llm._llm_type == "konko"
output_text = output.generations[0][0].text
assert isinstance(output_text, str)
assert output_text.count("bar") <= 1

View File

@@ -41,6 +41,7 @@ EXPECT_ALL = [
"HuggingFaceTextGenInference",
"HumanInputLLM",
"KoboldApiLLM",
"Konko",
"LlamaCpp",
"TextGen",
"ManifestWrapper",