DB-GPT/dbgpt/rag/transformer/llm_extractor.py

96 lines
3.0 KiB
Python

"""TripletExtractor class."""
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
from dbgpt.rag.transformer.base import ExtractorBase
logger = logging.getLogger(__name__)
class LLMExtractor(ExtractorBase, ABC):
"""LLMExtractor class."""
def __init__(self, llm_client: LLMClient, model_name: str, prompt_template: str):
"""Initialize the LLMExtractor."""
self._llm_client = llm_client
self._model_name = model_name
self._prompt_template = prompt_template
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract by LLM."""
return await self._extract(text, None, limit)
async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract by LLM."""
if batch_size < 1:
raise ValueError("batch_size >= 1")
results = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
# Create tasks for current batch
extraction_tasks = [
self._extract(text, None, limit) for text in batch_texts
]
# Execute batch concurrently and wait for all to complete
batch_results = await asyncio.gather(*extraction_tasks)
results.extend(batch_results)
return results
async def _extract(
self, text: str, history: str = None, limit: Optional[int] = None
) -> List:
"""Inner extract by LLM."""
template = HumanPromptTemplate.from_template(self._prompt_template)
messages = (
template.format_messages(text=text, history=history)
if history is not None
else template.format_messages(text=text)
)
# use default model if needed
if not self._model_name:
models = await self._llm_client.models()
if not models:
raise Exception("No models available")
self._model_name = models[0].model
logger.info(f"Using model {self._model_name} to extract")
model_messages = ModelMessage.from_base_messages(messages)
request = ModelRequest(model=self._model_name, messages=model_messages)
response = await self._llm_client.generate(request=request)
if not response.success:
code = str(response.error_code)
reason = response.text
logger.error(f"request llm failed ({code}) {reason}")
return []
if limit and limit < 1:
ValueError("optional argument limit >= 1")
return self._parse_response(response.text, limit)
def truncate(self):
"""Do nothing by default."""
def drop(self):
"""Do nothing by default."""
@abstractmethod
def _parse_response(self, text: str, limit: Optional[int] = None) -> List:
"""Parse llm response."""