diff --git a/routers/database/timetable/crud/base.py b/routers/database/timetable/crud/base.py new file mode 100644 index 0000000..ee53858 --- /dev/null +++ b/routers/database/timetable/crud/base.py @@ -0,0 +1,113 @@ +"""Base CRUD class for timetable operations.""" +from typing import List, Optional, Dict, Any +from uuid import UUID +from supabase import Client + +class BaseCRUD: + def __init__(self, table_name: str): + self.table_name = table_name + + async def get_by_id(self, db: Client, id: UUID) -> Optional[Dict[str, Any]]: + """Get record by ID.""" + result = db.table(self.table_name).select("*").eq("id", str(id)).execute() + return result.data[0] if result.data else None + + async def get_by_ids(self, db: Client, ids: List[UUID]) -> List[Dict[str, Any]]: + """Get multiple records by IDs.""" + str_ids = [str(id) for id in ids] + result = db.table(self.table_name).select("*").in_("id", str_ids).execute() + return result.data or [] + + async def get_multi( + self, db: Client, + skip: int = 0, + limit: int = 100, + order_by: str = "created_at", + ascending: bool = False, + filters: Optional[Dict[str, Any]] = None + ) -> List[Dict[str, Any]]: + """Alias for get_all - get all records with pagination and optional filtering.""" + return await self.get_all(db, skip=skip, limit=limit, order_by=order_by, ascending=ascending, filters=filters) + + async def get_all( + self, db: Client, + skip: int = 0, + limit: int = 100, + order_by: str = "created_at", + ascending: bool = False, + filters: Optional[Dict[str, Any]] = None + ) -> List[Dict[str, Any]]: + """Get all records with pagination and optional filtering.""" + query = db.table(self.table_name).select("*") + + # Apply filters if provided + if filters: + for key, value in filters.items(): + if value is not None: + if isinstance(value, UUID): + query = query.eq(key, str(value)) + elif isinstance(value, list): + query = query.in_(key, [str(v) if isinstance(v, UUID) else v for v in value]) + else: + query = query.eq(key, value) + + query = query.order(order_by, desc=not ascending) + query = query.range(skip, skip + limit - 1) + result = query.execute() + return result.data or [] + + async def count(self, db: Client, filters: Optional[Dict[str, Any]] = None) -> int: + """Count records with optional filters.""" + query = db.table(self.table_name).select("id", count="exact") + if filters: + for key, value in filters.items(): + query = query.eq(key, str(value) if isinstance(value, UUID) else value) + result = query.execute() + return result.count if hasattr(result, 'count') else len(result.data) + + async def create(self, db: Client, data: Dict[str, Any]) -> Dict[str, Any]: + """Create new record.""" + # Convert UUIDs to strings + processed_data = {} + for key, value in data.items(): + if isinstance(value, UUID): + processed_data[key] = str(value) + elif isinstance(value, list): + processed_data[key] = [ + str(v) if isinstance(v, UUID) else v for v in value + ] + else: + processed_data[key] = value + + result = db.table(self.table_name).insert(processed_data).execute() + return result.data[0] if result.data else processed_data + + async def update( + self, db: Client, id: UUID, data: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """Update existing record.""" + # Convert UUIDs to strings and remove None values + processed_data = {} + for key, value in data.items(): + if value is not None: + if isinstance(value, UUID): + processed_data[key] = str(value) + elif isinstance(value, list): + processed_data[key] = [ + str(v) if isinstance(v, UUID) else v for v in value + ] + else: + processed_data[key] = value + + if not processed_data: + return await self.get_by_id(db, id) + + result = db.table(self.table_name) \ + .update(processed_data) \ + .eq("id", str(id)).execute() + return result.data[0] if result.data else None + + async def delete(self, db: Client, id: UUID) -> bool: + """Delete record by ID.""" + result = db.table(self.table_name).delete().eq("id", str(id)).execute() + return bool(result.data)