From 5efec068c96195428415c5b1a513db17c98727d5 Mon Sep 17 00:00:00 2001 From: Xin Liu Date: Fri, 12 Jan 2024 13:32:48 +0800 Subject: [PATCH] feat: Implement `stream` interface (#15875) Major changes: - Rename `wasm_chat.py` to `llama_edge.py` - Rename the `WasmChatService` class to `ChatService` - Implement the `stream` interface for `ChatService` - Add `test_chat_wasm_service_streaming` in the integration test - Update `llama_edge.ipynb` --------- Signed-off-by: Xin Liu --- docs/docs/integrations/chat/llama_edge.ipynb | 135 ++++++++++++++++++ docs/docs/integrations/chat/wasm_chat.ipynb | 85 ----------- .../chat_models/__init__.py | 4 +- .../{wasm_chat.py => llama_edge.py} | 115 +++++++++++++-- .../chat_models/test_llama_edge.py | 52 +++++++ .../chat_models/test_wasm_chat.py | 28 ---- .../unit_tests/chat_models/test_imports.py | 2 +- .../{test_wasmchat.py => test_llama_edge.py} | 6 +- 8 files changed, 299 insertions(+), 128 deletions(-) create mode 100644 docs/docs/integrations/chat/llama_edge.ipynb delete mode 100644 docs/docs/integrations/chat/wasm_chat.ipynb rename libs/community/langchain_community/chat_models/{wasm_chat.py => llama_edge.py} (54%) create mode 100644 libs/community/tests/integration_tests/chat_models/test_llama_edge.py delete mode 100644 libs/community/tests/integration_tests/chat_models/test_wasm_chat.py rename libs/community/tests/unit_tests/chat_models/{test_wasmchat.py => test_llama_edge.py} (95%) diff --git a/docs/docs/integrations/chat/llama_edge.ipynb b/docs/docs/integrations/chat/llama_edge.ipynb new file mode 100644 index 00000000000..003439844a4 --- /dev/null +++ b/docs/docs/integrations/chat/llama_edge.ipynb @@ -0,0 +1,135 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LlamaEdge\n", + "\n", + "[LlamaEdge](https://github.com/second-state/LlamaEdge) allows you to chat with LLMs of [GGUF](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/README.md) format both locally and via chat service.\n", + "\n", + "- `LlamaEdgeChatService` provides developers an OpenAI API compatible service to chat with LLMs via HTTP requests.\n", + "\n", + "- `LlamaEdgeChatLocal` enables developers to chat with LLMs locally (coming soon).\n", + "\n", + "Both `LlamaEdgeChatService` and `LlamaEdgeChatLocal` run on the infrastructure driven by [WasmEdge Runtime](https://wasmedge.org/), which provides a lightweight and portable WebAssembly container environment for LLM inference tasks.\n", + "\n", + "## Chat via API Service\n", + "\n", + "`LlamaEdgeChatService` works on the `llama-api-server`. Following the steps in [llama-api-server quick-start](https://github.com/second-state/llama-utils/tree/main/api-server#readme), you can host your own API service so that you can chat with any models you like on any device you have anywhere as long as the internet is available." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models.llama_edge import LlamaEdgeChatService\n", + "from langchain_core.messages import HumanMessage, SystemMessage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Chat with LLMs in the non-streaming mode" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Bot] Hello! The capital of France is Paris.\n" + ] + } + ], + "source": [ + "# service url\n", + "service_url = \"https://b008-54-186-154-209.ngrok-free.app\"\n", + "\n", + "# create wasm-chat service instance\n", + "chat = LlamaEdgeChatService(service_url=service_url)\n", + "\n", + "# create message sequence\n", + "system_message = SystemMessage(content=\"You are an AI assistant\")\n", + "user_message = HumanMessage(content=\"What is the capital of France?\")\n", + "messages = [system_message, user_message]\n", + "\n", + "# chat with wasm-chat service\n", + "response = chat(messages)\n", + "\n", + "print(f\"[Bot] {response.content}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Chat with LLMs in the streaming mode" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Bot] Hello! I'm happy to help you with your question. The capital of Norway is Oslo.\n" + ] + } + ], + "source": [ + "# service url\n", + "service_url = \"https://b008-54-186-154-209.ngrok-free.app\"\n", + "\n", + "# create wasm-chat service instance\n", + "chat = LlamaEdgeChatService(service_url=service_url, streaming=True)\n", + "\n", + "# create message sequence\n", + "system_message = SystemMessage(content=\"You are an AI assistant\")\n", + "user_message = HumanMessage(content=\"What is the capital of Norway?\")\n", + "messages = [\n", + " system_message,\n", + " user_message,\n", + "]\n", + "\n", + "output = \"\"\n", + "for chunk in chat.stream(messages):\n", + " # print(chunk.content, end=\"\", flush=True)\n", + " output += chunk.content\n", + "\n", + "print(f\"[Bot] {output}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/docs/integrations/chat/wasm_chat.ipynb b/docs/docs/integrations/chat/wasm_chat.ipynb deleted file mode 100644 index ed7859d3b97..00000000000 --- a/docs/docs/integrations/chat/wasm_chat.ipynb +++ /dev/null @@ -1,85 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Wasm Chat\n", - "\n", - "`Wasm-chat` allows you to chat with LLMs of [GGUF](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/README.md) format both locally and via chat service.\n", - "\n", - "- `WasmChatService` provides developers an OpenAI API compatible service to chat with LLMs via HTTP requests.\n", - "\n", - "- `WasmChatLocal` enables developers to chat with LLMs locally (coming soon).\n", - "\n", - "Both `WasmChatService` and `WasmChatLocal` run on the infrastructure driven by [WasmEdge Runtime](https://wasmedge.org/), which provides a lightweight and portable WebAssembly container environment for LLM inference tasks.\n", - "\n", - "## Chat via API Service\n", - "\n", - "`WasmChatService` provides chat services by the `llama-api-server`. Following the steps in [llama-api-server quick-start](https://github.com/second-state/llama-utils/tree/main/api-server#readme), you can host your own API service so that you can chat with any models you like on any device you have anywhere as long as the internet is available." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_community.chat_models.wasm_chat import WasmChatService\n", - "from langchain_core.messages import AIMessage, HumanMessage, SystemMessage" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Bot] Paris\n" - ] - } - ], - "source": [ - "# service url\n", - "service_url = \"https://b008-54-186-154-209.ngrok-free.app\"\n", - "\n", - "# create wasm-chat service instance\n", - "chat = WasmChatService(service_url=service_url)\n", - "\n", - "# create message sequence\n", - "system_message = SystemMessage(content=\"You are an AI assistant\")\n", - "user_message = HumanMessage(content=\"What is the capital of France?\")\n", - "messages = [system_message, user_message]\n", - "\n", - "# chat with wasm-chat service\n", - "response = chat(messages)\n", - "\n", - "print(f\"[Bot] {response.content}\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.7" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index a2145a825bc..aec1fcb8784 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -39,6 +39,7 @@ from langchain_community.chat_models.javelin_ai_gateway import ChatJavelinAIGate from langchain_community.chat_models.jinachat import JinaChat from langchain_community.chat_models.konko import ChatKonko from langchain_community.chat_models.litellm import ChatLiteLLM +from langchain_community.chat_models.llama_edge import LlamaEdgeChatService from langchain_community.chat_models.minimax import MiniMaxChat from langchain_community.chat_models.mlflow import ChatMlflow from langchain_community.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway @@ -49,12 +50,11 @@ from langchain_community.chat_models.promptlayer_openai import PromptLayerChatOp from langchain_community.chat_models.tongyi import ChatTongyi from langchain_community.chat_models.vertexai import ChatVertexAI from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat -from langchain_community.chat_models.wasm_chat import WasmChatService from langchain_community.chat_models.yandex import ChatYandexGPT from langchain_community.chat_models.zhipuai import ChatZhipuAI __all__ = [ - "WasmChatService", + "LlamaEdgeChatService", "ChatOpenAI", "BedrockChat", "AzureChatOpenAI", diff --git a/libs/community/langchain_community/chat_models/wasm_chat.py b/libs/community/langchain_community/chat_models/llama_edge.py similarity index 54% rename from libs/community/langchain_community/chat_models/wasm_chat.py rename to libs/community/langchain_community/chat_models/llama_edge.py index fb0826502c3..5cd8d72edc9 100644 --- a/libs/community/langchain_community/chat_models/wasm_chat.py +++ b/libs/community/langchain_community/chat_models/llama_edge.py @@ -1,18 +1,26 @@ import json import logging -from typing import Any, Dict, List, Mapping, Optional +import re +from typing import Any, Dict, Iterator, List, Mapping, Optional, Type import requests from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.chat_models import ( + BaseChatModel, + generate_from_stream, +) from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, + BaseMessageChunk, ChatMessage, + ChatMessageChunk, HumanMessage, + HumanMessageChunk, SystemMessage, ) -from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import root_validator from langchain_core.utils import get_pydantic_field_names @@ -45,10 +53,26 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: return message_dict -class WasmChatService(BaseChatModel): +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] +) -> BaseMessageChunk: + role = _dict.get("role") + content = _dict.get("content") or "" + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) + else: + return default_class(content=content) + + +class LlamaEdgeChatService(BaseChatModel): """Chat with LLMs via `llama-api-server` - For the information about `llama-api-server`, visit https://github.com/second-state/llama-utils + For the information about `llama-api-server`, visit https://github.com/second-state/LlamaEdge """ request_timeout: int = 60 @@ -57,6 +81,8 @@ class WasmChatService(BaseChatModel): """URL of WasmChat service""" model: str = "NA" """model name, default is `NA`.""" + streaming: bool = False + """Whether to stream the results or not.""" class Config: """Configuration for this pydantic object.""" @@ -96,6 +122,12 @@ class WasmChatService(BaseChatModel): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.streaming: + stream_iter = self._stream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + res = self._chat(messages, **kwargs) if res.status_code != 200: @@ -105,6 +137,64 @@ class WasmChatService(BaseChatModel): return self._create_chat_result(response) + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + res = self._chat(messages, **kwargs) + + default_chunk_class = AIMessageChunk + substring = '"object":"chat.completion.chunk"}' + for line in res.iter_lines(): + chunks = [] + if line: + json_string = line.decode("utf-8") + + # Find all positions of the substring + positions = [m.start() for m in re.finditer(substring, json_string)] + positions = [-1 * len(substring)] + positions + + for i in range(len(positions) - 1): + chunk = json.loads( + json_string[ + positions[i] + len(substring) : positions[i + 1] + + len(substring) + ] + ) + chunks.append(chunk) + + for chunk in chunks: + if not isinstance(chunk, dict): + chunk = chunk.dict() + if len(chunk["choices"]) == 0: + continue + + choice = chunk["choices"][0] + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + if ( + choice.get("finish_reason") is not None + and choice.get("finish_reason") == "stop" + ): + break + finish_reason = choice.get("finish_reason") + generation_info = ( + dict(finish_reason=finish_reason) + if finish_reason is not None + else None + ) + default_chunk_class = chunk.__class__ + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info + ) + yield chunk + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: if self.service_url is None: res = requests.models.Response() @@ -114,10 +204,17 @@ class WasmChatService(BaseChatModel): service_url = f"{self.service_url}/v1/chat/completions" - payload = { - "model": self.model, - "messages": [_convert_message_to_dict(m) for m in messages], - } + if self.streaming: + payload = { + "model": self.model, + "messages": [_convert_message_to_dict(m) for m in messages], + "stream": self.streaming, + } + else: + payload = { + "model": self.model, + "messages": [_convert_message_to_dict(m) for m in messages], + } res = requests.post( url=service_url, diff --git a/libs/community/tests/integration_tests/chat_models/test_llama_edge.py b/libs/community/tests/integration_tests/chat_models/test_llama_edge.py new file mode 100644 index 00000000000..50919902dd4 --- /dev/null +++ b/libs/community/tests/integration_tests/chat_models/test_llama_edge.py @@ -0,0 +1,52 @@ +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from langchain_community.chat_models.llama_edge import LlamaEdgeChatService + + +@pytest.mark.enable_socket +def test_chat_wasm_service() -> None: + """This test requires the port 8080 is not occupied.""" + + # service url + service_url = "https://b008-54-186-154-209.ngrok-free.app" + + # create wasm-chat service instance + chat = LlamaEdgeChatService(service_url=service_url) + + # create message sequence + system_message = SystemMessage(content="You are an AI assistant") + user_message = HumanMessage(content="What is the capital of France?") + messages = [system_message, user_message] + + # chat with wasm-chat service + response = chat(messages) + + # check response + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "Paris" in response.content + + +@pytest.mark.enable_socket +def test_chat_wasm_service_streaming() -> None: + """This test requires the port 8080 is not occupied.""" + + # service url + service_url = "https://b008-54-186-154-209.ngrok-free.app" + + # create wasm-chat service instance + chat = LlamaEdgeChatService(service_url=service_url, streaming=True) + + # create message sequence + user_message = HumanMessage(content="What is the capital of France?") + messages = [ + user_message, + ] + + output = "" + for chunk in chat.stream(messages): + print(chunk.content, end="", flush=True) + output += chunk.content + + assert "Paris" in output diff --git a/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py b/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py deleted file mode 100644 index 369908a2d26..00000000000 --- a/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py +++ /dev/null @@ -1,28 +0,0 @@ -import pytest -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage - -from langchain_community.chat_models.wasm_chat import WasmChatService - - -@pytest.mark.enable_socket -def test_chat_wasm_service() -> None: - """This test requires the port 8080 is not occupied.""" - - # service url - service_url = "https://b008-54-186-154-209.ngrok-free.app" - - # create wasm-chat service instance - chat = WasmChatService(service_url=service_url) - - # create message sequence - system_message = SystemMessage(content="You are an AI assistant") - user_message = HumanMessage(content="What is the capital of France?") - messages = [system_message, user_message] - - # chat with wasm-chat service - response = chat(messages) - - # check response - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - assert "Paris" in response.content diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index fd27b932f73..031fb96e893 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -33,7 +33,7 @@ EXPECTED_ALL = [ "ChatHunyuan", "GigaChat", "VolcEngineMaasChat", - "WasmChatService", + "LlamaEdgeChatService", "GPTRouter", "ChatZhipuAI", ] diff --git a/libs/community/tests/unit_tests/chat_models/test_wasmchat.py b/libs/community/tests/unit_tests/chat_models/test_llama_edge.py similarity index 95% rename from libs/community/tests/unit_tests/chat_models/test_wasmchat.py rename to libs/community/tests/unit_tests/chat_models/test_llama_edge.py index 6905130de67..4bcf0fde6e1 100644 --- a/libs/community/tests/unit_tests/chat_models/test_wasmchat.py +++ b/libs/community/tests/unit_tests/chat_models/test_llama_edge.py @@ -7,8 +7,8 @@ from langchain_core.messages import ( SystemMessage, ) -from langchain_community.chat_models.wasm_chat import ( - WasmChatService, +from langchain_community.chat_models.llama_edge import ( + LlamaEdgeChatService, _convert_dict_to_message, _convert_message_to_dict, ) @@ -64,7 +64,7 @@ def test__convert_dict_to_message_other_role() -> None: def test_wasm_chat_without_service_url() -> None: - chat = WasmChatService() + chat = LlamaEdgeChatService() # create message sequence system_message = SystemMessage(content="You are an AI assistant")