Added support for Nebula Chat model (#21925)

Description: Added support for Nebula Chat model in addition to Nebula
Instruct
Dependencies: N/A
Twitter handle: @Symbldotai

---------

Co-authored-by: Eugene Yurtsev <eugene@langchain.dev>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Sharmistha S. Gupta 2024-08-23 15:34:32 -07:00 committed by GitHub
parent 080741d336
commit 90439b12f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 572 additions and 0 deletions

View File

@ -0,0 +1,297 @@
{
"cells": [
{
"cell_type": "raw",
"id": "53fbf15f",
"metadata": {
"vscode": {
"languageId": "raw"
}
},
"source": [
"---\n",
"sidebar_label: Nebula (Symbl.ai)\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "bf733a38-db84-4363-89e2-de6735c37230",
"metadata": {},
"source": [
"# Nebula (Symbl.ai)\n",
"\n",
"## Overview\n",
"This notebook covers how to get started with [Nebula](https://docs.symbl.ai/docs/nebula-llm) - Symbl.ai's chat model.\n",
"\n",
"### Integration details\n",
"Head to the [API reference](https://docs.symbl.ai/reference/nebula-chat) for detailed documentation.\n",
"\n",
"### Model features: TODO"
]
},
{
"cell_type": "markdown",
"id": "3607d67e-e56c-4102-bbba-df2edc0e109e",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"### Credentials\n",
"To get started, request a [Nebula API key](https://platform.symbl.ai/#/login) and set the `NEBULA_API_KEY` environment variable:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2108b517-1e8d-473d-92fa-4f930e8072a7",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"NEBULA_API_KEY\"] = getpass.getpass()"
]
},
{
"cell_type": "markdown",
"id": "68b44357",
"metadata": {},
"source": [
"### Installation\n",
"The integration is set up in the `langchain-community` package."
]
},
{
"cell_type": "markdown",
"id": "4c26754b-b3c9-4d93-8f36-43049bd943bf",
"metadata": {},
"source": [
"## Instantiation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0fdd26e7",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.symblai_nebula import ChatNebula\n",
"from langchain_core.messages import AIMessage, HumanMessage, SystemMessage"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"chat = ChatNebula(max_tokens=1024, temperature=0.5)"
]
},
{
"cell_type": "markdown",
"id": "2a915547",
"metadata": {},
"source": [
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=[{'role': 'human', 'text': 'What is the capital of France?'}, {'role': 'assistant', 'text': 'The capital of France is Paris.'}])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" SystemMessage(\n",
" content=\"You are a helpful assistant that answers general knowledge questions.\"\n",
" ),\n",
" HumanMessage(content=\"What is the capital of France?\"),\n",
"]\n",
"chat.invoke(messages)"
]
},
{
"cell_type": "markdown",
"id": "9723913f",
"metadata": {},
"source": [
"### Async"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=[{'role': 'human', 'text': 'What is the capital of France?'}, {'role': 'assistant', 'text': 'The capital of France is Paris.'}])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await chat.ainvoke(messages)"
]
},
{
"cell_type": "markdown",
"id": "e0a1d3b4",
"metadata": {},
"source": [
"### Streaming"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" The capital of France is Paris."
]
}
],
"source": [
"for chunk in chat.stream(messages):\n",
" print(chunk.content, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"id": "9f91b7c7",
"metadata": {},
"source": [
"### Batch"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "054dc648",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[AIMessage(content=[{'role': 'human', 'text': 'What is the capital of France?'}, {'role': 'assistant', 'text': 'The capital of France is Paris.'}])]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat.batch([messages])"
]
},
{
"cell_type": "markdown",
"id": "e59a5519",
"metadata": {},
"source": [
"## Chaining"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "6455f67b",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_template(\"Tell me a joke about {topic}\")\n",
"chain = prompt | chat"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "deb1e2a1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=[{'role': 'human', 'text': 'Tell me a joke about cows'}, {'role': 'assistant', 'text': \"Sure, here's a joke about cows:\\n\\nWhy did the cow cross the road?\\n\\nTo get to the udder side!\"}])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"topic\": \"cows\"})"
]
},
{
"cell_type": "markdown",
"id": "bb9d4755",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"Check out the [API reference](https://python.langchain.com/v0.2/api_reference/community/chat_models/langchain_community.chat_models.symblai_nebula.ChatNebula.html) for more detail."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -153,6 +153,7 @@ if TYPE_CHECKING:
from langchain_community.chat_models.sparkllm import ( from langchain_community.chat_models.sparkllm import (
ChatSparkLLM, ChatSparkLLM,
) )
from langchain_community.chat_models.symblai_nebula import ChatNebula
from langchain_community.chat_models.tongyi import ( from langchain_community.chat_models.tongyi import (
ChatTongyi, ChatTongyi,
) )
@ -201,6 +202,7 @@ __all__ = [
"ChatMLflowAIGateway", "ChatMLflowAIGateway",
"ChatMaritalk", "ChatMaritalk",
"ChatMlflow", "ChatMlflow",
"ChatNebula",
"ChatOCIGenAI", "ChatOCIGenAI",
"ChatOllama", "ChatOllama",
"ChatOpenAI", "ChatOpenAI",
@ -257,6 +259,7 @@ _module_lookup = {
"ChatMLX": "langchain_community.chat_models.mlx", "ChatMLX": "langchain_community.chat_models.mlx",
"ChatMaritalk": "langchain_community.chat_models.maritalk", "ChatMaritalk": "langchain_community.chat_models.maritalk",
"ChatMlflow": "langchain_community.chat_models.mlflow", "ChatMlflow": "langchain_community.chat_models.mlflow",
"ChatNebula": "langchain_community.chat_models.symblai_nebula",
"ChatOctoAI": "langchain_community.chat_models.octoai", "ChatOctoAI": "langchain_community.chat_models.octoai",
"ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai", "ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai",
"ChatOllama": "langchain_community.chat_models.ollama", "ChatOllama": "langchain_community.chat_models.ollama",

View File

@ -0,0 +1,271 @@
import json
import os
from json import JSONDecodeError
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
import requests
from aiohttp import ClientSession
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr
from langchain_core.utils import convert_to_secret_str
def _convert_role(role: str) -> str:
map = {"ai": "assistant", "human": "human", "chat": "human"}
if role in map:
return map[role]
else:
raise ValueError(f"Unknown role type: {role}")
def _format_nebula_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
system = ""
formatted_messages = []
for message in messages[:-1]:
if message.type == "system":
if isinstance(message.content, str):
system = message.content
else:
raise ValueError("System prompt must be a string")
else:
formatted_messages.append(
{
"role": _convert_role(message.type),
"text": message.content,
}
)
text = messages[-1].content
formatted_messages.append({"role": "human", "text": text})
return {"system_prompt": system, "messages": formatted_messages}
class ChatNebula(BaseChatModel):
"""`Nebula` chat large language model - https://docs.symbl.ai/docs/nebula-llm
API Reference: https://docs.symbl.ai/reference/nebula-chat
To use, set the environment variable ``NEBULA_API_KEY``,
or pass it as a named parameter to the constructor.
To request an API key, visit https://platform.symbl.ai/#/login
Example:
.. code-block:: python
from langchain_community.chat_models import ChatNebula
from langchain_core.messages import SystemMessage, HumanMessage
chat = ChatNebula(max_new_tokens=1024, temperature=0.5)
messages = [
SystemMessage(
content="You are a helpful assistant."
),
HumanMessage(
"Answer the following question. How can I help save the world."
),
]
chat.invoke(messages)
"""
max_new_tokens: int = 1024
"""Denotes the number of tokens to predict per generation."""
temperature: Optional[float] = 0
"""A non-negative float that tunes the degree of randomness in generation."""
streaming: bool = False
nebula_api_url: str = "https://api-nebula.symbl.ai"
nebula_api_key: Optional[SecretStr] = Field(None, description="Nebula API Token")
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
arbitrary_types_allowed = True
def __init__(self, **kwargs: Any) -> None:
if "nebula_api_key" in kwargs:
api_key = convert_to_secret_str(kwargs.pop("nebula_api_key"))
elif "NEBULA_API_KEY" in os.environ:
api_key = convert_to_secret_str(os.environ["NEBULA_API_KEY"])
else:
api_key = None
super().__init__(nebula_api_key=api_key, **kwargs) # type: ignore[call-arg]
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "nebula-chat"
@property
def _api_key(self) -> str:
if self.nebula_api_key:
return self.nebula_api_key.get_secret_value()
return ""
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Call out to Nebula's chat endpoint."""
url = f"{self.nebula_api_url}/v1/model/chat/streaming"
headers = {
"ApiKey": self._api_key,
"Content-Type": "application/json",
}
formatted_data = _format_nebula_messages(messages=messages)
payload: Dict[str, Any] = {
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
**formatted_data,
**kwargs,
}
payload = {k: v for k, v in payload.items() if v is not None}
json_payload = json.dumps(payload)
response = requests.request(
"POST", url, headers=headers, data=json_payload, stream=True
)
response.raise_for_status()
for chunk_response in response.iter_lines():
chunk_decoded = chunk_response.decode()[6:]
try:
chunk = json.loads(chunk_decoded)
except JSONDecodeError:
continue
token = chunk["delta"]
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
if run_manager:
run_manager.on_llm_new_token(token, chunk=cg_chunk)
yield cg_chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
url = f"{self.nebula_api_url}/v1/model/chat/streaming"
headers = {"ApiKey": self._api_key, "Content-Type": "application/json"}
formatted_data = _format_nebula_messages(messages=messages)
payload: Dict[str, Any] = {
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
**formatted_data,
**kwargs,
}
payload = {k: v for k, v in payload.items() if v is not None}
json_payload = json.dumps(payload)
async with ClientSession() as session:
async with session.post( # type: ignore[call-arg]
url, data=json_payload, headers=headers, stream=True
) as response:
response.raise_for_status()
async for chunk_response in response.content:
chunk_decoded = chunk_response.decode()[6:]
try:
chunk = json.loads(chunk_decoded)
except JSONDecodeError:
continue
token = chunk["delta"]
cg_chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token)
)
if run_manager:
await run_manager.on_llm_new_token(token, chunk=cg_chunk)
yield cg_chunk
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
url = f"{self.nebula_api_url}/v1/model/chat"
headers = {"ApiKey": self._api_key, "Content-Type": "application/json"}
formatted_data = _format_nebula_messages(messages=messages)
payload: Dict[str, Any] = {
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
**formatted_data,
**kwargs,
}
payload = {k: v for k, v in payload.items() if v is not None}
json_payload = json.dumps(payload)
response = requests.request("POST", url, headers=headers, data=json_payload)
response.raise_for_status()
data = response.json()
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=data["messages"]))],
llm_output=data,
)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
url = f"{self.nebula_api_url}/v1/model/chat"
headers = {"ApiKey": self._api_key, "Content-Type": "application/json"}
formatted_data = _format_nebula_messages(messages=messages)
payload: Dict[str, Any] = {
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
**formatted_data,
**kwargs,
}
payload = {k: v for k, v in payload.items() if v is not None}
json_payload = json.dumps(payload)
async with ClientSession() as session:
async with session.post(
url, data=json_payload, headers=headers
) as response:
response.raise_for_status()
data = await response.json()
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content=data["messages"]))
],
llm_output=data,
)

View File

@ -28,6 +28,7 @@ EXPECTED_ALL = [
"ChatMlflow", "ChatMlflow",
"ChatMLflowAIGateway", "ChatMLflowAIGateway",
"ChatMLX", "ChatMLX",
"ChatNebula",
"ChatOCIGenAI", "ChatOCIGenAI",
"ChatOllama", "ChatOllama",
"ChatOpenAI", "ChatOpenAI",