mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 19:18:48 +00:00
fix chat models and tests
This commit is contained in:
@@ -674,26 +674,33 @@ class CallbackManager(BaseCallbackManager):
|
||||
messages: List[List[BaseMessage]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForLLMRun:
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
messages,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Re-use the LLM Run Manager since the outputs are treated
|
||||
# the same for now
|
||||
return CallbackManagerForLLMRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
managers = []
|
||||
for message in messages:
|
||||
run_id_ = uuid4()
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
[message],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
managers.append(
|
||||
CallbackManagerForLLMRun(
|
||||
run_id_,
|
||||
self.handlers,
|
||||
self.inheritable_handlers,
|
||||
self.parent_run_id,
|
||||
)
|
||||
)
|
||||
|
||||
return managers
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
@@ -815,23 +822,36 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
tasks = []
|
||||
managers = []
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
messages,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
for message in messages:
|
||||
run_id_ = uuid4()
|
||||
|
||||
return AsyncCallbackManagerForLLMRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
tasks.append(
|
||||
_ahandle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
[message],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
managers.append(
|
||||
AsyncCallbackManagerForLLMRun(
|
||||
run_id_,
|
||||
self.handlers,
|
||||
self.inheritable_handlers,
|
||||
self.parent_run_id,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
return managers
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
|
||||
@@ -73,7 +73,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = callback_manager.on_chat_model_start(
|
||||
run_managers = callback_manager.on_chat_model_start(
|
||||
{"name": self.__class__.__name__}, messages, invocation_params=params
|
||||
)
|
||||
|
||||
@@ -82,20 +82,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
)
|
||||
try:
|
||||
results = [
|
||||
self._generate(m, stop=stop, run_manager=run_manager)
|
||||
self._generate(m, stop=stop, run_manager=run_managers[0])
|
||||
if new_arg_supported
|
||||
else self._generate(m, stop=stop)
|
||||
for m in messages
|
||||
]
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
run_manager.on_llm_end(output)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
flattened_outputs = output.flatten()
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
|
||||
async def agenerate(
|
||||
@@ -111,7 +116,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = await callback_manager.on_chat_model_start(
|
||||
run_managers = await callback_manager.on_chat_model_start(
|
||||
{"name": self.__class__.__name__}, messages, invocation_params=params
|
||||
)
|
||||
|
||||
@@ -121,21 +126,33 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._agenerate(m, stop=stop, run_manager=run_manager)
|
||||
self._agenerate(m, stop=stop, run_manager=run_managers[0])
|
||||
if new_arg_supported
|
||||
else self._agenerate(m, stop=stop)
|
||||
for m in messages
|
||||
]
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
await asyncio.gather(
|
||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
||||
)
|
||||
raise e
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
await run_manager.on_llm_end(output)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
flattened_outputs = output.flatten()
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
for run_manager, flattened_output in zip(
|
||||
run_managers, flattened_outputs
|
||||
)
|
||||
]
|
||||
)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
|
||||
def generate_prompt(
|
||||
|
||||
@@ -18,11 +18,12 @@ def _test_callback_manager(
|
||||
manager: CallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
run_manager = manager.on_llm_start({}, [])
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
run_manager.on_llm_new_token("foo")
|
||||
run_manager.on_text("foo")
|
||||
run_managers = manager.on_llm_start({}, ["prompt"])
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
run_manager.on_llm_new_token("foo")
|
||||
run_manager.on_text("foo")
|
||||
|
||||
run_manager_chain = manager.on_chain_start({"name": "foo"}, {})
|
||||
run_manager_chain.on_chain_end({})
|
||||
@@ -42,11 +43,12 @@ async def _test_callback_manager_async(
|
||||
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
run_manager = await manager.on_llm_start({}, [])
|
||||
await run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
await run_manager.on_llm_error(Exception())
|
||||
await run_manager.on_llm_new_token("foo")
|
||||
await run_manager.on_text("foo")
|
||||
run_managers = await manager.on_llm_start({}, ["prompt"])
|
||||
for run_manager in run_managers:
|
||||
await run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
await run_manager.on_llm_error(Exception())
|
||||
await run_manager.on_llm_new_token("foo")
|
||||
await run_manager.on_text("foo")
|
||||
|
||||
run_manager_chain = await manager.on_chain_start({"name": "foo"}, {})
|
||||
await run_manager_chain.on_chain_end({})
|
||||
@@ -95,9 +97,10 @@ def test_ignore_llm() -> None:
|
||||
handler1 = FakeCallbackHandler(ignore_llm_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
run_manager = manager.on_llm_start({}, [])
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
run_managers = manager.on_llm_start({}, ["prompt"])
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
assert handler1.starts == 0
|
||||
assert handler1.ends == 0
|
||||
assert handler1.errors == 0
|
||||
|
||||
@@ -11,7 +11,7 @@ from freezegun import freeze_time
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.tracers.base import BaseTracer, TracerException
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema import HumanMessage, LLMResult
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
@@ -55,9 +55,13 @@ def test_tracer_llm_run() -> None:
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chat_model_run() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
uuid = uuid4()
|
||||
tracer = FakeTracer()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_managers = manager.on_chat_model_start(
|
||||
serialized={"name": "chat_model"}, messages=[[HumanMessage(content="")]]
|
||||
)
|
||||
compare_run = Run(
|
||||
id=str(uuid),
|
||||
id=str(run_managers[0].run_id),
|
||||
name="chat_model",
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
@@ -65,17 +69,13 @@ def test_tracer_chat_model_run() -> None:
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chat_model"},
|
||||
inputs=dict(prompts=[""]),
|
||||
inputs=dict(prompts=["Human: "]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
run_type="llm",
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_manager = manager.on_chat_model_start(
|
||||
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
|
||||
)
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from langchain.callbacks.tracers.langchain_v1 import (
|
||||
TracerSessionV1,
|
||||
)
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema import HumanMessage, LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
|
||||
@@ -124,9 +124,15 @@ def test_tracer_llm_run() -> None:
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chat_model_run() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
uuid = uuid4()
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_managers = manager.on_chat_model_start(
|
||||
serialized={"name": "chat_model"}, messages=[[HumanMessage(content="")]]
|
||||
)
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
uuid=str(run_managers[0].run_id),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
@@ -134,19 +140,13 @@ def test_tracer_chat_model_run() -> None:
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chat_model"},
|
||||
prompts=[""],
|
||||
prompts=["Human: "],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_manager = manager.on_chat_model_start(
|
||||
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
|
||||
)
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user