WhisperLive-Server/batch_transcribe.py

271 lines
11 KiB
Python

#!/usr/bin/env python3
"""
Batch Transcription Script for WhisperLive
Processes all audio files in a folder using the HTTP transcription endpoint
"""
import os
import sys
import json
import time
import argparse
import requests
from pathlib import Path
from typing import List, Dict, Optional
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class BatchTranscriber:
def __init__(self, server_url: str = "http://localhost:8080"):
self.server_url = server_url
self.supported_formats = {'.wav', '.mp3', '.flac', '.m4a', '.ogg', '.webm'}
def get_audio_files(self, folder_path: str) -> List[Path]:
"""Get all audio files from the specified folder"""
folder = Path(folder_path)
if not folder.exists():
raise FileNotFoundError(f"Folder not found: {folder_path}")
audio_files = []
for file_path in folder.iterdir():
if file_path.is_file() and file_path.suffix.lower() in self.supported_formats:
audio_files.append(file_path)
return sorted(audio_files)
def transcribe_file(self, file_path: Path, language: Optional[str] = None,
task: str = "transcribe", model: str = "base") -> Dict:
"""Transcribe a single audio file"""
try:
logger.info(f"Transcribing: {file_path.name}")
with open(file_path, 'rb') as f:
files = {'file': f}
data = {
'language': language,
'task': task,
'model': model
}
response = requests.post(f"{self.server_url}/transcribe",
files=files, data=data, timeout=300)
if response.status_code == 200:
result = response.json()
logger.info(f"✅ Successfully transcribed: {file_path.name}")
return result
else:
error_msg = response.text
logger.error(f"❌ Failed to transcribe {file_path.name}: {error_msg}")
return {'error': error_msg, 'status_code': response.status_code}
except Exception as e:
logger.error(f"❌ Error transcribing {file_path.name}: {str(e)}")
return {'error': str(e)}
def save_transcript(self, transcript_data: Dict, output_path: Path,
format_type: str = "txt") -> bool:
"""Save transcript in specified format"""
try:
if 'error' in transcript_data:
return False
if format_type == "txt":
with open(output_path, 'w', encoding='utf-8') as f:
f.write(f"Transcription of: {transcript_data.get('filename', 'Unknown')}\n")
f.write(f"Language: {transcript_data['info'].get('language', 'Auto-detected')}\n")
f.write(f"Duration: {transcript_data['info'].get('duration', 0):.2f} seconds\n")
f.write("=" * 50 + "\n\n")
for segment in transcript_data['segments']:
f.write(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] {segment['text']}\n")
elif format_type == "json":
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(transcript_data, f, indent=2, ensure_ascii=False)
elif format_type == "srt":
with open(output_path, 'w', encoding='utf-8') as f:
for i, segment in enumerate(transcript_data['segments'], 1):
start_time = self.format_srt_time(segment['start'])
end_time = self.format_srt_time(segment['end'])
f.write(f"{i}\n{start_time} --> {end_time}\n{segment['text']}\n\n")
elif format_type == "vtt":
with open(output_path, 'w', encoding='utf-8') as f:
f.write("WEBVTT\n\n")
for segment in transcript_data['segments']:
start_time = self.format_vtt_time(segment['start'])
end_time = self.format_vtt_time(segment['end'])
f.write(f"{start_time} --> {end_time}\n{segment['text']}\n\n")
logger.info(f"💾 Saved transcript: {output_path}")
return True
except Exception as e:
logger.error(f"❌ Error saving transcript {output_path}: {str(e)}")
return False
def format_srt_time(self, seconds: float) -> str:
"""Format time for SRT subtitles"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
millisecs = int((seconds % 1) * 1000)
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millisecs:03d}"
def format_vtt_time(self, seconds: float) -> str:
"""Format time for VTT subtitles"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
millisecs = int((seconds % 1) * 1000)
return f"{hours:02d}:{minutes:02d}:{secs:02d}.{millisecs:03d}"
def batch_transcribe(self, input_folder: str, output_folder: str,
language: Optional[str] = None, task: str = "transcribe",
model: str = "base", format_type: str = "txt",
delay: float = 1.0) -> Dict:
"""Process all audio files in the input folder"""
# Create output folder if it doesn't exist
output_path = Path(output_folder)
output_path.mkdir(parents=True, exist_ok=True)
# Get all audio files
audio_files = self.get_audio_files(input_folder)
if not audio_files:
logger.warning(f"No audio files found in: {input_folder}")
return {'processed': 0, 'successful': 0, 'failed': 0}
logger.info(f"Found {len(audio_files)} audio files to process")
results = {
'processed': len(audio_files),
'successful': 0,
'failed': 0,
'files': []
}
for i, audio_file in enumerate(audio_files, 1):
logger.info(f"Processing {i}/{len(audio_files)}: {audio_file.name}")
# Transcribe the file
transcript_data = self.transcribe_file(audio_file, language, task, model)
if 'error' not in transcript_data:
# Create output filename
base_name = audio_file.stem
output_file = output_path / f"{base_name}.{format_type}"
# Save transcript
if self.save_transcript(transcript_data, output_file, format_type):
results['successful'] += 1
results['files'].append({
'input': str(audio_file),
'output': str(output_file),
'status': 'success'
})
else:
results['failed'] += 1
results['files'].append({
'input': str(audio_file),
'output': str(output_file),
'status': 'failed'
})
else:
results['failed'] += 1
results['files'].append({
'input': str(audio_file),
'output': None,
'status': 'failed',
'error': transcript_data.get('error', 'Unknown error')
})
# Add delay between requests to avoid overwhelming the server
if i < len(audio_files):
time.sleep(delay)
return results
def main():
parser = argparse.ArgumentParser(description='Batch transcribe audio files using WhisperLive')
parser.add_argument('input_folder', help='Folder containing audio files')
parser.add_argument('output_folder', help='Folder to save transcripts')
parser.add_argument('--server', '-s', default='http://localhost:8080',
help='WhisperLive server URL (default: http://localhost:8080)')
parser.add_argument('--language', '-l', help='Language code (e.g., en, es, fr)')
parser.add_argument('--task', '-t', choices=['transcribe', 'translate'], default='transcribe',
help='Task to perform (default: transcribe)')
parser.add_argument('--model', '-m', default='base',
help='Model size (default: base)')
parser.add_argument('--format', '-f', choices=['txt', 'json', 'srt', 'vtt'], default='txt',
help='Output format (default: txt)')
parser.add_argument('--delay', '-d', type=float, default=1.0,
help='Delay between requests in seconds (default: 1.0)')
parser.add_argument('--verbose', '-v', action='store_true',
help='Verbose output')
args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
try:
# Initialize transcriber
transcriber = BatchTranscriber(args.server)
# Check server health
try:
response = requests.get(f"{args.server}/health", timeout=5)
if response.status_code != 200:
logger.error(f"Server health check failed: {response.status_code}")
sys.exit(1)
logger.info("✅ Server health check passed")
except requests.exceptions.RequestException as e:
logger.error(f"❌ Cannot connect to server: {e}")
sys.exit(1)
# Process files
results = transcriber.batch_transcribe(
input_folder=args.input_folder,
output_folder=args.output_folder,
language=args.language,
task=args.task,
model=args.model,
format_type=args.format,
delay=args.delay
)
# Print summary
logger.info("\n" + "=" * 50)
logger.info("BATCH TRANSCRIPTION COMPLETED")
logger.info("=" * 50)
logger.info(f"Total files processed: {results['processed']}")
logger.info(f"Successful: {results['successful']}")
logger.info(f"Failed: {results['failed']}")
logger.info(f"Output folder: {args.output_folder}")
logger.info(f"Output format: {args.format}")
if results['failed'] > 0:
logger.warning("\nFailed files:")
for file_info in results['files']:
if file_info['status'] == 'failed':
logger.warning(f" - {file_info['input']}: {file_info.get('error', 'Unknown error')}")
if results['successful'] > 0:
logger.info(f"\n✅ Successfully processed {results['successful']} files!")
except KeyboardInterrupt:
logger.info("\n⚠️ Process interrupted by user")
sys.exit(1)
except Exception as e:
logger.error(f"❌ Unexpected error: {str(e)}")
sys.exit(1)
if __name__ == "__main__":
main()