diff --git a/modules/transcription/llm_client.py b/modules/transcription/llm_client.py index 98ce96b..3c866a7 100644 --- a/modules/transcription/llm_client.py +++ b/modules/transcription/llm_client.py @@ -1,53 +1,365 @@ """Pluggable LLM client for transcription summaries. -Phase 1: Stub implementation — returns TODO string. -Phase 3: Wire up Anthropic, OpenAI, and Ollama providers. +Phase 3: Full implementation with Anthropic, OpenAI, Ollama, OpenRouter, and Google providers. """ import os -from typing import Optional +import json +import logging +from typing import Optional, Dict, Any + +import aiohttp + +from modules.transcription.prompts import PROMPT_TEMPLATES + +logger = logging.getLogger(__name__) + +# Default models per provider +DEFAULT_MODELS: Dict[str, str] = { + "anthropic": "claude-sonnet-4-6", + "openai": "gpt-4o", + "ollama": "llama3", + "openrouter": "anthropic/claude-sonnet-4-6", + "google": "gemini-2.0-flash", +} + +# Timeout for LLM calls (seconds) +LLM_TIMEOUT = 120 + + +class LLMCallResult: + """Result from an LLM call, containing content and token usage.""" + + def __init__(self, content: str, input_tokens: Optional[int] = None, + output_tokens: Optional[int] = None, raw_response: Optional[Dict[str, Any]] = None): + self.content = content + self.input_tokens = input_tokens + self.output_tokens = output_tokens + self.raw_response = raw_response + + def to_dict(self) -> Dict[str, Any]: + return { + "content": self.content, + "input_tokens": self.input_tokens, + "output_tokens": self.output_tokens, + } async def call_llm( provider: str, - model: str, - api_key: str, - system_prompt: str, - user_message: str, -) -> str: + model: Optional[str] = None, + api_key: str = "", + system_prompt: str = "", + user_message: str = "", +) -> LLMCallResult: """Call an LLM to generate a summary. - Phase 1 stub — returns a TODO string. - Phase 3 will implement actual provider routing. + Routes to the appropriate provider implementation. Args: provider: 'anthropic', 'openai', 'ollama', 'openrouter', 'google' - model: Model name (e.g. 'claude-sonnet-4-6', 'gpt-4o', 'llama3') - api_key: User's API key (from localStorage, passed per-request) + model: Model name (falls back to provider default if None) + api_key: User's API key (from frontend, passed per-request) system_prompt: System prompt template (already filled with transcript) user_message: User message content Returns: - LLM-generated summary text + LLMCallResult with generated summary text and token counts + + Raises: + ValueError: If provider is not supported + Exception: If the API call fails """ - # Phase 1 stub — TODO: implement in Phase 3 - return f"[TODO: Implement LLM call for provider={provider}, model={model}]" + provider = provider.lower().strip() + if model is None: + model = DEFAULT_MODELS.get(provider, "") + + dispatch = { + "anthropic": call_anthropic, + "openai": call_openai, + "ollama": call_ollama, + "openrouter": call_openrouter, + "google": call_google, + } + + if provider not in dispatch: + raise ValueError( + f"Unsupported provider: {provider}. " + f"Supported: {', '.join(dispatch.keys())}" + ) + + logger.info(f"Calling LLM provider={provider} model={model}") + result = await dispatch[provider]( + api_key=api_key, model=model, + system_prompt=system_prompt, user_message=user_message, + ) + logger.info(f"LLM call complete: provider={provider} tokens_in={result.input_tokens} tokens_out={result.output_tokens}") + return result -async def call_anthropic(api_key: str, model: str, system_prompt: str, user_message: str) -> str: - """Call Anthropic Claude API.""" - # Phase 3 implementation placeholder - return f"[TODO: Anthropic call — model={model}]" +# --------------------------------------------------------------------------- +# Provider implementations +# --------------------------------------------------------------------------- + +async def call_anthropic( + api_key: str, model: str, system_prompt: str, user_message: str, +) -> LLMCallResult: + """Call Anthropic Claude API (messages v2).""" + url = "https://api.anthropic.com/v1/messages" + headers = { + "x-api-key": api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + } + payload = { + "model": model, + "max_tokens": 4096, + "system": system_prompt, + "messages": [{"role": "user", "content": user_message}], + } + + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=payload, + timeout=aiohttp.ClientTimeout(total=LLM_TIMEOUT)) as resp: + if resp.status != 200: + body = await resp.text() + logger.error(f"Anthropic API error ({resp.status}): {body}") + raise Exception(f"Anthropic API error {resp.status}: {body}") + + data = await resp.json() + + # Extract content blocks + content_parts = [] + for block in data.get("content", []): + if block.get("type") == "text": + content_parts.append(block["text"]) + + content = "\n".join(content_parts) + + # Token counts from response + usage = data.get("usage", {}) + input_tokens = usage.get("input_tokens") or usage.get("input_tokens") + output_tokens = usage.get("output_tokens") or usage.get("output_tokens") + + # Anthropic v2 uses input_tokens/output_tokens; fall back to input_tokens/input_tokens + if not input_tokens: + input_tokens = usage.get("input_tokens") + if not output_tokens: + output_tokens = usage.get("output_tokens") + + return LLMCallResult( + content=content, + input_tokens=input_tokens, + output_tokens=output_tokens, + raw_response=data, + ) -async def call_openai(api_key: str, model: str, system_prompt: str, user_message: str) -> str: - """Call OpenAI API.""" - # Phase 3 implementation placeholder - return f"[TODO: OpenAI call — model={model}]" +async def call_openai( + api_key: str, model: str, system_prompt: str, user_message: str, +) -> LLMCallResult: + """Call OpenAI Chat Completions API.""" + url = "https://api.openai.com/v1/chat/completions" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + payload = { + "model": model, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message}, + ], + "max_tokens": 4096, + } + + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=payload, + timeout=aiohttp.ClientTimeout(total=LLM_TIMEOUT)) as resp: + if resp.status != 200: + body = await resp.text() + logger.error(f"OpenAI API error ({resp.status}): {body}") + raise Exception(f"OpenAI API error {resp.status}: {body}") + + data = await resp.json() + + choice = data.get("choices", [{}])[0] + content = choice.get("message", {}).get("content", "") + + usage = data.get("usage", {}) + return LLMCallResult( + content=content, + input_tokens=usage.get("prompt_tokens"), + output_tokens=usage.get("completion_tokens"), + raw_response=data, + ) -async def call_ollama(api_key: str, model: str, system_prompt: str, user_message: str) -> str: - """Call local Ollama instance.""" - # Phase 3 implementation placeholder - ollama_url = os.getenv("OLLAMA_URL", "https://ollama.kevlarai.com") - return f"[TODO: Ollama call — url={ollama_url}, model={model}]" +async def call_ollama( + api_key: str, model: str, system_prompt: str, user_message: str, +) -> LLMCallResult: + """Call local Ollama instance (generate endpoint).""" + ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434") + url = f"{ollama_url}/api/generate" + + # Ollama uses a single prompt with system instructions prepended + full_prompt = f"{system_prompt}\n\n{user_message}" + + payload = { + "model": model, + "prompt": full_prompt, + "stream": False, + } + + headers = {"Content-Type": "application/json"} + # Ollama may not need an API key; include if set + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=payload, + timeout=aiohttp.ClientTimeout(total=LLM_TIMEOUT)) as resp: + if resp.status != 200: + body = await resp.text() + logger.error(f"Ollama API error ({resp.status}): {body}") + raise Exception(f"Ollama API error {resp.status}: {body}") + + data = await resp.json() + + content = data.get("response", "") + # Ollama reports total_tokens; split into input/output heuristically + total = data.get("total_tokens", 0) + prompt_tokens = data.get("prompt_eval_count", None) + eval_count = data.get("eval_count", None) + + return LLMCallResult( + content=content, + input_tokens=prompt_tokens, + output_tokens=eval_count, + raw_response=data, + ) + + +async def call_openrouter( + api_key: str, model: str, system_prompt: str, user_message: str, +) -> LLMCallResult: + """Call OpenRouter API (OpenAI-compatible chat completions).""" + url = "https://openrouter.ai/api/v1/chat/completions" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "HTTP-Referer": os.getenv("APP_URL", "https://classroom-copilot.example.com"), + "X-Title": "Classroom Copilot", + } + payload = { + "model": model, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message}, + ], + "max_tokens": 4096, + } + + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=payload, + timeout=aiohttp.ClientTimeout(total=LLM_TIMEOUT)) as resp: + if resp.status != 200: + body = await resp.text() + logger.error(f"OpenRouter API error ({resp.status}): {body}") + raise Exception(f"OpenRouter API error {resp.status}: {body}") + + data = await resp.json() + + choice = data.get("choices", [{}])[0] + content = choice.get("message", {}).get("content", "") + + usage = data.get("usage", {}) + return LLMCallResult( + content=content, + input_tokens=usage.get("prompt_tokens"), + output_tokens=usage.get("completion_tokens"), + raw_response=data, + ) + + +async def call_google( + api_key: str, model: str, system_prompt: str, user_message: str, +) -> LLMCallResult: + """Call Google Gemini API (generateContent).""" + url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}" + + payload = { + "contents": [ + { + "role": "user", + "parts": [{"text": user_message}], + } + ], + "system_instruction": { + "parts": [{"text": system_prompt}], + }, + "generationConfig": { + "maxOutputTokens": 4096, + }, + } + + headers = {"Content-Type": "application/json"} + + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=payload, + timeout=aiohttp.ClientTimeout(total=LLM_TIMEOUT)) as resp: + if resp.status != 200: + body = await resp.text() + logger.error(f"Google Gemini API error ({resp.status}): {body}") + raise Exception(f"Google Gemini API error {resp.status}: {body}") + + data = await resp.json() + + # Extract text from candidates + candidates = data.get("candidates", []) + if candidates: + content_parts = candidates[0].get("content", {}).get("parts", []) + content = "\n".join(p.get("text", "") for p in content_parts) + else: + content = "" + + # Token usage from usage_metadata + usage = data.get("usageMetadata", {}) + return LLMCallResult( + content=content, + input_tokens=usage.get("promptTokenCount"), + output_tokens=usage.get("candidatesTokenCount"), + raw_response=data, + ) + + +# --------------------------------------------------------------------------- +# Helper: build prompt from template +# --------------------------------------------------------------------------- + +def build_prompt(summary_type: str, transcript: str) -> tuple[str, str]: + """Build system + user prompt from template and transcript. + + Args: + summary_type: One of 'full_lesson', 'questions_asked', 'teaching_style', + 'key_moments', 'segment' + transcript: The full (or segment) transcript text + + Returns: + (system_prompt, user_message) tuple + """ + template = PROMPT_TEMPLATES.get(summary_type, PROMPT_TEMPLATES["full_lesson"]) + # The template has {transcript} placeholder — fill it in + filled = template.format(transcript=transcript) + + # Split into system and user: everything before "Transcript:" is the system prompt, + # everything from "Transcript:" onward is the user message. + transcript_marker = "\n\nTranscript:\n" + if transcript_marker in filled: + system_prompt, user_message = filled.split(transcript_marker, 1) + user_message = "Transcript:\n" + user_message + else: + system_prompt = "You are an expert educational analyst." + user_message = filled + + return system_prompt, user_message diff --git a/routers/transcribe/sessions.py b/routers/transcribe/sessions.py index 7952395..3d507cb 100644 --- a/routers/transcribe/sessions.py +++ b/routers/transcribe/sessions.py @@ -16,6 +16,10 @@ from modules.transcription.models import ( SummaryResponse, ExportFormat, ) +from modules.transcription.llm_client import call_llm, build_prompt +import logging + +logger = logging.getLogger(__name__) router = APIRouter() @@ -211,28 +215,86 @@ async def generate_summary( summary_request: SummaryGenerateRequest, user_id: str = Depends(get_user_id), ): - """Generate a summary for a session (Phase 1 stub).""" + """Generate a summary for a session using the specified LLM provider. + + Phase 3: Full implementation — calls the pluggable LLM client with + prompt templates from prompts.py. API key is passed per-request and + never stored or logged. + """ supabase = get_supabase_client() # Verify session exists and user owns it - session_check = supabase.supabase.table("transcription_sessions").select("id").eq("id", session_id).eq("user_id", user_id).execute() + session_check = supabase.supabase.table("transcription_sessions").select("*").eq("id", session_id).eq("user_id", user_id).execute() if not session_check.data: raise HTTPException(status_code=404, detail="Session not found") - # Phase 1 stub: TODO implement LLM call in Phase 3 - content = "[TODO: Generate summary via LLM — provider={}, model={}]".format( - summary_request.provider, summary_request.model - ) + session = session_check.data[0] + + # Build transcript from segments (or use segment range) + segments_query = supabase.supabase.table("transcription_segments").select("*").eq("session_id", session_id).order("sequence_index") + segments_result = segments_query.execute() + + if not segments_result.data: + raise HTTPException(status_code=400, detail="No segments found for this session") + + # Apply segment range filter if specified + segments = segments_result.data + if summary_request.segment_range and len(summary_request.segment_range) == 2: + start_idx, end_idx = summary_request.segment_range[0], summary_request.segment_range[1] + if start_idx is not None and end_idx is not None: + segments = segments[start_idx:end_idx] + elif start_idx is not None: + segments = segments[start_idx:] + elif end_idx is not None: + segments = segments[:end_idx] + + # Build full transcript text from segments + transcript_parts = [s["text"] for s in segments if s.get("text")] + transcript = "\n".join(transcript_parts) + + if not transcript.strip(): + raise HTTPException(status_code=400, detail="Transcript is empty — cannot generate summary") + + # Build prompt from template + system_prompt, user_message = build_prompt(summary_request.summary_type, transcript) + + # Call the LLM client + try: + llm_result = await call_llm( + provider=summary_request.provider, + model=summary_request.model, + api_key=summary_request.api_key, + system_prompt=system_prompt, + user_message=user_message, + ) + except Exception as e: + logger.error(f"LLM call failed: {e}") + raise HTTPException(status_code=502, detail=f"LLM generation failed: {str(e)}") + + # Determine segment range for storage + seg_start = None + seg_end = None + if summary_request.segment_range and len(summary_request.segment_range) == 2: + seg_start = summary_request.segment_range[0] + seg_end = summary_request.segment_range[1] + + # Build the prompt that was used (for audit trail) + prompt_used = f"{system_prompt}\n\n{user_message}" if summary_request.summary_type != "segment" else user_message # Save summary to database summary_data = { "session_id": session_id, "user_id": user_id, "summary_type": summary_request.summary_type, - "content": content, + "content": llm_result.content, + "prompt_used": prompt_used, "llm_provider": summary_request.provider, "llm_model": summary_request.model, + "input_tokens": llm_result.input_tokens, + "output_tokens": llm_result.output_tokens, + "segment_range_start": seg_start, + "segment_range_end": seg_end, } result = supabase.supabase.table("transcription_summaries").insert(summary_data).execute()