Integrating the Yi family of models. (#24491)

Thank you for contributing to LangChain!

- [x] **PR title**: "community:add Yi LLM", "docs:add Yi Documentation"
                          
- [x] **PR message**: ***Delete this entire checklist*** and replace
with
- **Description:** This PR adds support for the Yi model to LangChain.
- **Dependencies:**
[langchain_core,requests,contextlib,typing,logging,json,langchain_community]
    - **Twitter handle:** 01.AI


- [x] **Add tests and docs**: I've added the corresponding documentation
to the relevant paths

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
This commit is contained in:
Haijian Wang 2024-07-27 01:57:33 +08:00 committed by GitHub
parent ad7581751f
commit cda3025ee1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 856 additions and 13 deletions

View File

@ -0,0 +1,228 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ChatYI\n",
"\n",
"This will help you getting started with Yi [chat models](/docs/concepts/#chat-models). For detailed documentation of all ChatYi features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/lanchain_community.chat_models.yi.ChatYi.html).\n",
"\n",
"[01.AI](https://www.lingyiwanwu.com/en), founded by Dr. Kai-Fu Lee, is a global company at the forefront of AI 2.0. They offer cutting-edge large language models, including the Yi series, which range from 6B to hundreds of billions of parameters. 01.AI also provides multimodal models, an open API platform, and open-source options like Yi-34B/9B/6B and Yi-VL.\n",
"\n",
"## Overview\n",
"### Integration details\n",
"\n",
"\n",
"| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| [ChatYi](https://api.python.langchain.com/en/latest/chat_models/lanchain_community.chat_models.yi.ChatYi.html) | [langchain_community](https://api.python.langchain.com/en/latest/community_api_reference.html) | ✅ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_community?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_community?style=flat-square&label=%20) |\n",
"\n",
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | \n",
"\n",
"## Setup\n",
"\n",
"To access ChatYi models you'll need to create a/an 01.AI account, get an API key, and install the `langchain_community` integration package.\n",
"\n",
"### Credentials\n",
"\n",
"Head to [01.AI](https://platform.01.ai) to sign up to 01.AI and generate an API key. Once you've done this set the `YI_API_KEY` environment variable:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"YI_API_KEY\"] = getpass.getpass(\"Enter your Yi API key: \")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Installation\n",
"\n",
"The LangChain __ModuleName__ integration lives in the `langchain_community` package:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain_community"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our model object and generate chat completions:\n",
"\n",
"- TODO: Update model instantiation with relevant params."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.yi import ChatYi\n",
"\n",
"llm = ChatYi(\n",
" model=\"yi-large\",\n",
" temperature=0,\n",
" timeout=60,\n",
" yi_api_base=\"https://api.01.ai/v1/chat/completions\",\n",
" # other params...\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Invocation\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"Large Language Models (LLMs) have the potential to significantly impact healthcare by enhancing various aspects of patient care, research, and administrative processes. Here are some potential applications:\\n\\n1. **Clinical Documentation and Reporting**: LLMs can assist in generating patient reports and documentation by understanding and summarizing clinical notes, making the process more efficient and reducing the administrative burden on healthcare professionals.\\n\\n2. **Medical Coding and Billing**: These models can help in automating the coding process for medical billing by accurately translating clinical notes into standardized codes, reducing errors and improving billing efficiency.\\n\\n3. **Clinical Decision Support**: LLMs can analyze patient data and medical literature to provide evidence-based recommendations to healthcare providers, aiding in diagnosis and treatment planning.\\n\\n4. **Patient Education and Communication**: By simplifying medical jargon, LLMs can help in educating patients about their conditions, treatment options, and preventive care, improving patient engagement and health literacy.\\n\\n5. **Natural Language Processing (NLP) for EHRs**: LLMs can enhance NLP capabilities in Electronic Health Records (EHRs) systems, enabling better extraction of information from unstructured data, such as clinical notes, to support data-driven decision-making.\\n\\n6. **Drug Discovery and Development**: LLMs can analyze biomedical literature and clinical trial data to identify new drug candidates, predict drug interactions, and support the development of personalized medicine.\\n\\n7. **Telemedicine and Virtual Health Assistants**: Integrated into telemedicine platforms, LLMs can provide preliminary assessments and triage, offering patients basic health advice and determining the urgency of their needs, thus optimizing the utilization of healthcare resources.\\n\\n8. **Research and Literature Review**: LLMs can expedite the process of reviewing medical literature by quickly identifying relevant studies and summarizing findings, accelerating research and evidence-based practice.\\n\\n9. **Personalized Medicine**: By analyzing a patient's genetic information and medical history, LLMs can help in tailoring treatment plans and medication dosages, contributing to the advancement of personalized medicine.\\n\\n10. **Quality Improvement and Risk Assessment**: LLMs can analyze healthcare data to identify patterns that may indicate areas for quality improvement or potential risks, such as hospital-acquired infections or adverse drug events.\\n\\n11. **Mental Health Support**: LLMs can provide mental health support by offering coping strategies, mindfulness exercises, and preliminary assessments, serving as a complement to professional mental health services.\\n\\n12. **Continuing Medical Education (CME)**: LLMs can personalize CME by recommending educational content based on a healthcare provider's practice area, patient demographics, and emerging medical literature, ensuring that professionals stay updated with the latest advancements.\\n\\nWhile the applications of LLMs in healthcare are promising, it's crucial to address challenges such as data privacy, model bias, and the need for regulatory approval to ensure that these technologies are implemented safely and ethically.\", response_metadata={'token_usage': {'completion_tokens': 656, 'prompt_tokens': 40, 'total_tokens': 696}, 'model': 'yi-large'}, id='run-870850bd-e4bf-4265-8730-1736409c0acf-0')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.messages import HumanMessage, SystemMessage\n",
"\n",
"messages = [\n",
" SystemMessage(content=\"You are an AI assistant specializing in technology trends.\"),\n",
" HumanMessage(\n",
" content=\"What are the potential applications of large language models in healthcare?\"\n",
" ),\n",
"]\n",
"\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chaining\n",
"\n",
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Ich liebe das Programmieren.', response_metadata={'token_usage': {'completion_tokens': 8, 'prompt_tokens': 33, 'total_tokens': 41}, 'model': 'yi-large'}, id='run-daa3bc58-8289-4d72-a24e-80622fa90d6d-0')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all ChatYi features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.yi.ChatYi.html"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -0,0 +1,133 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Yi\n",
"[01.AI](https://www.lingyiwanwu.com/en), founded by Dr. Kai-Fu Lee, is a global company at the forefront of AI 2.0. They offer cutting-edge large language models, including the Yi series, which range from 6B to hundreds of billions of parameters. 01.AI also provides multimodal models, an open API platform, and open-source options like Yi-34B/9B/6B and Yi-VL."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## Installing the langchain packages needed to use the integration\n",
"%pip install -qU langchain-community"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prerequisite\n",
"An API key is required to access Yi LLM API. Visit https://www.lingyiwanwu.com/ to get your API key. When applying for the API key, you need to specify whether it's for domestic (China) or international use."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use Yi LLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"YI_API_KEY\"] = \"YOUR_API_KEY\"\n",
"\n",
"from langchain_community.llms import YiLLM\n",
"\n",
"# Load the model\n",
"llm = YiLLM(model=\"yi-large\")\n",
"\n",
"# You can specify the region if needed (default is \"auto\")\n",
"# llm = YiLLM(model=\"yi-large\", region=\"domestic\") # or \"international\"\n",
"\n",
"# Basic usage\n",
"res = llm.invoke(\"What's your name?\")\n",
"print(res)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate method\n",
"res = llm.generate(\n",
" prompts=[\n",
" \"Explain the concept of large language models.\",\n",
" \"What are the potential applications of AI in healthcare?\",\n",
" ]\n",
")\n",
"print(res)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Streaming\n",
"for chunk in llm.stream(\"Describe the key features of the Yi language model series.\"):\n",
" print(chunk, end=\"\", flush=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Asynchronous streaming\n",
"import asyncio\n",
"\n",
"\n",
"async def run_aio_stream():\n",
" async for chunk in llm.astream(\n",
" \"Write a brief on the future of AI according to Dr. Kai-Fu Lee's vision.\"\n",
" ):\n",
" print(chunk, end=\"\", flush=True)\n",
"\n",
"\n",
"asyncio.run(run_aio_stream())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Adjusting parameters\n",
"llm_with_params = YiLLM(\n",
" model=\"yi-large\",\n",
" temperature=0.7,\n",
" top_p=0.9,\n",
")\n",
"\n",
"res = llm_with_params(\n",
" \"Propose an innovative AI application that could benefit society.\"\n",
")\n",
"print(res)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,23 @@
# 01.AI
>[01.AI](https://www.lingyiwanwu.com/en), founded by Dr. Kai-Fu Lee, is a global company at the forefront of AI 2.0. They offer cutting-edge large language models, including the Yi series, which range from 6B to hundreds of billions of parameters. 01.AI also provides multimodal models, an open API platform, and open-source options like Yi-34B/9B/6B and Yi-VL.
## Installation and Setup
Register and get an API key from either the China site [here](https://platform.lingyiwanwu.com/apikeys) or the global site [here](https://platform.01.ai/apikeys).
## LLMs
See a [usage example](/docs/integrations/llms/yi).
```python
from langchain_community.llms import YiLLM
```
## Chat models
See a [usage example](/docs/integrations/chat/yi).
```python
from langchain_community.chat_models import ChatYi
```

View File

@ -165,13 +165,15 @@ if TYPE_CHECKING:
from langchain_community.chat_models.yandex import (
ChatYandexGPT,
)
from langchain_community.chat_models.yi import (
ChatYi,
)
from langchain_community.chat_models.yuan2 import (
ChatYuan2,
)
from langchain_community.chat_models.zhipuai import (
ChatZhipuAI,
)
__all__ = [
"AzureChatOpenAI",
"BedrockChat",
@ -225,6 +227,7 @@ __all__ = [
"QianfanChatEndpoint",
"SolarChat",
"VolcEngineMaasChat",
"ChatYi",
]
@ -281,6 +284,7 @@ _module_lookup = {
"VolcEngineMaasChat": "langchain_community.chat_models.volcengine_maas",
"ChatPremAI": "langchain_community.chat_models.premai",
"ChatLlamaCpp": "langchain_community.chat_models.llamacpp",
"ChatYi": "langchain_community.chat_models.yi",
}

View File

@ -0,0 +1,339 @@
import json
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Type
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
logger = logging.getLogger(__name__)
DEFAULT_API_BASE_CN = "https://api.lingyiwanwu.com/v1/chat/completions"
DEFAULT_API_BASE_GLOBAL = "https://api.01.ai/v1/chat/completions"
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "assistant", "content": message.content}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict.get("content", "") or "")
elif role == "system":
return AIMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role: str = _dict["role"]
content = _dict.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content, type=role)
@asynccontextmanager
async def aconnect_httpx_sse(
client: Any, method: str, url: str, **kwargs: Any
) -> AsyncIterator:
from httpx_sse import EventSource
async with client.stream(method, url, **kwargs) as response:
yield EventSource(response)
class ChatYi(BaseChatModel):
"""Yi chat models API."""
@property
def lc_secrets(self) -> Dict[str, str]:
return {
"yi_api_key": "YI_API_KEY",
}
@property
def lc_serializable(self) -> bool:
return True
yi_api_base: str = Field(default=DEFAULT_API_BASE_CN)
yi_api_key: SecretStr = Field(alias="api_key")
region: str = Field(default="cn") # 默认使用中国区
streaming: bool = False
request_timeout: int = Field(default=60, alias="timeout")
model: str = "yi-large"
temperature: Optional[float] = Field(default=0.7)
top_p: float = 0.7
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
class Config:
allow_population_by_field_name = True
def __init__(self, **kwargs: Any) -> None:
kwargs["yi_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
kwargs,
["yi_api_key", "api_key"],
"YI_API_KEY",
)
)
if kwargs.get("yi_api_base") is None:
region = kwargs.get("region", "cn").lower()
if region == "global":
kwargs["yi_api_base"] = DEFAULT_API_BASE_GLOBAL
else:
kwargs["yi_api_base"] = DEFAULT_API_BASE_CN
all_required_field_names = get_pydantic_field_names(self.__class__)
extra = kwargs.get("model_kwargs", {})
for field_name in list(kwargs):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
extra[field_name] = kwargs.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
kwargs["model_kwargs"] = extra
super().__init__(**kwargs)
@property
def _default_params(self) -> Dict[str, Any]:
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"stream": self.streaming,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
res = self._chat(messages, **kwargs)
if res.status_code != 200:
raise ValueError(f"Error from Yi api response: {res}")
response = res.json()
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
res = self._chat(messages, stream=True, **kwargs)
if res.status_code != 200:
raise ValueError(f"Error from Yi api response: {res}")
default_chunk_class = AIMessageChunk
for chunk in res.iter_lines():
chunk = chunk.decode("utf-8").strip("\r\n")
parts = chunk.split("data: ", 1)
chunk = parts[1] if len(parts) > 1 else None
if chunk is None:
continue
if chunk == "[DONE]":
break
response = json.loads(chunk)
for m in response.get("choices"):
chunk = _convert_delta_to_message_chunk(
m.get("delta"), default_chunk_class
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
headers = self._create_headers_parameters(**kwargs)
payload = self._create_payload_parameters(messages, **kwargs)
import httpx
async with httpx.AsyncClient(
headers=headers, timeout=self.request_timeout
) as client:
response = await client.post(self.yi_api_base, json=payload)
response.raise_for_status()
return self._create_chat_result(response.json())
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
headers = self._create_headers_parameters(**kwargs)
payload = self._create_payload_parameters(messages, stream=True, **kwargs)
import httpx
async with httpx.AsyncClient(
headers=headers, timeout=self.request_timeout
) as client:
async with aconnect_httpx_sse(
client, "POST", self.yi_api_base, json=payload
) as event_source:
async for sse in event_source.aiter_sse():
chunk = json.loads(sse.data)
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], AIMessageChunk
)
finish_reason = choice.get("finish_reason", None)
generation_info = (
{"finish_reason": finish_reason}
if finish_reason is not None
else None
)
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
if finish_reason is not None:
break
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
payload = self._create_payload_parameters(messages, **kwargs)
url = self.yi_api_base
headers = self._create_headers_parameters(**kwargs)
res = requests.post(
url=url,
timeout=self.request_timeout,
headers=headers,
json=payload,
stream=self.streaming,
)
return res
def _create_payload_parameters(
self, messages: List[BaseMessage], **kwargs: Any
) -> Dict[str, Any]:
parameters = {**self._default_params, **kwargs}
temperature = parameters.pop("temperature", 0.7)
top_p = parameters.pop("top_p", 0.7)
model = parameters.pop("model")
stream = parameters.pop("stream", False)
payload = {
"model": model,
"messages": [_convert_message_to_dict(m) for m in messages],
"top_p": top_p,
"temperature": temperature,
"stream": stream,
}
return payload
def _create_headers_parameters(self, **kwargs: Any) -> Dict[str, Any]:
parameters = {**self._default_params, **kwargs}
default_headers = parameters.pop("headers", {})
api_key = ""
if self.yi_api_key:
api_key = self.yi_api_key.get_secret_value()
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
**default_headers,
}
return headers
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for c in response["choices"]:
message = _convert_dict_to_message(c["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
token_usage = response["usage"]
llm_output = {"token_usage": token_usage, "model": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
@property
def _llm_type(self) -> str:
return "yi-chat"

View File

@ -640,12 +640,6 @@ def _import_yuan2() -> Type[BaseLLM]:
return Yuan2
def _import_you() -> Type[BaseLLM]:
from langchain_community.llms.you import You
return You
def _import_volcengine_maas() -> Type[BaseLLM]:
from langchain_community.llms.volcengine_maas import VolcEngineMaasLLM
@ -658,6 +652,18 @@ def _import_sparkllm() -> Type[BaseLLM]:
return SparkLLM
def _import_you() -> Type[BaseLLM]:
from langchain_community.llms.you import You
return You
def _import_yi() -> Type[BaseLLM]:
from langchain_community.llms.yi import YiLLM
return YiLLM
def __getattr__(name: str) -> Any:
if name == "AI21":
return _import_ai21()
@ -853,18 +859,20 @@ def __getattr__(name: str) -> Any:
return _import_yandex_gpt()
elif name == "Yuan2":
return _import_yuan2()
elif name == "You":
return _import_you()
elif name == "VolcEngineMaasLLM":
return _import_volcengine_maas()
elif name == "SparkLLM":
return _import_sparkllm()
elif name == "YiLLM":
return _import_yi()
elif name == "You":
return _import_you()
elif name == "type_to_cls_dict":
# for backwards compatibility
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
k: v() for k, v in get_type_to_cls_dict().items()
}
return type_to_cls_dict
elif name == "SparkLLM":
return _import_sparkllm()
else:
raise AttributeError(f"Could not find: {name}")
@ -967,8 +975,9 @@ __all__ = [
"Writer",
"Xinference",
"YandexGPT",
"You",
"Yuan2",
"YiLLM",
"You",
]
@ -1065,7 +1074,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
"yandex_gpt": _import_yandex_gpt,
"yuan2": _import_yuan2,
"you": _import_you,
"VolcEngineMaasLLM": _import_volcengine_maas,
"SparkLLM": _import_sparkllm,
"yi": _import_yi,
"you": _import_you,
}

View File

@ -0,0 +1,104 @@
from __future__ import annotations
import json
import logging
from typing import Any, Dict, List, Literal, Optional
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, SecretStr
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_community.llms.utils import enforce_stop_tokens
logger = logging.getLogger(__name__)
class YiLLM(LLM):
"""Yi large language models."""
model: str = "yi-large"
temperature: float = 0.3
top_p: float = 0.95
timeout: int = 60
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
yi_api_key: Optional[SecretStr] = None
region: Literal["auto", "domestic", "international"] = "auto"
yi_api_url_domestic: str = "https://api.lingyiwanwu.com/v1/chat/completions"
yi_api_url_international: str = "https://api.01.ai/v1/chat/completions"
def __init__(self, **kwargs: Any):
kwargs["yi_api_key"] = convert_to_secret_str(
get_from_dict_or_env(kwargs, "yi_api_key", "YI_API_KEY")
)
super().__init__(**kwargs)
@property
def _default_params(self) -> Dict[str, Any]:
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
**self.model_kwargs,
}
def _post(self, request: Any) -> Any:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.yi_api_key.get_secret_value()}", # type: ignore
}
urls = []
if self.region == "domestic":
urls = [self.yi_api_url_domestic]
elif self.region == "international":
urls = [self.yi_api_url_international]
else: # auto
urls = [self.yi_api_url_domestic, self.yi_api_url_international]
for url in urls:
try:
response = requests.post(
url,
headers=headers,
json=request,
timeout=self.timeout,
)
if response.status_code == 200:
parsed_json = json.loads(response.text)
return parsed_json["choices"][0]["message"]["content"]
elif (
response.status_code != 403
): # If not a permission error, raise immediately
response.raise_for_status()
except requests.RequestException as e:
if url == urls[-1]: # If this is the last URL to try
raise ValueError(f"An error has occurred: {e}")
else:
logger.warning(f"Failed to connect to {url}, trying next URL")
continue
raise ValueError("Failed to connect to all available URLs")
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
request = self._default_params
request["messages"] = [{"role": "user", "content": prompt}]
request.update(kwargs)
text = self._post(request)
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
@property
def _llm_type(self) -> str:
"""Return type of chat_model."""
return "yi-llm"

View File

@ -54,6 +54,7 @@ EXPECTED_ALL = [
"VolcEngineMaasChat",
"ChatOctoAI",
"ChatSnowflakeCortex",
"ChatYi",
]

View File

@ -98,6 +98,7 @@ EXPECT_ALL = [
"QianfanLLMEndpoint",
"YandexGPT",
"Yuan2",
"YiLLM",
"You",
"VolcEngineMaasLLM",
"WatsonxLLM",