fix: Add filters parameter to BaseCRUD.get_multi() method

- Fixed signature mismatch where enrollment_requests router was passing
  filters parameter to get_multi() but method didn't accept it
- get_multi() now accepts optional filters dict and passes it to get_all()
This commit is contained in:
Classroom Copilot Dev 2026-02-25 22:53:52 +00:00
parent f5eacab946
commit d68b63cedb

View File

@ -0,0 +1,113 @@
"""Base CRUD class for timetable operations."""
from typing import List, Optional, Dict, Any
from uuid import UUID
from supabase import Client
class BaseCRUD:
def __init__(self, table_name: str):
self.table_name = table_name
async def get_by_id(self, db: Client, id: UUID) -> Optional[Dict[str, Any]]:
"""Get record by ID."""
result = db.table(self.table_name).select("*").eq("id", str(id)).execute()
return result.data[0] if result.data else None
async def get_by_ids(self, db: Client, ids: List[UUID]) -> List[Dict[str, Any]]:
"""Get multiple records by IDs."""
str_ids = [str(id) for id in ids]
result = db.table(self.table_name).select("*").in_("id", str_ids).execute()
return result.data or []
async def get_multi(
self, db: Client,
skip: int = 0,
limit: int = 100,
order_by: str = "created_at",
ascending: bool = False,
filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""Alias for get_all - get all records with pagination and optional filtering."""
return await self.get_all(db, skip=skip, limit=limit, order_by=order_by, ascending=ascending, filters=filters)
async def get_all(
self, db: Client,
skip: int = 0,
limit: int = 100,
order_by: str = "created_at",
ascending: bool = False,
filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""Get all records with pagination and optional filtering."""
query = db.table(self.table_name).select("*")
# Apply filters if provided
if filters:
for key, value in filters.items():
if value is not None:
if isinstance(value, UUID):
query = query.eq(key, str(value))
elif isinstance(value, list):
query = query.in_(key, [str(v) if isinstance(v, UUID) else v for v in value])
else:
query = query.eq(key, value)
query = query.order(order_by, desc=not ascending)
query = query.range(skip, skip + limit - 1)
result = query.execute()
return result.data or []
async def count(self, db: Client, filters: Optional[Dict[str, Any]] = None) -> int:
"""Count records with optional filters."""
query = db.table(self.table_name).select("id", count="exact")
if filters:
for key, value in filters.items():
query = query.eq(key, str(value) if isinstance(value, UUID) else value)
result = query.execute()
return result.count if hasattr(result, 'count') else len(result.data)
async def create(self, db: Client, data: Dict[str, Any]) -> Dict[str, Any]:
"""Create new record."""
# Convert UUIDs to strings
processed_data = {}
for key, value in data.items():
if isinstance(value, UUID):
processed_data[key] = str(value)
elif isinstance(value, list):
processed_data[key] = [
str(v) if isinstance(v, UUID) else v for v in value
]
else:
processed_data[key] = value
result = db.table(self.table_name).insert(processed_data).execute()
return result.data[0] if result.data else processed_data
async def update(
self, db: Client, id: UUID, data: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""Update existing record."""
# Convert UUIDs to strings and remove None values
processed_data = {}
for key, value in data.items():
if value is not None:
if isinstance(value, UUID):
processed_data[key] = str(value)
elif isinstance(value, list):
processed_data[key] = [
str(v) if isinstance(v, UUID) else v for v in value
]
else:
processed_data[key] = value
if not processed_data:
return await self.get_by_id(db, id)
result = db.table(self.table_name) \
.update(processed_data) \
.eq("id", str(id)).execute()
return result.data[0] if result.data else None
async def delete(self, db: Client, id: UUID) -> bool:
"""Delete record by ID."""
result = db.table(self.table_name).delete().eq("id", str(id)).execute()
return bool(result.data)