api/modules/queue_system.py
2025-11-14 14:47:19 +00:00

680 lines
28 KiB
Python

"""
Robust Document Processing Queue System
This module provides a Redis-based queuing system for document processing tasks
to prevent server overload and handle concurrent processing efficiently.
Features:
- Priority-based queuing (high, normal, low)
- Rate limiting per service (Tika, Docling, LLM)
- Concurrent processing limits
- Task retry mechanism with exponential backoff
- Dead letter queue for failed tasks
- Health monitoring and metrics
- Graceful shutdown handling
"""
import json
import time
import uuid
import asyncio
import threading
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Callable, Union
from urllib.parse import urlparse
from dataclasses import dataclass, asdict
from enum import Enum
import redis
import os
from modules.logger_tool import initialise_logger
from .redis_manager import get_redis_manager, Environment
logger = initialise_logger(__name__, os.getenv("LOG_LEVEL"), os.getenv("LOG_PATH"), 'default', True)
# Custom exceptions
class QueueConnectionError(Exception):
"""Raised when the queue system cannot connect to Redis"""
pass
class QueueFullError(Exception):
"""Raised when a service queue is at capacity"""
pass
class RateLimitError(Exception):
"""Raised when rate limits are exceeded"""
pass
class TaskPriority(Enum):
HIGH = "high" # Interactive uploads, user waiting
NORMAL = "normal" # Regular batch processing
LOW = "low" # Background/cleanup tasks
class TaskStatus(Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
RETRYING = "retrying"
DEAD = "dead"
class ServiceType(Enum):
TIKA = "tika"
DOCLING = "docling"
LLM = "llm"
SPLIT_MAP = "split_map"
DOCUMENT_ANALYSIS = "document_analysis"
PAGE_IMAGES = "page_images"
@dataclass
class QueueTask:
"""Represents a task in the processing queue."""
id: str
service: ServiceType
priority: TaskPriority
file_id: str
task_type: str # e.g., "tika_metadata", "docling_frontmatter", "llm_classify"
payload: Dict[str, Any]
created_at: float
scheduled_at: float = None # For delayed/retry tasks
attempts: int = 0
max_attempts: int = 3
timeout: int = 300 # seconds
callback_url: Optional[str] = None
user_id: Optional[str] = None
def __post_init__(self):
if self.scheduled_at is None:
self.scheduled_at = time.time()
class DocumentProcessingQueue:
"""Redis-based document processing queue with rate limiting and priorities."""
def __init__(self, environment: str = None, redis_url: str = None):
"""Initialize queue with Redis connection using new Redis manager.
Args:
environment: 'dev', 'prod', or 'test' (auto-detected if not provided)
redis_url: Legacy parameter for backward compatibility
"""
# Auto-detect environment from startup mode if not provided
if not environment:
environment = 'dev' if os.getenv('BACKEND_DEV_MODE', 'true').lower() == 'true' else 'prod'
# Initialize Redis manager for this environment
self.redis_manager = get_redis_manager(environment)
# Initialize Redis environment (ensures service running and connects)
if not self.redis_manager.initialize_environment():
raise ConnectionError(f"Failed to initialize Redis for {environment} environment")
# Use the managed Redis client
self.redis_client = self.redis_manager.client
self.environment = environment
logger.info(f"🎯 Queue initialized for {environment} environment (db={self.redis_manager.config.db})")
# Queue configuration
self.service_limits = {
ServiceType.TIKA: int(os.getenv('QUEUE_TIKA_LIMIT', '3')),
ServiceType.DOCLING: int(os.getenv('QUEUE_DOCLING_LIMIT', '2')),
ServiceType.LLM: int(os.getenv('QUEUE_LLM_LIMIT', '5')),
ServiceType.SPLIT_MAP: int(os.getenv('QUEUE_SPLIT_MAP_LIMIT', '10')),
ServiceType.DOCUMENT_ANALYSIS: int(os.getenv('QUEUE_DOCUMENT_ANALYSIS_LIMIT', '5')),
ServiceType.PAGE_IMAGES: int(os.getenv('QUEUE_PAGE_IMAGES_LIMIT', '3'))
}
# Rate limiting (requests per minute)
self.rate_limits = {
ServiceType.TIKA: int(os.getenv('QUEUE_TIKA_RATE', '60')),
ServiceType.DOCLING: int(os.getenv('QUEUE_DOCLING_RATE', '30')),
ServiceType.LLM: int(os.getenv('QUEUE_LLM_RATE', '120')),
ServiceType.SPLIT_MAP: int(os.getenv('QUEUE_SPLIT_MAP_RATE', '100'))
}
# Queue names
self.queue_keys = {
priority: f"queue:{priority.value}" for priority in TaskPriority
}
self.processing_key = "processing"
self.dead_letter_key = "dead_letter"
self.metrics_key = "metrics"
# Worker control
self.workers_running = {}
self.shutdown_event = threading.Event()
logger.info(f"⚙️ Queue service limits: {dict(self.service_limits)}")
def _get_task_key(self, task_id: str) -> str:
"""Get Redis key for task data."""
return f"task:{task_id}"
def _get_service_processing_key(self, service: ServiceType) -> str:
"""Get Redis key for service processing counter."""
return f"processing:{service.value}"
def _get_rate_limit_key(self, service: ServiceType) -> str:
"""Get Redis key for rate limiting."""
return f"rate_limit:{service.value}:{int(time.time() // 60)}"
def enqueue_task(self,
service: ServiceType,
task_type: str,
file_id: str,
payload: Dict[str, Any],
priority: TaskPriority = TaskPriority.NORMAL,
user_id: str = None,
timeout: int = 300,
max_attempts: int = 3) -> str:
"""
Enqueue a new processing task.
Returns:
str: Task ID
"""
task_id = str(uuid.uuid4())
task = QueueTask(
id=task_id,
service=service,
priority=priority,
file_id=file_id,
task_type=task_type,
payload=payload,
created_at=time.time(),
timeout=timeout,
max_attempts=max_attempts,
user_id=user_id
)
try:
# Store task data (convert enums to strings for Redis)
task_dict = asdict(task)
task_dict['service'] = task_dict['service'].value
task_dict['priority'] = task_dict['priority'].value
task_dict['payload'] = json.dumps(task_dict['payload'])
# Convert None values to empty strings for Redis
for key, value in task_dict.items():
if value is None:
task_dict[key] = ''
self.redis_client.hset(
self._get_task_key(task_id),
mapping=task_dict
)
self.redis_client.expire(self._get_task_key(task_id), 86400) # 24 hours TTL
# Add to priority queue
queue_key = self.queue_keys[priority]
self.redis_client.lpush(queue_key, task_id)
except redis.ConnectionError as e:
logger.error(f"Redis connection failed when enqueueing task: {e}")
logger.error("Please ensure Redis is running. Start the API server with './start.sh dev' to auto-start Redis.")
raise QueueConnectionError(f"Queue system unavailable: Redis connection failed")
except Exception as e:
logger.error(f"Failed to enqueue task {task_id}: {e}")
raise
# Update metrics
self._update_metrics("enqueued", service.value, priority.value)
logger.info(f"Enqueued task {task_id}: {service.value}/{task_type} for file {file_id}")
return task_id
def dequeue_task(self, timeout: int = 10) -> Optional[QueueTask]:
"""
Dequeue next available task respecting service limits and rate limits.
Args:
timeout: Blocking timeout in seconds
Returns:
QueueTask or None if no task available
"""
# Check all priority queues in order
for priority in [TaskPriority.HIGH, TaskPriority.NORMAL, TaskPriority.LOW]:
queue_key = self.queue_keys[priority]
# Non-blocking pop to check availability
task_id = self.redis_client.rpop(queue_key)
if not task_id:
continue
# Get task data
task_data = self.redis_client.hgetall(self._get_task_key(task_id))
if not task_data:
logger.warning(f"Task {task_id} data not found, skipping")
continue
# Reconstruct task object (ignore non-dataclass keys like 'status', 'result', timestamps)
task_data['service'] = ServiceType(task_data['service'])
task_data['priority'] = TaskPriority(task_data['priority'])
task_data['created_at'] = float(task_data['created_at'])
task_data['scheduled_at'] = float(task_data['scheduled_at'])
task_data['attempts'] = int(task_data['attempts'])
task_data['max_attempts'] = int(task_data['max_attempts'])
task_data['timeout'] = int(task_data['timeout'])
task_data['payload'] = json.loads(task_data['payload'])
# Drop extraneous keys that may be present in Redis
for k in ['status', 'result', 'completed_at', 'failed_at', 'last_error', 'final_error']:
task_data.pop(k, None)
task = QueueTask(**task_data)
# Check if task is ready (for delayed/retry tasks)
if task.scheduled_at > time.time():
# Put back in queue for later
self.redis_client.lpush(queue_key, task_id)
continue
# Enforce simple dependency ordering using optional depends_on array in payload
try:
depends_on = []
if isinstance(task.payload, dict):
depends_on = task.payload.get('depends_on') or []
if isinstance(depends_on, list) and len(depends_on) > 0:
logger.info(f"Checking {len(depends_on)} dependencies for task {task_id}")
unmet = []
for dep_id in depends_on:
if not dep_id:
continue
dep_key = self._get_task_key(dep_id)
dep_status = self.redis_client.hget(dep_key, 'status') or TaskStatus.PENDING.value
if dep_status != TaskStatus.COMPLETED.value:
unmet.append((dep_id, dep_status))
if len(unmet) > 0:
# Reschedule this task a bit later to avoid tight loops
next_time = time.time() + 10
self.redis_client.hset(
self._get_task_key(task_id),
mapping={'scheduled_at': next_time}
)
# Put back for later processing
self.redis_client.lpush(queue_key, task_id)
logger.info(f"Deferring task {task_id} due to unmet dependencies: {unmet}")
continue
else:
logger.info(f"All {len(depends_on)} dependencies satisfied for task {task_id}")
except Exception as dep_e:
logger.warning(f"Dependency check failed for task {task_id}: {dep_e}")
# Check service limits
service_processing_key = self._get_service_processing_key(task.service)
current_processing = int(self.redis_client.get(service_processing_key) or 0)
if current_processing >= self.service_limits[task.service]:
# Put back in queue with delay to prevent infinite loops
logger.warning(f"🚨 SERVICE LIMIT: Task {task_id} re-queued due to service limit exceeded: {current_processing}/{self.service_limits[task.service]} for {task.service.value}")
# Add delay before re-queueing to prevent tight loops
next_time = time.time() + 5 # Wait 5 seconds before retrying
self.redis_client.hset(
self._get_task_key(task_id),
mapping={'scheduled_at': next_time}
)
self.redis_client.lpush(queue_key, task_id)
continue
# Check rate limits
rate_key = self._get_rate_limit_key(task.service)
current_rate = int(self.redis_client.get(rate_key) or 0)
if current_rate >= self.rate_limits[task.service]:
# Put back in queue with delay to prevent infinite loops
logger.warning(f"🚨 RATE LIMIT: Task {task_id} re-queued due to rate limit exceeded: {current_rate}/{self.rate_limits[task.service]} for {task.service.value}")
# Add delay before re-queueing to prevent tight loops
next_time = time.time() + 60 # Wait 60 seconds for rate limit reset
self.redis_client.hset(
self._get_task_key(task_id),
mapping={'scheduled_at': next_time}
)
self.redis_client.lpush(queue_key, task_id)
continue
# Task can be processed
# Increment processing counter
self.redis_client.incr(service_processing_key)
self.redis_client.expire(service_processing_key, 3600) # 1 hour
# Increment rate limit counter
self.redis_client.incr(rate_key)
self.redis_client.expire(rate_key, 60) # 1 minute
# Add to processing set
self.redis_client.hset(
self.processing_key,
task_id,
json.dumps({
'service': task.service.value,
'started_at': time.time(),
'worker_id': threading.current_thread().ident
})
)
# Update metrics
self._update_metrics("dequeued", task.service.value, task.priority.value)
logger.debug(f"Dequeued task {task_id}: {task.service.value}/{task.task_type}")
return task
return None
def complete_task(self, task: QueueTask, result: Dict[str, Any] = None):
"""Mark task as completed and clean up."""
# Remove from processing
self.redis_client.hdel(self.processing_key, task.id)
# Decrement processing counter
service_processing_key = self._get_service_processing_key(task.service)
self.redis_client.decr(service_processing_key)
# Update task status
self.redis_client.hset(
self._get_task_key(task.id),
mapping={
'status': TaskStatus.COMPLETED.value,
'completed_at': time.time(),
'result': json.dumps(result or {})
}
)
# Update metrics
self._update_metrics("completed", task.service.value, task.priority.value)
logger.info(f"Completed task {task.id}: {task.service.value}/{task.task_type}")
def fail_task(self, task: QueueTask, error: str, retry: bool = True):
"""Handle task failure with retry logic."""
# Remove from processing
self.redis_client.hdel(self.processing_key, task.id)
# Decrement processing counter
service_processing_key = self._get_service_processing_key(task.service)
self.redis_client.decr(service_processing_key)
task.attempts += 1
if retry and task.attempts < task.max_attempts:
# Enhanced retry logic with progress-aware delays for comparison tasks
if task.task_type == 'docling_comparison_analysis':
# Special handling for comparison analysis tasks
payload = getattr(task, 'payload', {})
# Check if this is a progress-aware retry
if hasattr(error, 'is_progress_retry') and error.is_progress_retry:
# Active progress - shorter delay (30s to 2 minutes)
delay = min(120, 30 + task.attempts * 15)
logger.info(f"Comparison task {task.id}: Active progress detected, shorter retry delay: {delay}s")
elif hasattr(error, 'is_alignment_retry') and error.is_alignment_retry:
# Alignment issues - medium delay (1-3 minutes)
delay = min(180, 60 + task.attempts * 20)
logger.info(f"Comparison task {task.id}: Alignment retry, medium delay: {delay}s")
elif hasattr(error, 'is_stalled_retry') and error.is_stalled_retry:
# Stalled processing - longer delay (2-10 minutes)
delay = min(600, 120 + task.attempts * 30)
logger.info(f"Comparison task {task.id}: Stalled retry, longer delay: {delay}s")
else:
# Default comparison delay - extended for large PDFs (5-20 minutes)
delay = min(1200, 300 + task.attempts * 60)
logger.info(f"Comparison task {task.id}: Standard retry, extended delay: {delay}s")
# Update progress tracking in payload for next attempt
if hasattr(error, 'current_progress'):
payload['previous_progress'] = error.current_progress
# Update the task payload in Redis
import json
self.redis_client.hset(
self._get_task_key(task.id),
'payload', json.dumps(payload)
)
else:
# Standard retry with exponential backoff for other tasks
delay = min(300, 2 ** task.attempts * 10) # Max 5 minutes
task.scheduled_at = time.time() + delay
task.status = TaskStatus.RETRYING
# Update task data
self.redis_client.hset(
self._get_task_key(task.id),
mapping={
'attempts': task.attempts,
'scheduled_at': task.scheduled_at,
'status': TaskStatus.RETRYING.value,
'last_error': error
}
)
# Re-queue for retry
queue_key = self.queue_keys[task.priority]
self.redis_client.lpush(queue_key, task.id)
# Update metrics
self._update_metrics("retried", task.service.value, task.priority.value)
logger.warning(f"Retrying task {task.id} in {delay}s (attempt {task.attempts}/{task.max_attempts}): {error}")
else:
# Move to dead letter queue
task.status = TaskStatus.DEAD
self.redis_client.hset(
self._get_task_key(task.id),
mapping={
'attempts': task.attempts,
'status': TaskStatus.DEAD.value,
'failed_at': time.time(),
'final_error': error
}
)
self.redis_client.lpush(self.dead_letter_key, task.id)
# Update metrics
self._update_metrics("failed", task.service.value, task.priority.value)
logger.error(f"Task {task.id} moved to dead letter queue after {task.attempts} attempts: {error}")
def get_queue_stats(self) -> Dict[str, Any]:
"""Get comprehensive queue statistics."""
stats = {
'queues': {},
'processing': {},
'service_limits': dict(self.service_limits),
'rate_limits': dict(self.rate_limits),
'dead_letter_count': self.redis_client.llen(self.dead_letter_key)
}
# Queue lengths
for priority in TaskPriority:
queue_key = self.queue_keys[priority]
stats['queues'][priority.value] = self.redis_client.llen(queue_key)
# Processing counts
for service in ServiceType:
service_key = self._get_service_processing_key(service)
stats['processing'][service.value] = int(self.redis_client.get(service_key) or 0)
# Total processing
stats['total_processing'] = self.redis_client.hlen(self.processing_key)
return stats
def _update_metrics(self, action: str, service: str, priority: str):
"""Update queue metrics."""
timestamp = int(time.time())
metric_key = f"{self.metrics_key}:{action}:{service}:{priority}:{timestamp // 60}"
self.redis_client.incr(metric_key)
self.redis_client.expire(metric_key, 3600) # 1 hour
def start_worker(self, worker_id: str = None, services: List[ServiceType] = None):
"""Start a queue worker thread."""
if worker_id is None:
worker_id = f"worker-{uuid.uuid4().hex[:8]}"
if services is None:
services = list(ServiceType)
def worker_loop():
logger.info(f"Starting worker {worker_id} for services: {[s.value for s in services]}")
while not self.shutdown_event.is_set():
try:
task = self.dequeue_task(timeout=5)
if task is None:
continue
# DEBUG: Log task details immediately after dequeue
logger.info(f"🔍 WORKER DEBUG: Dequeued task {task.id}, service={task.service.value}, task_type={task.task_type}")
logger.info(f"🔍 WORKER DEBUG: Worker {worker_id} handles services: {[s.value for s in services]}")
if task.service not in services:
# Put back in queue if worker doesn't handle this service
# But first clean up the processing state to avoid infinite loops
logger.warning(f"🚨 WORKER DEBUG: Task {task.id} service {task.service.value} NOT in worker services {[s.value for s in services]}")
self.redis_client.hdel(self.processing_key, task.id)
# NOTE: Do NOT decrement service_processing_key here - task will be re-processed by correct worker
# service_processing_key = self._get_service_processing_key(task.service)
# self.redis_client.decr(service_processing_key) # REMOVED: Caused negative counters
queue_key = self.queue_keys[task.priority]
self.redis_client.lpush(queue_key, task.id)
logger.warning(f"🚨 WORKER DEBUG: Task {task.id} re-queued due to service mismatch")
continue
# DEBUG: Confirm we're about to process
logger.info(f"✅ WORKER DEBUG: About to process task {task.id}")
# Process the task
self._process_task(task)
# DEBUG: Confirm processing completed
logger.info(f"✅ WORKER DEBUG: Finished processing task {task.id}")
except Exception as e:
logger.error(f"Worker {worker_id} error: {e}")
time.sleep(1)
logger.info(f"Worker {worker_id} shutting down")
thread = threading.Thread(target=worker_loop, name=f"QueueWorker-{worker_id}")
thread.daemon = True
thread.start()
self.workers_running[worker_id] = thread
return worker_id
def _process_task(self, task: QueueTask):
"""Process a single task (to be overridden by specific implementations)."""
logger.warning(f"No processor implemented for task {task.id}: {task.service.value}/{task.task_type}")
self.fail_task(task, "No processor implemented", retry=False)
def shutdown(self, timeout: int = 30):
"""Gracefully shutdown all workers."""
logger.info("Shutting down queue workers...")
self.shutdown_event.set()
# Wait for workers to finish
for worker_id, thread in self.workers_running.items():
thread.join(timeout=timeout)
if thread.is_alive():
logger.warning(f"Worker {worker_id} did not shut down gracefully")
self.workers_running.clear()
logger.info("Queue shutdown complete")
# Global queue instance
_queue_instance = None
def get_queue() -> DocumentProcessingQueue:
"""Get the global queue instance."""
global _queue_instance
if _queue_instance is None:
_queue_instance = DocumentProcessingQueue()
return _queue_instance
# Convenience functions for common operations
def enqueue_tika_task(file_id: str, payload: Dict[str, Any], priority: TaskPriority = TaskPriority.NORMAL) -> str:
"""Enqueue a Tika processing task."""
return get_queue().enqueue_task(
service=ServiceType.TIKA,
task_type="metadata_extraction",
file_id=file_id,
payload=payload,
priority=priority,
timeout=int(os.getenv('TIKA_TIMEOUT', '300'))
)
def enqueue_docling_task(file_id: str, task_type: str, payload: Dict[str, Any],
priority: TaskPriority = TaskPriority.NORMAL, timeout: int = 1800,
max_attempts: int = None) -> str:
"""Enqueue a Docling processing task with intelligent retry limits."""
# Auto-configure max_attempts based on task type and payload
if max_attempts is None:
if task_type == 'docling_comparison_analysis':
# Use the max_retry_attempts from payload, or default to high limit for comparisons
max_attempts = payload.get('max_retry_attempts', 50)
else:
# Standard retry limit for other docling tasks
max_attempts = 3
return get_queue().enqueue_task(
service=ServiceType.DOCLING,
task_type=task_type,
file_id=file_id,
payload=payload,
priority=priority,
timeout=timeout,
max_attempts=max_attempts
)
def enqueue_llm_task(file_id: str, task_type: str, payload: Dict[str, Any],
priority: TaskPriority = TaskPriority.NORMAL) -> str:
"""Enqueue an LLM processing task."""
return get_queue().enqueue_task(
service=ServiceType.LLM,
task_type=task_type,
file_id=file_id,
payload=payload,
priority=priority,
timeout=int(os.getenv('LLM_TIMEOUT', '180'))
)
def enqueue_split_map_task(file_id: str, payload: Dict[str, Any],
priority: TaskPriority = TaskPriority.NORMAL) -> str:
"""Enqueue a split map generation task."""
return get_queue().enqueue_task(
service=ServiceType.SPLIT_MAP,
task_type="generate_split_map",
file_id=file_id,
payload=payload,
priority=priority,
timeout=120
)
def enqueue_document_analysis_task(file_id: str, payload: Dict[str, Any],
priority: TaskPriority = TaskPriority.NORMAL) -> str:
"""Enqueue a document structure analysis task."""
return get_queue().enqueue_task(
service=ServiceType.DOCUMENT_ANALYSIS,
task_type="document_structure_analysis",
file_id=file_id,
payload=payload,
priority=priority,
timeout=int(os.getenv('DOCUMENT_ANALYSIS_TIMEOUT', '300'))
)
def enqueue_page_images_task(file_id: str, payload: Dict[str, Any],
priority: TaskPriority = TaskPriority.NORMAL) -> str:
"""Enqueue a page images generation task."""
return get_queue().enqueue_task(
service=ServiceType.PAGE_IMAGES,
task_type="generate_page_images",
file_id=file_id,
payload=payload,
priority=priority,
timeout=int(os.getenv('PAGE_IMAGES_TIMEOUT', '600'))
)