token chunks (#9739)

Co-authored-by: Andrew <abatutin@gmail.com>
This commit is contained in:
William FH 2023-08-25 12:52:07 -07:00 committed by GitHub
parent 2ab04a4e32
commit 1960ac8d25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 153 additions and 17 deletions

View File

@ -10,7 +10,7 @@ if TYPE_CHECKING:
from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.document import Document
from langchain.schema.messages import BaseMessage
from langchain.schema.output import LLMResult
from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult
class RetrieverManagerMixin:
@ -44,11 +44,18 @@ class LLMManagerMixin:
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""
"""Run on new LLM token. Only available when streaming is enabled.
Args:
token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information.
"""
def on_llm_end(
self,
@ -316,6 +323,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,

View File

@ -49,6 +49,7 @@ from langchain.schema import (
LLMResult,
)
from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
if TYPE_CHECKING:
from langsmith import Client as LangSmithClient
@ -592,6 +593,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token.
@ -607,6 +610,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
chunk=chunk,
**kwargs,
)
@ -655,6 +659,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
async def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token.
@ -667,6 +673,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
"on_llm_new_token",
"ignore_llm",
token,
chunk=chunk,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,

View File

@ -13,7 +13,12 @@ from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import Run
from langchain.load.dump import dumpd
from langchain.schema.document import Document
from langchain.schema.output import ChatGeneration, LLMResult
from langchain.schema.output import (
ChatGeneration,
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)
logger = logging.getLogger(__name__)
@ -123,6 +128,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
@ -135,11 +141,14 @@ class BaseTracer(BaseCallbackHandler, ABC):
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException(f"No LLM Run found to be traced for {run_id}")
event_kwargs: Dict[str, Any] = {"token": token}
if chunk:
event_kwargs["chunk"] = chunk
llm_run.events.append(
{
"name": "new_token",
"time": datetime.utcnow(),
"kwargs": {"token": token},
"kwargs": event_kwargs,
},
)

View File

@ -318,7 +318,7 @@ class ChatOpenAI(BaseChatModel):
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
if run_manager:
run_manager.on_llm_new_token(chunk.content)
run_manager.on_llm_new_token(chunk.content, chunk=chunk)
def _generate(
self,
@ -398,7 +398,7 @@ class ChatOpenAI(BaseChatModel):
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
async def _agenerate(
self,

View File

@ -289,9 +289,10 @@ class Anthropic(LLM, _AnthropicCommon):
for token in self.client.completions.create(
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
):
yield GenerationChunk(text=token.completion)
chunk = GenerationChunk(text=token.completion)
yield chunk
if run_manager:
run_manager.on_llm_new_token(token.completion)
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
async def _astream(
self,
@ -324,9 +325,10 @@ class Anthropic(LLM, _AnthropicCommon):
stream=True,
**params,
):
yield GenerationChunk(text=token.completion)
chunk = GenerationChunk(text=token.completion)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(token.completion)
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""

View File

@ -297,6 +297,7 @@ class BaseOpenAI(BaseLLM):
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
@ -320,6 +321,7 @@ class BaseOpenAI(BaseLLM):
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
@ -825,9 +827,10 @@ class OpenAIChat(BaseLLM):
self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token)
chunk = GenerationChunk(text=token)
yield chunk
if run_manager:
run_manager.on_llm_new_token(token)
run_manager.on_llm_new_token(token, chunk=chunk)
async def _astream(
self,
@ -842,9 +845,10 @@ class OpenAIChat(BaseLLM):
self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token)
chunk = GenerationChunk(text=token)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(token)
await run_manager.on_llm_new_token(token, chunk=chunk)
def _generate(
self,

View File

@ -1,18 +1,22 @@
"""Test ChatOpenAI wrapper."""
from typing import Any
from typing import Any, List, Optional, Union
import pytest
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.manager import CallbackManager
from langchain.chains.openai_functions import (
create_openai_fn_chain,
)
from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import (
ChatGeneration,
ChatResult,
LLMResult,
)
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -191,6 +195,108 @@ async def test_async_chat_openai_streaming() -> None:
assert generation.text == generation.message.content
@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_async_chat_openai_streaming_with_function() -> None:
"""Test ChatOpenAI wrapper with multiple completions."""
class MyCustomAsyncHandler(AsyncCallbackHandler):
def __init__(self) -> None:
super().__init__()
self._captured_tokens: List[str] = []
self._captured_chunks: List[
Optional[Union[ChatGenerationChunk, GenerationChunk]]
] = []
def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[ChatGenerationChunk, GenerationChunk]] = None,
**kwargs: Any,
) -> Any:
self._captured_tokens.append(token)
self._captured_chunks.append(chunk)
json_schema = {
"title": "Person",
"description": "Identifying information about a person.",
"type": "object",
"properties": {
"name": {
"title": "Name",
"description": "The person's name",
"type": "string",
},
"age": {
"title": "Age",
"description": "The person's age",
"type": "integer",
},
"fav_food": {
"title": "Fav Food",
"description": "The person's favorite food",
"type": "string",
},
},
"required": ["name", "age"],
}
callback_handler = MyCustomAsyncHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatOpenAI(
max_tokens=10,
n=1,
callback_manager=callback_manager,
streaming=True,
)
prompt_msgs = [
SystemMessage(
content="You are a world class algorithm for "
"extracting information in structured formats."
),
HumanMessage(
content="Use the given format to extract "
"information from the following input:"
),
HumanMessagePromptTemplate.from_template("{input}"),
HumanMessage(content="Tips: Make sure to answer in the correct format"),
]
prompt = ChatPromptTemplate(messages=prompt_msgs)
function: Any = {
"name": "output_formatter",
"description": (
"Output formatter. Should always be used to format your response to the"
" user."
),
"parameters": json_schema,
}
chain = create_openai_fn_chain(
[function],
chat,
prompt,
output_parser=None,
)
message = HumanMessage(content="Sally is 13 years old")
response = await chain.agenerate([{"input": message}])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
assert len(callback_handler._captured_tokens) > 0
assert len(callback_handler._captured_chunks) > 0
assert all([chunk is not None for chunk in callback_handler._captured_chunks])
def test_chat_openai_extra_kwargs() -> None:
"""Test extra kwargs to chat openai."""
# Check that foo is saved in extra_kwargs.