You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
10 KiB
Python
210 lines
10 KiB
Python
import argparse
|
|
import os
|
|
import warnings
|
|
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
|
|
|
|
import numpy as np
|
|
import torch
|
|
import tqdm
|
|
|
|
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
|
from .decoding import DecodingOptions, DecodingResult
|
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
|
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
|
|
|
|
if TYPE_CHECKING:
|
|
from .model import Whisper
|
|
|
|
|
|
def transcribe(
|
|
model: "Whisper",
|
|
audio: Union[str, np.ndarray, torch.Tensor],
|
|
*,
|
|
verbose: Optional[bool] = None,
|
|
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
|
compression_ratio_threshold: Optional[float] = 2.4,
|
|
logprob_threshold: Optional[float] = -1.0,
|
|
no_speech_threshold: Optional[float] = 0.6,
|
|
condition_on_previous_text: bool = True,
|
|
force_extraction: bool = False,
|
|
**decode_options,
|
|
):
|
|
"""
|
|
Transcribe an audio file using Whisper
|
|
|
|
Parameters
|
|
----------
|
|
model: Whisper
|
|
The Whisper model instance
|
|
|
|
audio: Union[str, np.ndarray, torch.Tensor]
|
|
The path to the audio file to open, or the audio waveform
|
|
|
|
verbose: bool
|
|
Whether to display the text being decoded to the console. If True, displays all the details,
|
|
If False, displays minimal details. If None, does not display anything
|
|
|
|
temperature: Union[float, Tuple[float, ...]]
|
|
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
|
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
|
|
|
compression_ratio_threshold: float
|
|
If the gzip compression ratio is above this value, treat as failed
|
|
|
|
logprob_threshold: float
|
|
If the average log probability over sampled tokens is below this value, treat as failed
|
|
|
|
no_speech_threshold: float
|
|
If the no_speech probability is higher than this value AND the average log probability
|
|
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
|
|
|
condition_on_previous_text: bool
|
|
if True, the previous output of the model is provided as a prompt for the next window;
|
|
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
|
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
|
|
|
decode_options: dict
|
|
Keyword arguments to construct `DecodingOptions` instances
|
|
|
|
Returns
|
|
-------
|
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
|
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
|
"""
|
|
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
|
if model.device == torch.device("cpu"):
|
|
if torch.cuda.is_available():
|
|
warnings.warn("Performing inference on CPU when CUDA is available")
|
|
if dtype == torch.float16:
|
|
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
|
dtype = torch.float32
|
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
warnings.warn("Performing inference on CPU when MPS is available")
|
|
|
|
if dtype == torch.float32:
|
|
decode_options["fp16"] = False
|
|
|
|
mel = log_mel_spectrogram(audio)
|
|
|
|
all_segments = []
|
|
def add_segment(
|
|
*, start: float, end: float, encoder_embeddings
|
|
):
|
|
|
|
all_segments.append(
|
|
{
|
|
"start": start,
|
|
"end": end,
|
|
"encoder_embeddings":encoder_embeddings,
|
|
}
|
|
)
|
|
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
|
num_frames = mel.shape[-1]
|
|
seek = 0
|
|
previous_seek_value = seek
|
|
sample_skip = 3000 #
|
|
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
|
while seek < num_frames:
|
|
# seek是开始的帧数
|
|
end_seek = min(seek + sample_skip, num_frames)
|
|
segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype)
|
|
|
|
single = segment.ndim == 2
|
|
if single:
|
|
segment = segment.unsqueeze(0)
|
|
if dtype == torch.float16:
|
|
segment = segment.half()
|
|
audio_features, embeddings = model.encoder(segment, include_embeddings = True)
|
|
|
|
encoder_embeddings = embeddings
|
|
#print(f"encoder_embeddings shape {encoder_embeddings.shape}")
|
|
add_segment(
|
|
start=seek,
|
|
end=end_seek,
|
|
#text_tokens=tokens,
|
|
#result=result,
|
|
encoder_embeddings=encoder_embeddings,
|
|
)
|
|
seek+=sample_skip
|
|
|
|
return dict(segments=all_segments)
|
|
|
|
|
|
def cli():
|
|
from . import available_models
|
|
|
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
|
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "mps", help="device to use for PyTorch inference")
|
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
|
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
|
|
|
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
|
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
|
|
|
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
|
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
|
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
|
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
|
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
|
|
|
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
|
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
|
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
|
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
|
|
|
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
|
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
|
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
|
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
|
|
|
args = parser.parse_args().__dict__
|
|
model_name: str = args.pop("model")
|
|
model_dir: str = args.pop("model_dir")
|
|
output_dir: str = args.pop("output_dir")
|
|
device: str = args.pop("device")
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
|
if args["language"] is not None:
|
|
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
|
args["language"] = "en"
|
|
|
|
temperature = args.pop("temperature")
|
|
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
|
if temperature_increment_on_fallback is not None:
|
|
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
|
else:
|
|
temperature = [temperature]
|
|
|
|
threads = args.pop("threads")
|
|
if threads > 0:
|
|
torch.set_num_threads(threads)
|
|
|
|
from . import load_model
|
|
model = load_model(model_name, device=device, download_root=model_dir)
|
|
|
|
for audio_path in args.pop("audio"):
|
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
|
|
|
audio_basename = os.path.basename(audio_path)
|
|
|
|
# save TXT
|
|
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
|
|
write_txt(result["segments"], file=txt)
|
|
|
|
# save VTT
|
|
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
|
|
write_vtt(result["segments"], file=vtt)
|
|
|
|
# save SRT
|
|
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
|
write_srt(result["segments"], file=srt)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
cli()
|