mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 03:50:42 +00:00
feat: add GraphRAG framework and integrate TuGraph (#1506)
Co-authored-by: KingSkyLi <15566300566@163.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
50
dbgpt/rag/transformer/llm_extractor.py
Normal file
50
dbgpt/rag/transformer/llm_extractor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""TripletExtractor class."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
|
||||
from dbgpt.rag.transformer.base import ExtractorBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMExtractor(ExtractorBase, ABC):
|
||||
"""LLMExtractor class."""
|
||||
|
||||
def __init__(self, llm_client: LLMClient, model_name: str, prompt_template: str):
|
||||
"""Initialize the LLMExtractor."""
|
||||
self._llm_client = llm_client
|
||||
self._model_name = model_name
|
||||
self._prompt_template = prompt_template
|
||||
|
||||
async def extract(self, text: str, limit: Optional[int] = None) -> List:
|
||||
"""Extract by LLm."""
|
||||
template = HumanPromptTemplate.from_template(self._prompt_template)
|
||||
messages = template.format_messages(text=text)
|
||||
|
||||
# use default model if needed
|
||||
if not self._model_name:
|
||||
models = await self._llm_client.models()
|
||||
if not models:
|
||||
raise Exception("No models available")
|
||||
self._model_name = models[0].model
|
||||
logger.info(f"Using model {self._model_name} to extract")
|
||||
|
||||
model_messages = ModelMessage.from_base_messages(messages)
|
||||
request = ModelRequest(model=self._model_name, messages=model_messages)
|
||||
response = await self._llm_client.generate(request=request)
|
||||
|
||||
if not response.success:
|
||||
code = str(response.error_code)
|
||||
reason = response.text
|
||||
logger.error(f"request llm failed ({code}) {reason}")
|
||||
return []
|
||||
|
||||
if limit and limit < 1:
|
||||
ValueError("optional argument limit >= 1")
|
||||
return self._parse_response(response.text, limit)
|
||||
|
||||
@abstractmethod
|
||||
def _parse_response(self, text: str, limit: Optional[int] = None) -> List:
|
||||
"""Parse llm response."""
|
Reference in New Issue
Block a user