271 lines
11 KiB
Python
271 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Batch Transcription Script for WhisperLive
|
|
Processes all audio files in a folder using the HTTP transcription endpoint
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import time
|
|
import argparse
|
|
import requests
|
|
from pathlib import Path
|
|
from typing import List, Dict, Optional
|
|
import logging
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class BatchTranscriber:
|
|
def __init__(self, server_url: str = "http://localhost:8080"):
|
|
self.server_url = server_url
|
|
self.supported_formats = {'.wav', '.mp3', '.flac', '.m4a', '.ogg', '.webm'}
|
|
|
|
def get_audio_files(self, folder_path: str) -> List[Path]:
|
|
"""Get all audio files from the specified folder"""
|
|
folder = Path(folder_path)
|
|
if not folder.exists():
|
|
raise FileNotFoundError(f"Folder not found: {folder_path}")
|
|
|
|
audio_files = []
|
|
for file_path in folder.iterdir():
|
|
if file_path.is_file() and file_path.suffix.lower() in self.supported_formats:
|
|
audio_files.append(file_path)
|
|
|
|
return sorted(audio_files)
|
|
|
|
def transcribe_file(self, file_path: Path, language: Optional[str] = None,
|
|
task: str = "transcribe", model: str = "base") -> Dict:
|
|
"""Transcribe a single audio file"""
|
|
try:
|
|
logger.info(f"Transcribing: {file_path.name}")
|
|
|
|
with open(file_path, 'rb') as f:
|
|
files = {'file': f}
|
|
data = {
|
|
'language': language,
|
|
'task': task,
|
|
'model': model
|
|
}
|
|
|
|
response = requests.post(f"{self.server_url}/transcribe",
|
|
files=files, data=data, timeout=300)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
logger.info(f"✅ Successfully transcribed: {file_path.name}")
|
|
return result
|
|
else:
|
|
error_msg = response.text
|
|
logger.error(f"❌ Failed to transcribe {file_path.name}: {error_msg}")
|
|
return {'error': error_msg, 'status_code': response.status_code}
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error transcribing {file_path.name}: {str(e)}")
|
|
return {'error': str(e)}
|
|
|
|
def save_transcript(self, transcript_data: Dict, output_path: Path,
|
|
format_type: str = "txt") -> bool:
|
|
"""Save transcript in specified format"""
|
|
try:
|
|
if 'error' in transcript_data:
|
|
return False
|
|
|
|
if format_type == "txt":
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
f.write(f"Transcription of: {transcript_data.get('filename', 'Unknown')}\n")
|
|
f.write(f"Language: {transcript_data['info'].get('language', 'Auto-detected')}\n")
|
|
f.write(f"Duration: {transcript_data['info'].get('duration', 0):.2f} seconds\n")
|
|
f.write("=" * 50 + "\n\n")
|
|
|
|
for segment in transcript_data['segments']:
|
|
f.write(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] {segment['text']}\n")
|
|
|
|
elif format_type == "json":
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
json.dump(transcript_data, f, indent=2, ensure_ascii=False)
|
|
|
|
elif format_type == "srt":
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
for i, segment in enumerate(transcript_data['segments'], 1):
|
|
start_time = self.format_srt_time(segment['start'])
|
|
end_time = self.format_srt_time(segment['end'])
|
|
f.write(f"{i}\n{start_time} --> {end_time}\n{segment['text']}\n\n")
|
|
|
|
elif format_type == "vtt":
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
f.write("WEBVTT\n\n")
|
|
for segment in transcript_data['segments']:
|
|
start_time = self.format_vtt_time(segment['start'])
|
|
end_time = self.format_vtt_time(segment['end'])
|
|
f.write(f"{start_time} --> {end_time}\n{segment['text']}\n\n")
|
|
|
|
logger.info(f"💾 Saved transcript: {output_path}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error saving transcript {output_path}: {str(e)}")
|
|
return False
|
|
|
|
def format_srt_time(self, seconds: float) -> str:
|
|
"""Format time for SRT subtitles"""
|
|
hours = int(seconds // 3600)
|
|
minutes = int((seconds % 3600) // 60)
|
|
secs = int(seconds % 60)
|
|
millisecs = int((seconds % 1) * 1000)
|
|
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millisecs:03d}"
|
|
|
|
def format_vtt_time(self, seconds: float) -> str:
|
|
"""Format time for VTT subtitles"""
|
|
hours = int(seconds // 3600)
|
|
minutes = int((seconds % 3600) // 60)
|
|
secs = int(seconds % 60)
|
|
millisecs = int((seconds % 1) * 1000)
|
|
return f"{hours:02d}:{minutes:02d}:{secs:02d}.{millisecs:03d}"
|
|
|
|
def batch_transcribe(self, input_folder: str, output_folder: str,
|
|
language: Optional[str] = None, task: str = "transcribe",
|
|
model: str = "base", format_type: str = "txt",
|
|
delay: float = 1.0) -> Dict:
|
|
"""Process all audio files in the input folder"""
|
|
|
|
# Create output folder if it doesn't exist
|
|
output_path = Path(output_folder)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Get all audio files
|
|
audio_files = self.get_audio_files(input_folder)
|
|
if not audio_files:
|
|
logger.warning(f"No audio files found in: {input_folder}")
|
|
return {'processed': 0, 'successful': 0, 'failed': 0}
|
|
|
|
logger.info(f"Found {len(audio_files)} audio files to process")
|
|
|
|
results = {
|
|
'processed': len(audio_files),
|
|
'successful': 0,
|
|
'failed': 0,
|
|
'files': []
|
|
}
|
|
|
|
for i, audio_file in enumerate(audio_files, 1):
|
|
logger.info(f"Processing {i}/{len(audio_files)}: {audio_file.name}")
|
|
|
|
# Transcribe the file
|
|
transcript_data = self.transcribe_file(audio_file, language, task, model)
|
|
|
|
if 'error' not in transcript_data:
|
|
# Create output filename
|
|
base_name = audio_file.stem
|
|
output_file = output_path / f"{base_name}.{format_type}"
|
|
|
|
# Save transcript
|
|
if self.save_transcript(transcript_data, output_file, format_type):
|
|
results['successful'] += 1
|
|
results['files'].append({
|
|
'input': str(audio_file),
|
|
'output': str(output_file),
|
|
'status': 'success'
|
|
})
|
|
else:
|
|
results['failed'] += 1
|
|
results['files'].append({
|
|
'input': str(audio_file),
|
|
'output': str(output_file),
|
|
'status': 'failed'
|
|
})
|
|
else:
|
|
results['failed'] += 1
|
|
results['files'].append({
|
|
'input': str(audio_file),
|
|
'output': None,
|
|
'status': 'failed',
|
|
'error': transcript_data.get('error', 'Unknown error')
|
|
})
|
|
|
|
# Add delay between requests to avoid overwhelming the server
|
|
if i < len(audio_files):
|
|
time.sleep(delay)
|
|
|
|
return results
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Batch transcribe audio files using WhisperLive')
|
|
parser.add_argument('input_folder', help='Folder containing audio files')
|
|
parser.add_argument('output_folder', help='Folder to save transcripts')
|
|
parser.add_argument('--server', '-s', default='http://localhost:8080',
|
|
help='WhisperLive server URL (default: http://localhost:8080)')
|
|
parser.add_argument('--language', '-l', help='Language code (e.g., en, es, fr)')
|
|
parser.add_argument('--task', '-t', choices=['transcribe', 'translate'], default='transcribe',
|
|
help='Task to perform (default: transcribe)')
|
|
parser.add_argument('--model', '-m', default='base',
|
|
help='Model size (default: base)')
|
|
parser.add_argument('--format', '-f', choices=['txt', 'json', 'srt', 'vtt'], default='txt',
|
|
help='Output format (default: txt)')
|
|
parser.add_argument('--delay', '-d', type=float, default=1.0,
|
|
help='Delay between requests in seconds (default: 1.0)')
|
|
parser.add_argument('--verbose', '-v', action='store_true',
|
|
help='Verbose output')
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.verbose:
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
|
try:
|
|
# Initialize transcriber
|
|
transcriber = BatchTranscriber(args.server)
|
|
|
|
# Check server health
|
|
try:
|
|
response = requests.get(f"{args.server}/health", timeout=5)
|
|
if response.status_code != 200:
|
|
logger.error(f"Server health check failed: {response.status_code}")
|
|
sys.exit(1)
|
|
logger.info("✅ Server health check passed")
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error(f"❌ Cannot connect to server: {e}")
|
|
sys.exit(1)
|
|
|
|
# Process files
|
|
results = transcriber.batch_transcribe(
|
|
input_folder=args.input_folder,
|
|
output_folder=args.output_folder,
|
|
language=args.language,
|
|
task=args.task,
|
|
model=args.model,
|
|
format_type=args.format,
|
|
delay=args.delay
|
|
)
|
|
|
|
# Print summary
|
|
logger.info("\n" + "=" * 50)
|
|
logger.info("BATCH TRANSCRIPTION COMPLETED")
|
|
logger.info("=" * 50)
|
|
logger.info(f"Total files processed: {results['processed']}")
|
|
logger.info(f"Successful: {results['successful']}")
|
|
logger.info(f"Failed: {results['failed']}")
|
|
logger.info(f"Output folder: {args.output_folder}")
|
|
logger.info(f"Output format: {args.format}")
|
|
|
|
if results['failed'] > 0:
|
|
logger.warning("\nFailed files:")
|
|
for file_info in results['files']:
|
|
if file_info['status'] == 'failed':
|
|
logger.warning(f" - {file_info['input']}: {file_info.get('error', 'Unknown error')}")
|
|
|
|
if results['successful'] > 0:
|
|
logger.info(f"\n✅ Successfully processed {results['successful']} files!")
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("\n⚠️ Process interrupted by user")
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
logger.error(f"❌ Unexpected error: {str(e)}")
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|