mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-27 13:31:53 +00:00
Added support for Nebula Chat model (#21925)
Description: Added support for Nebula Chat model in addition to Nebula Instruct Dependencies: N/A Twitter handle: @Symbldotai --------- Co-authored-by: Eugene Yurtsev <eugene@langchain.dev> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
080741d336
commit
90439b12f6
297
docs/docs/integrations/chat/symblai_nebula.ipynb
Normal file
297
docs/docs/integrations/chat/symblai_nebula.ipynb
Normal file
@ -0,0 +1,297 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"id": "53fbf15f",
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "raw"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"---\n",
|
||||
"sidebar_label: Nebula (Symbl.ai)\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bf733a38-db84-4363-89e2-de6735c37230",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Nebula (Symbl.ai)\n",
|
||||
"\n",
|
||||
"## Overview\n",
|
||||
"This notebook covers how to get started with [Nebula](https://docs.symbl.ai/docs/nebula-llm) - Symbl.ai's chat model.\n",
|
||||
"\n",
|
||||
"### Integration details\n",
|
||||
"Head to the [API reference](https://docs.symbl.ai/reference/nebula-chat) for detailed documentation.\n",
|
||||
"\n",
|
||||
"### Model features: TODO"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3607d67e-e56c-4102-bbba-df2edc0e109e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"### Credentials\n",
|
||||
"To get started, request a [Nebula API key](https://platform.symbl.ai/#/login) and set the `NEBULA_API_KEY` environment variable:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "2108b517-1e8d-473d-92fa-4f930e8072a7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"NEBULA_API_KEY\"] = getpass.getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "68b44357",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Installation\n",
|
||||
"The integration is set up in the `langchain-community` package."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4c26754b-b3c9-4d93-8f36-43049bd943bf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Instantiation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0fdd26e7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.chat_models.symblai_nebula import ChatNebula\n",
|
||||
"from langchain_core.messages import AIMessage, HumanMessage, SystemMessage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = ChatNebula(max_tokens=1024, temperature=0.5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2a915547",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Invocation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=[{'role': 'human', 'text': 'What is the capital of France?'}, {'role': 'assistant', 'text': 'The capital of France is Paris.'}])"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" SystemMessage(\n",
|
||||
" content=\"You are a helpful assistant that answers general knowledge questions.\"\n",
|
||||
" ),\n",
|
||||
" HumanMessage(content=\"What is the capital of France?\"),\n",
|
||||
"]\n",
|
||||
"chat.invoke(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9723913f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Async"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=[{'role': 'human', 'text': 'What is the capital of France?'}, {'role': 'assistant', 'text': 'The capital of France is Paris.'}])"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await chat.ainvoke(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e0a1d3b4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Streaming"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" The capital of France is Paris."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for chunk in chat.stream(messages):\n",
|
||||
" print(chunk.content, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9f91b7c7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Batch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "054dc648",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[AIMessage(content=[{'role': 'human', 'text': 'What is the capital of France?'}, {'role': 'assistant', 'text': 'The capital of France is Paris.'}])]"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat.batch([messages])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e59a5519",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chaining"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "6455f67b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"\n",
|
||||
"prompt = ChatPromptTemplate.from_template(\"Tell me a joke about {topic}\")\n",
|
||||
"chain = prompt | chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "deb1e2a1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=[{'role': 'human', 'text': 'Tell me a joke about cows'}, {'role': 'assistant', 'text': \"Sure, here's a joke about cows:\\n\\nWhy did the cow cross the road?\\n\\nTo get to the udder side!\"}])"
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke({\"topic\": \"cows\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bb9d4755",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## API reference\n",
|
||||
"\n",
|
||||
"Check out the [API reference](https://python.langchain.com/v0.2/api_reference/community/chat_models/langchain_community.chat_models.symblai_nebula.ChatNebula.html) for more detail."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -153,6 +153,7 @@ if TYPE_CHECKING:
|
||||
from langchain_community.chat_models.sparkllm import (
|
||||
ChatSparkLLM,
|
||||
)
|
||||
from langchain_community.chat_models.symblai_nebula import ChatNebula
|
||||
from langchain_community.chat_models.tongyi import (
|
||||
ChatTongyi,
|
||||
)
|
||||
@ -201,6 +202,7 @@ __all__ = [
|
||||
"ChatMLflowAIGateway",
|
||||
"ChatMaritalk",
|
||||
"ChatMlflow",
|
||||
"ChatNebula",
|
||||
"ChatOCIGenAI",
|
||||
"ChatOllama",
|
||||
"ChatOpenAI",
|
||||
@ -257,6 +259,7 @@ _module_lookup = {
|
||||
"ChatMLX": "langchain_community.chat_models.mlx",
|
||||
"ChatMaritalk": "langchain_community.chat_models.maritalk",
|
||||
"ChatMlflow": "langchain_community.chat_models.mlflow",
|
||||
"ChatNebula": "langchain_community.chat_models.symblai_nebula",
|
||||
"ChatOctoAI": "langchain_community.chat_models.octoai",
|
||||
"ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai",
|
||||
"ChatOllama": "langchain_community.chat_models.ollama",
|
||||
|
271
libs/community/langchain_community/chat_models/symblai_nebula.py
Normal file
271
libs/community/langchain_community/chat_models/symblai_nebula.py
Normal file
@ -0,0 +1,271 @@
|
||||
import json
|
||||
import os
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
import requests
|
||||
from aiohttp import ClientSession
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr
|
||||
from langchain_core.utils import convert_to_secret_str
|
||||
|
||||
|
||||
def _convert_role(role: str) -> str:
|
||||
map = {"ai": "assistant", "human": "human", "chat": "human"}
|
||||
if role in map:
|
||||
return map[role]
|
||||
else:
|
||||
raise ValueError(f"Unknown role type: {role}")
|
||||
|
||||
|
||||
def _format_nebula_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
|
||||
system = ""
|
||||
formatted_messages = []
|
||||
for message in messages[:-1]:
|
||||
if message.type == "system":
|
||||
if isinstance(message.content, str):
|
||||
system = message.content
|
||||
else:
|
||||
raise ValueError("System prompt must be a string")
|
||||
else:
|
||||
formatted_messages.append(
|
||||
{
|
||||
"role": _convert_role(message.type),
|
||||
"text": message.content,
|
||||
}
|
||||
)
|
||||
|
||||
text = messages[-1].content
|
||||
formatted_messages.append({"role": "human", "text": text})
|
||||
return {"system_prompt": system, "messages": formatted_messages}
|
||||
|
||||
|
||||
class ChatNebula(BaseChatModel):
|
||||
"""`Nebula` chat large language model - https://docs.symbl.ai/docs/nebula-llm
|
||||
|
||||
API Reference: https://docs.symbl.ai/reference/nebula-chat
|
||||
|
||||
To use, set the environment variable ``NEBULA_API_KEY``,
|
||||
or pass it as a named parameter to the constructor.
|
||||
To request an API key, visit https://platform.symbl.ai/#/login
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatNebula
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
|
||||
chat = ChatNebula(max_new_tokens=1024, temperature=0.5)
|
||||
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant."
|
||||
),
|
||||
HumanMessage(
|
||||
"Answer the following question. How can I help save the world."
|
||||
),
|
||||
]
|
||||
chat.invoke(messages)
|
||||
"""
|
||||
|
||||
max_new_tokens: int = 1024
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
|
||||
temperature: Optional[float] = 0
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
|
||||
streaming: bool = False
|
||||
|
||||
nebula_api_url: str = "https://api-nebula.symbl.ai"
|
||||
|
||||
nebula_api_key: Optional[SecretStr] = Field(None, description="Nebula API Token")
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
if "nebula_api_key" in kwargs:
|
||||
api_key = convert_to_secret_str(kwargs.pop("nebula_api_key"))
|
||||
elif "NEBULA_API_KEY" in os.environ:
|
||||
api_key = convert_to_secret_str(os.environ["NEBULA_API_KEY"])
|
||||
else:
|
||||
api_key = None
|
||||
super().__init__(nebula_api_key=api_key, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "nebula-chat"
|
||||
|
||||
@property
|
||||
def _api_key(self) -> str:
|
||||
if self.nebula_api_key:
|
||||
return self.nebula_api_key.get_secret_value()
|
||||
return ""
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Call out to Nebula's chat endpoint."""
|
||||
url = f"{self.nebula_api_url}/v1/model/chat/streaming"
|
||||
headers = {
|
||||
"ApiKey": self._api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
formatted_data = _format_nebula_messages(messages=messages)
|
||||
payload: Dict[str, Any] = {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"temperature": self.temperature,
|
||||
**formatted_data,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
json_payload = json.dumps(payload)
|
||||
|
||||
response = requests.request(
|
||||
"POST", url, headers=headers, data=json_payload, stream=True
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
for chunk_response in response.iter_lines():
|
||||
chunk_decoded = chunk_response.decode()[6:]
|
||||
try:
|
||||
chunk = json.loads(chunk_decoded)
|
||||
except JSONDecodeError:
|
||||
continue
|
||||
token = chunk["delta"]
|
||||
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token, chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
url = f"{self.nebula_api_url}/v1/model/chat/streaming"
|
||||
headers = {"ApiKey": self._api_key, "Content-Type": "application/json"}
|
||||
formatted_data = _format_nebula_messages(messages=messages)
|
||||
payload: Dict[str, Any] = {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"temperature": self.temperature,
|
||||
**formatted_data,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
json_payload = json.dumps(payload)
|
||||
|
||||
async with ClientSession() as session:
|
||||
async with session.post( # type: ignore[call-arg]
|
||||
url, data=json_payload, headers=headers, stream=True
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for chunk_response in response.content:
|
||||
chunk_decoded = chunk_response.decode()[6:]
|
||||
try:
|
||||
chunk = json.loads(chunk_decoded)
|
||||
except JSONDecodeError:
|
||||
continue
|
||||
token = chunk["delta"]
|
||||
cg_chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(content=token)
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(token, chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
url = f"{self.nebula_api_url}/v1/model/chat"
|
||||
headers = {"ApiKey": self._api_key, "Content-Type": "application/json"}
|
||||
formatted_data = _format_nebula_messages(messages=messages)
|
||||
payload: Dict[str, Any] = {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"temperature": self.temperature,
|
||||
**formatted_data,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
json_payload = json.dumps(payload)
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=json_payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=data["messages"]))],
|
||||
llm_output=data,
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
url = f"{self.nebula_api_url}/v1/model/chat"
|
||||
headers = {"ApiKey": self._api_key, "Content-Type": "application/json"}
|
||||
formatted_data = _format_nebula_messages(messages=messages)
|
||||
payload: Dict[str, Any] = {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"temperature": self.temperature,
|
||||
**formatted_data,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
json_payload = json.dumps(payload)
|
||||
|
||||
async with ClientSession() as session:
|
||||
async with session.post(
|
||||
url, data=json_payload, headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=AIMessage(content=data["messages"]))
|
||||
],
|
||||
llm_output=data,
|
||||
)
|
@ -28,6 +28,7 @@ EXPECTED_ALL = [
|
||||
"ChatMlflow",
|
||||
"ChatMLflowAIGateway",
|
||||
"ChatMLX",
|
||||
"ChatNebula",
|
||||
"ChatOCIGenAI",
|
||||
"ChatOllama",
|
||||
"ChatOpenAI",
|
||||
|
Loading…
Reference in New Issue
Block a user