Compare commits

...

1 Commits

Author SHA1 Message Date
Nuno Campos
b7d5d67402 A suggestion on how to implement PoeHandler for BaseChatModel 2023-04-02 22:09:20 +01:00

View File

@@ -6,7 +6,11 @@ from pydantic import BaseModel, Extra, Field, validator
import langchain
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.schema import (
AIMessage,
BaseLanguageModel,
@@ -16,6 +20,7 @@ from langchain.schema import (
HumanMessage,
LLMResult,
PromptValue,
SystemMessage,
)
@@ -131,6 +136,53 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
result = self([HumanMessage(content=message)], stop=stop)
return result.content
async def as_poe_handler(self):
model = self
class LLMChainPoeHandler(PoeHandler):
async def get_response(self, query):
callback_handler = PoeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
model.callback_manager = callback_manager
run = asyncio.create_task(
model([poe_msg_to_lc_msg(msg) for msg in query.query])
)
while not callback_handler.done.is_set():
token = await callback_handler.queue.get()
yield token
await run
return LLMChainPoeHandler()
class PoeCallbackHandler(BaseCallbackHandler):
def __init__(self):
self.queue = asyncio.Queue()
self.done = asyncio.Event()
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]):
pass
def on_llm_new_token(self, token: str):
self.queue.put_nowait(token)
def on_llm_end(self, serialized: Dict[str, Any], prompts: List[str]):
self.done.set()
def poe_msg_to_lc_msg(msg: ProtocolMessage) -> BaseMessage:
if msg.type == "human":
return HumanMessage(content=msg.text)
elif msg.type == "bot" or msg.type == "assistant":
return AIMessage(content=msg.text)
elif msg.type == "system":
return SystemMessage(content=msg.text)
else:
raise ValueError(f"Unknown message type: {msg.type}")
class SimpleChatModel(BaseChatModel):
def _generate(