commit 05648af633c456d5b2565f7aa4a38105eaa4a3c1 Author: K Car Date: Sat Jun 7 13:06:08 2025 +0100 First commit of files diff --git a/.archive/Dockerfile.macos.dev b/.archive/Dockerfile.macos.dev new file mode 100644 index 0000000..4926c61 --- /dev/null +++ b/.archive/Dockerfile.macos.dev @@ -0,0 +1,51 @@ +FROM python:3.10-bookworm + +ARG DEBIAN_FRONTEND=noninteractive + +# Create log directories with proper permissions +RUN mkdir -p /app/logs && \ + touch /app/logs/whisperlive.log && \ + touch /app/logs/connections.log && \ + chmod 666 /app/logs/whisperlive.log && \ + chmod 666 /app/logs/connections.log + +# install lib required for pyaudio +RUN apt update && apt install -y portaudio19-dev && apt-get clean && rm -rf /var/lib/apt/lists/* + +# update pip to support for whl.metadata -> less downloading +RUN pip install --no-cache-dir -U "pip>=24" + +# create a working directory +WORKDIR /app + +# install the requirements for running the whisper-live server +COPY requirements/server.txt /app/ +RUN pip install -r server.txt && rm server.txt + +COPY whisper_live /app/whisper_live +COPY run_server.py /app + +# Port options +EXPOSE ${PORT_WHISPERLIVE} +EXPOSE ${PORT_WHISPERLIVE_SSL} +ARG PORT_WHISPERLIVE +ENV PORT_WHISPERLIVE=${PORT_WHISPERLIVE} +ARG PORT_WHISPERLIVE_SSL +ENV PORT_WHISPERLIVE_SSL=${PORT_WHISPERLIVE_SSL} + +# SSL options +ARG WHISPERLIVE_SSL +ENV WHISPERLIVE_SSL=${WHISPERLIVE_SSL} + +# Model options +ARG WHISPL_USE_CUSTOM_MODEL +ENV WHISPL_USE_CUSTOM_MODEL=${WHISPL_USE_CUSTOM_MODEL} +ARG FASTERWHISPER_MODEL +ENV FASTERWHISPER_MODEL=${FASTERWHISPER_MODEL} + +CMD ["sh", "-c", "\ + if [ \"$WHISPERLIVE_SSL\" = \"true\" ]; then \ + python3 -u run_server.py --port $PORT_WHISPERLIVE_SSL --backend faster_whisper --faster_whisper_custom_model_path /app/models/$FASTERWHISPER_MODEL --ssl_cert_path /app/ssl; \ + else \ + python3 -u run_server.py --port $PORT_WHISPERLIVE --backend faster_whisper --faster_whisper_custom_model_path /app/models/$FASTERWHISPER_MODEL --no_single_model; \ + fi"] diff --git a/.archive/Dockerfile.macos.prod b/.archive/Dockerfile.macos.prod new file mode 100644 index 0000000..9ab915a --- /dev/null +++ b/.archive/Dockerfile.macos.prod @@ -0,0 +1,45 @@ +FROM python:3.10-bookworm + +ARG DEBIAN_FRONTEND=noninteractive + +# Create log directories with proper permissions +RUN mkdir -p /app/logs && \ + touch /app/logs/whisperlive.log && \ + touch /app/logs/connections.log && \ + chmod 666 /app/logs/whisperlive.log && \ + chmod 666 /app/logs/connections.log + +# install lib required for pyaudio +RUN apt update && apt install -y portaudio19-dev && apt-get clean && rm -rf /var/lib/apt/lists/* + +# update pip to support for whl.metadata -> less downloading +RUN pip install --no-cache-dir -U "pip>=24" + +# create a working directory +WORKDIR /app + +# install the requirements for running the whisper-live server +COPY requirements/server.txt /app/ +RUN pip install -r server.txt && rm server.txt + +COPY whisper_live /app/whisper_live +COPY run_server.py /app + +# Copy application files +EXPOSE ${PORT_WHISPERLIVE} +EXPOSE ${PORT_WHISPERLIVE_SSL} +ARG PORT_WHISPERLIVE +ENV PORT_WHISPERLIVE=${PORT_WHISPERLIVE} +ARG PORT_WHISPERLIVE_SSL +ENV PORT_WHISPERLIVE_SSL=${PORT_WHISPERLIVE_SSL} +ARG FASTERWHISPER_MODEL +ENV FASTERWHISPER_MODEL=${FASTERWHISPER_MODEL} +ARG WHISPERLIVE_SSL +ENV WHISPERLIVE_SSL=${WHISPERLIVE_SSL} + +CMD ["sh", "-c", "\ + if [ \"$WHISPERLIVE_SSL\" = \"true\" ]; then \ + python3 -u run_server.py --port $PORT_WHISPERLIVE_SSL --backend faster_whisper --faster_whisper_custom_model_path /app/models/$FASTERWHISPER_MODEL --ssl_cert_path /app/ssl; \ + else \ + python3 -u run_server.py --port $PORT_WHISPERLIVE --backend faster_whisper --faster_whisper_custom_model_path /app/models/$FASTERWHISPER_MODEL; \ + fi"] diff --git a/.archive/Dockerfile.win.prod b/.archive/Dockerfile.win.prod new file mode 100644 index 0000000..79cb85a --- /dev/null +++ b/.archive/Dockerfile.win.prod @@ -0,0 +1,49 @@ +FROM python:3.10-bookworm + +ARG DEBIAN_FRONTEND=noninteractive + +# Create log directories with proper permissions +RUN mkdir -p /app/logs && \ + touch /app/logs/whisperlive.log && \ + touch /app/logs/connections.log && \ + chmod 666 /app/logs/whisperlive.log && \ + chmod 666 /app/logs/connections.log + +# install lib required for pyaudio +RUN apt update && apt install -y portaudio19-dev && apt-get clean && rm -rf /var/lib/apt/lists/* + +# update pip to support for whl.metadata -> less downloading +RUN pip install --no-cache-dir -U "pip>=24" + +# create a working directory +WORKDIR /app + +# install the requirements for running the whisper-live server +COPY requirements/server.txt /app/ +RUN pip install -r server.txt && rm server.txt + +# make the paths of the nvidia libs installed as wheels visible. equivalent to: +# export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'` +ENV LD_LIBRARY_PATH="/usr/local/lib/python3.10/site-packages/nvidia/cublas/lib:/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib" + +COPY whisper_live /app/whisper_live +COPY run_server.py /app + +# Copy application files +EXPOSE ${PORT_WHISPERLIVE} +EXPOSE ${PORT_WHISPERLIVE_SSL} +ARG PORT_WHISPERLIVE +ENV PORT_WHISPERLIVE=${PORT_WHISPERLIVE} +ARG PORT_WHISPERLIVE_SSL +ENV PORT_WHISPERLIVE_SSL=${PORT_WHISPERLIVE_SSL} +ARG FASTERWHISPER_MODEL +ENV FASTERWHISPER_MODEL=${FASTERWHISPER_MODEL} +ARG WHISPERLIVE_SSL +ENV WHISPERLIVE_SSL=${WHISPERLIVE_SSL} + +CMD ["sh", "-c", "\ + if [ \"$WHISPERLIVE_SSL\" = \"true\" ]; then \ + python3 -u run_server.py --port $PORT_WHISPERLIVE_SSL --backend faster_whisper --faster_whisper_custom_model_path /app/models/$FASTERWHISPER_MODEL --ssl_cert_path /app/ssl; \ + else \ + python3 -u run_server.py --port $PORT_WHISPERLIVE --backend faster_whisper --faster_whisper_custom_model_path /app/models/$FASTERWHISPER_MODEL; \ + fi"] diff --git a/.archive/docker-compose.yml b/.archive/docker-compose.yml new file mode 100644 index 0000000..404508a --- /dev/null +++ b/.archive/docker-compose.yml @@ -0,0 +1,1191 @@ +services: + nginx: + profiles: + - core + container_name: nginx-proxy-manager-${NGINX_MODE:-dev} + image: 'jc21/nginx-proxy-manager:latest' + ports: + - '80:80' + - '81:81' + - '443:443' + volumes: + - ./cc-volumes/nginx-proxy-manager/${BUILD_OS}/${NGINX_MODE:-dev}/data:/data + - ./cc-volumes/nginx-proxy-manager/${BUILD_OS}/${NGINX_MODE:-dev}/letsencrypt:/etc/letsencrypt + - ./cc-volumes/nginx-proxy-manager/${BUILD_OS}/${NGINX_MODE:-dev}/snippets:/snippets:ro + environment: + TZ: Europe/London + networks: + - cc-network + + keycloak: + profiles: + - core + - database + container_name: keycloak-${NGINX_MODE:-dev} + build: + context: ./cc-volumes/keycloak/${NGINX_MODE:-dev}/docker + dockerfile: Dockerfile.${BUILD_OS}.${NGINX_MODE:-dev} + args: + KC_BOOTSTRAP_ADMIN_PASSWORD: ${KEYCLOAK_ADMIN_PASSWORD} + KC_BOOTSTRAP_ADMIN_USERNAME: ${KEYCLOAK_ADMIN} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} + KC_DB: postgres + KC_DB_URL: jdbc:postgresql://db:5432/postgres + KC_DB_USERNAME: keycloak + KC_DB_PASSWORD: keycloak + KC_DB_SCHEMA: keycloak + KC_HOSTNAME: ${KEYCLOAK_URL} + KC_HOSTNAME_STRICT: "false" + KC_PROXY_HEADERS: xforwarded + KC_PROXY_PROTOCOL_ENABLED: "false" + KC_HTTP_ENABLED: "true" + KC_HTTPS_ENABLED: "false" + KC_HOSTNAME_ADMIN: ${KEYCLOAK_ADMIN_URL} + KC_HOSTNAME_DEBUG: "true" + KC_HEALTH_ENABLED: "true" + KC_HOSTNAME_BACKCHANNEL_DYNAMIC: "false" + KC_METRICS_ENABLED: "true" + KC_LOG_LEVEL: DEBUG + KC_HTTP_RELATIVE_PATH: / + depends_on: + db: + condition: service_healthy + restart: unless-stopped + ports: + - "${KEYCLOAK_MANAGEMENT_PORT}:9000" + - "${KEYCLOAK_PORT}:8080" + - "${KEYCLOAK_SSL_PORT}:8443" + volumes: + - ./cc-volumes/keycloak/${NGINX_MODE:-dev}/conf:/opt/keycloak/conf:ro + - ./cc-volumes/keycloak/${NGINX_MODE:-dev}/providers:/opt/keycloak/providers:ro + - ./cc-volumes/keycloak/${NGINX_MODE:-dev}/themes:/opt/keycloak/themes:ro + - ./cc-volumes/keycloak/${NGINX_MODE:-dev}/master-realm-${NGINX_MODE:-dev}-${BUILD_OS}.json:/opt/keycloak/data/import/master-realm.json:ro + - ./cc-volumes/keycloak/${NGINX_MODE:-dev}/classroomcopilot-realm-${NGINX_MODE:-dev}-${BUILD_OS}.json:/opt/keycloak/data/import/classroomcopilot-realm.json:ro + networks: + - cc-network + + oauth2-proxy-admin: + image: quay.io/oauth2-proxy/oauth2-proxy:v7.6.0 + container_name: oauth2-proxy-admin + restart: unless-stopped + environment: + OAUTH2_PROXY_PROVIDER: oidc + OAUTH2_PROXY_OIDC_ISSUER_URL: https://keycloak.classroomcopilot.test/realms/classroomcopilot + OAUTH2_PROXY_CLIENT_ID: admin-app + OAUTH2_PROXY_CLIENT_SECRET: ${KEYCLOAK_SECRET_ADMIN} + OAUTH2_PROXY_COOKIE_SECRET: ${COOKIE_SECRET_ADMIN} + OAUTH2_PROXY_COOKIE_DOMAIN: .classroomcopilot.test + OAUTH2_PROXY_UPSTREAMS: http://cc-admin:3000 + OAUTH2_PROXY_REDIRECT_URL: https://admin.classroomcopilot.test/oauth2/callback + OAUTH2_PROXY_EMAIL_DOMAINS: "*" + OAUTH2_PROXY_ALLOWED_GROUPS: "admin" + OAUTH2_PROXY_SKIP_PROVIDER_BUTTON: "true" + OAUTH2_PROXY_PASS_ACCESS_TOKEN: "true" + OAUTH2_PROXY_SET_XAUTHREQUEST: "true" + ports: + - "4181:4180" + networks: + - cc-network + + whisperlive-frontend: + profiles: + - core + - frontend + container_name: whisperlive-frontend-${NGINX_MODE:-dev} + build: + context: . + dockerfile: ./whisperlive-frontend/Dockerfile + args: + BUILD_OS: ${BUILD_OS} + NGINX_MODE: ${NGINX_MODE} + environment: + - VITE_APP_URL=${APP_URL} + - VITE_APP_PROTOCOL=${APP_PROTOCOL} + - VITE_APP_NAME=${APP_NAME} + - VITE_DEV=${DEV_MODE} + - VITE_WHISPERLIVE_URL=${WHISPERLIVE_URL} + ports: + - "${PORT_WHISPERLIVE_FRONTEND}:${PORT_WHISPERLIVE_FRONTEND}" + - "${PORT_WHISPERLIVE_FRONTEND_SSL}:${PORT_WHISPERLIVE_FRONTEND_SSL}" + volumes: + - ./whisperlive-frontend:/app + - /app/node_modules + - ./cc-volumes/whisperlive/frontend/ssl/fullchain1.pem:/etc/nginx/ssl/fullchain.pem:ro + - ./cc-volumes/whisperlive/frontend/ssl/privkey1.pem:/etc/nginx/ssl/privkey.pem:ro + networks: + - cc-network + + whisperlive-win: + profiles: + - none + container_name: whisperlive-${NGINX_MODE:-dev} + build: + context: ./WhisperLive/server + dockerfile: Dockerfile.${NGINX_MODE:-dev} + args: + PORT_WHISPERLIVE: ${PORT_WHISPERLIVE} + PORT_WHISPERLIVE_SSL: ${PORT_WHISPERLIVE_SSL} + WHISPERLIVE_SSL: ${WHISPERLIVE_SSL:-false} + WHISPERLIVE_MODEL: ${WHISPERLIVE_MODEL:-base} + env_file: + - .env + environment: + WHISPERLIVE_SSL: ${WHISPERLIVE_SSL:-false} + LOG_PATH: /app/logs + NVIDIA_VISIBLE_DEVICES: all + NVIDIA_DRIVER_CAPABILITIES: compute,utility + volumes: + - ./cc-volumes/whisperlive/models:/app/models + - ./cc-volumes/whisperlive/${NGINX_MODE:-dev}/ssl:/app/ssl + - ./local/logs/whisperlive:/app/logs + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + ports: + - ${PORT_WHISPERLIVE}:${PORT_WHISPERLIVE} + - ${PORT_WHISPERLIVE_SSL}:${PORT_WHISPERLIVE_SSL} + networks: + - cc-network + + whisperlive-macos: + profiles: + - core + container_name: whisperlive-${NGINX_MODE:-dev} + build: + context: ./WhisperLive/server + dockerfile: Dockerfile.${BUILD_OS}.${NGINX_MODE:-dev} + args: + PORT_WHISPERLIVE: ${PORT_WHISPERLIVE} + PORT_WHISPERLIVE_SSL: ${PORT_WHISPERLIVE_SSL} + WHISPERLIVE_SSL: ${WHISPERLIVE_SSL:-false} + WHISPL_USE_CUSTOM_MODEL: ${WHISPL_USE_CUSTOM_MODEL:-false} + FASTERWHISPER_MODEL: ${FASTERWHISPER_MODEL:-base} + env_file: + - .env + environment: + WHISPERLIVE_SSL: ${WHISPERLIVE_SSL:-false} + LOG_PATH: /app/logs + NVIDIA_VISIBLE_DEVICES: all + NVIDIA_DRIVER_CAPABILITIES: compute,utility + volumes: + - ./local/data/whisperlive/models:/app/models + - ./local/data/whisperlive/auto-download:/root/.cache/huggingface/hub + - ./cc-volumes/whisperlive/${NGINX_MODE:-dev}/ssl:/app/ssl + - ./local/logs/whisperlive:/app/logs + deploy: + resources: + limits: + cpus: '4' + memory: 8G + ports: + - ${PORT_WHISPERLIVE}:${PORT_WHISPERLIVE} + - ${PORT_WHISPERLIVE_SSL}:${PORT_WHISPERLIVE_SSL} + networks: + - cc-network + + whisperlive-cpu: + profiles: + - none + container_name: whisperlive-cpu-${NGINX_MODE:-dev} + image: ghcr.io/collabora/whisperlive-cpu:latest + environment: + LOG_PATH: /app/logs + volumes: + - ./cc-volumes/whisperlive/models:/app/models + - ./cc-volumes/whisperlive/${NGINX_MODE:-dev}/ssl:/app/ssl + - ./local/logs/whisperlive-cpu:/app/logs + deploy: + resources: + limits: + cpus: '4' + memory: 8G + ports: + - ${PORT_WHISPERLIVE}:9090 + networks: + - cc-network + + whisperlive-gpu: + profiles: + - none + container_name: whisperlive-gpu-${NGINX_MODE:-dev} + image: ghcr.io/collabora/whisperlive-gpu:latest + environment: + LOG_PATH: /app/logs + NVIDIA_VISIBLE_DEVICES: all + NVIDIA_DRIVER_CAPABILITIES: compute,utility + volumes: + - ./cc-volumes/whisperlive/models:/app/models + - ./cc-volumes/whisperlive/${NGINX_MODE:-dev}/ssl:/app/ssl + - ./local/logs/whisperlive-gpu:/app/logs + deploy: + resources: + limits: + cpus: '4' + memory: 16G + ports: + - ${PORT_WHISPERLIVE}:9090 + networks: + - cc-network + + solid-proxy-internal: + profiles: + - core + container_name: solid-proxy-internal-${NGINX_MODE:-dev} + image: nginx:alpine + ports: + - 3007:3007 + volumes: + - ./cc-volumes/solid-css/${NGINX_MODE:-dev}/nginx/solid-internal.conf:/etc/nginx/conf.d/default.conf:ro + - ./cc-volumes/cloudflare-origin-certs/solid_cc_cert.pem:/etc/nginx/ssl/cert.pem:ro + - ./cc-volumes/cloudflare-origin-certs/solid_cc_key.pem:/etc/nginx/ssl/key.pem:ro + - ./local/logs/${NGINX_MODE:-dev}/solid-proxy-internal:/var/log/nginx + networks: + - cc-network + + cc-marketing-site: + profiles: + - core + - frontend + container_name: cc-marketing-${NGINX_MODE:-dev} + build: + context: ./cc-marketing + dockerfile: Dockerfile.${NGINX_MODE:-dev} + env_file: + - .env + environment: + - VITE_APP_URL=${APP_URL} + - VITE_APP_SITE_URL=${SITE_URL} + - VITE_APP_APP_URL=${APP_URL} someone check + ports: + - "${PORT_MARKETING_SITE}:${PORT_MARKETING_SITE}" + - "${PORT_MARKETING_SITE_SSL}:${PORT_MARKETING_SITE_SSL}" + networks: + - cc-network + + frontend: + profiles: + - core + - frontend + container_name: frontend-${NGINX_MODE:-dev} + build: + context: ./frontend + dockerfile: Dockerfile.${NGINX_MODE:-dev} + args: + VITE_APP_URL: ${VITE_APP_URL} + environment: + - VITE_FRONTEND_SITE_URL=${SITE_URL} + - VITE_APP_PROTOCOL=${APP_PROTOCOL} + - VITE_APP_NAME=${APP_NAME} + - VITE_SUPER_ADMIN_EMAIL=${APP_AUTHOR_EMAIL} + - VITE_DEV=${DEV_MODE} + - VITE_SUPABASE_URL=${SUPABASE_URL} + - VITE_SUPABASE_ANON_KEY=${ANON_KEY} + - VITE_STRICT_MODE=${STRICT_MODE} + - APP_URL=${APP_URL} + - PORT_FRONTEND=${PORT_FRONTEND} + ports: + - "${PORT_FRONTEND}:${PORT_FRONTEND}" + volumes: + - ./frontend:/app + - /app/node_modules + networks: + - cc-network + + storybook: + profiles: + - core + - frontend + container_name: storybook-${NGINX_MODE:-dev} + build: + context: ./frontend + dockerfile: Dockerfile.storybook.macos.${NGINX_MODE:-dev} + environment: + - NODE_ENV=${NGINX_MODE:-dev} + ports: + - "${PORT_STORYBOOK:-6006}:6006" + volumes: + - ./frontend:/app + - /app/node_modules + networks: + - cc-network + depends_on: + - frontend + + cc-admin: + profiles: + - core + - frontend + container_name: cc-admin-${NGINX_MODE:-dev} + build: + context: ./cc-admin + dockerfile: Dockerfile.${NGINX_MODE:-dev} + args: + PORT: ${PORT_CC_ADMIN} + PORT_DEVTOOLS: ${PORT_CC_ADMIN_DEVTOOLS} + SUPABASE_URL: ${SUPABASE_URL} + ANON_KEY: ${ANON_KEY} + SERVICE_ROLE_KEY: ${SERVICE_ROLE_KEY} + VITE_CC_ADMIN_URL: ${CC_ADMIN_URL} + environment: + APP_URL: ${APP_URL} + PORT_CC_ADMIN: ${PORT_CC_ADMIN} + PORT_CC_ADMIN_DEVTOOLS: ${PORT_CC_ADMIN_DEVTOOLS} + env_file: + - .env + - ./cc-admin/.env.${NGINX_MODE:-dev} + ports: + - "${PORT_CC_ADMIN}:${PORT_CC_ADMIN}" + volumes: + - ./cc-admin:/app + - /app/node_modules + networks: + - cc-network + + backend: + profiles: + - core + - backend + container_name: backend-${NGINX_MODE:-dev} + build: + context: ./backend + dockerfile: Dockerfile.${BUILD_OS}.${NGINX_MODE:-dev} + env_file: + - .env + environment: + ADMIN_EMAIL: ${SUPER_ADMIN_EMAIL} + ADMIN_PASSWORD: ${SUPER_ADMIN_PASSWORD} + ADMIN_NAME: ${SUPER_ADMIN_NAME} + ADMIN_USERNAME: ${SUPER_ADMIN_USERNAME} + ADMIN_DISPLAY_NAME: ${SUPER_ADMIN_DISPLAY_NAME} + SUPABASE_URL: ${SUPABASE_URL} + SERVICE_ROLE_KEY: ${SERVICE_ROLE_KEY} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} + POSTGRES_DB: ${POSTGRES_DB} + UVICORN_TIMEOUT: 300 + volumes: + - /var/run/docker.sock:/var/run/docker.sock + - ./backend/:/app/backend + - ./cc-volumes/init:/init:rw + - ./local/logs/container/backend:/logs + - ./local/input:/app/local/input:rw + - ./local/output:/app/local/output:rw + ports: + - "${PORT_BACKEND}:${PORT_BACKEND}" + extra_hosts: + - "supa.classroomcopilot.test:172.23.0.1" + networks: + - cc-network + deploy: + resources: + limits: + cpus: '2' + memory: 4G + + tldraw-sync: + profiles: + - core + - backend + container_name: tldraw-sync-${NGINX_MODE:-dev} + build: + context: ./tldraw-sync + dockerfile: Dockerfile + env_file: + - .env + environment: + - LOG_PATH=/app/logs + ports: + - "5002:5002" + volumes: + - ./tldraw-sync:/app + - ./cc-volumes/tldraw-sync/bunfig.toml:/app/bunfig.toml:ro + - ./local/data/tldraw-sync/.assets:/app/.assets + - ./local/data/tldraw-sync/.rooms:/app/.rooms + - ./local/logs/container/tldraw-sync:/app/logs + networks: + - cc-network + + neo4j: + profiles: + - database + - backend + image: neo4j:enterprise + container_name: neo4j-${NGINX_MODE:-dev} + env_file: + - .env + environment: + - NEO4J_ACCEPT_LICENSE_AGREEMENT=yes + - NEO4J_PLUGINS='["apoc"]' + ports: + - ${PORT_NEO4J_HTTP}:${PORT_NEO4J_HTTP} + - ${PORT_NEO4J_HTTPS}:${PORT_NEO4J_HTTPS} + - ${PORT_NEO4J_BOLT}:${PORT_NEO4J_BOLT} + volumes: + - neo4j-data:/data + - neo4j-logs:/logs + - ./cc-volumes/neo4j/conf/${NGINX_MODE:-dev}/neo4j.conf:/conf/neo4j.conf:ro + - ./cc-volumes/cloudflare-origin-certs/graph_cc_key.pem:/certificates/https/private.key:ro + - ./cc-volumes/cloudflare-origin-certs/graph_cc_cert.pem:/certificates/https/public.crt:ro + - ./cc-volumes/letsencrypt-certs/bolt.classroomcopilot/privkey1.pem:/certificates/bolt/private.key:ro + - ./cc-volumes/letsencrypt-certs/bolt.classroomcopilot/fullchain1.pem:/certificates/bolt/public.crt:ro + - ./cc-volumes/letsencrypt-certs/bolt.classroomcopilot/fullchain1.pem:/certificates/bolt/trusted/public.crt:ro + - ./cc-volumes/neo4j/plugins:/plugins:rw + - ./local/logs/container/neo4j:/logs + healthcheck: + test: ["CMD-SHELL", "neo4j status || exit 1"] + interval: 10s + timeout: 5s + retries: 10 + networks: + - cc-network + + solid-css: + profiles: + - solid + image: solidproject/community-server:latest + container_name: solid-css-${NGINX_MODE:-dev} + restart: unless-stopped + ports: + - "${PORT_SOLID_CSS}:3000" + volumes: + - ./cc-volumes/solid-css/${NGINX_MODE:-dev}/config:/config:ro + - ./cc-volumes/solid-css/${NGINX_MODE:-dev}/data:/data + command: + - --config + - /config/docker.json + networks: + - cc-network + + redis: + profiles: + - database + - backend + image: redis:alpine + container_name: redis-${NGINX_MODE:-dev} + networks: + - cc-network + ports: + - "${PORT_REDIS:-6379}:6379" + command: redis-server --appendonly yes + volumes: + - redis-data:/data + + searxng: + profiles: + - core + - services + - backend + image: searxng/searxng + container_name: searxng-${NGINX_MODE:-dev} + ports: + - "${PORT_SEARXNG}:${PORT_SEARXNG}" + env_file: + - .env + volumes: + - ./cc-volumes/searxng/limiter.toml:/etc/searxng/limiter.toml + - ./cc-volumes/searxng/settings.yml:/etc/searxng/settings.yml + networks: + - cc-network + + mailhog: + profiles: + - core + container_name: mailhog-${NGINX_MODE:-dev} + image: mailhog/mailhog + ports: + - "${PORT_MAILHOG_SMTP}:1025" # SMTP port + - "${PORT_MAILHOG_WEB}:8025" # Web UI port + env_file: + - .env + volumes: + - ./local/logs/mailhog:/var/mailhog + - ./local/data/mailhog:/var/mailhog/mailhog + networks: + - cc-network + + postfix: + profiles: + - prod + image: catatnight/postfix + environment: + - maildomain=${APP_URL} + - smtp_user=user:password + ports: + - "25:25" + + minecraft-server: + profiles: + - none + image: itzg/minecraft-server + container_name: cc-minecraft-forge-${NGINX_MODE:-dev} + environment: + EULA: "TRUE" + TYPE: VANILLA + ONLINE_MODE: "false" + PROXY: "minecraft.kevlarai.com" + + # ✅ Set custom server host details + MOTD: "Welcome to KevlarAI's Minecraft Forge Server" + + # ✅ Optional extras (customize as desired) + MAX_PLAYERS: 20 + ALLOW_NETHER: "TRUE" + ENABLE_COMMAND_BLOCK: "TRUE" + DIFFICULTY: "normal" + MODE: "survival" + LEVEL_TYPE: "minecraft:default" + LEVEL: "world" + PVP: "TRUE" + ports: + - 25575:25575 + - 25565:25565 + volumes: + - ./cc-volumes/minecraft/${NGINX_MODE:-dev}/vanilla/data:/data + restart: unless-stopped + networks: + - cc-network + + # Supabase containers + studio: + profiles: + - database + - supabase + container_name: supabase-studio-${NGINX_MODE:-dev} + image: supabase/studio:20250113-83c9420 + restart: unless-stopped + healthcheck: + test: + [ + "CMD", + "node", + "-e", + "fetch('http://studio:3000/api/profile').then((r) => {if (r.status !== 200) throw new Error(r.status)})", + ] + timeout: 10s + interval: 5s + retries: 3 + depends_on: + analytics: + condition: service_healthy + ports: + - ${PORT_SUPABASE_STUDIO}:3000 + env_file: + - .env + environment: + STUDIO_PG_META_URL: http://meta:8080 + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} + DEFAULT_PROJECT_ID: "ClassroomCopilot" + DEFAULT_ORGANIZATION_NAME: ${STUDIO_DEFAULT_ORGANIZATION} + DEFAULT_PROJECT_NAME: ${STUDIO_DEFAULT_PROJECT} + OPENAI_API_KEY: ${OPENAI_API_KEY:-} + SUPABASE_URL: ${SUPABASE_URL} + SUPABASE_PUBLIC_URL: ${SUPABASE_PUBLIC_URL} + SUPABASE_ANON_KEY: ${ANON_KEY} + SUPABASE_SERVICE_KEY: ${SERVICE_ROLE_KEY} + LOGFLARE_API_KEY: ${LOGFLARE_API_KEY} + LOGFLARE_URL: http://analytics:4000 + NEXT_PUBLIC_ENABLE_LOGS: true + NEXT_ANALYTICS_BACKEND_PROVIDER: postgres + networks: + - cc-network + + kong: + profiles: + - database + - supabase + container_name: supabase-kong-${NGINX_MODE:-dev} + image: kong:2.8.1 + restart: unless-stopped + entrypoint: bash -c 'eval "echo \"$$(cat ~/temp.yml)\"" > ~/kong.yml && /docker-entrypoint.sh kong docker-start' + ports: + - ${KONG_HTTP_PORT}:8000/tcp + - ${KONG_HTTPS_PORT}:8443/tcp + depends_on: + analytics: + condition: service_healthy + env_file: + - .env + environment: + KONG_DATABASE: "off" + KONG_DECLARATIVE_CONFIG: /home/kong/kong.yml + KONG_DNS_ORDER: LAST,A,CNAME + KONG_PLUGINS: request-transformer,cors,key-auth,acl,basic-auth + KONG_NGINX_PROXY_PROXY_BUFFER_SIZE: 160k + KONG_NGINX_PROXY_PROXY_BUFFERS: 64 160k + SUPABASE_ANON_KEY: ${ANON_KEY} + SUPABASE_SERVICE_KEY: ${SERVICE_ROLE_KEY} + DASHBOARD_USERNAME: ${DASHBOARD_USERNAME} + DASHBOARD_PASSWORD: ${DASHBOARD_PASSWORD} + KONG_PROXY_ACCESS_LOG: "/dev/stdout" + KONG_ADMIN_ACCESS_LOG: "/dev/stdout" + KONG_PROXY_ERROR_LOG: "/dev/stderr" + KONG_ADMIN_ERROR_LOG: "/dev/stderr" + KONG_CORS_ORIGINS: "*" + KONG_CORS_METHODS: "GET,HEAD,PUT,PATCH,POST,DELETE,OPTIONS" + KONG_CORS_HEADERS: "DNT,X-Auth-Token,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization,apikey,x-client-info" + KONG_CORS_EXPOSED_HEADERS: "Content-Length,Content-Range" + KONG_CORS_MAX_AGE: 3600 + volumes: + - ./supabase/api/kong.yml:/home/kong/temp.yml:ro + networks: + - cc-network + + auth: + profiles: + - database + - supabase + container_name: supabase-auth-${NGINX_MODE:-dev} + image: supabase/gotrue:v2.167.0 + depends_on: + db: + condition: service_healthy + analytics: + condition: service_healthy + healthcheck: + test: + [ + "CMD", + "wget", + "--no-verbose", + "--tries=1", + "--spider", + "http://localhost:9999/health", + ] + timeout: 5s + interval: 5s + retries: 3 + restart: unless-stopped + env_file: + - .env + environment: + GOTRUE_API_HOST: 0.0.0.0 + GOTRUE_API_PORT: 9999 + API_EXTERNAL_URL: ${API_EXTERNAL_URL} + GOTRUE_DB_DRIVER: postgres + GOTRUE_DB_DATABASE_URL: postgres://supabase_auth_admin:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB} + GOTRUE_SITE_URL: ${SITE_URL} + GOTRUE_URI_ALLOW_LIST: ${ADDITIONAL_REDIRECT_URLS} + GOTRUE_DISABLE_SIGNUP: ${DISABLE_SIGNUP} + GOTRUE_JWT_ADMIN_ROLES: service_role + GOTRUE_JWT_AUD: authenticated + GOTRUE_JWT_DEFAULT_GROUP_NAME: authenticated + GOTRUE_JWT_EXP: ${JWT_EXPIRY} + GOTRUE_JWT_SECRET: ${JWT_SECRET} + GOTRUE_LOG_LEVEL: ${AUTH_LOG_LEVEL} + GOTRUE_SMTP_ADMIN_EMAIL: ${SMTP_ADMIN_EMAIL} + GOTRUE_SMTP_HOST: ${SMTP_HOST} + GOTRUE_SMTP_PORT: ${SMTP_PORT} + GOTRUE_SMTP_USER: ${SMTP_USER} + GOTRUE_SMTP_PASS: ${SMTP_PASS} + GOTRUE_SMTP_SENDER_NAME: ${SMTP_SENDER_NAME} + GOTRUE_MAILER_URLPATHS_INVITE: ${MAILER_URLPATHS_INVITE} + GOTRUE_MAILER_URLPATHS_CONFIRMATION: ${MAILER_URLPATHS_CONFIRMATION} + GOTRUE_MAILER_URLPATHS_RECOVERY: ${MAILER_URLPATHS_RECOVERY} + GOTRUE_MAILER_URLPATHS_EMAIL_CHANGE: ${MAILER_URLPATHS_EMAIL_CHANGE} + GOTRUE_MAILER_AUTOCONFIRM: ${ENABLE_EMAIL_AUTOCONFIRM} + GOTRUE_MAILER_SECURE_EMAIL_CHANGE_ENABLED: ${MAILER_SECURE_EMAIL_CHANGE_ENABLED} + GOTRUE_MAILER_EXTERNAL_HOSTS: "localhost,admin.localhost,kong,supabase.classroomcopilot.ai,classroomcopilot.ai" + GOTRUE_MAILER_EXTERNAL_HOSTS_ALLOW_REGEX: ".*\\.classroomcopilot\\.ai$" + GOTRUE_SMS_AUTOCONFIRM: ${ENABLE_PHONE_AUTOCONFIRM} + GOTRUE_EXTERNAL_EMAIL_ENABLED: ${ENABLE_EMAIL_SIGNUP} + GOTRUE_EXTERNAL_ANONYMOUS_USERS_ENABLED: ${ENABLE_ANONYMOUS_USERS} + GOTRUE_EXTERNAL_PHONE_ENABLED: ${ENABLE_PHONE_SIGNUP} + GOTRUE_EXTERNAL_AZURE_ENABLED: ${AZURE_ENABLED} + GOTRUE_EXTERNAL_AZURE_CLIENT_ID: ${AZURE_CLIENT_ID} + GOTRUE_EXTERNAL_AZURE_SECRET: ${AZURE_SECRET} + GOTRUE_EXTERNAL_AZURE_REDIRECT_URI: ${AZURE_REDIRECT_URI} + networks: + - cc-network + + rest: + profiles: + - database + - supabase + container_name: supabase-rest-${NGINX_MODE:-dev} + image: postgrest/postgrest:v12.2.0 + depends_on: + db: + condition: service_healthy + analytics: + condition: service_healthy + restart: unless-stopped + env_file: + - .env + environment: + PGRST_DB_URI: postgres://authenticator:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB} + PGRST_DB_SCHEMAS: ${PGRST_DB_SCHEMAS} + PGRST_DB_ANON_ROLE: anon + PGRST_JWT_SECRET: ${JWT_SECRET} + PGRST_DB_USE_LEGACY_GUCS: "false" + PGRST_APP_SETTINGS_JWT_SECRET: ${JWT_SECRET} + PGRST_APP_SETTINGS_JWT_EXP: ${JWT_EXPIRY} + command: "postgrest" + networks: + - cc-network + + realtime: + profiles: + - database + - supabase + container_name: realtime-dev-${NGINX_MODE:-dev}.supabase-realtime + image: supabase/realtime:v2.34.7 + depends_on: + db: + condition: service_healthy + analytics: + condition: service_healthy + healthcheck: + test: + [ + "CMD", + "curl", + "-sSfL", + "--head", + "-o", + "/dev/null", + "-H", + "Authorization: Bearer ${ANON_KEY}", + "http://localhost:4000/api/tenants/realtime-dev/health", + ] + timeout: 5s + interval: 5s + retries: 3 + restart: unless-stopped + env_file: + - .env + environment: + PORT: 4000 + DB_HOST: ${POSTGRES_HOST} + DB_PORT: ${POSTGRES_PORT} + DB_USER: supabase_admin + DB_PASSWORD: ${POSTGRES_PASSWORD} + DB_NAME: ${POSTGRES_DB} + DB_AFTER_CONNECT_QUERY: "SET search_path TO _realtime" + DB_ENC_KEY: supabaserealtime + API_JWT_SECRET: ${JWT_SECRET} + SECRET_KEY_BASE: ${SECRET_KEY_BASE} + ERL_AFLAGS: -proto_dist inet_tcp + DNS_NODES: "''" + RLIMIT_NOFILE: "10000" + APP_NAME: realtime + SEED_SELF_HOST: true + RUN_JANITOR: true + networks: + - cc-network + + storage: + profiles: + - database + - supabase + container_name: supabase-storage-${NGINX_MODE:-dev} + image: supabase/storage-api:v1.14.5 + depends_on: + db: + condition: service_healthy + rest: + condition: service_started + imgproxy: + condition: service_started + healthcheck: + test: + [ + "CMD", + "wget", + "--no-verbose", + "--tries=1", + "--spider", + "http://storage:5000/status", + ] + timeout: 5s + interval: 5s + retries: 3 + restart: unless-stopped + env_file: + - .env + environment: + ANON_KEY: ${ANON_KEY} + SERVICE_KEY: ${SERVICE_ROLE_KEY} + POSTGREST_URL: http://rest:3000 + PGRST_JWT_SECRET: ${JWT_SECRET} + DATABASE_URL: postgres://supabase_storage_admin:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB} + FILE_SIZE_LIMIT: 52428800 + STORAGE_BACKEND: file + FILE_STORAGE_BACKEND_PATH: /var/lib/storage + TENANT_ID: stub + REGION: stub + GLOBAL_S3_BUCKET: stub + ENABLE_IMAGE_TRANSFORMATION: "true" + IMGPROXY_URL: http://imgproxy:5001 + networks: + - cc-network + + imgproxy: + profiles: + - database + - supabase + container_name: supabase-imgproxy-${NGINX_MODE:-dev} + image: darthsim/imgproxy:v3.8.0 + healthcheck: + test: ["CMD", "imgproxy", "health"] + timeout: 10s + interval: 5s + retries: 10 + env_file: + - .env + environment: + IMGPROXY_BIND: ":5001" + IMGPROXY_LOCAL_FILESYSTEM_ROOT: / + IMGPROXY_USE_ETAG: "true" + IMGPROXY_ENABLE_WEBP_DETECTION: ${IMGPROXY_ENABLE_WEBP_DETECTION} + volumes: + - ./local/data/supabase/storage-${NGINX_MODE:-dev}:/var/lib/storage:z + networks: + - cc-network + + meta: + profiles: + - database + - supabase + container_name: supabase-meta-${NGINX_MODE:-dev} + image: supabase/postgres-meta:v0.84.2 + depends_on: + db: + condition: service_healthy + analytics: + condition: service_healthy + restart: unless-stopped + env_file: + - .env + environment: + PG_META_PORT: 8080 + PG_META_DB_HOST: ${POSTGRES_HOST} + PG_META_DB_PORT: ${POSTGRES_PORT} + PG_META_DB_NAME: ${POSTGRES_DB} + PG_META_DB_USER: supabase_admin + PG_META_DB_PASSWORD: ${POSTGRES_PASSWORD} + networks: + - cc-network + + functions: + profiles: + - database + - supabase + container_name: supabase-edge-functions-${NGINX_MODE:-dev} + image: supabase/edge-runtime:v1.67.0 + restart: unless-stopped + depends_on: + analytics: + condition: service_healthy + env_file: + - .env + environment: + JWT_SECRET: ${JWT_SECRET} + SUPABASE_URL: ${SUPABASE_URL} + SUPABASE_ANON_KEY: ${ANON_KEY} + SUPABASE_SERVICE_ROLE_KEY: ${SERVICE_ROLE_KEY} + SUPABASE_DB_URL: postgresql://postgres:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB} + VERIFY_JWT: "${FUNCTIONS_VERIFY_JWT}" + volumes: + - ./supabase/functions:/home/deno/functions:Z + command: + - start + - --main-service + - /home/deno/functions/main + networks: + - cc-network + + analytics: + profiles: + - database + - supabase + container_name: supabase-analytics-${NGINX_MODE:-dev} + image: supabase/logflare:1.4.0 + healthcheck: + test: ["CMD", "curl", "http://localhost:4000/health"] + timeout: 10s + interval: 5s + retries: 10 + restart: unless-stopped + depends_on: + db: + condition: service_healthy + env_file: + - .env + environment: + LOGFLARE_NODE_HOST: 127.0.0.1 + DB_USERNAME: supabase_admin + DB_DATABASE: _supabase + DB_HOSTNAME: ${POSTGRES_HOST} + DB_PORT: ${POSTGRES_PORT} + DB_PASSWORD: ${POSTGRES_PASSWORD} + DB_SCHEMA: _analytics + LOGFLARE_API_KEY: ${LOGFLARE_API_KEY} + LOGFLARE_SINGLE_TENANT: true + LOGFLARE_SUPABASE_MODE: true + LOGFLARE_MIN_CLUSTER_SIZE: 1 + POSTGRES_BACKEND_URL: postgresql://supabase_admin:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/_supabase + POSTGRES_BACKEND_SCHEMA: _analytics + LOGFLARE_FEATURE_FLAG_OVERRIDE: multibackend=true + ports: + - 4000:4000 + networks: + - cc-network + + db: + profiles: + - database + - supabase + container_name: supabase-db-${NGINX_MODE:-dev} + image: supabase/postgres:15.8.1.020 + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres -h localhost || exit 1"] + interval: 10s + timeout: 5s + retries: 20 + start_period: 30s + depends_on: + vector: + condition: service_healthy + command: + - postgres + - -c + - config_file=/etc/postgresql/postgresql.conf + - -c + - log_min_messages=fatal + restart: unless-stopped + env_file: + - .env + environment: + POSTGRES_HOST: /var/run/postgresql + PGPORT: ${POSTGRES_PORT} + POSTGRES_PORT: ${POSTGRES_PORT} + PGPASSWORD: ${POSTGRES_PASSWORD} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} + PGDATABASE: ${POSTGRES_DB} + POSTGRES_DB: ${POSTGRES_DB} + JWT_SECRET: ${JWT_SECRET} + JWT_EXP: ${JWT_EXPIRY} + volumes: + - ./supabase/db/migrations/supabase/50-_supabase.sql:/docker-entrypoint-initdb.d/migrations/50-_supabase.sql + - ./supabase/db/migrations/supabase/52-realtime.sql:/docker-entrypoint-initdb.d/migrations/52-realtime.sql + - ./supabase/db/migrations/supabase/52-pooler.sql:/docker-entrypoint-initdb.d/migrations/52-pooler.sql + - ./supabase/db/migrations/supabase/52-logs.sql:/docker-entrypoint-initdb.d/migrations/52-logs.sql + - ./supabase/db/init-scripts/51-webhooks.sql:/docker-entrypoint-initdb.d/init-scripts/51-webhooks.sql + - ./supabase/db/init-scripts/52-roles.sql:/docker-entrypoint-initdb.d/init-scripts/52-roles.sql + - ./supabase/db/init-scripts/52-jwt.sql:/docker-entrypoint-initdb.d/init-scripts/52-jwt.sql + - ./supabase/db/migrations/core/60-create-databases.sql:/docker-entrypoint-initdb.d/migrations/60-create-databases.sql + - ./supabase/db/migrations/core/61-core-schema.sql:/docker-entrypoint-initdb.d/migrations/61-core-schema.sql + - ./supabase/db/migrations/core/62-functions-triggers.sql:/docker-entrypoint-initdb.d/migrations/62-functions-triggers.sql + - ./supabase/db/migrations/core/63-storage-policies.sql:/docker-entrypoint-initdb.d/migrations/63-storage-policies.sql + - ./supabase/db/migrations/core/64-initial-admin.sql:/docker-entrypoint-initdb.d/migrations/64-initial-admin.sql + - ./supabase/db/migrations/core/65-keycloak-setup.sql:/docker-entrypoint-initdb.d/migrations/65-keycloak-setup.sql + - supabase-db-data:/var/lib/postgresql/data + - supabase-db-config:/etc/postgresql-custom + networks: + - cc-network + + vector: + profiles: + - database + - supabase + container_name: supabase-vector-${NGINX_MODE:-dev} + image: timberio/vector:0.28.1-alpine + healthcheck: + test: + [ + "CMD", + "wget", + "--no-verbose", + "--tries=1", + "--spider", + "http://vector:9001/health", + ] + timeout: 10s + interval: 10s + retries: 10 + volumes: + - ./supabase/logs/vector.yml:/etc/vector/vector.yml:ro + - /var/run/docker.sock:/var/run/docker.sock:ro + env_file: + - .env + environment: + LOGFLARE_API_KEY: ${LOGFLARE_API_KEY} + command: ["--config", "/etc/vector/vector.yml"] + networks: + - cc-network + + supavisor: + profiles: + - database + - supabase + container_name: supabase-pooler-${NGINX_MODE:-dev} + image: supabase/supavisor:1.1.56 + healthcheck: + test: curl -sSfL --head -o /dev/null "http://127.0.0.1:4000/api/health" + interval: 10s + timeout: 10s + retries: 10 + depends_on: + db: + condition: service_healthy + analytics: + condition: service_healthy + command: + - /bin/sh + - -c + - /app/bin/migrate && /app/bin/supavisor eval "$$(cat /etc/pooler/pooler.exs)" && /app/bin/server + restart: unless-stopped + ports: + - ${POSTGRES_PORT}:5432 + - ${POOLER_PROXY_PORT_TRANSACTION}:6543 + env_file: + - .env + environment: + - PORT=4000 + - POSTGRES_PORT=${POSTGRES_PORT} + - POSTGRES_DB=${POSTGRES_DB} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD} + - DATABASE_URL=ecto://supabase_admin:${POSTGRES_PASSWORD}@db:${POSTGRES_PORT}/_supabase + - CLUSTER_POSTGRES=true + - SECRET_KEY_BASE=${SECRET_KEY_BASE} + - VAULT_ENC_KEY=${VAULT_ENC_KEY} + - API_JWT_SECRET=${JWT_SECRET} + - METRICS_JWT_SECRET=${JWT_SECRET} + - REGION=local + - ERL_AFLAGS=-proto_dist inet_tcp + - POOLER_TENANT_ID=${POOLER_TENANT_ID} + - POOLER_DEFAULT_POOL_SIZE=${POOLER_DEFAULT_POOL_SIZE} + - POOLER_MAX_CLIENT_CONN=${POOLER_MAX_CLIENT_CONN} + - POOLER_POOL_MODE=transaction + volumes: + - ./supabase/pooler/pooler.exs:/etc/pooler/pooler.exs:ro + networks: + - cc-network + + ollama: + profiles: + - none + - ai_services + container_name: ollama-${NGINX_MODE:-dev} + build: + context: ./cc-volumes/ollama/docker + dockerfile: Dockerfile.${BUILD_OS}.${NGINX_MODE:-dev} + ports: + - "${PORT_OLLAMA}:11434" + volumes: + - ./local/data/ollama:/root/.ollama + - ./local/logs/ollama:/var/log/ollama + environment: + - OLLAMA_HOST=0.0.0.0 + - OLLAMA_ORIGINS=* + networks: + - cc-network + deploy: + resources: + limits: + cpus: '4' + memory: 8G + + open-webui: + profiles: + - core + - ai_services + container_name: open-webui-${NGINX_MODE:-dev} + image: ghcr.io/open-webui/open-webui:main + ports: + - "${PORT_OPEN_WEBUI:-3333}:8080" + volumes: + - ./local/${BUILD_OS}/${NGINX_MODE:-dev}/data/open-webui:/app/backend/data + - ./local/${BUILD_OS}/${NGINX_MODE:-dev}/logs/open-webui:/app/backend/logs + environment: + - OLLAMA_LOG_LEVEL=DEBUG + - WEBUI_URL=http://open-webui.classroomcopilot.test + - DEFAULT_LOCALE=en + - DEFAULT_USER_ROLE=pending where features + - ENABLE_OAUTH_SIGNUP=true + - OAUTH_CLIENT_ID=open-webui + - OAUTH_CLIENT_SECRET=${KEYCLOAK_SECRET_OPENWEBUI} + - OAUTH_PROVIDER_NAME=Keycloak + - OAUTH_SCOPES=openid,email,profile + # Optional + - OAUTH_MERGE_ACCOUNTS_BY_EMAIL=true + - OAUTH_ROLES_CLAIM=realm_access.roles + - ENABLE_OAUTH_ROLE_MANAGEMENT=true + - OAUTH_ALLOWED_ROLES=user,admin,superadmin + - OAUTH_ADMIN_ROLES=superadmin,admin + - OAUTH_ALLOWED_DOMAINS=kevlarai.test + # Keycloak + - OPENID_PROVIDER_URL=http://keycloak.kevlarai.test/realms/ClassroomCopilot/.well-known/openid-configuration + - OLLAMA_BASE_URL=http://${HOST_OLLAMA}:11434 + - PORT=8080 + - WEBUI_PORT=8080 + - HOST=0.0.0.0 + env_file: + - .env + extra_hosts: + - "keycloak.kevlarai.test=${HOST_IP}" + networks: + - cc-network + deploy: + resources: + limits: + cpus: '2' + memory: 4G + + n8n: + profiles: + - none + - ai_services + container_name: n8n-${NGINX_MODE:-dev} + build: + context: ./cc-volumes/n8n/docker + dockerfile: Dockerfile.${BUILD_OS}.${NGINX_MODE:-dev} + ports: + - "5678:5678" + volumes: + - ./local/data/n8n:/home/node/.n8n + - ./local/logs/n8n:/home/node/.n8n/logs + environment: + - N8N_HOST=0.0.0.0 + - N8N_PORT=5678 + - N8N_PROTOCOL=http + - N8N_USER_MANAGEMENT_DISABLED=true + - N8N_BASIC_AUTH_ACTIVE=false + - N8N_SECURE_COOKIE=false + - NODE_ENV=production + networks: + - cc-network + deploy: + resources: + limits: + cpus: '2' + memory: 4G + + +volumes: + supabase-db-config: + driver: local + supabase-db-data: + driver: local + neo4j-data: + driver: local + neo4j-logs: + driver: local + frontend-node-modules: + driver: local + frontend-dist: + driver: local + tldraw-sync-node-modules: + driver: local + redis-data: + driver: local + jupyter-user-data: + driver: local + +networks: + cc-network: + name: cc-network + driver: bridge diff --git a/.env b/.env new file mode 100644 index 0000000..630b92e --- /dev/null +++ b/.env @@ -0,0 +1,11 @@ +# Whisper live settings +APP_WS_PROTOCOL=wss +APP_URL=kevlarai.com + +PORT_WHISPERLIVE=5050 +PORT_WHISPERLIVE_SSL=5053 +WHISPERLIVE_SSL=false + +WHISPL_USE_CUSTOM_MODEL=false +FASTERWHISPER_MODEL=faster-whisper-large-v3 +WHISPERLIVE_URL=${APP_WS_PROTOCOL}://whisperlive.${APP_URL} diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..f62a710 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,42 @@ +FROM python:3.10-bookworm + +ARG DEBIAN_FRONTEND=noninteractive + +# Create log directories with proper permissions +RUN mkdir -p /app/logs && \ + touch /app/logs/whisperlive.log && \ + touch /app/logs/connections.log && \ + chmod 666 /app/logs/whisperlive.log && \ + chmod 666 /app/logs/connections.log + +# install lib required for pyaudio +RUN apt update && apt install -y portaudio19-dev && apt-get clean && rm -rf /var/lib/apt/lists/* + +# update pip to support for whl.metadata -> less downloading +RUN pip install --no-cache-dir -U "pip>=24" + +# create a working directory +WORKDIR /app + +# install the requirements for running the whisper-live server +COPY requirements/server.txt /app/ +RUN pip install -r server.txt && rm server.txt + +# make the paths of the nvidia libs installed as wheels visible +ENV LD_LIBRARY_PATH="/usr/local/lib/python3.10/site-packages/nvidia/cublas/lib:/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib" + +COPY whisper_live /app/whisper_live +COPY run_server.py /app + +# Copy application files +EXPOSE ${PORT_WHISPERLIVE} +ARG PORT_WHISPERLIVE +ENV PORT_WHISPERLIVE=${PORT_WHISPERLIVE} +ARG FASTERWHISPER_MODEL +ENV FASTERWHISPER_MODEL=${FASTERWHISPER_MODEL} + +CMD ["python3", "-u", "run_server.py", "--port", "${PORT_WHISPERLIVE}", "--backend", "faster_whisper"] + +# CMD ["python3", "-u", "run_server.py", "--port", "${PORT_WHISPERLIVE}", "--backend", "faster_whisper", "--faster_whisper_custom_model_path", "/app/models/${FASTERWHISPER_MODEL}", "--ssl_cert_path", "/app/ssl"] + +# CMD ["python3", "-u", "run_server.py", "--port", "${PORT_WHISPERLIVE_SSL}", "--backend", "faster_whisper", "--faster_whisper_custom_model_path", "/app/models/${FASTERWHISPER_MODEL}", "--ssl_cert_path", "/app/ssl"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..375556f --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Vineet Suryan, Collabora Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..302c559 --- /dev/null +++ b/README.md @@ -0,0 +1,219 @@ +# WhisperLive + +

