366 lines
13 KiB
Python
366 lines
13 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import os
|
|
from collections import defaultdict
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from subprocess import CalledProcessError, run
|
|
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
|
|
|
import kaldialign
|
|
import numpy as np
|
|
import soundfile
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
Pathlike = Union[str, Path]
|
|
|
|
SAMPLE_RATE = 16000
|
|
N_FFT = 400
|
|
HOP_LENGTH = 160
|
|
CHUNK_LENGTH = 30
|
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
|
|
|
|
|
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|
"""
|
|
Open an audio file and read as mono waveform, resampling as necessary
|
|
|
|
Parameters
|
|
----------
|
|
file: str
|
|
The audio file to open
|
|
|
|
sr: int
|
|
The sample rate to resample the audio if necessary
|
|
|
|
Returns
|
|
-------
|
|
A NumPy array containing the audio waveform, in float32 dtype.
|
|
"""
|
|
|
|
# This launches a subprocess to decode audio while down-mixing
|
|
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
|
# fmt: off
|
|
cmd = [
|
|
"ffmpeg", "-nostdin", "-threads", "0", "-i", file, "-f", "s16le", "-ac",
|
|
"1", "-acodec", "pcm_s16le", "-ar",
|
|
str(sr), "-"
|
|
]
|
|
# fmt: on
|
|
try:
|
|
out = run(cmd, capture_output=True, check=True).stdout
|
|
except CalledProcessError as e:
|
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
|
|
|
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
|
|
|
|
|
def load_audio_wav_format(wav_path):
|
|
# make sure audio in .wav format
|
|
assert wav_path.endswith(
|
|
'.wav'), f"Only support .wav format, but got {wav_path}"
|
|
waveform, sample_rate = soundfile.read(wav_path)
|
|
assert sample_rate == 16000, f"Only support 16k sample rate, but got {sample_rate}"
|
|
return waveform, sample_rate
|
|
|
|
|
|
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
|
"""
|
|
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
|
"""
|
|
if torch.is_tensor(array):
|
|
if array.shape[axis] > length:
|
|
array = array.index_select(dim=axis,
|
|
index=torch.arange(length,
|
|
device=array.device))
|
|
|
|
if array.shape[axis] < length:
|
|
pad_widths = [(0, 0)] * array.ndim
|
|
pad_widths[axis] = (0, length - array.shape[axis])
|
|
array = F.pad(array,
|
|
[pad for sizes in pad_widths[::-1] for pad in sizes])
|
|
else:
|
|
if array.shape[axis] > length:
|
|
array = array.take(indices=range(length), axis=axis)
|
|
|
|
if array.shape[axis] < length:
|
|
pad_widths = [(0, 0)] * array.ndim
|
|
pad_widths[axis] = (0, length - array.shape[axis])
|
|
array = np.pad(array, pad_widths)
|
|
|
|
return array
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def mel_filters(device,
|
|
n_mels: int,
|
|
mel_filters_dir: str = None) -> torch.Tensor:
|
|
"""
|
|
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
|
Allows decoupling librosa dependency; saved using:
|
|
|
|
np.savez_compressed(
|
|
"mel_filters.npz",
|
|
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
|
)
|
|
"""
|
|
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
|
if mel_filters_dir is None:
|
|
mel_filters_path = os.path.join(os.path.dirname(__file__), "assets",
|
|
"mel_filters.npz")
|
|
else:
|
|
mel_filters_path = os.path.join(mel_filters_dir, "mel_filters.npz")
|
|
with np.load(mel_filters_path) as f:
|
|
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
|
|
|
|
|
def log_mel_spectrogram(
|
|
audio: Union[str, np.ndarray, torch.Tensor],
|
|
n_mels: int,
|
|
padding: int = 0,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
return_duration: bool = False,
|
|
mel_filters_dir: str = None,
|
|
):
|
|
"""
|
|
Compute the log-Mel spectrogram of
|
|
|
|
Parameters
|
|
----------
|
|
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
|
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
|
|
|
n_mels: int
|
|
The number of Mel-frequency filters, only 80 and 128 are supported
|
|
|
|
padding: int
|
|
Number of zero samples to pad to the right
|
|
|
|
device: Optional[Union[str, torch.device]]
|
|
If given, the audio tensor is moved to this device before STFT
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor, shape = (80 or 128, n_frames)
|
|
A Tensor that contains the Mel spectrogram
|
|
"""
|
|
if not torch.is_tensor(audio):
|
|
if isinstance(audio, str):
|
|
if audio.endswith('.wav'):
|
|
audio, _ = load_audio_wav_format(audio)
|
|
else:
|
|
audio = load_audio(audio)
|
|
assert isinstance(audio,
|
|
np.ndarray), f"Unsupported audio type: {type(audio)}"
|
|
duration = audio.shape[-1] / SAMPLE_RATE
|
|
audio = pad_or_trim(audio, N_SAMPLES)
|
|
audio = audio.astype(np.float32)
|
|
audio = torch.from_numpy(audio)
|
|
|
|
if device is not None:
|
|
audio = audio.to(device)
|
|
if padding > 0:
|
|
audio = F.pad(audio, (0, padding))
|
|
window = torch.hann_window(N_FFT).to(audio.device)
|
|
stft = torch.stft(audio,
|
|
N_FFT,
|
|
HOP_LENGTH,
|
|
window=window,
|
|
return_complex=True)
|
|
magnitudes = stft[..., :-1].abs()**2
|
|
|
|
filters = mel_filters(audio.device, n_mels, mel_filters_dir)
|
|
mel_spec = filters @ magnitudes
|
|
|
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
|
log_spec = (log_spec + 4.0) / 4.0
|
|
if return_duration:
|
|
return log_spec, duration
|
|
else:
|
|
return log_spec
|
|
|
|
|
|
def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str,
|
|
str]]) -> None:
|
|
"""Save predicted results and reference transcripts to a file.
|
|
https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
|
Args:
|
|
filename:
|
|
File to save the results to.
|
|
texts:
|
|
An iterable of tuples. The first element is the cur_id, the second is
|
|
the reference transcript and the third element is the predicted result.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
with open(filename, "w") as f:
|
|
for cut_id, ref, hyp in texts:
|
|
print(f"{cut_id}:\tref={ref}", file=f)
|
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
|
|
|
|
|
def write_error_stats( # noqa: C901
|
|
f: TextIO,
|
|
test_set_name: str,
|
|
results: List[Tuple[str, str]],
|
|
enable_log: bool = True,
|
|
) -> float:
|
|
"""Write statistics based on predicted results and reference transcripts.
|
|
https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
|
It will write the following to the given file:
|
|
|
|
- WER
|
|
- number of insertions, deletions, substitutions, corrects and total
|
|
reference words. For example::
|
|
|
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
|
reference words (2337 correct)
|
|
|
|
- The difference between the reference transcript and predicted result.
|
|
An instance is given below::
|
|
|
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
|
|
|
The above example shows that the reference word is `EDISON`,
|
|
but it is predicted to `ADDISON` (a substitution error).
|
|
|
|
Another example is::
|
|
|
|
FOR THE FIRST DAY (SIR->*) I THINK
|
|
|
|
The reference word `SIR` is missing in the predicted
|
|
results (a deletion error).
|
|
results:
|
|
An iterable of tuples. The first element is the cur_id, the second is
|
|
the reference transcript and the third element is the predicted result.
|
|
enable_log:
|
|
If True, also print detailed WER to the console.
|
|
Otherwise, it is written only to the given file.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
|
ins: Dict[str, int] = defaultdict(int)
|
|
dels: Dict[str, int] = defaultdict(int)
|
|
|
|
# `words` stores counts per word, as follows:
|
|
# corr, ref_sub, hyp_sub, ins, dels
|
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
|
num_corr = 0
|
|
ERR = "*"
|
|
for cut_id, ref, hyp in results:
|
|
ali = kaldialign.align(ref, hyp, ERR)
|
|
for ref_word, hyp_word in ali:
|
|
if ref_word == ERR:
|
|
ins[hyp_word] += 1
|
|
words[hyp_word][3] += 1
|
|
elif hyp_word == ERR:
|
|
dels[ref_word] += 1
|
|
words[ref_word][4] += 1
|
|
elif hyp_word != ref_word:
|
|
subs[(ref_word, hyp_word)] += 1
|
|
words[ref_word][1] += 1
|
|
words[hyp_word][2] += 1
|
|
else:
|
|
words[ref_word][0] += 1
|
|
num_corr += 1
|
|
ref_len = sum([len(r) for _, r, _ in results])
|
|
sub_errs = sum(subs.values())
|
|
ins_errs = sum(ins.values())
|
|
del_errs = sum(dels.values())
|
|
tot_errs = sub_errs + ins_errs + del_errs
|
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
|
|
|
if enable_log:
|
|
logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
|
f"{del_errs} del, {sub_errs} sub ]")
|
|
|
|
print(f"%WER = {tot_err_rate}", file=f)
|
|
print(
|
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
|
f"words ({num_corr} correct)",
|
|
file=f,
|
|
)
|
|
print(
|
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
|
for cut_id, ref, hyp in results:
|
|
ali = kaldialign.align(ref, hyp, ERR)
|
|
combine_successive_errors = True
|
|
if combine_successive_errors:
|
|
ali = [[[x], [y]] for x, y in ali]
|
|
for i in range(len(ali) - 1):
|
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
|
ali[i] = [[], []]
|
|
ali = [[
|
|
list(filter(lambda a: a != ERR, x)),
|
|
list(filter(lambda a: a != ERR, y)),
|
|
] for x, y in ali]
|
|
ali = list(filter(lambda x: x != [[], []], ali))
|
|
ali = [[
|
|
ERR if x == [] else " ".join(x),
|
|
ERR if y == [] else " ".join(y),
|
|
] for x, y in ali]
|
|
|
|
print(
|
|
f"{cut_id}:\t" + " ".join((ref_word if ref_word == hyp_word else
|
|
f"({ref_word}->{hyp_word})"
|
|
for ref_word, hyp_word in ali)),
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
|
|
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()],
|
|
reverse=True):
|
|
print(f"{count} {ref} -> {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print("DELETIONS: count ref", file=f)
|
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
|
print(f"{count} {ref}", file=f)
|
|
|
|
print("", file=f)
|
|
print("INSERTIONS: count hyp", file=f)
|
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
|
print(f"{count} {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp",
|
|
file=f)
|
|
for _, word, counts in sorted([(sum(v[1:]), k, v)
|
|
for k, v in words.items()],
|
|
reverse=True):
|
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
|
ref_count = corr + ref_sub + dels
|
|
hyp_count = corr + hyp_sub + ins
|
|
|
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
|
return float(tot_err_rate)
|