mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-27 13:31:53 +00:00
Added support for Nebula Chat model (#21925)
Description: Added support for Nebula Chat model in addition to Nebula Instruct Dependencies: N/A Twitter handle: @Symbldotai --------- Co-authored-by: Eugene Yurtsev <eugene@langchain.dev> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
080741d336
commit
90439b12f6
297
docs/docs/integrations/chat/symblai_nebula.ipynb
Normal file
297
docs/docs/integrations/chat/symblai_nebula.ipynb
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "raw",
|
||||||
|
"id": "53fbf15f",
|
||||||
|
"metadata": {
|
||||||
|
"vscode": {
|
||||||
|
"languageId": "raw"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"---\n",
|
||||||
|
"sidebar_label: Nebula (Symbl.ai)\n",
|
||||||
|
"---"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bf733a38-db84-4363-89e2-de6735c37230",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Nebula (Symbl.ai)\n",
|
||||||
|
"\n",
|
||||||
|
"## Overview\n",
|
||||||
|
"This notebook covers how to get started with [Nebula](https://docs.symbl.ai/docs/nebula-llm) - Symbl.ai's chat model.\n",
|
||||||
|
"\n",
|
||||||
|
"### Integration details\n",
|
||||||
|
"Head to the [API reference](https://docs.symbl.ai/reference/nebula-chat) for detailed documentation.\n",
|
||||||
|
"\n",
|
||||||
|
"### Model features: TODO"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "3607d67e-e56c-4102-bbba-df2edc0e109e",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Setup\n",
|
||||||
|
"\n",
|
||||||
|
"### Credentials\n",
|
||||||
|
"To get started, request a [Nebula API key](https://platform.symbl.ai/#/login) and set the `NEBULA_API_KEY` environment variable:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "2108b517-1e8d-473d-92fa-4f930e8072a7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import getpass\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"NEBULA_API_KEY\"] = getpass.getpass()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "68b44357",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Installation\n",
|
||||||
|
"The integration is set up in the `langchain-community` package."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "4c26754b-b3c9-4d93-8f36-43049bd943bf",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Instantiation"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0fdd26e7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.chat_models.symblai_nebula import ChatNebula\n",
|
||||||
|
"from langchain_core.messages import AIMessage, HumanMessage, SystemMessage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = ChatNebula(max_tokens=1024, temperature=0.5)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "2a915547",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Invocation"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=[{'role': 'human', 'text': 'What is the capital of France?'}, {'role': 'assistant', 'text': 'The capital of France is Paris.'}])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"messages = [\n",
|
||||||
|
" SystemMessage(\n",
|
||||||
|
" content=\"You are a helpful assistant that answers general knowledge questions.\"\n",
|
||||||
|
" ),\n",
|
||||||
|
" HumanMessage(content=\"What is the capital of France?\"),\n",
|
||||||
|
"]\n",
|
||||||
|
"chat.invoke(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "9723913f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Async"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=[{'role': 'human', 'text': 'What is the capital of France?'}, {'role': 'assistant', 'text': 'The capital of France is Paris.'}])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"await chat.ainvoke(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e0a1d3b4",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Streaming"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" The capital of France is Paris."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"for chunk in chat.stream(messages):\n",
|
||||||
|
" print(chunk.content, end=\"\", flush=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "9f91b7c7",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Batch"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"id": "054dc648",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[AIMessage(content=[{'role': 'human', 'text': 'What is the capital of France?'}, {'role': 'assistant', 'text': 'The capital of France is Paris.'}])]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chat.batch([messages])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e59a5519",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Chaining"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 18,
|
||||||
|
"id": "6455f67b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||||
|
"\n",
|
||||||
|
"prompt = ChatPromptTemplate.from_template(\"Tell me a joke about {topic}\")\n",
|
||||||
|
"chain = prompt | chat"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"id": "deb1e2a1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=[{'role': 'human', 'text': 'Tell me a joke about cows'}, {'role': 'assistant', 'text': \"Sure, here's a joke about cows:\\n\\nWhy did the cow cross the road?\\n\\nTo get to the udder side!\"}])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 19,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain.invoke({\"topic\": \"cows\"})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bb9d4755",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## API reference\n",
|
||||||
|
"\n",
|
||||||
|
"Check out the [API reference](https://python.langchain.com/v0.2/api_reference/community/chat_models/langchain_community.chat_models.symblai_nebula.ChatNebula.html) for more detail."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.8"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -153,6 +153,7 @@ if TYPE_CHECKING:
|
|||||||
from langchain_community.chat_models.sparkllm import (
|
from langchain_community.chat_models.sparkllm import (
|
||||||
ChatSparkLLM,
|
ChatSparkLLM,
|
||||||
)
|
)
|
||||||
|
from langchain_community.chat_models.symblai_nebula import ChatNebula
|
||||||
from langchain_community.chat_models.tongyi import (
|
from langchain_community.chat_models.tongyi import (
|
||||||
ChatTongyi,
|
ChatTongyi,
|
||||||
)
|
)
|
||||||
@ -201,6 +202,7 @@ __all__ = [
|
|||||||
"ChatMLflowAIGateway",
|
"ChatMLflowAIGateway",
|
||||||
"ChatMaritalk",
|
"ChatMaritalk",
|
||||||
"ChatMlflow",
|
"ChatMlflow",
|
||||||
|
"ChatNebula",
|
||||||
"ChatOCIGenAI",
|
"ChatOCIGenAI",
|
||||||
"ChatOllama",
|
"ChatOllama",
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
@ -257,6 +259,7 @@ _module_lookup = {
|
|||||||
"ChatMLX": "langchain_community.chat_models.mlx",
|
"ChatMLX": "langchain_community.chat_models.mlx",
|
||||||
"ChatMaritalk": "langchain_community.chat_models.maritalk",
|
"ChatMaritalk": "langchain_community.chat_models.maritalk",
|
||||||
"ChatMlflow": "langchain_community.chat_models.mlflow",
|
"ChatMlflow": "langchain_community.chat_models.mlflow",
|
||||||
|
"ChatNebula": "langchain_community.chat_models.symblai_nebula",
|
||||||
"ChatOctoAI": "langchain_community.chat_models.octoai",
|
"ChatOctoAI": "langchain_community.chat_models.octoai",
|
||||||
"ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai",
|
"ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai",
|
||||||
"ChatOllama": "langchain_community.chat_models.ollama",
|
"ChatOllama": "langchain_community.chat_models.ollama",
|
||||||
|
271
libs/community/langchain_community/chat_models/symblai_nebula.py
Normal file
271
libs/community/langchain_community/chat_models/symblai_nebula.py
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from json import JSONDecodeError
|
||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.chat_models import (
|
||||||
|
BaseChatModel,
|
||||||
|
agenerate_from_stream,
|
||||||
|
generate_from_stream,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||||
|
from langchain_core.utils import convert_to_secret_str
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_role(role: str) -> str:
|
||||||
|
map = {"ai": "assistant", "human": "human", "chat": "human"}
|
||||||
|
if role in map:
|
||||||
|
return map[role]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown role type: {role}")
|
||||||
|
|
||||||
|
|
||||||
|
def _format_nebula_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
|
||||||
|
system = ""
|
||||||
|
formatted_messages = []
|
||||||
|
for message in messages[:-1]:
|
||||||
|
if message.type == "system":
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
system = message.content
|
||||||
|
else:
|
||||||
|
raise ValueError("System prompt must be a string")
|
||||||
|
else:
|
||||||
|
formatted_messages.append(
|
||||||
|
{
|
||||||
|
"role": _convert_role(message.type),
|
||||||
|
"text": message.content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
text = messages[-1].content
|
||||||
|
formatted_messages.append({"role": "human", "text": text})
|
||||||
|
return {"system_prompt": system, "messages": formatted_messages}
|
||||||
|
|
||||||
|
|
||||||
|
class ChatNebula(BaseChatModel):
|
||||||
|
"""`Nebula` chat large language model - https://docs.symbl.ai/docs/nebula-llm
|
||||||
|
|
||||||
|
API Reference: https://docs.symbl.ai/reference/nebula-chat
|
||||||
|
|
||||||
|
To use, set the environment variable ``NEBULA_API_KEY``,
|
||||||
|
or pass it as a named parameter to the constructor.
|
||||||
|
To request an API key, visit https://platform.symbl.ai/#/login
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.chat_models import ChatNebula
|
||||||
|
from langchain_core.messages import SystemMessage, HumanMessage
|
||||||
|
|
||||||
|
chat = ChatNebula(max_new_tokens=1024, temperature=0.5)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
SystemMessage(
|
||||||
|
content="You are a helpful assistant."
|
||||||
|
),
|
||||||
|
HumanMessage(
|
||||||
|
"Answer the following question. How can I help save the world."
|
||||||
|
),
|
||||||
|
]
|
||||||
|
chat.invoke(messages)
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_new_tokens: int = 1024
|
||||||
|
"""Denotes the number of tokens to predict per generation."""
|
||||||
|
|
||||||
|
temperature: Optional[float] = 0
|
||||||
|
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
|
||||||
|
nebula_api_url: str = "https://api-nebula.symbl.ai"
|
||||||
|
|
||||||
|
nebula_api_key: Optional[SecretStr] = Field(None, description="Nebula API Token")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
if "nebula_api_key" in kwargs:
|
||||||
|
api_key = convert_to_secret_str(kwargs.pop("nebula_api_key"))
|
||||||
|
elif "NEBULA_API_KEY" in os.environ:
|
||||||
|
api_key = convert_to_secret_str(os.environ["NEBULA_API_KEY"])
|
||||||
|
else:
|
||||||
|
api_key = None
|
||||||
|
super().__init__(nebula_api_key=api_key, **kwargs) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "nebula-chat"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _api_key(self) -> str:
|
||||||
|
if self.nebula_api_key:
|
||||||
|
return self.nebula_api_key.get_secret_value()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
"""Call out to Nebula's chat endpoint."""
|
||||||
|
url = f"{self.nebula_api_url}/v1/model/chat/streaming"
|
||||||
|
headers = {
|
||||||
|
"ApiKey": self._api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
formatted_data = _format_nebula_messages(messages=messages)
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"max_new_tokens": self.max_new_tokens,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
**formatted_data,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {k: v for k, v in payload.items() if v is not None}
|
||||||
|
json_payload = json.dumps(payload)
|
||||||
|
|
||||||
|
response = requests.request(
|
||||||
|
"POST", url, headers=headers, data=json_payload, stream=True
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
for chunk_response in response.iter_lines():
|
||||||
|
chunk_decoded = chunk_response.decode()[6:]
|
||||||
|
try:
|
||||||
|
chunk = json.loads(chunk_decoded)
|
||||||
|
except JSONDecodeError:
|
||||||
|
continue
|
||||||
|
token = chunk["delta"]
|
||||||
|
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(token, chunk=cg_chunk)
|
||||||
|
yield cg_chunk
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
url = f"{self.nebula_api_url}/v1/model/chat/streaming"
|
||||||
|
headers = {"ApiKey": self._api_key, "Content-Type": "application/json"}
|
||||||
|
formatted_data = _format_nebula_messages(messages=messages)
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"max_new_tokens": self.max_new_tokens,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
**formatted_data,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {k: v for k, v in payload.items() if v is not None}
|
||||||
|
json_payload = json.dumps(payload)
|
||||||
|
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with session.post( # type: ignore[call-arg]
|
||||||
|
url, data=json_payload, headers=headers, stream=True
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
async for chunk_response in response.content:
|
||||||
|
chunk_decoded = chunk_response.decode()[6:]
|
||||||
|
try:
|
||||||
|
chunk = json.loads(chunk_decoded)
|
||||||
|
except JSONDecodeError:
|
||||||
|
continue
|
||||||
|
token = chunk["delta"]
|
||||||
|
cg_chunk = ChatGenerationChunk(
|
||||||
|
message=AIMessageChunk(content=token)
|
||||||
|
)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(token, chunk=cg_chunk)
|
||||||
|
yield cg_chunk
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
stream_iter = self._stream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return generate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
url = f"{self.nebula_api_url}/v1/model/chat"
|
||||||
|
headers = {"ApiKey": self._api_key, "Content-Type": "application/json"}
|
||||||
|
formatted_data = _format_nebula_messages(messages=messages)
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"max_new_tokens": self.max_new_tokens,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
**formatted_data,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {k: v for k, v in payload.items() if v is not None}
|
||||||
|
json_payload = json.dumps(payload)
|
||||||
|
|
||||||
|
response = requests.request("POST", url, headers=headers, data=json_payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
return ChatResult(
|
||||||
|
generations=[ChatGeneration(message=AIMessage(content=data["messages"]))],
|
||||||
|
llm_output=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
stream_iter = self._astream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return await agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
url = f"{self.nebula_api_url}/v1/model/chat"
|
||||||
|
headers = {"ApiKey": self._api_key, "Content-Type": "application/json"}
|
||||||
|
formatted_data = _format_nebula_messages(messages=messages)
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"max_new_tokens": self.max_new_tokens,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
**formatted_data,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {k: v for k, v in payload.items() if v is not None}
|
||||||
|
json_payload = json.dumps(payload)
|
||||||
|
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
url, data=json_payload, headers=headers
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
data = await response.json()
|
||||||
|
|
||||||
|
return ChatResult(
|
||||||
|
generations=[
|
||||||
|
ChatGeneration(message=AIMessage(content=data["messages"]))
|
||||||
|
],
|
||||||
|
llm_output=data,
|
||||||
|
)
|
@ -28,6 +28,7 @@ EXPECTED_ALL = [
|
|||||||
"ChatMlflow",
|
"ChatMlflow",
|
||||||
"ChatMLflowAIGateway",
|
"ChatMLflowAIGateway",
|
||||||
"ChatMLX",
|
"ChatMLX",
|
||||||
|
"ChatNebula",
|
||||||
"ChatOCIGenAI",
|
"ChatOCIGenAI",
|
||||||
"ChatOllama",
|
"ChatOllama",
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
|
Loading…
Reference in New Issue
Block a user