+ WhisperLive + WhisperLive +

A nearly-live implementation of OpenAI's Whisper. +

+

+ +This project is a real-time transcription application that uses the OpenAI Whisper model +to convert speech input into text output. It can be used to transcribe both live audio +input from microphone and pre-recorded audio files. + +- [Installation](#installation) +- [Getting Started](#getting-started) +- [Running the Server](#running-the-server) +- [Running the Client](#running-the-client) +- [Browser Extensions](#browser-extensions) +- [Whisper Live Server in Docker](#whisper-live-server-in-docker) +- [Future Work](#future-work) +- [Blog Posts](#blog-posts) +- [Contact](#contact) +- [Citations](#citations) + +## Installation +- Install PyAudio +```bash + bash scripts/setup.sh +``` + +- Install whisper-live from pip +```bash + pip install whisper-live +``` + +### Setting up NVIDIA/TensorRT-LLM for TensorRT backend +- Please follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) for setup of [NVIDIA/TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) and for building Whisper-TensorRT engine. + +## Getting Started +The server supports 3 backends `faster_whisper`, `tensorrt` and `openvino`. If running `tensorrt` backend follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) + +### Running the Server +- [Faster Whisper](https://github.com/SYSTRAN/faster-whisper) backend +```bash +python3 run_server.py --port 9090 \ + --backend faster_whisper + +# running with custom model and cache_dir to save auto-converted ctranslate2 models +python3 run_server.py --port 9090 \ + --backend faster_whisper \ + -fw "/path/to/custom/faster/whisper/model" + -c ~/.cache/whisper-live/ +``` + +- TensorRT backend. Currently, we recommend to only use the docker setup for TensorRT. Follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) which works as expected. Make sure to build your TensorRT Engines before running the server with TensorRT backend. +```bash +# Run English only model +python3 run_server.py -p 9090 \ + -b tensorrt \ + -trt /home/TensorRT-LLM/examples/whisper/whisper_small_en + +# Run Multilingual model +python3 run_server.py -p 9090 \ + -b tensorrt \ + -trt /home/TensorRT-LLM/examples/whisper/whisper_small \ + -m +``` + +- WhisperLive now supports the [OpenVINO](https://github.com/openvinotoolkit/openvino) backend for efficient inference on Intel CPUs, iGPU and dGPUs. Currently, we tested the models uploaded to [huggingface by OpenVINO](https://huggingface.co/OpenVINO?search_models=whisper). + - > **Docker Recommended:** Running WhisperLive with OpenVINO inside Docker automatically enables GPU support (iGPU/dGPU) without requiring additional host setup. + - > **Native (non-Docker) Use:** If you prefer running outside Docker, ensure the Intel drivers and OpenVINO runtime are installed and properly configured on your system. Refer to the documentation for [installing OpenVINO](https://docs.openvino.ai/2025/get-started/install-openvino.html?PACKAGE=OPENVINO_BASE&VERSION=v_2025_0_0&OP_SYSTEM=LINUX&DISTRIBUTION=PIP#). + +``` +python3 run_server.py -p 9090 -b openvino +``` + + +#### Controlling OpenMP Threads +To control the number of threads used by OpenMP, you can set the `OMP_NUM_THREADS` environment variable. This is useful for managing CPU resources and ensuring consistent performance. If not specified, `OMP_NUM_THREADS` is set to `1` by default. You can change this by using the `--omp_num_threads` argument: +```bash +python3 run_server.py --port 9090 \ + --backend faster_whisper \ + --omp_num_threads 4 +``` + +#### Single model mode +By default, when running the server without specifying a model, the server will instantiate a new whisper model for every client connection. This has the advantage, that the server can use different model sizes, based on the client's requested model size. On the other hand, it also means you have to wait for the model to be loaded upon client connection and you will have increased (V)RAM usage. + +When serving a custom TensorRT model using the `-trt` or a custom faster_whisper model using the `-fw` option, the server will instead only instantiate the custom model once and then reuse it for all client connections. + +If you don't want this, set `--no_single_model`. + + +### Running the Client +- Initializing the client with below parameters: + - `lang`: Language of the input audio, applicable only if using a multilingual model. + - `translate`: If set to `True` then translate from any language to `en`. + - `model`: Whisper model size. + - `use_vad`: Whether to use `Voice Activity Detection` on the server. + - `save_output_recording`: Set to True to save the microphone input as a `.wav` file during live transcription. This option is helpful for recording sessions for later playback or analysis. Defaults to `False`. + - `output_recording_filename`: Specifies the `.wav` file path where the microphone input will be saved if `save_output_recording` is set to `True`. + - `max_clients`: Specifies the maximum number of clients the server should allow. Defaults to 4. + - `max_connection_time`: Maximum connection time for each client in seconds. Defaults to 600. + - `mute_audio_playback`: Whether to mute audio playback when transcribing an audio file. Defaults to False. + +```python +from whisper_live.client import TranscriptionClient +client = TranscriptionClient( + "localhost", + 9090, + lang="en", + translate=False, + model="small", # also support hf_model => `Systran/faster-whisper-small` + use_vad=False, + save_output_recording=True, # Only used for microphone input, False by Default + output_recording_filename="./output_recording.wav", # Only used for microphone input + max_clients=4, + max_connection_time=600, + mute_audio_playback=False, # Only used for file input, False by Default +) +``` +It connects to the server running on localhost at port 9090. Using a multilingual model, language for the transcription will be automatically detected. You can also use the language option to specify the target language for the transcription, in this case, English ("en"). The translate option should be set to `True` if we want to translate from the source language to English and `False` if we want to transcribe in the source language. + +- Transcribe an audio file: +```python +client("tests/jfk.wav") +``` + +- To transcribe from microphone: +```python +client() +``` + +- To transcribe from a RTSP stream: +```python +client(rtsp_url="rtsp://admin:admin@192.168.0.1/rtsp") +``` + +- To transcribe from a HLS stream: +```python +client(hls_url="http://as-hls-ww-live.akamaized.net/pool_904/live/ww/bbc_1xtra/bbc_1xtra.isml/bbc_1xtra-audio%3d96000.norewind.m3u8") +``` + +## Browser Extensions +- Run the server with your desired backend as shown [here](https://github.com/collabora/WhisperLive?tab=readme-ov-file#running-the-server). +- Transcribe audio directly from your browser using our Chrome or Firefox extensions. Refer to [Audio-Transcription-Chrome](https://github.com/collabora/whisper-live/tree/main/Audio-Transcription-Chrome#readme) and https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md + +## Whisper Live Server in Docker +- GPU + - Faster-Whisper + ```bash + docker run -it --gpus all -p 9090:9090 ghcr.io/collabora/whisperlive-gpu:latest + ``` + + - TensorRT. Refer to [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) for setup and more tensorrt backend configurations. + ```bash + docker build . -f docker/Dockerfile.tensorrt -t whisperlive-tensorrt + docker run -p 9090:9090 --runtime=nvidia --entrypoint /bin/bash -it whisperlive-tensorrt + + # Build small.en engine + bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en # float16 + bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en int8 # int8 weight only quantization + bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en int4 # int4 weight only quantization + + # Run server with small.en + python3 run_server.py --port 9090 \ + --backend tensorrt \ + --trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_en_float16" + --trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_en_int8" + --trt_model_path "/app/TensorRT-LLM-examples/whisper/whisper_small_en_int4" + ``` + + - OpenVINO + ``` + docker run -it --device=/dev/dri -p 9090:9090 ghcr.io/collabora/whisperlive-openvino + ``` + +- CPU + - Faster-whisper + ```bash + docker run -it -p 9090:9090 ghcr.io/collabora/whisperlive-cpu:latest + ``` + +## Future Work +- [ ] Add translation to other languages on top of transcription. + +## Blog Posts +- [Transforming speech technology with WhisperLive](https://www.collabora.com/news-and-blog/blog/2024/05/28/transforming-speech-technology-with-whisperlive/) +- [WhisperFusion: Ultra-low latency conversations with an AI chatbot](https://www.collabora.com/news-and-blog/news-and-events/whisperfusion-ultra-low-latency-conversations-with-an-ai-chatbot.html) powered by WhisperLive +- [Breaking language barriers 2.0: Moving closer towards fully reliable, production-ready Hindi ASR](https://www.collabora.com/news-and-blog/news-and-events/breaking-language-barriers-20-moving-closer-production-ready-hindi-asr.html) which is used in WhisperLive for hindi. + +## Contact + +We are available to help you with both Open Source and proprietary AI projects. You can reach us via the Collabora website or [vineet.suryan@collabora.com](mailto:vineet.suryan@collabora.com) and [marcus.edel@collabora.com](mailto:marcus.edel@collabora.com). + + +## Citations +```bibtex +@article{Whisper + title = {Robust Speech Recognition via Large-Scale Weak Supervision}, + url = {https://arxiv.org/abs/2212.04356}, + author = {Radford, Alec and Kim, Jong Wook and Xu, Tao and Brockman, Greg and McLeavey, Christine and Sutskever, Ilya}, + publisher = {arXiv}, + year = {2022}, +} +``` + +```bibtex +@misc{Silero VAD, + author = {Silero Team}, + title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier}, + year = {2021}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/snakers4/silero-vad}}, + email = {hello@silero.ai} +} \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/assets/jfk.flac b/assets/jfk.flac new file mode 100644 index 0000000..e44b7c1 Binary files /dev/null and b/assets/jfk.flac differ diff --git a/check_cudnn.py b/check_cudnn.py new file mode 100644 index 0000000..6515906 --- /dev/null +++ b/check_cudnn.py @@ -0,0 +1,16 @@ +import tensorflow as tf + +if tf.test.is_built_with_cuda(): + print("TF is built with CUDA") +else: + print("TF is not built with CUDA") + +if tf.test.is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): + print("CUDA is available in TF") +else: + print("CUDA is not available in TF") + +if tf.test.is_built_with_cudnn(): + print("cuDNN is available") +else: + print("cuDNN is not available") \ No newline at end of file diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..6d13566 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,36 @@ +version: '3.8' + +services: + whisperlive: + container_name: whisperlive + build: + context: . + dockerfile: Dockerfile + args: + PORT_WHISPERLIVE: ${PORT_WHISPERLIVE} + FASTERWHISPER_MODEL: ${FASTERWHISPER_MODEL} + env_file: + - .env + environment: + LOG_PATH: /app/logs + NVIDIA_VISIBLE_DEVICES: all + NVIDIA_DRIVER_CAPABILITIES: compute,utility + volumes: + - ./models:/app/models + - ./ssl:/app/ssl + - ./logs:/app/logs + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + ports: + - ${PORT_WHISPERLIVE}:${PORT_WHISPERLIVE} + networks: + - audio-network + +networks: + audio-network: + driver: bridge diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..1e77442 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,35 @@ +FROM python:3.10-bookworm + +ARG DEBIAN_FRONTEND=noninteractive + +# install lib required for pyaudio +RUN apt update && apt install -y portaudio19-dev && apt-get clean && rm -rf /var/lib/apt/lists/* + +# update pip to support for whl.metadata -> less downloading +RUN pip install --no-cache-dir -U "pip>=24" + +# create a working directory +RUN mkdir /app +WORKDIR /app + +# install the requirements for running the whisper-live server +COPY requirements/server.txt /app/ +RUN pip install --no-cache-dir -r server.txt && rm server.txt + +# make the paths of the nvidia libs installed as wheels visible. equivalent to: +# export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'` +ENV LD_LIBRARY_PATH="/usr/local/lib/python3.10/site-packages/nvidia/cublas/lib:/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib" + +EXPOSE ${WHISPERLIVE_PORT} + +COPY whisper_live /app/whisper_live +COPY models /app/models +COPY run_server.py /app + +ARG WHISPERLIVE_PORT +ENV WHISPERLIVE_PORT=${WHISPERLIVE_PORT} + +ARG FASTERWHISPER_MODEL +ENV FASTERWHISPER_MODEL=${FASTERWHISPER_MODEL} + +CMD python3 run_server.py --port $WHISPERLIVE_PORT --backend faster_whisper --faster_whisper_custom_model_path /app/models/$FASTERWHISPER_MODEL \ No newline at end of file diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu new file mode 100644 index 0000000..6f320b4 --- /dev/null +++ b/docker/Dockerfile.cpu @@ -0,0 +1,25 @@ +FROM python:3.10-bookworm + +ARG DEBIAN_FRONTEND=noninteractive + +# install lib required for pyaudio +RUN apt update && apt install -y portaudio19-dev && apt-get clean && rm -rf /var/lib/apt/lists/* + +# update pip to support for whl.metadata -> less downloading +RUN pip install --no-cache-dir -U "pip>=24" + +# create a working directory +RUN mkdir /app +WORKDIR /app + +# install pytorch, but without the nvidia-libs that are only necessary for gpu +RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu + +# install the requirements for running the whisper-live server +COPY requirements/server.txt /app/ +RUN pip install --no-cache-dir -r server.txt && rm server.txt + +COPY whisper_live /app/whisper_live +COPY run_server.py /app + +CMD ["python", "run_server.py"] diff --git a/docker/Dockerfile.gpu b/docker/Dockerfile.gpu new file mode 100644 index 0000000..e88112e --- /dev/null +++ b/docker/Dockerfile.gpu @@ -0,0 +1,26 @@ +FROM python:3.10-bookworm + +ARG DEBIAN_FRONTEND=noninteractive + +# install lib required for pyaudio +RUN apt update && apt install -y portaudio19-dev && apt-get clean && rm -rf /var/lib/apt/lists/* + +# update pip to support for whl.metadata -> less downloading +RUN pip install --no-cache-dir -U "pip>=24" + +# create a working directory +RUN mkdir /app +WORKDIR /app + +# install the requirements for running the whisper-live server +COPY requirements/server.txt /app/ +RUN pip install --no-cache-dir -r server.txt && rm server.txt + +# make the paths of the nvidia libs installed as wheels visible. equivalent to: +# export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'` +ENV LD_LIBRARY_PATH="/usr/local/lib/python3.10/site-packages/nvidia/cublas/lib:/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib" + +COPY whisper_live /app/whisper_live +COPY run_server.py /app + +CMD ["python", "run_server.py"] diff --git a/docker/Dockerfile.tensorrt b/docker/Dockerfile.tensorrt new file mode 100644 index 0000000..1263673 --- /dev/null +++ b/docker/Dockerfile.tensorrt @@ -0,0 +1,37 @@ +FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04 +ARG DEBIAN_FRONTEND=noninteractive + +# Remove any third-party apt sources to avoid issues with expiring keys. +RUN rm -f /etc/apt/sources.list.d/*.list + +# Install some basic utilities. +RUN apt-get update && apt-get install -y \ + python3.10 python3-pip openmpi-bin libopenmpi-dev git wget \ + && rm -rf /var/lib/apt/lists/* + +RUN pip3 install --no-cache-dir -U tensorrt_llm==0.9.0 --extra-index-url https://pypi.nvidia.com + +WORKDIR /app + +RUN git clone -b v0.9.0 --depth 1 https://github.com/NVIDIA/TensorRT-LLM.git && \ + mv TensorRT-LLM/examples ./TensorRT-LLM-examples && \ + rm -rf TensorRT-LLM + +COPY assets/ ./assets +RUN wget -nc -P assets/ https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz + +COPY scripts/setup.sh ./ +RUN apt update && bash setup.sh && rm setup.sh + +COPY requirements/server.txt . +RUN pip install --no-cache-dir -r server.txt && rm server.txt + +COPY whisper_live ./whisper_live +COPY scripts/build_whisper_tensorrt.sh . +COPY run_server.py . + +# Build the TensorRT engine +RUN bash build_whisper_tensorrt.sh /app/TensorRT-LLM-examples small.en + +# Set the command to run the server +CMD ["python3", "run_server.py", "--port", "9090", "--backend", "tensorrt", "--trt_model_path", "/app/TensorRT-LLM-examples/whisper/whisper_small_en"] \ No newline at end of file diff --git a/docker/docker-compose.override.yml b/docker/docker-compose.override.yml new file mode 100644 index 0000000..abf7a14 --- /dev/null +++ b/docker/docker-compose.override.yml @@ -0,0 +1,28 @@ +services: + whisperlive-server: + runtime: nvidia + build: + context: ./backend/whisperlive/server + dockerfile: Dockerfile.tensorrt # Override to use Dockerfile.tensorrt + args: + WHISPERLIVE_PORT: ${WHISPERLIVE_PORT} + env_file: + - ./.env + environment: + WHISPERLIVE_PORT: ${WHISPERLIVE_PORT} + NVIDIA_VISIBLE_DEVICES: all + NVIDIA_DRIVER_CAPABILITIES: compute,utility + volumes: + - data_volume:/data + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + ports: + - ${WHISPERLIVE_PORT}:${WHISPERLIVE_PORT} + networks: + - app-network + diff --git a/docs/.nojekyll b/docs/.nojekyll new file mode 100644 index 0000000..e69de29 diff --git a/docs/doctrees/environment.pickle b/docs/doctrees/environment.pickle new file mode 100644 index 0000000..0923b66 Binary files /dev/null and b/docs/doctrees/environment.pickle differ diff --git a/docs/doctrees/index.doctree b/docs/doctrees/index.doctree new file mode 100644 index 0000000..be33d87 Binary files /dev/null and b/docs/doctrees/index.doctree differ diff --git a/docs/html/.buildinfo b/docs/html/.buildinfo new file mode 100644 index 0000000..7598238 --- /dev/null +++ b/docs/html/.buildinfo @@ -0,0 +1,4 @@ +# Sphinx build info version 1 +# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. +config: 7b818b47e6f359b937e5a2517f120d43 +tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/html/_sources/index.rst.txt b/docs/html/_sources/index.rst.txt new file mode 100644 index 0000000..acc105e --- /dev/null +++ b/docs/html/_sources/index.rst.txt @@ -0,0 +1,26 @@ +.. whisper_live documentation master file, created by + sphinx-quickstart on Fri Sep 22 11:39:30 2023. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to Whisper Live documentation! +======================================== + +.. toctree:: + :maxdepth: 2 + + +.. automodule:: whisper_live.server + :members: + +.. automodule:: whisper_live.client + :members: + + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/html/_static/alabaster.css b/docs/html/_static/alabaster.css new file mode 100644 index 0000000..517d0b2 --- /dev/null +++ b/docs/html/_static/alabaster.css @@ -0,0 +1,703 @@ +@import url("basic.css"); + +/* -- page layout ----------------------------------------------------------- */ + +body { + font-family: Georgia, serif; + font-size: 17px; + background-color: #fff; + color: #000; + margin: 0; + padding: 0; +} + + +div.document { + width: 940px; + margin: 30px auto 0 auto; +} + +div.documentwrapper { + float: left; + width: 100%; +} + +div.bodywrapper { + margin: 0 0 0 220px; +} + +div.sphinxsidebar { + width: 220px; + font-size: 14px; + line-height: 1.5; +} + +hr { + border: 1px solid #B1B4B6; +} + +div.body { + background-color: #fff; + color: #3E4349; + padding: 0 30px 0 30px; +} + +div.body > .section { + text-align: left; +} + +div.footer { + width: 940px; + margin: 20px auto 30px auto; + font-size: 14px; + color: #888; + text-align: right; +} + +div.footer a { + color: #888; +} + +p.caption { + font-family: inherit; + font-size: inherit; +} + + +div.relations { + display: none; +} + + +div.sphinxsidebar a { + color: #444; + text-decoration: none; + border-bottom: 1px dotted #999; +} + +div.sphinxsidebar a:hover { + border-bottom: 1px solid #999; +} + +div.sphinxsidebarwrapper { + padding: 18px 10px; +} + +div.sphinxsidebarwrapper p.logo { + padding: 0; + margin: -10px 0 0 0px; + text-align: center; +} + +div.sphinxsidebarwrapper h1.logo { + margin-top: -10px; + text-align: center; + margin-bottom: 5px; + text-align: left; +} + +div.sphinxsidebarwrapper h1.logo-name { + margin-top: 0px; +} + +div.sphinxsidebarwrapper p.blurb { + margin-top: 0; + font-style: normal; +} + +div.sphinxsidebar h3, +div.sphinxsidebar h4 { + font-family: Georgia, serif; + color: #444; + font-size: 24px; + font-weight: normal; + margin: 0 0 5px 0; + padding: 0; +} + +div.sphinxsidebar h4 { + font-size: 20px; +} + +div.sphinxsidebar h3 a { + color: #444; +} + +div.sphinxsidebar p.logo a, +div.sphinxsidebar h3 a, +div.sphinxsidebar p.logo a:hover, +div.sphinxsidebar h3 a:hover { + border: none; +} + +div.sphinxsidebar p { + color: #555; + margin: 10px 0; +} + +div.sphinxsidebar ul { + margin: 10px 0; + padding: 0; + color: #000; +} + +div.sphinxsidebar ul li.toctree-l1 > a { + font-size: 120%; +} + +div.sphinxsidebar ul li.toctree-l2 > a { + font-size: 110%; +} + +div.sphinxsidebar input { + border: 1px solid #CCC; + font-family: Georgia, serif; + font-size: 1em; +} + +div.sphinxsidebar hr { + border: none; + height: 1px; + color: #AAA; + background: #AAA; + + text-align: left; + margin-left: 0; + width: 50%; +} + +div.sphinxsidebar .badge { + border-bottom: none; +} + +div.sphinxsidebar .badge:hover { + border-bottom: none; +} + +/* To address an issue with donation coming after search */ +div.sphinxsidebar h3.donation { + margin-top: 10px; +} + +/* -- body styles ----------------------------------------------------------- */ + +a { + color: #004B6B; + text-decoration: underline; +} + +a:hover { + color: #6D4100; + text-decoration: underline; +} + +div.body h1, +div.body h2, +div.body h3, +div.body h4, +div.body h5, +div.body h6 { + font-family: Georgia, serif; + font-weight: normal; + margin: 30px 0px 10px 0px; + padding: 0; +} + +div.body h1 { margin-top: 0; padding-top: 0; font-size: 240%; } +div.body h2 { font-size: 180%; } +div.body h3 { font-size: 150%; } +div.body h4 { font-size: 130%; } +div.body h5 { font-size: 100%; } +div.body h6 { font-size: 100%; } + +a.headerlink { + color: #DDD; + padding: 0 4px; + text-decoration: none; +} + +a.headerlink:hover { + color: #444; + background: #EAEAEA; +} + +div.body p, div.body dd, div.body li { + line-height: 1.4em; +} + +div.admonition { + margin: 20px 0px; + padding: 10px 30px; + background-color: #EEE; + border: 1px solid #CCC; +} + +div.admonition tt.xref, div.admonition code.xref, div.admonition a tt { + background-color: #FBFBFB; + border-bottom: 1px solid #fafafa; +} + +div.admonition p.admonition-title { + font-family: Georgia, serif; + font-weight: normal; + font-size: 24px; + margin: 0 0 10px 0; + padding: 0; + line-height: 1; +} + +div.admonition p.last { + margin-bottom: 0; +} + +div.highlight { + background-color: #fff; +} + +dt:target, .highlight { + background: #FAF3E8; +} + +div.warning { + background-color: #FCC; + border: 1px solid #FAA; +} + +div.danger { + background-color: #FCC; + border: 1px solid #FAA; + -moz-box-shadow: 2px 2px 4px #D52C2C; + -webkit-box-shadow: 2px 2px 4px #D52C2C; + box-shadow: 2px 2px 4px #D52C2C; +} + +div.error { + background-color: #FCC; + border: 1px solid #FAA; + -moz-box-shadow: 2px 2px 4px #D52C2C; + -webkit-box-shadow: 2px 2px 4px #D52C2C; + box-shadow: 2px 2px 4px #D52C2C; +} + +div.caution { + background-color: #FCC; + border: 1px solid #FAA; +} + +div.attention { + background-color: #FCC; + border: 1px solid #FAA; +} + +div.important { + background-color: #EEE; + border: 1px solid #CCC; +} + +div.note { + background-color: #EEE; + border: 1px solid #CCC; +} + +div.tip { + background-color: #EEE; + border: 1px solid #CCC; +} + +div.hint { + background-color: #EEE; + border: 1px solid #CCC; +} + +div.seealso { + background-color: #EEE; + border: 1px solid #CCC; +} + +div.topic { + background-color: #EEE; +} + +p.admonition-title { + display: inline; +} + +p.admonition-title:after { + content: ":"; +} + +pre, tt, code { + font-family: 'Consolas', 'Menlo', 'DejaVu Sans Mono', 'Bitstream Vera Sans Mono', monospace; + font-size: 0.9em; +} + +.hll { + background-color: #FFC; + margin: 0 -12px; + padding: 0 12px; + display: block; +} + +img.screenshot { +} + +tt.descname, tt.descclassname, code.descname, code.descclassname { + font-size: 0.95em; +} + +tt.descname, code.descname { + padding-right: 0.08em; +} + +img.screenshot { + -moz-box-shadow: 2px 2px 4px #EEE; + -webkit-box-shadow: 2px 2px 4px #EEE; + box-shadow: 2px 2px 4px #EEE; +} + +table.docutils { + border: 1px solid #888; + -moz-box-shadow: 2px 2px 4px #EEE; + -webkit-box-shadow: 2px 2px 4px #EEE; + box-shadow: 2px 2px 4px #EEE; +} + +table.docutils td, table.docutils th { + border: 1px solid #888; + padding: 0.25em 0.7em; +} + +table.field-list, table.footnote { + border: none; + -moz-box-shadow: none; + -webkit-box-shadow: none; + box-shadow: none; +} + +table.footnote { + margin: 15px 0; + width: 100%; + border: 1px solid #EEE; + background: #FDFDFD; + font-size: 0.9em; +} + +table.footnote + table.footnote { + margin-top: -15px; + border-top: none; +} + +table.field-list th { + padding: 0 0.8em 0 0; +} + +table.field-list td { + padding: 0; +} + +table.field-list p { + margin-bottom: 0.8em; +} + +/* Cloned from + * https://github.com/sphinx-doc/sphinx/commit/ef60dbfce09286b20b7385333d63a60321784e68 + */ +.field-name { + -moz-hyphens: manual; + -ms-hyphens: manual; + -webkit-hyphens: manual; + hyphens: manual; +} + +table.footnote td.label { + width: .1px; + padding: 0.3em 0 0.3em 0.5em; +} + +table.footnote td { + padding: 0.3em 0.5em; +} + +dl { + margin-left: 0; + margin-right: 0; + margin-top: 0; + padding: 0; +} + +dl dd { + margin-left: 30px; +} + +blockquote { + margin: 0 0 0 30px; + padding: 0; +} + +ul, ol { + /* Matches the 30px from the narrow-screen "li > ul" selector below */ + margin: 10px 0 10px 30px; + padding: 0; +} + +pre { + background: #EEE; + padding: 7px 30px; + margin: 15px 0px; + line-height: 1.3em; +} + +div.viewcode-block:target { + background: #ffd; +} + +dl pre, blockquote pre, li pre { + margin-left: 0; + padding-left: 30px; +} + +tt, code { + background-color: #ecf0f3; + color: #222; + /* padding: 1px 2px; */ +} + +tt.xref, code.xref, a tt { + background-color: #FBFBFB; + border-bottom: 1px solid #fff; +} + +a.reference { + text-decoration: none; + border-bottom: 1px dotted #004B6B; +} + +/* Don't put an underline on images */ +a.image-reference, a.image-reference:hover { + border-bottom: none; +} + +a.reference:hover { + border-bottom: 1px solid #6D4100; +} + +a.footnote-reference { + text-decoration: none; + font-size: 0.7em; + vertical-align: top; + border-bottom: 1px dotted #004B6B; +} + +a.footnote-reference:hover { + border-bottom: 1px solid #6D4100; +} + +a:hover tt, a:hover code { + background: #EEE; +} + + +@media screen and (max-width: 870px) { + + div.sphinxsidebar { + display: none; + } + + div.document { + width: 100%; + + } + + div.documentwrapper { + margin-left: 0; + margin-top: 0; + margin-right: 0; + margin-bottom: 0; + } + + div.bodywrapper { + margin-top: 0; + margin-right: 0; + margin-bottom: 0; + margin-left: 0; + } + + ul { + margin-left: 0; + } + + li > ul { + /* Matches the 30px from the "ul, ol" selector above */ + margin-left: 30px; + } + + .document { + width: auto; + } + + .footer { + width: auto; + } + + .bodywrapper { + margin: 0; + } + + .footer { + width: auto; + } + + .github { + display: none; + } + + + +} + + + +@media screen and (max-width: 875px) { + + body { + margin: 0; + padding: 20px 30px; + } + + div.documentwrapper { + float: none; + background: #fff; + } + + div.sphinxsidebar { + display: block; + float: none; + width: 102.5%; + margin: 50px -30px -20px -30px; + padding: 10px 20px; + background: #333; + color: #FFF; + } + + div.sphinxsidebar h3, div.sphinxsidebar h4, div.sphinxsidebar p, + div.sphinxsidebar h3 a { + color: #fff; + } + + div.sphinxsidebar a { + color: #AAA; + } + + div.sphinxsidebar p.logo { + display: none; + } + + div.document { + width: 100%; + margin: 0; + } + + div.footer { + display: none; + } + + div.bodywrapper { + margin: 0; + } + + div.body { + min-height: 0; + padding: 0; + } + + .rtd_doc_footer { + display: none; + } + + .document { + width: auto; + } + + .footer { + width: auto; + } + + .footer { + width: auto; + } + + .github { + display: none; + } +} + + +/* misc. */ + +.revsys-inline { + display: none!important; +} + +/* Make nested-list/multi-paragraph items look better in Releases changelog + * pages. Without this, docutils' magical list fuckery causes inconsistent + * formatting between different release sub-lists. + */ +div#changelog > div.section > ul > li > p:only-child { + margin-bottom: 0; +} + +/* Hide fugly table cell borders in ..bibliography:: directive output */ +table.docutils.citation, table.docutils.citation td, table.docutils.citation th { + border: none; + /* Below needed in some edge cases; if not applied, bottom shadows appear */ + -moz-box-shadow: none; + -webkit-box-shadow: none; + box-shadow: none; +} + + +/* relbar */ + +.related { + line-height: 30px; + width: 100%; + font-size: 0.9rem; +} + +.related.top { + border-bottom: 1px solid #EEE; + margin-bottom: 20px; +} + +.related.bottom { + border-top: 1px solid #EEE; +} + +.related ul { + padding: 0; + margin: 0; + list-style: none; +} + +.related li { + display: inline; +} + +nav#rellinks { + float: right; +} + +nav#rellinks li+li:before { + content: "|"; +} + +nav#breadcrumbs li+li:before { + content: "\00BB"; +} + +/* Hide certain items when printing */ +@media print { + div.related { + display: none; + } +} \ No newline at end of file diff --git a/docs/html/_static/basic.css b/docs/html/_static/basic.css new file mode 100644 index 0000000..30fee9d --- /dev/null +++ b/docs/html/_static/basic.css @@ -0,0 +1,925 @@ +/* + * basic.css + * ~~~~~~~~~ + * + * Sphinx stylesheet -- basic theme. + * + * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS. + * :license: BSD, see LICENSE for details. + * + */ + +/* -- main layout ----------------------------------------------------------- */ + +div.clearer { + clear: both; +} + +div.section::after { + display: block; + content: ''; + clear: left; +} + +/* -- relbar ---------------------------------------------------------------- */ + +div.related { + width: 100%; + font-size: 90%; +} + +div.related h3 { + display: none; +} + +div.related ul { + margin: 0; + padding: 0 0 0 10px; + list-style: none; +} + +div.related li { + display: inline; +} + +div.related li.right { + float: right; + margin-right: 5px; +} + +/* -- sidebar --------------------------------------------------------------- */ + +div.sphinxsidebarwrapper { + padding: 10px 5px 0 10px; +} + +div.sphinxsidebar { + float: left; + width: 230px; + margin-left: -100%; + font-size: 90%; + word-wrap: break-word; + overflow-wrap : break-word; +} + +div.sphinxsidebar ul { + list-style: none; +} + +div.sphinxsidebar ul ul, +div.sphinxsidebar ul.want-points { + margin-left: 20px; + list-style: square; +} + +div.sphinxsidebar ul ul { + margin-top: 0; + margin-bottom: 0; +} + +div.sphinxsidebar form { + margin-top: 10px; +} + +div.sphinxsidebar input { + border: 1px solid #98dbcc; + font-family: sans-serif; + font-size: 1em; +} + +div.sphinxsidebar #searchbox form.search { + overflow: hidden; +} + +div.sphinxsidebar #searchbox input[type="text"] { + float: left; + width: 80%; + padding: 0.25em; + box-sizing: border-box; +} + +div.sphinxsidebar #searchbox input[type="submit"] { + float: left; + width: 20%; + border-left: none; + padding: 0.25em; + box-sizing: border-box; +} + + +img { + border: 0; + max-width: 100%; +} + +/* -- search page ----------------------------------------------------------- */ + +ul.search { + margin: 10px 0 0 20px; + padding: 0; +} + +ul.search li { + padding: 5px 0 5px 20px; + background-image: url(file.png); + background-repeat: no-repeat; + background-position: 0 7px; +} + +ul.search li a { + font-weight: bold; +} + +ul.search li p.context { + color: #888; + margin: 2px 0 0 30px; + text-align: left; +} + +ul.keywordmatches li.goodmatch a { + font-weight: bold; +} + +/* -- index page ------------------------------------------------------------ */ + +table.contentstable { + width: 90%; + margin-left: auto; + margin-right: auto; +} + +table.contentstable p.biglink { + line-height: 150%; +} + +a.biglink { + font-size: 1.3em; +} + +span.linkdescr { + font-style: italic; + padding-top: 5px; + font-size: 90%; +} + +/* -- general index --------------------------------------------------------- */ + +table.indextable { + width: 100%; +} + +table.indextable td { + text-align: left; + vertical-align: top; +} + +table.indextable ul { + margin-top: 0; + margin-bottom: 0; + list-style-type: none; +} + +table.indextable > tbody > tr > td > ul { + padding-left: 0em; +} + +table.indextable tr.pcap { + height: 10px; +} + +table.indextable tr.cap { + margin-top: 10px; + background-color: #f2f2f2; +} + +img.toggler { + margin-right: 3px; + margin-top: 3px; + cursor: pointer; +} + +div.modindex-jumpbox { + border-top: 1px solid #ddd; + border-bottom: 1px solid #ddd; + margin: 1em 0 1em 0; + padding: 0.4em; +} + +div.genindex-jumpbox { + border-top: 1px solid #ddd; + border-bottom: 1px solid #ddd; + margin: 1em 0 1em 0; + padding: 0.4em; +} + +/* -- domain module index --------------------------------------------------- */ + +table.modindextable td { + padding: 2px; + border-collapse: collapse; +} + +/* -- general body styles --------------------------------------------------- */ + +div.body { + min-width: 360px; + max-width: 800px; +} + +div.body p, div.body dd, div.body li, div.body blockquote { + -moz-hyphens: auto; + -ms-hyphens: auto; + -webkit-hyphens: auto; + hyphens: auto; +} + +a.headerlink { + visibility: hidden; +} + +a:visited { + color: #551A8B; +} + +h1:hover > a.headerlink, +h2:hover > a.headerlink, +h3:hover > a.headerlink, +h4:hover > a.headerlink, +h5:hover > a.headerlink, +h6:hover > a.headerlink, +dt:hover > a.headerlink, +caption:hover > a.headerlink, +p.caption:hover > a.headerlink, +div.code-block-caption:hover > a.headerlink { + visibility: visible; +} + +div.body p.caption { + text-align: inherit; +} + +div.body td { + text-align: left; +} + +.first { + margin-top: 0 !important; +} + +p.rubric { + margin-top: 30px; + font-weight: bold; +} + +img.align-left, figure.align-left, .figure.align-left, object.align-left { + clear: left; + float: left; + margin-right: 1em; +} + +img.align-right, figure.align-right, .figure.align-right, object.align-right { + clear: right; + float: right; + margin-left: 1em; +} + +img.align-center, figure.align-center, .figure.align-center, object.align-center { + display: block; + margin-left: auto; + margin-right: auto; +} + +img.align-default, figure.align-default, .figure.align-default { + display: block; + margin-left: auto; + margin-right: auto; +} + +.align-left { + text-align: left; +} + +.align-center { + text-align: center; +} + +.align-default { + text-align: center; +} + +.align-right { + text-align: right; +} + +/* -- sidebars -------------------------------------------------------------- */ + +div.sidebar, +aside.sidebar { + margin: 0 0 0.5em 1em; + border: 1px solid #ddb; + padding: 7px; + background-color: #ffe; + width: 40%; + float: right; + clear: right; + overflow-x: auto; +} + +p.sidebar-title { + font-weight: bold; +} + +nav.contents, +aside.topic, +div.admonition, div.topic, blockquote { + clear: left; +} + +/* -- topics ---------------------------------------------------------------- */ + +nav.contents, +aside.topic, +div.topic { + border: 1px solid #ccc; + padding: 7px; + margin: 10px 0 10px 0; +} + +p.topic-title { + font-size: 1.1em; + font-weight: bold; + margin-top: 10px; +} + +/* -- admonitions ----------------------------------------------------------- */ + +div.admonition { + margin-top: 10px; + margin-bottom: 10px; + padding: 7px; +} + +div.admonition dt { + font-weight: bold; +} + +p.admonition-title { + margin: 0px 10px 5px 0px; + font-weight: bold; +} + +div.body p.centered { + text-align: center; + margin-top: 25px; +} + +/* -- content of sidebars/topics/admonitions -------------------------------- */ + +div.sidebar > :last-child, +aside.sidebar > :last-child, +nav.contents > :last-child, +aside.topic > :last-child, +div.topic > :last-child, +div.admonition > :last-child { + margin-bottom: 0; +} + +div.sidebar::after, +aside.sidebar::after, +nav.contents::after, +aside.topic::after, +div.topic::after, +div.admonition::after, +blockquote::after { + display: block; + content: ''; + clear: both; +} + +/* -- tables ---------------------------------------------------------------- */ + +table.docutils { + margin-top: 10px; + margin-bottom: 10px; + border: 0; + border-collapse: collapse; +} + +table.align-center { + margin-left: auto; + margin-right: auto; +} + +table.align-default { + margin-left: auto; + margin-right: auto; +} + +table caption span.caption-number { + font-style: italic; +} + +table caption span.caption-text { +} + +table.docutils td, table.docutils th { + padding: 1px 8px 1px 5px; + border-top: 0; + border-left: 0; + border-right: 0; + border-bottom: 1px solid #aaa; +} + +th { + text-align: left; + padding-right: 5px; +} + +table.citation { + border-left: solid 1px gray; + margin-left: 1px; +} + +table.citation td { + border-bottom: none; +} + +th > :first-child, +td > :first-child { + margin-top: 0px; +} + +th > :last-child, +td > :last-child { + margin-bottom: 0px; +} + +/* -- figures --------------------------------------------------------------- */ + +div.figure, figure { + margin: 0.5em; + padding: 0.5em; +} + +div.figure p.caption, figcaption { + padding: 0.3em; +} + +div.figure p.caption span.caption-number, +figcaption span.caption-number { + font-style: italic; +} + +div.figure p.caption span.caption-text, +figcaption span.caption-text { +} + +/* -- field list styles ----------------------------------------------------- */ + +table.field-list td, table.field-list th { + border: 0 !important; +} + +.field-list ul { + margin: 0; + padding-left: 1em; +} + +.field-list p { + margin: 0; +} + +.field-name { + -moz-hyphens: manual; + -ms-hyphens: manual; + -webkit-hyphens: manual; + hyphens: manual; +} + +/* -- hlist styles ---------------------------------------------------------- */ + +table.hlist { + margin: 1em 0; +} + +table.hlist td { + vertical-align: top; +} + +/* -- object description styles --------------------------------------------- */ + +.sig { + font-family: 'Consolas', 'Menlo', 'DejaVu Sans Mono', 'Bitstream Vera Sans Mono', monospace; +} + +.sig-name, code.descname { + background-color: transparent; + font-weight: bold; +} + +.sig-name { + font-size: 1.1em; +} + +code.descname { + font-size: 1.2em; +} + +.sig-prename, code.descclassname { + background-color: transparent; +} + +.optional { + font-size: 1.3em; +} + +.sig-paren { + font-size: larger; +} + +.sig-param.n { + font-style: italic; +} + +/* C++ specific styling */ + +.sig-inline.c-texpr, +.sig-inline.cpp-texpr { + font-family: unset; +} + +.sig.c .k, .sig.c .kt, +.sig.cpp .k, .sig.cpp .kt { + color: #0033B3; +} + +.sig.c .m, +.sig.cpp .m { + color: #1750EB; +} + +.sig.c .s, .sig.c .sc, +.sig.cpp .s, .sig.cpp .sc { + color: #067D17; +} + + +/* -- other body styles ----------------------------------------------------- */ + +ol.arabic { + list-style: decimal; +} + +ol.loweralpha { + list-style: lower-alpha; +} + +ol.upperalpha { + list-style: upper-alpha; +} + +ol.lowerroman { + list-style: lower-roman; +} + +ol.upperroman { + list-style: upper-roman; +} + +:not(li) > ol > li:first-child > :first-child, +:not(li) > ul > li:first-child > :first-child { + margin-top: 0px; +} + +:not(li) > ol > li:last-child > :last-child, +:not(li) > ul > li:last-child > :last-child { + margin-bottom: 0px; +} + +ol.simple ol p, +ol.simple ul p, +ul.simple ol p, +ul.simple ul p { + margin-top: 0; +} + +ol.simple > li:not(:first-child) > p, +ul.simple > li:not(:first-child) > p { + margin-top: 0; +} + +ol.simple p, +ul.simple p { + margin-bottom: 0; +} + +aside.footnote > span, +div.citation > span { + float: left; +} +aside.footnote > span:last-of-type, +div.citation > span:last-of-type { + padding-right: 0.5em; +} +aside.footnote > p { + margin-left: 2em; +} +div.citation > p { + margin-left: 4em; +} +aside.footnote > p:last-of-type, +div.citation > p:last-of-type { + margin-bottom: 0em; +} +aside.footnote > p:last-of-type:after, +div.citation > p:last-of-type:after { + content: ""; + clear: both; +} + +dl.field-list { + display: grid; + grid-template-columns: fit-content(30%) auto; +} + +dl.field-list > dt { + font-weight: bold; + word-break: break-word; + padding-left: 0.5em; + padding-right: 5px; +} + +dl.field-list > dd { + padding-left: 0.5em; + margin-top: 0em; + margin-left: 0em; + margin-bottom: 0em; +} + +dl { + margin-bottom: 15px; +} + +dd > :first-child { + margin-top: 0px; +} + +dd ul, dd table { + margin-bottom: 10px; +} + +dd { + margin-top: 3px; + margin-bottom: 10px; + margin-left: 30px; +} + +.sig dd { + margin-top: 0px; + margin-bottom: 0px; +} + +.sig dl { + margin-top: 0px; + margin-bottom: 0px; +} + +dl > dd:last-child, +dl > dd:last-child > :last-child { + margin-bottom: 0; +} + +dt:target, span.highlighted { + background-color: #fbe54e; +} + +rect.highlighted { + fill: #fbe54e; +} + +dl.glossary dt { + font-weight: bold; + font-size: 1.1em; +} + +.versionmodified { + font-style: italic; +} + +.system-message { + background-color: #fda; + padding: 5px; + border: 3px solid red; +} + +.footnote:target { + background-color: #ffa; +} + +.line-block { + display: block; + margin-top: 1em; + margin-bottom: 1em; +} + +.line-block .line-block { + margin-top: 0; + margin-bottom: 0; + margin-left: 1.5em; +} + +.guilabel, .menuselection { + font-family: sans-serif; +} + +.accelerator { + text-decoration: underline; +} + +.classifier { + font-style: oblique; +} + +.classifier:before { + font-style: normal; + margin: 0 0.5em; + content: ":"; + display: inline-block; +} + +abbr, acronym { + border-bottom: dotted 1px; + cursor: help; +} + +.translated { + background-color: rgba(207, 255, 207, 0.2) +} + +.untranslated { + background-color: rgba(255, 207, 207, 0.2) +} + +/* -- code displays --------------------------------------------------------- */ + +pre { + overflow: auto; + overflow-y: hidden; /* fixes display issues on Chrome browsers */ +} + +pre, div[class*="highlight-"] { + clear: both; +} + +span.pre { + -moz-hyphens: none; + -ms-hyphens: none; + -webkit-hyphens: none; + hyphens: none; + white-space: nowrap; +} + +div[class*="highlight-"] { + margin: 1em 0; +} + +td.linenos pre { + border: 0; + background-color: transparent; + color: #aaa; +} + +table.highlighttable { + display: block; +} + +table.highlighttable tbody { + display: block; +} + +table.highlighttable tr { + display: flex; +} + +table.highlighttable td { + margin: 0; + padding: 0; +} + +table.highlighttable td.linenos { + padding-right: 0.5em; +} + +table.highlighttable td.code { + flex: 1; + overflow: hidden; +} + +.highlight .hll { + display: block; +} + +div.highlight pre, +table.highlighttable pre { + margin: 0; +} + +div.code-block-caption + div { + margin-top: 0; +} + +div.code-block-caption { + margin-top: 1em; + padding: 2px 5px; + font-size: small; +} + +div.code-block-caption code { + background-color: transparent; +} + +table.highlighttable td.linenos, +span.linenos, +div.highlight span.gp { /* gp: Generic.Prompt */ + user-select: none; + -webkit-user-select: text; /* Safari fallback only */ + -webkit-user-select: none; /* Chrome/Safari */ + -moz-user-select: none; /* Firefox */ + -ms-user-select: none; /* IE10+ */ +} + +div.code-block-caption span.caption-number { + padding: 0.1em 0.3em; + font-style: italic; +} + +div.code-block-caption span.caption-text { +} + +div.literal-block-wrapper { + margin: 1em 0; +} + +code.xref, a code { + background-color: transparent; + font-weight: bold; +} + +h1 code, h2 code, h3 code, h4 code, h5 code, h6 code { + background-color: transparent; +} + +.viewcode-link { + float: right; +} + +.viewcode-back { + float: right; + font-family: sans-serif; +} + +div.viewcode-block:target { + margin: -1px -10px; + padding: 0 10px; +} + +/* -- math display ---------------------------------------------------------- */ + +img.math { + vertical-align: middle; +} + +div.body div.math p { + text-align: center; +} + +span.eqno { + float: right; +} + +span.eqno a.headerlink { + position: absolute; + z-index: 1; +} + +div.math:hover a.headerlink { + visibility: visible; +} + +/* -- printout stylesheet --------------------------------------------------- */ + +@media print { + div.document, + div.documentwrapper, + div.bodywrapper { + margin: 0 !important; + width: 100%; + } + + div.sphinxsidebar, + div.related, + div.footer, + #top-link { + display: none; + } +} \ No newline at end of file diff --git a/docs/html/_static/custom.css b/docs/html/_static/custom.css new file mode 100644 index 0000000..2a924f1 --- /dev/null +++ b/docs/html/_static/custom.css @@ -0,0 +1 @@ +/* This file intentionally left blank. */ diff --git a/docs/html/_static/doctools.js b/docs/html/_static/doctools.js new file mode 100644 index 0000000..d06a71d --- /dev/null +++ b/docs/html/_static/doctools.js @@ -0,0 +1,156 @@ +/* + * doctools.js + * ~~~~~~~~~~~ + * + * Base JavaScript utilities for all Sphinx HTML documentation. + * + * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS. + * :license: BSD, see LICENSE for details. + * + */ +"use strict"; + +const BLACKLISTED_KEY_CONTROL_ELEMENTS = new Set([ + "TEXTAREA", + "INPUT", + "SELECT", + "BUTTON", +]); + +const _ready = (callback) => { + if (document.readyState !== "loading") { + callback(); + } else { + document.addEventListener("DOMContentLoaded", callback); + } +}; + +/** + * Small JavaScript module for the documentation. + */ +const Documentation = { + init: () => { + Documentation.initDomainIndexTable(); + Documentation.initOnKeyListeners(); + }, + + /** + * i18n support + */ + TRANSLATIONS: {}, + PLURAL_EXPR: (n) => (n === 1 ? 0 : 1), + LOCALE: "unknown", + + // gettext and ngettext don't access this so that the functions + // can safely bound to a different name (_ = Documentation.gettext) + gettext: (string) => { + const translated = Documentation.TRANSLATIONS[string]; + switch (typeof translated) { + case "undefined": + return string; // no translation + case "string": + return translated; // translation exists + default: + return translated[0]; // (singular, plural) translation tuple exists + } + }, + + ngettext: (singular, plural, n) => { + const translated = Documentation.TRANSLATIONS[singular]; + if (typeof translated !== "undefined") + return translated[Documentation.PLURAL_EXPR(n)]; + return n === 1 ? singular : plural; + }, + + addTranslations: (catalog) => { + Object.assign(Documentation.TRANSLATIONS, catalog.messages); + Documentation.PLURAL_EXPR = new Function( + "n", + `return (${catalog.plural_expr})` + ); + Documentation.LOCALE = catalog.locale; + }, + + /** + * helper function to focus on search bar + */ + focusSearchBar: () => { + document.querySelectorAll("input[name=q]")[0]?.focus(); + }, + + /** + * Initialise the domain index toggle buttons + */ + initDomainIndexTable: () => { + const toggler = (el) => { + const idNumber = el.id.substr(7); + const toggledRows = document.querySelectorAll(`tr.cg-${idNumber}`); + if (el.src.substr(-9) === "minus.png") { + el.src = `${el.src.substr(0, el.src.length - 9)}plus.png`; + toggledRows.forEach((el) => (el.style.display = "none")); + } else { + el.src = `${el.src.substr(0, el.src.length - 8)}minus.png`; + toggledRows.forEach((el) => (el.style.display = "")); + } + }; + + const togglerElements = document.querySelectorAll("img.toggler"); + togglerElements.forEach((el) => + el.addEventListener("click", (event) => toggler(event.currentTarget)) + ); + togglerElements.forEach((el) => (el.style.display = "")); + if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) togglerElements.forEach(toggler); + }, + + initOnKeyListeners: () => { + // only install a listener if it is really needed + if ( + !DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS && + !DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS + ) + return; + + document.addEventListener("keydown", (event) => { + // bail for input elements + if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; + // bail with special keys + if (event.altKey || event.ctrlKey || event.metaKey) return; + + if (!event.shiftKey) { + switch (event.key) { + case "ArrowLeft": + if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; + + const prevLink = document.querySelector('link[rel="prev"]'); + if (prevLink && prevLink.href) { + window.location.href = prevLink.href; + event.preventDefault(); + } + break; + case "ArrowRight": + if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; + + const nextLink = document.querySelector('link[rel="next"]'); + if (nextLink && nextLink.href) { + window.location.href = nextLink.href; + event.preventDefault(); + } + break; + } + } + + // some keyboard layouts may need Shift to get / + switch (event.key) { + case "/": + if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) break; + Documentation.focusSearchBar(); + event.preventDefault(); + } + }); + }, +}; + +// quick alias for translations +const _ = Documentation.gettext; + +_ready(Documentation.init); diff --git a/docs/html/_static/documentation_options.js b/docs/html/_static/documentation_options.js new file mode 100644 index 0000000..7e4c114 --- /dev/null +++ b/docs/html/_static/documentation_options.js @@ -0,0 +1,13 @@ +const DOCUMENTATION_OPTIONS = { + VERSION: '', + LANGUAGE: 'en', + COLLAPSE_INDEX: false, + BUILDER: 'html', + FILE_SUFFIX: '.html', + LINK_SUFFIX: '.html', + HAS_SOURCE: true, + SOURCELINK_SUFFIX: '.txt', + NAVIGATION_WITH_KEYS: false, + SHOW_SEARCH_SUMMARY: true, + ENABLE_SEARCH_SHORTCUTS: true, +}; \ No newline at end of file diff --git a/docs/html/_static/file.png b/docs/html/_static/file.png new file mode 100644 index 0000000..a858a41 Binary files /dev/null and b/docs/html/_static/file.png differ diff --git a/docs/html/_static/language_data.js b/docs/html/_static/language_data.js new file mode 100644 index 0000000..250f566 --- /dev/null +++ b/docs/html/_static/language_data.js @@ -0,0 +1,199 @@ +/* + * language_data.js + * ~~~~~~~~~~~~~~~~ + * + * This script contains the language-specific data used by searchtools.js, + * namely the list of stopwords, stemmer, scorer and splitter. + * + * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS. + * :license: BSD, see LICENSE for details. + * + */ + +var stopwords = ["a", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", "into", "is", "it", "near", "no", "not", "of", "on", "or", "such", "that", "the", "their", "then", "there", "these", "they", "this", "to", "was", "will", "with"]; + + +/* Non-minified version is copied as a separate JS file, is available */ + +/** + * Porter Stemmer + */ +var Stemmer = function() { + + var step2list = { + ational: 'ate', + tional: 'tion', + enci: 'ence', + anci: 'ance', + izer: 'ize', + bli: 'ble', + alli: 'al', + entli: 'ent', + eli: 'e', + ousli: 'ous', + ization: 'ize', + ation: 'ate', + ator: 'ate', + alism: 'al', + iveness: 'ive', + fulness: 'ful', + ousness: 'ous', + aliti: 'al', + iviti: 'ive', + biliti: 'ble', + logi: 'log' + }; + + var step3list = { + icate: 'ic', + ative: '', + alize: 'al', + iciti: 'ic', + ical: 'ic', + ful: '', + ness: '' + }; + + var c = "[^aeiou]"; // consonant + var v = "[aeiouy]"; // vowel + var C = c + "[^aeiouy]*"; // consonant sequence + var V = v + "[aeiou]*"; // vowel sequence + + var mgr0 = "^(" + C + ")?" + V + C; // [C]VC... is m>0 + var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1 + var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1 + var s_v = "^(" + C + ")?" + v; // vowel in stem + + this.stemWord = function (w) { + var stem; + var suffix; + var firstch; + var origword = w; + + if (w.length < 3) + return w; + + var re; + var re2; + var re3; + var re4; + + firstch = w.substr(0,1); + if (firstch == "y") + w = firstch.toUpperCase() + w.substr(1); + + // Step 1a + re = /^(.+?)(ss|i)es$/; + re2 = /^(.+?)([^s])s$/; + + if (re.test(w)) + w = w.replace(re,"$1$2"); + else if (re2.test(w)) + w = w.replace(re2,"$1$2"); + + // Step 1b + re = /^(.+?)eed$/; + re2 = /^(.+?)(ed|ing)$/; + if (re.test(w)) { + var fp = re.exec(w); + re = new RegExp(mgr0); + if (re.test(fp[1])) { + re = /.$/; + w = w.replace(re,""); + } + } + else if (re2.test(w)) { + var fp = re2.exec(w); + stem = fp[1]; + re2 = new RegExp(s_v); + if (re2.test(stem)) { + w = stem; + re2 = /(at|bl|iz)$/; + re3 = new RegExp("([^aeiouylsz])\\1$"); + re4 = new RegExp("^" + C + v + "[^aeiouwxy]$"); + if (re2.test(w)) + w = w + "e"; + else if (re3.test(w)) { + re = /.$/; + w = w.replace(re,""); + } + else if (re4.test(w)) + w = w + "e"; + } + } + + // Step 1c + re = /^(.+?)y$/; + if (re.test(w)) { + var fp = re.exec(w); + stem = fp[1]; + re = new RegExp(s_v); + if (re.test(stem)) + w = stem + "i"; + } + + // Step 2 + re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/; + if (re.test(w)) { + var fp = re.exec(w); + stem = fp[1]; + suffix = fp[2]; + re = new RegExp(mgr0); + if (re.test(stem)) + w = stem + step2list[suffix]; + } + + // Step 3 + re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/; + if (re.test(w)) { + var fp = re.exec(w); + stem = fp[1]; + suffix = fp[2]; + re = new RegExp(mgr0); + if (re.test(stem)) + w = stem + step3list[suffix]; + } + + // Step 4 + re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/; + re2 = /^(.+?)(s|t)(ion)$/; + if (re.test(w)) { + var fp = re.exec(w); + stem = fp[1]; + re = new RegExp(mgr1); + if (re.test(stem)) + w = stem; + } + else if (re2.test(w)) { + var fp = re2.exec(w); + stem = fp[1] + fp[2]; + re2 = new RegExp(mgr1); + if (re2.test(stem)) + w = stem; + } + + // Step 5 + re = /^(.+?)e$/; + if (re.test(w)) { + var fp = re.exec(w); + stem = fp[1]; + re = new RegExp(mgr1); + re2 = new RegExp(meq1); + re3 = new RegExp("^" + C + v + "[^aeiouwxy]$"); + if (re.test(stem) || (re2.test(stem) && !(re3.test(stem)))) + w = stem; + } + re = /ll$/; + re2 = new RegExp(mgr1); + if (re.test(w) && re2.test(w)) { + re = /.$/; + w = w.replace(re,""); + } + + // and turn initial Y back to y + if (firstch == "y") + w = firstch.toLowerCase() + w.substr(1); + return w; + } +} + diff --git a/docs/html/_static/minus.png b/docs/html/_static/minus.png new file mode 100644 index 0000000..d96755f Binary files /dev/null and b/docs/html/_static/minus.png differ diff --git a/docs/html/_static/plus.png b/docs/html/_static/plus.png new file mode 100644 index 0000000..7107cec Binary files /dev/null and b/docs/html/_static/plus.png differ diff --git a/docs/html/_static/pygments.css b/docs/html/_static/pygments.css new file mode 100644 index 0000000..57c7df3 --- /dev/null +++ b/docs/html/_static/pygments.css @@ -0,0 +1,84 @@ +pre { line-height: 125%; } +td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } +span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } +td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } +span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } +.highlight .hll { background-color: #ffffcc } +.highlight { background: #f8f8f8; } +.highlight .c { color: #8f5902; font-style: italic } /* Comment */ +.highlight .err { color: #a40000; border: 1px solid #ef2929 } /* Error */ +.highlight .g { color: #000000 } /* Generic */ +.highlight .k { color: #004461; font-weight: bold } /* Keyword */ +.highlight .l { color: #000000 } /* Literal */ +.highlight .n { color: #000000 } /* Name */ +.highlight .o { color: #582800 } /* Operator */ +.highlight .x { color: #000000 } /* Other */ +.highlight .p { color: #000000; font-weight: bold } /* Punctuation */ +.highlight .ch { color: #8f5902; font-style: italic } /* Comment.Hashbang */ +.highlight .cm { color: #8f5902; font-style: italic } /* Comment.Multiline */ +.highlight .cp { color: #8f5902 } /* Comment.Preproc */ +.highlight .cpf { color: #8f5902; font-style: italic } /* Comment.PreprocFile */ +.highlight .c1 { color: #8f5902; font-style: italic } /* Comment.Single */ +.highlight .cs { color: #8f5902; font-style: italic } /* Comment.Special */ +.highlight .gd { color: #a40000 } /* Generic.Deleted */ +.highlight .ge { color: #000000; font-style: italic } /* Generic.Emph */ +.highlight .ges { color: #000000 } /* Generic.EmphStrong */ +.highlight .gr { color: #ef2929 } /* Generic.Error */ +.highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */ +.highlight .gi { color: #00A000 } /* Generic.Inserted */ +.highlight .go { color: #888888 } /* Generic.Output */ +.highlight .gp { color: #745334 } /* Generic.Prompt */ +.highlight .gs { color: #000000; font-weight: bold } /* Generic.Strong */ +.highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ +.highlight .gt { color: #a40000; font-weight: bold } /* Generic.Traceback */ +.highlight .kc { color: #004461; font-weight: bold } /* Keyword.Constant */ +.highlight .kd { color: #004461; font-weight: bold } /* Keyword.Declaration */ +.highlight .kn { color: #004461; font-weight: bold } /* Keyword.Namespace */ +.highlight .kp { color: #004461; font-weight: bold } /* Keyword.Pseudo */ +.highlight .kr { color: #004461; font-weight: bold } /* Keyword.Reserved */ +.highlight .kt { color: #004461; font-weight: bold } /* Keyword.Type */ +.highlight .ld { color: #000000 } /* Literal.Date */ +.highlight .m { color: #990000 } /* Literal.Number */ +.highlight .s { color: #4e9a06 } /* Literal.String */ +.highlight .na { color: #c4a000 } /* Name.Attribute */ +.highlight .nb { color: #004461 } /* Name.Builtin */ +.highlight .nc { color: #000000 } /* Name.Class */ +.highlight .no { color: #000000 } /* Name.Constant */ +.highlight .nd { color: #888888 } /* Name.Decorator */ +.highlight .ni { color: #ce5c00 } /* Name.Entity */ +.highlight .ne { color: #cc0000; font-weight: bold } /* Name.Exception */ +.highlight .nf { color: #000000 } /* Name.Function */ +.highlight .nl { color: #f57900 } /* Name.Label */ +.highlight .nn { color: #000000 } /* Name.Namespace */ +.highlight .nx { color: #000000 } /* Name.Other */ +.highlight .py { color: #000000 } /* Name.Property */ +.highlight .nt { color: #004461; font-weight: bold } /* Name.Tag */ +.highlight .nv { color: #000000 } /* Name.Variable */ +.highlight .ow { color: #004461; font-weight: bold } /* Operator.Word */ +.highlight .pm { color: #000000; font-weight: bold } /* Punctuation.Marker */ +.highlight .w { color: #f8f8f8; text-decoration: underline } /* Text.Whitespace */ +.highlight .mb { color: #990000 } /* Literal.Number.Bin */ +.highlight .mf { color: #990000 } /* Literal.Number.Float */ +.highlight .mh { color: #990000 } /* Literal.Number.Hex */ +.highlight .mi { color: #990000 } /* Literal.Number.Integer */ +.highlight .mo { color: #990000 } /* Literal.Number.Oct */ +.highlight .sa { color: #4e9a06 } /* Literal.String.Affix */ +.highlight .sb { color: #4e9a06 } /* Literal.String.Backtick */ +.highlight .sc { color: #4e9a06 } /* Literal.String.Char */ +.highlight .dl { color: #4e9a06 } /* Literal.String.Delimiter */ +.highlight .sd { color: #8f5902; font-style: italic } /* Literal.String.Doc */ +.highlight .s2 { color: #4e9a06 } /* Literal.String.Double */ +.highlight .se { color: #4e9a06 } /* Literal.String.Escape */ +.highlight .sh { color: #4e9a06 } /* Literal.String.Heredoc */ +.highlight .si { color: #4e9a06 } /* Literal.String.Interpol */ +.highlight .sx { color: #4e9a06 } /* Literal.String.Other */ +.highlight .sr { color: #4e9a06 } /* Literal.String.Regex */ +.highlight .s1 { color: #4e9a06 } /* Literal.String.Single */ +.highlight .ss { color: #4e9a06 } /* Literal.String.Symbol */ +.highlight .bp { color: #3465a4 } /* Name.Builtin.Pseudo */ +.highlight .fm { color: #000000 } /* Name.Function.Magic */ +.highlight .vc { color: #000000 } /* Name.Variable.Class */ +.highlight .vg { color: #000000 } /* Name.Variable.Global */ +.highlight .vi { color: #000000 } /* Name.Variable.Instance */ +.highlight .vm { color: #000000 } /* Name.Variable.Magic */ +.highlight .il { color: #990000 } /* Literal.Number.Integer.Long */ \ No newline at end of file diff --git a/docs/html/_static/searchtools.js b/docs/html/_static/searchtools.js new file mode 100644 index 0000000..7918c3f --- /dev/null +++ b/docs/html/_static/searchtools.js @@ -0,0 +1,574 @@ +/* + * searchtools.js + * ~~~~~~~~~~~~~~~~ + * + * Sphinx JavaScript utilities for the full-text search. + * + * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS. + * :license: BSD, see LICENSE for details. + * + */ +"use strict"; + +/** + * Simple result scoring code. + */ +if (typeof Scorer === "undefined") { + var Scorer = { + // Implement the following function to further tweak the score for each result + // The function takes a result array [docname, title, anchor, descr, score, filename] + // and returns the new score. + /* + score: result => { + const [docname, title, anchor, descr, score, filename] = result + return score + }, + */ + + // query matches the full name of an object + objNameMatch: 11, + // or matches in the last dotted part of the object name + objPartialMatch: 6, + // Additive scores depending on the priority of the object + objPrio: { + 0: 15, // used to be importantResults + 1: 5, // used to be objectResults + 2: -5, // used to be unimportantResults + }, + // Used when the priority is not in the mapping. + objPrioDefault: 0, + + // query found in title + title: 15, + partialTitle: 7, + // query found in terms + term: 5, + partialTerm: 2, + }; +} + +const _removeChildren = (element) => { + while (element && element.lastChild) element.removeChild(element.lastChild); +}; + +/** + * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Regular_Expressions#escaping + */ +const _escapeRegExp = (string) => + string.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string + +const _displayItem = (item, searchTerms, highlightTerms) => { + const docBuilder = DOCUMENTATION_OPTIONS.BUILDER; + const docFileSuffix = DOCUMENTATION_OPTIONS.FILE_SUFFIX; + const docLinkSuffix = DOCUMENTATION_OPTIONS.LINK_SUFFIX; + const showSearchSummary = DOCUMENTATION_OPTIONS.SHOW_SEARCH_SUMMARY; + const contentRoot = document.documentElement.dataset.content_root; + + const [docName, title, anchor, descr, score, _filename] = item; + + let listItem = document.createElement("li"); + let requestUrl; + let linkUrl; + if (docBuilder === "dirhtml") { + // dirhtml builder + let dirname = docName + "/"; + if (dirname.match(/\/index\/$/)) + dirname = dirname.substring(0, dirname.length - 6); + else if (dirname === "index/") dirname = ""; + requestUrl = contentRoot + dirname; + linkUrl = requestUrl; + } else { + // normal html builders + requestUrl = contentRoot + docName + docFileSuffix; + linkUrl = docName + docLinkSuffix; + } + let linkEl = listItem.appendChild(document.createElement("a")); + linkEl.href = linkUrl + anchor; + linkEl.dataset.score = score; + linkEl.innerHTML = title; + if (descr) { + listItem.appendChild(document.createElement("span")).innerHTML = + " (" + descr + ")"; + // highlight search terms in the description + if (SPHINX_HIGHLIGHT_ENABLED) // set in sphinx_highlight.js + highlightTerms.forEach((term) => _highlightText(listItem, term, "highlighted")); + } + else if (showSearchSummary) + fetch(requestUrl) + .then((responseData) => responseData.text()) + .then((data) => { + if (data) + listItem.appendChild( + Search.makeSearchSummary(data, searchTerms) + ); + // highlight search terms in the summary + if (SPHINX_HIGHLIGHT_ENABLED) // set in sphinx_highlight.js + highlightTerms.forEach((term) => _highlightText(listItem, term, "highlighted")); + }); + Search.output.appendChild(listItem); +}; +const _finishSearch = (resultCount) => { + Search.stopPulse(); + Search.title.innerText = _("Search Results"); + if (!resultCount) + Search.status.innerText = Documentation.gettext( + "Your search did not match any documents. Please make sure that all words are spelled correctly and that you've selected enough categories." + ); + else + Search.status.innerText = _( + `Search finished, found ${resultCount} page(s) matching the search query.` + ); +}; +const _displayNextItem = ( + results, + resultCount, + searchTerms, + highlightTerms, +) => { + // results left, load the summary and display it + // this is intended to be dynamic (don't sub resultsCount) + if (results.length) { + _displayItem(results.pop(), searchTerms, highlightTerms); + setTimeout( + () => _displayNextItem(results, resultCount, searchTerms, highlightTerms), + 5 + ); + } + // search finished, update title and status message + else _finishSearch(resultCount); +}; + +/** + * Default splitQuery function. Can be overridden in ``sphinx.search`` with a + * custom function per language. + * + * The regular expression works by splitting the string on consecutive characters + * that are not Unicode letters, numbers, underscores, or emoji characters. + * This is the same as ``\W+`` in Python, preserving the surrogate pair area. + */ +if (typeof splitQuery === "undefined") { + var splitQuery = (query) => query + .split(/[^\p{Letter}\p{Number}_\p{Emoji_Presentation}]+/gu) + .filter(term => term) // remove remaining empty strings +} + +/** + * Search Module + */ +const Search = { + _index: null, + _queued_query: null, + _pulse_status: -1, + + htmlToText: (htmlString) => { + const htmlElement = new DOMParser().parseFromString(htmlString, 'text/html'); + htmlElement.querySelectorAll(".headerlink").forEach((el) => { el.remove() }); + const docContent = htmlElement.querySelector('[role="main"]'); + if (docContent !== undefined) return docContent.textContent; + console.warn( + "Content block not found. Sphinx search tries to obtain it via '[role=main]'. Could you check your theme or template." + ); + return ""; + }, + + init: () => { + const query = new URLSearchParams(window.location.search).get("q"); + document + .querySelectorAll('input[name="q"]') + .forEach((el) => (el.value = query)); + if (query) Search.performSearch(query); + }, + + loadIndex: (url) => + (document.body.appendChild(document.createElement("script")).src = url), + + setIndex: (index) => { + Search._index = index; + if (Search._queued_query !== null) { + const query = Search._queued_query; + Search._queued_query = null; + Search.query(query); + } + }, + + hasIndex: () => Search._index !== null, + + deferQuery: (query) => (Search._queued_query = query), + + stopPulse: () => (Search._pulse_status = -1), + + startPulse: () => { + if (Search._pulse_status >= 0) return; + + const pulse = () => { + Search._pulse_status = (Search._pulse_status + 1) % 4; + Search.dots.innerText = ".".repeat(Search._pulse_status); + if (Search._pulse_status >= 0) window.setTimeout(pulse, 500); + }; + pulse(); + }, + + /** + * perform a search for something (or wait until index is loaded) + */ + performSearch: (query) => { + // create the required interface elements + const searchText = document.createElement("h2"); + searchText.textContent = _("Searching"); + const searchSummary = document.createElement("p"); + searchSummary.classList.add("search-summary"); + searchSummary.innerText = ""; + const searchList = document.createElement("ul"); + searchList.classList.add("search"); + + const out = document.getElementById("search-results"); + Search.title = out.appendChild(searchText); + Search.dots = Search.title.appendChild(document.createElement("span")); + Search.status = out.appendChild(searchSummary); + Search.output = out.appendChild(searchList); + + const searchProgress = document.getElementById("search-progress"); + // Some themes don't use the search progress node + if (searchProgress) { + searchProgress.innerText = _("Preparing search..."); + } + Search.startPulse(); + + // index already loaded, the browser was quick! + if (Search.hasIndex()) Search.query(query); + else Search.deferQuery(query); + }, + + /** + * execute search (requires search index to be loaded) + */ + query: (query) => { + const filenames = Search._index.filenames; + const docNames = Search._index.docnames; + const titles = Search._index.titles; + const allTitles = Search._index.alltitles; + const indexEntries = Search._index.indexentries; + + // stem the search terms and add them to the correct list + const stemmer = new Stemmer(); + const searchTerms = new Set(); + const excludedTerms = new Set(); + const highlightTerms = new Set(); + const objectTerms = new Set(splitQuery(query.toLowerCase().trim())); + splitQuery(query.trim()).forEach((queryTerm) => { + const queryTermLower = queryTerm.toLowerCase(); + + // maybe skip this "word" + // stopwords array is from language_data.js + if ( + stopwords.indexOf(queryTermLower) !== -1 || + queryTerm.match(/^\d+$/) + ) + return; + + // stem the word + let word = stemmer.stemWord(queryTermLower); + // select the correct list + if (word[0] === "-") excludedTerms.add(word.substr(1)); + else { + searchTerms.add(word); + highlightTerms.add(queryTermLower); + } + }); + + if (SPHINX_HIGHLIGHT_ENABLED) { // set in sphinx_highlight.js + localStorage.setItem("sphinx_highlight_terms", [...highlightTerms].join(" ")) + } + + // console.debug("SEARCH: searching for:"); + // console.info("required: ", [...searchTerms]); + // console.info("excluded: ", [...excludedTerms]); + + // array of [docname, title, anchor, descr, score, filename] + let results = []; + _removeChildren(document.getElementById("search-progress")); + + const queryLower = query.toLowerCase(); + for (const [title, foundTitles] of Object.entries(allTitles)) { + if (title.toLowerCase().includes(queryLower) && (queryLower.length >= title.length/2)) { + for (const [file, id] of foundTitles) { + let score = Math.round(100 * queryLower.length / title.length) + results.push([ + docNames[file], + titles[file] !== title ? `${titles[file]} > ${title}` : title, + id !== null ? "#" + id : "", + null, + score, + filenames[file], + ]); + } + } + } + + // search for explicit entries in index directives + for (const [entry, foundEntries] of Object.entries(indexEntries)) { + if (entry.includes(queryLower) && (queryLower.length >= entry.length/2)) { + for (const [file, id] of foundEntries) { + let score = Math.round(100 * queryLower.length / entry.length) + results.push([ + docNames[file], + titles[file], + id ? "#" + id : "", + null, + score, + filenames[file], + ]); + } + } + } + + // lookup as object + objectTerms.forEach((term) => + results.push(...Search.performObjectSearch(term, objectTerms)) + ); + + // lookup as search terms in fulltext + results.push(...Search.performTermsSearch(searchTerms, excludedTerms)); + + // let the scorer override scores with a custom scoring function + if (Scorer.score) results.forEach((item) => (item[4] = Scorer.score(item))); + + // now sort the results by score (in opposite order of appearance, since the + // display function below uses pop() to retrieve items) and then + // alphabetically + results.sort((a, b) => { + const leftScore = a[4]; + const rightScore = b[4]; + if (leftScore === rightScore) { + // same score: sort alphabetically + const leftTitle = a[1].toLowerCase(); + const rightTitle = b[1].toLowerCase(); + if (leftTitle === rightTitle) return 0; + return leftTitle > rightTitle ? -1 : 1; // inverted is intentional + } + return leftScore > rightScore ? 1 : -1; + }); + + // remove duplicate search results + // note the reversing of results, so that in the case of duplicates, the highest-scoring entry is kept + let seen = new Set(); + results = results.reverse().reduce((acc, result) => { + let resultStr = result.slice(0, 4).concat([result[5]]).map(v => String(v)).join(','); + if (!seen.has(resultStr)) { + acc.push(result); + seen.add(resultStr); + } + return acc; + }, []); + + results = results.reverse(); + + // for debugging + //Search.lastresults = results.slice(); // a copy + // console.info("search results:", Search.lastresults); + + // print the results + _displayNextItem(results, results.length, searchTerms, highlightTerms); + }, + + /** + * search for object names + */ + performObjectSearch: (object, objectTerms) => { + const filenames = Search._index.filenames; + const docNames = Search._index.docnames; + const objects = Search._index.objects; + const objNames = Search._index.objnames; + const titles = Search._index.titles; + + const results = []; + + const objectSearchCallback = (prefix, match) => { + const name = match[4] + const fullname = (prefix ? prefix + "." : "") + name; + const fullnameLower = fullname.toLowerCase(); + if (fullnameLower.indexOf(object) < 0) return; + + let score = 0; + const parts = fullnameLower.split("."); + + // check for different match types: exact matches of full name or + // "last name" (i.e. last dotted part) + if (fullnameLower === object || parts.slice(-1)[0] === object) + score += Scorer.objNameMatch; + else if (parts.slice(-1)[0].indexOf(object) > -1) + score += Scorer.objPartialMatch; // matches in last name + + const objName = objNames[match[1]][2]; + const title = titles[match[0]]; + + // If more than one term searched for, we require other words to be + // found in the name/title/description + const otherTerms = new Set(objectTerms); + otherTerms.delete(object); + if (otherTerms.size > 0) { + const haystack = `${prefix} ${name} ${objName} ${title}`.toLowerCase(); + if ( + [...otherTerms].some((otherTerm) => haystack.indexOf(otherTerm) < 0) + ) + return; + } + + let anchor = match[3]; + if (anchor === "") anchor = fullname; + else if (anchor === "-") anchor = objNames[match[1]][1] + "-" + fullname; + + const descr = objName + _(", in ") + title; + + // add custom score for some objects according to scorer + if (Scorer.objPrio.hasOwnProperty(match[2])) + score += Scorer.objPrio[match[2]]; + else score += Scorer.objPrioDefault; + + results.push([ + docNames[match[0]], + fullname, + "#" + anchor, + descr, + score, + filenames[match[0]], + ]); + }; + Object.keys(objects).forEach((prefix) => + objects[prefix].forEach((array) => + objectSearchCallback(prefix, array) + ) + ); + return results; + }, + + /** + * search for full-text terms in the index + */ + performTermsSearch: (searchTerms, excludedTerms) => { + // prepare search + const terms = Search._index.terms; + const titleTerms = Search._index.titleterms; + const filenames = Search._index.filenames; + const docNames = Search._index.docnames; + const titles = Search._index.titles; + + const scoreMap = new Map(); + const fileMap = new Map(); + + // perform the search on the required terms + searchTerms.forEach((word) => { + const files = []; + const arr = [ + { files: terms[word], score: Scorer.term }, + { files: titleTerms[word], score: Scorer.title }, + ]; + // add support for partial matches + if (word.length > 2) { + const escapedWord = _escapeRegExp(word); + Object.keys(terms).forEach((term) => { + if (term.match(escapedWord) && !terms[word]) + arr.push({ files: terms[term], score: Scorer.partialTerm }); + }); + Object.keys(titleTerms).forEach((term) => { + if (term.match(escapedWord) && !titleTerms[word]) + arr.push({ files: titleTerms[word], score: Scorer.partialTitle }); + }); + } + + // no match but word was a required one + if (arr.every((record) => record.files === undefined)) return; + + // found search word in contents + arr.forEach((record) => { + if (record.files === undefined) return; + + let recordFiles = record.files; + if (recordFiles.length === undefined) recordFiles = [recordFiles]; + files.push(...recordFiles); + + // set score for the word in each file + recordFiles.forEach((file) => { + if (!scoreMap.has(file)) scoreMap.set(file, {}); + scoreMap.get(file)[word] = record.score; + }); + }); + + // create the mapping + files.forEach((file) => { + if (fileMap.has(file) && fileMap.get(file).indexOf(word) === -1) + fileMap.get(file).push(word); + else fileMap.set(file, [word]); + }); + }); + + // now check if the files don't contain excluded terms + const results = []; + for (const [file, wordList] of fileMap) { + // check if all requirements are matched + + // as search terms with length < 3 are discarded + const filteredTermCount = [...searchTerms].filter( + (term) => term.length > 2 + ).length; + if ( + wordList.length !== searchTerms.size && + wordList.length !== filteredTermCount + ) + continue; + + // ensure that none of the excluded terms is in the search result + if ( + [...excludedTerms].some( + (term) => + terms[term] === file || + titleTerms[term] === file || + (terms[term] || []).includes(file) || + (titleTerms[term] || []).includes(file) + ) + ) + break; + + // select one (max) score for the file. + const score = Math.max(...wordList.map((w) => scoreMap.get(file)[w])); + // add result to the result list + results.push([ + docNames[file], + titles[file], + "", + null, + score, + filenames[file], + ]); + } + return results; + }, + + /** + * helper function to return a node containing the + * search summary for a given text. keywords is a list + * of stemmed words. + */ + makeSearchSummary: (htmlText, keywords) => { + const text = Search.htmlToText(htmlText); + if (text === "") return null; + + const textLower = text.toLowerCase(); + const actualStartPosition = [...keywords] + .map((k) => textLower.indexOf(k.toLowerCase())) + .filter((i) => i > -1) + .slice(-1)[0]; + const startWithContext = Math.max(actualStartPosition - 120, 0); + + const top = startWithContext === 0 ? "" : "..."; + const tail = startWithContext + 240 < text.length ? "..." : ""; + + let summary = document.createElement("p"); + summary.classList.add("context"); + summary.textContent = top + text.substr(startWithContext, 240).trim() + tail; + + return summary; + }, +}; + +_ready(Search.init); diff --git a/docs/html/_static/sphinx_highlight.js b/docs/html/_static/sphinx_highlight.js new file mode 100644 index 0000000..8a96c69 --- /dev/null +++ b/docs/html/_static/sphinx_highlight.js @@ -0,0 +1,154 @@ +/* Highlighting utilities for Sphinx HTML documentation. */ +"use strict"; + +const SPHINX_HIGHLIGHT_ENABLED = true + +/** + * highlight a given string on a node by wrapping it in + * span elements with the given class name. + */ +const _highlight = (node, addItems, text, className) => { + if (node.nodeType === Node.TEXT_NODE) { + const val = node.nodeValue; + const parent = node.parentNode; + const pos = val.toLowerCase().indexOf(text); + if ( + pos >= 0 && + !parent.classList.contains(className) && + !parent.classList.contains("nohighlight") + ) { + let span; + + const closestNode = parent.closest("body, svg, foreignObject"); + const isInSVG = closestNode && closestNode.matches("svg"); + if (isInSVG) { + span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); + } else { + span = document.createElement("span"); + span.classList.add(className); + } + + span.appendChild(document.createTextNode(val.substr(pos, text.length))); + const rest = document.createTextNode(val.substr(pos + text.length)); + parent.insertBefore( + span, + parent.insertBefore( + rest, + node.nextSibling + ) + ); + node.nodeValue = val.substr(0, pos); + /* There may be more occurrences of search term in this node. So call this + * function recursively on the remaining fragment. + */ + _highlight(rest, addItems, text, className); + + if (isInSVG) { + const rect = document.createElementNS( + "http://www.w3.org/2000/svg", + "rect" + ); + const bbox = parent.getBBox(); + rect.x.baseVal.value = bbox.x; + rect.y.baseVal.value = bbox.y; + rect.width.baseVal.value = bbox.width; + rect.height.baseVal.value = bbox.height; + rect.setAttribute("class", className); + addItems.push({ parent: parent, target: rect }); + } + } + } else if (node.matches && !node.matches("button, select, textarea")) { + node.childNodes.forEach((el) => _highlight(el, addItems, text, className)); + } +}; +const _highlightText = (thisNode, text, className) => { + let addItems = []; + _highlight(thisNode, addItems, text, className); + addItems.forEach((obj) => + obj.parent.insertAdjacentElement("beforebegin", obj.target) + ); +}; + +/** + * Small JavaScript module for the documentation. + */ +const SphinxHighlight = { + + /** + * highlight the search words provided in localstorage in the text + */ + highlightSearchWords: () => { + if (!SPHINX_HIGHLIGHT_ENABLED) return; // bail if no highlight + + // get and clear terms from localstorage + const url = new URL(window.location); + const highlight = + localStorage.getItem("sphinx_highlight_terms") + || url.searchParams.get("highlight") + || ""; + localStorage.removeItem("sphinx_highlight_terms") + url.searchParams.delete("highlight"); + window.history.replaceState({}, "", url); + + // get individual terms from highlight string + const terms = highlight.toLowerCase().split(/\s+/).filter(x => x); + if (terms.length === 0) return; // nothing to do + + // There should never be more than one element matching "div.body" + const divBody = document.querySelectorAll("div.body"); + const body = divBody.length ? divBody[0] : document.querySelector("body"); + window.setTimeout(() => { + terms.forEach((term) => _highlightText(body, term, "highlighted")); + }, 10); + + const searchBox = document.getElementById("searchbox"); + if (searchBox === null) return; + searchBox.appendChild( + document + .createRange() + .createContextualFragment( + '" + ) + ); + }, + + /** + * helper function to hide the search marks again + */ + hideSearchWords: () => { + document + .querySelectorAll("#searchbox .highlight-link") + .forEach((el) => el.remove()); + document + .querySelectorAll("span.highlighted") + .forEach((el) => el.classList.remove("highlighted")); + localStorage.removeItem("sphinx_highlight_terms") + }, + + initEscapeListener: () => { + // only install a listener if it is really needed + if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) return; + + document.addEventListener("keydown", (event) => { + // bail for input elements + if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; + // bail with special keys + if (event.shiftKey || event.altKey || event.ctrlKey || event.metaKey) return; + if (DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS && (event.key === "Escape")) { + SphinxHighlight.hideSearchWords(); + event.preventDefault(); + } + }); + }, +}; + +_ready(() => { + /* Do not call highlightSearchWords() when we are on the search page. + * It will highlight words from the *previous* search query. + */ + if (typeof Search === "undefined") SphinxHighlight.highlightSearchWords(); + SphinxHighlight.initEscapeListener(); +}); diff --git a/docs/html/genindex.html b/docs/html/genindex.html new file mode 100644 index 0000000..7937ca6 --- /dev/null +++ b/docs/html/genindex.html @@ -0,0 +1,281 @@ + + + + + + + Index — whisper_live documentation + + + + + + + + + + + + + + + + +
+
+
+ + +
+ + +

Index

+ +
+ A + | B + | C + | D + | F + | G + | M + | O + | P + | R + | S + | T + | U + | W + +
+

A

+ + +
+ +

B

+ + +
+ +

C

+ + + +
+ +

D

+ + +
+ +

F

+ + +
+ +

G

+ + + +
+ +

M

+ + +
+ +

O

+ + + +
+ +

P

+ + +
+ +

R

+ + + +
+ +

S

+ + + +
+ +

T

+ + + +
+ +

U

+ + +
+ +

W

+ + + +
    +
  • + whisper_live.client + +
  • +
  • + whisper_live.server + +
  • +
+ + + +
+ +
+
+ +
+
+ + + + + + + \ No newline at end of file diff --git a/docs/html/index.html b/docs/html/index.html new file mode 100644 index 0000000..c8b3e33 --- /dev/null +++ b/docs/html/index.html @@ -0,0 +1,468 @@ + + + + + + + + Welcome to Whisper Live documentation! — whisper_live documentation + + + + + + + + + + + + + + + + +
+
+
+ + +
+ +
+

Welcome to Whisper Live documentation!

+
+
+
+
+class whisper_live.server.ServeClient(websocket, task='transcribe', device=None, multilingual=False, language=None, client_uid=None)
+
+
Attributes:

RATE (int): The audio sampling rate (constant) set to 16000. +SERVER_READY (str): A constant message indicating that the server is ready. +DISCONNECT (str): A constant message indicating that the client should disconnect. +client_uid (str): A unique identifier for the client. +data (bytes): Accumulated audio data. +frames (bytes): Accumulated audio frames. +language (str): The language for transcription. +task (str): The task type, e.g., “transcribe.” +transcriber (WhisperModel): The Whisper model for speech-to-text. +timestamp_offset (float): The offset in audio timestamps. +frames_np (numpy.ndarray): NumPy array to store audio frames. +frames_offset (float): The offset in audio frames. +text (list): List of transcribed text segments. +current_out (str): The current incomplete transcription. +prev_out (str): The previous incomplete transcription. +t_start (float): Timestamp for the start of transcription. +exit (bool): A flag to exit the transcription thread. +same_output_threshold (int): Threshold for consecutive same output segments. +show_prev_out_thresh (int): Threshold for showing previous output segments. +add_pause_thresh (int): Threshold for adding a pause (blank) segment. +transcript (list): List of transcribed segments. +send_last_n_segments (int): Number of last segments to send to the client. +wrapper (textwrap.TextWrapper): Text wrapper for formatting text. +pick_previous_segments (int): Number of previous segments to include in the output. +websocket: The WebSocket connection for the client.

+
+
+
+
+add_frames(frame_np)
+

Add audio frames to the ongoing audio stream buffer.

+

This method is responsible for maintaining the audio stream buffer, allowing the continuous addition +of audio frames as they are received. It also ensures that the buffer does not exceed a specified size +to prevent excessive memory usage.

+

If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds +of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided +audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.

+
+
Args:

frame_np (numpy.ndarray): The audio frame data as a NumPy array.

+
+
+
+ +
+
+cleanup()
+

Perform cleanup tasks before exiting the transcription service.

+

This method performs necessary cleanup tasks, including stopping the transcription thread, marking +the exit flag to indicate the transcription thread should exit gracefully, and destroying resources +associated with the transcription process.

+
+ +
+
+disconnect()
+

Notify the client of disconnection and send a disconnect message.

+

This method sends a disconnect message to the client via the WebSocket connection to notify them +that the transcription service is disconnecting gracefully.

+
+ +
+
+fill_output(output)
+

Format the current incomplete transcription output by combining it with previous complete segments. +The resulting transcription is wrapped into two lines, each containing a maximum of 50 characters.

+

It ensures that the combined transcription fits within two lines, with a maximum of 50 characters per line. +Segments are concatenated in the order they exist in the list of previous segments, with the most +recent complete segment first and older segments prepended as needed to maintain the character limit. +If a 3-second pause is detected in the previous segments, any text preceding it is discarded to ensure +the transcription starts with the most recent complete content. The resulting transcription is returned +as a single string.

+
+
Args:

output(str): The current incomplete transcription segment.

+
+
Returns:

str: A formatted transcription wrapped in two lines.

+
+
+
+ +
+
+speech_to_text()
+

Process an audio stream in an infinite loop, continuously transcribing the speech.

+

This method continuously receives audio frames, performs real-time transcription, and sends +transcribed segments to the client via a WebSocket connection.

+

If the client’s language is not detected, it waits for 30 seconds of audio input to make a language prediction. +It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments +are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech +(no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if +there is no speech for a specified duration to indicate a pause.

+
+
Raises:

Exception: If there is an issue with audio processing or WebSocket communication.

+
+
+
+ +
+
+update_segments(segments, duration)
+

Processes the segments from whisper. Appends all the segments to the list +except for the last segment assuming that it is incomplete.

+

Updates the ongoing transcript with transcribed segments, including their start and end times. +Complete segments are appended to the transcript in chronological order. Incomplete segments +(assumed to be the last one) are processed to identify repeated content. If the same incomplete +segment is seen multiple times, it updates the offset and appends the segment to the transcript. +A threshold is used to detect repeated content and ensure it is only included once in the transcript. +The timestamp offset is updated based on the duration of processed segments. The method returns the +last processed segment, allowing it to be sent to the client for real-time updates.

+
+
Args:

segments(dict) : dictionary of segments as returned by whisper +duration(float): duration of the current chunk

+
+
Returns:
+
dict or None: The last processed segment with its start time, end time, and transcribed text.

Returns None if there are no valid segments to process.

+
+
+
+
+
+ +
+ +
+
+class whisper_live.server.TranscriptionServer
+

Represents a transcription server that handles incoming audio from clients.

+
+
Attributes:

RATE (int): The audio sampling rate (constant) set to 16000. +vad_model (torch.Module): The voice activity detection model. +vad_threshold (float): The voice activity detection threshold. +clients (dict): A dictionary to store connected clients. +websockets (dict): A dictionary to store WebSocket connections. +clients_start_time (dict): A dictionary to track client start times. +max_clients (int): Maximum allowed connected clients. +max_connection_time (int): Maximum allowed connection time in seconds.

+
+
+
+
+get_wait_time()
+

Calculate and return the estimated wait time for clients.

+
+
Returns:

float: The estimated wait time in minutes.

+
+
+
+ +
+
+recv_audio(websocket)
+

Receive audio chunks from a client in an infinite loop.

+

Continuously receives audio frames from a connected client +over a WebSocket connection. It processes the audio frames using a +voice activity detection (VAD) model to determine if they contain speech +or not. If the audio frame contains speech, it is added to the client’s +audio data for ASR. +If the maximum number of clients is reached, the method sends a +“WAIT” status to the client, indicating that they should wait +until a slot is available. +If a client’s connection exceeds the maximum allowed time, it will +be disconnected, and the client’s resources will be cleaned up.

+
+
Args:

websocket (WebSocket): The WebSocket connection for the client.

+
+
Raises:

Exception: If there is an error during the audio frame processing.

+
+
+
+ +
+
+run(host, port=9090)
+

Run the transcription server.

+
+
Args:

host (str): The host address to bind the server. +port (int): The port number to bind the server.

+
+
+
+ +
+ +
+
+class whisper_live.client.Client(host=None, port=None, is_multilingual=False, lang=None, translate=False)
+

Handles audio recording, streaming, and communication with a server using WebSocket.

+
+
+static bytes_to_float_array(audio_bytes)
+

Convert audio data from bytes to a NumPy float array.

+

It assumes that the audio data is in 16-bit PCM format. The audio data is normalized to +have values between -1 and 1.

+
+
Args:

audio_bytes (bytes): Audio data in bytes.

+
+
Returns:

np.ndarray: A NumPy array containing the audio data as float values normalized between -1 and 1.

+
+
+
+ +
+
+close_websocket()
+

Close the WebSocket connection and join the WebSocket thread.

+

First attempts to close the WebSocket connection using self.client_socket.close(). After +closing the connection, it joins the WebSocket thread to ensure proper termination.

+
+ +
+
+get_client_socket()
+

Get the WebSocket client socket instance.

+
+
Returns:

WebSocketApp: The WebSocket client socket instance currently in use by the client.

+
+
+
+ +
+
+on_message(ws, message)
+

Callback function called when a message is received from the server.

+

It updates various attributes of the client based on the received message, including +recording status, language detection, and server messages. If a disconnect message +is received, it sets the recording status to False.

+
+
Args:

ws (websocket.WebSocketApp): The WebSocket client instance. +message (str): The received message from the server.

+
+
+
+ +
+
+on_open(ws)
+

Callback function called when the WebSocket connection is successfully opened.

+

Sends an initial configuration message to the server, including client UID, multilingual mode, +language selection, and task type.

+
+
Args:

ws (websocket.WebSocketApp): The WebSocket client instance.

+
+
+
+ +
+
+play_file(filename)
+

Play an audio file and send it to the server for processing.

+

Reads an audio file, plays it through the audio output, and simultaneously sends +the audio data to the server for processing. It uses PyAudio to create an audio +stream for playback. The audio data is read from the file in chunks, converted to +floating-point format, and sent to the server using WebSocket communication. +This method is typically used when you want to process pre-recorded audio and send it +to the server in real-time.

+
+
Args:

filename (str): The path to the audio file to be played and sent to the server.

+
+
+
+ +
+
+record(out_file='output_recording.wav')
+

Record audio data from the input stream and save it to a WAV file.

+

Continuously records audio data from the input stream, sends it to the server via a WebSocket +connection, and simultaneously saves it to multiple WAV files in chunks. It stops recording when +the RECORD_SECONDS duration is reached or when the RECORDING flag is set to False.

+

Audio data is saved in chunks to the “chunks” directory. Each chunk is saved as a separate WAV file. +The recording will continue until the specified duration is reached or until the RECORDING flag is set to False. +The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording, +the method combines all the saved audio chunks into the specified out_file.

+
+
Args:

out_file (str, optional): The name of the output WAV file to save the entire recording. Default is “output_recording.wav”.

+
+
+
+ +
+
+send_packet_to_server(message)
+

Send an audio packet to the server using WebSocket.

+
+
Args:

message (bytes): The audio data packet in bytes to be sent to the server.

+
+
+
+ +
+
+write_audio_frames_to_file(frames, file_name)
+

Write audio frames to a WAV file.

+

The WAV file is created or overwritten with the specified name. The audio frames should be +in the correct format and match the specified channel, sample width, and sample rate.

+
+
Args:

frames (bytes): The audio frames to be written to the file. +file_name (str): The name of the WAV file to which the frames will be written.

+
+
+
+ +
+
+write_output_recording(n_audio_file, out_file)
+

Combine and save recorded audio chunks into a single WAV file.

+

The individual audio chunk files are expected to be located in the “chunks” directory. Reads each chunk +file, appends its audio data to the final recording, and then deletes the chunk file. After combining +and saving, the final recording is stored in the specified out_file.

+
+
Args:

n_audio_file (int): The number of audio chunk files to combine. +out_file (str): The name of the output WAV file to save the final recording.

+
+
+
+ +
+ +
+
+class whisper_live.client.TranscriptionClient(host, port, is_multilingual=False, lang=None, translate=False)
+

Client for handling audio transcription tasks via a WebSocket connection.

+

Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used +to send audio data for transcription to a server and receive transcribed text segments.

+
+
Args:

host (str): The hostname or IP address of the server. +port (int): The port number to connect to on the server. +is_multilingual (bool, optional): Indicates whether the transcription should support multiple languages (default is False). +lang (str, optional): The primary language for transcription (used if is_multilingual is False). Default is None, which defaults to English (‘en’). +translate (bool, optional): Indicates whether translation tasks are required (default is False).

+
+
Attributes:

client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection.

+
+
Example:

To create a TranscriptionClient and start transcription on microphone audio: +`python +transcription_client = TranscriptionClient(host="localhost", port=9090, is_multilingual=True) +transcription_client() +`

+
+
+
+ +
+
+whisper_live.client.resample(file: str, sr: int = 16000)
+

# https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/audio.py#L22 +Open an audio file and read as mono waveform, resampling as necessary, +save the resampled audio

+
+
Args:

file (str): The audio file to open +sr (int): The sample rate to resample the audio if necessary

+
+
Returns:

resampled_file (str): The resampled audio file

+
+
+
+ +
+
+

Indices and tables

+ +
+ + +
+ +
+
+ +
+
+ + + + + + + \ No newline at end of file diff --git a/docs/html/objects.inv b/docs/html/objects.inv new file mode 100644 index 0000000..ca8d141 --- /dev/null +++ b/docs/html/objects.inv @@ -0,0 +1,5 @@ +# Sphinx inventory version 2 +# Project: whisper_live +# Version: +# The remainder of this file is compressed using zlib. +xڭUn0+\DʵJTJ^mA@A hNIvwuF8JI[̓DJ⸭l+_HYtf#SmP=:5]K @h>IvrRB'hE[tjO h#k!ǞniG*N :#:+cOw`_},Qe902f_OT \ No newline at end of file diff --git a/docs/html/py-modindex.html b/docs/html/py-modindex.html new file mode 100644 index 0000000..17f51ed --- /dev/null +++ b/docs/html/py-modindex.html @@ -0,0 +1,123 @@ + + + + + + + Python Module Index — whisper_live documentation + + + + + + + + + + + + + + + + + + + +
+
+
+ + +
+ + +

Python Module Index

+ +
+ w +
+ + + + + + + + + + + + + +
 
+ w
+ whisper_live +
    + whisper_live.client +
    + whisper_live.server +
+ + +
+ +
+
+ +
+
+ + + + + + + \ No newline at end of file diff --git a/docs/html/search.html b/docs/html/search.html new file mode 100644 index 0000000..0ae5276 --- /dev/null +++ b/docs/html/search.html @@ -0,0 +1,117 @@ + + + + + + + Search — whisper_live documentation + + + + + + + + + + + + + + + + + + + + + + +
+
+
+ + +
+ +

Search

+ + + + +

+ Searching for multiple words only shows matches that contain + all words. +

+ + +
+ + + +
+ + + +
+ +
+ + +
+ +
+
+ +
+
+ + + + + + + \ No newline at end of file diff --git a/docs/html/searchindex.js b/docs/html/searchindex.js new file mode 100644 index 0000000..d75d57a --- /dev/null +++ b/docs/html/searchindex.js @@ -0,0 +1 @@ +Search.setIndex({"docnames": ["index"], "filenames": ["index.rst"], "titles": ["Welcome to Whisper Live documentation!"], "terms": {"class": 0, "whisper_l": 0, "server": 0, "servecli": 0, "websocket": 0, "task": 0, "transcrib": 0, "devic": 0, "none": 0, "multilingu": 0, "fals": 0, "languag": 0, "client_uid": 0, "attribut": 0, "rate": 0, "int": 0, "The": 0, "audio": 0, "sampl": 0, "constant": 0, "set": 0, "16000": 0, "server_readi": 0, "str": 0, "A": 0, "messag": 0, "i": 0, "readi": 0, "disconnect": 0, "client": 0, "should": 0, "uniqu": 0, "identifi": 0, "data": 0, "byte": 0, "accumul": 0, "frame": 0, "transcript": 0, "type": 0, "e": 0, "g": 0, "whispermodel": 0, "model": 0, "speech": 0, "text": 0, "timestamp_offset": 0, "float": 0, "offset": 0, "timestamp": 0, "frames_np": 0, "numpi": 0, "ndarrai": 0, "arrai": 0, "store": 0, "frames_offset": 0, "list": 0, "segment": 0, "current_out": 0, "current": 0, "incomplet": 0, "prev_out": 0, "previou": 0, "t_start": 0, "start": 0, "exit": 0, "bool": 0, "flag": 0, "thread": 0, "same_output_threshold": 0, "threshold": 0, "consecut": 0, "same": 0, "output": 0, "show_prev_out_thresh": 0, "show": 0, "add_pause_thresh": 0, "ad": 0, "paus": 0, "blank": 0, "send_last_n_seg": 0, "number": 0, "last": 0, "send": 0, "wrapper": 0, "textwrap": 0, "textwrapp": 0, "format": 0, "pick_previous_seg": 0, "includ": 0, "connect": 0, "add_fram": 0, "frame_np": 0, "add": 0, "ongo": 0, "stream": 0, "buffer": 0, "thi": 0, "method": 0, "respons": 0, "maintain": 0, "allow": 0, "continu": 0, "addit": 0, "thei": 0, "ar": 0, "receiv": 0, "It": 0, "also": 0, "ensur": 0, "doe": 0, "exce": 0, "specifi": 0, "size": 0, "prevent": 0, "excess": 0, "memori": 0, "usag": 0, "If": 0, "45": 0, "second": 0, "discard": 0, "oldest": 0, "30": 0, "reason": 0, "empti": 0, "initi": 0, "provid": 0, "us": 0, "real": 0, "time": 0, "process": 0, "arg": 0, "cleanup": 0, "perform": 0, "befor": 0, "servic": 0, "necessari": 0, "stop": 0, "mark": 0, "gracefulli": 0, "destroi": 0, "resourc": 0, "associ": 0, "notifi": 0, "via": 0, "them": 0, "fill_output": 0, "combin": 0, "complet": 0, "result": 0, "wrap": 0, "two": 0, "line": 0, "each": 0, "contain": 0, "maximum": 0, "50": 0, "charact": 0, "fit": 0, "within": 0, "per": 0, "concaten": 0, "order": 0, "exist": 0, "most": 0, "recent": 0, "first": 0, "older": 0, "prepend": 0, "need": 0, "limit": 0, "3": 0, "detect": 0, "ani": 0, "preced": 0, "content": 0, "return": 0, "singl": 0, "string": 0, "speech_to_text": 0, "an": 0, "infinit": 0, "loop": 0, "": 0, "wait": 0, "input": 0, "make": 0, "predict": 0, "util": 0, "asr": 0, "sent": 0, "histori": 0, "context": 0, "from": 0, "handl": 0, "durat": 0, "rais": 0, "except": 0, "issu": 0, "commun": 0, "update_seg": 0, "append": 0, "all": 0, "assum": 0, "updat": 0, "end": 0, "chronolog": 0, "one": 0, "repeat": 0, "seen": 0, "multipl": 0, "onli": 0, "onc": 0, "base": 0, "dict": 0, "dictionari": 0, "chunk": 0, "its": 0, "valid": 0, "transcriptionserv": 0, "repres": 0, "incom": 0, "vad_model": 0, "torch": 0, "modul": 0, "voic": 0, "activ": 0, "vad_threshold": 0, "clients_start_tim": 0, "track": 0, "max_client": 0, "max_connection_tim": 0, "get_wait_tim": 0, "calcul": 0, "estim": 0, "minut": 0, "recv_audio": 0, "over": 0, "vad": 0, "determin": 0, "reach": 0, "statu": 0, "until": 0, "slot": 0, "avail": 0, "clean": 0, "up": 0, "error": 0, "dure": 0, "run": 0, "host": 0, "port": 0, "9090": 0, "address": 0, "bind": 0, "transcriptioncli": 0, "is_multilingu": 0, "lang": 0, "translat": 0, "act": 0, "high": 0, "level": 0, "can": 0, "hostnam": 0, "ip": 0, "option": 0, "whether": 0, "support": 0, "default": 0, "primari": 0, "which": 0, "english": 0, "en": 0, "requir": 0, "instanc": 0, "underli": 0, "exampl": 0, "To": 0, "creat": 0, "microphon": 0, "python": 0, "transcription_cli": 0, "localhost": 0, "true": 0, "resampl": 0, "file": 0, "sr": 0, "http": 0, "github": 0, "com": 0, "openai": 0, "blob": 0, "7858aa9c08d98f75575035ecd6481f462d66ca27": 0, "py": 0, "l22": 0, "open": 0, "read": 0, "mono": 0, "waveform": 0, "save": 0, "resampled_fil": 0, "index": 0, "search": 0, "page": 0, "record": 0, "static": 0, "bytes_to_float_arrai": 0, "audio_byt": 0, "convert": 0, "16": 0, "bit": 0, "pcm": 0, "normal": 0, "have": 0, "valu": 0, "between": 0, "1": 0, "np": 0, "close_websocket": 0, "close": 0, "join": 0, "attempt": 0, "self": 0, "client_socket": 0, "after": 0, "proper": 0, "termin": 0, "get_client_socket": 0, "get": 0, "socket": 0, "websocketapp": 0, "on_messag": 0, "w": 0, "callback": 0, "function": 0, "call": 0, "when": 0, "variou": 0, "on_open": 0, "successfulli": 0, "configur": 0, "uid": 0, "mode": 0, "select": 0, "play_fil": 0, "filenam": 0, "plai": 0, "through": 0, "simultan": 0, "pyaudio": 0, "playback": 0, "point": 0, "typic": 0, "you": 0, "want": 0, "pre": 0, "path": 0, "out_fil": 0, "output_record": 0, "wav": 0, "record_second": 0, "directori": 0, "separ": 0, "interrupt": 0, "keyboardinterrupt": 0, "press": 0, "ctrl": 0, "c": 0, "name": 0, "entir": 0, "send_packet_to_serv": 0, "packet": 0, "write_audio_frames_to_fil": 0, "file_nam": 0, "write": 0, "overwritten": 0, "correct": 0, "match": 0, "channel": 0, "width": 0, "written": 0, "write_output_record": 0, "n_audio_fil": 0, "individu": 0, "expect": 0, "locat": 0, "final": 0, "delet": 0}, "objects": {"whisper_live": [[0, 0, 0, "-", "client"], [0, 0, 0, "-", "server"]], "whisper_live.client": [[0, 1, 1, "", "Client"], [0, 1, 1, "", "TranscriptionClient"], [0, 3, 1, "", "resample"]], "whisper_live.client.Client": [[0, 2, 1, "", "bytes_to_float_array"], [0, 2, 1, "", "close_websocket"], [0, 2, 1, "", "get_client_socket"], [0, 2, 1, "", "on_message"], [0, 2, 1, "", "on_open"], [0, 2, 1, "", "play_file"], [0, 2, 1, "", "record"], [0, 2, 1, "", "send_packet_to_server"], [0, 2, 1, "", "write_audio_frames_to_file"], [0, 2, 1, "", "write_output_recording"]], "whisper_live.server": [[0, 1, 1, "", "ServeClient"], [0, 1, 1, "", "TranscriptionServer"]], "whisper_live.server.ServeClient": [[0, 2, 1, "", "add_frames"], [0, 2, 1, "", "cleanup"], [0, 2, 1, "", "disconnect"], [0, 2, 1, "", "fill_output"], [0, 2, 1, "", "speech_to_text"], [0, 2, 1, "", "update_segments"]], "whisper_live.server.TranscriptionServer": [[0, 2, 1, "", "get_wait_time"], [0, 2, 1, "", "recv_audio"], [0, 2, 1, "", "run"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:method", "3": "py:function"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "method", "Python method"], "3": ["py", "function", "Python function"]}, "titleterms": {"welcom": 0, "whisper": 0, "live": 0, "document": 0, "indic": 0, "tabl": 0}, "envversion": {"sphinx.domains.c": 3, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 9, "sphinx.domains.index": 1, "sphinx.domains.javascript": 3, "sphinx.domains.math": 2, "sphinx.domains.python": 4, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx": 60}, "alltitles": {"Welcome to Whisper Live documentation!": [[0, "welcome-to-whisper-live-documentation"]], "Indices and tables": [[0, "indices-and-tables"]]}, "indexentries": {"client (class in whisper_live.client)": [[0, "whisper_live.client.Client"]], "serveclient (class in whisper_live.server)": [[0, "whisper_live.server.ServeClient"]], "transcriptionclient (class in whisper_live.client)": [[0, "whisper_live.client.TranscriptionClient"]], "transcriptionserver (class in whisper_live.server)": [[0, "whisper_live.server.TranscriptionServer"]], "add_frames() (whisper_live.server.serveclient method)": [[0, "whisper_live.server.ServeClient.add_frames"]], "bytes_to_float_array() (whisper_live.client.client static method)": [[0, "whisper_live.client.Client.bytes_to_float_array"]], "cleanup() (whisper_live.server.serveclient method)": [[0, "whisper_live.server.ServeClient.cleanup"]], "close_websocket() (whisper_live.client.client method)": [[0, "whisper_live.client.Client.close_websocket"]], "disconnect() (whisper_live.server.serveclient method)": [[0, "whisper_live.server.ServeClient.disconnect"]], "fill_output() (whisper_live.server.serveclient method)": [[0, "whisper_live.server.ServeClient.fill_output"]], "get_client_socket() (whisper_live.client.client method)": [[0, "whisper_live.client.Client.get_client_socket"]], "get_wait_time() (whisper_live.server.transcriptionserver method)": [[0, "whisper_live.server.TranscriptionServer.get_wait_time"]], "module": [[0, "module-whisper_live.client"], [0, "module-whisper_live.server"]], "on_message() (whisper_live.client.client method)": [[0, "whisper_live.client.Client.on_message"]], "on_open() (whisper_live.client.client method)": [[0, "whisper_live.client.Client.on_open"]], "play_file() (whisper_live.client.client method)": [[0, "whisper_live.client.Client.play_file"]], "record() (whisper_live.client.client method)": [[0, "whisper_live.client.Client.record"]], "recv_audio() (whisper_live.server.transcriptionserver method)": [[0, "whisper_live.server.TranscriptionServer.recv_audio"]], "resample() (in module whisper_live.client)": [[0, "whisper_live.client.resample"]], "run() (whisper_live.server.transcriptionserver method)": [[0, "whisper_live.server.TranscriptionServer.run"]], "send_packet_to_server() (whisper_live.client.client method)": [[0, "whisper_live.client.Client.send_packet_to_server"]], "speech_to_text() (whisper_live.server.serveclient method)": [[0, "whisper_live.server.ServeClient.speech_to_text"]], "update_segments() (whisper_live.server.serveclient method)": [[0, "whisper_live.server.ServeClient.update_segments"]], "whisper_live.client": [[0, "module-whisper_live.client"]], "whisper_live.server": [[0, "module-whisper_live.server"]], "write_audio_frames_to_file() (whisper_live.client.client method)": [[0, "whisper_live.client.Client.write_audio_frames_to_file"]], "write_output_recording() (whisper_live.client.client method)": [[0, "whisper_live.client.Client.write_output_recording"]]}}) \ No newline at end of file diff --git a/docs/index.html b/docs/index.html new file mode 100644 index 0000000..5896b1d --- /dev/null +++ b/docs/index.html @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/requirements/server.txt b/requirements/server.txt new file mode 100644 index 0000000..37df171 --- /dev/null +++ b/requirements/server.txt @@ -0,0 +1,13 @@ +faster-whisper==1.1.0 +websockets +onnxruntime==1.17.0 +numba +kaldialign +soundfile +scipy +av +jiwer +evaluate +numpy<2 +openai-whisper==20240930 +tokenizers==0.20.3 \ No newline at end of file diff --git a/run_server.py b/run_server.py new file mode 100644 index 0000000..0055a1e --- /dev/null +++ b/run_server.py @@ -0,0 +1,84 @@ +import argparse +import ssl +import os +import socket + +def check_port_availability(port): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + result = sock.connect_ex(('0.0.0.0', port)) + sock.close() + return result != 0 + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--port', '-p', + type=int, + default=int(os.getenv('PORT_WHISPERLIVE')), + help="Websocket port to run the server on.") + parser.add_argument('--backend', '-b', + type=str, + default='faster_whisper', + help='Backends from ["tensorrt", "faster_whisper"]') + parser.add_argument('--faster_whisper_custom_model_path', '-fw', + type=str, default=None, + help="Custom Faster Whisper Model") + parser.add_argument('--trt_model_path', '-trt', + type=str, + default=None, + help='Whisper TensorRT model path') + parser.add_argument('--trt_multilingual', '-m', + action="store_true", + help='Boolean only for TensorRT model. True if multilingual.') + parser.add_argument('--ssl_cert_path', '-ssl', + type=str, + default=None, + help='Path to cert.pem and key.pem if ssl should be used.') + parser.add_argument('--omp_num_threads', '-omp', + type=int, + default=1, + help="Number of threads to use for OpenMP") + parser.add_argument('--no_single_model', '-nsm', + action='store_true', + help='Set this if every connection should instantiate its own model. Only relevant for custom model, passed using -trt or -fw.') + args = parser.parse_args() + + if args.backend == "tensorrt": + if args.trt_model_path is None: + raise ValueError("Please Provide a valid tensorrt model path") + + port = args.port + if not check_port_availability(port): + print(f"Warning: Port {port} might already be in use!") + + ssl_context = None + if args.ssl_cert_path is not None: + try: + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain( + certfile=f"{args.ssl_cert_path}/cert.pem", + keyfile=f"{args.ssl_cert_path}/privkey.pem" + ) + print("SSL context created successfully") + except Exception as e: + print(f"Failed to load SSL certificates: {str(e)}") + raise + + if "OMP_NUM_THREADS" not in os.environ: + print(f"Setting OMP_NUM_THREADS to {args.omp_num_threads}") + os.environ["OMP_NUM_THREADS"] = str(args.omp_num_threads) + + from whisper_live.server import TranscriptionServer + print(f"Running server with args: {args}") + server = TranscriptionServer() + + print(f"Starting server on port {args.port} with backend {args.backend} using SSL: {args.ssl_cert_path is not None}") + server.run( + "0.0.0.0", + port=args.port, + backend=args.backend, + faster_whisper_custom_model_path=args.faster_whisper_custom_model_path, + whisper_tensorrt_path=args.trt_model_path, + trt_multilingual=args.trt_multilingual, + single_model=not args.no_single_model, + ssl_context=ssl_context + ) \ No newline at end of file diff --git a/scripts/build_whisper_tensorrt.sh b/scripts/build_whisper_tensorrt.sh new file mode 100644 index 0000000..9824803 --- /dev/null +++ b/scripts/build_whisper_tensorrt.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +download_and_build_model() { + local model_name="$1" + local model_url="" + + case "$model_name" in + "tiny.en") + model_url="https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt" + ;; + "tiny") + model_url="https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt" + ;; + "base.en") + model_url="https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt" + ;; + "base") + model_url="https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt" + ;; + "small.en") + model_url="https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt" + ;; + "small") + model_url="https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt" + ;; + "medium.en") + model_url="https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt" + ;; + "medium") + model_url="https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt" + ;; + "large-v1") + model_url="https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt" + ;; + "large-v2") + model_url="https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt" + ;; + "large-v3" | "large") + model_url="https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt" + ;; + *) + echo "Invalid model name: $model_name" + exit 1 + ;; + esac + + echo "Downloading $model_name..." + # wget --directory-prefix=assets "$model_url" + # echo "Download completed: ${model_name}.pt" + if [ ! -f "assets/${model_name}.pt" ]; then + wget --directory-prefix=assets "$model_url" + echo "Download completed: ${model_name}.pt" + else + echo "${model_name}.pt already exists in assets directory." + fi + + local output_dir="whisper_${model_name//./_}" + echo "$output_dir" + echo "Running build script for $model_name with output directory $output_dir" + python3 build.py --output_dir "$output_dir" --use_gpt_attention_plugin --use_gemm_plugin --use_bert_attention_plugin --model_name "$model_name" + echo "Whisper $model_name TensorRT engine built." + echo "=========================================" + echo "Model is located at: $(pwd)/$output_dir" +} + +if [ "$#" -lt 1 ]; then + echo "Usage: $0 [model-name]" + exit 1 +fi + +tensorrt_examples_dir="$1" +model_name="${2:-small.en}" + +cd $1/whisper +pip install --no-deps -r requirements.txt + +download_and_build_model "$model_name" diff --git a/scripts/setup.sh b/scripts/setup.sh new file mode 100644 index 0000000..feca7d6 --- /dev/null +++ b/scripts/setup.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +apt-get update +apt-get install -y portaudio19-dev ffmpeg wget diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b264506 --- /dev/null +++ b/setup.py @@ -0,0 +1,60 @@ +import pathlib +from setuptools import find_packages, setup +from whisper_live.__version__ import __version__ + + +# The directory containing this file +HERE = pathlib.Path(__file__).parent + +# The text of the README file +README = (HERE / "README.md").read_text() + +# This call to setup() does all the work +setup( + name="whisper_live", + version=__version__, + description="A nearly-live implementation of OpenAI's Whisper.", + long_description=README, + long_description_content_type="text/markdown", + include_package_data=True, + url="https://github.com/collabora/WhisperLive", + author="Collabora Ltd", + author_email="vineet.suryan@collabora.com", + license="MIT", + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + packages=find_packages( + exclude=( + "examples", + "Audio-Transcription-Chrome", + "Audio-Transcription-Firefox", + "requirements", + "whisper-finetuning" + ) + ), + install_requires=[ + "PyAudio", + "faster-whisper==1.1.0", + "torch", + "torchaudio", + "websockets", + "onnxruntime==1.17.0", + "scipy", + "websocket-client", + "numba", + "openai-whisper==20240930", #TODO: understand this + "kaldialign", + "soundfile", + "tokenizers==0.20.3" + ], + python_requires=">=3.8" +) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..66189e5 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,156 @@ +import json +import os +import scipy +import websocket +import copy +import unittest +from unittest.mock import patch, MagicMock +from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient +from whisper_live.utils import resample +from pathlib import Path + + +class BaseTestCase(unittest.TestCase): + @patch('whisper_live.client.websocket.WebSocketApp') + @patch('whisper_live.client.pyaudio.PyAudio') + def setUp(self, mock_pyaudio, mock_websocket): + self.mock_pyaudio_instance = MagicMock() + mock_pyaudio.return_value = self.mock_pyaudio_instance + self.mock_stream = MagicMock() + self.mock_pyaudio_instance.open.return_value = self.mock_stream + + self.mock_ws_app = mock_websocket.return_value + self.mock_ws_app.send = MagicMock() + + self.client = TranscriptionClient(host='localhost', port=9090, lang="en").client + + self.mock_pyaudio = mock_pyaudio + self.mock_websocket = mock_websocket + self.mock_audio_packet = b'\x00\x01\x02\x03' + + def tearDown(self): + self.client.close_websocket() + self.mock_pyaudio.stop() + self.mock_websocket.stop() + del self.client + +class TestClientWebSocketCommunication(BaseTestCase): + def test_websocket_communication(self): + expected_url = 'ws://localhost:9090' + self.mock_websocket.assert_called() + self.assertEqual(self.mock_websocket.call_args[0][0], expected_url) + + +class TestClientCallbacks(BaseTestCase): + def test_on_open(self): + expected_message = json.dumps({ + "uid": self.client.uid, + "language": self.client.language, + "task": self.client.task, + "model": self.client.model, + "use_vad": True + }) + self.client.on_open(self.mock_ws_app) + self.mock_ws_app.send.assert_called_with(expected_message) + + def test_on_message(self): + message = json.dumps( + { + "uid": self.client.uid, + "message": "SERVER_READY", + "backend": "faster_whisper" + } + ) + self.client.on_message(self.mock_ws_app, message) + + message = json.dumps({ + "uid": self.client.uid, + "segments": [ + {"start": 0, "end": 1, "text": "Test transcript"}, + {"start": 1, "end": 2, "text": "Test transcript 2"}, + {"start": 2, "end": 3, "text": "Test transcript 3"} + ] + }) + self.client.on_message(self.mock_ws_app, message) + + # Assert that the transcript was updated correctly + self.assertEqual(len(self.client.transcript), 2) + self.assertEqual(self.client.transcript[1]['text'], "Test transcript 2") + + def test_on_close(self): + close_status_code = 1000 + close_msg = "Normal closure" + self.client.on_close(self.mock_ws_app, close_status_code, close_msg) + + self.assertFalse(self.client.recording) + self.assertFalse(self.client.server_error) + self.assertFalse(self.client.waiting) + + def test_on_error(self): + error_message = "Test Error" + self.client.on_error(self.mock_ws_app, error_message) + + self.assertTrue(self.client.server_error) + self.assertEqual(self.client.error_message, error_message) + + +class TestAudioResampling(unittest.TestCase): + def test_resample_audio(self): + original_audio = "assets/jfk.flac" + expected_sr = 16000 + resampled_audio = resample(original_audio, expected_sr) + + sr, _ = scipy.io.wavfile.read(resampled_audio) + self.assertEqual(sr, expected_sr) + + os.remove(resampled_audio) + + +class TestSendingAudioPacket(BaseTestCase): + def test_send_packet(self): + self.client.send_packet_to_server(self.mock_audio_packet) + self.client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY) + +class TestTee(BaseTestCase): + @patch('whisper_live.client.websocket.WebSocketApp') + @patch('whisper_live.client.pyaudio.PyAudio') + def setUp(self, mock_audio, mock_websocket): + super().setUp() + self.client2 = Client(host='localhost', port=9090, lang="es", translate=False, srt_file_path="transcript.srt") + self.client3 = Client(host='localhost', port=9090, lang="es", translate=True, srt_file_path="translation.srt") + # need a separate mock for each websocket + self.client3.client_socket = copy.deepcopy(self.client3.client_socket) + self.tee = TranscriptionTeeClient([self.client2, self.client3]) + + def tearDown(self): + self.tee.close_all_clients() + del self.tee + super().tearDown() + + def test_invalid_constructor(self): + with self.assertRaises(Exception) as context: + TranscriptionTeeClient([]) + + def test_multicast_unconditional(self): + self.tee.multicast_packet(self.mock_audio_packet, True) + for client in self.tee.clients: + client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY) + + def test_multicast_conditional(self): + self.client2.recording = False + self.client3.recording = True + self.tee.multicast_packet(self.mock_audio_packet, False) + self.client2.client_socket.send.assert_not_called() + self.client3.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY) + + def test_close_all(self): + self.tee.close_all_clients() + for client in self.tee.clients: + client.client_socket.close.assert_called() + + def test_write_all_srt(self): + for client in self.tee.clients: + client.server_backend = "faster_whisper" + self.tee.write_all_clients_srt() + self.assertTrue(Path("transcript.srt").is_file()) + self.assertTrue(Path("translation.srt").is_file()) diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..f836be7 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,150 @@ +import subprocess +import time +import json +import unittest +from unittest import mock + +import numpy as np +import evaluate + +from websockets.exceptions import ConnectionClosed +from whisper_live.server import TranscriptionServer +from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient +from whisper.normalizers import EnglishTextNormalizer + + +class TestTranscriptionServerInitialization(unittest.TestCase): + def test_initialization(self): + server = TranscriptionServer() + self.assertEqual(server.client_manager.max_clients, 4) + self.assertEqual(server.client_manager.max_connection_time, 600) + self.assertDictEqual(server.client_manager.clients, {}) + self.assertDictEqual(server.client_manager.start_times, {}) + + +class TestGetWaitTime(unittest.TestCase): + def setUp(self): + self.server = TranscriptionServer() + self.server.client_manager.start_times = { + 'client1': time.time() - 120, + 'client2': time.time() - 300 + } + self.server.client_manager.max_connection_time = 600 + + def test_get_wait_time(self): + expected_wait_time = (600 - (time.time() - self.server.client_manager.start_times['client2'])) / 60 + print(self.server.client_manager.get_wait_time(), expected_wait_time) + self.assertAlmostEqual(self.server.client_manager.get_wait_time(), expected_wait_time, places=2) + + +class TestServerConnection(unittest.TestCase): + def setUp(self): + self.server = TranscriptionServer() + + @mock.patch('websockets.WebSocketCommonProtocol') + def test_connection(self, mock_websocket): + mock_websocket.recv.return_value = json.dumps({ + 'uid': 'test_client', + 'language': 'en', + 'task': 'transcribe', + 'model': 'tiny.en' + }) + self.server.recv_audio(mock_websocket, "faster_whisper") + + @mock.patch('websockets.WebSocketCommonProtocol') + def test_recv_audio_exception_handling(self, mock_websocket): + mock_websocket.recv.side_effect = [json.dumps({ + 'uid': 'test_client', + 'language': 'en', + 'task': 'transcribe', + 'model': 'tiny.en' + }), np.array([1, 2, 3]).tobytes()] + + with self.assertLogs(level="ERROR"): + self.server.recv_audio(mock_websocket, "faster_whisper") + + self.assertNotIn(mock_websocket, self.server.client_manager.clients) + + +class TestServerInferenceAccuracy(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.mock_pyaudio_patch = mock.patch('pyaudio.PyAudio') + cls.mock_pyaudio = cls.mock_pyaudio_patch.start() + cls.mock_pyaudio.return_value.open.return_value = mock.MagicMock() + + cls.server_process = subprocess.Popen(["python", "run_server.py"]) + time.sleep(2) + + @classmethod + def tearDownClass(cls): + cls.server_process.terminate() + cls.server_process.wait() + + def setUp(self): + self.metric = evaluate.load("wer") + self.normalizer = EnglishTextNormalizer() + + def check_prediction(self, srt_path): + gt = "And so my fellow Americans, ask not, what your country can do for you. Ask what you can do for your country!" + with open(srt_path, "r") as f: + lines = f.readlines() + prediction = " ".join([line.strip() for line in lines[2::4]]) + prediction_normalized = self.normalizer(prediction) + gt_normalized = self.normalizer(gt) + + # calculate WER + wer = self.metric.compute( + predictions=[prediction_normalized], + references=[gt_normalized] + ) + self.assertLess(wer, 0.05) + + def test_inference(self): + client = TranscriptionClient( + "localhost", "9090", model="base.en", lang="en", + ) + client("assets/jfk.flac") + self.check_prediction("output.srt") + + def test_simultaneous_inference(self): + client1 = Client( + "localhost", "9090", model="base.en", lang="en", srt_file_path="transcript1.srt") + client2 = Client( + "localhost", "9090", model="base.en", lang="en", srt_file_path="transcript2.srt") + tee = TranscriptionTeeClient([client1, client2]) + tee("assets/jfk.flac") + self.check_prediction("transcript1.srt") + self.check_prediction("transcript2.srt") + + +class TestExceptionHandling(unittest.TestCase): + def setUp(self): + self.server = TranscriptionServer() + + @mock.patch('websockets.WebSocketCommonProtocol') + def test_connection_closed_exception(self, mock_websocket): + mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed") + + with self.assertLogs(level="INFO") as log: + self.server.recv_audio(mock_websocket, "faster_whisper") + self.assertTrue(any("Connection closed by client" in message for message in log.output)) + + @mock.patch('websockets.WebSocketCommonProtocol') + def test_json_decode_exception(self, mock_websocket): + mock_websocket.recv.return_value = "invalid json" + + with self.assertLogs(level="ERROR") as log: + self.server.recv_audio(mock_websocket, "faster_whisper") + self.assertTrue(any("Failed to decode JSON from client" in message for message in log.output)) + + @mock.patch('websockets.WebSocketCommonProtocol') + def test_unexpected_exception_handling(self, mock_websocket): + mock_websocket.recv.side_effect = RuntimeError("Unexpected error") + + with self.assertLogs(level="ERROR") as log: + self.server.recv_audio(mock_websocket, "faster_whisper") + for message in log.output: + print(message) + print() + self.assertTrue(any("Unexpected error" in message for message in log.output)) diff --git a/tests/test_vad.py b/tests/test_vad.py new file mode 100644 index 0000000..cfc2d3a --- /dev/null +++ b/tests/test_vad.py @@ -0,0 +1,26 @@ +import unittest +import numpy as np +from whisper_live.tensorrt_utils import load_audio +from whisper_live.vad import VoiceActivityDetector + + +class TestVoiceActivityDetection(unittest.TestCase): + def setUp(self): + self.vad = VoiceActivityDetector() + self.sample_rate = 16000 + + def generate_silence(self, duration_seconds): + return np.zeros(int(self.sample_rate * duration_seconds), dtype=np.float32) + + def load_speech_segment(self, filepath): + return load_audio(filepath) + + def test_vad_silence_detection(self): + silence = self.generate_silence(3) + is_speech_present = self.vad(silence.copy()) + self.assertFalse(is_speech_present, "VAD incorrectly identified silence as speech.") + + def test_vad_speech_detection(self): + audio_tensor = load_audio("assets/jfk.flac") + is_speech_present = self.vad(audio_tensor) + self.assertTrue(is_speech_present, "VAD failed to identify speech segment.") diff --git a/whisper_live/__init__.py b/whisper_live/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/whisper_live/__pycache__/__init__.cpython-312.pyc b/whisper_live/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..fd07219 Binary files /dev/null and b/whisper_live/__pycache__/__init__.cpython-312.pyc differ diff --git a/whisper_live/__pycache__/__version__.cpython-312.pyc b/whisper_live/__pycache__/__version__.cpython-312.pyc new file mode 100644 index 0000000..bdde886 Binary files /dev/null and b/whisper_live/__pycache__/__version__.cpython-312.pyc differ diff --git a/whisper_live/__version__.py b/whisper_live/__version__.py new file mode 100644 index 0000000..3d26edf --- /dev/null +++ b/whisper_live/__version__.py @@ -0,0 +1 @@ +__version__ = "0.4.1" diff --git a/whisper_live/server.py b/whisper_live/server.py new file mode 100644 index 0000000..f9bb594 --- /dev/null +++ b/whisper_live/server.py @@ -0,0 +1,1139 @@ +import os +import time +import threading +import json +import functools +import logging +from enum import Enum +from typing import List, Optional + +import torch +import numpy as np +from websockets.sync.server import serve +from websockets.exceptions import ConnectionClosed +from whisper_live.vad import VoiceActivityDetector +from whisper_live.transcriber import WhisperModel +try: + from whisper_live.transcriber_tensorrt import WhisperTRTLLM +except Exception: + pass + +logging.basicConfig(level=logging.INFO) + + +class ClientManager: + def __init__(self, max_clients=4, max_connection_time=600): + """ + Initializes the ClientManager with specified limits on client connections and connection durations. + + Args: + max_clients (int, optional): The maximum number of simultaneous client connections allowed. Defaults to 4. + max_connection_time (int, optional): The maximum duration (in seconds) a client can stay connected. Defaults + to 600 seconds (10 minutes). + """ + self.clients = {} + self.start_times = {} + self.max_clients = max_clients + self.max_connection_time = max_connection_time + + def add_client(self, websocket, client): + """ + Adds a client and their connection start time to the tracking dictionaries. + + Args: + websocket: The websocket associated with the client to add. + client: The client object to be added and tracked. + """ + self.clients[websocket] = client + self.start_times[websocket] = time.time() + + def get_client(self, websocket): + """ + Retrieves a client associated with the given websocket. + + Args: + websocket: The websocket associated with the client to retrieve. + + Returns: + The client object if found, False otherwise. + """ + if websocket in self.clients: + return self.clients[websocket] + return False + + def remove_client(self, websocket): + """ + Removes a client and their connection start time from the tracking dictionaries. Performs cleanup on the + client if necessary. + + Args: + websocket: The websocket associated with the client to be removed. + """ + client = self.clients.pop(websocket, None) + if client: + client.cleanup() + self.start_times.pop(websocket, None) + + def get_wait_time(self): + """ + Calculates the estimated wait time for new clients based on the remaining connection times of current clients. + + Returns: + The estimated wait time in minutes for new clients to connect. Returns 0 if there are available slots. + """ + wait_time = None + for start_time in self.start_times.values(): + current_client_time_remaining = self.max_connection_time - (time.time() - start_time) + if wait_time is None or current_client_time_remaining < wait_time: + wait_time = current_client_time_remaining + return wait_time / 60 if wait_time is not None else 0 + + def is_server_full(self, websocket, options): + """ + Checks if the server is at its maximum client capacity and sends a wait message to the client if necessary. + + Args: + websocket: The websocket of the client attempting to connect. + options: A dictionary of options that may include the client's unique identifier. + + Returns: + True if the server is full, False otherwise. + """ + if len(self.clients) >= self.max_clients: + wait_time = self.get_wait_time() + response = {"uid": options["uid"], "status": "WAIT", "message": wait_time} + websocket.send(json.dumps(response)) + return True + return False + + def is_client_timeout(self, websocket): + """ + Checks if a client has exceeded the maximum allowed connection time and disconnects them if so, issuing a warning. + + Args: + websocket: The websocket associated with the client to check. + + Returns: + True if the client's connection time has exceeded the maximum limit, False otherwise. + """ + elapsed_time = time.time() - self.start_times[websocket] + if elapsed_time >= self.max_connection_time: + self.clients[websocket].disconnect() + logging.warning(f"Client with uid '{self.clients[websocket].client_uid}' disconnected due to overtime.") + return True + return False + + +class BackendType(Enum): + FASTER_WHISPER = "faster_whisper" + TENSORRT = "tensorrt" + + @staticmethod + def valid_types() -> List[str]: + return [backend_type.value for backend_type in BackendType] + + @staticmethod + def is_valid(backend: str) -> bool: + return backend in BackendType.valid_types() + + def is_faster_whisper(self) -> bool: + return self == BackendType.FASTER_WHISPER + + def is_tensorrt(self) -> bool: + return self == BackendType.TENSORRT + + +class TranscriptionServer: + RATE = 16000 + + def __init__(self): + self.client_manager = None + self.no_voice_activity_chunks = 0 + self.use_vad = True + self.single_model = False + + def initialize_client( + self, websocket, options, faster_whisper_custom_model_path, + whisper_tensorrt_path, trt_multilingual + ): + client: Optional[ServeClientBase] = None + + if self.backend.is_tensorrt(): + try: + client = ServeClientTensorRT( + websocket, + multilingual=trt_multilingual, + language=options["language"], + task=options["task"], + client_uid=options["uid"], + model=whisper_tensorrt_path, + single_model=self.single_model, + ) + logging.info("Running TensorRT backend.") + except Exception as e: + logging.error(f"TensorRT-LLM not supported: {e}") + self.client_uid = options["uid"] + websocket.send(json.dumps({ + "uid": self.client_uid, + "status": "WARNING", + "message": "TensorRT-LLM not supported on Server yet. " + "Reverting to available backend: 'faster_whisper'" + })) + self.backend = BackendType.FASTER_WHISPER + + try: + if self.backend.is_faster_whisper(): + if faster_whisper_custom_model_path is not None and os.path.exists(faster_whisper_custom_model_path): + logging.info(f"Using custom model {faster_whisper_custom_model_path}") + options["model"] = faster_whisper_custom_model_path + client = ServeClientFasterWhisper( + websocket, + language=options["language"], + task=options["task"], + client_uid=options["uid"], + model=options["model"], + initial_prompt=options.get("initial_prompt"), + vad_parameters=options.get("vad_parameters"), + use_vad=self.use_vad, + single_model=self.single_model, + ) + + logging.info("Running faster_whisper backend.") + except Exception as e: + return + + if client is None: + raise ValueError(f"Backend type {self.backend.value} not recognised or not handled.") + + self.client_manager.add_client(websocket, client) + + def get_audio_from_websocket(self, websocket): + """ + Receives audio buffer from websocket and creates a numpy array out of it. + + Args: + websocket: The websocket to receive audio from. + + Returns: + A numpy array containing the audio. + """ + frame_data = websocket.recv() + if frame_data == b"END_OF_AUDIO": + return False + return np.frombuffer(frame_data, dtype=np.float32) + + def handle_new_connection(self, websocket, faster_whisper_custom_model_path, + whisper_tensorrt_path, trt_multilingual): + try: + logging.info("New client connected") + options = websocket.recv() + options = json.loads(options) + + if self.client_manager is None: + max_clients = options.get('max_clients', 4) + max_connection_time = options.get('max_connection_time', 600) + self.client_manager = ClientManager(max_clients, max_connection_time) + + self.use_vad = options.get('use_vad') + if self.client_manager.is_server_full(websocket, options): + websocket.close() + return False # Indicates that the connection should not continue + + if self.backend.is_tensorrt(): + self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE) + self.initialize_client(websocket, options, faster_whisper_custom_model_path, + whisper_tensorrt_path, trt_multilingual) + return True + except json.JSONDecodeError: + logging.error("Failed to decode JSON from client") + return False + except ConnectionClosed: + logging.info("Connection closed by client") + return False + except Exception as e: + logging.error(f"Error during new connection initialization: {str(e)}") + return False + + def process_audio_frames(self, websocket): + frame_np = self.get_audio_from_websocket(websocket) + client = self.client_manager.get_client(websocket) + if frame_np is False: + if self.backend.is_tensorrt(): + client.set_eos(True) + return False + + if self.backend.is_tensorrt(): + voice_active = self.voice_activity(websocket, frame_np) + if voice_active: + self.no_voice_activity_chunks = 0 + client.set_eos(False) + if self.use_vad and not voice_active: + return True + + client.add_frames(frame_np) + return True + + def recv_audio(self, + websocket, + backend: BackendType = BackendType.FASTER_WHISPER, + faster_whisper_custom_model_path=None, + whisper_tensorrt_path=None, + trt_multilingual=False): + """ + Receive audio chunks from a client in an infinite loop. + + Continuously receives audio frames from a connected client + over a WebSocket connection. It processes the audio frames using a + voice activity detection (VAD) model to determine if they contain speech + or not. If the audio frame contains speech, it is added to the client's + audio data for ASR. + If the maximum number of clients is reached, the method sends a + "WAIT" status to the client, indicating that they should wait + until a slot is available. + If a client's connection exceeds the maximum allowed time, it will + be disconnected, and the client's resources will be cleaned up. + + Args: + websocket (WebSocket): The WebSocket connection for the client. + backend (str): The backend to run the server with. + faster_whisper_custom_model_path (str): path to custom faster whisper model. + whisper_tensorrt_path (str): Required for tensorrt backend. + trt_multilingual(bool): Only used for tensorrt, True if multilingual model. + + Raises: + Exception: If there is an error during the audio frame processing. + """ + self.backend = backend + if not self.handle_new_connection(websocket, faster_whisper_custom_model_path, + whisper_tensorrt_path, trt_multilingual): + return + + try: + while not self.client_manager.is_client_timeout(websocket): + if not self.process_audio_frames(websocket): + break + except ConnectionClosed: + logging.info("Connection closed by client") + except Exception as e: + logging.error(f"Unexpected error: {str(e)}") + finally: + if self.client_manager.get_client(websocket): + self.cleanup(websocket) + websocket.close() + del websocket + + def run(self, + host, + port=int(os.getenv('PORT_WHISPERLIVE')), + backend="tensorrt", + faster_whisper_custom_model_path=None, + whisper_tensorrt_path=None, + trt_multilingual=False, + single_model=False, + ssl_context=None): + """ + Run the transcription server. + + Args: + host (str): The host address to bind the server. + port (int): The port number to bind the server. + """ + if faster_whisper_custom_model_path is not None and not os.path.exists(faster_whisper_custom_model_path): + raise ValueError(f"Custom faster_whisper model '{faster_whisper_custom_model_path}' is not a valid path.") + if whisper_tensorrt_path is not None and not os.path.exists(whisper_tensorrt_path): + raise ValueError(f"TensorRT model '{whisper_tensorrt_path}' is not a valid path.") + if single_model: + if faster_whisper_custom_model_path or whisper_tensorrt_path: + logging.info("Custom model option was provided. Switching to single model mode.") + self.single_model = True + # TODO: load model initially + else: + logging.info("Single model mode currently only works with custom models.") + if not BackendType.is_valid(backend): + raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}") + with serve( + functools.partial( + self.recv_audio, + backend=BackendType(backend), + faster_whisper_custom_model_path=faster_whisper_custom_model_path, + whisper_tensorrt_path=whisper_tensorrt_path, + trt_multilingual=trt_multilingual + ), + host, + port, + ssl_context=ssl_context + ) as server: + server.serve_forever() + + def voice_activity(self, websocket, frame_np): + """ + Evaluates the voice activity in a given audio frame and manages the state of voice activity detection. + + This method uses the configured voice activity detection (VAD) model to assess whether the given audio frame + contains speech. If the VAD model detects no voice activity for more than three consecutive frames, + it sets an end-of-speech (EOS) flag for the associated client. This method aims to efficiently manage + speech detection to improve subsequent processing steps. + + Args: + websocket: The websocket associated with the current client. Used to retrieve the client object + from the client manager for state management. + frame_np (numpy.ndarray): The audio frame to be analyzed. This should be a NumPy array containing + the audio data for the current frame. + + Returns: + bool: True if voice activity is detected in the current frame, False otherwise. When returning False + after detecting no voice activity for more than three consecutive frames, it also triggers the + end-of-speech (EOS) flag for the client. + """ + if not self.vad_detector(frame_np): + self.no_voice_activity_chunks += 1 + if self.no_voice_activity_chunks > 3: + client = self.client_manager.get_client(websocket) + if not client.eos: + client.set_eos(True) + time.sleep(0.1) # Sleep 100m; wait some voice activity. + return False + return True + + def cleanup(self, websocket): + """ + Cleans up resources associated with a given client's websocket. + + Args: + websocket: The websocket associated with the client to be cleaned up. + """ + if self.client_manager.get_client(websocket): + self.client_manager.remove_client(websocket) + + +class ServeClientBase(object): + RATE = 16000 + SERVER_READY = "SERVER_READY" + DISCONNECT = "DISCONNECT" + + def __init__(self, client_uid, websocket): + self.client_uid = client_uid + self.websocket = websocket + self.frames = b"" + self.timestamp_offset = 0.0 + self.frames_np = None + self.frames_offset = 0.0 + self.text = [] + self.current_out = '' + self.prev_out = '' + self.t_start = None + self.exit = False + self.same_output_count = 0 + self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds + self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds + self.transcript = [] + self.send_last_n_segments = 10 + + # text formatting + self.pick_previous_segments = 2 + + # threading + self.lock = threading.Lock() + + def speech_to_text(self): + raise NotImplementedError + + def transcribe_audio(self): + raise NotImplementedError + + def handle_transcription_output(self): + raise NotImplementedError + + def add_frames(self, frame_np): + """ + Add audio frames to the ongoing audio stream buffer. + + This method is responsible for maintaining the audio stream buffer, allowing the continuous addition + of audio frames as they are received. It also ensures that the buffer does not exceed a specified size + to prevent excessive memory usage. + + If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds + of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided + audio frame. The audio stream buffer is used for real-time processing of audio data for transcription. + + Args: + frame_np (numpy.ndarray): The audio frame data as a NumPy array. + + """ + self.lock.acquire() + if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE: + self.frames_offset += 30.0 + self.frames_np = self.frames_np[int(30*self.RATE):] + # check timestamp offset(should be >= self.frame_offset) + # this basically means that there is no speech as timestamp offset hasnt updated + # and is less than frame_offset + if self.timestamp_offset < self.frames_offset: + self.timestamp_offset = self.frames_offset + if self.frames_np is None: + self.frames_np = frame_np.copy() + else: + self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0) + self.lock.release() + + def clip_audio_if_no_valid_segment(self): + """ + Update the timestamp offset based on audio buffer status. + Clip audio if the current chunk exceeds 30 seconds, this basically implies that + no valid segment for the last 30 seconds from whisper + """ + with self.lock: + if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE: + duration = self.frames_np.shape[0] / self.RATE + self.timestamp_offset = self.frames_offset + duration - 5 + + def get_audio_chunk_for_processing(self): + """ + Retrieves the next chunk of audio data for processing based on the current offsets. + + Calculates which part of the audio data should be processed next, based on + the difference between the current timestamp offset and the frame's offset, scaled by + the audio sample rate (RATE). It then returns this chunk of audio data along with its + duration in seconds. + + Returns: + tuple: A tuple containing: + - input_bytes (np.ndarray): The next chunk of audio data to be processed. + - duration (float): The duration of the audio chunk in seconds. + """ + with self.lock: + samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE) + input_bytes = self.frames_np[int(samples_take):].copy() + duration = input_bytes.shape[0] / self.RATE + return input_bytes, duration + + def prepare_segments(self, last_segment=None): + """ + Prepares the segments of transcribed text to be sent to the client. + + This method compiles the recent segments of transcribed text, ensuring that only the + specified number of the most recent segments are included. It also appends the most + recent segment of text if provided (which is considered incomplete because of the possibility + of the last word being truncated in the audio chunk). + + Args: + last_segment (str, optional): The most recent segment of transcribed text to be added + to the list of segments. Defaults to None. + + Returns: + list: A list of transcribed text segments to be sent to the client. + """ + segments = [] + if len(self.transcript) >= self.send_last_n_segments: + segments = self.transcript[-self.send_last_n_segments:].copy() + else: + segments = self.transcript.copy() + if last_segment is not None: + segments = segments + [last_segment] + return segments + + def get_audio_chunk_duration(self, input_bytes): + """ + Calculates the duration of the provided audio chunk. + + Args: + input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration. + + Returns: + float: The duration of the audio chunk in seconds. + """ + return input_bytes.shape[0] / self.RATE + + def send_transcription_to_client(self, segments): + """ + Sends the specified transcription segments to the client over the websocket connection. + + This method formats the transcription segments into a JSON object and attempts to send + this object to the client. If an error occurs during the send operation, it logs the error. + + Returns: + segments (list): A list of transcription segments to be sent to the client. + """ + try: + self.websocket.send( + json.dumps({ + "uid": self.client_uid, + "segments": segments, + }) + ) + except Exception as e: + logging.error(f"[ERROR]: Sending data to client: {e}") + + def disconnect(self): + """ + Notify the client of disconnection and send a disconnect message. + + This method sends a disconnect message to the client via the WebSocket connection to notify them + that the transcription service is disconnecting gracefully. + + """ + self.websocket.send(json.dumps({ + "uid": self.client_uid, + "message": self.DISCONNECT + })) + + def cleanup(self): + """ + Perform cleanup tasks before exiting the transcription service. + + This method performs necessary cleanup tasks, including stopping the transcription thread, marking + the exit flag to indicate the transcription thread should exit gracefully, and destroying resources + associated with the transcription process. + + """ + logging.info("Cleaning up.") + self.exit = True + + +class ServeClientTensorRT(ServeClientBase): + + SINGLE_MODEL = None + SINGLE_MODEL_LOCK = threading.Lock() + + def __init__(self, websocket, task="transcribe", multilingual=False, language=None, client_uid=None, model=None, single_model=False): + """ + Initialize a ServeClient instance. + The Whisper model is initialized based on the client's language and device availability. + The transcription thread is started upon initialization. A "SERVER_READY" message is sent + to the client to indicate that the server is ready. + + Args: + websocket (WebSocket): The WebSocket connection for the client. + task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe". + device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None. + multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False. + language (str, optional): The language for transcription. Defaults to None. + client_uid (str, optional): A unique identifier for the client. Defaults to None. + single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False. + + """ + super().__init__(client_uid, websocket) + self.language = language if multilingual else "en" + self.task = task + self.eos = False + + if single_model: + if ServeClientTensorRT.SINGLE_MODEL is None: + self.create_model(model, multilingual) + ServeClientTensorRT.SINGLE_MODEL = self.transcriber + else: + self.transcriber = ServeClientTensorRT.SINGLE_MODEL + else: + self.create_model(model, multilingual) + + # threading + self.trans_thread = threading.Thread(target=self.speech_to_text) + self.trans_thread.start() + + self.websocket.send(json.dumps({ + "uid": self.client_uid, + "message": self.SERVER_READY, + "backend": "tensorrt" + })) + + def create_model(self, model, multilingual, warmup=True): + """ + Instantiates a new model, sets it as the transcriber and does warmup if desired. + """ + self.transcriber = WhisperTRTLLM( + model, + assets_dir="assets", + device="cuda", + is_multilingual=multilingual, + language=self.language, + task=self.task + ) + if warmup: + self.warmup() + + def warmup(self, warmup_steps=10): + """ + Warmup TensorRT since first few inferences are slow. + + Args: + warmup_steps (int): Number of steps to warm up the model for. + """ + logging.info("[INFO:] Warming up TensorRT engine..") + mel, _ = self.transcriber.log_mel_spectrogram("assets/jfk.flac") + for i in range(warmup_steps): + self.transcriber.transcribe(mel) + + def set_eos(self, eos): + """ + Sets the End of Speech (EOS) flag. + + Args: + eos (bool): The value to set for the EOS flag. + """ + self.lock.acquire() + self.eos = eos + self.lock.release() + + def handle_transcription_output(self, last_segment, duration): + """ + Handle the transcription output, updating the transcript and sending data to the client. + + Args: + last_segment (str): The last segment from the whisper output which is considered to be incomplete because + of the possibility of word being truncated. + duration (float): Duration of the transcribed audio chunk. + """ + segments = self.prepare_segments({"text": last_segment}) + self.send_transcription_to_client(segments) + if self.eos: + self.update_timestamp_offset(last_segment, duration) + + def transcribe_audio(self, input_bytes): + """ + Transcribe the audio chunk and send the results to the client. + + Args: + input_bytes (np.array): The audio chunk to transcribe. + """ + if ServeClientTensorRT.SINGLE_MODEL: + ServeClientTensorRT.SINGLE_MODEL_LOCK.acquire() + logging.info(f"[WhisperTensorRT:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}") + mel, duration = self.transcriber.log_mel_spectrogram(input_bytes) + last_segment = self.transcriber.transcribe( + mel, + text_prefix=f"<|startoftranscript|><|{self.language}|><|{self.task}|><|notimestamps|>" + ) + if ServeClientTensorRT.SINGLE_MODEL: + ServeClientTensorRT.SINGLE_MODEL_LOCK.release() + if last_segment: + self.handle_transcription_output(last_segment, duration) + + def update_timestamp_offset(self, last_segment, duration): + """ + Update timestamp offset and transcript. + + Args: + last_segment (str): Last transcribed audio from the whisper model. + duration (float): Duration of the last audio chunk. + """ + if not len(self.transcript): + self.transcript.append({"text": last_segment + " "}) + elif self.transcript[-1]["text"].strip() != last_segment: + self.transcript.append({"text": last_segment + " "}) + + with self.lock: + self.timestamp_offset += duration + + def speech_to_text(self): + """ + Process an audio stream in an infinite loop, continuously transcribing the speech. + + This method continuously receives audio frames, performs real-time transcription, and sends + transcribed segments to the client via a WebSocket connection. + + If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction. + It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments + are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech + (no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if + there is no speech for a specified duration to indicate a pause. + + Raises: + Exception: If there is an issue with audio processing or WebSocket communication. + + """ + while True: + if self.exit: + logging.info("Exiting speech to text thread") + break + + if self.frames_np is None: + time.sleep(0.02) # wait for any audio to arrive + continue + + self.clip_audio_if_no_valid_segment() + + input_bytes, duration = self.get_audio_chunk_for_processing() + if duration < 0.4: + continue + + try: + input_sample = input_bytes.copy() + logging.info(f"[WhisperTensorRT:] Processing audio with duration: {duration}") + self.transcribe_audio(input_sample) + + except Exception as e: + logging.error(f"[ERROR]: {e}") + + +class ServeClientFasterWhisper(ServeClientBase): + + SINGLE_MODEL = None + SINGLE_MODEL_LOCK = threading.Lock() + + def __init__(self, websocket, task="transcribe", device=None, language=None, client_uid=None, model="small.en", + initial_prompt=None, vad_parameters=None, use_vad=True, single_model=False): + """ + Initialize a ServeClient instance. + The Whisper model is initialized based on the client's language and device availability. + The transcription thread is started upon initialization. A "SERVER_READY" message is sent + to the client to indicate that the server is ready. + + Args: + websocket (WebSocket): The WebSocket connection for the client. + task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe". + device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None. + language (str, optional): The language for transcription. Defaults to None. + client_uid (str, optional): A unique identifier for the client. Defaults to None. + model (str, optional): The whisper model size. Defaults to 'small.en' + initial_prompt (str, optional): Prompt for whisper inference. Defaults to None. + single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False. + """ + super().__init__(client_uid, websocket) + self.model_sizes = [ + "tiny", "tiny.en", "base", "base.en", "small", "small.en", + "medium", "medium.en", "large-v2", "large-v3", "distil-small.en", + "distil-medium.en", "distil-large-v2", "distil-large-v3", + "large-v3-turbo", "turbo" + ] + + if not os.path.exists(model): + self.model_size_or_path = self.check_valid_model(model) + else: + self.model_size_or_path = model + self.language = "en" if self.model_size_or_path.endswith("en") else language + self.task = task + self.initial_prompt = initial_prompt + self.vad_parameters = vad_parameters or {"onset": 0.5} + self.no_speech_thresh = 0.45 + self.same_output_threshold = 10 + self.end_time_for_same_output = None + + device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cuda": + major, _ = torch.cuda.get_device_capability(device) + self.compute_type = "float16" if major >= 7 else "float32" + else: + self.compute_type = "int8" + + if self.model_size_or_path is None: + return + logging.info(f"Using Device={device} with precision {self.compute_type}") + + try: + if single_model: + if ServeClientFasterWhisper.SINGLE_MODEL is None: + self.create_model(device) + ServeClientFasterWhisper.SINGLE_MODEL = self.transcriber + else: + self.transcriber = ServeClientFasterWhisper.SINGLE_MODEL + else: + self.create_model(device) + except Exception as e: + logging.error(f"Failed to load model: {e}") + self.websocket.send(json.dumps({ + "uid": self.client_uid, + "status": "ERROR", + "message": f"Failed to load model: {str(self.model_size_or_path)}" + })) + self.websocket.close() + return + + self.use_vad = use_vad + + # threading + self.trans_thread = threading.Thread(target=self.speech_to_text) + self.trans_thread.start() + self.websocket.send( + json.dumps( + { + "uid": self.client_uid, + "message": self.SERVER_READY, + "backend": "faster_whisper" + } + ) + ) + + def create_model(self, device): + """ + Instantiates a new model, sets it as the transcriber. + """ + self.transcriber = WhisperModel( + self.model_size_or_path, + device=device, + compute_type=self.compute_type, + local_files_only=False, + ) + + def check_valid_model(self, model_size): + """ + Check if it's a valid whisper model size. + + Args: + model_size (str): The name of the model size to check. + + Returns: + str: The model size if valid, None otherwise. + """ + if model_size not in self.model_sizes: + self.websocket.send( + json.dumps( + { + "uid": self.client_uid, + "status": "ERROR", + "message": f"Invalid model size {model_size}. Available choices: {self.model_sizes}" + } + ) + ) + return None + return model_size + + def set_language(self, info): + """ + Updates the language attribute based on the detected language information. + + Args: + info (object): An object containing the detected language and its probability. This object + must have at least two attributes: `language`, a string indicating the detected + language, and `language_probability`, a float representing the confidence level + of the language detection. + """ + if info.language_probability > 0.5: + self.language = info.language + logging.info(f"Detected language {self.language} with probability {info.language_probability}") + self.websocket.send(json.dumps( + {"uid": self.client_uid, "language": self.language, "language_prob": info.language_probability})) + + def transcribe_audio(self, input_sample): + """ + Transcribes the provided audio sample using the configured transcriber instance. + + If the language has not been set, it updates the session's language based on the transcription + information. + + Args: + input_sample (np.array): The audio chunk to be transcribed. This should be a NumPy + array representing the audio data. + + Returns: + The transcription result from the transcriber. The exact format of this result + depends on the implementation of the `transcriber.transcribe` method but typically + includes the transcribed text. + """ + if ServeClientFasterWhisper.SINGLE_MODEL: + ServeClientFasterWhisper.SINGLE_MODEL_LOCK.acquire() + result, info = self.transcriber.transcribe( + input_sample, + initial_prompt=self.initial_prompt, + language=self.language, + task=self.task, + vad_filter=self.use_vad, + vad_parameters=self.vad_parameters if self.use_vad else None) + if ServeClientFasterWhisper.SINGLE_MODEL: + ServeClientFasterWhisper.SINGLE_MODEL_LOCK.release() + + if self.language is None and info is not None: + self.set_language(info) + return result + + def get_previous_output(self): + """ + Retrieves previously generated transcription outputs if no new transcription is available + from the current audio chunks. + + Checks the time since the last transcription output and, if it is within a specified + threshold, returns the most recent segments of transcribed text. It also manages + adding a pause (blank segment) to indicate a significant gap in speech based on a defined + threshold. + + Returns: + segments (list): A list of transcription segments. This may include the most recent + transcribed text segments or a blank segment to indicate a pause + in speech. + """ + segments = [] + if self.t_start is None: + self.t_start = time.time() + if time.time() - self.t_start < self.show_prev_out_thresh: + segments = self.prepare_segments() + + # add a blank if there is no speech for 3 seconds + if len(self.text) and self.text[-1] != '': + if time.time() - self.t_start > self.add_pause_thresh: + self.text.append('') + return segments + + def handle_transcription_output(self, result, duration): + """ + Handle the transcription output, updating the transcript and sending data to the client. + + Args: + result (str): The result from whisper inference i.e. the list of segments. + duration (float): Duration of the transcribed audio chunk. + """ + segments = [] + if len(result): + self.t_start = None + last_segment = self.update_segments(result, duration) + segments = self.prepare_segments(last_segment) + else: + # show previous output if there is pause i.e. no output from whisper + segments = self.get_previous_output() + + if len(segments): + self.send_transcription_to_client(segments) + + def speech_to_text(self): + """ + Process an audio stream in an infinite loop, continuously transcribing the speech. + + This method continuously receives audio frames, performs real-time transcription, and sends + transcribed segments to the client via a WebSocket connection. + + If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction. + It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments + are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech + (no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if + there is no speech for a specified duration to indicate a pause. + + Raises: + Exception: If there is an issue with audio processing or WebSocket communication. + + """ + while True: + if self.exit: + logging.info("Exiting speech to text thread") + break + + if self.frames_np is None: + continue + + self.clip_audio_if_no_valid_segment() + + input_bytes, duration = self.get_audio_chunk_for_processing() + if duration < 1.0: + time.sleep(0.1) # wait for audio chunks to arrive + continue + try: + input_sample = input_bytes.copy() + result = self.transcribe_audio(input_sample) + + if result is None or self.language is None: + self.timestamp_offset += duration + time.sleep(0.25) # wait for voice activity, result is None when no voice activity + continue + self.handle_transcription_output(result, duration) + + except Exception as e: + logging.error(f"[ERROR]: Failed to transcribe audio chunk: {e}") + time.sleep(0.01) + + def format_segment(self, start, end, text, completed=False): + """ + Formats a transcription segment with precise start and end times alongside the transcribed text. + + Args: + start (float): The start time of the transcription segment in seconds. + end (float): The end time of the transcription segment in seconds. + text (str): The transcribed text corresponding to the segment. + + Returns: + dict: A dictionary representing the formatted transcription segment, including + 'start' and 'end' times as strings with three decimal places and the 'text' + of the transcription. + """ + return { + 'start': "{:.3f}".format(start), + 'end': "{:.3f}".format(end), + 'text': text, + 'completed': completed + } + + def update_segments(self, segments, duration): + """ + Processes the segments from whisper. Appends all the segments to the list + except for the last segment assuming that it is incomplete. + + Updates the ongoing transcript with transcribed segments, including their start and end times. + Complete segments are appended to the transcript in chronological order. Incomplete segments + (assumed to be the last one) are processed to identify repeated content. If the same incomplete + segment is seen multiple times, it updates the offset and appends the segment to the transcript. + A threshold is used to detect repeated content and ensure it is only included once in the transcript. + The timestamp offset is updated based on the duration of processed segments. The method returns the + last processed segment, allowing it to be sent to the client for real-time updates. + + Args: + segments(dict) : dictionary of segments as returned by whisper + duration(float): duration of the current chunk + + Returns: + dict or None: The last processed segment with its start time, end time, and transcribed text. + Returns None if there are no valid segments to process. + """ + offset = None + self.current_out = '' + last_segment = None + + # process complete segments + if len(segments) > 1 and segments[-1].no_speech_prob <= self.no_speech_thresh: + for i, s in enumerate(segments[:-1]): + text_ = s.text + self.text.append(text_) + with self.lock: + start, end = self.timestamp_offset + s.start, self.timestamp_offset + min(duration, s.end) + + if start >= end: + continue + if s.no_speech_prob > self.no_speech_thresh: + continue + + self.transcript.append(self.format_segment(start, end, text_, completed=True)) + offset = min(duration, s.end) + + # only process the last segment if it satisfies the no_speech_thresh + if segments[-1].no_speech_prob <= self.no_speech_thresh: + self.current_out += segments[-1].text + with self.lock: + last_segment = self.format_segment( + self.timestamp_offset + segments[-1].start, + self.timestamp_offset + min(duration, segments[-1].end), + self.current_out, + completed=False + ) + + if self.current_out.strip() == self.prev_out.strip() and self.current_out != '': + self.same_output_count += 1 + + # if we remove the audio because of same output on the nth reptition we might remove the + # audio thats not yet transcribed so, capturing the time when it was repeated for the first time + if self.end_time_for_same_output is None: + self.end_time_for_same_output = segments[-1].end + time.sleep(0.1) # wait for some voice activity just in case there is an unitended pause from the speaker for better punctuations. + else: + self.same_output_count = 0 + self.end_time_for_same_output = None + + # if same incomplete segment is seen multiple times then update the offset + # and append the segment to the list + if self.same_output_count > self.same_output_threshold: + if not len(self.text) or self.text[-1].strip().lower() != self.current_out.strip().lower(): + self.text.append(self.current_out) + with self.lock: + self.transcript.append(self.format_segment( + self.timestamp_offset, + self.timestamp_offset + min(duration, self.end_time_for_same_output), + self.current_out, + completed=True + )) + self.current_out = '' + offset = min(duration, self.end_time_for_same_output) + self.same_output_count = 0 + last_segment = None + self.end_time_for_same_output = None + else: + self.prev_out = self.current_out + + # update offset + if offset is not None: + with self.lock: + self.timestamp_offset += offset + + return last_segment \ No newline at end of file diff --git a/whisper_live/tensorrt_utils.py b/whisper_live/tensorrt_utils.py new file mode 100644 index 0000000..9752e7a --- /dev/null +++ b/whisper_live/tensorrt_utils.py @@ -0,0 +1,365 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from collections import defaultdict +from functools import lru_cache +from pathlib import Path +from subprocess import CalledProcessError, run +from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union + +import kaldialign +import numpy as np +import soundfile +import torch +import torch.nn.functional as F + +Pathlike = Union[str, Path] + +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk + + +def load_audio(file: str, sr: int = SAMPLE_RATE): + """ + Open an audio file and read as mono waveform, resampling as necessary + + Parameters + ---------- + file: str + The audio file to open + + sr: int + The sample rate to resample the audio if necessary + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + + # This launches a subprocess to decode audio while down-mixing + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # fmt: off + cmd = [ + "ffmpeg", "-nostdin", "-threads", "0", "-i", file, "-f", "s16le", "-ac", + "1", "-acodec", "pcm_s16le", "-ar", + str(sr), "-" + ] + # fmt: on + try: + out = run(cmd, capture_output=True, check=True).stdout + except CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + +def load_audio_wav_format(wav_path): + # make sure audio in .wav format + assert wav_path.endswith( + '.wav'), f"Only support .wav format, but got {wav_path}" + waveform, sample_rate = soundfile.read(wav_path) + assert sample_rate == 16000, f"Only support 16k sample rate, but got {sample_rate}" + return waveform, sample_rate + + +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select(dim=axis, + index=torch.arange(length, + device=array.device)) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad(array, + [pad for sizes in pad_widths[::-1] for pad in sizes]) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, + n_mels: int, + mel_filters_dir: str = None) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" + if mel_filters_dir is None: + mel_filters_path = os.path.join(os.path.dirname(__file__), "assets", + "mel_filters.npz") + else: + mel_filters_path = os.path.join(mel_filters_dir, "mel_filters.npz") + with np.load(mel_filters_path) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + +def log_mel_spectrogram( + audio: Union[str, np.ndarray, torch.Tensor], + n_mels: int, + padding: int = 0, + device: Optional[Union[str, torch.device]] = None, + return_duration: bool = False, + mel_filters_dir: str = None, +): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 and 128 are supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80 or 128, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + if audio.endswith('.wav'): + audio, _ = load_audio_wav_format(audio) + else: + audio = load_audio(audio) + assert isinstance(audio, + np.ndarray), f"Unsupported audio type: {type(audio)}" + duration = audio.shape[-1] / SAMPLE_RATE + audio = pad_or_trim(audio, N_SAMPLES) + audio = audio.astype(np.float32) + audio = torch.from_numpy(audio) + + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, + N_FFT, + HOP_LENGTH, + window=window, + return_complex=True) + magnitudes = stft[..., :-1].abs()**2 + + filters = mel_filters(audio.device, n_mels, mel_filters_dir) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + if return_duration: + return log_spec, duration + else: + return log_spec + + +def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str, + str]]) -> None: + """Save predicted results and reference transcripts to a file. + https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + Args: + filename: + File to save the results to. + texts: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + Returns: + Return None. + """ + with open(filename, "w") as f: + for cut_id, ref, hyp in texts: + print(f"{cut_id}:\tref={ref}", file=f) + print(f"{cut_id}:\thyp={hyp}", file=f) + + +def write_error_stats( # noqa: C901 + f: TextIO, + test_set_name: str, + results: List[Tuple[str, str]], + enable_log: bool = True, +) -> float: + """Write statistics based on predicted results and reference transcripts. + https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + It will write the following to the given file: + + - WER + - number of insertions, deletions, substitutions, corrects and total + reference words. For example:: + + Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 + reference words (2337 correct) + + - The difference between the reference transcript and predicted result. + An instance is given below:: + + THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES + + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). + + Another example is:: + + FOR THE FIRST DAY (SIR->*) I THINK + + The reference word `SIR` is missing in the predicted + results (a deletion error). + results: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. + Returns: + Return None. + """ + subs: Dict[Tuple[str, str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) + num_corr = 0 + ERR = "*" + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + for ref_word, hyp_word in ali: + if ref_word == ERR: + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + elif hyp_word == ERR: + dels[ref_word] += 1 + words[ref_word][4] += 1 + elif hyp_word != ref_word: + subs[(ref_word, hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + ref_len = sum([len(r) for _, r, _ in results]) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + + if enable_log: + logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]") + + print(f"%WER = {tot_err_rate}", file=f) + print( + f"Errors: {ins_errs} insertions, {del_errs} deletions, " + f"{sub_errs} substitutions, over {ref_len} reference " + f"words ({num_corr} correct)", + file=f, + ) + print( + "Search below for sections starting with PER-UTT DETAILS:, " + "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f, + ) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + combine_successive_errors = True + if combine_successive_errors: + ali = [[[x], [y]] for x, y in ali] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: + ali[i + 1][0] = ali[i][0] + ali[i + 1][0] + ali[i + 1][1] = ali[i][1] + ali[i + 1][1] + ali[i] = [[], []] + ali = [[ + list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y)), + ] for x, y in ali] + ali = list(filter(lambda x: x != [[], []], ali)) + ali = [[ + ERR if x == [] else " ".join(x), + ERR if y == [] else " ".join(y), + ] for x, y in ali] + + print( + f"{cut_id}:\t" + " ".join((ref_word if ref_word == hyp_word else + f"({ref_word}->{hyp_word})" + for ref_word, hyp_word in ali)), + file=f, + ) + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], + reverse=True): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + print("", file=f) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", + file=f) + for _, word, counts in sorted([(sum(v[1:]), k, v) + for k, v in words.items()], + reverse=True): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + ref_count = corr + ref_sub + dels + hyp_count = corr + hyp_sub + ins + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) + return float(tot_err_rate) diff --git a/whisper_live/transcriber.py b/whisper_live/transcriber.py new file mode 100644 index 0000000..cfa2e56 --- /dev/null +++ b/whisper_live/transcriber.py @@ -0,0 +1,1889 @@ +# original https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py + +import itertools +import json +import logging +import os +import zlib + +from dataclasses import asdict, dataclass +from inspect import signature +from math import ceil +from typing import BinaryIO, Iterable, List, Optional, Tuple, Union +from warnings import warn + +import ctranslate2 +import numpy as np +import tokenizers + +from tqdm import tqdm + +from faster_whisper.audio import decode_audio, pad_or_trim +from faster_whisper.feature_extractor import FeatureExtractor +from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer +from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger +from faster_whisper.vad import ( + SpeechTimestampsMap, + VadOptions, + collect_chunks, + get_speech_timestamps, + merge_segments, +) + + +@dataclass +class Word: + start: float + end: float + word: str + probability: float + + def _asdict(self): + warn( + "Word._asdict() method is deprecated, use dataclasses.asdict(Word) instead", + DeprecationWarning, + 2, + ) + return asdict(self) + + +@dataclass +class Segment: + id: int + seek: int + start: float + end: float + text: str + tokens: List[int] + avg_logprob: float + compression_ratio: float + no_speech_prob: float + words: Optional[List[Word]] + temperature: Optional[float] + + def _asdict(self): + warn( + "Segment._asdict() method is deprecated, use dataclasses.asdict(Segment) instead", + DeprecationWarning, + 2, + ) + return asdict(self) + + +@dataclass +class TranscriptionOptions: + beam_size: int + best_of: int + patience: float + length_penalty: float + repetition_penalty: float + no_repeat_ngram_size: int + log_prob_threshold: Optional[float] + no_speech_threshold: Optional[float] + compression_ratio_threshold: Optional[float] + condition_on_previous_text: bool + prompt_reset_on_temperature: float + temperatures: List[float] + initial_prompt: Optional[Union[str, Iterable[int]]] + prefix: Optional[str] + suppress_blank: bool + suppress_tokens: Optional[List[int]] + without_timestamps: bool + max_initial_timestamp: float + word_timestamps: bool + prepend_punctuations: str + append_punctuations: str + multilingual: bool + max_new_tokens: Optional[int] + clip_timestamps: Union[str, List[float]] + hallucination_silence_threshold: Optional[float] + hotwords: Optional[str] + + +@dataclass +class TranscriptionInfo: + language: str + language_probability: float + duration: float + duration_after_vad: float + all_language_probs: Optional[List[Tuple[str, float]]] + transcription_options: TranscriptionOptions + vad_options: VadOptions + + +class BatchedInferencePipeline: + def __init__( + self, + model, + ): + self.model: WhisperModel = model + self.last_speech_timestamp = 0.0 + + def forward(self, features, tokenizer, chunks_metadata, options): + encoder_output, outputs = self.generate_segment_batched( + features, tokenizer, options + ) + + segmented_outputs = [] + segment_sizes = [] + for chunk_metadata, output in zip(chunks_metadata, outputs): + duration = chunk_metadata["end_time"] - chunk_metadata["start_time"] + segment_size = int(ceil(duration) * self.model.frames_per_second) + segment_sizes.append(segment_size) + ( + subsegments, + seek, + single_timestamp_ending, + ) = self.model._split_segments_by_timestamps( + tokenizer=tokenizer, + tokens=output["tokens"], + time_offset=chunk_metadata["start_time"], + segment_size=segment_size, + segment_duration=duration, + seek=0, + ) + segmented_outputs.append( + [ + dict( + text=tokenizer.decode(subsegment["tokens"]), + avg_logprob=output["avg_logprob"], + no_speech_prob=output["no_speech_prob"], + tokens=subsegment["tokens"], + start=subsegment["start"], + end=subsegment["end"], + compression_ratio=get_compression_ratio( + tokenizer.decode(subsegment["tokens"]) + ), + seek=int( + chunk_metadata["start_time"] * self.model.frames_per_second + ), + ) + for subsegment in subsegments + ] + ) + if options.word_timestamps: + self.last_speech_timestamp = self.model.add_word_timestamps( + segmented_outputs, + tokenizer, + encoder_output, + segment_sizes, + options.prepend_punctuations, + options.append_punctuations, + self.last_speech_timestamp, + ) + + return segmented_outputs + + def generate_segment_batched( + self, + features: np.ndarray, + tokenizer: Tokenizer, + options: TranscriptionOptions, + ): + batch_size = features.shape[0] + + prompt = self.model.get_prompt( + tokenizer, + previous_tokens=( + tokenizer.encode(options.initial_prompt) + if options.initial_prompt is not None + else [] + ), + without_timestamps=options.without_timestamps, + hotwords=options.hotwords, + ) + + if options.max_new_tokens is not None: + max_length = len(prompt) + options.max_new_tokens + else: + max_length = self.model.max_length + + if max_length > self.model.max_length: + raise ValueError( + f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` " + f"{max_length - len(prompt)}. Thus, the combined length of the prompt " + f"and `max_new_tokens` is: {max_length}. This exceeds the " + f"`max_length` of the Whisper model: {self.model.max_length}. " + "You should either reduce the length of your prompt, or " + "reduce the value of `max_new_tokens`, " + f"so that their combined length is less that {self.model.max_length}." + ) + + encoder_output = self.model.encode(features) + prompts = [prompt.copy() for _ in range(batch_size)] + + if options.multilingual: + language_tokens = [ + tokenizer.tokenizer.token_to_id(segment_langs[0][0]) + for segment_langs in self.model.model.detect_language(encoder_output) + ] + language_token_index = prompt.index(tokenizer.language) + + for i, language_token in enumerate(language_tokens): + prompts[i][language_token_index] = language_token + + results = self.model.model.generate( + encoder_output, + prompts, + beam_size=options.beam_size, + patience=options.patience, + length_penalty=options.length_penalty, + max_length=max_length, + suppress_blank=options.suppress_blank, + suppress_tokens=options.suppress_tokens, + return_scores=True, + return_no_speech_prob=True, + sampling_temperature=options.temperatures[0], + repetition_penalty=options.repetition_penalty, + no_repeat_ngram_size=options.no_repeat_ngram_size, + ) + + output = [] + for result in results: + # return scores + seq_len = len(result.sequences_ids[0]) + cum_logprob = result.scores[0] * (seq_len**options.length_penalty) + + output.append( + dict( + avg_logprob=cum_logprob / (seq_len + 1), + no_speech_prob=result.no_speech_prob, + tokens=result.sequences_ids[0], + ) + ) + + return encoder_output, output + + def transcribe( + self, + audio: Union[str, BinaryIO, np.ndarray], + language: Optional[str] = None, + task: str = "transcribe", + log_progress: bool = False, + beam_size: int = 5, + best_of: int = 5, + patience: float = 1, + length_penalty: float = 1, + repetition_penalty: float = 1, + no_repeat_ngram_size: int = 0, + temperature: Union[float, List[float], Tuple[float, ...]] = [ + 0.0, + 0.2, + 0.4, + 0.6, + 0.8, + 1.0, + ], + compression_ratio_threshold: Optional[float] = 2.4, + log_prob_threshold: Optional[float] = -1.0, + no_speech_threshold: Optional[float] = 0.6, + condition_on_previous_text: bool = True, + prompt_reset_on_temperature: float = 0.5, + initial_prompt: Optional[Union[str, Iterable[int]]] = None, + prefix: Optional[str] = None, + suppress_blank: bool = True, + suppress_tokens: Optional[List[int]] = [-1], + without_timestamps: bool = True, + max_initial_timestamp: float = 1.0, + word_timestamps: bool = False, + prepend_punctuations: str = "\"'“¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", + multilingual: bool = False, + vad_filter: bool = True, + vad_parameters: Optional[Union[dict, VadOptions]] = None, + max_new_tokens: Optional[int] = None, + chunk_length: Optional[int] = None, + clip_timestamps: Optional[List[dict]] = None, + hallucination_silence_threshold: Optional[float] = None, + batch_size: int = 8, + hotwords: Optional[str] = None, + language_detection_threshold: Optional[float] = 0.5, + language_detection_segments: int = 1, + ) -> Tuple[Iterable[Segment], TranscriptionInfo]: + """transcribe audio in chunks in batched fashion and return with language info. + + Arguments: + audio: Path to the input file (or a file-like object), or the audio waveform. + language: The language spoken in the audio. It should be a language code such + as "en" or "fr". If not set, the language will be detected in the first 30 seconds + of audio. + task: Task to execute (transcribe or translate). + log_progress: whether to show progress bar or not. + beam_size: Beam size to use for decoding. + best_of: Number of candidates when sampling with non-zero temperature. + patience: Beam search patience factor. + length_penalty: Exponential length penalty constant. + repetition_penalty: Penalty applied to the score of previously generated tokens + (set > 1 to penalize). + no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). + temperature: Temperature for sampling. If a list or tuple is passed, + only the first value is used. + initial_prompt: Optional text string or iterable of token ids to provide as a + prompt for the each window. + suppress_blank: Suppress blank outputs at the beginning of the sampling. + suppress_tokens: List of token IDs to suppress. -1 will suppress a default set + of symbols as defined in `tokenizer.non_speech_tokens()`. + without_timestamps: Only sample text tokens. + word_timestamps: Extract word-level timestamps using the cross-attention pattern + and dynamic time warping, and include the timestamps for each word in each segment. + Set as False. + prepend_punctuations: If word_timestamps is True, merge these punctuation symbols + with the next word + append_punctuations: If word_timestamps is True, merge these punctuation symbols + with the previous word + multilingual: Perform language detection on every segment. + vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio + without speech. This step is using the Silero VAD model + https://github.com/snakers4/silero-vad. + vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available + parameters and default values in the class `VadOptions`). + max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set, + the maximum will be set by the default max_length. + chunk_length: The length of audio segments. If it is not None, it will overwrite the + default chunk_length of the FeatureExtractor. + clip_timestamps: Optionally provide list of dictionaries each containing "start" and + "end" keys that specify the start and end of the voiced region within + `chunk_length` boundary. vad_filter will be ignored if clip_timestamps is used. + batch_size: the maximum number of parallel requests to model for decoding. + hotwords: + Hotwords/hint phrases to the model. Has no effect if prefix is not None. + language_detection_threshold: If the maximum probability of the language tokens is + higher than this value, the language is detected. + language_detection_segments: Number of segments to consider for the language detection. + + Unused Arguments + compression_ratio_threshold: If the gzip compression ratio is above this value, + treat as failed. + log_prob_threshold: If the average log probability over sampled tokens is + below this value, treat as failed. + no_speech_threshold: If the no_speech probability is higher than this value AND + the average log probability over sampled tokens is below `log_prob_threshold`, + consider the segment as silent. + condition_on_previous_text: If True, the previous output of the model is provided + as a prompt for the next window; disabling may make the text inconsistent across + windows, but the model becomes less prone to getting stuck in a failure loop, + such as repetition looping or timestamps going out of sync. Set as False + prompt_reset_on_temperature: Resets prompt if temperature is above this value. + Arg has effect only if condition_on_previous_text is True. Set at 0.5 + prefix: Optional text to provide as a prefix at the beginning of each window. + max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0. + hallucination_silence_threshold: Optional[float] + When word_timestamps is True, skip silent periods longer than this threshold + (in seconds) when a possible hallucination is detected. set as None. + Returns: + A tuple with: + + - a generator over transcribed segments + - an instance of TranscriptionInfo + """ + + sampling_rate = self.model.feature_extractor.sampling_rate + + if multilingual and not self.model.model.is_multilingual: + self.model.logger.warning( + "The current model is English-only but the multilingual parameter is set to" + "True; setting to False instead." + ) + multilingual = False + + if not isinstance(audio, np.ndarray): + audio = decode_audio(audio, sampling_rate=sampling_rate) + duration = audio.shape[0] / sampling_rate + + chunk_length = chunk_length or self.model.feature_extractor.chunk_length + # if no segment split is provided, use vad_model and generate segments + if not clip_timestamps: + if vad_filter: + if vad_parameters is None: + vad_parameters = VadOptions( + max_speech_duration_s=chunk_length, + min_silence_duration_ms=160, + ) + elif isinstance(vad_parameters, dict): + if "max_speech_duration_s" in vad_parameters.keys(): + vad_parameters.pop("max_speech_duration_s") + + vad_parameters = VadOptions( + **vad_parameters, max_speech_duration_s=chunk_length + ) + + active_segments = get_speech_timestamps(audio, vad_parameters) + clip_timestamps = merge_segments(active_segments, vad_parameters) + # run the audio if it is less than 30 sec even without clip_timestamps + elif duration < chunk_length: + clip_timestamps = [{"start": 0, "end": audio.shape[0]}] + else: + raise RuntimeError( + "No clip timestamps found. " + "Set 'vad_filter' to True or provide 'clip_timestamps'." + ) + + duration_after_vad = ( + sum((segment["end"] - segment["start"]) for segment in clip_timestamps) + / sampling_rate + ) + + audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps) + features = ( + [self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks] + if duration_after_vad + else [] + ) + + all_language_probs = None + # detecting the language if not provided + if language is None: + if not self.model.model.is_multilingual: + language = "en" + language_probability = 1 + else: + ( + language, + language_probability, + all_language_probs, + ) = self.model.detect_language( + features=np.concatenate( + features + + [ + np.full((self.model.model.n_mels, 1), -1.5, dtype="float32") + ], + axis=1, + ), # add a dummy feature to account for empty audio + language_detection_segments=language_detection_segments, + language_detection_threshold=language_detection_threshold, + ) + + self.model.logger.info( + "Detected language '%s' with probability %.2f", + language, + language_probability, + ) + else: + if not self.model.model.is_multilingual and language != "en": + self.model.logger.warning( + "The current model is English-only but the language parameter is set to '%s'; " + "using 'en' instead." % language + ) + language = "en" + + language_probability = 1 + + tokenizer = Tokenizer( + self.model.hf_tokenizer, + self.model.model.is_multilingual, + task=task, + language=language, + ) + + features = ( + np.stack([pad_or_trim(feature) for feature in features]) if features else [] + ) + + options = TranscriptionOptions( + beam_size=beam_size, + best_of=best_of, + patience=patience, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + log_prob_threshold=log_prob_threshold, + no_speech_threshold=no_speech_threshold, + compression_ratio_threshold=compression_ratio_threshold, + temperatures=( + temperature[:1] + if isinstance(temperature, (list, tuple)) + else [temperature] + ), + initial_prompt=initial_prompt, + prefix=prefix, + suppress_blank=suppress_blank, + suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens), + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, + max_new_tokens=max_new_tokens, + hotwords=hotwords, + word_timestamps=word_timestamps, + hallucination_silence_threshold=None, + condition_on_previous_text=False, + clip_timestamps=clip_timestamps, + prompt_reset_on_temperature=0.5, + multilingual=multilingual, + without_timestamps=without_timestamps, + max_initial_timestamp=0.0, + ) + + info = TranscriptionInfo( + language=language, + language_probability=language_probability, + duration=duration, + duration_after_vad=duration_after_vad, + transcription_options=options, + vad_options=vad_parameters, + all_language_probs=all_language_probs, + ) + + segments = self._batched_segments_generator( + features, + tokenizer, + chunks_metadata, + batch_size, + options, + log_progress, + ) + + return segments, info + + def _batched_segments_generator( + self, features, tokenizer, chunks_metadata, batch_size, options, log_progress + ): + pbar = tqdm(total=len(features), disable=not log_progress, position=0) + seg_idx = 0 + for i in range(0, len(features), batch_size): + results = self.forward( + features[i : i + batch_size], + tokenizer, + chunks_metadata[i : i + batch_size], + options, + ) + + for result in results: + for segment in result: + seg_idx += 1 + yield Segment( + seek=segment["seek"], + id=seg_idx, + text=segment["text"], + start=round(segment["start"], 3), + end=round(segment["end"], 3), + words=( + None + if not options.word_timestamps + else [Word(**word) for word in segment["words"]] + ), + tokens=segment["tokens"], + avg_logprob=segment["avg_logprob"], + no_speech_prob=segment["no_speech_prob"], + compression_ratio=segment["compression_ratio"], + temperature=options.temperatures[0], + ) + + pbar.update(1) + + pbar.close() + self.last_speech_timestamp = 0.0 + + +class WhisperModel: + def __init__( + self, + model_size_or_path: str, + device: str = "auto", + device_index: Union[int, List[int]] = 0, + compute_type: str = "default", + cpu_threads: int = 0, + num_workers: int = 1, + download_root: Optional[str] = None, + local_files_only: bool = False, + files: dict = None, + **model_kwargs, + ): + """Initializes the Whisper model. + + Args: + model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en, + small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, + large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo), + a path to a converted model directory, or a CTranslate2-converted Whisper model ID from + the HF Hub. When a size or a model ID is configured, the converted model is downloaded + from the Hugging Face Hub. + device: Device to use for computation ("cpu", "cuda", "auto"). + device_index: Device ID to use. + The model can also be loaded on multiple GPUs by passing a list of IDs + (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel + when transcribe() is called from multiple Python threads (see also num_workers). + compute_type: Type to use for computation. + See https://opennmt.net/CTranslate2/quantization.html. + cpu_threads: Number of threads to use when running on CPU (4 by default). + A non zero value overrides the OMP_NUM_THREADS environment variable. + num_workers: When transcribe() is called from multiple Python threads, + having multiple workers enables true parallelism when running the model + (concurrent calls to self.model.generate() will run in parallel). + This can improve the global throughput at the cost of increased memory usage. + download_root: Directory where the models should be saved. If not set, the models + are saved in the standard Hugging Face cache directory. + local_files_only: If True, avoid downloading the file and return the path to the + local cached file if it exists. + files: Load model files from the memory. This argument is a dictionary mapping file names + to file contents as file-like or bytes objects. If this is set, model_path acts as an + identifier for this model. + """ + self.logger = get_logger() + + tokenizer_bytes, preprocessor_bytes = None, None + if files: + model_path = model_size_or_path + tokenizer_bytes = files.pop("tokenizer.json", None) + preprocessor_bytes = files.pop("preprocessor_config.json", None) + elif os.path.isdir(model_size_or_path): + model_path = model_size_or_path + else: + model_path = download_model( + model_size_or_path, + local_files_only=local_files_only, + cache_dir=download_root, + ) + + self.model = ctranslate2.models.Whisper( + model_path, + device=device, + device_index=device_index, + compute_type=compute_type, + intra_threads=cpu_threads, + inter_threads=num_workers, + files=files, + **model_kwargs, + ) + + tokenizer_file = os.path.join(model_path, "tokenizer.json") + if tokenizer_bytes: + self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes) + elif os.path.isfile(tokenizer_file): + self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) + else: + self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained( + "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") + ) + self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes) + self.feature_extractor = FeatureExtractor(**self.feat_kwargs) + self.input_stride = 2 + self.num_samples_per_token = ( + self.feature_extractor.hop_length * self.input_stride + ) + self.frames_per_second = ( + self.feature_extractor.sampling_rate // self.feature_extractor.hop_length + ) + self.tokens_per_second = ( + self.feature_extractor.sampling_rate // self.num_samples_per_token + ) + self.time_precision = 0.02 + self.max_length = 448 + + @property + def supported_languages(self) -> List[str]: + """The languages supported by the model.""" + return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"] + + def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict: + config = {} + try: + config_path = os.path.join(model_path, "preprocessor_config.json") + if preprocessor_bytes: + config = json.loads(preprocessor_bytes) + elif os.path.isfile(config_path): + with open(config_path, "r", encoding="utf-8") as file: + config = json.load(file) + else: + return config + valid_keys = signature(FeatureExtractor.__init__).parameters.keys() + return {k: v for k, v in config.items() if k in valid_keys} + except json.JSONDecodeError as e: + self.logger.warning("Could not load preprocessor config: %s", e) + + return config + + def transcribe( + self, + audio: Union[str, BinaryIO, np.ndarray], + language: Optional[str] = None, + task: str = "transcribe", + log_progress: bool = False, + beam_size: int = 5, + best_of: int = 5, + patience: float = 1, + length_penalty: float = 1, + repetition_penalty: float = 1, + no_repeat_ngram_size: int = 0, + temperature: Union[float, List[float], Tuple[float, ...]] = [ + 0.0, + 0.2, + 0.4, + 0.6, + 0.8, + 1.0, + ], + compression_ratio_threshold: Optional[float] = 2.4, + log_prob_threshold: Optional[float] = -1.0, + no_speech_threshold: Optional[float] = 0.6, + condition_on_previous_text: bool = True, + prompt_reset_on_temperature: float = 0.5, + initial_prompt: Optional[Union[str, Iterable[int]]] = None, + prefix: Optional[str] = None, + suppress_blank: bool = True, + suppress_tokens: Optional[List[int]] = [-1], + without_timestamps: bool = False, + max_initial_timestamp: float = 1.0, + word_timestamps: bool = False, + prepend_punctuations: str = "\"'“¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", + multilingual: bool = False, + vad_filter: bool = False, + vad_parameters: Optional[Union[dict, VadOptions]] = None, + max_new_tokens: Optional[int] = None, + chunk_length: Optional[int] = None, + clip_timestamps: Union[str, List[float]] = "0", + hallucination_silence_threshold: Optional[float] = None, + hotwords: Optional[str] = None, + language_detection_threshold: Optional[float] = 0.5, + language_detection_segments: int = 1, + ) -> Tuple[Iterable[Segment], TranscriptionInfo]: + """Transcribes an input file. + + Arguments: + audio: Path to the input file (or a file-like object), or the audio waveform. + language: The language spoken in the audio. It should be a language code such + as "en" or "fr". If not set, the language will be detected in the first 30 seconds + of audio. + task: Task to execute (transcribe or translate). + log_progress: whether to show progress bar or not. + beam_size: Beam size to use for decoding. + best_of: Number of candidates when sampling with non-zero temperature. + patience: Beam search patience factor. + length_penalty: Exponential length penalty constant. + repetition_penalty: Penalty applied to the score of previously generated tokens + (set > 1 to penalize). + no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). + temperature: Temperature for sampling. It can be a tuple of temperatures, + which will be successively used upon failures according to either + `compression_ratio_threshold` or `log_prob_threshold`. + compression_ratio_threshold: If the gzip compression ratio is above this value, + treat as failed. + log_prob_threshold: If the average log probability over sampled tokens is + below this value, treat as failed. + no_speech_threshold: If the no_speech probability is higher than this value AND + the average log probability over sampled tokens is below `log_prob_threshold`, + consider the segment as silent. + condition_on_previous_text: If True, the previous output of the model is provided + as a prompt for the next window; disabling may make the text inconsistent across + windows, but the model becomes less prone to getting stuck in a failure loop, + such as repetition looping or timestamps going out of sync. + prompt_reset_on_temperature: Resets prompt if temperature is above this value. + Arg has effect only if condition_on_previous_text is True. + initial_prompt: Optional text string or iterable of token ids to provide as a + prompt for the first window. + prefix: Optional text to provide as a prefix for the first window. + suppress_blank: Suppress blank outputs at the beginning of the sampling. + suppress_tokens: List of token IDs to suppress. -1 will suppress a default set + of symbols as defined in `tokenizer.non_speech_tokens()`. + without_timestamps: Only sample text tokens. + max_initial_timestamp: The initial timestamp cannot be later than this. + word_timestamps: Extract word-level timestamps using the cross-attention pattern + and dynamic time warping, and include the timestamps for each word in each segment. + prepend_punctuations: If word_timestamps is True, merge these punctuation symbols + with the next word + append_punctuations: If word_timestamps is True, merge these punctuation symbols + with the previous word + multilingual: Perform language detection on every segment. + vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio + without speech. This step is using the Silero VAD model + https://github.com/snakers4/silero-vad. + vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available + parameters and default values in the class `VadOptions`). + max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set, + the maximum will be set by the default max_length. + chunk_length: The length of audio segments. If it is not None, it will overwrite the + default chunk_length of the FeatureExtractor. + clip_timestamps: + Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to + process. The last end timestamp defaults to the end of the file. + vad_filter will be ignored if clip_timestamps is used. + hallucination_silence_threshold: + When word_timestamps is True, skip silent periods longer than this threshold + (in seconds) when a possible hallucination is detected + hotwords: + Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None. + language_detection_threshold: If the maximum probability of the language tokens is higher + than this value, the language is detected. + language_detection_segments: Number of segments to consider for the language detection. + Returns: + A tuple with: + + - a generator over transcribed segments + - an instance of TranscriptionInfo + """ + sampling_rate = self.feature_extractor.sampling_rate + + if multilingual and not self.model.is_multilingual: + self.logger.warning( + "The current model is English-only but the multilingual parameter is set to" + "True; setting to False instead." + ) + multilingual = False + + if not isinstance(audio, np.ndarray): + audio = decode_audio(audio, sampling_rate=sampling_rate) + + duration = audio.shape[0] / sampling_rate + duration_after_vad = duration + + self.logger.info( + "Processing audio with duration %s", format_timestamp(duration) + ) + + if vad_filter and clip_timestamps == "0": + if vad_parameters is None: + vad_parameters = VadOptions() + elif isinstance(vad_parameters, dict): + vad_parameters = VadOptions(**vad_parameters) + speech_chunks = get_speech_timestamps(audio, vad_parameters) + audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks) + audio = np.concatenate(audio_chunks, axis=0) + duration_after_vad = audio.shape[0] / sampling_rate + + self.logger.info( + "VAD filter removed %s of audio", + format_timestamp(duration - duration_after_vad), + ) + + if self.logger.isEnabledFor(logging.DEBUG): + self.logger.debug( + "VAD filter kept the following audio segments: %s", + ", ".join( + "[%s -> %s]" + % ( + format_timestamp(chunk["start"] / sampling_rate), + format_timestamp(chunk["end"] / sampling_rate), + ) + for chunk in speech_chunks + ), + ) + + else: + speech_chunks = None + if audio.shape[0] == 0: + return None, None + features = self.feature_extractor(audio, chunk_length=chunk_length) + + encoder_output = None + all_language_probs = None + + # detecting the language if not provided + if language is None: + if not self.model.is_multilingual: + language = "en" + language_probability = 1 + else: + start_timestamp = ( + float(clip_timestamps.split(",")[0]) + if isinstance(clip_timestamps, str) + else clip_timestamps[0] + ) + content_frames = features.shape[-1] - 1 + seek = ( + int(start_timestamp * self.frames_per_second) + if start_timestamp * self.frames_per_second < content_frames + else 0 + ) + ( + language, + language_probability, + all_language_probs, + ) = self.detect_language( + features=features[..., seek:], + language_detection_segments=language_detection_segments, + language_detection_threshold=language_detection_threshold, + ) + + self.logger.info( + "Detected language '%s' with probability %.2f", + language, + language_probability, + ) + else: + if not self.model.is_multilingual and language != "en": + self.logger.warning( + "The current model is English-only but the language parameter is set to '%s'; " + "using 'en' instead." % language + ) + language = "en" + + language_probability = 1 + + tokenizer = Tokenizer( + self.hf_tokenizer, + self.model.is_multilingual, + task=task, + language=language, + ) + + options = TranscriptionOptions( + beam_size=beam_size, + best_of=best_of, + patience=patience, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + log_prob_threshold=log_prob_threshold, + no_speech_threshold=no_speech_threshold, + compression_ratio_threshold=compression_ratio_threshold, + condition_on_previous_text=condition_on_previous_text, + prompt_reset_on_temperature=prompt_reset_on_temperature, + temperatures=( + temperature if isinstance(temperature, (list, tuple)) else [temperature] + ), + initial_prompt=initial_prompt, + prefix=prefix, + suppress_blank=suppress_blank, + suppress_tokens=( + get_suppressed_tokens(tokenizer, suppress_tokens) + if suppress_tokens + else suppress_tokens + ), + without_timestamps=without_timestamps, + max_initial_timestamp=max_initial_timestamp, + word_timestamps=word_timestamps, + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, + multilingual=multilingual, + max_new_tokens=max_new_tokens, + clip_timestamps=clip_timestamps, + hallucination_silence_threshold=hallucination_silence_threshold, + hotwords=hotwords, + ) + + segments = self.generate_segments( + features, tokenizer, options, log_progress, encoder_output + ) + + if speech_chunks: + segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate) + + info = TranscriptionInfo( + language=language, + language_probability=language_probability, + duration=duration, + duration_after_vad=duration_after_vad, + transcription_options=options, + vad_options=vad_parameters, + all_language_probs=all_language_probs, + ) + + return segments, info + + def _split_segments_by_timestamps( + self, + tokenizer: Tokenizer, + tokens: List[int], + time_offset: float, + segment_size: int, + segment_duration: float, + seek: int, + ) -> List[List[int]]: + current_segments = [] + single_timestamp_ending = ( + len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1] + ) + + consecutive_timestamps = [ + i + for i in range(len(tokens)) + if i > 0 + and tokens[i] >= tokenizer.timestamp_begin + and tokens[i - 1] >= tokenizer.timestamp_begin + ] + + if len(consecutive_timestamps) > 0: + slices = list(consecutive_timestamps) + if single_timestamp_ending: + slices.append(len(tokens)) + + last_slice = 0 + for current_slice in slices: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_position = sliced_tokens[0] - tokenizer.timestamp_begin + end_timestamp_position = sliced_tokens[-1] - tokenizer.timestamp_begin + start_time = ( + time_offset + start_timestamp_position * self.time_precision + ) + end_time = time_offset + end_timestamp_position * self.time_precision + + current_segments.append( + dict( + seek=seek, + start=start_time, + end=end_time, + tokens=sliced_tokens, + ) + ) + last_slice = current_slice + + if single_timestamp_ending: + # single timestamp at the end means no speech after the last timestamp. + seek += segment_size + else: + # otherwise, ignore the unfinished segment and seek to the last timestamp + last_timestamp_position = ( + tokens[last_slice - 1] - tokenizer.timestamp_begin + ) + seek += last_timestamp_position * self.input_stride + + else: + duration = segment_duration + timestamps = [ + token for token in tokens if token >= tokenizer.timestamp_begin + ] + if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin: + last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin + duration = last_timestamp_position * self.time_precision + + current_segments.append( + dict( + seek=seek, + start=time_offset, + end=time_offset + duration, + tokens=tokens, + ) + ) + + seek += segment_size + + return current_segments, seek, single_timestamp_ending + + def generate_segments( + self, + features: np.ndarray, + tokenizer: Tokenizer, + options: TranscriptionOptions, + log_progress, + encoder_output: Optional[ctranslate2.StorageView] = None, + ) -> Iterable[Segment]: + content_frames = features.shape[-1] - 1 + content_duration = float(content_frames * self.feature_extractor.time_per_frame) + + if isinstance(options.clip_timestamps, str): + options.clip_timestamps = [ + float(ts) + for ts in ( + options.clip_timestamps.split(",") + if options.clip_timestamps + else [] + ) + ] + + seek_points: List[int] = [ + round(ts * self.frames_per_second) for ts in options.clip_timestamps + ] + if len(seek_points) == 0: + seek_points.append(0) + if len(seek_points) % 2 == 1: + seek_points.append(content_frames) + seek_clips: List[Tuple[int, int]] = list( + zip(seek_points[::2], seek_points[1::2]) + ) + + punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" + + idx = 0 + clip_idx = 0 + seek = seek_clips[clip_idx][0] + all_tokens = [] + prompt_reset_since = 0 + + if options.initial_prompt is not None: + if isinstance(options.initial_prompt, str): + initial_prompt = " " + options.initial_prompt.strip() + initial_prompt_tokens = tokenizer.encode(initial_prompt) + all_tokens.extend(initial_prompt_tokens) + else: + all_tokens.extend(options.initial_prompt) + + pbar = tqdm(total=content_duration, unit="seconds", disable=not log_progress) + last_speech_timestamp = 0.0 + all_segments = [] + # NOTE: This loop is obscurely flattened to make the diff readable. + # A later commit should turn this into a simpler nested loop. + # for seek_clip_start, seek_clip_end in seek_clips: + # while seek < seek_clip_end + while clip_idx < len(seek_clips): + seek_clip_start, seek_clip_end = seek_clips[clip_idx] + if seek_clip_end > content_frames: + seek_clip_end = content_frames + if seek < seek_clip_start: + seek = seek_clip_start + if seek >= seek_clip_end: + clip_idx += 1 + if clip_idx < len(seek_clips): + seek = seek_clips[clip_idx][0] + continue + time_offset = seek * self.feature_extractor.time_per_frame + window_end_time = float( + (seek + self.feature_extractor.nb_max_frames) + * self.feature_extractor.time_per_frame + ) + segment_size = min( + self.feature_extractor.nb_max_frames, + content_frames - seek, + seek_clip_end - seek, + ) + segment = features[:, seek : seek + segment_size] + segment_duration = segment_size * self.feature_extractor.time_per_frame + segment = pad_or_trim(segment) + + if self.logger.isEnabledFor(logging.DEBUG): + self.logger.debug( + "Processing segment at %s", format_timestamp(time_offset) + ) + + previous_tokens = all_tokens[prompt_reset_since:] + + if seek > 0 or encoder_output is None: + encoder_output = self.encode(segment) + + if options.multilingual: + results = self.model.detect_language(encoder_output) + language_token, language_probability = results[0][0] + language = language_token[2:-2] + + tokenizer.language = tokenizer.tokenizer.token_to_id(language_token) + tokenizer.language_code = language + + prompt = self.get_prompt( + tokenizer, + previous_tokens, + without_timestamps=options.without_timestamps, + prefix=options.prefix if seek == 0 else None, + hotwords=options.hotwords, + ) + + ( + result, + avg_logprob, + temperature, + compression_ratio, + ) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options) + + if options.no_speech_threshold is not None: + # no voice activity check + should_skip = result.no_speech_prob > options.no_speech_threshold + + if ( + options.log_prob_threshold is not None + and avg_logprob > options.log_prob_threshold + ): + # don't skip if the logprob is high enough, despite the no_speech_prob + should_skip = False + + if should_skip: + self.logger.debug( + "No speech threshold is met (%f > %f)", + result.no_speech_prob, + options.no_speech_threshold, + ) + + # fast-forward to the next segment boundary + seek += segment_size + continue + + tokens = result.sequences_ids[0] + + previous_seek = seek + + # anomalous words are very long/short/improbable + def word_anomaly_score(word: dict) -> float: + probability = word.get("probability", 0.0) + duration = word["end"] - word["start"] + score = 0.0 + if probability < 0.15: + score += 1.0 + if duration < 0.133: + score += (0.133 - duration) * 15 + if duration > 2.0: + score += duration - 2.0 + return score + + def is_segment_anomaly(segment: Optional[dict]) -> bool: + if segment is None or not segment["words"]: + return False + words = [w for w in segment["words"] if w["word"] not in punctuation] + words = words[:8] + score = sum(word_anomaly_score(w) for w in words) + return score >= 3 or score + 0.01 >= len(words) + + def next_words_segment(segments: List[dict]) -> Optional[dict]: + return next((s for s in segments if s["words"]), None) + + ( + current_segments, + seek, + single_timestamp_ending, + ) = self._split_segments_by_timestamps( + tokenizer=tokenizer, + tokens=tokens, + time_offset=time_offset, + segment_size=segment_size, + segment_duration=segment_duration, + seek=seek, + ) + + if options.word_timestamps: + self.add_word_timestamps( + [current_segments], + tokenizer, + encoder_output, + segment_size, + options.prepend_punctuations, + options.append_punctuations, + last_speech_timestamp=last_speech_timestamp, + ) + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and last_word_end > time_offset: + seek = round(last_word_end * self.frames_per_second) + + # skip silence before possible hallucinations + if options.hallucination_silence_threshold is not None: + threshold = options.hallucination_silence_threshold + + # if first segment might be a hallucination, skip leading silence + first_segment = next_words_segment(current_segments) + if first_segment is not None and is_segment_anomaly(first_segment): + gap = first_segment["start"] - time_offset + if gap > threshold: + seek = previous_seek + round(gap * self.frames_per_second) + continue + + # skip silence before any possible hallucination that is surrounded + # by silence or more hallucinations + hal_last_end = last_speech_timestamp + for si in range(len(current_segments)): + segment = current_segments[si] + if not segment["words"]: + continue + if is_segment_anomaly(segment): + next_segment = next_words_segment( + current_segments[si + 1 :] + ) + if next_segment is not None: + hal_next_start = next_segment["words"][0]["start"] + else: + hal_next_start = time_offset + segment_duration + silence_before = ( + segment["start"] - hal_last_end > threshold + or segment["start"] < threshold + or segment["start"] - time_offset < 2.0 + ) + silence_after = ( + hal_next_start - segment["end"] > threshold + or is_segment_anomaly(next_segment) + or window_end_time - segment["end"] < 2.0 + ) + if silence_before and silence_after: + seek = round( + max(time_offset + 1, segment["start"]) + * self.frames_per_second + ) + if content_duration - segment["end"] < threshold: + seek = content_frames + current_segments[si:] = [] + break + hal_last_end = segment["end"] + + last_word_end = get_end(current_segments) + if last_word_end is not None: + last_speech_timestamp = last_word_end + for segment in current_segments: + tokens = segment["tokens"] + text = tokenizer.decode(tokens) + + if segment["start"] == segment["end"] or not text.strip(): + continue + + all_tokens.extend(tokens) + idx += 1 + + all_segments.append(Segment( + id=idx, + seek=previous_seek, + start=segment["start"], + end=segment["end"], + text=text, + tokens=tokens, + temperature=temperature, + avg_logprob=avg_logprob, + compression_ratio=compression_ratio, + no_speech_prob=result.no_speech_prob, + words=( + [Word(**word) for word in segment["words"]] + if options.word_timestamps + else None + ), + )) + + if ( + not options.condition_on_previous_text + or temperature > options.prompt_reset_on_temperature + ): + if options.condition_on_previous_text: + self.logger.debug( + "Reset prompt. prompt_reset_on_temperature threshold is met %f > %f", + temperature, + options.prompt_reset_on_temperature, + ) + + prompt_reset_since = len(all_tokens) + + pbar.update( + (min(content_frames, seek) - previous_seek) + * self.feature_extractor.time_per_frame, + ) + pbar.close() + return all_segments + + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: + # When the model is running on multiple GPUs, the encoder output should be moved + # to the CPU since we don't know which GPU will handle the next job. + to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 + + if features.ndim == 2: + features = np.expand_dims(features, 0) + features = get_ctranslate2_storage(features) + + return self.model.encode(features, to_cpu=to_cpu) + + def generate_with_fallback( + self, + encoder_output: ctranslate2.StorageView, + prompt: List[int], + tokenizer: Tokenizer, + options: TranscriptionOptions, + ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]: + decode_result = None + all_results = [] + below_cr_threshold_results = [] + + max_initial_timestamp_index = int( + round(options.max_initial_timestamp / self.time_precision) + ) + if options.max_new_tokens is not None: + max_length = len(prompt) + options.max_new_tokens + else: + max_length = self.max_length + + if max_length > self.max_length: + raise ValueError( + f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` " + f"{max_length - len(prompt)}. Thus, the combined length of the prompt " + f"and `max_new_tokens` is: {max_length}. This exceeds the " + f"`max_length` of the Whisper model: {self.max_length}. " + "You should either reduce the length of your prompt, or " + "reduce the value of `max_new_tokens`, " + f"so that their combined length is less that {self.max_length}." + ) + + for temperature in options.temperatures: + if temperature > 0: + kwargs = { + "beam_size": 1, + "num_hypotheses": options.best_of, + "sampling_topk": 0, + "sampling_temperature": temperature, + } + else: + kwargs = { + "beam_size": options.beam_size, + "patience": options.patience, + } + + result = self.model.generate( + encoder_output, + [prompt], + length_penalty=options.length_penalty, + repetition_penalty=options.repetition_penalty, + no_repeat_ngram_size=options.no_repeat_ngram_size, + max_length=max_length, + return_scores=True, + return_no_speech_prob=True, + suppress_blank=options.suppress_blank, + suppress_tokens=options.suppress_tokens, + max_initial_timestamp_index=max_initial_timestamp_index, + **kwargs, + )[0] + + tokens = result.sequences_ids[0] + + # Recover the average log prob from the returned score. + seq_len = len(tokens) + cum_logprob = result.scores[0] * (seq_len**options.length_penalty) + avg_logprob = cum_logprob / (seq_len + 1) + + text = tokenizer.decode(tokens).strip() + compression_ratio = get_compression_ratio(text) + + decode_result = ( + result, + avg_logprob, + temperature, + compression_ratio, + ) + all_results.append(decode_result) + + needs_fallback = False + + if options.compression_ratio_threshold is not None: + if compression_ratio > options.compression_ratio_threshold: + needs_fallback = True # too repetitive + + self.logger.debug( + "Compression ratio threshold is not met with temperature %.1f (%f > %f)", + temperature, + compression_ratio, + options.compression_ratio_threshold, + ) + else: + below_cr_threshold_results.append(decode_result) + + if ( + options.log_prob_threshold is not None + and avg_logprob < options.log_prob_threshold + ): + needs_fallback = True # average log probability is too low + + self.logger.debug( + "Log probability threshold is not met with temperature %.1f (%f < %f)", + temperature, + avg_logprob, + options.log_prob_threshold, + ) + + if ( + options.no_speech_threshold is not None + and result.no_speech_prob > options.no_speech_threshold + and options.log_prob_threshold is not None + and avg_logprob < options.log_prob_threshold + ): + needs_fallback = False # silence + + if not needs_fallback: + break + else: + # all failed, select the result with the highest average log probability + decode_result = max( + below_cr_threshold_results or all_results, key=lambda x: x[1] + ) + # to pass final temperature for prompt_reset_on_temperature + decode_result = ( + decode_result[0], + decode_result[1], + temperature, + decode_result[3], + ) + + return decode_result + + def get_prompt( + self, + tokenizer: Tokenizer, + previous_tokens: List[int], + without_timestamps: bool = False, + prefix: Optional[str] = None, + hotwords: Optional[str] = None, + ) -> List[int]: + prompt = [] + + if previous_tokens or (hotwords and not prefix): + prompt.append(tokenizer.sot_prev) + if hotwords and not prefix: + hotwords_tokens = tokenizer.encode(" " + hotwords.strip()) + if len(hotwords_tokens) >= self.max_length // 2: + hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1] + prompt.extend(hotwords_tokens) + if previous_tokens: + prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) + + prompt.extend(tokenizer.sot_sequence) + + if without_timestamps: + prompt.append(tokenizer.no_timestamps) + + if prefix: + prefix_tokens = tokenizer.encode(" " + prefix.strip()) + if len(prefix_tokens) >= self.max_length // 2: + prefix_tokens = prefix_tokens[: self.max_length // 2 - 1] + if not without_timestamps: + prompt.append(tokenizer.timestamp_begin) + prompt.extend(prefix_tokens) + + return prompt + + def add_word_timestamps( + self, + segments: List[dict], + tokenizer: Tokenizer, + encoder_output: ctranslate2.StorageView, + num_frames: int, + prepend_punctuations: str, + append_punctuations: str, + last_speech_timestamp: float, + ) -> float: + if len(segments) == 0: + return + + text_tokens = [] + text_tokens_per_segment = [] + for segment in segments: + segment_tokens = [ + [token for token in subsegment["tokens"] if token < tokenizer.eot] + for subsegment in segment + ] + text_tokens.append(list(itertools.chain.from_iterable(segment_tokens))) + text_tokens_per_segment.append(segment_tokens) + + alignments = self.find_alignment( + tokenizer, text_tokens, encoder_output, num_frames + ) + median_max_durations = [] + for alignment in alignments: + word_durations = np.array( + [word["end"] - word["start"] for word in alignment] + ) + word_durations = word_durations[word_durations.nonzero()] + median_duration = ( + np.median(word_durations) if len(word_durations) > 0 else 0.0 + ) + median_duration = min(0.7, float(median_duration)) + max_duration = median_duration * 2 + + # hack: truncate long words at sentence boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(word_durations) > 0: + sentence_end_marks = ".。!!??" + # ensure words at sentence boundaries + # are not longer than twice the median word duration. + for i in range(1, len(alignment)): + if alignment[i]["end"] - alignment[i]["start"] > max_duration: + if alignment[i]["word"] in sentence_end_marks: + alignment[i]["end"] = alignment[i]["start"] + max_duration + elif alignment[i - 1]["word"] in sentence_end_marks: + alignment[i]["start"] = alignment[i]["end"] - max_duration + + merge_punctuations(alignment, prepend_punctuations, append_punctuations) + median_max_durations.append((median_duration, max_duration)) + + for segment_idx, segment in enumerate(segments): + word_index = 0 + time_offset = segment[0]["seek"] / self.frames_per_second + median_duration, max_duration = median_max_durations[segment_idx] + for subsegment_idx, subsegment in enumerate(segment): + saved_tokens = 0 + words = [] + + while word_index < len(alignments[segment_idx]) and saved_tokens < len( + text_tokens_per_segment[segment_idx][subsegment_idx] + ): + timing = alignments[segment_idx][word_index] + + if timing["word"]: + words.append( + dict( + word=timing["word"], + start=round(time_offset + timing["start"], 2), + end=round(time_offset + timing["end"], 2), + probability=timing["probability"], + ) + ) + + saved_tokens += len(timing["tokens"]) + word_index += 1 + + # hack: truncate long words at segment boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(words) > 0: + # ensure the first and second word after a pause is not longer than + # twice the median word duration. + if words[0][ + "end" + ] - last_speech_timestamp > median_duration * 4 and ( + words[0]["end"] - words[0]["start"] > max_duration + or ( + len(words) > 1 + and words[1]["end"] - words[0]["start"] > max_duration * 2 + ) + ): + if ( + len(words) > 1 + and words[1]["end"] - words[1]["start"] > max_duration + ): + boundary = max( + words[1]["end"] / 2, words[1]["end"] - max_duration + ) + words[0]["end"] = words[1]["start"] = boundary + words[0]["start"] = max(0, words[0]["end"] - max_duration) + + # prefer the segment-level start timestamp if the first word is too long. + if ( + subsegment["start"] < words[0]["end"] + and subsegment["start"] - 0.5 > words[0]["start"] + ): + words[0]["start"] = max( + 0, + min(words[0]["end"] - median_duration, subsegment["start"]), + ) + else: + subsegment["start"] = words[0]["start"] + + # prefer the segment-level end timestamp if the last word is too long. + if ( + subsegment["end"] > words[-1]["start"] + and subsegment["end"] + 0.5 < words[-1]["end"] + ): + words[-1]["end"] = max( + words[-1]["start"] + median_duration, subsegment["end"] + ) + else: + subsegment["end"] = words[-1]["end"] + + last_speech_timestamp = subsegment["end"] + segments[segment_idx][subsegment_idx]["words"] = words + return last_speech_timestamp + + def find_alignment( + self, + tokenizer: Tokenizer, + text_tokens: List[int], + encoder_output: ctranslate2.StorageView, + num_frames: int, + median_filter_width: int = 7, + ) -> List[dict]: + if len(text_tokens) == 0: + return [] + + results = self.model.align( + encoder_output, + tokenizer.sot_sequence, + text_tokens, + num_frames, + median_filter_width=median_filter_width, + ) + return_list = [] + for result, text_token in zip(results, text_tokens): + text_token_probs = result.text_token_probs + alignments = result.alignments + text_indices = np.array([pair[0] for pair in alignments]) + time_indices = np.array([pair[1] for pair in alignments]) + + words, word_tokens = tokenizer.split_to_word_tokens( + text_token + [tokenizer.eot] + ) + if len(word_tokens) <= 1: + # return on eot only + # >>> np.pad([], (1, 0)) + # array([0.]) + # This results in crashes when we lookup jump_times with float, like + # IndexError: arrays used as indices must be of integer (or boolean) type + return_list.append([]) + continue + word_boundaries = np.pad( + np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0) + ) + if len(word_boundaries) <= 1: + return_list.append([]) + continue + + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype( + bool + ) + jump_times = time_indices[jumps] / self.tokens_per_second + start_times = jump_times[word_boundaries[:-1]] + end_times = jump_times[word_boundaries[1:]] + word_probabilities = [ + np.mean(text_token_probs[i:j]) + for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + ] + + return_list.append( + [ + dict( + word=word, + tokens=tokens, + start=start, + end=end, + probability=probability, + ) + for word, tokens, start, end, probability in zip( + words, word_tokens, start_times, end_times, word_probabilities + ) + ] + ) + return return_list + + def detect_language( + self, + audio: Optional[np.ndarray] = None, + features: Optional[np.ndarray] = None, + vad_filter: bool = False, + vad_parameters: Union[dict, VadOptions] = None, + language_detection_segments: int = 1, + language_detection_threshold: float = 0.5, + ) -> Tuple[str, float, List[Tuple[str, float]]]: + """ + Use Whisper to detect the language of the input audio or features. + + Arguments: + audio: Input audio signal, must be a 1D float array sampled at 16khz. + features: Input Mel spectrogram features, must be a float array with + shape (n_mels, n_frames), if `audio` is provided, the features will be ignored. + Either `audio` or `features` must be provided. + vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio + without speech. This step is using the Silero VAD model. + vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available + parameters and default values in the class `VadOptions`). + language_detection_threshold: If the maximum probability of the language tokens is + higher than this value, the language is detected. + language_detection_segments: Number of segments to consider for the language detection. + + Returns: + language: Detected language. + languege_probability: Probability of the detected language. + all_language_probs: List of tuples with all language names and probabilities. + """ + assert ( + audio is not None or features is not None + ), "Either `audio` or `features` must be provided." + + if audio is not None: + if vad_filter: + speech_chunks = get_speech_timestamps(audio, vad_parameters) + audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks) + audio = np.concatenate(audio_chunks, axis=0) + + audio = audio[ + : language_detection_segments * self.feature_extractor.n_samples + ] + features = self.feature_extractor(audio) + + features = features[ + ..., : language_detection_segments * self.feature_extractor.nb_max_frames + ] + + detected_language_info = {} + for i in range(0, features.shape[-1], self.feature_extractor.nb_max_frames): + encoder_output = self.encode( + pad_or_trim(features[..., i : i + self.feature_extractor.nb_max_frames]) + ) + # results is a list of tuple[str, float] with language names and probabilities. + results = self.model.detect_language(encoder_output)[0] + + # Parse language names to strip out markers + all_language_probs = [(token[2:-2], prob) for (token, prob) in results] + # Get top language token and probability + language, language_probability = all_language_probs[0] + if language_probability > language_detection_threshold: + break + detected_language_info.setdefault(language, []).append(language_probability) + else: + # If no language detected for all segments, the majority vote of the highest + # projected languages for all segments is used to determine the language. + language = max( + detected_language_info, + key=lambda lang: len(detected_language_info[lang]), + ) + language_probability = max(detected_language_info[language]) + + return language, language_probability, all_language_probs + + +def restore_speech_timestamps( + segments: Iterable[Segment], + speech_chunks: List[dict], + sampling_rate: int, +) -> Iterable[Segment]: + ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate) + + for segment in segments: + if segment.words: + words = [] + for word in segment.words: + # Ensure the word start and end times are resolved to the same chunk. + middle = (word.start + word.end) / 2 + chunk_index = ts_map.get_chunk_index(middle) + word.start = ts_map.get_original_time(word.start, chunk_index) + word.end = ts_map.get_original_time(word.end, chunk_index) + words.append(word) + + segment.start = words[0].start + segment.end = words[-1].end + segment.words = words + + else: + segment.start = ts_map.get_original_time(segment.start) + segment.end = ts_map.get_original_time(segment.end) + return segments + + +def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView: + segment = np.ascontiguousarray(segment) + segment = ctranslate2.StorageView.from_array(segment) + return segment + + +def get_compression_ratio(text: str) -> float: + text_bytes = text.encode("utf-8") + return len(text_bytes) / len(zlib.compress(text_bytes)) + + +def get_suppressed_tokens( + tokenizer: Tokenizer, + suppress_tokens: Tuple[int], +) -> Optional[List[int]]: + if -1 in suppress_tokens: + suppress_tokens = [t for t in suppress_tokens if t >= 0] + suppress_tokens.extend(tokenizer.non_speech_tokens) + elif suppress_tokens is None or len(suppress_tokens) == 0: + suppress_tokens = [] # interpret empty string as an empty list + else: + assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" + + suppress_tokens.extend( + [ + tokenizer.transcribe, + tokenizer.translate, + tokenizer.sot, + tokenizer.sot_prev, + tokenizer.sot_lm, + ] + ) + + return tuple(sorted(set(suppress_tokens))) + + +def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None: + # merge prepended punctuations + i = len(alignment) - 2 + j = len(alignment) - 1 + while i >= 0: + previous = alignment[i] + following = alignment[j] + if previous["word"].startswith(" ") and previous["word"].strip() in prepended: + # prepend it to the following word + following["word"] = previous["word"] + following["word"] + following["tokens"] = previous["tokens"] + following["tokens"] + previous["word"] = "" + previous["tokens"] = [] + else: + j = i + i -= 1 + + # merge appended punctuations + i = 0 + j = 1 + while j < len(alignment): + previous = alignment[i] + following = alignment[j] + if not previous["word"].endswith(" ") and following["word"] in appended: + # append it to the previous word + previous["word"] = previous["word"] + following["word"] + previous["tokens"] = previous["tokens"] + following["tokens"] + following["word"] = "" + following["tokens"] = [] + else: + i = j + j += 1 \ No newline at end of file diff --git a/whisper_live/transcriber_tensorrt.py b/whisper_live/transcriber_tensorrt.py new file mode 100644 index 0000000..aaa8cc1 --- /dev/null +++ b/whisper_live/transcriber_tensorrt.py @@ -0,0 +1,320 @@ +import json +import re +from collections import OrderedDict +from pathlib import Path +from typing import Union + +import torch +import numpy as np +import torch.nn.functional as F +from whisper.tokenizer import get_tokenizer +from whisper_live.tensorrt_utils import (mel_filters, load_audio_wav_format, pad_or_trim, load_audio) + +import tensorrt_llm +import tensorrt_llm.logger as logger +from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt, + trt_dtype_to_torch) +from tensorrt_llm.runtime import ModelConfig, SamplingConfig +from tensorrt_llm.runtime.session import Session, TensorInfo + + +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk + + +class WhisperEncoding: + + def __init__(self, engine_dir): + self.session = self.get_session(engine_dir) + + def get_session(self, engine_dir): + config_path = engine_dir / 'encoder_config.json' + with open(config_path, 'r') as f: + config = json.load(f) + + dtype = config['builder_config']['precision'] + n_mels = config['builder_config']['n_mels'] + num_languages = config['builder_config']['num_languages'] + + self.dtype = dtype + self.n_mels = n_mels + self.num_languages = num_languages + + serialize_path = engine_dir / f'whisper_encoder_{self.dtype}_tp1_rank0.engine' + + with open(serialize_path, 'rb') as f: + session = Session.from_serialized_engine(f.read()) + + return session + + def get_audio_features(self, mel): + inputs = OrderedDict() + output_list = [] + + inputs.update({'x': mel}) + output_list.append( + TensorInfo('x', str_dtype_to_trt(self.dtype), mel.shape)) + + output_info = (self.session).infer_shapes(output_list) + + logger.debug(f'output info {output_info}') + outputs = { + t.name: torch.empty(tuple(t.shape), + dtype=trt_dtype_to_torch(t.dtype), + device='cuda') + for t in output_info + } + stream = torch.cuda.current_stream() + ok = self.session.run(inputs=inputs, + outputs=outputs, + stream=stream.cuda_stream) + assert ok, 'Engine execution failed' + stream.synchronize() + audio_features = outputs['output'] + return audio_features + + +class WhisperDecoding: + + def __init__(self, engine_dir, runtime_mapping, debug_mode=False): + + self.decoder_config = self.get_config(engine_dir) + self.decoder_generation_session = self.get_session( + engine_dir, runtime_mapping, debug_mode) + + def get_config(self, engine_dir): + config_path = engine_dir / 'decoder_config.json' + with open(config_path, 'r') as f: + config = json.load(f) + decoder_config = OrderedDict() + decoder_config.update(config['plugin_config']) + decoder_config.update(config['builder_config']) + return decoder_config + + def get_session(self, engine_dir, runtime_mapping, debug_mode=False): + dtype = self.decoder_config['precision'] + serialize_path = engine_dir / f'whisper_decoder_{dtype}_tp1_rank0.engine' + with open(serialize_path, "rb") as f: + decoder_engine_buffer = f.read() + + decoder_model_config = ModelConfig( + num_heads=self.decoder_config['num_heads'], + num_kv_heads=self.decoder_config['num_heads'], + hidden_size=self.decoder_config['hidden_size'], + vocab_size=self.decoder_config['vocab_size'], + num_layers=self.decoder_config['num_layers'], + gpt_attention_plugin=self.decoder_config['gpt_attention_plugin'], + remove_input_padding=self.decoder_config['remove_input_padding'], + cross_attention=self.decoder_config['cross_attention'], + has_position_embedding=self. + decoder_config['has_position_embedding'], + has_token_type_embedding=self. + decoder_config['has_token_type_embedding'], + ) + decoder_generation_session = tensorrt_llm.runtime.GenerationSession( + decoder_model_config, + decoder_engine_buffer, + runtime_mapping, + debug_mode=debug_mode) + + return decoder_generation_session + + def generate(self, + decoder_input_ids, + encoder_outputs, + eot_id, + max_new_tokens=40, + num_beams=1): + encoder_input_lengths = torch.tensor( + [encoder_outputs.shape[1] for x in range(encoder_outputs.shape[0])], + dtype=torch.int32, + device='cuda') + + decoder_input_lengths = torch.tensor([ + decoder_input_ids.shape[-1] + for _ in range(decoder_input_ids.shape[0]) + ], + dtype=torch.int32, + device='cuda') + decoder_max_input_length = torch.max(decoder_input_lengths).item() + + # generation config + sampling_config = SamplingConfig(end_id=eot_id, + pad_id=eot_id, + num_beams=num_beams) + self.decoder_generation_session.setup( + decoder_input_lengths.size(0), + decoder_max_input_length, + max_new_tokens, + beam_width=num_beams, + encoder_max_input_length=encoder_outputs.shape[1]) + + torch.cuda.synchronize() + + decoder_input_ids = decoder_input_ids.type(torch.int32).cuda() + output_ids = self.decoder_generation_session.decode( + decoder_input_ids, + decoder_input_lengths, + sampling_config, + encoder_output=encoder_outputs, + encoder_input_lengths=encoder_input_lengths, + ) + torch.cuda.synchronize() + + # get the list of int from output_ids tensor + output_ids = output_ids.cpu().numpy().tolist() + return output_ids + + +class WhisperTRTLLM(object): + + def __init__(self, engine_dir, assets_dir=None, device=None, is_multilingual=False, + language="en", task="transcribe"): + world_size = 1 + runtime_rank = tensorrt_llm.mpi_rank() + runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank) + torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) + engine_dir = Path(engine_dir) + + self.encoder = WhisperEncoding(engine_dir) + self.decoder = WhisperDecoding(engine_dir, + runtime_mapping, + debug_mode=False) + self.n_mels = self.encoder.n_mels + # self.tokenizer = get_tokenizer(num_languages=self.encoder.num_languages, + # tokenizer_dir=assets_dir) + self.device = device + self.tokenizer = get_tokenizer( + is_multilingual, + num_languages=self.encoder.num_languages, + language=language, + task=task, + ) + self.filters = mel_filters(self.device, self.encoder.n_mels, assets_dir) + + def log_mel_spectrogram( + self, + audio: Union[str, np.ndarray, torch.Tensor], + padding: int = 0, + return_duration=True + ): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 and 128 are supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80 or 128, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + if audio.endswith('.wav'): + audio, _ = load_audio_wav_format(audio) + else: + audio = load_audio(audio) + assert isinstance(audio, np.ndarray), f"Unsupported audio type: {type(audio)}" + duration = audio.shape[-1] / SAMPLE_RATE + audio = pad_or_trim(audio, N_SAMPLES) + audio = audio.astype(np.float32) + audio = torch.from_numpy(audio) + + if self.device is not None: + audio = audio.to(self.device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs()**2 + + mel_spec = self.filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + if return_duration: + return log_spec, duration + else: + return log_spec + + def process_batch( + self, + mel, + text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + num_beams=1): + prompt_id = self.tokenizer.encode( + text_prefix, allowed_special=set(self.tokenizer.special_tokens.keys())) + + prompt_id = torch.tensor(prompt_id) + batch_size = mel.shape[0] + decoder_input_ids = prompt_id.repeat(batch_size, 1) + + encoder_output = self.encoder.get_audio_features(mel) + output_ids = self.decoder.generate(decoder_input_ids, + encoder_output, + self.tokenizer.eot, + max_new_tokens=96, + num_beams=num_beams) + texts = [] + for i in range(len(output_ids)): + text = self.tokenizer.decode(output_ids[i][0]).strip() + texts.append(text) + return texts + + def transcribe( + self, + mel, + text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + dtype='float16', + batch_size=1, + num_beams=1, + ): + mel = mel.type(str_dtype_to_torch(dtype)) + mel = mel.unsqueeze(0) + predictions = self.process_batch(mel, text_prefix, num_beams) + prediction = predictions[0] + + # remove all special tokens in the prediction + prediction = re.sub(r'<\|.*?\|>', '', prediction) + return prediction.strip() + + +def decode_wav_file( + model, + mel, + text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + dtype='float16', + batch_size=1, + num_beams=1, + normalizer=None, + mel_filters_dir=None): + + mel = mel.type(str_dtype_to_torch(dtype)) + mel = mel.unsqueeze(0) + # repeat the mel spectrogram to match the batch size + mel = mel.repeat(batch_size, 1, 1) + predictions = model.process_batch(mel, text_prefix, num_beams) + prediction = predictions[0] + + # remove all special tokens in the prediction + prediction = re.sub(r'<\|.*?\|>', '', prediction) + if normalizer: + prediction = normalizer(prediction) + + return prediction.strip() diff --git a/whisper_live/utils.py b/whisper_live/utils.py new file mode 100644 index 0000000..2dd20b3 --- /dev/null +++ b/whisper_live/utils.py @@ -0,0 +1,82 @@ +import os +import textwrap +import scipy +import numpy as np +import av +from pathlib import Path + + +def clear_screen(): + """Clears the console screen.""" + os.system("cls" if os.name == "nt" else "clear") + + +def print_transcript(text): + """Prints formatted transcript text.""" + wrapper = textwrap.TextWrapper(width=60) + for line in wrapper.wrap(text="".join(text)): + print(line) + + +def format_time(s): + """Convert seconds (float) to SRT time format.""" + hours = int(s // 3600) + minutes = int((s % 3600) // 60) + seconds = int(s % 60) + milliseconds = int((s - int(s)) * 1000) + return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}" + + +def create_srt_file(segments, resampled_file): + with open(resampled_file, 'w', encoding='utf-8') as srt_file: + segment_number = 1 + for segment in segments: + start_time = format_time(float(segment['start'])) + end_time = format_time(float(segment['end'])) + text = segment['text'] + + srt_file.write(f"{segment_number}\n") + srt_file.write(f"{start_time} --> {end_time}\n") + srt_file.write(f"{text}\n\n") + + segment_number += 1 + + +def resample(file: str, sr: int = 16000): + """ + Resample the audio file to 16kHz. + + Args: + file (str): The audio file to open + sr (int): The sample rate to resample the audio if necessary + + Returns: + resampled_file (str): The resampled audio file + """ + container = av.open(file) + stream = next(s for s in container.streams if s.type == 'audio') + + resampler = av.AudioResampler( + format='s16', + layout='mono', + rate=sr, + ) + + resampled_file = Path(file).stem + "_resampled.wav" + output_container = av.open(resampled_file, mode='w') + output_stream = output_container.add_stream('pcm_s16le', rate=sr) + output_stream.layout = 'mono' + + for frame in container.decode(audio=0): + frame.pts = None + resampled_frames = resampler.resample(frame) + if resampled_frames is not None: + for resampled_frame in resampled_frames: + for packet in output_stream.encode(resampled_frame): + output_container.mux(packet) + + for packet in output_stream.encode(None): + output_container.mux(packet) + + output_container.close() + return resampled_file \ No newline at end of file diff --git a/whisper_live/vad.py b/whisper_live/vad.py new file mode 100644 index 0000000..01a2540 --- /dev/null +++ b/whisper_live/vad.py @@ -0,0 +1,155 @@ +# original: https://github.com/snakers4/silero-vad/blob/master/utils_vad.py + +import os +import subprocess +import torch +import numpy as np +import onnxruntime +import warnings + + +class VoiceActivityDetection(): + + def __init__(self, force_onnx_cpu=True): + path = self.download() + + opts = onnxruntime.SessionOptions() + opts.log_severity_level = 3 + + opts.inter_op_num_threads = 1 + opts.intra_op_num_threads = 1 + + if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers(): + self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts) + else: + self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts) + + self.reset_states() + self.sample_rates = [8000, 16000] + + def _validate_input(self, x, sr: int): + if x.dim() == 1: + x = x.unsqueeze(0) + if x.dim() > 2: + raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}") + + if sr != 16000 and (sr % 16000 == 0): + step = sr // 16000 + x = x[:, ::step] + sr = 16000 + + if sr not in self.sample_rates: + raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)") + if sr / x.shape[1] > 31.25: + raise ValueError("Input audio chunk is too short") + + return x, sr + + def reset_states(self, batch_size=1): + self._state = torch.zeros((2, batch_size, 128)).float() + self._context = torch.zeros(0) + self._last_sr = 0 + self._last_batch_size = 0 + + def __call__(self, x, sr: int): + + x, sr = self._validate_input(x, sr) + num_samples = 512 if sr == 16000 else 256 + + if x.shape[-1] != num_samples: + raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)") + + batch_size = x.shape[0] + context_size = 64 if sr == 16000 else 32 + + if not self._last_batch_size: + self.reset_states(batch_size) + if (self._last_sr) and (self._last_sr != sr): + self.reset_states(batch_size) + if (self._last_batch_size) and (self._last_batch_size != batch_size): + self.reset_states(batch_size) + + if not len(self._context): + self._context = torch.zeros(batch_size, context_size) + + x = torch.cat([self._context, x], dim=1) + if sr in [8000, 16000]: + ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')} + ort_outs = self.session.run(None, ort_inputs) + out, state = ort_outs + self._state = torch.from_numpy(state) + else: + raise ValueError() + + self._context = x[..., -context_size:] + self._last_sr = sr + self._last_batch_size = batch_size + + out = torch.from_numpy(out) + return out + + def audio_forward(self, x, sr: int): + outs = [] + x, sr = self._validate_input(x, sr) + self.reset_states() + num_samples = 512 if sr == 16000 else 256 + + if x.shape[1] % num_samples: + pad_num = num_samples - (x.shape[1] % num_samples) + x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0) + + for i in range(0, x.shape[1], num_samples): + wavs_batch = x[:, i:i+num_samples] + out_chunk = self.__call__(wavs_batch, sr) + outs.append(out_chunk) + + stacked = torch.cat(outs, dim=1) + return stacked.cpu() + + @staticmethod + def download(model_url="https://github.com/snakers4/silero-vad/raw/v5.0/files/silero_vad.onnx"): + target_dir = os.path.expanduser("~/.cache/whisper-live/") + + # Ensure the target directory exists + os.makedirs(target_dir, exist_ok=True) + + # Define the target file path + model_filename = os.path.join(target_dir, "silero_vad.onnx") + + # Check if the model file already exists + if not os.path.exists(model_filename): + # If it doesn't exist, download the model using wget + try: + subprocess.run(["wget", "-O", model_filename, model_url], check=True) + except subprocess.CalledProcessError: + print("Failed to download the model using wget.") + return model_filename + + +class VoiceActivityDetector: + def __init__(self, threshold=0.5, frame_rate=16000): + """ + Initializes the VoiceActivityDetector with a voice activity detection model and a threshold. + + Args: + threshold (float, optional): The probability threshold for detecting voice activity. Defaults to 0.5. + """ + self.model = VoiceActivityDetection() + self.threshold = threshold + self.frame_rate = frame_rate + + def __call__(self, audio_frame): + """ + Determines if the given audio frame contains speech by comparing the detected speech probability against + the threshold. + + Args: + audio_frame (np.ndarray): The audio frame to be analyzed for voice activity. It is expected to be a + NumPy array of audio samples. + + Returns: + bool: True if the speech probability exceeds the threshold, indicating the presence of voice activity; + False otherwise. + """ + speech_probs = self.model.audio_forward(torch.from_numpy(audio_frame.copy()), self.frame_rate)[0] + return torch.any(speech_probs > self.threshold).item() \ No newline at end of file diff --git a/workflows/ci.yml b/workflows/ci.yml new file mode 100644 index 0000000..c42b663 --- /dev/null +++ b/workflows/ci.yml @@ -0,0 +1,169 @@ +name: Test & Build CI/CD + +on: + push: + branches: + - main + tags: + - v* + pull_request: + branches: [ main ] + types: [opened, synchronize, reopened] + +jobs: + run-tests: + runs-on: ubuntu-22.04 + strategy: + matrix: + python-version: [3.8, 3.9, '3.10', 3.11] + steps: + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache Python dependencies + uses: actions/cache@v2 + with: + path: | + ~/.cache/pip + !~/.cache/pip/log + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('requirements/server.txt', 'requirements/client.txt') }} + restore-keys: | + ${{ runner.os }}-pip-${{ matrix.python-version }}- + + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y ffmpeg portaudio19-dev + + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements/server.txt --extra-index-url https://download.pytorch.org/whl/cpu + pip install -r requirements/client.txt + + - name: Run tests + run: | + echo "Running tests with Python ${{ matrix.python-version }}" + python -m unittest discover -s tests + + check-code-format: + runs-on: ubuntu-22.04 + strategy: + matrix: + python-version: [3.8, 3.9, '3.10', 3.11] + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 + + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + + build-and-push-docker-cpu: + needs: [run-tests, check-code-format] + runs-on: ubuntu-22.04 + if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/')) + steps: + - uses: actions/checkout@v2 + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GHCR_TOKEN }} + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + + - name: Build and push Docker image + uses: docker/build-push-action@v2 + with: + context: . + file: docker/Dockerfile.cpu + push: true + tags: ghcr.io/collabora/whisperlive-cpu:latest + + build-and-push-docker-gpu: + needs: [run-tests, check-code-format, build-and-push-docker-cpu] + timeout-minutes: 20 + runs-on: ubuntu-22.04 + if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/')) + steps: + - uses: actions/checkout@v2 + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GHCR_TOKEN }} + + - name: Docker Prune + run: docker system prune -af + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + + - name: Build and push Docker GPU image + uses: docker/build-push-action@v2 + with: + context: . + file: docker/Dockerfile.gpu + push: true + tags: ghcr.io/collabora/whisperlive-gpu:latest + + publish-to-pypi: + needs: [run-tests, check-code-format] + runs-on: ubuntu-22.04 + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') + steps: + - uses: actions/checkout@v2 + + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Cache Python dependencies + uses: actions/cache@v2 + with: + path: | + ~/.cache/pip + !~/.cache/pip/log + key: ubuntu-latest-pip-3.8-${{ hashFiles('requirements/server.txt', 'requirements/client.txt') }} + restore-keys: | + ubuntu-latest-pip-3.8- + + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y ffmpeg portaudio19-dev + + - name: Install Python dependencies + run: | + pip install -r requirements/server.txt + pip install -r requirements/client.txt + pip install wheel + + - name: Build package + run: python setup.py sdist bdist_wheel + + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }}