Phase 3B: Implement pluggable LLM client for summary generation

- Create llm_client.py with 5 provider implementations (Anthropic, OpenAI, Ollama, OpenRouter, Google)
- Add build_prompt() helper to construct system/user prompts from templates
- Wire up POST /transcribe/sessions/{id}/summaries endpoint to call LLM client
- Return generated content + token counts (input_tokens, output_tokens)
- API keys passed per-request, never stored or logged
- Uses prompt templates from prompts.py based on summary_type
This commit is contained in:
Kevin Carter 2026-05-20 22:20:19 +00:00
parent fd8d2a537d
commit 36ae76143f
2 changed files with 409 additions and 35 deletions

View File

@ -1,53 +1,365 @@
"""Pluggable LLM client for transcription summaries.
Phase 1: Stub implementation returns TODO string.
Phase 3: Wire up Anthropic, OpenAI, and Ollama providers.
Phase 3: Full implementation with Anthropic, OpenAI, Ollama, OpenRouter, and Google providers.
"""
import os
from typing import Optional
import json
import logging
from typing import Optional, Dict, Any
import aiohttp
from modules.transcription.prompts import PROMPT_TEMPLATES
logger = logging.getLogger(__name__)
# Default models per provider
DEFAULT_MODELS: Dict[str, str] = {
"anthropic": "claude-sonnet-4-6",
"openai": "gpt-4o",
"ollama": "llama3",
"openrouter": "anthropic/claude-sonnet-4-6",
"google": "gemini-2.0-flash",
}
# Timeout for LLM calls (seconds)
LLM_TIMEOUT = 120
class LLMCallResult:
"""Result from an LLM call, containing content and token usage."""
def __init__(self, content: str, input_tokens: Optional[int] = None,
output_tokens: Optional[int] = None, raw_response: Optional[Dict[str, Any]] = None):
self.content = content
self.input_tokens = input_tokens
self.output_tokens = output_tokens
self.raw_response = raw_response
def to_dict(self) -> Dict[str, Any]:
return {
"content": self.content,
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
}
async def call_llm(
provider: str,
model: str,
api_key: str,
system_prompt: str,
user_message: str,
) -> str:
model: Optional[str] = None,
api_key: str = "",
system_prompt: str = "",
user_message: str = "",
) -> LLMCallResult:
"""Call an LLM to generate a summary.
Phase 1 stub returns a TODO string.
Phase 3 will implement actual provider routing.
Routes to the appropriate provider implementation.
Args:
provider: 'anthropic', 'openai', 'ollama', 'openrouter', 'google'
model: Model name (e.g. 'claude-sonnet-4-6', 'gpt-4o', 'llama3')
api_key: User's API key (from localStorage, passed per-request)
model: Model name (falls back to provider default if None)
api_key: User's API key (from frontend, passed per-request)
system_prompt: System prompt template (already filled with transcript)
user_message: User message content
Returns:
LLM-generated summary text
LLMCallResult with generated summary text and token counts
Raises:
ValueError: If provider is not supported
Exception: If the API call fails
"""
# Phase 1 stub — TODO: implement in Phase 3
return f"[TODO: Implement LLM call for provider={provider}, model={model}]"
provider = provider.lower().strip()
if model is None:
model = DEFAULT_MODELS.get(provider, "")
dispatch = {
"anthropic": call_anthropic,
"openai": call_openai,
"ollama": call_ollama,
"openrouter": call_openrouter,
"google": call_google,
}
if provider not in dispatch:
raise ValueError(
f"Unsupported provider: {provider}. "
f"Supported: {', '.join(dispatch.keys())}"
)
logger.info(f"Calling LLM provider={provider} model={model}")
result = await dispatch[provider](
api_key=api_key, model=model,
system_prompt=system_prompt, user_message=user_message,
)
logger.info(f"LLM call complete: provider={provider} tokens_in={result.input_tokens} tokens_out={result.output_tokens}")
return result
async def call_anthropic(api_key: str, model: str, system_prompt: str, user_message: str) -> str:
"""Call Anthropic Claude API."""
# Phase 3 implementation placeholder
return f"[TODO: Anthropic call — model={model}]"
# ---------------------------------------------------------------------------
# Provider implementations
# ---------------------------------------------------------------------------
async def call_anthropic(
api_key: str, model: str, system_prompt: str, user_message: str,
) -> LLMCallResult:
"""Call Anthropic Claude API (messages v2)."""
url = "https://api.anthropic.com/v1/messages"
headers = {
"x-api-key": api_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
}
payload = {
"model": model,
"max_tokens": 4096,
"system": system_prompt,
"messages": [{"role": "user", "content": user_message}],
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload,
timeout=aiohttp.ClientTimeout(total=LLM_TIMEOUT)) as resp:
if resp.status != 200:
body = await resp.text()
logger.error(f"Anthropic API error ({resp.status}): {body}")
raise Exception(f"Anthropic API error {resp.status}: {body}")
data = await resp.json()
# Extract content blocks
content_parts = []
for block in data.get("content", []):
if block.get("type") == "text":
content_parts.append(block["text"])
content = "\n".join(content_parts)
# Token counts from response
usage = data.get("usage", {})
input_tokens = usage.get("input_tokens") or usage.get("input_tokens")
output_tokens = usage.get("output_tokens") or usage.get("output_tokens")
# Anthropic v2 uses input_tokens/output_tokens; fall back to input_tokens/input_tokens
if not input_tokens:
input_tokens = usage.get("input_tokens")
if not output_tokens:
output_tokens = usage.get("output_tokens")
return LLMCallResult(
content=content,
input_tokens=input_tokens,
output_tokens=output_tokens,
raw_response=data,
)
async def call_openai(api_key: str, model: str, system_prompt: str, user_message: str) -> str:
"""Call OpenAI API."""
# Phase 3 implementation placeholder
return f"[TODO: OpenAI call — model={model}]"
async def call_openai(
api_key: str, model: str, system_prompt: str, user_message: str,
) -> LLMCallResult:
"""Call OpenAI Chat Completions API."""
url = "https://api.openai.com/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
],
"max_tokens": 4096,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload,
timeout=aiohttp.ClientTimeout(total=LLM_TIMEOUT)) as resp:
if resp.status != 200:
body = await resp.text()
logger.error(f"OpenAI API error ({resp.status}): {body}")
raise Exception(f"OpenAI API error {resp.status}: {body}")
data = await resp.json()
choice = data.get("choices", [{}])[0]
content = choice.get("message", {}).get("content", "")
usage = data.get("usage", {})
return LLMCallResult(
content=content,
input_tokens=usage.get("prompt_tokens"),
output_tokens=usage.get("completion_tokens"),
raw_response=data,
)
async def call_ollama(api_key: str, model: str, system_prompt: str, user_message: str) -> str:
"""Call local Ollama instance."""
# Phase 3 implementation placeholder
ollama_url = os.getenv("OLLAMA_URL", "https://ollama.kevlarai.com")
return f"[TODO: Ollama call — url={ollama_url}, model={model}]"
async def call_ollama(
api_key: str, model: str, system_prompt: str, user_message: str,
) -> LLMCallResult:
"""Call local Ollama instance (generate endpoint)."""
ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434")
url = f"{ollama_url}/api/generate"
# Ollama uses a single prompt with system instructions prepended
full_prompt = f"{system_prompt}\n\n{user_message}"
payload = {
"model": model,
"prompt": full_prompt,
"stream": False,
}
headers = {"Content-Type": "application/json"}
# Ollama may not need an API key; include if set
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload,
timeout=aiohttp.ClientTimeout(total=LLM_TIMEOUT)) as resp:
if resp.status != 200:
body = await resp.text()
logger.error(f"Ollama API error ({resp.status}): {body}")
raise Exception(f"Ollama API error {resp.status}: {body}")
data = await resp.json()
content = data.get("response", "")
# Ollama reports total_tokens; split into input/output heuristically
total = data.get("total_tokens", 0)
prompt_tokens = data.get("prompt_eval_count", None)
eval_count = data.get("eval_count", None)
return LLMCallResult(
content=content,
input_tokens=prompt_tokens,
output_tokens=eval_count,
raw_response=data,
)
async def call_openrouter(
api_key: str, model: str, system_prompt: str, user_message: str,
) -> LLMCallResult:
"""Call OpenRouter API (OpenAI-compatible chat completions)."""
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"HTTP-Referer": os.getenv("APP_URL", "https://classroom-copilot.example.com"),
"X-Title": "Classroom Copilot",
}
payload = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
],
"max_tokens": 4096,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload,
timeout=aiohttp.ClientTimeout(total=LLM_TIMEOUT)) as resp:
if resp.status != 200:
body = await resp.text()
logger.error(f"OpenRouter API error ({resp.status}): {body}")
raise Exception(f"OpenRouter API error {resp.status}: {body}")
data = await resp.json()
choice = data.get("choices", [{}])[0]
content = choice.get("message", {}).get("content", "")
usage = data.get("usage", {})
return LLMCallResult(
content=content,
input_tokens=usage.get("prompt_tokens"),
output_tokens=usage.get("completion_tokens"),
raw_response=data,
)
async def call_google(
api_key: str, model: str, system_prompt: str, user_message: str,
) -> LLMCallResult:
"""Call Google Gemini API (generateContent)."""
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
payload = {
"contents": [
{
"role": "user",
"parts": [{"text": user_message}],
}
],
"system_instruction": {
"parts": [{"text": system_prompt}],
},
"generationConfig": {
"maxOutputTokens": 4096,
},
}
headers = {"Content-Type": "application/json"}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload,
timeout=aiohttp.ClientTimeout(total=LLM_TIMEOUT)) as resp:
if resp.status != 200:
body = await resp.text()
logger.error(f"Google Gemini API error ({resp.status}): {body}")
raise Exception(f"Google Gemini API error {resp.status}: {body}")
data = await resp.json()
# Extract text from candidates
candidates = data.get("candidates", [])
if candidates:
content_parts = candidates[0].get("content", {}).get("parts", [])
content = "\n".join(p.get("text", "") for p in content_parts)
else:
content = ""
# Token usage from usage_metadata
usage = data.get("usageMetadata", {})
return LLMCallResult(
content=content,
input_tokens=usage.get("promptTokenCount"),
output_tokens=usage.get("candidatesTokenCount"),
raw_response=data,
)
# ---------------------------------------------------------------------------
# Helper: build prompt from template
# ---------------------------------------------------------------------------
def build_prompt(summary_type: str, transcript: str) -> tuple[str, str]:
"""Build system + user prompt from template and transcript.
Args:
summary_type: One of 'full_lesson', 'questions_asked', 'teaching_style',
'key_moments', 'segment'
transcript: The full (or segment) transcript text
Returns:
(system_prompt, user_message) tuple
"""
template = PROMPT_TEMPLATES.get(summary_type, PROMPT_TEMPLATES["full_lesson"])
# The template has {transcript} placeholder — fill it in
filled = template.format(transcript=transcript)
# Split into system and user: everything before "Transcript:" is the system prompt,
# everything from "Transcript:" onward is the user message.
transcript_marker = "\n\nTranscript:\n"
if transcript_marker in filled:
system_prompt, user_message = filled.split(transcript_marker, 1)
user_message = "Transcript:\n" + user_message
else:
system_prompt = "You are an expert educational analyst."
user_message = filled
return system_prompt, user_message

