fix chat models and tests

This commit is contained in:
Ankush Gola
2023-06-06 22:50:30 -07:00
parent 207a7b7bbd
commit 5fa7ea2146
5 changed files with 120 additions and 80 deletions

View File

@@ -674,26 +674,33 @@ class CallbackManager(BaseCallbackManager):
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
# Re-use the LLM Run Manager since the outputs are treated
# the same for now
return CallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
managers = []
for message in messages:
run_id_ = uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
[message],
run_id=run_id_,
parent_run_id=self.parent_run_id,
**kwargs,
)
managers.append(
CallbackManagerForLLMRun(
run_id_,
self.handlers,
self.inheritable_handlers,
self.parent_run_id,
)
)
return managers
def on_chain_start(
self,
@@ -815,23 +822,36 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if run_id is None:
run_id = uuid4()
tasks = []
managers = []
await _ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
for message in messages:
run_id_ = uuid4()
return AsyncCallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
tasks.append(
_ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
[message],
run_id=run_id_,
parent_run_id=self.parent_run_id,
**kwargs,
)
)
managers.append(
AsyncCallbackManagerForLLMRun(
run_id_,
self.handlers,
self.inheritable_handlers,
self.parent_run_id,
)
)
await asyncio.gather(*tasks)
return managers
async def on_chain_start(
self,

View File

@@ -73,7 +73,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chat_model_start(
run_managers = callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params
)
@@ -82,20 +82,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
)
try:
results = [
self._generate(m, stop=stop, run_manager=run_manager)
self._generate(m, stop=stop, run_manager=run_managers[0])
if new_arg_supported
else self._generate(m, stop=stop)
for m in messages
]
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
flattened_outputs = output.flatten()
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
async def agenerate(
@@ -111,7 +116,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = await callback_manager.on_chat_model_start(
run_managers = await callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params
)
@@ -121,21 +126,33 @@ class BaseChatModel(BaseLanguageModel, ABC):
try:
results = await asyncio.gather(
*[
self._agenerate(m, stop=stop, run_manager=run_manager)
self._agenerate(m, stop=stop, run_manager=run_managers[0])
if new_arg_supported
else self._agenerate(m, stop=stop)
for m in messages
]
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
await asyncio.gather(
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
await run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
flattened_outputs = output.flatten()
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
def generate_prompt(

View File

@@ -18,11 +18,12 @@ def _test_callback_manager(
manager: CallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
run_manager = manager.on_llm_start({}, [])
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
run_manager.on_llm_new_token("foo")
run_manager.on_text("foo")
run_managers = manager.on_llm_start({}, ["prompt"])
for run_manager in run_managers:
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
run_manager.on_llm_new_token("foo")
run_manager.on_text("foo")
run_manager_chain = manager.on_chain_start({"name": "foo"}, {})
run_manager_chain.on_chain_end({})
@@ -42,11 +43,12 @@ async def _test_callback_manager_async(
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
run_manager = await manager.on_llm_start({}, [])
await run_manager.on_llm_end(LLMResult(generations=[]))
await run_manager.on_llm_error(Exception())
await run_manager.on_llm_new_token("foo")
await run_manager.on_text("foo")
run_managers = await manager.on_llm_start({}, ["prompt"])
for run_manager in run_managers:
await run_manager.on_llm_end(LLMResult(generations=[]))
await run_manager.on_llm_error(Exception())
await run_manager.on_llm_new_token("foo")
await run_manager.on_text("foo")
run_manager_chain = await manager.on_chain_start({"name": "foo"}, {})
await run_manager_chain.on_chain_end({})
@@ -95,9 +97,10 @@ def test_ignore_llm() -> None:
handler1 = FakeCallbackHandler(ignore_llm_=True)
handler2 = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2])
run_manager = manager.on_llm_start({}, [])
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
run_managers = manager.on_llm_start({}, ["prompt"])
for run_manager in run_managers:
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
assert handler1.starts == 0
assert handler1.ends == 0
assert handler1.errors == 0

View File

@@ -11,7 +11,7 @@ from freezegun import freeze_time
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.tracers.base import BaseTracer, TracerException
from langchain.callbacks.tracers.schemas import Run
from langchain.schema import LLMResult
from langchain.schema import HumanMessage, LLMResult
class FakeTracer(BaseTracer):
@@ -55,9 +55,13 @@ def test_tracer_llm_run() -> None:
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
uuid = uuid4()
tracer = FakeTracer()
manager = CallbackManager(handlers=[tracer])
run_managers = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[HumanMessage(content="")]]
)
compare_run = Run(
id=str(uuid),
id=str(run_managers[0].run_id),
name="chat_model",
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
@@ -65,17 +69,13 @@ def test_tracer_chat_model_run() -> None:
execution_order=1,
child_execution_order=1,
serialized={"name": "chat_model"},
inputs=dict(prompts=[""]),
inputs=dict(prompts=["Human: "]),
outputs=LLMResult(generations=[[]]),
error=None,
run_type="llm",
)
tracer = FakeTracer()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
for run_manager in run_managers:
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]

View File

@@ -18,7 +18,7 @@ from langchain.callbacks.tracers.langchain_v1 import (
TracerSessionV1,
)
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base
from langchain.schema import LLMResult
from langchain.schema import HumanMessage, LLMResult
TEST_SESSION_ID = 2023
@@ -124,9 +124,15 @@ def test_tracer_llm_run() -> None:
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
uuid = uuid4()
tracer = FakeTracer()
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_managers = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[HumanMessage(content="")]]
)
compare_run = LLMRun(
uuid=str(uuid),
uuid=str(run_managers[0].run_id),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
@@ -134,19 +140,13 @@ def test_tracer_chat_model_run() -> None:
execution_order=1,
child_execution_order=1,
serialized={"name": "chat_model"},
prompts=[""],
prompts=["Human: "],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
for run_manager in run_managers:
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]