from weakref import ref from dotenv import load_dotenv, find_dotenv load_dotenv(find_dotenv()) import os import modules.logger_tool as logger log_name = 'api_routers_langchain_graph_qa' log_dir = os.getenv("LOG_PATH", "/logs") # Default path as fallback logging = logger.get_logger( name=log_name, log_level=os.getenv("LOG_LEVEL", "DEBUG"), log_path=log_dir, log_file=log_name, runtime=True, log_format='default' ) from fastapi import APIRouter, HTTPException from langchain_classic.chains import GraphCypherQAChain from langchain_community.graphs import Neo4jGraph from langchain_community.chat_models import ChatOpenAI from langchain_classic.prompts.prompt import PromptTemplate from routers.llm.private.ollama.ollama_wrapper import OllamaWrapper from modules.database.tools.neontology.utils import get_node_types, get_rels_by_type from modules.database.tools.neontology.basenode import BaseNode from modules.database.tools.neontology.baserelationship import BaseRelationship router = APIRouter() # Define the schema for nodes and relationships node_types = get_node_types(BaseNode) relationship_types = get_rels_by_type(BaseRelationship) @router.get("/prompt") async def query_graph( database: str, prompt: str, top_k: int = 30, model: str = "qwen2.5-coder:3b", temperature: float = 0, verbose: bool = False, return_intermediate_steps: bool = False, exclude_types: list = None, include_types: list = None, return_direct: bool = False, validate_cypher: bool = False, model_type: str = "ollama" ): logging.info(f"Received request with prompt: {prompt}") if exclude_types is None: logging.info("No exclude_types provided, using default.") exclude_types = [] if include_types is None: logging.info("No include_types provided, using default.") include_types = [] # Validate include_types and exclude_types logging.info(f"Validating include_types and exclude_types...") valid_types = set(node_types.keys()).union(set(relationship_types.keys())) logging.info(f"Valid types: {valid_types}") exclude_types = [t for t in exclude_types if t in valid_types] logging.info(f"Validated exclude_types: {exclude_types}") include_types = [t for t in include_types if t in valid_types] logging.info(f"Validated include_types: {include_types}") graph = Neo4jGraph( url=os.environ['APP_BOLT_URL'], username=os.environ['USER_NEO4J'], password=os.environ['PASSWORD_NEO4J'], database=database, enhanced_schema=True, sanitize=True, ) logging.info("Refreshing schema...") graph.refresh_schema() logging.info("Schema refreshed.") schema = graph.schema logging.info(f"Schema: {schema}") CYPHER_GENERATION_TEMPLATE = """Task: Generate a Cypher statement to query a graph database. Role: You are an assistant specializing in querying graph databases to find answers to questions about establishments, schools, and related data. The user will ask you questions about the graph database Instructions: 1. Use only the provided relationship types and properties in the schema. 2. Do not use any other relationship types or properties that are not provided. 3. When querying for geographic entities like counties, towns, or countries, use the 'name' property, not 'code'. 4. To find relationship types, use: MATCH (n:NodeType)-[r]->(m) RETURN DISTINCT type(r) 5. Relationship labels are in uppercase, e.g. LOCATED_IN_COUNTRY 6. For broad queries use OPTIONAL MATCH to allow for null results, e.g. OPTIONAL MATCH (n:NodeType) RETURN n.name Schema: {schema} Note: 1. Do not include any explanations or apologies in your responses. 2. Do not respond to any questions that might ask anything else than for you to construct a Cypher statement. 3. Do not include any text except the generated Cypher statement. 4. Do not include line break characters n other formatting keys. The question is: {question}""" CYPHER_GENERATION_PROMPT = PromptTemplate( input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE ) if model_type == "ollama": ollama_host = os.getenv("HOST_OLLAMA") ollama_port = os.getenv("PORT_OLLAMA") if not ollama_host or not ollama_port: raise HTTPException(status_code=500, detail="Ollama host or port not set") client = OllamaWrapper(host=f'{ollama_host}:{ollama_port}', model=model) cypher_llm = client qa_llm = client else: cypher_llm = ChatOpenAI(temperature=temperature, model=model) qa_llm = ChatOpenAI(temperature=temperature, model=model) chain = GraphCypherQAChain.from_llm( graph=graph, cypher_llm=cypher_llm, qa_llm=qa_llm, top_k=top_k, verbose=verbose, cypher_prompt=CYPHER_GENERATION_PROMPT, return_intermediate_steps=return_intermediate_steps, exclude_types=exclude_types, include_types=include_types, return_direct=return_direct, validate_cypher=validate_cypher, allow_dangerous_requests=True ) formatted_prompt = CYPHER_GENERATION_PROMPT.format(schema=schema, question=prompt) logging.info("\n\n") logging.info("==================================================") logging.info("= graph_qa.py =") logging.info("==================================================") logging.info(f"Prompt: {prompt}") logging.info("--------------------------------------------------") logging.info(f"Schema: \n{schema}\n") logging.info("--------------------------------------------------") logging.info(f"Formatted Prompt: \n{formatted_prompt}\n") logging.info("--------------------------------------------------") logging.info(f"Cypher prompt: \n{CYPHER_GENERATION_PROMPT}\n") logging.info("--------------------------------------------------") logging.info(f"Cypher template: \n{CYPHER_GENERATION_TEMPLATE}\n") logging.info("--------------------------------------------------") logging.info(f"Cypher chain: \n{chain}\n") logging.info("==================================================") return chain(prompt)