fix: tighten API P0 auth and route handling

This commit is contained in:
kcar 2026-05-28 12:42:42 +01:00
parent 550d405935
commit 54760083b5
6 changed files with 309 additions and 96 deletions

View File

@ -20,6 +20,10 @@ class SupabaseBearer(HTTPBearer):
token = credentials.credentials
# Decode using the string-based verifier to avoid async dependency conflicts
payload = verify_supabase_jwt_str(token)
# Keep the bearer token available to downstream dependencies that must
# call Supabase as the user (RLS/storage policies), without requiring
# each router to decode the Authorization header again.
payload["_access_token"] = token
return payload
except Exception as e:
logger.error(f"Token verification failed: {str(e)}")

View File

@ -24,12 +24,17 @@ def _create_base_client(url: str, key: str, access_token: Optional[str] = None,
# Otherwise fall back to the API key
auth_header = f"Bearer {access_token}" if access_token else f"Bearer {key}"
headers = {
"apikey": key,
"Authorization": auth_header,
}
if options:
headers.update(options.get("headers", {}))
client_options = SyncClientOptions(
schema="public",
storage=SyncMemoryStorage(),
headers={{
"Authorization": auth_header
}}
headers=headers,
)
return create_client(url, key, options=client_options)
@ -95,4 +100,9 @@ class SupabaseAnonClient:
This enables per-user RLS enforcement via auth.uid() in the JWT.
"""
return cls(access_token=access_token)
if not access_token or not access_token.strip():
raise ValueError("access_token is required for per-user Supabase clients")
token = access_token.strip()
if token.lower().startswith("bearer "):
token = token.split(None, 1)[1]
return cls(access_token=token)

View File

@ -230,6 +230,36 @@ async def my_student_classes(
)
return {"classes": res}
@router.get("/school/students")
async def list_school_students(
credentials: dict = Depends(SupabaseBearer()),
) -> Dict[str, Any]:
"""List all students in the caller's school. Used by admin to add students to a class."""
user_id = credentials.get("sub", "")
institute_id = _require_institute(user_id)
sb = _sb()
members = (
sb.supabase.table("institute_memberships")
.select("profile_id")
.eq("institute_id", institute_id)
.eq("role", "student")
.execute()
.data or []
)
student_ids = [m["profile_id"] for m in members]
if not student_ids:
return {"students": []}
profiles = (
sb.supabase.table("profiles")
.select("id, full_name, display_name, email, user_type")
.in_("id", student_ids)
.order("full_name")
.execute()
.data or []
)
return {"students": profiles}
@router.get("/{class_id}")
async def get_class(
@ -459,35 +489,6 @@ async def remove_teacher(
return {"status": "ok"}
@router.get("/school/students")
async def list_school_students(
credentials: dict = Depends(SupabaseBearer()),
) -> Dict[str, Any]:
"""List all students in the caller's school. Used by admin to add students to a class."""
user_id = credentials.get("sub", "")
institute_id = _require_institute(user_id)
sb = _sb()
members = (
sb.supabase.table("institute_memberships")
.select("profile_id")
.eq("institute_id", institute_id)
.eq("role", "student")
.execute()
.data or []
)
student_ids = [m["profile_id"] for m in members]
if not student_ids:
return {"students": []}
profiles = (
sb.supabase.table("profiles")
.select("id, full_name, display_name, email, user_type")
.in_("id", student_ids)
.order("full_name")
.execute()
.data or []
)
return {"students": profiles}
@router.post("/{class_id}/students")
async def add_student(

View File

@ -60,6 +60,33 @@ def _resolve_institute(
return None, db, teacher_uuid
def _allowed_neo4j_dbs(user_id: str, user_email: str) -> set[str]:
"""Return Neo4j databases this user may request via lazy graph APIs."""
allowed = {f"cc.users.teacher.{user_id.replace('-', '')}"} if user_id else set()
if user_id or user_email:
_, institute_db, _ = _resolve_institute(user_id, user_email)
if institute_db:
allowed.add(institute_db)
allowed.add(f"{institute_db}.curriculum")
return allowed
def _require_allowed_neo4j_db(neo4j_db_name: str, node_type: str, section_id: str, user_id: str, user_email: str) -> None:
"""Reject arbitrary DB traversal from /graph/node/children query params."""
if not neo4j_db_name:
raise HTTPException(status_code=400, detail="neo4j_db_name is required")
if neo4j_db_name == "classroomcopilot":
if node_type.startswith("Calendar") or section_id == "calendar":
return
raise HTTPException(status_code=403, detail="Requested graph database is not allowed for this node")
if neo4j_db_name in _allowed_neo4j_dbs(user_id, user_email):
return
raise HTTPException(status_code=403, detail="Requested graph database is outside the authenticated user's scope")
def _find_teacher_institute(user_email: str) -> Tuple[Optional[str], Optional[str]]:
"""Return (institute_db_name, teacher_uuid) by matching worker_email in all institute DBs."""
if not user_email:
@ -821,7 +848,11 @@ async def get_node_children(
section_id: str = "",
credentials: dict = Depends(SupabaseBearer()),
) -> Dict[str, Any]:
user_id = credentials.get("sub", "")
user_email = credentials.get("email", "")
if not user_id:
raise HTTPException(status_code=403, detail="Could not extract user_id from token")
_require_allowed_neo4j_db(neo4j_db_name, node_type, section_id, user_id, user_email)
children = _get_children_for_node(neo4j_node_id, neo4j_db_name, node_type, section_id, user_email)
return {"status": "success", "children": children}

View File

@ -11,15 +11,117 @@ load_dotenv(find_dotenv())
import os
import json
import logging
from fastapi import APIRouter, HTTPException, Query
from typing import Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Dict, Any, Tuple
from modules.database.supabase.utils.storage import StorageAdmin
from modules.auth.supabase_bearer import SupabaseBearer
from modules.database.supabase.utils.client import SupabaseServiceRoleClient
from modules.database.supabase.utils.storage import StorageError, StorageUser
from modules.logger_tool import initialise_logger
router = APIRouter()
logger = initialise_logger(__name__, os.getenv("LOG_LEVEL"), os.getenv("LOG_PATH"), 'default', True)
ALLOWED_SNAPSHOT_BUCKETS = {"cc.public.snapshots"}
PERSONAL_NODE_TYPES = {"User", "Teacher", "Developer", "SuperAdmin", "UserTeacherTimetable"}
GLOBAL_READONLY_NODE_TYPES = {"CalendarYear", "CalendarMonth", "CalendarWeek", "CalendarDay", "CalendarTimeChunk"}
def _sb() -> SupabaseServiceRoleClient:
return SupabaseServiceRoleClient()
def _parse_snapshot_path(path: str) -> Tuple[str, str, str, str]:
"""Parse and validate bucket/node_type/node_id into a storage object path."""
if not path:
raise HTTPException(status_code=400, detail="Path not provided")
path_parts = [part for part in path.split('/') if part]
if len(path_parts) != 3:
raise HTTPException(status_code=400, detail="Invalid path format. Expected: bucket/nodetype/node_id")
bucket, node_type, node_id = path_parts
if bucket not in ALLOWED_SNAPSHOT_BUCKETS:
raise HTTPException(status_code=403, detail="Snapshot bucket is not allowed")
if any(part in {".", ".."} or ".." in part for part in path_parts):
raise HTTPException(status_code=400, detail="Invalid path component")
if not node_type.replace("_", "").replace("-", "").isalnum():
raise HTTPException(status_code=400, detail="Invalid node type")
if not node_id.replace("_", "").replace("-", "").isalnum():
raise HTTPException(status_code=400, detail="Invalid node id")
return bucket, node_type, node_id, f"{node_type}/{node_id}/tldraw_file.json"
def _user_scope(user_id: str) -> Dict[str, Any]:
"""Resolve Supabase/Neo4j scope for the authenticated user."""
scope: Dict[str, Any] = {
"user_id": user_id,
"teacher_db": f"cc.users.teacher.{user_id.replace('-', '')}" if user_id else "",
"institute_id": "",
"institute_db": "",
"curriculum_db": "",
}
if not user_id:
return scope
try:
sb = _sb()
prof = sb.supabase.table("profiles").select("school_id").eq("id", user_id).single().execute()
school_id = str((prof.data or {}).get("school_id") or "")
scope["institute_id"] = school_id
if school_id:
inst = sb.supabase.table("institutes").select("neo4j_uuid_string").eq("id", school_id).single().execute()
neo4j_uuid = (inst.data or {}).get("neo4j_uuid_string")
if neo4j_uuid:
scope["institute_db"] = f"cc.institutes.{neo4j_uuid}"
scope["curriculum_db"] = f"cc.institutes.{neo4j_uuid}.curriculum"
except Exception as exc:
logger.warning(f"Could not resolve TLDraw storage scope for user {user_id}: {exc}")
return scope
def _authorize_snapshot_path(path: str, db_name: str, credentials: Dict[str, Any], write: bool) -> Tuple[str, str, str, str]:
"""Authorize TLDraw snapshot access before touching Supabase Storage."""
user_id = credentials.get("sub", "")
if not user_id:
raise HTTPException(status_code=403, detail="Could not extract user_id from token")
bucket, node_type, node_id, file_path = _parse_snapshot_path(path)
scope = _user_scope(user_id)
allowed_dbs = {db for db in (scope["teacher_db"], scope["institute_db"], scope["curriculum_db"]) if db}
if node_type in PERSONAL_NODE_TYPES:
if node_id == user_id or (db_name and db_name == scope["teacher_db"]):
return bucket, node_type, node_id, file_path
raise HTTPException(status_code=403, detail="Snapshot path is outside the authenticated user's workspace")
if node_type in GLOBAL_READONLY_NODE_TYPES and db_name == "classroomcopilot":
if write:
raise HTTPException(status_code=403, detail="Global calendar snapshots are read-only")
return bucket, node_type, node_id, file_path
# Institute/curriculum snapshots must be accessed through the caller's institute DB.
if db_name and db_name in allowed_dbs:
return bucket, node_type, node_id, file_path
raise HTTPException(status_code=403, detail="Snapshot path is outside the authenticated user's tenant")
def _storage_for_user(credentials: Dict[str, Any]) -> StorageUser:
access_token = credentials.get("_access_token")
if not access_token:
raise HTTPException(status_code=403, detail="User access token is required for storage access")
return StorageUser(user_id=credentials.get("sub"), access_token=access_token)
def _is_valid_tldraw_snapshot(snapshot_data: Any) -> bool:
if not isinstance(snapshot_data, dict):
return False
if not ("document" in snapshot_data and "session" in snapshot_data):
return False
document = snapshot_data.get("document")
if isinstance(document, dict) and "schema" in document:
return True
return "schemaVersion" in snapshot_data
def create_default_tldraw_content():
"""Create default tldraw content structure."""
return {
@ -101,7 +203,8 @@ def create_default_tldraw_content():
@router.get("/get_tldraw_node_file")
async def read_tldraw_node_file_from_supabase(
path: str = Query(..., description="Supabase Storage path (e.g., 'cc.public.snapshots/User/user_id')"),
db_name: str = Query(..., description="Database name for context")
db_name: str = Query(..., description="Database name for context"),
credentials: dict = Depends(SupabaseBearer()),
):
"""
Load TLDraw snapshot from Supabase Storage.
@ -116,26 +219,9 @@ async def read_tldraw_node_file_from_supabase(
logger.debug(f"Reading tldraw file from Supabase Storage for path: {path}")
logger.debug(f"Database name: {db_name}")
if not path:
raise HTTPException(status_code=400, detail="Path not provided")
try:
# Initialize Supabase Storage
storage = StorageAdmin()
# Parse the path to extract bucket and file path
# Expected format: "cc.public.snapshots/User/user_id" or "cc.public.snapshots/Teacher/teacher_id"
path_parts = path.split('/')
if len(path_parts) < 3:
raise HTTPException(status_code=400, detail="Invalid path format. Expected: bucket/nodetype/node_id")
bucket = path_parts[0] # e.g., "cc.public.snapshots"
node_type = path_parts[1] # e.g., "User", "Teacher"
node_id = path_parts[2] # e.g., "cbc309e5-4029-4c34-aab7-0aa33c563cd0"
# Construct the file path in Supabase Storage
# Format: nodetype/node_id/tldraw_file.json
file_path = f"{node_type}/{node_id}/tldraw_file.json"
bucket, node_type, node_id, file_path = _authorize_snapshot_path(path, db_name, credentials, write=False)
storage = _storage_for_user(credentials)
logger.debug(f"Bucket: {bucket}")
logger.debug(f"File path: {file_path}")
@ -147,29 +233,17 @@ async def read_tldraw_node_file_from_supabase(
# Parse JSON data
try:
snapshot_data = json.loads(file_data.decode('utf-8'))
logger.info(f"Successfully loaded tldraw snapshot from Supabase Storage: {file_path}")
except (UnicodeDecodeError, json.JSONDecodeError) as e:
logger.warning(f"Malformed TLDraw snapshot {file_path}; returning default content: {e}")
return create_default_tldraw_content()
logger.info(f"Successfully loaded tldraw snapshot from Supabase Storage: {file_path}")
if _is_valid_tldraw_snapshot(snapshot_data):
return snapshot_data
logger.warning(f"Snapshot data from {file_path} is missing required TLDraw structure. Using default structure.")
return create_default_tldraw_content()
# Ensure the snapshot has the correct structure for TLDraw
if isinstance(snapshot_data, dict) and 'document' in snapshot_data and 'session' in snapshot_data:
# Check if it has the new format (schemaVersion in document.schema)
if 'document' in snapshot_data and isinstance(snapshot_data['document'], dict) and 'schema' in snapshot_data['document']:
return snapshot_data
# Check if it has the old format (schemaVersion at root level)
elif 'schemaVersion' in snapshot_data:
return snapshot_data
else:
# Use default structure if schema is missing
logger.warning(f"Snapshot data from {file_path_in_bucket} is missing schemaVersion. Using default structure.")
return create_default_tldraw_content()
else:
# Use default structure if basic structure is missing
logger.warning(f"Snapshot data from {file_path_in_bucket} is missing top-level TLDraw keys. Using default structure.")
return create_default_tldraw_content()
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON from Supabase Storage file: {e}")
raise HTTPException(status_code=500, detail="Invalid JSON in file")
except Exception as e:
except StorageError as e:
# File doesn't exist, create default content
logger.info(f"File not found in Supabase Storage, creating default tldraw content: {file_path}")
@ -199,7 +273,8 @@ async def read_tldraw_node_file_from_supabase(
async def set_tldraw_node_file_in_supabase(
path: str = Query(..., description="Supabase Storage path (e.g., 'cc.public.snapshots/User/user_id')"),
db_name: str = Query(..., description="Database name for context"),
data: Dict[str, Any] = None
data: Dict[str, Any] = None,
credentials: dict = Depends(SupabaseBearer()),
):
"""
Save TLDraw snapshot to Supabase Storage.
@ -215,27 +290,12 @@ async def set_tldraw_node_file_in_supabase(
logger.debug(f"Saving tldraw file to Supabase Storage for path: {path}")
logger.debug(f"Database name: {db_name}")
if not path:
raise HTTPException(status_code=400, detail="Path not provided")
if not data:
raise HTTPException(status_code=400, detail="Data not provided")
try:
# Initialize Supabase Storage
storage = StorageAdmin()
# Parse the path to extract bucket and file path
path_parts = path.split('/')
if len(path_parts) < 3:
raise HTTPException(status_code=400, detail="Invalid path format. Expected: bucket/nodetype/node_id")
bucket = path_parts[0] # e.g., "cc.public.snapshots"
node_type = path_parts[1] # e.g., "User", "Teacher"
node_id = path_parts[2] # e.g., "cbc309e5-4029-4c34-aab7-0aa33c563cd0"
# Construct the file path in Supabase Storage
file_path = f"{node_type}/{node_id}/tldraw_file.json"
bucket, node_type, node_id, file_path = _authorize_snapshot_path(path, db_name, credentials, write=True)
storage = _storage_for_user(credentials)
logger.debug(f"Bucket: {bucket}")
logger.debug(f"File path: {file_path}")

View File

@ -0,0 +1,107 @@
import pytest
from fastapi import HTTPException
def test_classes_school_students_route_registered_before_dynamic_class_id():
from routers.database.tools.classes_router import router
paths = [route.path for route in router.routes]
assert paths.index('/school/students') < paths.index('/{class_id}')
def test_supabase_anon_for_user_sets_user_authorization_header(monkeypatch):
from modules.database.supabase.utils import client as client_module
captured = {}
def fake_create_client(url, key, options=None):
captured['url'] = url
captured['key'] = key
captured['options'] = options
return object()
monkeypatch.setenv('SUPABASE_URL', 'http://supabase.test')
monkeypatch.setenv('ANON_KEY', 'anon-key')
monkeypatch.setattr(client_module, 'create_client', fake_create_client)
client_module.SupabaseAnonClient.for_user('Bearer user-jwt')
assert captured['key'] == 'anon-key'
assert captured['options'].headers['apikey'] == 'anon-key'
assert captured['options'].headers['Authorization'] == 'Bearer user-jwt'
@pytest.mark.parametrize('token', ['', ' '])
def test_supabase_anon_for_user_requires_token(token):
from modules.database.supabase.utils.client import SupabaseAnonClient
with pytest.raises(ValueError):
SupabaseAnonClient.for_user(token)
def test_tldraw_malformed_snapshot_falls_back_to_default():
from routers.database.tools import tldraw_supabase_storage as storage
assert not storage._is_valid_tldraw_snapshot({'document': {}, 'session': {}})
default = storage.create_default_tldraw_content()
assert storage._is_valid_tldraw_snapshot(default)
assert default['document']['schema']['schemaVersion'] == 2
def test_tldraw_rejects_cross_tenant_snapshot_db(monkeypatch):
from routers.database.tools import tldraw_supabase_storage as storage
monkeypatch.setattr(storage, '_user_scope', lambda user_id: {
'user_id': user_id,
'teacher_db': f"cc.users.teacher.{user_id.replace('-', '')}",
'institute_id': 'school-1',
'institute_db': 'cc.institutes.allowed',
'curriculum_db': 'cc.institutes.allowed.curriculum',
})
with pytest.raises(HTTPException) as exc:
storage._authorize_snapshot_path(
'cc.public.snapshots/School/other-school',
'cc.institutes.other',
{'sub': 'user-1'},
write=False,
)
assert exc.value.status_code == 403
def test_graph_node_children_rejects_unscoped_db(monkeypatch):
from routers.database.tools import graph_tree_router
monkeypatch.setattr(graph_tree_router, '_allowed_neo4j_dbs', lambda user_id, email: {'cc.users.teacher.user1'})
with pytest.raises(HTTPException) as exc:
graph_tree_router._require_allowed_neo4j_db(
'cc.institutes.not-mine',
'SubjectClass',
'classes',
'user-1',
'teacher@example.test',
)
assert exc.value.status_code == 403
def test_graph_node_children_allows_global_calendar_only():
from routers.database.tools import graph_tree_router
graph_tree_router._require_allowed_neo4j_db(
'classroomcopilot',
'CalendarYear',
'',
'user-1',
'teacher@example.test',
)
with pytest.raises(HTTPException) as exc:
graph_tree_router._require_allowed_neo4j_db(
'classroomcopilot',
'School',
'school',
'user-1',
'teacher@example.test',
)
assert exc.value.status_code == 403