diff --git a/src/server/auth.test.ts b/src/server/auth.test.ts new file mode 100644 index 0000000..b9d996c --- /dev/null +++ b/src/server/auth.test.ts @@ -0,0 +1,114 @@ +import { describe, test, expect } from "bun:test"; +import { createHmac } from "node:crypto"; + +// Replicate the base64url encode function +function base64UrlEncode(data: Buffer | string): string { + const buf = typeof data === "string" ? Buffer.from(data) : data; + return buf.toString("base64") + .replace(/\+/g, "-") + .replace(/\//g, "_") + .replace(/=/g, ""); +} + +// Helper: create a valid HS256 JWT +function makeJwt(payload: Record, secret: string): string { + const header = base64UrlEncode(JSON.stringify({ alg: "HS256", typ: "JWT" })); + const body = base64UrlEncode(JSON.stringify(payload)); + const sig = createHmac("sha256", secret) + .update(`${header}.${body}`) + .digest(); + const signature = base64UrlEncode(sig); + return `${header}.${body}.${signature}`; +} + +// Inline the validateJwt for testing (import would need bun module resolution) +function base64UrlDecode(str: string): Buffer { + let base64 = str.replace(/-/g, "+").replace(/_/g, "/"); + const pad = base64.length % 4; + if (pad) base64 += "=".repeat(4 - pad); + return Buffer.from(base64, "base64"); +} + +import { createHmac as chmac, timingSafeEqual } from "node:crypto"; + +function validateJwt(token: string, secret: string): { valid: boolean; payload?: any; error?: string } { + try { + const parts = token.split("."); + if (parts.length !== 3) return { valid: false, error: "Invalid token format" }; + + const [headerB64, payloadB64, signatureB64] = parts; + const expectedSignature = chmac("sha256", secret).update(`${headerB64}.${payloadB64}`).digest(); + const receivedSignature = base64UrlDecode(signatureB64); + + if (expectedSignature.length !== receivedSignature.length || !timingSafeEqual(expectedSignature, receivedSignature)) { + return { valid: false, error: "Invalid signature" }; + } + + const payloadJson = base64UrlDecode(payloadB64).toString("utf-8"); + const payload = JSON.parse(payloadJson); + + if (payload.aud !== "tlsync") return { valid: false, error: "Invalid audience" }; + + const now = Math.floor(Date.now() / 1000); + if (!payload.exp || payload.exp < now) return { valid: false, error: "Token expired" }; + + return { valid: true, payload }; + } catch (error) { + return { valid: false, error: "Failed to validate token" }; + } +} + +const SECRET = "test-secret-key-12345"; + +describe("validateJwt", () => { + test("accepts valid token", () => { + const now = Math.floor(Date.now() / 1000); + const token = makeJwt({ sub: "user-1", aud: "tlsync", iat: now, exp: now + 300 }, SECRET); + const result = validateJwt(token, SECRET); + expect(result.valid).toBe(true); + expect(result.payload?.sub).toBe("user-1"); + expect(result.payload?.aud).toBe("tlsync"); + }); + + test("rejects expired token", () => { + const now = Math.floor(Date.now() / 1000); + const token = makeJwt({ sub: "user-1", aud: "tlsync", iat: now - 600, exp: now - 300 }, SECRET); + const result = validateJwt(token, SECRET); + expect(result.valid).toBe(false); + expect(result.error).toBe("Token expired"); + }); + + test("rejects token with wrong audience", () => { + const now = Math.floor(Date.now() / 1000); + const token = makeJwt({ sub: "user-1", aud: "wrong", iat: now, exp: now + 300 }, SECRET); + const result = validateJwt(token, SECRET); + expect(result.valid).toBe(false); + expect(result.error).toBe("Invalid audience"); + }); + + test("rejects token with wrong secret", () => { + const now = Math.floor(Date.now() / 1000); + const token = makeJwt({ sub: "user-1", aud: "tlsync", iat: now, exp: now + 300 }, "wrong-secret"); + const result = validateJwt(token, SECRET); + expect(result.valid).toBe(false); + expect(result.error).toBe("Invalid signature"); + }); + + test("rejects malformed token", () => { + const result = validateJwt("not.a.jwt", SECRET); + expect(result.valid).toBe(false); + }); + + test("rejects empty string", () => { + const result = validateJwt("", SECRET); + expect(result.valid).toBe(false); + }); + + test("rejects token missing exp claim", () => { + const now = Math.floor(Date.now() / 1000); + const token = makeJwt({ sub: "user-1", aud: "tlsync", iat: now }, SECRET); + const result = validateJwt(token, SECRET); + expect(result.valid).toBe(false); + expect(result.error).toBe("Token expired"); + }); +}); diff --git a/src/server/auth.ts b/src/server/auth.ts new file mode 100644 index 0000000..c936a05 --- /dev/null +++ b/src/server/auth.ts @@ -0,0 +1,71 @@ +import { createHmac, timingSafeEqual } from "node:crypto"; +import { logger } from "../logger"; + +function base64UrlDecode(str: string): Buffer { + // Convert base64url to base64 + let base64 = str.replace(/-/g, "+").replace(/_/g, "/"); + // Add padding + const pad = base64.length % 4; + if (pad) { + base64 += "=".repeat(4 - pad); + } + return Buffer.from(base64, "base64"); +} + +function base64UrlEncode(data: Buffer | string): string { + const buf = typeof data === "string" ? Buffer.from(data) : data; + return buf.toString("base64") + .replace(/\+/g, "-") + .replace(/\//g, "_") + .replace(/=/g, ""); +} + +interface JwtPayload { + sub: string; + aud: string; + iat: number; + exp: number; + jti?: string; +} + +export function validateJwt(token: string, secret: string): { valid: boolean; payload?: JwtPayload; error?: string } { + try { + const parts = token.split("."); + if (parts.length !== 3) { + return { valid: false, error: "Invalid token format" }; + } + + const [headerB64, payloadB64, signatureB64] = parts; + + // Verify signature + const expectedSignature = createHmac("sha256", secret) + .update(`${headerB64}.${payloadB64}`) + .digest(); + + const receivedSignature = base64UrlDecode(signatureB64); + + if (expectedSignature.length !== receivedSignature.length || !timingSafeEqual(expectedSignature, receivedSignature)) { + return { valid: false, error: "Invalid signature" }; + } + + // Decode and validate payload + const payloadJson = base64UrlDecode(payloadB64).toString("utf-8"); + const payload: JwtPayload = JSON.parse(payloadJson); + + // Validate audience + if (payload.aud !== "tlsync") { + return { valid: false, error: "Invalid audience" }; + } + + // Validate expiration + const now = Math.floor(Date.now() / 1000); + if (!payload.exp || payload.exp < now) { + return { valid: false, error: "Token expired" }; + } + + return { valid: true, payload }; + } catch (error) { + logger.error("JWT validation error:", error); + return { valid: false, error: "Failed to validate token" }; + } +} diff --git a/src/server/server.bun.ts b/src/server/server.bun.ts index 1f99c7f..62d8dcf 100644 --- a/src/server/server.bun.ts +++ b/src/server/server.bun.ts @@ -8,6 +8,7 @@ import { makeOrLoadRoom } from './rooms' import { unfurl } from './unfurl' import { server_schema_default } from './schema' import { logger } from './../logger' +import { validateJwt } from './auth' logger.info('Environment variables:', { PORT_TLDRAW_SYNC: process.env.PORT_TLDRAW_SYNC, @@ -75,11 +76,19 @@ const router: RouterType = Router() } if (TLSYNC_SECRET) { - const token = (req.query as any).token - if (token !== TLSYNC_SECRET) { - logger.warn(`Unauthorized connection attempt from IP: ${ip}`) + const token = (req.query as any).token as string | undefined + if (!token) { + logger.warn(`Missing token from IP: ${ip}`) return new Response('Unauthorized', { status: 401 }) } + + const result = validateJwt(token, TLSYNC_SECRET) + if (!result.valid) { + logger.warn(`Unauthorized connection attempt from IP: ${ip}, reason: ${result.error}`) + return new Response('Unauthorized', { status: 401 }) + } + + logger.debug(`Verified JWT for subject: ${result.payload?.sub}`) } const { roomId } = req.params