diff --git a/modules/auth/supabase_bearer.py b/modules/auth/supabase_bearer.py index 94034f3..af2fb6c 100644 --- a/modules/auth/supabase_bearer.py +++ b/modules/auth/supabase_bearer.py @@ -20,6 +20,10 @@ class SupabaseBearer(HTTPBearer): token = credentials.credentials # Decode using the string-based verifier to avoid async dependency conflicts payload = verify_supabase_jwt_str(token) + # Keep the bearer token available to downstream dependencies that must + # call Supabase as the user (RLS/storage policies), without requiring + # each router to decode the Authorization header again. + payload["_access_token"] = token return payload except Exception as e: logger.error(f"Token verification failed: {str(e)}") diff --git a/modules/database/supabase/utils/client.py b/modules/database/supabase/utils/client.py index bc8ec97..2dcd6b2 100644 --- a/modules/database/supabase/utils/client.py +++ b/modules/database/supabase/utils/client.py @@ -24,12 +24,17 @@ def _create_base_client(url: str, key: str, access_token: Optional[str] = None, # Otherwise fall back to the API key auth_header = f"Bearer {access_token}" if access_token else f"Bearer {key}" + headers = { + "apikey": key, + "Authorization": auth_header, + } + if options: + headers.update(options.get("headers", {})) + client_options = SyncClientOptions( schema="public", storage=SyncMemoryStorage(), - headers={ - "Authorization": auth_header - } + headers=headers, ) return create_client(url, key, options=client_options) @@ -95,4 +100,9 @@ class SupabaseAnonClient: This enables per-user RLS enforcement via auth.uid() in the JWT. """ - return cls(access_token=access_token) + if not access_token or not access_token.strip(): + raise ValueError("access_token is required for per-user Supabase clients") + token = access_token.strip() + if token.lower().startswith("bearer "): + token = token.split(None, 1)[1] + return cls(access_token=token) diff --git a/routers/database/tools/classes_router.py b/routers/database/tools/classes_router.py index 68c6fd1..a8b82d5 100644 --- a/routers/database/tools/classes_router.py +++ b/routers/database/tools/classes_router.py @@ -230,6 +230,36 @@ async def my_student_classes( ) return {"classes": res} +@router.get("/school/students") +async def list_school_students( + credentials: dict = Depends(SupabaseBearer()), +) -> Dict[str, Any]: + """List all students in the caller's school. Used by admin to add students to a class.""" + user_id = credentials.get("sub", "") + institute_id = _require_institute(user_id) + sb = _sb() + members = ( + sb.supabase.table("institute_memberships") + .select("profile_id") + .eq("institute_id", institute_id) + .eq("role", "student") + .execute() + .data or [] + ) + student_ids = [m["profile_id"] for m in members] + if not student_ids: + return {"students": []} + profiles = ( + sb.supabase.table("profiles") + .select("id, full_name, display_name, email, user_type") + .in_("id", student_ids) + .order("full_name") + .execute() + .data or [] + ) + return {"students": profiles} + + @router.get("/{class_id}") async def get_class( @@ -459,35 +489,6 @@ async def remove_teacher( return {"status": "ok"} -@router.get("/school/students") -async def list_school_students( - credentials: dict = Depends(SupabaseBearer()), -) -> Dict[str, Any]: - """List all students in the caller's school. Used by admin to add students to a class.""" - user_id = credentials.get("sub", "") - institute_id = _require_institute(user_id) - sb = _sb() - members = ( - sb.supabase.table("institute_memberships") - .select("profile_id") - .eq("institute_id", institute_id) - .eq("role", "student") - .execute() - .data or [] - ) - student_ids = [m["profile_id"] for m in members] - if not student_ids: - return {"students": []} - profiles = ( - sb.supabase.table("profiles") - .select("id, full_name, display_name, email, user_type") - .in_("id", student_ids) - .order("full_name") - .execute() - .data or [] - ) - return {"students": profiles} - @router.post("/{class_id}/students") async def add_student( diff --git a/routers/database/tools/graph_tree_router.py b/routers/database/tools/graph_tree_router.py index 19a8b5f..62e1cbc 100644 --- a/routers/database/tools/graph_tree_router.py +++ b/routers/database/tools/graph_tree_router.py @@ -60,6 +60,33 @@ def _resolve_institute( return None, db, teacher_uuid +def _allowed_neo4j_dbs(user_id: str, user_email: str) -> set[str]: + """Return Neo4j databases this user may request via lazy graph APIs.""" + allowed = {f"cc.users.teacher.{user_id.replace('-', '')}"} if user_id else set() + if user_id or user_email: + _, institute_db, _ = _resolve_institute(user_id, user_email) + if institute_db: + allowed.add(institute_db) + allowed.add(f"{institute_db}.curriculum") + return allowed + + +def _require_allowed_neo4j_db(neo4j_db_name: str, node_type: str, section_id: str, user_id: str, user_email: str) -> None: + """Reject arbitrary DB traversal from /graph/node/children query params.""" + if not neo4j_db_name: + raise HTTPException(status_code=400, detail="neo4j_db_name is required") + + if neo4j_db_name == "classroomcopilot": + if node_type.startswith("Calendar") or section_id == "calendar": + return + raise HTTPException(status_code=403, detail="Requested graph database is not allowed for this node") + + if neo4j_db_name in _allowed_neo4j_dbs(user_id, user_email): + return + + raise HTTPException(status_code=403, detail="Requested graph database is outside the authenticated user's scope") + + def _find_teacher_institute(user_email: str) -> Tuple[Optional[str], Optional[str]]: """Return (institute_db_name, teacher_uuid) by matching worker_email in all institute DBs.""" if not user_email: @@ -821,7 +848,11 @@ async def get_node_children( section_id: str = "", credentials: dict = Depends(SupabaseBearer()), ) -> Dict[str, Any]: + user_id = credentials.get("sub", "") user_email = credentials.get("email", "") + if not user_id: + raise HTTPException(status_code=403, detail="Could not extract user_id from token") + _require_allowed_neo4j_db(neo4j_db_name, node_type, section_id, user_id, user_email) children = _get_children_for_node(neo4j_node_id, neo4j_db_name, node_type, section_id, user_email) return {"status": "success", "children": children} diff --git a/routers/database/tools/tldraw_supabase_storage.py b/routers/database/tools/tldraw_supabase_storage.py index c9f08c8..bf857bf 100644 --- a/routers/database/tools/tldraw_supabase_storage.py +++ b/routers/database/tools/tldraw_supabase_storage.py @@ -11,15 +11,117 @@ load_dotenv(find_dotenv()) import os import json import logging -from fastapi import APIRouter, HTTPException, Query -from typing import Dict, Any +from fastapi import APIRouter, Depends, HTTPException, Query +from typing import Dict, Any, Tuple -from modules.database.supabase.utils.storage import StorageAdmin +from modules.auth.supabase_bearer import SupabaseBearer +from modules.database.supabase.utils.client import SupabaseServiceRoleClient +from modules.database.supabase.utils.storage import StorageError, StorageUser from modules.logger_tool import initialise_logger router = APIRouter() logger = initialise_logger(__name__, os.getenv("LOG_LEVEL"), os.getenv("LOG_PATH"), 'default', True) +ALLOWED_SNAPSHOT_BUCKETS = {"cc.public.snapshots"} +PERSONAL_NODE_TYPES = {"User", "Teacher", "Developer", "SuperAdmin", "UserTeacherTimetable"} +GLOBAL_READONLY_NODE_TYPES = {"CalendarYear", "CalendarMonth", "CalendarWeek", "CalendarDay", "CalendarTimeChunk"} + + +def _sb() -> SupabaseServiceRoleClient: + return SupabaseServiceRoleClient() + + +def _parse_snapshot_path(path: str) -> Tuple[str, str, str, str]: + """Parse and validate bucket/node_type/node_id into a storage object path.""" + if not path: + raise HTTPException(status_code=400, detail="Path not provided") + path_parts = [part for part in path.split('/') if part] + if len(path_parts) != 3: + raise HTTPException(status_code=400, detail="Invalid path format. Expected: bucket/nodetype/node_id") + + bucket, node_type, node_id = path_parts + if bucket not in ALLOWED_SNAPSHOT_BUCKETS: + raise HTTPException(status_code=403, detail="Snapshot bucket is not allowed") + if any(part in {".", ".."} or ".." in part for part in path_parts): + raise HTTPException(status_code=400, detail="Invalid path component") + if not node_type.replace("_", "").replace("-", "").isalnum(): + raise HTTPException(status_code=400, detail="Invalid node type") + if not node_id.replace("_", "").replace("-", "").isalnum(): + raise HTTPException(status_code=400, detail="Invalid node id") + + return bucket, node_type, node_id, f"{node_type}/{node_id}/tldraw_file.json" + + +def _user_scope(user_id: str) -> Dict[str, Any]: + """Resolve Supabase/Neo4j scope for the authenticated user.""" + scope: Dict[str, Any] = { + "user_id": user_id, + "teacher_db": f"cc.users.teacher.{user_id.replace('-', '')}" if user_id else "", + "institute_id": "", + "institute_db": "", + "curriculum_db": "", + } + if not user_id: + return scope + try: + sb = _sb() + prof = sb.supabase.table("profiles").select("school_id").eq("id", user_id).single().execute() + school_id = str((prof.data or {}).get("school_id") or "") + scope["institute_id"] = school_id + if school_id: + inst = sb.supabase.table("institutes").select("neo4j_uuid_string").eq("id", school_id).single().execute() + neo4j_uuid = (inst.data or {}).get("neo4j_uuid_string") + if neo4j_uuid: + scope["institute_db"] = f"cc.institutes.{neo4j_uuid}" + scope["curriculum_db"] = f"cc.institutes.{neo4j_uuid}.curriculum" + except Exception as exc: + logger.warning(f"Could not resolve TLDraw storage scope for user {user_id}: {exc}") + return scope + + +def _authorize_snapshot_path(path: str, db_name: str, credentials: Dict[str, Any], write: bool) -> Tuple[str, str, str, str]: + """Authorize TLDraw snapshot access before touching Supabase Storage.""" + user_id = credentials.get("sub", "") + if not user_id: + raise HTTPException(status_code=403, detail="Could not extract user_id from token") + bucket, node_type, node_id, file_path = _parse_snapshot_path(path) + scope = _user_scope(user_id) + allowed_dbs = {db for db in (scope["teacher_db"], scope["institute_db"], scope["curriculum_db"]) if db} + + if node_type in PERSONAL_NODE_TYPES: + if node_id == user_id or (db_name and db_name == scope["teacher_db"]): + return bucket, node_type, node_id, file_path + raise HTTPException(status_code=403, detail="Snapshot path is outside the authenticated user's workspace") + + if node_type in GLOBAL_READONLY_NODE_TYPES and db_name == "classroomcopilot": + if write: + raise HTTPException(status_code=403, detail="Global calendar snapshots are read-only") + return bucket, node_type, node_id, file_path + + # Institute/curriculum snapshots must be accessed through the caller's institute DB. + if db_name and db_name in allowed_dbs: + return bucket, node_type, node_id, file_path + + raise HTTPException(status_code=403, detail="Snapshot path is outside the authenticated user's tenant") + + +def _storage_for_user(credentials: Dict[str, Any]) -> StorageUser: + access_token = credentials.get("_access_token") + if not access_token: + raise HTTPException(status_code=403, detail="User access token is required for storage access") + return StorageUser(user_id=credentials.get("sub"), access_token=access_token) + + +def _is_valid_tldraw_snapshot(snapshot_data: Any) -> bool: + if not isinstance(snapshot_data, dict): + return False + if not ("document" in snapshot_data and "session" in snapshot_data): + return False + document = snapshot_data.get("document") + if isinstance(document, dict) and "schema" in document: + return True + return "schemaVersion" in snapshot_data + def create_default_tldraw_content(): """Create default tldraw content structure.""" return { @@ -101,7 +203,8 @@ def create_default_tldraw_content(): @router.get("/get_tldraw_node_file") async def read_tldraw_node_file_from_supabase( path: str = Query(..., description="Supabase Storage path (e.g., 'cc.public.snapshots/User/user_id')"), - db_name: str = Query(..., description="Database name for context") + db_name: str = Query(..., description="Database name for context"), + credentials: dict = Depends(SupabaseBearer()), ): """ Load TLDraw snapshot from Supabase Storage. @@ -116,26 +219,9 @@ async def read_tldraw_node_file_from_supabase( logger.debug(f"Reading tldraw file from Supabase Storage for path: {path}") logger.debug(f"Database name: {db_name}") - if not path: - raise HTTPException(status_code=400, detail="Path not provided") - try: - # Initialize Supabase Storage - storage = StorageAdmin() - - # Parse the path to extract bucket and file path - # Expected format: "cc.public.snapshots/User/user_id" or "cc.public.snapshots/Teacher/teacher_id" - path_parts = path.split('/') - if len(path_parts) < 3: - raise HTTPException(status_code=400, detail="Invalid path format. Expected: bucket/nodetype/node_id") - - bucket = path_parts[0] # e.g., "cc.public.snapshots" - node_type = path_parts[1] # e.g., "User", "Teacher" - node_id = path_parts[2] # e.g., "cbc309e5-4029-4c34-aab7-0aa33c563cd0" - - # Construct the file path in Supabase Storage - # Format: nodetype/node_id/tldraw_file.json - file_path = f"{node_type}/{node_id}/tldraw_file.json" + bucket, node_type, node_id, file_path = _authorize_snapshot_path(path, db_name, credentials, write=False) + storage = _storage_for_user(credentials) logger.debug(f"Bucket: {bucket}") logger.debug(f"File path: {file_path}") @@ -147,29 +233,17 @@ async def read_tldraw_node_file_from_supabase( # Parse JSON data try: snapshot_data = json.loads(file_data.decode('utf-8')) - logger.info(f"Successfully loaded tldraw snapshot from Supabase Storage: {file_path}") + except (UnicodeDecodeError, json.JSONDecodeError) as e: + logger.warning(f"Malformed TLDraw snapshot {file_path}; returning default content: {e}") + return create_default_tldraw_content() + + logger.info(f"Successfully loaded tldraw snapshot from Supabase Storage: {file_path}") + if _is_valid_tldraw_snapshot(snapshot_data): + return snapshot_data + logger.warning(f"Snapshot data from {file_path} is missing required TLDraw structure. Using default structure.") + return create_default_tldraw_content() - # Ensure the snapshot has the correct structure for TLDraw - if isinstance(snapshot_data, dict) and 'document' in snapshot_data and 'session' in snapshot_data: - # Check if it has the new format (schemaVersion in document.schema) - if 'document' in snapshot_data and isinstance(snapshot_data['document'], dict) and 'schema' in snapshot_data['document']: - return snapshot_data - # Check if it has the old format (schemaVersion at root level) - elif 'schemaVersion' in snapshot_data: - return snapshot_data - else: - # Use default structure if schema is missing - logger.warning(f"Snapshot data from {file_path_in_bucket} is missing schemaVersion. Using default structure.") - return create_default_tldraw_content() - else: - # Use default structure if basic structure is missing - logger.warning(f"Snapshot data from {file_path_in_bucket} is missing top-level TLDraw keys. Using default structure.") - return create_default_tldraw_content() - except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON from Supabase Storage file: {e}") - raise HTTPException(status_code=500, detail="Invalid JSON in file") - - except Exception as e: + except StorageError as e: # File doesn't exist, create default content logger.info(f"File not found in Supabase Storage, creating default tldraw content: {file_path}") @@ -199,7 +273,8 @@ async def read_tldraw_node_file_from_supabase( async def set_tldraw_node_file_in_supabase( path: str = Query(..., description="Supabase Storage path (e.g., 'cc.public.snapshots/User/user_id')"), db_name: str = Query(..., description="Database name for context"), - data: Dict[str, Any] = None + data: Dict[str, Any] = None, + credentials: dict = Depends(SupabaseBearer()), ): """ Save TLDraw snapshot to Supabase Storage. @@ -215,27 +290,12 @@ async def set_tldraw_node_file_in_supabase( logger.debug(f"Saving tldraw file to Supabase Storage for path: {path}") logger.debug(f"Database name: {db_name}") - if not path: - raise HTTPException(status_code=400, detail="Path not provided") - if not data: raise HTTPException(status_code=400, detail="Data not provided") try: - # Initialize Supabase Storage - storage = StorageAdmin() - - # Parse the path to extract bucket and file path - path_parts = path.split('/') - if len(path_parts) < 3: - raise HTTPException(status_code=400, detail="Invalid path format. Expected: bucket/nodetype/node_id") - - bucket = path_parts[0] # e.g., "cc.public.snapshots" - node_type = path_parts[1] # e.g., "User", "Teacher" - node_id = path_parts[2] # e.g., "cbc309e5-4029-4c34-aab7-0aa33c563cd0" - - # Construct the file path in Supabase Storage - file_path = f"{node_type}/{node_id}/tldraw_file.json" + bucket, node_type, node_id, file_path = _authorize_snapshot_path(path, db_name, credentials, write=True) + storage = _storage_for_user(credentials) logger.debug(f"Bucket: {bucket}") logger.debug(f"File path: {file_path}") diff --git a/tests/test_p0_api_security.py b/tests/test_p0_api_security.py new file mode 100644 index 0000000..5914e92 --- /dev/null +++ b/tests/test_p0_api_security.py @@ -0,0 +1,107 @@ +import pytest +from fastapi import HTTPException + + +def test_classes_school_students_route_registered_before_dynamic_class_id(): + from routers.database.tools.classes_router import router + + paths = [route.path for route in router.routes] + assert paths.index('/school/students') < paths.index('/{class_id}') + + +def test_supabase_anon_for_user_sets_user_authorization_header(monkeypatch): + from modules.database.supabase.utils import client as client_module + + captured = {} + + def fake_create_client(url, key, options=None): + captured['url'] = url + captured['key'] = key + captured['options'] = options + return object() + + monkeypatch.setenv('SUPABASE_URL', 'http://supabase.test') + monkeypatch.setenv('ANON_KEY', 'anon-key') + monkeypatch.setattr(client_module, 'create_client', fake_create_client) + + client_module.SupabaseAnonClient.for_user('Bearer user-jwt') + + assert captured['key'] == 'anon-key' + assert captured['options'].headers['apikey'] == 'anon-key' + assert captured['options'].headers['Authorization'] == 'Bearer user-jwt' + + +@pytest.mark.parametrize('token', ['', ' ']) +def test_supabase_anon_for_user_requires_token(token): + from modules.database.supabase.utils.client import SupabaseAnonClient + + with pytest.raises(ValueError): + SupabaseAnonClient.for_user(token) + + +def test_tldraw_malformed_snapshot_falls_back_to_default(): + from routers.database.tools import tldraw_supabase_storage as storage + + assert not storage._is_valid_tldraw_snapshot({'document': {}, 'session': {}}) + default = storage.create_default_tldraw_content() + assert storage._is_valid_tldraw_snapshot(default) + assert default['document']['schema']['schemaVersion'] == 2 + + +def test_tldraw_rejects_cross_tenant_snapshot_db(monkeypatch): + from routers.database.tools import tldraw_supabase_storage as storage + + monkeypatch.setattr(storage, '_user_scope', lambda user_id: { + 'user_id': user_id, + 'teacher_db': f"cc.users.teacher.{user_id.replace('-', '')}", + 'institute_id': 'school-1', + 'institute_db': 'cc.institutes.allowed', + 'curriculum_db': 'cc.institutes.allowed.curriculum', + }) + + with pytest.raises(HTTPException) as exc: + storage._authorize_snapshot_path( + 'cc.public.snapshots/School/other-school', + 'cc.institutes.other', + {'sub': 'user-1'}, + write=False, + ) + assert exc.value.status_code == 403 + + +def test_graph_node_children_rejects_unscoped_db(monkeypatch): + from routers.database.tools import graph_tree_router + + monkeypatch.setattr(graph_tree_router, '_allowed_neo4j_dbs', lambda user_id, email: {'cc.users.teacher.user1'}) + + with pytest.raises(HTTPException) as exc: + graph_tree_router._require_allowed_neo4j_db( + 'cc.institutes.not-mine', + 'SubjectClass', + 'classes', + 'user-1', + 'teacher@example.test', + ) + assert exc.value.status_code == 403 + + +def test_graph_node_children_allows_global_calendar_only(): + from routers.database.tools import graph_tree_router + + graph_tree_router._require_allowed_neo4j_db( + 'classroomcopilot', + 'CalendarYear', + '', + 'user-1', + 'teacher@example.test', + ) + + with pytest.raises(HTTPException) as exc: + graph_tree_router._require_allowed_neo4j_db( + 'classroomcopilot', + 'School', + 'school', + 'user-1', + 'teacher@example.test', + ) + assert exc.value.status_code == 403