View File

@ -16,6 +16,10 @@ from modules.transcription.models import (
SummaryResponse,
ExportFormat,
)
from modules.transcription.llm_client import call_llm, build_prompt
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@ -211,28 +215,86 @@ async def generate_summary(
summary_request: SummaryGenerateRequest,
user_id: str = Depends(get_user_id),
):
"""Generate a summary for a session (Phase 1 stub)."""
"""Generate a summary for a session using the specified LLM provider.
Phase 3: Full implementation calls the pluggable LLM client with
prompt templates from prompts.py. API key is passed per-request and
never stored or logged.
"""
supabase = get_supabase_client()
# Verify session exists and user owns it
session_check = supabase.supabase.table("transcription_sessions").select("id").eq("id", session_id).eq("user_id", user_id).execute()
session_check = supabase.supabase.table("transcription_sessions").select("*").eq("id", session_id).eq("user_id", user_id).execute()
if not session_check.data:
raise HTTPException(status_code=404, detail="Session not found")
# Phase 1 stub: TODO implement LLM call in Phase 3
content = "[TODO: Generate summary via LLM — provider={}, model={}]".format(
summary_request.provider, summary_request.model
)
session = session_check.data[0]
# Build transcript from segments (or use segment range)
segments_query = supabase.supabase.table("transcription_segments").select("*").eq("session_id", session_id).order("sequence_index")
segments_result = segments_query.execute()
if not segments_result.data:
raise HTTPException(status_code=400, detail="No segments found for this session")
# Apply segment range filter if specified
segments = segments_result.data
if summary_request.segment_range and len(summary_request.segment_range) == 2:
start_idx, end_idx = summary_request.segment_range[0], summary_request.segment_range[1]
if start_idx is not None and end_idx is not None:
segments = segments[start_idx:end_idx]
elif start_idx is not None:
segments = segments[start_idx:]
elif end_idx is not None:
segments = segments[:end_idx]
# Build full transcript text from segments
transcript_parts = [s["text"] for s in segments if s.get("text")]
transcript = "\n".join(transcript_parts)
if not transcript.strip():
raise HTTPException(status_code=400, detail="Transcript is empty — cannot generate summary")
# Build prompt from template
system_prompt, user_message = build_prompt(summary_request.summary_type, transcript)
# Call the LLM client
try:
llm_result = await call_llm(
provider=summary_request.provider,
model=summary_request.model,
api_key=summary_request.api_key,
system_prompt=system_prompt,
user_message=user_message,
)
except Exception as e:
logger.error(f"LLM call failed: {e}")
raise HTTPException(status_code=502, detail=f"LLM generation failed: {str(e)}")
# Determine segment range for storage
seg_start = None
seg_end = None
if summary_request.segment_range and len(summary_request.segment_range) == 2:
seg_start = summary_request.segment_range[0]
seg_end = summary_request.segment_range[1]
# Build the prompt that was used (for audit trail)
prompt_used = f"{system_prompt}\n\n{user_message}" if summary_request.summary_type != "segment" else user_message
# Save summary to database
summary_data = {
"session_id": session_id,
"user_id": user_id,
"summary_type": summary_request.summary_type,
"content": content,
"content": llm_result.content,
"prompt_used": prompt_used,
"llm_provider": summary_request.provider,
"llm_model": summary_request.model,
"input_tokens": llm_result.input_tokens,
"output_tokens": llm_result.output_tokens,
"segment_range_start": seg_start,
"segment_range_end": seg_end,
}
result = supabase.supabase.table("transcription_summaries").insert(summary_data).execute()