680 lines
28 KiB
Python
680 lines
28 KiB
Python
"""
|
|
Robust Document Processing Queue System
|
|
|
|
This module provides a Redis-based queuing system for document processing tasks
|
|
to prevent server overload and handle concurrent processing efficiently.
|
|
|
|
Features:
|
|
- Priority-based queuing (high, normal, low)
|
|
- Rate limiting per service (Tika, Docling, LLM)
|
|
- Concurrent processing limits
|
|
- Task retry mechanism with exponential backoff
|
|
- Dead letter queue for failed tasks
|
|
- Health monitoring and metrics
|
|
- Graceful shutdown handling
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
import uuid
|
|
import asyncio
|
|
import threading
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Callable, Union
|
|
from urllib.parse import urlparse
|
|
from dataclasses import dataclass, asdict
|
|
from enum import Enum
|
|
import redis
|
|
import os
|
|
from modules.logger_tool import initialise_logger
|
|
from .redis_manager import get_redis_manager, Environment
|
|
|
|
logger = initialise_logger(__name__, os.getenv("LOG_LEVEL"), os.getenv("LOG_PATH"), 'default', True)
|
|
|
|
# Custom exceptions
|
|
class QueueConnectionError(Exception):
|
|
"""Raised when the queue system cannot connect to Redis"""
|
|
pass
|
|
|
|
class QueueFullError(Exception):
|
|
"""Raised when a service queue is at capacity"""
|
|
pass
|
|
|
|
class RateLimitError(Exception):
|
|
"""Raised when rate limits are exceeded"""
|
|
pass
|
|
|
|
class TaskPriority(Enum):
|
|
HIGH = "high" # Interactive uploads, user waiting
|
|
NORMAL = "normal" # Regular batch processing
|
|
LOW = "low" # Background/cleanup tasks
|
|
|
|
class TaskStatus(Enum):
|
|
PENDING = "pending"
|
|
PROCESSING = "processing"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
RETRYING = "retrying"
|
|
DEAD = "dead"
|
|
|
|
class ServiceType(Enum):
|
|
TIKA = "tika"
|
|
DOCLING = "docling"
|
|
LLM = "llm"
|
|
SPLIT_MAP = "split_map"
|
|
DOCUMENT_ANALYSIS = "document_analysis"
|
|
PAGE_IMAGES = "page_images"
|
|
|
|
@dataclass
|
|
class QueueTask:
|
|
"""Represents a task in the processing queue."""
|
|
id: str
|
|
service: ServiceType
|
|
priority: TaskPriority
|
|
file_id: str
|
|
task_type: str # e.g., "tika_metadata", "docling_frontmatter", "llm_classify"
|
|
payload: Dict[str, Any]
|
|
created_at: float
|
|
scheduled_at: float = None # For delayed/retry tasks
|
|
attempts: int = 0
|
|
max_attempts: int = 3
|
|
timeout: int = 300 # seconds
|
|
callback_url: Optional[str] = None
|
|
user_id: Optional[str] = None
|
|
|
|
def __post_init__(self):
|
|
if self.scheduled_at is None:
|
|
self.scheduled_at = time.time()
|
|
|
|
class DocumentProcessingQueue:
|
|
"""Redis-based document processing queue with rate limiting and priorities."""
|
|
|
|
def __init__(self, environment: str = None, redis_url: str = None):
|
|
"""Initialize queue with Redis connection using new Redis manager.
|
|
|
|
Args:
|
|
environment: 'dev', 'prod', or 'test' (auto-detected if not provided)
|
|
redis_url: Legacy parameter for backward compatibility
|
|
"""
|
|
# Auto-detect environment from startup mode if not provided
|
|
if not environment:
|
|
environment = 'dev' if os.getenv('BACKEND_DEV_MODE', 'true').lower() == 'true' else 'prod'
|
|
|
|
# Initialize Redis manager for this environment
|
|
self.redis_manager = get_redis_manager(environment)
|
|
|
|
# Initialize Redis environment (ensures service running and connects)
|
|
if not self.redis_manager.initialize_environment():
|
|
raise ConnectionError(f"Failed to initialize Redis for {environment} environment")
|
|
|
|
# Use the managed Redis client
|
|
self.redis_client = self.redis_manager.client
|
|
self.environment = environment
|
|
|
|
logger.info(f"🎯 Queue initialized for {environment} environment (db={self.redis_manager.config.db})")
|
|
|
|
# Queue configuration
|
|
self.service_limits = {
|
|
ServiceType.TIKA: int(os.getenv('QUEUE_TIKA_LIMIT', '3')),
|
|
ServiceType.DOCLING: int(os.getenv('QUEUE_DOCLING_LIMIT', '2')),
|
|
ServiceType.LLM: int(os.getenv('QUEUE_LLM_LIMIT', '5')),
|
|
ServiceType.SPLIT_MAP: int(os.getenv('QUEUE_SPLIT_MAP_LIMIT', '10')),
|
|
ServiceType.DOCUMENT_ANALYSIS: int(os.getenv('QUEUE_DOCUMENT_ANALYSIS_LIMIT', '5')),
|
|
ServiceType.PAGE_IMAGES: int(os.getenv('QUEUE_PAGE_IMAGES_LIMIT', '3'))
|
|
}
|
|
|
|
# Rate limiting (requests per minute)
|
|
self.rate_limits = {
|
|
ServiceType.TIKA: int(os.getenv('QUEUE_TIKA_RATE', '60')),
|
|
ServiceType.DOCLING: int(os.getenv('QUEUE_DOCLING_RATE', '30')),
|
|
ServiceType.LLM: int(os.getenv('QUEUE_LLM_RATE', '120')),
|
|
ServiceType.SPLIT_MAP: int(os.getenv('QUEUE_SPLIT_MAP_RATE', '100'))
|
|
}
|
|
|
|
# Queue names
|
|
self.queue_keys = {
|
|
priority: f"queue:{priority.value}" for priority in TaskPriority
|
|
}
|
|
self.processing_key = "processing"
|
|
self.dead_letter_key = "dead_letter"
|
|
self.metrics_key = "metrics"
|
|
|
|
# Worker control
|
|
self.workers_running = {}
|
|
self.shutdown_event = threading.Event()
|
|
|
|
logger.info(f"⚙️ Queue service limits: {dict(self.service_limits)}")
|
|
|
|
def _get_task_key(self, task_id: str) -> str:
|
|
"""Get Redis key for task data."""
|
|
return f"task:{task_id}"
|
|
|
|
def _get_service_processing_key(self, service: ServiceType) -> str:
|
|
"""Get Redis key for service processing counter."""
|
|
return f"processing:{service.value}"
|
|
|
|
def _get_rate_limit_key(self, service: ServiceType) -> str:
|
|
"""Get Redis key for rate limiting."""
|
|
return f"rate_limit:{service.value}:{int(time.time() // 60)}"
|
|
|
|
def enqueue_task(self,
|
|
service: ServiceType,
|
|
task_type: str,
|
|
file_id: str,
|
|
payload: Dict[str, Any],
|
|
priority: TaskPriority = TaskPriority.NORMAL,
|
|
user_id: str = None,
|
|
timeout: int = 300,
|
|
max_attempts: int = 3) -> str:
|
|
"""
|
|
Enqueue a new processing task.
|
|
|
|
Returns:
|
|
str: Task ID
|
|
"""
|
|
task_id = str(uuid.uuid4())
|
|
task = QueueTask(
|
|
id=task_id,
|
|
service=service,
|
|
priority=priority,
|
|
file_id=file_id,
|
|
task_type=task_type,
|
|
payload=payload,
|
|
created_at=time.time(),
|
|
timeout=timeout,
|
|
max_attempts=max_attempts,
|
|
user_id=user_id
|
|
)
|
|
|
|
try:
|
|
# Store task data (convert enums to strings for Redis)
|
|
task_dict = asdict(task)
|
|
task_dict['service'] = task_dict['service'].value
|
|
task_dict['priority'] = task_dict['priority'].value
|
|
task_dict['payload'] = json.dumps(task_dict['payload'])
|
|
# Convert None values to empty strings for Redis
|
|
for key, value in task_dict.items():
|
|
if value is None:
|
|
task_dict[key] = ''
|
|
|
|
self.redis_client.hset(
|
|
self._get_task_key(task_id),
|
|
mapping=task_dict
|
|
)
|
|
self.redis_client.expire(self._get_task_key(task_id), 86400) # 24 hours TTL
|
|
|
|
# Add to priority queue
|
|
queue_key = self.queue_keys[priority]
|
|
self.redis_client.lpush(queue_key, task_id)
|
|
|
|
except redis.ConnectionError as e:
|
|
logger.error(f"Redis connection failed when enqueueing task: {e}")
|
|
logger.error("Please ensure Redis is running. Start the API server with './start.sh dev' to auto-start Redis.")
|
|
raise QueueConnectionError(f"Queue system unavailable: Redis connection failed")
|
|
except Exception as e:
|
|
logger.error(f"Failed to enqueue task {task_id}: {e}")
|
|
raise
|
|
|
|
# Update metrics
|
|
self._update_metrics("enqueued", service.value, priority.value)
|
|
|
|
logger.info(f"Enqueued task {task_id}: {service.value}/{task_type} for file {file_id}")
|
|
return task_id
|
|
|
|
def dequeue_task(self, timeout: int = 10) -> Optional[QueueTask]:
|
|
"""
|
|
Dequeue next available task respecting service limits and rate limits.
|
|
|
|
Args:
|
|
timeout: Blocking timeout in seconds
|
|
|
|
Returns:
|
|
QueueTask or None if no task available
|
|
"""
|
|
# Check all priority queues in order
|
|
for priority in [TaskPriority.HIGH, TaskPriority.NORMAL, TaskPriority.LOW]:
|
|
queue_key = self.queue_keys[priority]
|
|
|
|
# Non-blocking pop to check availability
|
|
task_id = self.redis_client.rpop(queue_key)
|
|
if not task_id:
|
|
continue
|
|
|
|
# Get task data
|
|
task_data = self.redis_client.hgetall(self._get_task_key(task_id))
|
|
if not task_data:
|
|
logger.warning(f"Task {task_id} data not found, skipping")
|
|
continue
|
|
|
|
# Reconstruct task object (ignore non-dataclass keys like 'status', 'result', timestamps)
|
|
task_data['service'] = ServiceType(task_data['service'])
|
|
task_data['priority'] = TaskPriority(task_data['priority'])
|
|
task_data['created_at'] = float(task_data['created_at'])
|
|
task_data['scheduled_at'] = float(task_data['scheduled_at'])
|
|
task_data['attempts'] = int(task_data['attempts'])
|
|
task_data['max_attempts'] = int(task_data['max_attempts'])
|
|
task_data['timeout'] = int(task_data['timeout'])
|
|
task_data['payload'] = json.loads(task_data['payload'])
|
|
|
|
# Drop extraneous keys that may be present in Redis
|
|
for k in ['status', 'result', 'completed_at', 'failed_at', 'last_error', 'final_error']:
|
|
task_data.pop(k, None)
|
|
|
|
task = QueueTask(**task_data)
|
|
|
|
# Check if task is ready (for delayed/retry tasks)
|
|
if task.scheduled_at > time.time():
|
|
# Put back in queue for later
|
|
self.redis_client.lpush(queue_key, task_id)
|
|
continue
|
|
|
|
# Enforce simple dependency ordering using optional depends_on array in payload
|
|
try:
|
|
depends_on = []
|
|
if isinstance(task.payload, dict):
|
|
depends_on = task.payload.get('depends_on') or []
|
|
if isinstance(depends_on, list) and len(depends_on) > 0:
|
|
logger.info(f"Checking {len(depends_on)} dependencies for task {task_id}")
|
|
unmet = []
|
|
for dep_id in depends_on:
|
|
if not dep_id:
|
|
continue
|
|
dep_key = self._get_task_key(dep_id)
|
|
dep_status = self.redis_client.hget(dep_key, 'status') or TaskStatus.PENDING.value
|
|
if dep_status != TaskStatus.COMPLETED.value:
|
|
unmet.append((dep_id, dep_status))
|
|
if len(unmet) > 0:
|
|
# Reschedule this task a bit later to avoid tight loops
|
|
next_time = time.time() + 10
|
|
self.redis_client.hset(
|
|
self._get_task_key(task_id),
|
|
mapping={'scheduled_at': next_time}
|
|
)
|
|
# Put back for later processing
|
|
self.redis_client.lpush(queue_key, task_id)
|
|
logger.info(f"Deferring task {task_id} due to unmet dependencies: {unmet}")
|
|
continue
|
|
else:
|
|
logger.info(f"All {len(depends_on)} dependencies satisfied for task {task_id}")
|
|
except Exception as dep_e:
|
|
logger.warning(f"Dependency check failed for task {task_id}: {dep_e}")
|
|
|
|
# Check service limits
|
|
service_processing_key = self._get_service_processing_key(task.service)
|
|
current_processing = int(self.redis_client.get(service_processing_key) or 0)
|
|
|
|
if current_processing >= self.service_limits[task.service]:
|
|
# Put back in queue with delay to prevent infinite loops
|
|
logger.warning(f"🚨 SERVICE LIMIT: Task {task_id} re-queued due to service limit exceeded: {current_processing}/{self.service_limits[task.service]} for {task.service.value}")
|
|
|
|
# Add delay before re-queueing to prevent tight loops
|
|
next_time = time.time() + 5 # Wait 5 seconds before retrying
|
|
self.redis_client.hset(
|
|
self._get_task_key(task_id),
|
|
mapping={'scheduled_at': next_time}
|
|
)
|
|
self.redis_client.lpush(queue_key, task_id)
|
|
continue
|
|
|
|
# Check rate limits
|
|
rate_key = self._get_rate_limit_key(task.service)
|
|
current_rate = int(self.redis_client.get(rate_key) or 0)
|
|
|
|
if current_rate >= self.rate_limits[task.service]:
|
|
# Put back in queue with delay to prevent infinite loops
|
|
logger.warning(f"🚨 RATE LIMIT: Task {task_id} re-queued due to rate limit exceeded: {current_rate}/{self.rate_limits[task.service]} for {task.service.value}")
|
|
|
|
# Add delay before re-queueing to prevent tight loops
|
|
next_time = time.time() + 60 # Wait 60 seconds for rate limit reset
|
|
self.redis_client.hset(
|
|
self._get_task_key(task_id),
|
|
mapping={'scheduled_at': next_time}
|
|
)
|
|
self.redis_client.lpush(queue_key, task_id)
|
|
continue
|
|
|
|
# Task can be processed
|
|
# Increment processing counter
|
|
self.redis_client.incr(service_processing_key)
|
|
self.redis_client.expire(service_processing_key, 3600) # 1 hour
|
|
|
|
# Increment rate limit counter
|
|
self.redis_client.incr(rate_key)
|
|
self.redis_client.expire(rate_key, 60) # 1 minute
|
|
|
|
# Add to processing set
|
|
self.redis_client.hset(
|
|
self.processing_key,
|
|
task_id,
|
|
json.dumps({
|
|
'service': task.service.value,
|
|
'started_at': time.time(),
|
|
'worker_id': threading.current_thread().ident
|
|
})
|
|
)
|
|
|
|
# Update metrics
|
|
self._update_metrics("dequeued", task.service.value, task.priority.value)
|
|
|
|
logger.debug(f"Dequeued task {task_id}: {task.service.value}/{task.task_type}")
|
|
return task
|
|
|
|
return None
|
|
|
|
def complete_task(self, task: QueueTask, result: Dict[str, Any] = None):
|
|
"""Mark task as completed and clean up."""
|
|
# Remove from processing
|
|
self.redis_client.hdel(self.processing_key, task.id)
|
|
|
|
# Decrement processing counter
|
|
service_processing_key = self._get_service_processing_key(task.service)
|
|
self.redis_client.decr(service_processing_key)
|
|
|
|
# Update task status
|
|
self.redis_client.hset(
|
|
self._get_task_key(task.id),
|
|
mapping={
|
|
'status': TaskStatus.COMPLETED.value,
|
|
'completed_at': time.time(),
|
|
'result': json.dumps(result or {})
|
|
}
|
|
)
|
|
|
|
# Update metrics
|
|
self._update_metrics("completed", task.service.value, task.priority.value)
|
|
|
|
logger.info(f"Completed task {task.id}: {task.service.value}/{task.task_type}")
|
|
|
|
def fail_task(self, task: QueueTask, error: str, retry: bool = True):
|
|
"""Handle task failure with retry logic."""
|
|
# Remove from processing
|
|
self.redis_client.hdel(self.processing_key, task.id)
|
|
|
|
# Decrement processing counter
|
|
service_processing_key = self._get_service_processing_key(task.service)
|
|
self.redis_client.decr(service_processing_key)
|
|
|
|
task.attempts += 1
|
|
|
|
if retry and task.attempts < task.max_attempts:
|
|
# Enhanced retry logic with progress-aware delays for comparison tasks
|
|
if task.task_type == 'docling_comparison_analysis':
|
|
# Special handling for comparison analysis tasks
|
|
payload = getattr(task, 'payload', {})
|
|
|
|
# Check if this is a progress-aware retry
|
|
if hasattr(error, 'is_progress_retry') and error.is_progress_retry:
|
|
# Active progress - shorter delay (30s to 2 minutes)
|
|
delay = min(120, 30 + task.attempts * 15)
|
|
logger.info(f"Comparison task {task.id}: Active progress detected, shorter retry delay: {delay}s")
|
|
elif hasattr(error, 'is_alignment_retry') and error.is_alignment_retry:
|
|
# Alignment issues - medium delay (1-3 minutes)
|
|
delay = min(180, 60 + task.attempts * 20)
|
|
logger.info(f"Comparison task {task.id}: Alignment retry, medium delay: {delay}s")
|
|
elif hasattr(error, 'is_stalled_retry') and error.is_stalled_retry:
|
|
# Stalled processing - longer delay (2-10 minutes)
|
|
delay = min(600, 120 + task.attempts * 30)
|
|
logger.info(f"Comparison task {task.id}: Stalled retry, longer delay: {delay}s")
|
|
else:
|
|
# Default comparison delay - extended for large PDFs (5-20 minutes)
|
|
delay = min(1200, 300 + task.attempts * 60)
|
|
logger.info(f"Comparison task {task.id}: Standard retry, extended delay: {delay}s")
|
|
|
|
# Update progress tracking in payload for next attempt
|
|
if hasattr(error, 'current_progress'):
|
|
payload['previous_progress'] = error.current_progress
|
|
# Update the task payload in Redis
|
|
import json
|
|
self.redis_client.hset(
|
|
self._get_task_key(task.id),
|
|
'payload', json.dumps(payload)
|
|
)
|
|
else:
|
|
# Standard retry with exponential backoff for other tasks
|
|
delay = min(300, 2 ** task.attempts * 10) # Max 5 minutes
|
|
|
|
task.scheduled_at = time.time() + delay
|
|
task.status = TaskStatus.RETRYING
|
|
|
|
# Update task data
|
|
self.redis_client.hset(
|
|
self._get_task_key(task.id),
|
|
mapping={
|
|
'attempts': task.attempts,
|
|
'scheduled_at': task.scheduled_at,
|
|
'status': TaskStatus.RETRYING.value,
|
|
'last_error': error
|
|
}
|
|
)
|
|
|
|
# Re-queue for retry
|
|
queue_key = self.queue_keys[task.priority]
|
|
self.redis_client.lpush(queue_key, task.id)
|
|
|
|
# Update metrics
|
|
self._update_metrics("retried", task.service.value, task.priority.value)
|
|
|
|
logger.warning(f"Retrying task {task.id} in {delay}s (attempt {task.attempts}/{task.max_attempts}): {error}")
|
|
else:
|
|
# Move to dead letter queue
|
|
task.status = TaskStatus.DEAD
|
|
|
|
self.redis_client.hset(
|
|
self._get_task_key(task.id),
|
|
mapping={
|
|
'attempts': task.attempts,
|
|
'status': TaskStatus.DEAD.value,
|
|
'failed_at': time.time(),
|
|
'final_error': error
|
|
}
|
|
)
|
|
|
|
self.redis_client.lpush(self.dead_letter_key, task.id)
|
|
|
|
# Update metrics
|
|
self._update_metrics("failed", task.service.value, task.priority.value)
|
|
|
|
logger.error(f"Task {task.id} moved to dead letter queue after {task.attempts} attempts: {error}")
|
|
|
|
def get_queue_stats(self) -> Dict[str, Any]:
|
|
"""Get comprehensive queue statistics."""
|
|
stats = {
|
|
'queues': {},
|
|
'processing': {},
|
|
'service_limits': dict(self.service_limits),
|
|
'rate_limits': dict(self.rate_limits),
|
|
'dead_letter_count': self.redis_client.llen(self.dead_letter_key)
|
|
}
|
|
|
|
# Queue lengths
|
|
for priority in TaskPriority:
|
|
queue_key = self.queue_keys[priority]
|
|
stats['queues'][priority.value] = self.redis_client.llen(queue_key)
|
|
|
|
# Processing counts
|
|
for service in ServiceType:
|
|
service_key = self._get_service_processing_key(service)
|
|
stats['processing'][service.value] = int(self.redis_client.get(service_key) or 0)
|
|
|
|
# Total processing
|
|
stats['total_processing'] = self.redis_client.hlen(self.processing_key)
|
|
|
|
return stats
|
|
|
|
def _update_metrics(self, action: str, service: str, priority: str):
|
|
"""Update queue metrics."""
|
|
timestamp = int(time.time())
|
|
metric_key = f"{self.metrics_key}:{action}:{service}:{priority}:{timestamp // 60}"
|
|
self.redis_client.incr(metric_key)
|
|
self.redis_client.expire(metric_key, 3600) # 1 hour
|
|
|
|
def start_worker(self, worker_id: str = None, services: List[ServiceType] = None):
|
|
"""Start a queue worker thread."""
|
|
if worker_id is None:
|
|
worker_id = f"worker-{uuid.uuid4().hex[:8]}"
|
|
|
|
if services is None:
|
|
services = list(ServiceType)
|
|
|
|
def worker_loop():
|
|
logger.info(f"Starting worker {worker_id} for services: {[s.value for s in services]}")
|
|
|
|
while not self.shutdown_event.is_set():
|
|
try:
|
|
task = self.dequeue_task(timeout=5)
|
|
if task is None:
|
|
continue
|
|
|
|
# DEBUG: Log task details immediately after dequeue
|
|
logger.info(f"🔍 WORKER DEBUG: Dequeued task {task.id}, service={task.service.value}, task_type={task.task_type}")
|
|
logger.info(f"🔍 WORKER DEBUG: Worker {worker_id} handles services: {[s.value for s in services]}")
|
|
|
|
if task.service not in services:
|
|
# Put back in queue if worker doesn't handle this service
|
|
# But first clean up the processing state to avoid infinite loops
|
|
logger.warning(f"🚨 WORKER DEBUG: Task {task.id} service {task.service.value} NOT in worker services {[s.value for s in services]}")
|
|
self.redis_client.hdel(self.processing_key, task.id)
|
|
# NOTE: Do NOT decrement service_processing_key here - task will be re-processed by correct worker
|
|
# service_processing_key = self._get_service_processing_key(task.service)
|
|
# self.redis_client.decr(service_processing_key) # REMOVED: Caused negative counters
|
|
|
|
queue_key = self.queue_keys[task.priority]
|
|
self.redis_client.lpush(queue_key, task.id)
|
|
logger.warning(f"🚨 WORKER DEBUG: Task {task.id} re-queued due to service mismatch")
|
|
continue
|
|
|
|
# DEBUG: Confirm we're about to process
|
|
logger.info(f"✅ WORKER DEBUG: About to process task {task.id}")
|
|
|
|
# Process the task
|
|
self._process_task(task)
|
|
|
|
# DEBUG: Confirm processing completed
|
|
logger.info(f"✅ WORKER DEBUG: Finished processing task {task.id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Worker {worker_id} error: {e}")
|
|
time.sleep(1)
|
|
|
|
logger.info(f"Worker {worker_id} shutting down")
|
|
|
|
thread = threading.Thread(target=worker_loop, name=f"QueueWorker-{worker_id}")
|
|
thread.daemon = True
|
|
thread.start()
|
|
|
|
self.workers_running[worker_id] = thread
|
|
return worker_id
|
|
|
|
def _process_task(self, task: QueueTask):
|
|
"""Process a single task (to be overridden by specific implementations)."""
|
|
logger.warning(f"No processor implemented for task {task.id}: {task.service.value}/{task.task_type}")
|
|
self.fail_task(task, "No processor implemented", retry=False)
|
|
|
|
def shutdown(self, timeout: int = 30):
|
|
"""Gracefully shutdown all workers."""
|
|
logger.info("Shutting down queue workers...")
|
|
self.shutdown_event.set()
|
|
|
|
# Wait for workers to finish
|
|
for worker_id, thread in self.workers_running.items():
|
|
thread.join(timeout=timeout)
|
|
if thread.is_alive():
|
|
logger.warning(f"Worker {worker_id} did not shut down gracefully")
|
|
|
|
self.workers_running.clear()
|
|
logger.info("Queue shutdown complete")
|
|
|
|
# Global queue instance
|
|
_queue_instance = None
|
|
|
|
def get_queue() -> DocumentProcessingQueue:
|
|
"""Get the global queue instance."""
|
|
global _queue_instance
|
|
if _queue_instance is None:
|
|
_queue_instance = DocumentProcessingQueue()
|
|
return _queue_instance
|
|
|
|
# Convenience functions for common operations
|
|
def enqueue_tika_task(file_id: str, payload: Dict[str, Any], priority: TaskPriority = TaskPriority.NORMAL) -> str:
|
|
"""Enqueue a Tika processing task."""
|
|
return get_queue().enqueue_task(
|
|
service=ServiceType.TIKA,
|
|
task_type="metadata_extraction",
|
|
file_id=file_id,
|
|
payload=payload,
|
|
priority=priority,
|
|
timeout=int(os.getenv('TIKA_TIMEOUT', '300'))
|
|
)
|
|
|
|
def enqueue_docling_task(file_id: str, task_type: str, payload: Dict[str, Any],
|
|
priority: TaskPriority = TaskPriority.NORMAL, timeout: int = 1800,
|
|
max_attempts: int = None) -> str:
|
|
"""Enqueue a Docling processing task with intelligent retry limits."""
|
|
|
|
# Auto-configure max_attempts based on task type and payload
|
|
if max_attempts is None:
|
|
if task_type == 'docling_comparison_analysis':
|
|
# Use the max_retry_attempts from payload, or default to high limit for comparisons
|
|
max_attempts = payload.get('max_retry_attempts', 50)
|
|
else:
|
|
# Standard retry limit for other docling tasks
|
|
max_attempts = 3
|
|
|
|
return get_queue().enqueue_task(
|
|
service=ServiceType.DOCLING,
|
|
task_type=task_type,
|
|
file_id=file_id,
|
|
payload=payload,
|
|
priority=priority,
|
|
timeout=timeout,
|
|
max_attempts=max_attempts
|
|
)
|
|
|
|
def enqueue_llm_task(file_id: str, task_type: str, payload: Dict[str, Any],
|
|
priority: TaskPriority = TaskPriority.NORMAL) -> str:
|
|
"""Enqueue an LLM processing task."""
|
|
return get_queue().enqueue_task(
|
|
service=ServiceType.LLM,
|
|
task_type=task_type,
|
|
file_id=file_id,
|
|
payload=payload,
|
|
priority=priority,
|
|
timeout=int(os.getenv('LLM_TIMEOUT', '180'))
|
|
)
|
|
|
|
def enqueue_split_map_task(file_id: str, payload: Dict[str, Any],
|
|
priority: TaskPriority = TaskPriority.NORMAL) -> str:
|
|
"""Enqueue a split map generation task."""
|
|
return get_queue().enqueue_task(
|
|
service=ServiceType.SPLIT_MAP,
|
|
task_type="generate_split_map",
|
|
file_id=file_id,
|
|
payload=payload,
|
|
priority=priority,
|
|
timeout=120
|
|
)
|
|
|
|
def enqueue_document_analysis_task(file_id: str, payload: Dict[str, Any],
|
|
priority: TaskPriority = TaskPriority.NORMAL) -> str:
|
|
"""Enqueue a document structure analysis task."""
|
|
return get_queue().enqueue_task(
|
|
service=ServiceType.DOCUMENT_ANALYSIS,
|
|
task_type="document_structure_analysis",
|
|
file_id=file_id,
|
|
payload=payload,
|
|
priority=priority,
|
|
timeout=int(os.getenv('DOCUMENT_ANALYSIS_TIMEOUT', '300'))
|
|
)
|
|
|
|
def enqueue_page_images_task(file_id: str, payload: Dict[str, Any],
|
|
priority: TaskPriority = TaskPriority.NORMAL) -> str:
|
|
"""Enqueue a page images generation task."""
|
|
return get_queue().enqueue_task(
|
|
service=ServiceType.PAGE_IMAGES,
|
|
task_type="generate_page_images",
|
|
file_id=file_id,
|
|
payload=payload,
|
|
priority=priority,
|
|
timeout=int(os.getenv('PAGE_IMAGES_TIMEOUT', '600'))
|
|
)
|