From 1960ac8d25c142f23a10a8203e6ccd14c8ca6be7 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 25 Aug 2023 12:52:07 -0700 Subject: [PATCH] token chunks (#9739) Co-authored-by: Andrew --- libs/langchain/langchain/callbacks/base.py | 12 +- libs/langchain/langchain/callbacks/manager.py | 7 ++ .../langchain/callbacks/tracers/base.py | 13 +- .../langchain/langchain/chat_models/openai.py | 4 +- libs/langchain/langchain/llms/anthropic.py | 10 +- libs/langchain/langchain/llms/openai.py | 12 +- .../chat_models/test_openai.py | 112 +++++++++++++++++- 7 files changed, 153 insertions(+), 17 deletions(-) diff --git a/libs/langchain/langchain/callbacks/base.py b/libs/langchain/langchain/callbacks/base.py index c03633e2d4d..d6155536b06 100644 --- a/libs/langchain/langchain/callbacks/base.py +++ b/libs/langchain/langchain/callbacks/base.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from langchain.schema.agent import AgentAction, AgentFinish from langchain.schema.document import Document from langchain.schema.messages import BaseMessage - from langchain.schema.output import LLMResult + from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult class RetrieverManagerMixin: @@ -44,11 +44,18 @@ class LLMManagerMixin: self, token: str, *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - """Run on new LLM token. Only available when streaming is enabled.""" + """Run on new LLM token. Only available when streaming is enabled. + + Args: + token (str): The new token. + chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk, + containing content and other information. + """ def on_llm_end( self, @@ -316,6 +323,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): self, token: str, *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index c6f626204ce..3f22832de3e 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -49,6 +49,7 @@ from langchain.schema import ( LLMResult, ) from langchain.schema.messages import BaseMessage, get_buffer_string +from langchain.schema.output import ChatGenerationChunk, GenerationChunk if TYPE_CHECKING: from langsmith import Client as LangSmithClient @@ -592,6 +593,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): def on_llm_new_token( self, token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, **kwargs: Any, ) -> None: """Run when LLM generates a new token. @@ -607,6 +610,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): run_id=self.run_id, parent_run_id=self.parent_run_id, tags=self.tags, + chunk=chunk, **kwargs, ) @@ -655,6 +659,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): async def on_llm_new_token( self, token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, **kwargs: Any, ) -> None: """Run when LLM generates a new token. @@ -667,6 +673,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): "on_llm_new_token", "ignore_llm", token, + chunk=chunk, run_id=self.run_id, parent_run_id=self.parent_run_id, tags=self.tags, diff --git a/libs/langchain/langchain/callbacks/tracers/base.py b/libs/langchain/langchain/callbacks/tracers/base.py index 7dec527c1bb..bee30a515fd 100644 --- a/libs/langchain/langchain/callbacks/tracers/base.py +++ b/libs/langchain/langchain/callbacks/tracers/base.py @@ -13,7 +13,12 @@ from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.tracers.schemas import Run from langchain.load.dump import dumpd from langchain.schema.document import Document -from langchain.schema.output import ChatGeneration, LLMResult +from langchain.schema.output import ( + ChatGeneration, + ChatGenerationChunk, + GenerationChunk, + LLMResult, +) logger = logging.getLogger(__name__) @@ -123,6 +128,7 @@ class BaseTracer(BaseCallbackHandler, ABC): self, token: str, *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, @@ -135,11 +141,14 @@ class BaseTracer(BaseCallbackHandler, ABC): llm_run = self.run_map.get(run_id_) if llm_run is None or llm_run.run_type != "llm": raise TracerException(f"No LLM Run found to be traced for {run_id}") + event_kwargs: Dict[str, Any] = {"token": token} + if chunk: + event_kwargs["chunk"] = chunk llm_run.events.append( { "name": "new_token", "time": datetime.utcnow(), - "kwargs": {"token": token}, + "kwargs": event_kwargs, }, ) diff --git a/libs/langchain/langchain/chat_models/openai.py b/libs/langchain/langchain/chat_models/openai.py index 7cb1947cf88..5d944852d66 100644 --- a/libs/langchain/langchain/chat_models/openai.py +++ b/libs/langchain/langchain/chat_models/openai.py @@ -318,7 +318,7 @@ class ChatOpenAI(BaseChatModel): default_chunk_class = chunk.__class__ yield ChatGenerationChunk(message=chunk, generation_info=generation_info) if run_manager: - run_manager.on_llm_new_token(chunk.content) + run_manager.on_llm_new_token(chunk.content, chunk=chunk) def _generate( self, @@ -398,7 +398,7 @@ class ChatOpenAI(BaseChatModel): default_chunk_class = chunk.__class__ yield ChatGenerationChunk(message=chunk, generation_info=generation_info) if run_manager: - await run_manager.on_llm_new_token(chunk.content) + await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk) async def _agenerate( self, diff --git a/libs/langchain/langchain/llms/anthropic.py b/libs/langchain/langchain/llms/anthropic.py index 63664e07af6..f7ee0ab57e3 100644 --- a/libs/langchain/langchain/llms/anthropic.py +++ b/libs/langchain/langchain/llms/anthropic.py @@ -289,9 +289,10 @@ class Anthropic(LLM, _AnthropicCommon): for token in self.client.completions.create( prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params ): - yield GenerationChunk(text=token.completion) + chunk = GenerationChunk(text=token.completion) + yield chunk if run_manager: - run_manager.on_llm_new_token(token.completion) + run_manager.on_llm_new_token(chunk.text, chunk=chunk) async def _astream( self, @@ -324,9 +325,10 @@ class Anthropic(LLM, _AnthropicCommon): stream=True, **params, ): - yield GenerationChunk(text=token.completion) + chunk = GenerationChunk(text=token.completion) + yield chunk if run_manager: - await run_manager.on_llm_new_token(token.completion) + await run_manager.on_llm_new_token(chunk.text, chunk=chunk) def get_num_tokens(self, text: str) -> int: """Calculate number of tokens.""" diff --git a/libs/langchain/langchain/llms/openai.py b/libs/langchain/langchain/llms/openai.py index 0ab26529122..837256ec992 100644 --- a/libs/langchain/langchain/llms/openai.py +++ b/libs/langchain/langchain/llms/openai.py @@ -297,6 +297,7 @@ class BaseOpenAI(BaseLLM): if run_manager: run_manager.on_llm_new_token( chunk.text, + chunk=chunk, verbose=self.verbose, logprobs=chunk.generation_info["logprobs"] if chunk.generation_info @@ -320,6 +321,7 @@ class BaseOpenAI(BaseLLM): if run_manager: await run_manager.on_llm_new_token( chunk.text, + chunk=chunk, verbose=self.verbose, logprobs=chunk.generation_info["logprobs"] if chunk.generation_info @@ -825,9 +827,10 @@ class OpenAIChat(BaseLLM): self, messages=messages, run_manager=run_manager, **params ): token = stream_resp["choices"][0]["delta"].get("content", "") - yield GenerationChunk(text=token) + chunk = GenerationChunk(text=token) + yield chunk if run_manager: - run_manager.on_llm_new_token(token) + run_manager.on_llm_new_token(token, chunk=chunk) async def _astream( self, @@ -842,9 +845,10 @@ class OpenAIChat(BaseLLM): self, messages=messages, run_manager=run_manager, **params ): token = stream_resp["choices"][0]["delta"].get("content", "") - yield GenerationChunk(text=token) + chunk = GenerationChunk(text=token) + yield chunk if run_manager: - await run_manager.on_llm_new_token(token) + await run_manager.on_llm_new_token(token, chunk=chunk) def _generate( self, diff --git a/libs/langchain/tests/integration_tests/chat_models/test_openai.py b/libs/langchain/tests/integration_tests/chat_models/test_openai.py index 7637014a4c1..5c8b0e43e6e 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_openai.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_openai.py @@ -1,18 +1,22 @@ """Test ChatOpenAI wrapper.""" - - -from typing import Any +from typing import Any, List, Optional, Union import pytest +from langchain.callbacks.base import AsyncCallbackHandler from langchain.callbacks.manager import CallbackManager +from langchain.chains.openai_functions import ( + create_openai_fn_chain, +) from langchain.chat_models.openai import ChatOpenAI +from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain.schema import ( ChatGeneration, ChatResult, LLMResult, ) from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage +from langchain.schema.output import ChatGenerationChunk, GenerationChunk from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -191,6 +195,108 @@ async def test_async_chat_openai_streaming() -> None: assert generation.text == generation.message.content +@pytest.mark.scheduled +@pytest.mark.asyncio +async def test_async_chat_openai_streaming_with_function() -> None: + """Test ChatOpenAI wrapper with multiple completions.""" + + class MyCustomAsyncHandler(AsyncCallbackHandler): + def __init__(self) -> None: + super().__init__() + self._captured_tokens: List[str] = [] + self._captured_chunks: List[ + Optional[Union[ChatGenerationChunk, GenerationChunk]] + ] = [] + + def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[ChatGenerationChunk, GenerationChunk]] = None, + **kwargs: Any, + ) -> Any: + self._captured_tokens.append(token) + self._captured_chunks.append(chunk) + + json_schema = { + "title": "Person", + "description": "Identifying information about a person.", + "type": "object", + "properties": { + "name": { + "title": "Name", + "description": "The person's name", + "type": "string", + }, + "age": { + "title": "Age", + "description": "The person's age", + "type": "integer", + }, + "fav_food": { + "title": "Fav Food", + "description": "The person's favorite food", + "type": "string", + }, + }, + "required": ["name", "age"], + } + + callback_handler = MyCustomAsyncHandler() + callback_manager = CallbackManager([callback_handler]) + + chat = ChatOpenAI( + max_tokens=10, + n=1, + callback_manager=callback_manager, + streaming=True, + ) + + prompt_msgs = [ + SystemMessage( + content="You are a world class algorithm for " + "extracting information in structured formats." + ), + HumanMessage( + content="Use the given format to extract " + "information from the following input:" + ), + HumanMessagePromptTemplate.from_template("{input}"), + HumanMessage(content="Tips: Make sure to answer in the correct format"), + ] + prompt = ChatPromptTemplate(messages=prompt_msgs) + + function: Any = { + "name": "output_formatter", + "description": ( + "Output formatter. Should always be used to format your response to the" + " user." + ), + "parameters": json_schema, + } + chain = create_openai_fn_chain( + [function], + chat, + prompt, + output_parser=None, + ) + + message = HumanMessage(content="Sally is 13 years old") + response = await chain.agenerate([{"input": message}]) + + assert isinstance(response, LLMResult) + assert len(response.generations) == 1 + for generations in response.generations: + assert len(generations) == 1 + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content + assert len(callback_handler._captured_tokens) > 0 + assert len(callback_handler._captured_chunks) > 0 + assert all([chunk is not None for chunk in callback_handler._captured_chunks]) + + def test_chat_openai_extra_kwargs() -> None: """Test extra kwargs to chat openai.""" # Check that foo is saved in extra_kwargs.