mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-21 15:07:35 +00:00
parent
2ab04a4e32
commit
1960ac8d25
@ -10,7 +10,7 @@ if TYPE_CHECKING:
|
||||
from langchain.schema.agent import AgentAction, AgentFinish
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.output import LLMResult
|
||||
from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||
|
||||
|
||||
class RetrieverManagerMixin:
|
||||
@ -44,11 +44,18 @@ class LLMManagerMixin:
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
"""Run on new LLM token. Only available when streaming is enabled.
|
||||
|
||||
Args:
|
||||
token (str): The new token.
|
||||
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
|
||||
containing content and other information.
|
||||
"""
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
@ -316,6 +323,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
|
@ -49,6 +49,7 @@ from langchain.schema import (
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langsmith import Client as LangSmithClient
|
||||
@ -592,6 +593,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token.
|
||||
@ -607,6 +610,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
chunk=chunk,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -655,6 +659,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token.
|
||||
@ -667,6 +673,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
"on_llm_new_token",
|
||||
"ignore_llm",
|
||||
token,
|
||||
chunk=chunk,
|
||||
run_id=self.run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
|
@ -13,7 +13,12 @@ from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.output import ChatGeneration, LLMResult
|
||||
from langchain.schema.output import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
GenerationChunk,
|
||||
LLMResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -123,6 +128,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
@ -135,11 +141,14 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
||||
event_kwargs: Dict[str, Any] = {"token": token}
|
||||
if chunk:
|
||||
event_kwargs["chunk"] = chunk
|
||||
llm_run.events.append(
|
||||
{
|
||||
"name": "new_token",
|
||||
"time": datetime.utcnow(),
|
||||
"kwargs": {"token": token},
|
||||
"kwargs": event_kwargs,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -318,7 +318,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content)
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=chunk)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -398,7 +398,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.content)
|
||||
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
@ -289,9 +289,10 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
for token in self.client.completions.create(
|
||||
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
|
||||
):
|
||||
yield GenerationChunk(text=token.completion)
|
||||
chunk = GenerationChunk(text=token.completion)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token.completion)
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@ -324,9 +325,10 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
stream=True,
|
||||
**params,
|
||||
):
|
||||
yield GenerationChunk(text=token.completion)
|
||||
chunk = GenerationChunk(text=token.completion)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(token.completion)
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate number of tokens."""
|
||||
|
@ -297,6 +297,7 @@ class BaseOpenAI(BaseLLM):
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=self.verbose,
|
||||
logprobs=chunk.generation_info["logprobs"]
|
||||
if chunk.generation_info
|
||||
@ -320,6 +321,7 @@ class BaseOpenAI(BaseLLM):
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=self.verbose,
|
||||
logprobs=chunk.generation_info["logprobs"]
|
||||
if chunk.generation_info
|
||||
@ -825,9 +827,10 @@ class OpenAIChat(BaseLLM):
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
yield GenerationChunk(text=token)
|
||||
chunk = GenerationChunk(text=token)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
run_manager.on_llm_new_token(token, chunk=chunk)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@ -842,9 +845,10 @@ class OpenAIChat(BaseLLM):
|
||||
self, messages=messages, run_manager=run_manager, **params
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
yield GenerationChunk(text=token)
|
||||
chunk = GenerationChunk(text=token)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(token)
|
||||
await run_manager.on_llm_new_token(token, chunk=chunk)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
@ -1,18 +1,22 @@
|
||||
"""Test ChatOpenAI wrapper."""
|
||||
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chains.openai_functions import (
|
||||
create_openai_fn_chain,
|
||||
)
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
@ -191,6 +195,108 @@ async def test_async_chat_openai_streaming() -> None:
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_openai_streaming_with_function() -> None:
|
||||
"""Test ChatOpenAI wrapper with multiple completions."""
|
||||
|
||||
class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._captured_tokens: List[str] = []
|
||||
self._captured_chunks: List[
|
||||
Optional[Union[ChatGenerationChunk, GenerationChunk]]
|
||||
] = []
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[ChatGenerationChunk, GenerationChunk]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self._captured_tokens.append(token)
|
||||
self._captured_chunks.append(chunk)
|
||||
|
||||
json_schema = {
|
||||
"title": "Person",
|
||||
"description": "Identifying information about a person.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "The person's name",
|
||||
"type": "string",
|
||||
},
|
||||
"age": {
|
||||
"title": "Age",
|
||||
"description": "The person's age",
|
||||
"type": "integer",
|
||||
},
|
||||
"fav_food": {
|
||||
"title": "Fav Food",
|
||||
"description": "The person's favorite food",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
}
|
||||
|
||||
callback_handler = MyCustomAsyncHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=10,
|
||||
n=1,
|
||||
callback_manager=callback_manager,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
prompt_msgs = [
|
||||
SystemMessage(
|
||||
content="You are a world class algorithm for "
|
||||
"extracting information in structured formats."
|
||||
),
|
||||
HumanMessage(
|
||||
content="Use the given format to extract "
|
||||
"information from the following input:"
|
||||
),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
HumanMessage(content="Tips: Make sure to answer in the correct format"),
|
||||
]
|
||||
prompt = ChatPromptTemplate(messages=prompt_msgs)
|
||||
|
||||
function: Any = {
|
||||
"name": "output_formatter",
|
||||
"description": (
|
||||
"Output formatter. Should always be used to format your response to the"
|
||||
" user."
|
||||
),
|
||||
"parameters": json_schema,
|
||||
}
|
||||
chain = create_openai_fn_chain(
|
||||
[function],
|
||||
chat,
|
||||
prompt,
|
||||
output_parser=None,
|
||||
)
|
||||
|
||||
message = HumanMessage(content="Sally is 13 years old")
|
||||
response = await chain.agenerate([{"input": message}])
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 1
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
assert len(callback_handler._captured_tokens) > 0
|
||||
assert len(callback_handler._captured_chunks) > 0
|
||||
assert all([chunk is not None for chunk in callback_handler._captured_chunks])
|
||||
|
||||
|
||||
def test_chat_openai_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to chat openai."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
|
Loading…
Reference in New Issue
Block a user