fix: Add filters parameter to BaseCRUD.get_multi() method
- Fixed signature mismatch where enrollment_requests router was passing filters parameter to get_multi() but method didn't accept it - get_multi() now accepts optional filters dict and passes it to get_all()
This commit is contained in:
parent
f5eacab946
commit
d68b63cedb
113
routers/database/timetable/crud/base.py
Normal file
113
routers/database/timetable/crud/base.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""Base CRUD class for timetable operations."""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
from supabase import Client
|
||||
|
||||
class BaseCRUD:
|
||||
def __init__(self, table_name: str):
|
||||
self.table_name = table_name
|
||||
|
||||
async def get_by_id(self, db: Client, id: UUID) -> Optional[Dict[str, Any]]:
|
||||
"""Get record by ID."""
|
||||
result = db.table(self.table_name).select("*").eq("id", str(id)).execute()
|
||||
return result.data[0] if result.data else None
|
||||
|
||||
async def get_by_ids(self, db: Client, ids: List[UUID]) -> List[Dict[str, Any]]:
|
||||
"""Get multiple records by IDs."""
|
||||
str_ids = [str(id) for id in ids]
|
||||
result = db.table(self.table_name).select("*").in_("id", str_ids).execute()
|
||||
return result.data or []
|
||||
|
||||
async def get_multi(
|
||||
self, db: Client,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
order_by: str = "created_at",
|
||||
ascending: bool = False,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Alias for get_all - get all records with pagination and optional filtering."""
|
||||
return await self.get_all(db, skip=skip, limit=limit, order_by=order_by, ascending=ascending, filters=filters)
|
||||
|
||||
async def get_all(
|
||||
self, db: Client,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
order_by: str = "created_at",
|
||||
ascending: bool = False,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get all records with pagination and optional filtering."""
|
||||
query = db.table(self.table_name).select("*")
|
||||
|
||||
# Apply filters if provided
|
||||
if filters:
|
||||
for key, value in filters.items():
|
||||
if value is not None:
|
||||
if isinstance(value, UUID):
|
||||
query = query.eq(key, str(value))
|
||||
elif isinstance(value, list):
|
||||
query = query.in_(key, [str(v) if isinstance(v, UUID) else v for v in value])
|
||||
else:
|
||||
query = query.eq(key, value)
|
||||
|
||||
query = query.order(order_by, desc=not ascending)
|
||||
query = query.range(skip, skip + limit - 1)
|
||||
result = query.execute()
|
||||
return result.data or []
|
||||
|
||||
async def count(self, db: Client, filters: Optional[Dict[str, Any]] = None) -> int:
|
||||
"""Count records with optional filters."""
|
||||
query = db.table(self.table_name).select("id", count="exact")
|
||||
if filters:
|
||||
for key, value in filters.items():
|
||||
query = query.eq(key, str(value) if isinstance(value, UUID) else value)
|
||||
result = query.execute()
|
||||
return result.count if hasattr(result, 'count') else len(result.data)
|
||||
|
||||
async def create(self, db: Client, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create new record."""
|
||||
# Convert UUIDs to strings
|
||||
processed_data = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, UUID):
|
||||
processed_data[key] = str(value)
|
||||
elif isinstance(value, list):
|
||||
processed_data[key] = [
|
||||
str(v) if isinstance(v, UUID) else v for v in value
|
||||
]
|
||||
else:
|
||||
processed_data[key] = value
|
||||
|
||||
result = db.table(self.table_name).insert(processed_data).execute()
|
||||
return result.data[0] if result.data else processed_data
|
||||
|
||||
async def update(
|
||||
self, db: Client, id: UUID, data: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Update existing record."""
|
||||
# Convert UUIDs to strings and remove None values
|
||||
processed_data = {}
|
||||
for key, value in data.items():
|
||||
if value is not None:
|
||||
if isinstance(value, UUID):
|
||||
processed_data[key] = str(value)
|
||||
elif isinstance(value, list):
|
||||
processed_data[key] = [
|
||||
str(v) if isinstance(v, UUID) else v for v in value
|
||||
]
|
||||
else:
|
||||
processed_data[key] = value
|
||||
|
||||
if not processed_data:
|
||||
return await self.get_by_id(db, id)
|
||||
|
||||
result = db.table(self.table_name) \
|
||||
.update(processed_data) \
|
||||
.eq("id", str(id)).execute()
|
||||
return result.data[0] if result.data else None
|
||||
|
||||
async def delete(self, db: Client, id: UUID) -> bool:
|
||||
"""Delete record by ID."""
|
||||
result = db.table(self.table_name).delete().eq("id", str(id)).execute()
|
||||
return bool(result.data)
|
||||
Loading…
x
Reference in New Issue
Block a user