WhisperLive-Server/test_http_endpoints.py

160 lines
5.1 KiB
Python

#!/usr/bin/env python3
"""
Test script for WhisperLive HTTP endpoints
This script demonstrates how to use the new HTTP API for file transcription
"""
import requests
import json
import os
from pathlib import Path
# Configuration
HTTP_BASE_URL = "http://localhost:8080" # Adjust if using different port
WEBSOCKET_PORT = 5050 # Your existing WebSocket port
def test_health_endpoint():
"""Test the health check endpoint"""
print("Testing health endpoint...")
try:
response = requests.get(f"{HTTP_BASE_URL}/health")
print(f"Status: {response.status_code}")
print(f"Response: {response.json()}")
return response.status_code == 200
except Exception as e:
print(f"Error: {e}")
return False
def test_file_transcription(audio_file_path, language=None, task="transcribe", model="base"):
"""Test file transcription endpoint"""
print(f"\nTesting file transcription endpoint...")
print(f"File: {audio_file_path}")
print(f"Language: {language or 'auto-detect'}")
print(f"Task: {task}")
print(f"Model: {model}")
if not os.path.exists(audio_file_path):
print(f"Error: File {audio_file_path} not found")
return False
try:
# Prepare the request
files = {'file': open(audio_file_path, 'rb')}
data = {
'language': language,
'task': task,
'model': model
}
# Make the request
response = requests.post(f"{HTTP_BASE_URL}/transcribe", files=files, data=data)
print(f"Status: {response.status_code}")
if response.status_code == 200:
result = response.json()
print("Transcription successful!")
print(f"Filename: {result.get('filename')}")
print(f"Language: {result['info'].get('language')}")
print(f"Duration: {result['info'].get('duration')} seconds")
print(f"Number of segments: {len(result['segments'])}")
# Print first few segments
for i, segment in enumerate(result['segments'][:3]):
print(f"Segment {i+1}: [{segment['start']:.2f}s - {segment['end']:.2f}s] {segment['text']}")
if len(result['segments']) > 3:
print(f"... and {len(result['segments']) - 3} more segments")
return True
else:
print(f"Error: {response.text}")
return False
except Exception as e:
print(f"Error: {e}")
return False
def test_url_transcription():
"""Test URL transcription endpoint (placeholder)"""
print(f"\nTesting URL transcription endpoint...")
try:
data = {
'url': 'https://example.com/audio.mp3',
'language': 'en',
'task': 'transcribe',
'model': 'base'
}
response = requests.post(f"{HTTP_BASE_URL}/transcribe/url", json=data)
print(f"Status: {response.status_code}")
print(f"Response: {response.json()}")
return response.status_code == 200
except Exception as e:
print(f"Error: {e}")
return False
def test_openai_endpoint(audio_file_path):
"""Test the OpenAI compatible endpoint"""
print(f"\nTesting OpenAI compatible endpoint...")
print(f"File: {audio_file_path}")
if not os.path.exists(audio_file_path):
print(f"Error: File {audio_file_path} not found")
return False
try:
files = {'file': open(audio_file_path, 'rb')}
data = {
'model': 'whisper-1',
'response_format': 'json'
}
response = requests.post(f"{HTTP_BASE_URL}/v1/audio/transcriptions", files=files, data=data)
print(f"Status: {response.status_code}")
if response.status_code == 200:
result = response.json()
print("OpenAI endpoint successful!")
print(f"Response: {result}")
return True
else:
print(f"Error: {response.text}")
return False
except Exception as e:
print(f"Error: {e}")
return False
def main():
"""Main test function"""
print("WhisperLive HTTP Endpoints Test")
print("=" * 40)
# Test health endpoint
if not test_health_endpoint():
print("Health check failed. Make sure the server is running.")
return
# Test file transcription with a sample audio file
# You can replace this with any audio file you have
sample_audio = "assets/jfk.flac" # Adjust path as needed
if os.path.exists(sample_audio):
test_file_transcription(sample_audio, language="en", task="transcribe", model="base")
test_openai_endpoint(sample_audio)
else:
print(f"\nSample audio file not found at {sample_audio}")
print("You can test with any audio file by calling:")
print("test_file_transcription('path/to/your/audio.wav')")
# Test URL transcription endpoint
test_url_transcription()
print("\n" + "=" * 40)
print("Test completed!")
if __name__ == "__main__":
main()