153 lines
5.4 KiB
Python
153 lines
5.4 KiB
Python
from dotenv import load_dotenv, find_dotenv
|
|
load_dotenv(find_dotenv())
|
|
import os
|
|
import time
|
|
from typing import Optional, Tuple, Generator
|
|
from modules.logger_tool import initialise_logger
|
|
from neo4j import GraphDatabase as gd, Driver, Session
|
|
from contextlib import contextmanager
|
|
|
|
logger = initialise_logger(__name__, os.getenv("LOG_LEVEL"), os.getenv("LOG_PATH"), 'default', True)
|
|
|
|
def _retry_with_backoff(
|
|
func,
|
|
max_attempts: int = 10, # Increased from 3 to 10
|
|
initial_delay: float = 2.0, # Increased from 1 to 2 seconds
|
|
max_total_wait: float = 60.0, # Maximum total time to wait (60 seconds)
|
|
max_delay: float = 10.0 # Maximum delay between retries
|
|
) -> any:
|
|
"""
|
|
Helper function to retry operations with exponential backoff.
|
|
|
|
Args:
|
|
func: Function to retry
|
|
max_attempts: Maximum number of retry attempts
|
|
initial_delay: Initial delay between retries in seconds
|
|
max_total_wait: Maximum total time to wait before giving up
|
|
max_delay: Maximum delay between retries
|
|
"""
|
|
attempt = 0
|
|
delay = initial_delay
|
|
start_time = time.time()
|
|
|
|
while attempt < max_attempts:
|
|
try:
|
|
return func()
|
|
except Exception as e:
|
|
attempt += 1
|
|
elapsed_time = time.time() - start_time
|
|
|
|
# Check if we've exceeded the maximum total wait time
|
|
if elapsed_time >= max_total_wait:
|
|
logger.error(f"Exceeded maximum total wait time of {max_total_wait} seconds")
|
|
raise
|
|
|
|
if attempt == max_attempts:
|
|
logger.error(f"Final attempt {attempt} failed: {e}")
|
|
raise
|
|
|
|
# Calculate next delay with exponential backoff, but cap it
|
|
delay = min(delay * 2, max_delay)
|
|
|
|
# If we're in a container initialization scenario, provide more context
|
|
if "Connection refused" in str(e):
|
|
logger.warning(
|
|
f"Attempt {attempt} failed: Connection refused. "
|
|
f"This might indicate that Neo4j is still starting up. "
|
|
f"Retrying in {delay:.1f} seconds... "
|
|
f"(Total elapsed: {elapsed_time:.1f}s)"
|
|
)
|
|
else:
|
|
logger.warning(f"Attempt {attempt} failed: {e}. Retrying in {delay:.1f} seconds...")
|
|
|
|
time.sleep(delay)
|
|
|
|
def get_driver(db_name: Optional[str] = None, url: Optional[str] = None, auth: Optional[Tuple[str, str]] = None) -> Optional[Driver]:
|
|
if url is None:
|
|
url = os.getenv("APP_BOLT_URL")
|
|
username = os.getenv("USER_NEO4J")
|
|
password = os.getenv("PASSWORD_NEO4J")
|
|
if not username or not password:
|
|
logger.error("Neo4j credentials not found in environment")
|
|
return None
|
|
auth = (username, password)
|
|
|
|
if auth is None:
|
|
logger.error("No authentication credentials provided")
|
|
return None
|
|
|
|
def create_driver():
|
|
logger.info(f"Attempting to connect to Neo4j at {url}")
|
|
driver = gd.driver(url, auth=auth)
|
|
driver.verify_connectivity()
|
|
logger.info(f"Connected to Neo4j at {url}")
|
|
return driver
|
|
|
|
try:
|
|
# Use more lenient retry parameters for initial connection
|
|
driver = _retry_with_backoff(
|
|
create_driver,
|
|
max_attempts=10,
|
|
initial_delay=2.0,
|
|
max_total_wait=60.0,
|
|
max_delay=10.0
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to establish Neo4j connection after all retries: {e}")
|
|
return None
|
|
|
|
# Test the connection with the specific database
|
|
if db_name and driver:
|
|
def verify_database():
|
|
with driver.session(database=db_name) as session:
|
|
result = session.run("RETURN 'Connection successful' AS message")
|
|
record = result.single()
|
|
if not record or not record.get("message"):
|
|
raise Exception(f"Failed to verify database {db_name} connection")
|
|
logger.info(f"Connection to Neo4j at {url} with database {db_name} successful")
|
|
|
|
try:
|
|
# Use more lenient retry parameters for database verification
|
|
_retry_with_backoff(
|
|
verify_database,
|
|
max_attempts=10,
|
|
initial_delay=2.0,
|
|
max_total_wait=60.0,
|
|
max_delay=10.0
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to database {db_name} after all retries: {e}")
|
|
driver.close()
|
|
return None
|
|
|
|
return driver
|
|
|
|
def close_driver(driver: Optional[Driver]) -> None:
|
|
if driver:
|
|
logger.info("Closing driver")
|
|
driver.close()
|
|
|
|
# Global driver instance
|
|
_driver: Optional[Driver] = None
|
|
|
|
def get_global_driver() -> Optional[Driver]:
|
|
"""Get or create the global Neo4j driver instance."""
|
|
global _driver
|
|
if _driver is None:
|
|
_driver = get_driver()
|
|
return _driver
|
|
|
|
@contextmanager
|
|
def get_session(database: Optional[str] = None) -> Generator[Session, None, None]:
|
|
"""Get a Neo4j session using the global driver."""
|
|
driver = get_global_driver()
|
|
if driver is None:
|
|
raise Exception("Failed to get Neo4j driver")
|
|
|
|
session = None
|
|
try:
|
|
session = driver.session(database=database)
|
|
yield session
|
|
finally:
|
|
if session:
|
|
session.close() |