Compare commits

...

1 Commits

Author SHA1 Message Date
Nuno Campos
35484bfdcb A suggestion on how to implement a PoeHandler 2023-04-02 21:22:34 +01:00

View File

@@ -1,10 +1,12 @@
"""Chain that just formats a prompt and calls an LLM."""
from __future__ import annotations
import asyncio
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import BaseModel, Extra
from langchain.callbacks.base import BaseCallbackHandler, CallbackManager
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.prompts.base import BasePromptTemplate
@@ -217,3 +219,38 @@ class LLMChain(Chain, BaseModel):
"""Create LLMChain from LLM and template."""
prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template)
async def as_poe_handler(self):
chain = self
class LLMChainPoeHandler(PoeHandler):
async def get_response(self, query):
callback_handler = PoeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chain.callback_manager = callback_manager
chain.llm.callback_manager = callback_manager
run = asyncio.create_task(chain.acall(query))
while not callback_handler.done.is_set():
token = await callback_handler.queue.get()
yield token
await run
return LLMChainPoeHandler()
class PoeCallbackHandler(BaseCallbackHandler):
def __init__(self):
self.queue = asyncio.Queue()
self.done = asyncio.Event()
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]):
pass
def on_llm_new_token(self, token: str):
self.queue.put_nowait(token)
def on_llm_end(self, serialized: Dict[str, Any], prompts: List[str]):
self.done.set()