mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-20 09:14:44 +00:00
96 lines
3.0 KiB
Python
96 lines
3.0 KiB
Python
"""TripletExtractor class."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Optional
|
|
|
|
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
|
|
from dbgpt.rag.transformer.base import ExtractorBase
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMExtractor(ExtractorBase, ABC):
|
|
"""LLMExtractor class."""
|
|
|
|
def __init__(self, llm_client: LLMClient, model_name: str, prompt_template: str):
|
|
"""Initialize the LLMExtractor."""
|
|
self._llm_client = llm_client
|
|
self._model_name = model_name
|
|
self._prompt_template = prompt_template
|
|
|
|
async def extract(self, text: str, limit: Optional[int] = None) -> List:
|
|
"""Extract by LLM."""
|
|
return await self._extract(text, None, limit)
|
|
|
|
async def batch_extract(
|
|
self,
|
|
texts: List[str],
|
|
batch_size: int = 1,
|
|
limit: Optional[int] = None,
|
|
) -> List:
|
|
"""Batch extract by LLM."""
|
|
if batch_size < 1:
|
|
raise ValueError("batch_size >= 1")
|
|
|
|
results = []
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
batch_texts = texts[i : i + batch_size]
|
|
|
|
# Create tasks for current batch
|
|
extraction_tasks = [
|
|
self._extract(text, None, limit) for text in batch_texts
|
|
]
|
|
|
|
# Execute batch concurrently and wait for all to complete
|
|
batch_results = await asyncio.gather(*extraction_tasks)
|
|
results.extend(batch_results)
|
|
|
|
return results
|
|
|
|
async def _extract(
|
|
self, text: str, history: str = None, limit: Optional[int] = None
|
|
) -> List:
|
|
"""Inner extract by LLM."""
|
|
template = HumanPromptTemplate.from_template(self._prompt_template)
|
|
|
|
messages = (
|
|
template.format_messages(text=text, history=history)
|
|
if history is not None
|
|
else template.format_messages(text=text)
|
|
)
|
|
|
|
# use default model if needed
|
|
if not self._model_name:
|
|
models = await self._llm_client.models()
|
|
if not models:
|
|
raise Exception("No models available")
|
|
self._model_name = models[0].model
|
|
logger.info(f"Using model {self._model_name} to extract")
|
|
|
|
model_messages = ModelMessage.from_base_messages(messages)
|
|
request = ModelRequest(model=self._model_name, messages=model_messages)
|
|
response = await self._llm_client.generate(request=request)
|
|
|
|
if not response.success:
|
|
code = str(response.error_code)
|
|
reason = response.text
|
|
logger.error(f"request llm failed ({code}) {reason}")
|
|
return []
|
|
|
|
if limit and limit < 1:
|
|
ValueError("optional argument limit >= 1")
|
|
return self._parse_response(response.text, limit)
|
|
|
|
def truncate(self):
|
|
"""Do nothing by default."""
|
|
|
|
def drop(self):
|
|
"""Do nothing by default."""
|
|
|
|
@abstractmethod
|
|
def _parse_response(self, text: str, limit: Optional[int] = None) -> List:
|
|
"""Parse llm response."""
|