token chunks (#9739)

Co-authored-by: Andrew <abatutin@gmail.com>
This commit is contained in:
William FH 2023-08-25 12:52:07 -07:00 committed by GitHub
parent 2ab04a4e32
commit 1960ac8d25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 153 additions and 17 deletions

View File

@ -10,7 +10,7 @@ if TYPE_CHECKING:
from langchain.schema.agent import AgentAction, AgentFinish from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.document import Document from langchain.schema.document import Document
from langchain.schema.messages import BaseMessage from langchain.schema.messages import BaseMessage
from langchain.schema.output import LLMResult from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult
class RetrieverManagerMixin: class RetrieverManagerMixin:
@ -44,11 +44,18 @@ class LLMManagerMixin:
self, self,
token: str, token: str,
*, *,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run on new LLM token. Only available when streaming is enabled.""" """Run on new LLM token. Only available when streaming is enabled.
Args:
token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information.
"""
def on_llm_end( def on_llm_end(
self, self,
@ -316,6 +323,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
self, self,
token: str, token: str,
*, *,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,

View File

@ -49,6 +49,7 @@ from langchain.schema import (
LLMResult, LLMResult,
) )
from langchain.schema.messages import BaseMessage, get_buffer_string from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
if TYPE_CHECKING: if TYPE_CHECKING:
from langsmith import Client as LangSmithClient from langsmith import Client as LangSmithClient
@ -592,6 +593,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
def on_llm_new_token( def on_llm_new_token(
self, self,
token: str, token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when LLM generates a new token. """Run when LLM generates a new token.
@ -607,6 +610,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
run_id=self.run_id, run_id=self.run_id,
parent_run_id=self.parent_run_id, parent_run_id=self.parent_run_id,
tags=self.tags, tags=self.tags,
chunk=chunk,
**kwargs, **kwargs,
) )
@ -655,6 +659,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
async def on_llm_new_token( async def on_llm_new_token(
self, self,
token: str, token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when LLM generates a new token. """Run when LLM generates a new token.
@ -667,6 +673,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
"on_llm_new_token", "on_llm_new_token",
"ignore_llm", "ignore_llm",
token, token,
chunk=chunk,
run_id=self.run_id, run_id=self.run_id,
parent_run_id=self.parent_run_id, parent_run_id=self.parent_run_id,
tags=self.tags, tags=self.tags,

View File

@ -13,7 +13,12 @@ from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.schemas import Run
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.schema.document import Document from langchain.schema.document import Document
from langchain.schema.output import ChatGeneration, LLMResult from langchain.schema.output import (
ChatGeneration,
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -123,6 +128,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
self, self,
token: str, token: str,
*, *,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
@ -135,11 +141,14 @@ class BaseTracer(BaseCallbackHandler, ABC):
llm_run = self.run_map.get(run_id_) llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm": if llm_run is None or llm_run.run_type != "llm":
raise TracerException(f"No LLM Run found to be traced for {run_id}") raise TracerException(f"No LLM Run found to be traced for {run_id}")
event_kwargs: Dict[str, Any] = {"token": token}
if chunk:
event_kwargs["chunk"] = chunk
llm_run.events.append( llm_run.events.append(
{ {
"name": "new_token", "name": "new_token",
"time": datetime.utcnow(), "time": datetime.utcnow(),
"kwargs": {"token": token}, "kwargs": event_kwargs,
}, },
) )

View File

@ -318,7 +318,7 @@ class ChatOpenAI(BaseChatModel):
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info) yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.content) run_manager.on_llm_new_token(chunk.content, chunk=chunk)
def _generate( def _generate(
self, self,
@ -398,7 +398,7 @@ class ChatOpenAI(BaseChatModel):
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info) yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
if run_manager: if run_manager:
await run_manager.on_llm_new_token(chunk.content) await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
async def _agenerate( async def _agenerate(
self, self,

View File

@ -289,9 +289,10 @@ class Anthropic(LLM, _AnthropicCommon):
for token in self.client.completions.create( for token in self.client.completions.create(
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
): ):
yield GenerationChunk(text=token.completion) chunk = GenerationChunk(text=token.completion)
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(token.completion) run_manager.on_llm_new_token(chunk.text, chunk=chunk)
async def _astream( async def _astream(
self, self,
@ -324,9 +325,10 @@ class Anthropic(LLM, _AnthropicCommon):
stream=True, stream=True,
**params, **params,
): ):
yield GenerationChunk(text=token.completion) chunk = GenerationChunk(text=token.completion)
yield chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token(token.completion) await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens.""" """Calculate number of tokens."""

View File

@ -297,6 +297,7 @@ class BaseOpenAI(BaseLLM):
if run_manager: if run_manager:
run_manager.on_llm_new_token( run_manager.on_llm_new_token(
chunk.text, chunk.text,
chunk=chunk,
verbose=self.verbose, verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"] logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info if chunk.generation_info
@ -320,6 +321,7 @@ class BaseOpenAI(BaseLLM):
if run_manager: if run_manager:
await run_manager.on_llm_new_token( await run_manager.on_llm_new_token(
chunk.text, chunk.text,
chunk=chunk,
verbose=self.verbose, verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"] logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info if chunk.generation_info
@ -825,9 +827,10 @@ class OpenAIChat(BaseLLM):
self, messages=messages, run_manager=run_manager, **params self, messages=messages, run_manager=run_manager, **params
): ):
token = stream_resp["choices"][0]["delta"].get("content", "") token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token) chunk = GenerationChunk(text=token)
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(token) run_manager.on_llm_new_token(token, chunk=chunk)
async def _astream( async def _astream(
self, self,
@ -842,9 +845,10 @@ class OpenAIChat(BaseLLM):
self, messages=messages, run_manager=run_manager, **params self, messages=messages, run_manager=run_manager, **params
): ):
token = stream_resp["choices"][0]["delta"].get("content", "") token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token) chunk = GenerationChunk(text=token)
yield chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token(token) await run_manager.on_llm_new_token(token, chunk=chunk)
def _generate( def _generate(
self, self,

View File

@ -1,18 +1,22 @@
"""Test ChatOpenAI wrapper.""" """Test ChatOpenAI wrapper."""
from typing import Any, List, Optional, Union
from typing import Any
import pytest import pytest
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
from langchain.chains.openai_functions import (
create_openai_fn_chain,
)
from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import ( from langchain.schema import (
ChatGeneration, ChatGeneration,
ChatResult, ChatResult,
LLMResult, LLMResult,
) )
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -191,6 +195,108 @@ async def test_async_chat_openai_streaming() -> None:
assert generation.text == generation.message.content assert generation.text == generation.message.content
@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_async_chat_openai_streaming_with_function() -> None:
"""Test ChatOpenAI wrapper with multiple completions."""
class MyCustomAsyncHandler(AsyncCallbackHandler):
def __init__(self) -> None:
super().__init__()
self._captured_tokens: List[str] = []
self._captured_chunks: List[
Optional[Union[ChatGenerationChunk, GenerationChunk]]
] = []
def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[ChatGenerationChunk, GenerationChunk]] = None,
**kwargs: Any,
) -> Any:
self._captured_tokens.append(token)
self._captured_chunks.append(chunk)
json_schema = {
"title": "Person",
"description": "Identifying information about a person.",
"type": "object",
"properties": {
"name": {
"title": "Name",
"description": "The person's name",
"type": "string",
},
"age": {
"title": "Age",
"description": "The person's age",
"type": "integer",
},
"fav_food": {
"title": "Fav Food",
"description": "The person's favorite food",
"type": "string",
},
},
"required": ["name", "age"],
}
callback_handler = MyCustomAsyncHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatOpenAI(
max_tokens=10,
n=1,
callback_manager=callback_manager,
streaming=True,
)
prompt_msgs = [
SystemMessage(
content="You are a world class algorithm for "
"extracting information in structured formats."
),
HumanMessage(
content="Use the given format to extract "
"information from the following input:"
),
HumanMessagePromptTemplate.from_template("{input}"),
HumanMessage(content="Tips: Make sure to answer in the correct format"),
]
prompt = ChatPromptTemplate(messages=prompt_msgs)
function: Any = {
"name": "output_formatter",
"description": (
"Output formatter. Should always be used to format your response to the"
" user."
),
"parameters": json_schema,
}
chain = create_openai_fn_chain(
[function],
chat,
prompt,
output_parser=None,
)
message = HumanMessage(content="Sally is 13 years old")
response = await chain.agenerate([{"input": message}])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
assert len(callback_handler._captured_tokens) > 0
assert len(callback_handler._captured_chunks) > 0
assert all([chunk is not None for chunk in callback_handler._captured_chunks])
def test_chat_openai_extra_kwargs() -> None: def test_chat_openai_extra_kwargs() -> None:
"""Test extra kwargs to chat openai.""" """Test extra kwargs to chat openai."""
# Check that foo is saved in extra_kwargs. # Check that foo is saved in extra_kwargs.