1139 lines
46 KiB
Python
1139 lines
46 KiB
Python
import os
|
|
import time
|
|
import threading
|
|
import json
|
|
import functools
|
|
import logging
|
|
from enum import Enum
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
import numpy as np
|
|
from websockets.sync.server import serve
|
|
from websockets.exceptions import ConnectionClosed
|
|
from whisper_live.vad import VoiceActivityDetector
|
|
from whisper_live.transcriber import WhisperModel
|
|
try:
|
|
from whisper_live.transcriber_tensorrt import WhisperTRTLLM
|
|
except Exception:
|
|
pass
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
class ClientManager:
|
|
def __init__(self, max_clients=4, max_connection_time=600):
|
|
"""
|
|
Initializes the ClientManager with specified limits on client connections and connection durations.
|
|
|
|
Args:
|
|
max_clients (int, optional): The maximum number of simultaneous client connections allowed. Defaults to 4.
|
|
max_connection_time (int, optional): The maximum duration (in seconds) a client can stay connected. Defaults
|
|
to 600 seconds (10 minutes).
|
|
"""
|
|
self.clients = {}
|
|
self.start_times = {}
|
|
self.max_clients = max_clients
|
|
self.max_connection_time = max_connection_time
|
|
|
|
def add_client(self, websocket, client):
|
|
"""
|
|
Adds a client and their connection start time to the tracking dictionaries.
|
|
|
|
Args:
|
|
websocket: The websocket associated with the client to add.
|
|
client: The client object to be added and tracked.
|
|
"""
|
|
self.clients[websocket] = client
|
|
self.start_times[websocket] = time.time()
|
|
|
|
def get_client(self, websocket):
|
|
"""
|
|
Retrieves a client associated with the given websocket.
|
|
|
|
Args:
|
|
websocket: The websocket associated with the client to retrieve.
|
|
|
|
Returns:
|
|
The client object if found, False otherwise.
|
|
"""
|
|
if websocket in self.clients:
|
|
return self.clients[websocket]
|
|
return False
|
|
|
|
def remove_client(self, websocket):
|
|
"""
|
|
Removes a client and their connection start time from the tracking dictionaries. Performs cleanup on the
|
|
client if necessary.
|
|
|
|
Args:
|
|
websocket: The websocket associated with the client to be removed.
|
|
"""
|
|
client = self.clients.pop(websocket, None)
|
|
if client:
|
|
client.cleanup()
|
|
self.start_times.pop(websocket, None)
|
|
|
|
def get_wait_time(self):
|
|
"""
|
|
Calculates the estimated wait time for new clients based on the remaining connection times of current clients.
|
|
|
|
Returns:
|
|
The estimated wait time in minutes for new clients to connect. Returns 0 if there are available slots.
|
|
"""
|
|
wait_time = None
|
|
for start_time in self.start_times.values():
|
|
current_client_time_remaining = self.max_connection_time - (time.time() - start_time)
|
|
if wait_time is None or current_client_time_remaining < wait_time:
|
|
wait_time = current_client_time_remaining
|
|
return wait_time / 60 if wait_time is not None else 0
|
|
|
|
def is_server_full(self, websocket, options):
|
|
"""
|
|
Checks if the server is at its maximum client capacity and sends a wait message to the client if necessary.
|
|
|
|
Args:
|
|
websocket: The websocket of the client attempting to connect.
|
|
options: A dictionary of options that may include the client's unique identifier.
|
|
|
|
Returns:
|
|
True if the server is full, False otherwise.
|
|
"""
|
|
if len(self.clients) >= self.max_clients:
|
|
wait_time = self.get_wait_time()
|
|
response = {"uid": options["uid"], "status": "WAIT", "message": wait_time}
|
|
websocket.send(json.dumps(response))
|
|
return True
|
|
return False
|
|
|
|
def is_client_timeout(self, websocket):
|
|
"""
|
|
Checks if a client has exceeded the maximum allowed connection time and disconnects them if so, issuing a warning.
|
|
|
|
Args:
|
|
websocket: The websocket associated with the client to check.
|
|
|
|
Returns:
|
|
True if the client's connection time has exceeded the maximum limit, False otherwise.
|
|
"""
|
|
elapsed_time = time.time() - self.start_times[websocket]
|
|
if elapsed_time >= self.max_connection_time:
|
|
self.clients[websocket].disconnect()
|
|
logging.warning(f"Client with uid '{self.clients[websocket].client_uid}' disconnected due to overtime.")
|
|
return True
|
|
return False
|
|
|
|
|
|
class BackendType(Enum):
|
|
FASTER_WHISPER = "faster_whisper"
|
|
TENSORRT = "tensorrt"
|
|
|
|
@staticmethod
|
|
def valid_types() -> List[str]:
|
|
return [backend_type.value for backend_type in BackendType]
|
|
|
|
@staticmethod
|
|
def is_valid(backend: str) -> bool:
|
|
return backend in BackendType.valid_types()
|
|
|
|
def is_faster_whisper(self) -> bool:
|
|
return self == BackendType.FASTER_WHISPER
|
|
|
|
def is_tensorrt(self) -> bool:
|
|
return self == BackendType.TENSORRT
|
|
|
|
|
|
class TranscriptionServer:
|
|
RATE = 16000
|
|
|
|
def __init__(self):
|
|
self.client_manager = None
|
|
self.no_voice_activity_chunks = 0
|
|
self.use_vad = True
|
|
self.single_model = False
|
|
|
|
def initialize_client(
|
|
self, websocket, options, faster_whisper_custom_model_path,
|
|
whisper_tensorrt_path, trt_multilingual
|
|
):
|
|
client: Optional[ServeClientBase] = None
|
|
|
|
if self.backend.is_tensorrt():
|
|
try:
|
|
client = ServeClientTensorRT(
|
|
websocket,
|
|
multilingual=trt_multilingual,
|
|
language=options["language"],
|
|
task=options["task"],
|
|
client_uid=options["uid"],
|
|
model=whisper_tensorrt_path,
|
|
single_model=self.single_model,
|
|
)
|
|
logging.info("Running TensorRT backend.")
|
|
except Exception as e:
|
|
logging.error(f"TensorRT-LLM not supported: {e}")
|
|
self.client_uid = options["uid"]
|
|
websocket.send(json.dumps({
|
|
"uid": self.client_uid,
|
|
"status": "WARNING",
|
|
"message": "TensorRT-LLM not supported on Server yet. "
|
|
"Reverting to available backend: 'faster_whisper'"
|
|
}))
|
|
self.backend = BackendType.FASTER_WHISPER
|
|
|
|
try:
|
|
if self.backend.is_faster_whisper():
|
|
if faster_whisper_custom_model_path is not None and os.path.exists(faster_whisper_custom_model_path):
|
|
logging.info(f"Using custom model {faster_whisper_custom_model_path}")
|
|
options["model"] = faster_whisper_custom_model_path
|
|
client = ServeClientFasterWhisper(
|
|
websocket,
|
|
language=options["language"],
|
|
task=options["task"],
|
|
client_uid=options["uid"],
|
|
model=options["model"],
|
|
initial_prompt=options.get("initial_prompt"),
|
|
vad_parameters=options.get("vad_parameters"),
|
|
use_vad=self.use_vad,
|
|
single_model=self.single_model,
|
|
)
|
|
|
|
logging.info("Running faster_whisper backend.")
|
|
except Exception as e:
|
|
return
|
|
|
|
if client is None:
|
|
raise ValueError(f"Backend type {self.backend.value} not recognised or not handled.")
|
|
|
|
self.client_manager.add_client(websocket, client)
|
|
|
|
def get_audio_from_websocket(self, websocket):
|
|
"""
|
|
Receives audio buffer from websocket and creates a numpy array out of it.
|
|
|
|
Args:
|
|
websocket: The websocket to receive audio from.
|
|
|
|
Returns:
|
|
A numpy array containing the audio.
|
|
"""
|
|
frame_data = websocket.recv()
|
|
if frame_data == b"END_OF_AUDIO":
|
|
return False
|
|
return np.frombuffer(frame_data, dtype=np.float32)
|
|
|
|
def handle_new_connection(self, websocket, faster_whisper_custom_model_path,
|
|
whisper_tensorrt_path, trt_multilingual):
|
|
try:
|
|
logging.info("New client connected")
|
|
options = websocket.recv()
|
|
options = json.loads(options)
|
|
|
|
if self.client_manager is None:
|
|
max_clients = options.get('max_clients', 4)
|
|
max_connection_time = options.get('max_connection_time', 600)
|
|
self.client_manager = ClientManager(max_clients, max_connection_time)
|
|
|
|
self.use_vad = options.get('use_vad')
|
|
if self.client_manager.is_server_full(websocket, options):
|
|
websocket.close()
|
|
return False # Indicates that the connection should not continue
|
|
|
|
if self.backend.is_tensorrt():
|
|
self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE)
|
|
self.initialize_client(websocket, options, faster_whisper_custom_model_path,
|
|
whisper_tensorrt_path, trt_multilingual)
|
|
return True
|
|
except json.JSONDecodeError:
|
|
logging.error("Failed to decode JSON from client")
|
|
return False
|
|
except ConnectionClosed:
|
|
logging.info("Connection closed by client")
|
|
return False
|
|
except Exception as e:
|
|
logging.error(f"Error during new connection initialization: {str(e)}")
|
|
return False
|
|
|
|
def process_audio_frames(self, websocket):
|
|
frame_np = self.get_audio_from_websocket(websocket)
|
|
client = self.client_manager.get_client(websocket)
|
|
if frame_np is False:
|
|
if self.backend.is_tensorrt():
|
|
client.set_eos(True)
|
|
return False
|
|
|
|
if self.backend.is_tensorrt():
|
|
voice_active = self.voice_activity(websocket, frame_np)
|
|
if voice_active:
|
|
self.no_voice_activity_chunks = 0
|
|
client.set_eos(False)
|
|
if self.use_vad and not voice_active:
|
|
return True
|
|
|
|
client.add_frames(frame_np)
|
|
return True
|
|
|
|
def recv_audio(self,
|
|
websocket,
|
|
backend: BackendType = BackendType.FASTER_WHISPER,
|
|
faster_whisper_custom_model_path=None,
|
|
whisper_tensorrt_path=None,
|
|
trt_multilingual=False):
|
|
"""
|
|
Receive audio chunks from a client in an infinite loop.
|
|
|
|
Continuously receives audio frames from a connected client
|
|
over a WebSocket connection. It processes the audio frames using a
|
|
voice activity detection (VAD) model to determine if they contain speech
|
|
or not. If the audio frame contains speech, it is added to the client's
|
|
audio data for ASR.
|
|
If the maximum number of clients is reached, the method sends a
|
|
"WAIT" status to the client, indicating that they should wait
|
|
until a slot is available.
|
|
If a client's connection exceeds the maximum allowed time, it will
|
|
be disconnected, and the client's resources will be cleaned up.
|
|
|
|
Args:
|
|
websocket (WebSocket): The WebSocket connection for the client.
|
|
backend (str): The backend to run the server with.
|
|
faster_whisper_custom_model_path (str): path to custom faster whisper model.
|
|
whisper_tensorrt_path (str): Required for tensorrt backend.
|
|
trt_multilingual(bool): Only used for tensorrt, True if multilingual model.
|
|
|
|
Raises:
|
|
Exception: If there is an error during the audio frame processing.
|
|
"""
|
|
self.backend = backend
|
|
if not self.handle_new_connection(websocket, faster_whisper_custom_model_path,
|
|
whisper_tensorrt_path, trt_multilingual):
|
|
return
|
|
|
|
try:
|
|
while not self.client_manager.is_client_timeout(websocket):
|
|
if not self.process_audio_frames(websocket):
|
|
break
|
|
except ConnectionClosed:
|
|
logging.info("Connection closed by client")
|
|
except Exception as e:
|
|
logging.error(f"Unexpected error: {str(e)}")
|
|
finally:
|
|
if self.client_manager.get_client(websocket):
|
|
self.cleanup(websocket)
|
|
websocket.close()
|
|
del websocket
|
|
|
|
def run(self,
|
|
host,
|
|
port=int(os.getenv('PORT_WHISPERLIVE')),
|
|
backend="tensorrt",
|
|
faster_whisper_custom_model_path=None,
|
|
whisper_tensorrt_path=None,
|
|
trt_multilingual=False,
|
|
single_model=False,
|
|
ssl_context=None):
|
|
"""
|
|
Run the transcription server.
|
|
|
|
Args:
|
|
host (str): The host address to bind the server.
|
|
port (int): The port number to bind the server.
|
|
"""
|
|
if faster_whisper_custom_model_path is not None and not os.path.exists(faster_whisper_custom_model_path):
|
|
raise ValueError(f"Custom faster_whisper model '{faster_whisper_custom_model_path}' is not a valid path.")
|
|
if whisper_tensorrt_path is not None and not os.path.exists(whisper_tensorrt_path):
|
|
raise ValueError(f"TensorRT model '{whisper_tensorrt_path}' is not a valid path.")
|
|
if single_model:
|
|
if faster_whisper_custom_model_path or whisper_tensorrt_path:
|
|
logging.info("Custom model option was provided. Switching to single model mode.")
|
|
self.single_model = True
|
|
# TODO: load model initially
|
|
else:
|
|
logging.info("Single model mode currently only works with custom models.")
|
|
if not BackendType.is_valid(backend):
|
|
raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}")
|
|
with serve(
|
|
functools.partial(
|
|
self.recv_audio,
|
|
backend=BackendType(backend),
|
|
faster_whisper_custom_model_path=faster_whisper_custom_model_path,
|
|
whisper_tensorrt_path=whisper_tensorrt_path,
|
|
trt_multilingual=trt_multilingual
|
|
),
|
|
host,
|
|
port,
|
|
ssl_context=ssl_context
|
|
) as server:
|
|
server.serve_forever()
|
|
|
|
def voice_activity(self, websocket, frame_np):
|
|
"""
|
|
Evaluates the voice activity in a given audio frame and manages the state of voice activity detection.
|
|
|
|
This method uses the configured voice activity detection (VAD) model to assess whether the given audio frame
|
|
contains speech. If the VAD model detects no voice activity for more than three consecutive frames,
|
|
it sets an end-of-speech (EOS) flag for the associated client. This method aims to efficiently manage
|
|
speech detection to improve subsequent processing steps.
|
|
|
|
Args:
|
|
websocket: The websocket associated with the current client. Used to retrieve the client object
|
|
from the client manager for state management.
|
|
frame_np (numpy.ndarray): The audio frame to be analyzed. This should be a NumPy array containing
|
|
the audio data for the current frame.
|
|
|
|
Returns:
|
|
bool: True if voice activity is detected in the current frame, False otherwise. When returning False
|
|
after detecting no voice activity for more than three consecutive frames, it also triggers the
|
|
end-of-speech (EOS) flag for the client.
|
|
"""
|
|
if not self.vad_detector(frame_np):
|
|
self.no_voice_activity_chunks += 1
|
|
if self.no_voice_activity_chunks > 3:
|
|
client = self.client_manager.get_client(websocket)
|
|
if not client.eos:
|
|
client.set_eos(True)
|
|
time.sleep(0.1) # Sleep 100m; wait some voice activity.
|
|
return False
|
|
return True
|
|
|
|
def cleanup(self, websocket):
|
|
"""
|
|
Cleans up resources associated with a given client's websocket.
|
|
|
|
Args:
|
|
websocket: The websocket associated with the client to be cleaned up.
|
|
"""
|
|
if self.client_manager.get_client(websocket):
|
|
self.client_manager.remove_client(websocket)
|
|
|
|
|
|
class ServeClientBase(object):
|
|
RATE = 16000
|
|
SERVER_READY = "SERVER_READY"
|
|
DISCONNECT = "DISCONNECT"
|
|
|
|
def __init__(self, client_uid, websocket):
|
|
self.client_uid = client_uid
|
|
self.websocket = websocket
|
|
self.frames = b""
|
|
self.timestamp_offset = 0.0
|
|
self.frames_np = None
|
|
self.frames_offset = 0.0
|
|
self.text = []
|
|
self.current_out = ''
|
|
self.prev_out = ''
|
|
self.t_start = None
|
|
self.exit = False
|
|
self.same_output_count = 0
|
|
self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
|
|
self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
|
|
self.transcript = []
|
|
self.send_last_n_segments = 10
|
|
|
|
# text formatting
|
|
self.pick_previous_segments = 2
|
|
|
|
# threading
|
|
self.lock = threading.Lock()
|
|
|
|
def speech_to_text(self):
|
|
raise NotImplementedError
|
|
|
|
def transcribe_audio(self):
|
|
raise NotImplementedError
|
|
|
|
def handle_transcription_output(self):
|
|
raise NotImplementedError
|
|
|
|
def add_frames(self, frame_np):
|
|
"""
|
|
Add audio frames to the ongoing audio stream buffer.
|
|
|
|
This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
|
|
of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
|
|
to prevent excessive memory usage.
|
|
|
|
If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
|
|
of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
|
|
audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
|
|
|
|
Args:
|
|
frame_np (numpy.ndarray): The audio frame data as a NumPy array.
|
|
|
|
"""
|
|
self.lock.acquire()
|
|
if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE:
|
|
self.frames_offset += 30.0
|
|
self.frames_np = self.frames_np[int(30*self.RATE):]
|
|
# check timestamp offset(should be >= self.frame_offset)
|
|
# this basically means that there is no speech as timestamp offset hasnt updated
|
|
# and is less than frame_offset
|
|
if self.timestamp_offset < self.frames_offset:
|
|
self.timestamp_offset = self.frames_offset
|
|
if self.frames_np is None:
|
|
self.frames_np = frame_np.copy()
|
|
else:
|
|
self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
|
|
self.lock.release()
|
|
|
|
def clip_audio_if_no_valid_segment(self):
|
|
"""
|
|
Update the timestamp offset based on audio buffer status.
|
|
Clip audio if the current chunk exceeds 30 seconds, this basically implies that
|
|
no valid segment for the last 30 seconds from whisper
|
|
"""
|
|
with self.lock:
|
|
if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE:
|
|
duration = self.frames_np.shape[0] / self.RATE
|
|
self.timestamp_offset = self.frames_offset + duration - 5
|
|
|
|
def get_audio_chunk_for_processing(self):
|
|
"""
|
|
Retrieves the next chunk of audio data for processing based on the current offsets.
|
|
|
|
Calculates which part of the audio data should be processed next, based on
|
|
the difference between the current timestamp offset and the frame's offset, scaled by
|
|
the audio sample rate (RATE). It then returns this chunk of audio data along with its
|
|
duration in seconds.
|
|
|
|
Returns:
|
|
tuple: A tuple containing:
|
|
- input_bytes (np.ndarray): The next chunk of audio data to be processed.
|
|
- duration (float): The duration of the audio chunk in seconds.
|
|
"""
|
|
with self.lock:
|
|
samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE)
|
|
input_bytes = self.frames_np[int(samples_take):].copy()
|
|
duration = input_bytes.shape[0] / self.RATE
|
|
return input_bytes, duration
|
|
|
|
def prepare_segments(self, last_segment=None):
|
|
"""
|
|
Prepares the segments of transcribed text to be sent to the client.
|
|
|
|
This method compiles the recent segments of transcribed text, ensuring that only the
|
|
specified number of the most recent segments are included. It also appends the most
|
|
recent segment of text if provided (which is considered incomplete because of the possibility
|
|
of the last word being truncated in the audio chunk).
|
|
|
|
Args:
|
|
last_segment (str, optional): The most recent segment of transcribed text to be added
|
|
to the list of segments. Defaults to None.
|
|
|
|
Returns:
|
|
list: A list of transcribed text segments to be sent to the client.
|
|
"""
|
|
segments = []
|
|
if len(self.transcript) >= self.send_last_n_segments:
|
|
segments = self.transcript[-self.send_last_n_segments:].copy()
|
|
else:
|
|
segments = self.transcript.copy()
|
|
if last_segment is not None:
|
|
segments = segments + [last_segment]
|
|
return segments
|
|
|
|
def get_audio_chunk_duration(self, input_bytes):
|
|
"""
|
|
Calculates the duration of the provided audio chunk.
|
|
|
|
Args:
|
|
input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration.
|
|
|
|
Returns:
|
|
float: The duration of the audio chunk in seconds.
|
|
"""
|
|
return input_bytes.shape[0] / self.RATE
|
|
|
|
def send_transcription_to_client(self, segments):
|
|
"""
|
|
Sends the specified transcription segments to the client over the websocket connection.
|
|
|
|
This method formats the transcription segments into a JSON object and attempts to send
|
|
this object to the client. If an error occurs during the send operation, it logs the error.
|
|
|
|
Returns:
|
|
segments (list): A list of transcription segments to be sent to the client.
|
|
"""
|
|
try:
|
|
self.websocket.send(
|
|
json.dumps({
|
|
"uid": self.client_uid,
|
|
"segments": segments,
|
|
})
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"[ERROR]: Sending data to client: {e}")
|
|
|
|
def disconnect(self):
|
|
"""
|
|
Notify the client of disconnection and send a disconnect message.
|
|
|
|
This method sends a disconnect message to the client via the WebSocket connection to notify them
|
|
that the transcription service is disconnecting gracefully.
|
|
|
|
"""
|
|
self.websocket.send(json.dumps({
|
|
"uid": self.client_uid,
|
|
"message": self.DISCONNECT
|
|
}))
|
|
|
|
def cleanup(self):
|
|
"""
|
|
Perform cleanup tasks before exiting the transcription service.
|
|
|
|
This method performs necessary cleanup tasks, including stopping the transcription thread, marking
|
|
the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
|
|
associated with the transcription process.
|
|
|
|
"""
|
|
logging.info("Cleaning up.")
|
|
self.exit = True
|
|
|
|
|
|
class ServeClientTensorRT(ServeClientBase):
|
|
|
|
SINGLE_MODEL = None
|
|
SINGLE_MODEL_LOCK = threading.Lock()
|
|
|
|
def __init__(self, websocket, task="transcribe", multilingual=False, language=None, client_uid=None, model=None, single_model=False):
|
|
"""
|
|
Initialize a ServeClient instance.
|
|
The Whisper model is initialized based on the client's language and device availability.
|
|
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
|
|
to the client to indicate that the server is ready.
|
|
|
|
Args:
|
|
websocket (WebSocket): The WebSocket connection for the client.
|
|
task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe".
|
|
device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
|
|
multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False.
|
|
language (str, optional): The language for transcription. Defaults to None.
|
|
client_uid (str, optional): A unique identifier for the client. Defaults to None.
|
|
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
|
|
|
|
"""
|
|
super().__init__(client_uid, websocket)
|
|
self.language = language if multilingual else "en"
|
|
self.task = task
|
|
self.eos = False
|
|
|
|
if single_model:
|
|
if ServeClientTensorRT.SINGLE_MODEL is None:
|
|
self.create_model(model, multilingual)
|
|
ServeClientTensorRT.SINGLE_MODEL = self.transcriber
|
|
else:
|
|
self.transcriber = ServeClientTensorRT.SINGLE_MODEL
|
|
else:
|
|
self.create_model(model, multilingual)
|
|
|
|
# threading
|
|
self.trans_thread = threading.Thread(target=self.speech_to_text)
|
|
self.trans_thread.start()
|
|
|
|
self.websocket.send(json.dumps({
|
|
"uid": self.client_uid,
|
|
"message": self.SERVER_READY,
|
|
"backend": "tensorrt"
|
|
}))
|
|
|
|
def create_model(self, model, multilingual, warmup=True):
|
|
"""
|
|
Instantiates a new model, sets it as the transcriber and does warmup if desired.
|
|
"""
|
|
self.transcriber = WhisperTRTLLM(
|
|
model,
|
|
assets_dir="assets",
|
|
device="cuda",
|
|
is_multilingual=multilingual,
|
|
language=self.language,
|
|
task=self.task
|
|
)
|
|
if warmup:
|
|
self.warmup()
|
|
|
|
def warmup(self, warmup_steps=10):
|
|
"""
|
|
Warmup TensorRT since first few inferences are slow.
|
|
|
|
Args:
|
|
warmup_steps (int): Number of steps to warm up the model for.
|
|
"""
|
|
logging.info("[INFO:] Warming up TensorRT engine..")
|
|
mel, _ = self.transcriber.log_mel_spectrogram("assets/jfk.flac")
|
|
for i in range(warmup_steps):
|
|
self.transcriber.transcribe(mel)
|
|
|
|
def set_eos(self, eos):
|
|
"""
|
|
Sets the End of Speech (EOS) flag.
|
|
|
|
Args:
|
|
eos (bool): The value to set for the EOS flag.
|
|
"""
|
|
self.lock.acquire()
|
|
self.eos = eos
|
|
self.lock.release()
|
|
|
|
def handle_transcription_output(self, last_segment, duration):
|
|
"""
|
|
Handle the transcription output, updating the transcript and sending data to the client.
|
|
|
|
Args:
|
|
last_segment (str): The last segment from the whisper output which is considered to be incomplete because
|
|
of the possibility of word being truncated.
|
|
duration (float): Duration of the transcribed audio chunk.
|
|
"""
|
|
segments = self.prepare_segments({"text": last_segment})
|
|
self.send_transcription_to_client(segments)
|
|
if self.eos:
|
|
self.update_timestamp_offset(last_segment, duration)
|
|
|
|
def transcribe_audio(self, input_bytes):
|
|
"""
|
|
Transcribe the audio chunk and send the results to the client.
|
|
|
|
Args:
|
|
input_bytes (np.array): The audio chunk to transcribe.
|
|
"""
|
|
if ServeClientTensorRT.SINGLE_MODEL:
|
|
ServeClientTensorRT.SINGLE_MODEL_LOCK.acquire()
|
|
logging.info(f"[WhisperTensorRT:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}")
|
|
mel, duration = self.transcriber.log_mel_spectrogram(input_bytes)
|
|
last_segment = self.transcriber.transcribe(
|
|
mel,
|
|
text_prefix=f"<|startoftranscript|><|{self.language}|><|{self.task}|><|notimestamps|>"
|
|
)
|
|
if ServeClientTensorRT.SINGLE_MODEL:
|
|
ServeClientTensorRT.SINGLE_MODEL_LOCK.release()
|
|
if last_segment:
|
|
self.handle_transcription_output(last_segment, duration)
|
|
|
|
def update_timestamp_offset(self, last_segment, duration):
|
|
"""
|
|
Update timestamp offset and transcript.
|
|
|
|
Args:
|
|
last_segment (str): Last transcribed audio from the whisper model.
|
|
duration (float): Duration of the last audio chunk.
|
|
"""
|
|
if not len(self.transcript):
|
|
self.transcript.append({"text": last_segment + " "})
|
|
elif self.transcript[-1]["text"].strip() != last_segment:
|
|
self.transcript.append({"text": last_segment + " "})
|
|
|
|
with self.lock:
|
|
self.timestamp_offset += duration
|
|
|
|
def speech_to_text(self):
|
|
"""
|
|
Process an audio stream in an infinite loop, continuously transcribing the speech.
|
|
|
|
This method continuously receives audio frames, performs real-time transcription, and sends
|
|
transcribed segments to the client via a WebSocket connection.
|
|
|
|
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
|
|
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
|
|
are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech
|
|
(no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if
|
|
there is no speech for a specified duration to indicate a pause.
|
|
|
|
Raises:
|
|
Exception: If there is an issue with audio processing or WebSocket communication.
|
|
|
|
"""
|
|
while True:
|
|
if self.exit:
|
|
logging.info("Exiting speech to text thread")
|
|
break
|
|
|
|
if self.frames_np is None:
|
|
time.sleep(0.02) # wait for any audio to arrive
|
|
continue
|
|
|
|
self.clip_audio_if_no_valid_segment()
|
|
|
|
input_bytes, duration = self.get_audio_chunk_for_processing()
|
|
if duration < 0.4:
|
|
continue
|
|
|
|
try:
|
|
input_sample = input_bytes.copy()
|
|
logging.info(f"[WhisperTensorRT:] Processing audio with duration: {duration}")
|
|
self.transcribe_audio(input_sample)
|
|
|
|
except Exception as e:
|
|
logging.error(f"[ERROR]: {e}")
|
|
|
|
|
|
class ServeClientFasterWhisper(ServeClientBase):
|
|
|
|
SINGLE_MODEL = None
|
|
SINGLE_MODEL_LOCK = threading.Lock()
|
|
|
|
def __init__(self, websocket, task="transcribe", device=None, language=None, client_uid=None, model="small.en",
|
|
initial_prompt=None, vad_parameters=None, use_vad=True, single_model=False):
|
|
"""
|
|
Initialize a ServeClient instance.
|
|
The Whisper model is initialized based on the client's language and device availability.
|
|
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
|
|
to the client to indicate that the server is ready.
|
|
|
|
Args:
|
|
websocket (WebSocket): The WebSocket connection for the client.
|
|
task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe".
|
|
device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None.
|
|
language (str, optional): The language for transcription. Defaults to None.
|
|
client_uid (str, optional): A unique identifier for the client. Defaults to None.
|
|
model (str, optional): The whisper model size. Defaults to 'small.en'
|
|
initial_prompt (str, optional): Prompt for whisper inference. Defaults to None.
|
|
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
|
|
"""
|
|
super().__init__(client_uid, websocket)
|
|
self.model_sizes = [
|
|
"tiny", "tiny.en", "base", "base.en", "small", "small.en",
|
|
"medium", "medium.en", "large-v2", "large-v3", "distil-small.en",
|
|
"distil-medium.en", "distil-large-v2", "distil-large-v3",
|
|
"large-v3-turbo", "turbo"
|
|
]
|
|
|
|
if not os.path.exists(model):
|
|
self.model_size_or_path = self.check_valid_model(model)
|
|
else:
|
|
self.model_size_or_path = model
|
|
self.language = "en" if self.model_size_or_path.endswith("en") else language
|
|
self.task = task
|
|
self.initial_prompt = initial_prompt
|
|
self.vad_parameters = vad_parameters or {"onset": 0.5}
|
|
self.no_speech_thresh = 0.45
|
|
self.same_output_threshold = 10
|
|
self.end_time_for_same_output = None
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
if device == "cuda":
|
|
major, _ = torch.cuda.get_device_capability(device)
|
|
self.compute_type = "float16" if major >= 7 else "float32"
|
|
else:
|
|
self.compute_type = "int8"
|
|
|
|
if self.model_size_or_path is None:
|
|
return
|
|
logging.info(f"Using Device={device} with precision {self.compute_type}")
|
|
|
|
try:
|
|
if single_model:
|
|
if ServeClientFasterWhisper.SINGLE_MODEL is None:
|
|
self.create_model(device)
|
|
ServeClientFasterWhisper.SINGLE_MODEL = self.transcriber
|
|
else:
|
|
self.transcriber = ServeClientFasterWhisper.SINGLE_MODEL
|
|
else:
|
|
self.create_model(device)
|
|
except Exception as e:
|
|
logging.error(f"Failed to load model: {e}")
|
|
self.websocket.send(json.dumps({
|
|
"uid": self.client_uid,
|
|
"status": "ERROR",
|
|
"message": f"Failed to load model: {str(self.model_size_or_path)}"
|
|
}))
|
|
self.websocket.close()
|
|
return
|
|
|
|
self.use_vad = use_vad
|
|
|
|
# threading
|
|
self.trans_thread = threading.Thread(target=self.speech_to_text)
|
|
self.trans_thread.start()
|
|
self.websocket.send(
|
|
json.dumps(
|
|
{
|
|
"uid": self.client_uid,
|
|
"message": self.SERVER_READY,
|
|
"backend": "faster_whisper"
|
|
}
|
|
)
|
|
)
|
|
|
|
def create_model(self, device):
|
|
"""
|
|
Instantiates a new model, sets it as the transcriber.
|
|
"""
|
|
self.transcriber = WhisperModel(
|
|
self.model_size_or_path,
|
|
device=device,
|
|
compute_type=self.compute_type,
|
|
local_files_only=False,
|
|
)
|
|
|
|
def check_valid_model(self, model_size):
|
|
"""
|
|
Check if it's a valid whisper model size.
|
|
|
|
Args:
|
|
model_size (str): The name of the model size to check.
|
|
|
|
Returns:
|
|
str: The model size if valid, None otherwise.
|
|
"""
|
|
if model_size not in self.model_sizes:
|
|
self.websocket.send(
|
|
json.dumps(
|
|
{
|
|
"uid": self.client_uid,
|
|
"status": "ERROR",
|
|
"message": f"Invalid model size {model_size}. Available choices: {self.model_sizes}"
|
|
}
|
|
)
|
|
)
|
|
return None
|
|
return model_size
|
|
|
|
def set_language(self, info):
|
|
"""
|
|
Updates the language attribute based on the detected language information.
|
|
|
|
Args:
|
|
info (object): An object containing the detected language and its probability. This object
|
|
must have at least two attributes: `language`, a string indicating the detected
|
|
language, and `language_probability`, a float representing the confidence level
|
|
of the language detection.
|
|
"""
|
|
if info.language_probability > 0.5:
|
|
self.language = info.language
|
|
logging.info(f"Detected language {self.language} with probability {info.language_probability}")
|
|
self.websocket.send(json.dumps(
|
|
{"uid": self.client_uid, "language": self.language, "language_prob": info.language_probability}))
|
|
|
|
def transcribe_audio(self, input_sample):
|
|
"""
|
|
Transcribes the provided audio sample using the configured transcriber instance.
|
|
|
|
If the language has not been set, it updates the session's language based on the transcription
|
|
information.
|
|
|
|
Args:
|
|
input_sample (np.array): The audio chunk to be transcribed. This should be a NumPy
|
|
array representing the audio data.
|
|
|
|
Returns:
|
|
The transcription result from the transcriber. The exact format of this result
|
|
depends on the implementation of the `transcriber.transcribe` method but typically
|
|
includes the transcribed text.
|
|
"""
|
|
if ServeClientFasterWhisper.SINGLE_MODEL:
|
|
ServeClientFasterWhisper.SINGLE_MODEL_LOCK.acquire()
|
|
result, info = self.transcriber.transcribe(
|
|
input_sample,
|
|
initial_prompt=self.initial_prompt,
|
|
language=self.language,
|
|
task=self.task,
|
|
vad_filter=self.use_vad,
|
|
vad_parameters=self.vad_parameters if self.use_vad else None)
|
|
if ServeClientFasterWhisper.SINGLE_MODEL:
|
|
ServeClientFasterWhisper.SINGLE_MODEL_LOCK.release()
|
|
|
|
if self.language is None and info is not None:
|
|
self.set_language(info)
|
|
return result
|
|
|
|
def get_previous_output(self):
|
|
"""
|
|
Retrieves previously generated transcription outputs if no new transcription is available
|
|
from the current audio chunks.
|
|
|
|
Checks the time since the last transcription output and, if it is within a specified
|
|
threshold, returns the most recent segments of transcribed text. It also manages
|
|
adding a pause (blank segment) to indicate a significant gap in speech based on a defined
|
|
threshold.
|
|
|
|
Returns:
|
|
segments (list): A list of transcription segments. This may include the most recent
|
|
transcribed text segments or a blank segment to indicate a pause
|
|
in speech.
|
|
"""
|
|
segments = []
|
|
if self.t_start is None:
|
|
self.t_start = time.time()
|
|
if time.time() - self.t_start < self.show_prev_out_thresh:
|
|
segments = self.prepare_segments()
|
|
|
|
# add a blank if there is no speech for 3 seconds
|
|
if len(self.text) and self.text[-1] != '':
|
|
if time.time() - self.t_start > self.add_pause_thresh:
|
|
self.text.append('')
|
|
return segments
|
|
|
|
def handle_transcription_output(self, result, duration):
|
|
"""
|
|
Handle the transcription output, updating the transcript and sending data to the client.
|
|
|
|
Args:
|
|
result (str): The result from whisper inference i.e. the list of segments.
|
|
duration (float): Duration of the transcribed audio chunk.
|
|
"""
|
|
segments = []
|
|
if len(result):
|
|
self.t_start = None
|
|
last_segment = self.update_segments(result, duration)
|
|
segments = self.prepare_segments(last_segment)
|
|
else:
|
|
# show previous output if there is pause i.e. no output from whisper
|
|
segments = self.get_previous_output()
|
|
|
|
if len(segments):
|
|
self.send_transcription_to_client(segments)
|
|
|
|
def speech_to_text(self):
|
|
"""
|
|
Process an audio stream in an infinite loop, continuously transcribing the speech.
|
|
|
|
This method continuously receives audio frames, performs real-time transcription, and sends
|
|
transcribed segments to the client via a WebSocket connection.
|
|
|
|
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
|
|
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
|
|
are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech
|
|
(no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if
|
|
there is no speech for a specified duration to indicate a pause.
|
|
|
|
Raises:
|
|
Exception: If there is an issue with audio processing or WebSocket communication.
|
|
|
|
"""
|
|
while True:
|
|
if self.exit:
|
|
logging.info("Exiting speech to text thread")
|
|
break
|
|
|
|
if self.frames_np is None:
|
|
continue
|
|
|
|
self.clip_audio_if_no_valid_segment()
|
|
|
|
input_bytes, duration = self.get_audio_chunk_for_processing()
|
|
if duration < 1.0:
|
|
time.sleep(0.1) # wait for audio chunks to arrive
|
|
continue
|
|
try:
|
|
input_sample = input_bytes.copy()
|
|
result = self.transcribe_audio(input_sample)
|
|
|
|
if result is None or self.language is None:
|
|
self.timestamp_offset += duration
|
|
time.sleep(0.25) # wait for voice activity, result is None when no voice activity
|
|
continue
|
|
self.handle_transcription_output(result, duration)
|
|
|
|
except Exception as e:
|
|
logging.error(f"[ERROR]: Failed to transcribe audio chunk: {e}")
|
|
time.sleep(0.01)
|
|
|
|
def format_segment(self, start, end, text, completed=False):
|
|
"""
|
|
Formats a transcription segment with precise start and end times alongside the transcribed text.
|
|
|
|
Args:
|
|
start (float): The start time of the transcription segment in seconds.
|
|
end (float): The end time of the transcription segment in seconds.
|
|
text (str): The transcribed text corresponding to the segment.
|
|
|
|
Returns:
|
|
dict: A dictionary representing the formatted transcription segment, including
|
|
'start' and 'end' times as strings with three decimal places and the 'text'
|
|
of the transcription.
|
|
"""
|
|
return {
|
|
'start': "{:.3f}".format(start),
|
|
'end': "{:.3f}".format(end),
|
|
'text': text,
|
|
'completed': completed
|
|
}
|
|
|
|
def update_segments(self, segments, duration):
|
|
"""
|
|
Processes the segments from whisper. Appends all the segments to the list
|
|
except for the last segment assuming that it is incomplete.
|
|
|
|
Updates the ongoing transcript with transcribed segments, including their start and end times.
|
|
Complete segments are appended to the transcript in chronological order. Incomplete segments
|
|
(assumed to be the last one) are processed to identify repeated content. If the same incomplete
|
|
segment is seen multiple times, it updates the offset and appends the segment to the transcript.
|
|
A threshold is used to detect repeated content and ensure it is only included once in the transcript.
|
|
The timestamp offset is updated based on the duration of processed segments. The method returns the
|
|
last processed segment, allowing it to be sent to the client for real-time updates.
|
|
|
|
Args:
|
|
segments(dict) : dictionary of segments as returned by whisper
|
|
duration(float): duration of the current chunk
|
|
|
|
Returns:
|
|
dict or None: The last processed segment with its start time, end time, and transcribed text.
|
|
Returns None if there are no valid segments to process.
|
|
"""
|
|
offset = None
|
|
self.current_out = ''
|
|
last_segment = None
|
|
|
|
# process complete segments
|
|
if len(segments) > 1 and segments[-1].no_speech_prob <= self.no_speech_thresh:
|
|
for i, s in enumerate(segments[:-1]):
|
|
text_ = s.text
|
|
self.text.append(text_)
|
|
with self.lock:
|
|
start, end = self.timestamp_offset + s.start, self.timestamp_offset + min(duration, s.end)
|
|
|
|
if start >= end:
|
|
continue
|
|
if s.no_speech_prob > self.no_speech_thresh:
|
|
continue
|
|
|
|
self.transcript.append(self.format_segment(start, end, text_, completed=True))
|
|
offset = min(duration, s.end)
|
|
|
|
# only process the last segment if it satisfies the no_speech_thresh
|
|
if segments[-1].no_speech_prob <= self.no_speech_thresh:
|
|
self.current_out += segments[-1].text
|
|
with self.lock:
|
|
last_segment = self.format_segment(
|
|
self.timestamp_offset + segments[-1].start,
|
|
self.timestamp_offset + min(duration, segments[-1].end),
|
|
self.current_out,
|
|
completed=False
|
|
)
|
|
|
|
if self.current_out.strip() == self.prev_out.strip() and self.current_out != '':
|
|
self.same_output_count += 1
|
|
|
|
# if we remove the audio because of same output on the nth reptition we might remove the
|
|
# audio thats not yet transcribed so, capturing the time when it was repeated for the first time
|
|
if self.end_time_for_same_output is None:
|
|
self.end_time_for_same_output = segments[-1].end
|
|
time.sleep(0.1) # wait for some voice activity just in case there is an unitended pause from the speaker for better punctuations.
|
|
else:
|
|
self.same_output_count = 0
|
|
self.end_time_for_same_output = None
|
|
|
|
# if same incomplete segment is seen multiple times then update the offset
|
|
# and append the segment to the list
|
|
if self.same_output_count > self.same_output_threshold:
|
|
if not len(self.text) or self.text[-1].strip().lower() != self.current_out.strip().lower():
|
|
self.text.append(self.current_out)
|
|
with self.lock:
|
|
self.transcript.append(self.format_segment(
|
|
self.timestamp_offset,
|
|
self.timestamp_offset + min(duration, self.end_time_for_same_output),
|
|
self.current_out,
|
|
completed=True
|
|
))
|
|
self.current_out = ''
|
|
offset = min(duration, self.end_time_for_same_output)
|
|
self.same_output_count = 0
|
|
last_segment = None
|
|
self.end_time_for_same_output = None
|
|
else:
|
|
self.prev_out = self.current_out
|
|
|
|
# update offset
|
|
if offset is not None:
|
|
with self.lock:
|
|
self.timestamp_offset += offset
|
|
|
|
return last_segment |