|
|
|
@ -78,6 +78,8 @@ def transcribe(
|
|
|
|
|
if dtype == torch.float16:
|
|
|
|
|
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
|
|
|
|
dtype = torch.float32
|
|
|
|
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
|
|
|
warnings.warn("Performing inference on CPU when MPS is available")
|
|
|
|
|
|
|
|
|
|
if dtype == torch.float32:
|
|
|
|
|
decode_options["fp16"] = False
|
|
|
|
@ -135,7 +137,7 @@ def cli():
|
|
|
|
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
|
|
|
|
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
|
|
|
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
|
|
|
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
|
|
|
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "mps", help="device to use for PyTorch inference")
|
|
|
|
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
|
|
|
|
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
|
|
|
|
|
|
|
|
|