150 lines
6.2 KiB
Python
150 lines
6.2 KiB
Python
from weakref import ref
|
|
from dotenv import load_dotenv, find_dotenv
|
|
load_dotenv(find_dotenv())
|
|
import os
|
|
import modules.logger_tool as logger
|
|
log_name = 'api_routers_langchain_graph_qa'
|
|
log_dir = os.getenv("LOG_PATH", "/logs") # Default path as fallback
|
|
logging = logger.get_logger(
|
|
name=log_name,
|
|
log_level=os.getenv("LOG_LEVEL", "DEBUG"),
|
|
log_path=log_dir,
|
|
log_file=log_name,
|
|
runtime=True,
|
|
log_format='default'
|
|
)
|
|
from fastapi import APIRouter, HTTPException
|
|
from langchain_classic.chains import GraphCypherQAChain
|
|
from langchain_community.graphs import Neo4jGraph
|
|
from langchain_community.chat_models import ChatOpenAI
|
|
from langchain_classic.prompts.prompt import PromptTemplate
|
|
from routers.llm.private.ollama.ollama_wrapper import OllamaWrapper
|
|
from modules.database.tools.neontology.utils import get_node_types, get_rels_by_type
|
|
from modules.database.tools.neontology.basenode import BaseNode
|
|
from modules.database.tools.neontology.baserelationship import BaseRelationship
|
|
|
|
router = APIRouter()
|
|
|
|
# Define the schema for nodes and relationships
|
|
node_types = get_node_types(BaseNode)
|
|
relationship_types = get_rels_by_type(BaseRelationship)
|
|
|
|
@router.get("/prompt")
|
|
async def query_graph(
|
|
database: str, prompt: str, top_k: int = 30, model: str = "qwen2.5-coder:3b", temperature: float = 0,
|
|
verbose: bool = False, return_intermediate_steps: bool = False, exclude_types: list = None, include_types: list = None,
|
|
return_direct: bool = False, validate_cypher: bool = False, model_type: str = "ollama"
|
|
):
|
|
logging.info(f"Received request with prompt: {prompt}")
|
|
if exclude_types is None:
|
|
logging.info("No exclude_types provided, using default.")
|
|
exclude_types = []
|
|
if include_types is None:
|
|
logging.info("No include_types provided, using default.")
|
|
include_types = []
|
|
|
|
# Validate include_types and exclude_types
|
|
logging.info(f"Validating include_types and exclude_types...")
|
|
valid_types = set(node_types.keys()).union(set(relationship_types.keys()))
|
|
logging.info(f"Valid types: {valid_types}")
|
|
exclude_types = [t for t in exclude_types if t in valid_types]
|
|
logging.info(f"Validated exclude_types: {exclude_types}")
|
|
include_types = [t for t in include_types if t in valid_types]
|
|
logging.info(f"Validated include_types: {include_types}")
|
|
|
|
graph = Neo4jGraph(
|
|
url=os.environ['APP_BOLT_URL'],
|
|
username=os.environ['USER_NEO4J'],
|
|
password=os.environ['PASSWORD_NEO4J'],
|
|
database=database,
|
|
enhanced_schema=True,
|
|
sanitize=True,
|
|
)
|
|
|
|
logging.info("Refreshing schema...")
|
|
graph.refresh_schema()
|
|
logging.info("Schema refreshed.")
|
|
schema = graph.schema
|
|
logging.info(f"Schema: {schema}")
|
|
|
|
CYPHER_GENERATION_TEMPLATE = """Task: Generate a Cypher statement to query a graph database.
|
|
Role:
|
|
You are an assistant specializing in querying graph databases to find answers to questions about establishments, schools, and related data.
|
|
The user will ask you questions about the graph database
|
|
|
|
Instructions:
|
|
1. Use only the provided relationship types and properties in the schema.
|
|
2. Do not use any other relationship types or properties that are not provided.
|
|
3. When querying for geographic entities like counties, towns, or countries, use the 'name' property, not 'code'.
|
|
4. To find relationship types, use: MATCH (n:NodeType)-[r]->(m) RETURN DISTINCT type(r)
|
|
5. Relationship labels are in uppercase, e.g. LOCATED_IN_COUNTRY
|
|
6. For broad queries use OPTIONAL MATCH to allow for null results, e.g. OPTIONAL MATCH (n:NodeType) RETURN n.name
|
|
|
|
Schema:
|
|
{schema}
|
|
|
|
Note:
|
|
1. Do not include any explanations or apologies in your responses.
|
|
2. Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
|
|
3. Do not include any text except the generated Cypher statement.
|
|
4. Do not include line break characters n other formatting keys.
|
|
|
|
The question is:
|
|
{question}"""
|
|
|
|
CYPHER_GENERATION_PROMPT = PromptTemplate(
|
|
input_variables=["schema", "question"],
|
|
template=CYPHER_GENERATION_TEMPLATE
|
|
)
|
|
|
|
if model_type == "ollama":
|
|
ollama_host = os.getenv("HOST_OLLAMA")
|
|
ollama_port = os.getenv("PORT_OLLAMA")
|
|
if not ollama_host or not ollama_port:
|
|
raise HTTPException(status_code=500, detail="Ollama host or port not set")
|
|
client = OllamaWrapper(host=f'{ollama_host}:{ollama_port}', model=model)
|
|
cypher_llm = client
|
|
qa_llm = client
|
|
else:
|
|
cypher_llm = ChatOpenAI(temperature=temperature, model=model)
|
|
qa_llm = ChatOpenAI(temperature=temperature, model=model)
|
|
|
|
chain = GraphCypherQAChain.from_llm(
|
|
graph=graph,
|
|
cypher_llm=cypher_llm,
|
|
qa_llm=qa_llm,
|
|
top_k=top_k,
|
|
verbose=verbose,
|
|
cypher_prompt=CYPHER_GENERATION_PROMPT,
|
|
return_intermediate_steps=return_intermediate_steps,
|
|
exclude_types=exclude_types,
|
|
include_types=include_types,
|
|
return_direct=return_direct,
|
|
validate_cypher=validate_cypher,
|
|
allow_dangerous_requests=True
|
|
)
|
|
|
|
formatted_prompt = CYPHER_GENERATION_PROMPT.format(schema=schema, question=prompt)
|
|
|
|
logging.info("\n\n")
|
|
|
|
logging.info("==================================================")
|
|
logging.info("= graph_qa.py =")
|
|
logging.info("==================================================")
|
|
logging.info(f"Prompt: {prompt}")
|
|
logging.info("--------------------------------------------------")
|
|
logging.info(f"Schema: \n{schema}\n")
|
|
logging.info("--------------------------------------------------")
|
|
logging.info(f"Formatted Prompt: \n{formatted_prompt}\n")
|
|
logging.info("--------------------------------------------------")
|
|
logging.info(f"Cypher prompt: \n{CYPHER_GENERATION_PROMPT}\n")
|
|
logging.info("--------------------------------------------------")
|
|
logging.info(f"Cypher template: \n{CYPHER_GENERATION_TEMPLATE}\n")
|
|
logging.info("--------------------------------------------------")
|
|
logging.info(f"Cypher chain: \n{chain}\n")
|
|
logging.info("==================================================")
|
|
|
|
return chain(prompt)
|
|
|
|
|