diff --git a/modules/database/tools/neo4j_driver_tools.py b/modules/database/tools/neo4j_driver_tools.py index 61d48f6..a7efa9a 100644 --- a/modules/database/tools/neo4j_driver_tools.py +++ b/modules/database/tools/neo4j_driver_tools.py @@ -18,7 +18,7 @@ def _retry_with_backoff( ) -> any: """ Helper function to retry operations with exponential backoff. - + Args: func: Function to retry max_attempts: Maximum number of retry attempts @@ -29,26 +29,26 @@ def _retry_with_backoff( attempt = 0 delay = initial_delay start_time = time.time() - + while attempt < max_attempts: try: return func() except Exception as e: attempt += 1 elapsed_time = time.time() - start_time - + # Check if we've exceeded the maximum total wait time if elapsed_time >= max_total_wait: logger.error(f"Exceeded maximum total wait time of {max_total_wait} seconds") raise - + if attempt == max_attempts: logger.error(f"Final attempt {attempt} failed: {e}") raise - + # Calculate next delay with exponential backoff, but cap it delay = min(delay * 2, max_delay) - + # If we're in a container initialization scenario, provide more context if "Connection refused" in str(e): logger.warning( @@ -59,7 +59,7 @@ def _retry_with_backoff( ) else: logger.warning(f"Attempt {attempt} failed: {e}. Retrying in {delay:.1f} seconds...") - + time.sleep(delay) def get_driver(db_name: Optional[str] = None, url: Optional[str] = None, auth: Optional[Tuple[str, str]] = None) -> Optional[Driver]: @@ -71,7 +71,7 @@ def get_driver(db_name: Optional[str] = None, url: Optional[str] = None, auth: O logger.error("Neo4j credentials not found in environment") return None auth = (username, password) - + if auth is None: logger.error("No authentication credentials provided") return None @@ -95,7 +95,7 @@ def get_driver(db_name: Optional[str] = None, url: Optional[str] = None, auth: O except Exception as e: logger.error(f"Failed to establish Neo4j connection after all retries: {e}") return None - + # Test the connection with the specific database if db_name and driver: def verify_database(): @@ -127,27 +127,50 @@ def close_driver(driver: Optional[Driver]) -> None: logger.info("Closing driver") driver.close() -# Global driver instance +# Global driver instance — None means not yet initialised, _driver_unavailable=True means connection failed _driver: Optional[Driver] = None +_driver_unavailable: bool = False def get_global_driver() -> Optional[Driver]: - """Get or create the global Neo4j driver instance.""" - global _driver + """Get or create the global Neo4j driver instance. + + Caches both success and failure so a broken Neo4j connection causes + a single 60-second retry at startup, then fast-fails on every + subsequent call instead of hanging for 60s each time. + """ + global _driver, _driver_unavailable + if _driver_unavailable: + return None if _driver is None: _driver = get_driver() + if _driver is None: + _driver_unavailable = True + logger.error("Neo4j driver unavailable — all subsequent Neo4j calls will fail fast until process restarts") return _driver +def reset_global_driver() -> None: + """Reset the cached driver, forcing a reconnection attempt on the next call. + + Call this if Neo4j becomes available after the process started. + """ + global _driver, _driver_unavailable + if _driver: + close_driver(_driver) + _driver = None + _driver_unavailable = False + logger.info("Global Neo4j driver reset — will reconnect on next request") + @contextmanager def get_session(database: Optional[str] = None) -> Generator[Session, None, None]: """Get a Neo4j session using the global driver.""" driver = get_global_driver() if driver is None: raise Exception("Failed to get Neo4j driver") - + session = None try: session = driver.session(database=database) yield session finally: if session: - session.close() \ No newline at end of file + session.close()