feat: Implement stream interface (#15875)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->

Major changes:

- Rename `wasm_chat.py` to `llama_edge.py`
- Rename the `WasmChatService` class to `ChatService`
- Implement the `stream` interface for `ChatService`
- Add `test_chat_wasm_service_streaming` in the integration test
- Update `llama_edge.ipynb`

---------

Signed-off-by: Xin Liu <sam@secondstate.io>
This commit is contained in:
Xin Liu 2024-01-12 13:32:48 +08:00 committed by GitHub
parent ec4dab0449
commit 5efec068c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 299 additions and 128 deletions

View File

@ -0,0 +1,135 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# LlamaEdge\n",
"\n",
"[LlamaEdge](https://github.com/second-state/LlamaEdge) allows you to chat with LLMs of [GGUF](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/README.md) format both locally and via chat service.\n",
"\n",
"- `LlamaEdgeChatService` provides developers an OpenAI API compatible service to chat with LLMs via HTTP requests.\n",
"\n",
"- `LlamaEdgeChatLocal` enables developers to chat with LLMs locally (coming soon).\n",
"\n",
"Both `LlamaEdgeChatService` and `LlamaEdgeChatLocal` run on the infrastructure driven by [WasmEdge Runtime](https://wasmedge.org/), which provides a lightweight and portable WebAssembly container environment for LLM inference tasks.\n",
"\n",
"## Chat via API Service\n",
"\n",
"`LlamaEdgeChatService` works on the `llama-api-server`. Following the steps in [llama-api-server quick-start](https://github.com/second-state/llama-utils/tree/main/api-server#readme), you can host your own API service so that you can chat with any models you like on any device you have anywhere as long as the internet is available."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.llama_edge import LlamaEdgeChatService\n",
"from langchain_core.messages import HumanMessage, SystemMessage"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Chat with LLMs in the non-streaming mode"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Bot] Hello! The capital of France is Paris.\n"
]
}
],
"source": [
"# service url\n",
"service_url = \"https://b008-54-186-154-209.ngrok-free.app\"\n",
"\n",
"# create wasm-chat service instance\n",
"chat = LlamaEdgeChatService(service_url=service_url)\n",
"\n",
"# create message sequence\n",
"system_message = SystemMessage(content=\"You are an AI assistant\")\n",
"user_message = HumanMessage(content=\"What is the capital of France?\")\n",
"messages = [system_message, user_message]\n",
"\n",
"# chat with wasm-chat service\n",
"response = chat(messages)\n",
"\n",
"print(f\"[Bot] {response.content}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Chat with LLMs in the streaming mode"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Bot] Hello! I'm happy to help you with your question. The capital of Norway is Oslo.\n"
]
}
],
"source": [
"# service url\n",
"service_url = \"https://b008-54-186-154-209.ngrok-free.app\"\n",
"\n",
"# create wasm-chat service instance\n",
"chat = LlamaEdgeChatService(service_url=service_url, streaming=True)\n",
"\n",
"# create message sequence\n",
"system_message = SystemMessage(content=\"You are an AI assistant\")\n",
"user_message = HumanMessage(content=\"What is the capital of Norway?\")\n",
"messages = [\n",
" system_message,\n",
" user_message,\n",
"]\n",
"\n",
"output = \"\"\n",
"for chunk in chat.stream(messages):\n",
" # print(chunk.content, end=\"\", flush=True)\n",
" output += chunk.content\n",
"\n",
"print(f\"[Bot] {output}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,85 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Wasm Chat\n",
"\n",
"`Wasm-chat` allows you to chat with LLMs of [GGUF](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/README.md) format both locally and via chat service.\n",
"\n",
"- `WasmChatService` provides developers an OpenAI API compatible service to chat with LLMs via HTTP requests.\n",
"\n",
"- `WasmChatLocal` enables developers to chat with LLMs locally (coming soon).\n",
"\n",
"Both `WasmChatService` and `WasmChatLocal` run on the infrastructure driven by [WasmEdge Runtime](https://wasmedge.org/), which provides a lightweight and portable WebAssembly container environment for LLM inference tasks.\n",
"\n",
"## Chat via API Service\n",
"\n",
"`WasmChatService` provides chat services by the `llama-api-server`. Following the steps in [llama-api-server quick-start](https://github.com/second-state/llama-utils/tree/main/api-server#readme), you can host your own API service so that you can chat with any models you like on any device you have anywhere as long as the internet is available."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.wasm_chat import WasmChatService\n",
"from langchain_core.messages import AIMessage, HumanMessage, SystemMessage"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Bot] Paris\n"
]
}
],
"source": [
"# service url\n",
"service_url = \"https://b008-54-186-154-209.ngrok-free.app\"\n",
"\n",
"# create wasm-chat service instance\n",
"chat = WasmChatService(service_url=service_url)\n",
"\n",
"# create message sequence\n",
"system_message = SystemMessage(content=\"You are an AI assistant\")\n",
"user_message = HumanMessage(content=\"What is the capital of France?\")\n",
"messages = [system_message, user_message]\n",
"\n",
"# chat with wasm-chat service\n",
"response = chat(messages)\n",
"\n",
"print(f\"[Bot] {response.content}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -39,6 +39,7 @@ from langchain_community.chat_models.javelin_ai_gateway import ChatJavelinAIGate
from langchain_community.chat_models.jinachat import JinaChat from langchain_community.chat_models.jinachat import JinaChat
from langchain_community.chat_models.konko import ChatKonko from langchain_community.chat_models.konko import ChatKonko
from langchain_community.chat_models.litellm import ChatLiteLLM from langchain_community.chat_models.litellm import ChatLiteLLM
from langchain_community.chat_models.llama_edge import LlamaEdgeChatService
from langchain_community.chat_models.minimax import MiniMaxChat from langchain_community.chat_models.minimax import MiniMaxChat
from langchain_community.chat_models.mlflow import ChatMlflow from langchain_community.chat_models.mlflow import ChatMlflow
from langchain_community.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway from langchain_community.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
@ -49,12 +50,11 @@ from langchain_community.chat_models.promptlayer_openai import PromptLayerChatOp
from langchain_community.chat_models.tongyi import ChatTongyi from langchain_community.chat_models.tongyi import ChatTongyi
from langchain_community.chat_models.vertexai import ChatVertexAI from langchain_community.chat_models.vertexai import ChatVertexAI
from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat
from langchain_community.chat_models.wasm_chat import WasmChatService
from langchain_community.chat_models.yandex import ChatYandexGPT from langchain_community.chat_models.yandex import ChatYandexGPT
from langchain_community.chat_models.zhipuai import ChatZhipuAI from langchain_community.chat_models.zhipuai import ChatZhipuAI
__all__ = [ __all__ = [
"WasmChatService", "LlamaEdgeChatService",
"ChatOpenAI", "ChatOpenAI",
"BedrockChat", "BedrockChat",
"AzureChatOpenAI", "AzureChatOpenAI",

View File

@ -1,18 +1,26 @@
import json import json
import logging import logging
from typing import Any, Dict, List, Mapping, Optional import re
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
import requests import requests
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk,
BaseMessage, BaseMessage,
BaseMessageChunk,
ChatMessage, ChatMessage,
ChatMessageChunk,
HumanMessage, HumanMessage,
HumanMessageChunk,
SystemMessage, SystemMessage,
) )
from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import get_pydantic_field_names from langchain_core.utils import get_pydantic_field_names
@ -45,10 +53,26 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
return message_dict return message_dict
class WasmChatService(BaseChatModel): def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
class LlamaEdgeChatService(BaseChatModel):
"""Chat with LLMs via `llama-api-server` """Chat with LLMs via `llama-api-server`
For the information about `llama-api-server`, visit https://github.com/second-state/llama-utils For the information about `llama-api-server`, visit https://github.com/second-state/LlamaEdge
""" """
request_timeout: int = 60 request_timeout: int = 60
@ -57,6 +81,8 @@ class WasmChatService(BaseChatModel):
"""URL of WasmChat service""" """URL of WasmChat service"""
model: str = "NA" model: str = "NA"
"""model name, default is `NA`.""" """model name, default is `NA`."""
streaming: bool = False
"""Whether to stream the results or not."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -96,6 +122,12 @@ class WasmChatService(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
res = self._chat(messages, **kwargs) res = self._chat(messages, **kwargs)
if res.status_code != 200: if res.status_code != 200:
@ -105,6 +137,64 @@ class WasmChatService(BaseChatModel):
return self._create_chat_result(response) return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
res = self._chat(messages, **kwargs)
default_chunk_class = AIMessageChunk
substring = '"object":"chat.completion.chunk"}'
for line in res.iter_lines():
chunks = []
if line:
json_string = line.decode("utf-8")
# Find all positions of the substring
positions = [m.start() for m in re.finditer(substring, json_string)]
positions = [-1 * len(substring)] + positions
for i in range(len(positions) - 1):
chunk = json.loads(
json_string[
positions[i] + len(substring) : positions[i + 1]
+ len(substring)
]
)
chunks.append(chunk)
for chunk in chunks:
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
if (
choice.get("finish_reason") is not None
and choice.get("finish_reason") == "stop"
):
break
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason)
if finish_reason is not None
else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.service_url is None: if self.service_url is None:
res = requests.models.Response() res = requests.models.Response()
@ -114,10 +204,17 @@ class WasmChatService(BaseChatModel):
service_url = f"{self.service_url}/v1/chat/completions" service_url = f"{self.service_url}/v1/chat/completions"
payload = { if self.streaming:
"model": self.model, payload = {
"messages": [_convert_message_to_dict(m) for m in messages], "model": self.model,
} "messages": [_convert_message_to_dict(m) for m in messages],
"stream": self.streaming,
}
else:
payload = {
"model": self.model,
"messages": [_convert_message_to_dict(m) for m in messages],
}
res = requests.post( res = requests.post(
url=service_url, url=service_url,

View File

@ -0,0 +1,52 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_community.chat_models.llama_edge import LlamaEdgeChatService
@pytest.mark.enable_socket
def test_chat_wasm_service() -> None:
"""This test requires the port 8080 is not occupied."""
# service url
service_url = "https://b008-54-186-154-209.ngrok-free.app"
# create wasm-chat service instance
chat = LlamaEdgeChatService(service_url=service_url)
# create message sequence
system_message = SystemMessage(content="You are an AI assistant")
user_message = HumanMessage(content="What is the capital of France?")
messages = [system_message, user_message]
# chat with wasm-chat service
response = chat(messages)
# check response
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert "Paris" in response.content
@pytest.mark.enable_socket
def test_chat_wasm_service_streaming() -> None:
"""This test requires the port 8080 is not occupied."""
# service url
service_url = "https://b008-54-186-154-209.ngrok-free.app"
# create wasm-chat service instance
chat = LlamaEdgeChatService(service_url=service_url, streaming=True)
# create message sequence
user_message = HumanMessage(content="What is the capital of France?")
messages = [
user_message,
]
output = ""
for chunk in chat.stream(messages):
print(chunk.content, end="", flush=True)
output += chunk.content
assert "Paris" in output

View File

@ -1,28 +0,0 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_community.chat_models.wasm_chat import WasmChatService
@pytest.mark.enable_socket
def test_chat_wasm_service() -> None:
"""This test requires the port 8080 is not occupied."""
# service url
service_url = "https://b008-54-186-154-209.ngrok-free.app"
# create wasm-chat service instance
chat = WasmChatService(service_url=service_url)
# create message sequence
system_message = SystemMessage(content="You are an AI assistant")
user_message = HumanMessage(content="What is the capital of France?")
messages = [system_message, user_message]
# chat with wasm-chat service
response = chat(messages)
# check response
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert "Paris" in response.content

View File

@ -33,7 +33,7 @@ EXPECTED_ALL = [
"ChatHunyuan", "ChatHunyuan",
"GigaChat", "GigaChat",
"VolcEngineMaasChat", "VolcEngineMaasChat",
"WasmChatService", "LlamaEdgeChatService",
"GPTRouter", "GPTRouter",
"ChatZhipuAI", "ChatZhipuAI",
] ]

View File

@ -7,8 +7,8 @@ from langchain_core.messages import (
SystemMessage, SystemMessage,
) )
from langchain_community.chat_models.wasm_chat import ( from langchain_community.chat_models.llama_edge import (
WasmChatService, LlamaEdgeChatService,
_convert_dict_to_message, _convert_dict_to_message,
_convert_message_to_dict, _convert_message_to_dict,
) )
@ -64,7 +64,7 @@ def test__convert_dict_to_message_other_role() -> None:
def test_wasm_chat_without_service_url() -> None: def test_wasm_chat_without_service_url() -> None:
chat = WasmChatService() chat = LlamaEdgeChatService()
# create message sequence # create message sequence
system_message = SystemMessage(content="You are an AI assistant") system_message = SystemMessage(content="You are an AI assistant")