mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-21 23:17:48 +00:00
parent
2ab04a4e32
commit
1960ac8d25
@ -10,7 +10,7 @@ if TYPE_CHECKING:
|
|||||||
from langchain.schema.agent import AgentAction, AgentFinish
|
from langchain.schema.agent import AgentAction, AgentFinish
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import BaseMessage
|
||||||
from langchain.schema.output import LLMResult
|
from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class RetrieverManagerMixin:
|
class RetrieverManagerMixin:
|
||||||
@ -44,11 +44,18 @@ class LLMManagerMixin:
|
|||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
*,
|
*,
|
||||||
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
"""Run on new LLM token. Only available when streaming is enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (str): The new token.
|
||||||
|
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
|
||||||
|
containing content and other information.
|
||||||
|
"""
|
||||||
|
|
||||||
def on_llm_end(
|
def on_llm_end(
|
||||||
self,
|
self,
|
||||||
@ -316,6 +323,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
*,
|
*,
|
||||||
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
|
@ -49,6 +49,7 @@ from langchain.schema import (
|
|||||||
LLMResult,
|
LLMResult,
|
||||||
)
|
)
|
||||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||||
|
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langsmith import Client as LangSmithClient
|
from langsmith import Client as LangSmithClient
|
||||||
@ -592,6 +593,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
|||||||
def on_llm_new_token(
|
def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
|
*,
|
||||||
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM generates a new token.
|
"""Run when LLM generates a new token.
|
||||||
@ -607,6 +610,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
|||||||
run_id=self.run_id,
|
run_id=self.run_id,
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
tags=self.tags,
|
tags=self.tags,
|
||||||
|
chunk=chunk,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -655,6 +659,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|||||||
async def on_llm_new_token(
|
async def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
|
*,
|
||||||
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM generates a new token.
|
"""Run when LLM generates a new token.
|
||||||
@ -667,6 +673,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|||||||
"on_llm_new_token",
|
"on_llm_new_token",
|
||||||
"ignore_llm",
|
"ignore_llm",
|
||||||
token,
|
token,
|
||||||
|
chunk=chunk,
|
||||||
run_id=self.run_id,
|
run_id=self.run_id,
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
tags=self.tags,
|
tags=self.tags,
|
||||||
|
@ -13,7 +13,12 @@ from langchain.callbacks.base import BaseCallbackHandler
|
|||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
from langchain.schema.output import ChatGeneration, LLMResult
|
from langchain.schema.output import (
|
||||||
|
ChatGeneration,
|
||||||
|
ChatGenerationChunk,
|
||||||
|
GenerationChunk,
|
||||||
|
LLMResult,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -123,6 +128,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
*,
|
*,
|
||||||
|
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -135,11 +141,14 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
llm_run = self.run_map.get(run_id_)
|
llm_run = self.run_map.get(run_id_)
|
||||||
if llm_run is None or llm_run.run_type != "llm":
|
if llm_run is None or llm_run.run_type != "llm":
|
||||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
||||||
|
event_kwargs: Dict[str, Any] = {"token": token}
|
||||||
|
if chunk:
|
||||||
|
event_kwargs["chunk"] = chunk
|
||||||
llm_run.events.append(
|
llm_run.events.append(
|
||||||
{
|
{
|
||||||
"name": "new_token",
|
"name": "new_token",
|
||||||
"time": datetime.utcnow(),
|
"time": datetime.utcnow(),
|
||||||
"kwargs": {"token": token},
|
"kwargs": event_kwargs,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -318,7 +318,7 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
default_chunk_class = chunk.__class__
|
default_chunk_class = chunk.__class__
|
||||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(chunk.content)
|
run_manager.on_llm_new_token(chunk.content, chunk=chunk)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
@ -398,7 +398,7 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
default_chunk_class = chunk.__class__
|
default_chunk_class = chunk.__class__
|
||||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
await run_manager.on_llm_new_token(chunk.content)
|
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
|
@ -289,9 +289,10 @@ class Anthropic(LLM, _AnthropicCommon):
|
|||||||
for token in self.client.completions.create(
|
for token in self.client.completions.create(
|
||||||
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
|
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
|
||||||
):
|
):
|
||||||
yield GenerationChunk(text=token.completion)
|
chunk = GenerationChunk(text=token.completion)
|
||||||
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(token.completion)
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
@ -324,9 +325,10 @@ class Anthropic(LLM, _AnthropicCommon):
|
|||||||
stream=True,
|
stream=True,
|
||||||
**params,
|
**params,
|
||||||
):
|
):
|
||||||
yield GenerationChunk(text=token.completion)
|
chunk = GenerationChunk(text=token.completion)
|
||||||
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
await run_manager.on_llm_new_token(token.completion)
|
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
"""Calculate number of tokens."""
|
"""Calculate number of tokens."""
|
||||||
|
@ -297,6 +297,7 @@ class BaseOpenAI(BaseLLM):
|
|||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(
|
run_manager.on_llm_new_token(
|
||||||
chunk.text,
|
chunk.text,
|
||||||
|
chunk=chunk,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
logprobs=chunk.generation_info["logprobs"]
|
logprobs=chunk.generation_info["logprobs"]
|
||||||
if chunk.generation_info
|
if chunk.generation_info
|
||||||
@ -320,6 +321,7 @@ class BaseOpenAI(BaseLLM):
|
|||||||
if run_manager:
|
if run_manager:
|
||||||
await run_manager.on_llm_new_token(
|
await run_manager.on_llm_new_token(
|
||||||
chunk.text,
|
chunk.text,
|
||||||
|
chunk=chunk,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
logprobs=chunk.generation_info["logprobs"]
|
logprobs=chunk.generation_info["logprobs"]
|
||||||
if chunk.generation_info
|
if chunk.generation_info
|
||||||
@ -825,9 +827,10 @@ class OpenAIChat(BaseLLM):
|
|||||||
self, messages=messages, run_manager=run_manager, **params
|
self, messages=messages, run_manager=run_manager, **params
|
||||||
):
|
):
|
||||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||||
yield GenerationChunk(text=token)
|
chunk = GenerationChunk(text=token)
|
||||||
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(token)
|
run_manager.on_llm_new_token(token, chunk=chunk)
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
@ -842,9 +845,10 @@ class OpenAIChat(BaseLLM):
|
|||||||
self, messages=messages, run_manager=run_manager, **params
|
self, messages=messages, run_manager=run_manager, **params
|
||||||
):
|
):
|
||||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||||
yield GenerationChunk(text=token)
|
chunk = GenerationChunk(text=token)
|
||||||
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
await run_manager.on_llm_new_token(token)
|
await run_manager.on_llm_new_token(token, chunk=chunk)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
|
@ -1,18 +1,22 @@
|
|||||||
"""Test ChatOpenAI wrapper."""
|
"""Test ChatOpenAI wrapper."""
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from langchain.callbacks.base import AsyncCallbackHandler
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
from langchain.chains.openai_functions import (
|
||||||
|
create_openai_fn_chain,
|
||||||
|
)
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
|
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
ChatResult,
|
ChatResult,
|
||||||
LLMResult,
|
LLMResult,
|
||||||
)
|
)
|
||||||
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
|
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
|
||||||
|
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
@ -191,6 +195,108 @@ async def test_async_chat_openai_streaming() -> None:
|
|||||||
assert generation.text == generation.message.content
|
assert generation.text == generation.message.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_openai_streaming_with_function() -> None:
|
||||||
|
"""Test ChatOpenAI wrapper with multiple completions."""
|
||||||
|
|
||||||
|
class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._captured_tokens: List[str] = []
|
||||||
|
self._captured_chunks: List[
|
||||||
|
Optional[Union[ChatGenerationChunk, GenerationChunk]]
|
||||||
|
] = []
|
||||||
|
|
||||||
|
def on_llm_new_token(
|
||||||
|
self,
|
||||||
|
token: str,
|
||||||
|
*,
|
||||||
|
chunk: Optional[Union[ChatGenerationChunk, GenerationChunk]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self._captured_tokens.append(token)
|
||||||
|
self._captured_chunks.append(chunk)
|
||||||
|
|
||||||
|
json_schema = {
|
||||||
|
"title": "Person",
|
||||||
|
"description": "Identifying information about a person.",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"title": "Name",
|
||||||
|
"description": "The person's name",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"age": {
|
||||||
|
"title": "Age",
|
||||||
|
"description": "The person's age",
|
||||||
|
"type": "integer",
|
||||||
|
},
|
||||||
|
"fav_food": {
|
||||||
|
"title": "Fav Food",
|
||||||
|
"description": "The person's favorite food",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["name", "age"],
|
||||||
|
}
|
||||||
|
|
||||||
|
callback_handler = MyCustomAsyncHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
|
||||||
|
chat = ChatOpenAI(
|
||||||
|
max_tokens=10,
|
||||||
|
n=1,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_msgs = [
|
||||||
|
SystemMessage(
|
||||||
|
content="You are a world class algorithm for "
|
||||||
|
"extracting information in structured formats."
|
||||||
|
),
|
||||||
|
HumanMessage(
|
||||||
|
content="Use the given format to extract "
|
||||||
|
"information from the following input:"
|
||||||
|
),
|
||||||
|
HumanMessagePromptTemplate.from_template("{input}"),
|
||||||
|
HumanMessage(content="Tips: Make sure to answer in the correct format"),
|
||||||
|
]
|
||||||
|
prompt = ChatPromptTemplate(messages=prompt_msgs)
|
||||||
|
|
||||||
|
function: Any = {
|
||||||
|
"name": "output_formatter",
|
||||||
|
"description": (
|
||||||
|
"Output formatter. Should always be used to format your response to the"
|
||||||
|
" user."
|
||||||
|
),
|
||||||
|
"parameters": json_schema,
|
||||||
|
}
|
||||||
|
chain = create_openai_fn_chain(
|
||||||
|
[function],
|
||||||
|
chat,
|
||||||
|
prompt,
|
||||||
|
output_parser=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = HumanMessage(content="Sally is 13 years old")
|
||||||
|
response = await chain.agenerate([{"input": message}])
|
||||||
|
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 1
|
||||||
|
for generations in response.generations:
|
||||||
|
assert len(generations) == 1
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.text == generation.message.content
|
||||||
|
assert len(callback_handler._captured_tokens) > 0
|
||||||
|
assert len(callback_handler._captured_chunks) > 0
|
||||||
|
assert all([chunk is not None for chunk in callback_handler._captured_chunks])
|
||||||
|
|
||||||
|
|
||||||
def test_chat_openai_extra_kwargs() -> None:
|
def test_chat_openai_extra_kwargs() -> None:
|
||||||
"""Test extra kwargs to chat openai."""
|
"""Test extra kwargs to chat openai."""
|
||||||
# Check that foo is saved in extra_kwargs.
|
# Check that foo is saved in extra_kwargs.
|
||||||
|
Loading…
Reference in New Issue
Block a user