api/routers/langchain/neo4j_graph_qa.py
2025-11-19 19:38:09 +00:00

150 lines
6.2 KiB
Python

from weakref import ref
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
import os
import modules.logger_tool as logger
log_name = 'api_routers_langchain_graph_qa'
log_dir = os.getenv("LOG_PATH", "/logs") # Default path as fallback
logging = logger.get_logger(
name=log_name,
log_level=os.getenv("LOG_LEVEL", "DEBUG"),
log_path=log_dir,
log_file=log_name,
runtime=True,
log_format='default'
)
from fastapi import APIRouter, HTTPException
from langchain_classic.chains import GraphCypherQAChain
from langchain_community.graphs import Neo4jGraph
from langchain_community.chat_models import ChatOpenAI
from langchain_classic.prompts.prompt import PromptTemplate
from routers.llm.private.ollama.ollama_wrapper import OllamaWrapper
from modules.database.tools.neontology.utils import get_node_types, get_rels_by_type
from modules.database.tools.neontology.basenode import BaseNode
from modules.database.tools.neontology.baserelationship import BaseRelationship
router = APIRouter()
# Define the schema for nodes and relationships
node_types = get_node_types(BaseNode)
relationship_types = get_rels_by_type(BaseRelationship)
@router.get("/prompt")
async def query_graph(
database: str, prompt: str, top_k: int = 30, model: str = "qwen2.5-coder:3b", temperature: float = 0,
verbose: bool = False, return_intermediate_steps: bool = False, exclude_types: list = None, include_types: list = None,
return_direct: bool = False, validate_cypher: bool = False, model_type: str = "ollama"
):
logging.info(f"Received request with prompt: {prompt}")
if exclude_types is None:
logging.info("No exclude_types provided, using default.")
exclude_types = []
if include_types is None:
logging.info("No include_types provided, using default.")
include_types = []
# Validate include_types and exclude_types
logging.info(f"Validating include_types and exclude_types...")
valid_types = set(node_types.keys()).union(set(relationship_types.keys()))
logging.info(f"Valid types: {valid_types}")
exclude_types = [t for t in exclude_types if t in valid_types]
logging.info(f"Validated exclude_types: {exclude_types}")
include_types = [t for t in include_types if t in valid_types]
logging.info(f"Validated include_types: {include_types}")
graph = Neo4jGraph(
url=os.environ['APP_BOLT_URL'],
username=os.environ['USER_NEO4J'],
password=os.environ['PASSWORD_NEO4J'],
database=database,
enhanced_schema=True,
sanitize=True,
)
logging.info("Refreshing schema...")
graph.refresh_schema()
logging.info("Schema refreshed.")
schema = graph.schema
logging.info(f"Schema: {schema}")
CYPHER_GENERATION_TEMPLATE = """Task: Generate a Cypher statement to query a graph database.
Role:
You are an assistant specializing in querying graph databases to find answers to questions about establishments, schools, and related data.
The user will ask you questions about the graph database
Instructions:
1. Use only the provided relationship types and properties in the schema.
2. Do not use any other relationship types or properties that are not provided.
3. When querying for geographic entities like counties, towns, or countries, use the 'name' property, not 'code'.
4. To find relationship types, use: MATCH (n:NodeType)-[r]->(m) RETURN DISTINCT type(r)
5. Relationship labels are in uppercase, e.g. LOCATED_IN_COUNTRY
6. For broad queries use OPTIONAL MATCH to allow for null results, e.g. OPTIONAL MATCH (n:NodeType) RETURN n.name
Schema:
{schema}
Note:
1. Do not include any explanations or apologies in your responses.
2. Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
3. Do not include any text except the generated Cypher statement.
4. Do not include line break characters n other formatting keys.
The question is:
{question}"""
CYPHER_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"],
template=CYPHER_GENERATION_TEMPLATE
)
if model_type == "ollama":
ollama_host = os.getenv("HOST_OLLAMA")
ollama_port = os.getenv("PORT_OLLAMA")
if not ollama_host or not ollama_port:
raise HTTPException(status_code=500, detail="Ollama host or port not set")
client = OllamaWrapper(host=f'{ollama_host}:{ollama_port}', model=model)
cypher_llm = client
qa_llm = client
else:
cypher_llm = ChatOpenAI(temperature=temperature, model=model)
qa_llm = ChatOpenAI(temperature=temperature, model=model)
chain = GraphCypherQAChain.from_llm(
graph=graph,
cypher_llm=cypher_llm,
qa_llm=qa_llm,
top_k=top_k,
verbose=verbose,
cypher_prompt=CYPHER_GENERATION_PROMPT,
return_intermediate_steps=return_intermediate_steps,
exclude_types=exclude_types,
include_types=include_types,
return_direct=return_direct,
validate_cypher=validate_cypher,
allow_dangerous_requests=True
)
formatted_prompt = CYPHER_GENERATION_PROMPT.format(schema=schema, question=prompt)
logging.info("\n\n")
logging.info("==================================================")
logging.info("= graph_qa.py =")
logging.info("==================================================")
logging.info(f"Prompt: {prompt}")
logging.info("--------------------------------------------------")
logging.info(f"Schema: \n{schema}\n")
logging.info("--------------------------------------------------")
logging.info(f"Formatted Prompt: \n{formatted_prompt}\n")
logging.info("--------------------------------------------------")
logging.info(f"Cypher prompt: \n{CYPHER_GENERATION_PROMPT}\n")
logging.info("--------------------------------------------------")
logging.info(f"Cypher template: \n{CYPHER_GENERATION_TEMPLATE}\n")
logging.info("--------------------------------------------------")
logging.info(f"Cypher chain: \n{chain}\n")
logging.info("==================================================")
return chain(prompt)