mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 11:55:21 +00:00
Integrating the Yi family of models. (#24491)
Thank you for contributing to LangChain! - [x] **PR title**: "community:add Yi LLM", "docs:add Yi Documentation" - [x] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** This PR adds support for the Yi model to LangChain. - **Dependencies:** [langchain_core,requests,contextlib,typing,logging,json,langchain_community] - **Twitter handle:** 01.AI - [x] **Add tests and docs**: I've added the corresponding documentation to the relevant paths --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
This commit is contained in:
parent
ad7581751f
commit
cda3025ee1
228
docs/docs/integrations/chat/yi.ipynb
Normal file
228
docs/docs/integrations/chat/yi.ipynb
Normal file
@ -0,0 +1,228 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ChatYI\n",
|
||||
"\n",
|
||||
"This will help you getting started with Yi [chat models](/docs/concepts/#chat-models). For detailed documentation of all ChatYi features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/lanchain_community.chat_models.yi.ChatYi.html).\n",
|
||||
"\n",
|
||||
"[01.AI](https://www.lingyiwanwu.com/en), founded by Dr. Kai-Fu Lee, is a global company at the forefront of AI 2.0. They offer cutting-edge large language models, including the Yi series, which range from 6B to hundreds of billions of parameters. 01.AI also provides multimodal models, an open API platform, and open-source options like Yi-34B/9B/6B and Yi-VL.\n",
|
||||
"\n",
|
||||
"## Overview\n",
|
||||
"### Integration details\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n",
|
||||
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
|
||||
"| [ChatYi](https://api.python.langchain.com/en/latest/chat_models/lanchain_community.chat_models.yi.ChatYi.html) | [langchain_community](https://api.python.langchain.com/en/latest/community_api_reference.html) | ✅ | ❌ | ❌ |  |  |\n",
|
||||
"\n",
|
||||
"### Model features\n",
|
||||
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
|
||||
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
|
||||
"| ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | \n",
|
||||
"\n",
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"To access ChatYi models you'll need to create a/an 01.AI account, get an API key, and install the `langchain_community` integration package.\n",
|
||||
"\n",
|
||||
"### Credentials\n",
|
||||
"\n",
|
||||
"Head to [01.AI](https://platform.01.ai) to sign up to 01.AI and generate an API key. Once you've done this set the `YI_API_KEY` environment variable:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"YI_API_KEY\"] = getpass.getpass(\"Enter your Yi API key: \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
|
||||
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Installation\n",
|
||||
"\n",
|
||||
"The LangChain __ModuleName__ integration lives in the `langchain_community` package:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -qU langchain_community"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Instantiation\n",
|
||||
"\n",
|
||||
"Now we can instantiate our model object and generate chat completions:\n",
|
||||
"\n",
|
||||
"- TODO: Update model instantiation with relevant params."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.chat_models.yi import ChatYi\n",
|
||||
"\n",
|
||||
"llm = ChatYi(\n",
|
||||
" model=\"yi-large\",\n",
|
||||
" temperature=0,\n",
|
||||
" timeout=60,\n",
|
||||
" yi_api_base=\"https://api.01.ai/v1/chat/completions\",\n",
|
||||
" # other params...\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Invocation\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Large Language Models (LLMs) have the potential to significantly impact healthcare by enhancing various aspects of patient care, research, and administrative processes. Here are some potential applications:\\n\\n1. **Clinical Documentation and Reporting**: LLMs can assist in generating patient reports and documentation by understanding and summarizing clinical notes, making the process more efficient and reducing the administrative burden on healthcare professionals.\\n\\n2. **Medical Coding and Billing**: These models can help in automating the coding process for medical billing by accurately translating clinical notes into standardized codes, reducing errors and improving billing efficiency.\\n\\n3. **Clinical Decision Support**: LLMs can analyze patient data and medical literature to provide evidence-based recommendations to healthcare providers, aiding in diagnosis and treatment planning.\\n\\n4. **Patient Education and Communication**: By simplifying medical jargon, LLMs can help in educating patients about their conditions, treatment options, and preventive care, improving patient engagement and health literacy.\\n\\n5. **Natural Language Processing (NLP) for EHRs**: LLMs can enhance NLP capabilities in Electronic Health Records (EHRs) systems, enabling better extraction of information from unstructured data, such as clinical notes, to support data-driven decision-making.\\n\\n6. **Drug Discovery and Development**: LLMs can analyze biomedical literature and clinical trial data to identify new drug candidates, predict drug interactions, and support the development of personalized medicine.\\n\\n7. **Telemedicine and Virtual Health Assistants**: Integrated into telemedicine platforms, LLMs can provide preliminary assessments and triage, offering patients basic health advice and determining the urgency of their needs, thus optimizing the utilization of healthcare resources.\\n\\n8. **Research and Literature Review**: LLMs can expedite the process of reviewing medical literature by quickly identifying relevant studies and summarizing findings, accelerating research and evidence-based practice.\\n\\n9. **Personalized Medicine**: By analyzing a patient's genetic information and medical history, LLMs can help in tailoring treatment plans and medication dosages, contributing to the advancement of personalized medicine.\\n\\n10. **Quality Improvement and Risk Assessment**: LLMs can analyze healthcare data to identify patterns that may indicate areas for quality improvement or potential risks, such as hospital-acquired infections or adverse drug events.\\n\\n11. **Mental Health Support**: LLMs can provide mental health support by offering coping strategies, mindfulness exercises, and preliminary assessments, serving as a complement to professional mental health services.\\n\\n12. **Continuing Medical Education (CME)**: LLMs can personalize CME by recommending educational content based on a healthcare provider's practice area, patient demographics, and emerging medical literature, ensuring that professionals stay updated with the latest advancements.\\n\\nWhile the applications of LLMs in healthcare are promising, it's crucial to address challenges such as data privacy, model bias, and the need for regulatory approval to ensure that these technologies are implemented safely and ethically.\", response_metadata={'token_usage': {'completion_tokens': 656, 'prompt_tokens': 40, 'total_tokens': 696}, 'model': 'yi-large'}, id='run-870850bd-e4bf-4265-8730-1736409c0acf-0')"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_core.messages import HumanMessage, SystemMessage\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" SystemMessage(content=\"You are an AI assistant specializing in technology trends.\"),\n",
|
||||
" HumanMessage(\n",
|
||||
" content=\"What are the potential applications of large language models in healthcare?\"\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"ai_msg = llm.invoke(messages)\n",
|
||||
"ai_msg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chaining\n",
|
||||
"\n",
|
||||
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='Ich liebe das Programmieren.', response_metadata={'token_usage': {'completion_tokens': 8, 'prompt_tokens': 33, 'total_tokens': 41}, 'model': 'yi-large'}, id='run-daa3bc58-8289-4d72-a24e-80622fa90d6d-0')"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\n",
|
||||
" \"system\",\n",
|
||||
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
|
||||
" ),\n",
|
||||
" (\"human\", \"{input}\"),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = prompt | llm\n",
|
||||
"chain.invoke(\n",
|
||||
" {\n",
|
||||
" \"input_language\": \"English\",\n",
|
||||
" \"output_language\": \"German\",\n",
|
||||
" \"input\": \"I love programming.\",\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## API reference\n",
|
||||
"\n",
|
||||
"For detailed documentation of all ChatYi features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.yi.ChatYi.html"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
133
docs/docs/integrations/llms/yi.ipynb
Normal file
133
docs/docs/integrations/llms/yi.ipynb
Normal file
@ -0,0 +1,133 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Yi\n",
|
||||
"[01.AI](https://www.lingyiwanwu.com/en), founded by Dr. Kai-Fu Lee, is a global company at the forefront of AI 2.0. They offer cutting-edge large language models, including the Yi series, which range from 6B to hundreds of billions of parameters. 01.AI also provides multimodal models, an open API platform, and open-source options like Yi-34B/9B/6B and Yi-VL."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## Installing the langchain packages needed to use the integration\n",
|
||||
"%pip install -qU langchain-community"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Prerequisite\n",
|
||||
"An API key is required to access Yi LLM API. Visit https://www.lingyiwanwu.com/ to get your API key. When applying for the API key, you need to specify whether it's for domestic (China) or international use."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use Yi LLM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"YI_API_KEY\"] = \"YOUR_API_KEY\"\n",
|
||||
"\n",
|
||||
"from langchain_community.llms import YiLLM\n",
|
||||
"\n",
|
||||
"# Load the model\n",
|
||||
"llm = YiLLM(model=\"yi-large\")\n",
|
||||
"\n",
|
||||
"# You can specify the region if needed (default is \"auto\")\n",
|
||||
"# llm = YiLLM(model=\"yi-large\", region=\"domestic\") # or \"international\"\n",
|
||||
"\n",
|
||||
"# Basic usage\n",
|
||||
"res = llm.invoke(\"What's your name?\")\n",
|
||||
"print(res)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Generate method\n",
|
||||
"res = llm.generate(\n",
|
||||
" prompts=[\n",
|
||||
" \"Explain the concept of large language models.\",\n",
|
||||
" \"What are the potential applications of AI in healthcare?\",\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"print(res)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Streaming\n",
|
||||
"for chunk in llm.stream(\"Describe the key features of the Yi language model series.\"):\n",
|
||||
" print(chunk, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Asynchronous streaming\n",
|
||||
"import asyncio\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def run_aio_stream():\n",
|
||||
" async for chunk in llm.astream(\n",
|
||||
" \"Write a brief on the future of AI according to Dr. Kai-Fu Lee's vision.\"\n",
|
||||
" ):\n",
|
||||
" print(chunk, end=\"\", flush=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"asyncio.run(run_aio_stream())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Adjusting parameters\n",
|
||||
"llm_with_params = YiLLM(\n",
|
||||
" model=\"yi-large\",\n",
|
||||
" temperature=0.7,\n",
|
||||
" top_p=0.9,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"res = llm_with_params(\n",
|
||||
" \"Propose an innovative AI application that could benefit society.\"\n",
|
||||
")\n",
|
||||
"print(res)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
23
docs/docs/integrations/providers/yi.mdx
Normal file
23
docs/docs/integrations/providers/yi.mdx
Normal file
@ -0,0 +1,23 @@
|
||||
# 01.AI
|
||||
|
||||
>[01.AI](https://www.lingyiwanwu.com/en), founded by Dr. Kai-Fu Lee, is a global company at the forefront of AI 2.0. They offer cutting-edge large language models, including the Yi series, which range from 6B to hundreds of billions of parameters. 01.AI also provides multimodal models, an open API platform, and open-source options like Yi-34B/9B/6B and Yi-VL.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
Register and get an API key from either the China site [here](https://platform.lingyiwanwu.com/apikeys) or the global site [here](https://platform.01.ai/apikeys).
|
||||
|
||||
## LLMs
|
||||
|
||||
See a [usage example](/docs/integrations/llms/yi).
|
||||
|
||||
```python
|
||||
from langchain_community.llms import YiLLM
|
||||
```
|
||||
|
||||
## Chat models
|
||||
|
||||
See a [usage example](/docs/integrations/chat/yi).
|
||||
|
||||
```python
|
||||
from langchain_community.chat_models import ChatYi
|
||||
```
|
@ -165,13 +165,15 @@ if TYPE_CHECKING:
|
||||
from langchain_community.chat_models.yandex import (
|
||||
ChatYandexGPT,
|
||||
)
|
||||
from langchain_community.chat_models.yi import (
|
||||
ChatYi,
|
||||
)
|
||||
from langchain_community.chat_models.yuan2 import (
|
||||
ChatYuan2,
|
||||
)
|
||||
from langchain_community.chat_models.zhipuai import (
|
||||
ChatZhipuAI,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureChatOpenAI",
|
||||
"BedrockChat",
|
||||
@ -225,6 +227,7 @@ __all__ = [
|
||||
"QianfanChatEndpoint",
|
||||
"SolarChat",
|
||||
"VolcEngineMaasChat",
|
||||
"ChatYi",
|
||||
]
|
||||
|
||||
|
||||
@ -281,6 +284,7 @@ _module_lookup = {
|
||||
"VolcEngineMaasChat": "langchain_community.chat_models.volcengine_maas",
|
||||
"ChatPremAI": "langchain_community.chat_models.premai",
|
||||
"ChatLlamaCpp": "langchain_community.chat_models.llamacpp",
|
||||
"ChatYi": "langchain_community.chat_models.yi",
|
||||
}
|
||||
|
||||
|
||||
|
339
libs/community/langchain_community/chat_models/yi.py
Normal file
339
libs/community/langchain_community/chat_models/yi.py
Normal file
@ -0,0 +1,339 @@
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Type
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_API_BASE_CN = "https://api.lingyiwanwu.com/v1/chat/completions"
|
||||
DEFAULT_API_BASE_GLOBAL = "https://api.01.ai/v1/chat/completions"
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict.get("content", "") or "")
|
||||
elif role == "system":
|
||||
return AIMessage(content=_dict["content"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role: str = _dict["role"]
|
||||
content = _dict.get("content") or ""
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content, type=role)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def aconnect_httpx_sse(
|
||||
client: Any, method: str, url: str, **kwargs: Any
|
||||
) -> AsyncIterator:
|
||||
from httpx_sse import EventSource
|
||||
|
||||
async with client.stream(method, url, **kwargs) as response:
|
||||
yield EventSource(response)
|
||||
|
||||
|
||||
class ChatYi(BaseChatModel):
|
||||
"""Yi chat models API."""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {
|
||||
"yi_api_key": "YI_API_KEY",
|
||||
}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
yi_api_base: str = Field(default=DEFAULT_API_BASE_CN)
|
||||
yi_api_key: SecretStr = Field(alias="api_key")
|
||||
region: str = Field(default="cn") # 默认使用中国区
|
||||
streaming: bool = False
|
||||
request_timeout: int = Field(default=60, alias="timeout")
|
||||
model: str = "yi-large"
|
||||
temperature: Optional[float] = Field(default=0.7)
|
||||
top_p: float = 0.7
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
kwargs["yi_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
kwargs,
|
||||
["yi_api_key", "api_key"],
|
||||
"YI_API_KEY",
|
||||
)
|
||||
)
|
||||
if kwargs.get("yi_api_base") is None:
|
||||
region = kwargs.get("region", "cn").lower()
|
||||
if region == "global":
|
||||
kwargs["yi_api_base"] = DEFAULT_API_BASE_GLOBAL
|
||||
else:
|
||||
kwargs["yi_api_base"] = DEFAULT_API_BASE_CN
|
||||
|
||||
all_required_field_names = get_pydantic_field_names(self.__class__)
|
||||
extra = kwargs.get("model_kwargs", {})
|
||||
for field_name in list(kwargs):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
extra[field_name] = kwargs.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
kwargs["model_kwargs"] = extra
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"stream": self.streaming,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
res = self._chat(messages, **kwargs)
|
||||
if res.status_code != 200:
|
||||
raise ValueError(f"Error from Yi api response: {res}")
|
||||
response = res.json()
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
res = self._chat(messages, stream=True, **kwargs)
|
||||
if res.status_code != 200:
|
||||
raise ValueError(f"Error from Yi api response: {res}")
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in res.iter_lines():
|
||||
chunk = chunk.decode("utf-8").strip("\r\n")
|
||||
parts = chunk.split("data: ", 1)
|
||||
chunk = parts[1] if len(parts) > 1 else None
|
||||
if chunk is None:
|
||||
continue
|
||||
if chunk == "[DONE]":
|
||||
break
|
||||
response = json.loads(chunk)
|
||||
for m in response.get("choices"):
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
m.get("delta"), default_chunk_class
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
headers = self._create_headers_parameters(**kwargs)
|
||||
payload = self._create_payload_parameters(messages, **kwargs)
|
||||
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers=headers, timeout=self.request_timeout
|
||||
) as client:
|
||||
response = await client.post(self.yi_api_base, json=payload)
|
||||
response.raise_for_status()
|
||||
return self._create_chat_result(response.json())
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
headers = self._create_headers_parameters(**kwargs)
|
||||
payload = self._create_payload_parameters(messages, stream=True, **kwargs)
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers=headers, timeout=self.request_timeout
|
||||
) as client:
|
||||
async with aconnect_httpx_sse(
|
||||
client, "POST", self.yi_api_base, json=payload
|
||||
) as event_source:
|
||||
async for sse in event_source.aiter_sse():
|
||||
chunk = json.loads(sse.data)
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], AIMessageChunk
|
||||
)
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
|
||||
generation_info = (
|
||||
{"finish_reason": finish_reason}
|
||||
if finish_reason is not None
|
||||
else None
|
||||
)
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
yield chunk
|
||||
if finish_reason is not None:
|
||||
break
|
||||
|
||||
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
|
||||
payload = self._create_payload_parameters(messages, **kwargs)
|
||||
url = self.yi_api_base
|
||||
headers = self._create_headers_parameters(**kwargs)
|
||||
|
||||
res = requests.post(
|
||||
url=url,
|
||||
timeout=self.request_timeout,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=self.streaming,
|
||||
)
|
||||
return res
|
||||
|
||||
def _create_payload_parameters(
|
||||
self, messages: List[BaseMessage], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
parameters = {**self._default_params, **kwargs}
|
||||
temperature = parameters.pop("temperature", 0.7)
|
||||
top_p = parameters.pop("top_p", 0.7)
|
||||
model = parameters.pop("model")
|
||||
stream = parameters.pop("stream", False)
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"stream": stream,
|
||||
}
|
||||
return payload
|
||||
|
||||
def _create_headers_parameters(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
parameters = {**self._default_params, **kwargs}
|
||||
default_headers = parameters.pop("headers", {})
|
||||
api_key = ""
|
||||
if self.yi_api_key:
|
||||
api_key = self.yi_api_key.get_secret_value()
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
**default_headers,
|
||||
}
|
||||
return headers
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for c in response["choices"]:
|
||||
message = _convert_dict_to_message(c["message"])
|
||||
gen = ChatGeneration(message=message)
|
||||
generations.append(gen)
|
||||
|
||||
token_usage = response["usage"]
|
||||
llm_output = {"token_usage": token_usage, "model": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "yi-chat"
|
@ -640,12 +640,6 @@ def _import_yuan2() -> Type[BaseLLM]:
|
||||
return Yuan2
|
||||
|
||||
|
||||
def _import_you() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.you import You
|
||||
|
||||
return You
|
||||
|
||||
|
||||
def _import_volcengine_maas() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.volcengine_maas import VolcEngineMaasLLM
|
||||
|
||||
@ -658,6 +652,18 @@ def _import_sparkllm() -> Type[BaseLLM]:
|
||||
return SparkLLM
|
||||
|
||||
|
||||
def _import_you() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.you import You
|
||||
|
||||
return You
|
||||
|
||||
|
||||
def _import_yi() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.yi import YiLLM
|
||||
|
||||
return YiLLM
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "AI21":
|
||||
return _import_ai21()
|
||||
@ -853,18 +859,20 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_yandex_gpt()
|
||||
elif name == "Yuan2":
|
||||
return _import_yuan2()
|
||||
elif name == "You":
|
||||
return _import_you()
|
||||
elif name == "VolcEngineMaasLLM":
|
||||
return _import_volcengine_maas()
|
||||
elif name == "SparkLLM":
|
||||
return _import_sparkllm()
|
||||
elif name == "YiLLM":
|
||||
return _import_yi()
|
||||
elif name == "You":
|
||||
return _import_you()
|
||||
elif name == "type_to_cls_dict":
|
||||
# for backwards compatibility
|
||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
k: v() for k, v in get_type_to_cls_dict().items()
|
||||
}
|
||||
return type_to_cls_dict
|
||||
elif name == "SparkLLM":
|
||||
return _import_sparkllm()
|
||||
else:
|
||||
raise AttributeError(f"Could not find: {name}")
|
||||
|
||||
@ -967,8 +975,9 @@ __all__ = [
|
||||
"Writer",
|
||||
"Xinference",
|
||||
"YandexGPT",
|
||||
"You",
|
||||
"Yuan2",
|
||||
"YiLLM",
|
||||
"You",
|
||||
]
|
||||
|
||||
|
||||
@ -1065,7 +1074,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
|
||||
"yandex_gpt": _import_yandex_gpt,
|
||||
"yuan2": _import_yuan2,
|
||||
"you": _import_you,
|
||||
"VolcEngineMaasLLM": _import_volcengine_maas,
|
||||
"SparkLLM": _import_sparkllm,
|
||||
"yi": _import_yi,
|
||||
"you": _import_you,
|
||||
}
|
||||
|
104
libs/community/langchain_community/llms/yi.py
Normal file
104
libs/community/langchain_community/llms/yi.py
Normal file
@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class YiLLM(LLM):
|
||||
"""Yi large language models."""
|
||||
|
||||
model: str = "yi-large"
|
||||
temperature: float = 0.3
|
||||
top_p: float = 0.95
|
||||
timeout: int = 60
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
yi_api_key: Optional[SecretStr] = None
|
||||
region: Literal["auto", "domestic", "international"] = "auto"
|
||||
yi_api_url_domestic: str = "https://api.lingyiwanwu.com/v1/chat/completions"
|
||||
yi_api_url_international: str = "https://api.01.ai/v1/chat/completions"
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
kwargs["yi_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(kwargs, "yi_api_key", "YI_API_KEY")
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
def _post(self, request: Any) -> Any:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.yi_api_key.get_secret_value()}", # type: ignore
|
||||
}
|
||||
|
||||
urls = []
|
||||
if self.region == "domestic":
|
||||
urls = [self.yi_api_url_domestic]
|
||||
elif self.region == "international":
|
||||
urls = [self.yi_api_url_international]
|
||||
else: # auto
|
||||
urls = [self.yi_api_url_domestic, self.yi_api_url_international]
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=request,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
parsed_json = json.loads(response.text)
|
||||
return parsed_json["choices"][0]["message"]["content"]
|
||||
elif (
|
||||
response.status_code != 403
|
||||
): # If not a permission error, raise immediately
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
if url == urls[-1]: # If this is the last URL to try
|
||||
raise ValueError(f"An error has occurred: {e}")
|
||||
else:
|
||||
logger.warning(f"Failed to connect to {url}, trying next URL")
|
||||
continue
|
||||
|
||||
raise ValueError("Failed to connect to all available URLs")
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
request = self._default_params
|
||||
request["messages"] = [{"role": "user", "content": prompt}]
|
||||
request.update(kwargs)
|
||||
text = self._post(request)
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat_model."""
|
||||
return "yi-llm"
|
@ -54,6 +54,7 @@ EXPECTED_ALL = [
|
||||
"VolcEngineMaasChat",
|
||||
"ChatOctoAI",
|
||||
"ChatSnowflakeCortex",
|
||||
"ChatYi",
|
||||
]
|
||||
|
||||
|
||||
|
@ -98,6 +98,7 @@ EXPECT_ALL = [
|
||||
"QianfanLLMEndpoint",
|
||||
"YandexGPT",
|
||||
"Yuan2",
|
||||
"YiLLM",
|
||||
"You",
|
||||
"VolcEngineMaasLLM",
|
||||
"WatsonxLLM",
|
||||
|
Loading…
Reference in New Issue
Block a user