import os import time import threading import json import functools import logging from enum import Enum from typing import List, Optional import torch import numpy as np from websockets.sync.server import serve from websockets.exceptions import ConnectionClosed from whisper_live.vad import VoiceActivityDetector from whisper_live.transcriber import WhisperModel try: from whisper_live.transcriber_tensorrt import WhisperTRTLLM except Exception: pass logging.basicConfig(level=logging.INFO) class ClientManager: def __init__(self, max_clients=4, max_connection_time=600): """ Initializes the ClientManager with specified limits on client connections and connection durations. Args: max_clients (int, optional): The maximum number of simultaneous client connections allowed. Defaults to 4. max_connection_time (int, optional): The maximum duration (in seconds) a client can stay connected. Defaults to 600 seconds (10 minutes). """ self.clients = {} self.start_times = {} self.max_clients = max_clients self.max_connection_time = max_connection_time def add_client(self, websocket, client): """ Adds a client and their connection start time to the tracking dictionaries. Args: websocket: The websocket associated with the client to add. client: The client object to be added and tracked. """ self.clients[websocket] = client self.start_times[websocket] = time.time() def get_client(self, websocket): """ Retrieves a client associated with the given websocket. Args: websocket: The websocket associated with the client to retrieve. Returns: The client object if found, False otherwise. """ if websocket in self.clients: return self.clients[websocket] return False def remove_client(self, websocket): """ Removes a client and their connection start time from the tracking dictionaries. Performs cleanup on the client if necessary. Args: websocket: The websocket associated with the client to be removed. """ client = self.clients.pop(websocket, None) if client: client.cleanup() self.start_times.pop(websocket, None) def get_wait_time(self): """ Calculates the estimated wait time for new clients based on the remaining connection times of current clients. Returns: The estimated wait time in minutes for new clients to connect. Returns 0 if there are available slots. """ wait_time = None for start_time in self.start_times.values(): current_client_time_remaining = self.max_connection_time - (time.time() - start_time) if wait_time is None or current_client_time_remaining < wait_time: wait_time = current_client_time_remaining return wait_time / 60 if wait_time is not None else 0 def is_server_full(self, websocket, options): """ Checks if the server is at its maximum client capacity and sends a wait message to the client if necessary. Args: websocket: The websocket of the client attempting to connect. options: A dictionary of options that may include the client's unique identifier. Returns: True if the server is full, False otherwise. """ if len(self.clients) >= self.max_clients: wait_time = self.get_wait_time() response = {"uid": options["uid"], "status": "WAIT", "message": wait_time} websocket.send(json.dumps(response)) return True return False def is_client_timeout(self, websocket): """ Checks if a client has exceeded the maximum allowed connection time and disconnects them if so, issuing a warning. Args: websocket: The websocket associated with the client to check. Returns: True if the client's connection time has exceeded the maximum limit, False otherwise. """ elapsed_time = time.time() - self.start_times[websocket] if elapsed_time >= self.max_connection_time: self.clients[websocket].disconnect() logging.warning(f"Client with uid '{self.clients[websocket].client_uid}' disconnected due to overtime.") return True return False class BackendType(Enum): FASTER_WHISPER = "faster_whisper" TENSORRT = "tensorrt" @staticmethod def valid_types() -> List[str]: return [backend_type.value for backend_type in BackendType] @staticmethod def is_valid(backend: str) -> bool: return backend in BackendType.valid_types() def is_faster_whisper(self) -> bool: return self == BackendType.FASTER_WHISPER def is_tensorrt(self) -> bool: return self == BackendType.TENSORRT class TranscriptionServer: RATE = 16000 def __init__(self): self.client_manager = None self.no_voice_activity_chunks = 0 self.use_vad = True self.single_model = False def initialize_client( self, websocket, options, faster_whisper_custom_model_path, whisper_tensorrt_path, trt_multilingual ): client: Optional[ServeClientBase] = None if self.backend.is_tensorrt(): try: client = ServeClientTensorRT( websocket, multilingual=trt_multilingual, language=options["language"], task=options["task"], client_uid=options["uid"], model=whisper_tensorrt_path, single_model=self.single_model, ) logging.info("Running TensorRT backend.") except Exception as e: logging.error(f"TensorRT-LLM not supported: {e}") self.client_uid = options["uid"] websocket.send(json.dumps({ "uid": self.client_uid, "status": "WARNING", "message": "TensorRT-LLM not supported on Server yet. " "Reverting to available backend: 'faster_whisper'" })) self.backend = BackendType.FASTER_WHISPER try: if self.backend.is_faster_whisper(): if faster_whisper_custom_model_path is not None and os.path.exists(faster_whisper_custom_model_path): logging.info(f"Using custom model {faster_whisper_custom_model_path}") options["model"] = faster_whisper_custom_model_path client = ServeClientFasterWhisper( websocket, language=options["language"], task=options["task"], client_uid=options["uid"], model=options["model"], initial_prompt=options.get("initial_prompt"), vad_parameters=options.get("vad_parameters"), use_vad=self.use_vad, single_model=self.single_model, ) logging.info("Running faster_whisper backend.") except Exception as e: return if client is None: raise ValueError(f"Backend type {self.backend.value} not recognised or not handled.") self.client_manager.add_client(websocket, client) def get_audio_from_websocket(self, websocket): """ Receives audio buffer from websocket and creates a numpy array out of it. Args: websocket: The websocket to receive audio from. Returns: A numpy array containing the audio. """ frame_data = websocket.recv() if frame_data == b"END_OF_AUDIO": return False return np.frombuffer(frame_data, dtype=np.float32) def handle_new_connection(self, websocket, faster_whisper_custom_model_path, whisper_tensorrt_path, trt_multilingual): try: logging.info("New client connected") options = websocket.recv() options = json.loads(options) if self.client_manager is None: max_clients = options.get('max_clients', 4) max_connection_time = options.get('max_connection_time', 600) self.client_manager = ClientManager(max_clients, max_connection_time) self.use_vad = options.get('use_vad') if self.client_manager.is_server_full(websocket, options): websocket.close() return False # Indicates that the connection should not continue if self.backend.is_tensorrt(): self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE) self.initialize_client(websocket, options, faster_whisper_custom_model_path, whisper_tensorrt_path, trt_multilingual) return True except json.JSONDecodeError: logging.error("Failed to decode JSON from client") return False except ConnectionClosed: logging.info("Connection closed by client") return False except Exception as e: logging.error(f"Error during new connection initialization: {str(e)}") return False def process_audio_frames(self, websocket): frame_np = self.get_audio_from_websocket(websocket) client = self.client_manager.get_client(websocket) if frame_np is False: if self.backend.is_tensorrt(): client.set_eos(True) return False if self.backend.is_tensorrt(): voice_active = self.voice_activity(websocket, frame_np) if voice_active: self.no_voice_activity_chunks = 0 client.set_eos(False) if self.use_vad and not voice_active: return True client.add_frames(frame_np) return True def recv_audio(self, websocket, backend: BackendType = BackendType.FASTER_WHISPER, faster_whisper_custom_model_path=None, whisper_tensorrt_path=None, trt_multilingual=False): """ Receive audio chunks from a client in an infinite loop. Continuously receives audio frames from a connected client over a WebSocket connection. It processes the audio frames using a voice activity detection (VAD) model to determine if they contain speech or not. If the audio frame contains speech, it is added to the client's audio data for ASR. If the maximum number of clients is reached, the method sends a "WAIT" status to the client, indicating that they should wait until a slot is available. If a client's connection exceeds the maximum allowed time, it will be disconnected, and the client's resources will be cleaned up. Args: websocket (WebSocket): The WebSocket connection for the client. backend (str): The backend to run the server with. faster_whisper_custom_model_path (str): path to custom faster whisper model. whisper_tensorrt_path (str): Required for tensorrt backend. trt_multilingual(bool): Only used for tensorrt, True if multilingual model. Raises: Exception: If there is an error during the audio frame processing. """ self.backend = backend if not self.handle_new_connection(websocket, faster_whisper_custom_model_path, whisper_tensorrt_path, trt_multilingual): return try: while not self.client_manager.is_client_timeout(websocket): if not self.process_audio_frames(websocket): break except ConnectionClosed: logging.info("Connection closed by client") except Exception as e: logging.error(f"Unexpected error: {str(e)}") finally: if self.client_manager.get_client(websocket): self.cleanup(websocket) websocket.close() del websocket def run(self, host, port=int(os.getenv('PORT_WHISPERLIVE')), backend="tensorrt", faster_whisper_custom_model_path=None, whisper_tensorrt_path=None, trt_multilingual=False, single_model=False, ssl_context=None): """ Run the transcription server. Args: host (str): The host address to bind the server. port (int): The port number to bind the server. """ if faster_whisper_custom_model_path is not None and not os.path.exists(faster_whisper_custom_model_path): raise ValueError(f"Custom faster_whisper model '{faster_whisper_custom_model_path}' is not a valid path.") if whisper_tensorrt_path is not None and not os.path.exists(whisper_tensorrt_path): raise ValueError(f"TensorRT model '{whisper_tensorrt_path}' is not a valid path.") if single_model: if faster_whisper_custom_model_path or whisper_tensorrt_path: logging.info("Custom model option was provided. Switching to single model mode.") self.single_model = True # TODO: load model initially else: logging.info("Single model mode currently only works with custom models.") if not BackendType.is_valid(backend): raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}") with serve( functools.partial( self.recv_audio, backend=BackendType(backend), faster_whisper_custom_model_path=faster_whisper_custom_model_path, whisper_tensorrt_path=whisper_tensorrt_path, trt_multilingual=trt_multilingual ), host, port, ssl_context=ssl_context ) as server: server.serve_forever() def voice_activity(self, websocket, frame_np): """ Evaluates the voice activity in a given audio frame and manages the state of voice activity detection. This method uses the configured voice activity detection (VAD) model to assess whether the given audio frame contains speech. If the VAD model detects no voice activity for more than three consecutive frames, it sets an end-of-speech (EOS) flag for the associated client. This method aims to efficiently manage speech detection to improve subsequent processing steps. Args: websocket: The websocket associated with the current client. Used to retrieve the client object from the client manager for state management. frame_np (numpy.ndarray): The audio frame to be analyzed. This should be a NumPy array containing the audio data for the current frame. Returns: bool: True if voice activity is detected in the current frame, False otherwise. When returning False after detecting no voice activity for more than three consecutive frames, it also triggers the end-of-speech (EOS) flag for the client. """ if not self.vad_detector(frame_np): self.no_voice_activity_chunks += 1 if self.no_voice_activity_chunks > 3: client = self.client_manager.get_client(websocket) if not client.eos: client.set_eos(True) time.sleep(0.1) # Sleep 100m; wait some voice activity. return False return True def cleanup(self, websocket): """ Cleans up resources associated with a given client's websocket. Args: websocket: The websocket associated with the client to be cleaned up. """ if self.client_manager.get_client(websocket): self.client_manager.remove_client(websocket) class ServeClientBase(object): RATE = 16000 SERVER_READY = "SERVER_READY" DISCONNECT = "DISCONNECT" def __init__(self, client_uid, websocket): self.client_uid = client_uid self.websocket = websocket self.frames = b"" self.timestamp_offset = 0.0 self.frames_np = None self.frames_offset = 0.0 self.text = [] self.current_out = '' self.prev_out = '' self.t_start = None self.exit = False self.same_output_count = 0 self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds self.transcript = [] self.send_last_n_segments = 10 # text formatting self.pick_previous_segments = 2 # threading self.lock = threading.Lock() def speech_to_text(self): raise NotImplementedError def transcribe_audio(self): raise NotImplementedError def handle_transcription_output(self): raise NotImplementedError def add_frames(self, frame_np): """ Add audio frames to the ongoing audio stream buffer. This method is responsible for maintaining the audio stream buffer, allowing the continuous addition of audio frames as they are received. It also ensures that the buffer does not exceed a specified size to prevent excessive memory usage. If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided audio frame. The audio stream buffer is used for real-time processing of audio data for transcription. Args: frame_np (numpy.ndarray): The audio frame data as a NumPy array. """ self.lock.acquire() if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE: self.frames_offset += 30.0 self.frames_np = self.frames_np[int(30*self.RATE):] # check timestamp offset(should be >= self.frame_offset) # this basically means that there is no speech as timestamp offset hasnt updated # and is less than frame_offset if self.timestamp_offset < self.frames_offset: self.timestamp_offset = self.frames_offset if self.frames_np is None: self.frames_np = frame_np.copy() else: self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0) self.lock.release() def clip_audio_if_no_valid_segment(self): """ Update the timestamp offset based on audio buffer status. Clip audio if the current chunk exceeds 30 seconds, this basically implies that no valid segment for the last 30 seconds from whisper """ with self.lock: if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE: duration = self.frames_np.shape[0] / self.RATE self.timestamp_offset = self.frames_offset + duration - 5 def get_audio_chunk_for_processing(self): """ Retrieves the next chunk of audio data for processing based on the current offsets. Calculates which part of the audio data should be processed next, based on the difference between the current timestamp offset and the frame's offset, scaled by the audio sample rate (RATE). It then returns this chunk of audio data along with its duration in seconds. Returns: tuple: A tuple containing: - input_bytes (np.ndarray): The next chunk of audio data to be processed. - duration (float): The duration of the audio chunk in seconds. """ with self.lock: samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE) input_bytes = self.frames_np[int(samples_take):].copy() duration = input_bytes.shape[0] / self.RATE return input_bytes, duration def prepare_segments(self, last_segment=None): """ Prepares the segments of transcribed text to be sent to the client. This method compiles the recent segments of transcribed text, ensuring that only the specified number of the most recent segments are included. It also appends the most recent segment of text if provided (which is considered incomplete because of the possibility of the last word being truncated in the audio chunk). Args: last_segment (str, optional): The most recent segment of transcribed text to be added to the list of segments. Defaults to None. Returns: list: A list of transcribed text segments to be sent to the client. """ segments = [] if len(self.transcript) >= self.send_last_n_segments: segments = self.transcript[-self.send_last_n_segments:].copy() else: segments = self.transcript.copy() if last_segment is not None: segments = segments + [last_segment] return segments def get_audio_chunk_duration(self, input_bytes): """ Calculates the duration of the provided audio chunk. Args: input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration. Returns: float: The duration of the audio chunk in seconds. """ return input_bytes.shape[0] / self.RATE def send_transcription_to_client(self, segments): """ Sends the specified transcription segments to the client over the websocket connection. This method formats the transcription segments into a JSON object and attempts to send this object to the client. If an error occurs during the send operation, it logs the error. Returns: segments (list): A list of transcription segments to be sent to the client. """ try: self.websocket.send( json.dumps({ "uid": self.client_uid, "segments": segments, }) ) except Exception as e: logging.error(f"[ERROR]: Sending data to client: {e}") def disconnect(self): """ Notify the client of disconnection and send a disconnect message. This method sends a disconnect message to the client via the WebSocket connection to notify them that the transcription service is disconnecting gracefully. """ self.websocket.send(json.dumps({ "uid": self.client_uid, "message": self.DISCONNECT })) def cleanup(self): """ Perform cleanup tasks before exiting the transcription service. This method performs necessary cleanup tasks, including stopping the transcription thread, marking the exit flag to indicate the transcription thread should exit gracefully, and destroying resources associated with the transcription process. """ logging.info("Cleaning up.") self.exit = True class ServeClientTensorRT(ServeClientBase): SINGLE_MODEL = None SINGLE_MODEL_LOCK = threading.Lock() def __init__(self, websocket, task="transcribe", multilingual=False, language=None, client_uid=None, model=None, single_model=False): """ Initialize a ServeClient instance. The Whisper model is initialized based on the client's language and device availability. The transcription thread is started upon initialization. A "SERVER_READY" message is sent to the client to indicate that the server is ready. Args: websocket (WebSocket): The WebSocket connection for the client. task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe". device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None. multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False. language (str, optional): The language for transcription. Defaults to None. client_uid (str, optional): A unique identifier for the client. Defaults to None. single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False. """ super().__init__(client_uid, websocket) self.language = language if multilingual else "en" self.task = task self.eos = False if single_model: if ServeClientTensorRT.SINGLE_MODEL is None: self.create_model(model, multilingual) ServeClientTensorRT.SINGLE_MODEL = self.transcriber else: self.transcriber = ServeClientTensorRT.SINGLE_MODEL else: self.create_model(model, multilingual) # threading self.trans_thread = threading.Thread(target=self.speech_to_text) self.trans_thread.start() self.websocket.send(json.dumps({ "uid": self.client_uid, "message": self.SERVER_READY, "backend": "tensorrt" })) def create_model(self, model, multilingual, warmup=True): """ Instantiates a new model, sets it as the transcriber and does warmup if desired. """ self.transcriber = WhisperTRTLLM( model, assets_dir="assets", device="cuda", is_multilingual=multilingual, language=self.language, task=self.task ) if warmup: self.warmup() def warmup(self, warmup_steps=10): """ Warmup TensorRT since first few inferences are slow. Args: warmup_steps (int): Number of steps to warm up the model for. """ logging.info("[INFO:] Warming up TensorRT engine..") mel, _ = self.transcriber.log_mel_spectrogram("assets/jfk.flac") for i in range(warmup_steps): self.transcriber.transcribe(mel) def set_eos(self, eos): """ Sets the End of Speech (EOS) flag. Args: eos (bool): The value to set for the EOS flag. """ self.lock.acquire() self.eos = eos self.lock.release() def handle_transcription_output(self, last_segment, duration): """ Handle the transcription output, updating the transcript and sending data to the client. Args: last_segment (str): The last segment from the whisper output which is considered to be incomplete because of the possibility of word being truncated. duration (float): Duration of the transcribed audio chunk. """ segments = self.prepare_segments({"text": last_segment}) self.send_transcription_to_client(segments) if self.eos: self.update_timestamp_offset(last_segment, duration) def transcribe_audio(self, input_bytes): """ Transcribe the audio chunk and send the results to the client. Args: input_bytes (np.array): The audio chunk to transcribe. """ if ServeClientTensorRT.SINGLE_MODEL: ServeClientTensorRT.SINGLE_MODEL_LOCK.acquire() logging.info(f"[WhisperTensorRT:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}") mel, duration = self.transcriber.log_mel_spectrogram(input_bytes) last_segment = self.transcriber.transcribe( mel, text_prefix=f"<|startoftranscript|><|{self.language}|><|{self.task}|><|notimestamps|>" ) if ServeClientTensorRT.SINGLE_MODEL: ServeClientTensorRT.SINGLE_MODEL_LOCK.release() if last_segment: self.handle_transcription_output(last_segment, duration) def update_timestamp_offset(self, last_segment, duration): """ Update timestamp offset and transcript. Args: last_segment (str): Last transcribed audio from the whisper model. duration (float): Duration of the last audio chunk. """ if not len(self.transcript): self.transcript.append({"text": last_segment + " "}) elif self.transcript[-1]["text"].strip() != last_segment: self.transcript.append({"text": last_segment + " "}) with self.lock: self.timestamp_offset += duration def speech_to_text(self): """ Process an audio stream in an infinite loop, continuously transcribing the speech. This method continuously receives audio frames, performs real-time transcription, and sends transcribed segments to the client via a WebSocket connection. If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction. It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech (no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if there is no speech for a specified duration to indicate a pause. Raises: Exception: If there is an issue with audio processing or WebSocket communication. """ while True: if self.exit: logging.info("Exiting speech to text thread") break if self.frames_np is None: time.sleep(0.02) # wait for any audio to arrive continue self.clip_audio_if_no_valid_segment() input_bytes, duration = self.get_audio_chunk_for_processing() if duration < 0.4: continue try: input_sample = input_bytes.copy() logging.info(f"[WhisperTensorRT:] Processing audio with duration: {duration}") self.transcribe_audio(input_sample) except Exception as e: logging.error(f"[ERROR]: {e}") class ServeClientFasterWhisper(ServeClientBase): SINGLE_MODEL = None SINGLE_MODEL_LOCK = threading.Lock() def __init__(self, websocket, task="transcribe", device=None, language=None, client_uid=None, model="small.en", initial_prompt=None, vad_parameters=None, use_vad=True, single_model=False): """ Initialize a ServeClient instance. The Whisper model is initialized based on the client's language and device availability. The transcription thread is started upon initialization. A "SERVER_READY" message is sent to the client to indicate that the server is ready. Args: websocket (WebSocket): The WebSocket connection for the client. task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe". device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None. language (str, optional): The language for transcription. Defaults to None. client_uid (str, optional): A unique identifier for the client. Defaults to None. model (str, optional): The whisper model size. Defaults to 'small.en' initial_prompt (str, optional): Prompt for whisper inference. Defaults to None. single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False. """ super().__init__(client_uid, websocket) self.model_sizes = [ "tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2", "large-v3", "distil-small.en", "distil-medium.en", "distil-large-v2", "distil-large-v3", "large-v3-turbo", "turbo" ] if not os.path.exists(model): self.model_size_or_path = self.check_valid_model(model) else: self.model_size_or_path = model self.language = "en" if self.model_size_or_path.endswith("en") else language self.task = task self.initial_prompt = initial_prompt self.vad_parameters = vad_parameters or {"onset": 0.5} self.no_speech_thresh = 0.45 self.same_output_threshold = 10 self.end_time_for_same_output = None device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cuda": major, _ = torch.cuda.get_device_capability(device) self.compute_type = "float16" if major >= 7 else "float32" else: self.compute_type = "int8" if self.model_size_or_path is None: return logging.info(f"Using Device={device} with precision {self.compute_type}") try: if single_model: if ServeClientFasterWhisper.SINGLE_MODEL is None: self.create_model(device) ServeClientFasterWhisper.SINGLE_MODEL = self.transcriber else: self.transcriber = ServeClientFasterWhisper.SINGLE_MODEL else: self.create_model(device) except Exception as e: logging.error(f"Failed to load model: {e}") self.websocket.send(json.dumps({ "uid": self.client_uid, "status": "ERROR", "message": f"Failed to load model: {str(self.model_size_or_path)}" })) self.websocket.close() return self.use_vad = use_vad # threading self.trans_thread = threading.Thread(target=self.speech_to_text) self.trans_thread.start() self.websocket.send( json.dumps( { "uid": self.client_uid, "message": self.SERVER_READY, "backend": "faster_whisper" } ) ) def create_model(self, device): """ Instantiates a new model, sets it as the transcriber. """ self.transcriber = WhisperModel( self.model_size_or_path, device=device, compute_type=self.compute_type, local_files_only=False, ) def check_valid_model(self, model_size): """ Check if it's a valid whisper model size. Args: model_size (str): The name of the model size to check. Returns: str: The model size if valid, None otherwise. """ if model_size not in self.model_sizes: self.websocket.send( json.dumps( { "uid": self.client_uid, "status": "ERROR", "message": f"Invalid model size {model_size}. Available choices: {self.model_sizes}" } ) ) return None return model_size def set_language(self, info): """ Updates the language attribute based on the detected language information. Args: info (object): An object containing the detected language and its probability. This object must have at least two attributes: `language`, a string indicating the detected language, and `language_probability`, a float representing the confidence level of the language detection. """ if info.language_probability > 0.5: self.language = info.language logging.info(f"Detected language {self.language} with probability {info.language_probability}") self.websocket.send(json.dumps( {"uid": self.client_uid, "language": self.language, "language_prob": info.language_probability})) def transcribe_audio(self, input_sample): """ Transcribes the provided audio sample using the configured transcriber instance. If the language has not been set, it updates the session's language based on the transcription information. Args: input_sample (np.array): The audio chunk to be transcribed. This should be a NumPy array representing the audio data. Returns: The transcription result from the transcriber. The exact format of this result depends on the implementation of the `transcriber.transcribe` method but typically includes the transcribed text. """ if ServeClientFasterWhisper.SINGLE_MODEL: ServeClientFasterWhisper.SINGLE_MODEL_LOCK.acquire() result, info = self.transcriber.transcribe( input_sample, initial_prompt=self.initial_prompt, language=self.language, task=self.task, vad_filter=self.use_vad, vad_parameters=self.vad_parameters if self.use_vad else None) if ServeClientFasterWhisper.SINGLE_MODEL: ServeClientFasterWhisper.SINGLE_MODEL_LOCK.release() if self.language is None and info is not None: self.set_language(info) return result def get_previous_output(self): """ Retrieves previously generated transcription outputs if no new transcription is available from the current audio chunks. Checks the time since the last transcription output and, if it is within a specified threshold, returns the most recent segments of transcribed text. It also manages adding a pause (blank segment) to indicate a significant gap in speech based on a defined threshold. Returns: segments (list): A list of transcription segments. This may include the most recent transcribed text segments or a blank segment to indicate a pause in speech. """ segments = [] if self.t_start is None: self.t_start = time.time() if time.time() - self.t_start < self.show_prev_out_thresh: segments = self.prepare_segments() # add a blank if there is no speech for 3 seconds if len(self.text) and self.text[-1] != '': if time.time() - self.t_start > self.add_pause_thresh: self.text.append('') return segments def handle_transcription_output(self, result, duration): """ Handle the transcription output, updating the transcript and sending data to the client. Args: result (str): The result from whisper inference i.e. the list of segments. duration (float): Duration of the transcribed audio chunk. """ segments = [] if len(result): self.t_start = None last_segment = self.update_segments(result, duration) segments = self.prepare_segments(last_segment) else: # show previous output if there is pause i.e. no output from whisper segments = self.get_previous_output() if len(segments): self.send_transcription_to_client(segments) def speech_to_text(self): """ Process an audio stream in an infinite loop, continuously transcribing the speech. This method continuously receives audio frames, performs real-time transcription, and sends transcribed segments to the client via a WebSocket connection. If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction. It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech (no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if there is no speech for a specified duration to indicate a pause. Raises: Exception: If there is an issue with audio processing or WebSocket communication. """ while True: if self.exit: logging.info("Exiting speech to text thread") break if self.frames_np is None: continue self.clip_audio_if_no_valid_segment() input_bytes, duration = self.get_audio_chunk_for_processing() if duration < 1.0: time.sleep(0.1) # wait for audio chunks to arrive continue try: input_sample = input_bytes.copy() result = self.transcribe_audio(input_sample) if result is None or self.language is None: self.timestamp_offset += duration time.sleep(0.25) # wait for voice activity, result is None when no voice activity continue self.handle_transcription_output(result, duration) except Exception as e: logging.error(f"[ERROR]: Failed to transcribe audio chunk: {e}") time.sleep(0.01) def format_segment(self, start, end, text, completed=False): """ Formats a transcription segment with precise start and end times alongside the transcribed text. Args: start (float): The start time of the transcription segment in seconds. end (float): The end time of the transcription segment in seconds. text (str): The transcribed text corresponding to the segment. Returns: dict: A dictionary representing the formatted transcription segment, including 'start' and 'end' times as strings with three decimal places and the 'text' of the transcription. """ return { 'start': "{:.3f}".format(start), 'end': "{:.3f}".format(end), 'text': text, 'completed': completed } def update_segments(self, segments, duration): """ Processes the segments from whisper. Appends all the segments to the list except for the last segment assuming that it is incomplete. Updates the ongoing transcript with transcribed segments, including their start and end times. Complete segments are appended to the transcript in chronological order. Incomplete segments (assumed to be the last one) are processed to identify repeated content. If the same incomplete segment is seen multiple times, it updates the offset and appends the segment to the transcript. A threshold is used to detect repeated content and ensure it is only included once in the transcript. The timestamp offset is updated based on the duration of processed segments. The method returns the last processed segment, allowing it to be sent to the client for real-time updates. Args: segments(dict) : dictionary of segments as returned by whisper duration(float): duration of the current chunk Returns: dict or None: The last processed segment with its start time, end time, and transcribed text. Returns None if there are no valid segments to process. """ offset = None self.current_out = '' last_segment = None # process complete segments if len(segments) > 1 and segments[-1].no_speech_prob <= self.no_speech_thresh: for i, s in enumerate(segments[:-1]): text_ = s.text self.text.append(text_) with self.lock: start, end = self.timestamp_offset + s.start, self.timestamp_offset + min(duration, s.end) if start >= end: continue if s.no_speech_prob > self.no_speech_thresh: continue self.transcript.append(self.format_segment(start, end, text_, completed=True)) offset = min(duration, s.end) # only process the last segment if it satisfies the no_speech_thresh if segments[-1].no_speech_prob <= self.no_speech_thresh: self.current_out += segments[-1].text with self.lock: last_segment = self.format_segment( self.timestamp_offset + segments[-1].start, self.timestamp_offset + min(duration, segments[-1].end), self.current_out, completed=False ) if self.current_out.strip() == self.prev_out.strip() and self.current_out != '': self.same_output_count += 1 # if we remove the audio because of same output on the nth reptition we might remove the # audio thats not yet transcribed so, capturing the time when it was repeated for the first time if self.end_time_for_same_output is None: self.end_time_for_same_output = segments[-1].end time.sleep(0.1) # wait for some voice activity just in case there is an unitended pause from the speaker for better punctuations. else: self.same_output_count = 0 self.end_time_for_same_output = None # if same incomplete segment is seen multiple times then update the offset # and append the segment to the list if self.same_output_count > self.same_output_threshold: if not len(self.text) or self.text[-1].strip().lower() != self.current_out.strip().lower(): self.text.append(self.current_out) with self.lock: self.transcript.append(self.format_segment( self.timestamp_offset, self.timestamp_offset + min(duration, self.end_time_for_same_output), self.current_out, completed=True )) self.current_out = '' offset = min(duration, self.end_time_for_same_output) self.same_output_count = 0 last_segment = None self.end_time_for_same_output = None else: self.prev_out = self.current_out # update offset if offset is not None: with self.lock: self.timestamp_offset += offset return last_segment