diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 62cc7a9112a..a95368234e0 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -1,10 +1,12 @@ """Chain that just formats a prompt and calls an LLM.""" from __future__ import annotations +import asyncio from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from pydantic import BaseModel, Extra +from langchain.callbacks.base import BaseCallbackHandler, CallbackManager from langchain.chains.base import Chain from langchain.input import get_colored_text from langchain.prompts.base import BasePromptTemplate @@ -217,3 +219,38 @@ class LLMChain(Chain, BaseModel): """Create LLMChain from LLM and template.""" prompt_template = PromptTemplate.from_template(template) return cls(llm=llm, prompt=prompt_template) + + async def as_poe_handler(self): + chain = self + + class LLMChainPoeHandler(PoeHandler): + async def get_response(self, query): + callback_handler = PoeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chain.callback_manager = callback_manager + chain.llm.callback_manager = callback_manager + + run = asyncio.create_task(chain.acall(query)) + + while not callback_handler.done.is_set(): + token = await callback_handler.queue.get() + yield token + + await run + + return LLMChainPoeHandler() + + +class PoeCallbackHandler(BaseCallbackHandler): + def __init__(self): + self.queue = asyncio.Queue() + self.done = asyncio.Event() + + def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]): + pass + + def on_llm_new_token(self, token: str): + self.queue.put_nowait(token) + + def on_llm_end(self, serialized: Dict[str, Any], prompts: List[str]): + self.done.set()