api/modules/langchain/neo4j_graph_qa.py
2025-07-11 13:52:19 +00:00

61 lines
2.8 KiB
Python

from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
import os
import modules.logger_tool as logger
log_name = 'api_modules_langchain_graph_qa'
log_dir = os.getenv("LOG_PATH", "/logs") # Default path as fallback
logging = logger.get_logger(
name=log_name,
log_level=os.getenv("LOG_LEVEL", "DEBUG"),
log_path=log_dir,
log_file=log_name,
runtime=True,
log_format='default'
)
def test_query_graph(database, prompt, top_k=20, model="gpt-4o", temperature=0, verbose=False, return_intermediate_steps=True, exclude_types=None, include_types=None, return_direct=False, validate_cypher=False, model_type="openai"):
url = f"{os.environ['APP_API_URL']}/langchain/graph_qa/prompt"
params = {
"database": database,
"prompt": prompt,
"top_k": top_k,
"model": model,
"temperature": temperature,
"verbose": verbose,
"return_intermediate_steps": return_intermediate_steps,
"exclude_types": exclude_types or [],
"include_types": include_types or [],
"return_direct": return_direct,
"validate_cypher": validate_cypher,
"model_type": model_type
}
try:
response = requests.get(url, params=params)
response.raise_for_status() # Raise an error for bad status codes
data = response.json()
logging.info("==================================================")
logging.info("= =")
logging.info("= Test Execution =")
logging.info("= =")
logging.info("==================================================")
logging.info(f"= Prompt: {data.get('query', 'N/A')}")
logging.info("= =")
logging.info(f"= Query: \n{data.get('intermediate_steps', [{'query': 'N/A'}])[0].get('query', 'N/A')}")
logging.info("= =")
logging.info("==================================================")
# Determine if the test passed or failed
response_text = data.get('result', 'N/A')
context = data.get('intermediate_steps', [{'context': 'N/A'}])[1].get('context', 'N/A')
if "I don't know" in response_text or not context:
return False, context, response_text
else:
return True, context, response_text
except requests.exceptions.RequestException as e:
logging.error("==================================================")
logging.error("= ERROR =")
logging.error("==================================================")
logging.error(f"Error: {e}")
logging.error("==================================================")
return False, None, None