OpenAI's Whisper in depth
Whisper is a speech-to-text model developed by OpenAI in September 2022. The release of this model came with a paper and implementation which was open-sourced on OpenAI’s GitHub.
The architecture behind Whisper isn’t that complicated, but the inference workflow that was open-sourced on GitHub can be quite complex. This is because of:
- Language detection
- Optional generation of timestamps at the segment or word level
- Decoding fallbacks: Fallbacks that try to prevent low-quality transcriptions caused by looping, lack of speech, or low log probabilities
- Logit filters
- KV caching
For this reason, the Whisper codebase can be challenging to understand. The purpose of this blog is to strip off all of the bells and whistles of the implementation to create a version that is much easier to understand.
Some context on Whisper
Whisper represents a significant advancement in speech recognition research. To better understand its impact and unique features, let’s review the main insight from the original paper:
- Wav2Vec2 needs to be fine-tuned.
- Whisper trains on a lot of data, Wav2Vec2 also trains on a lot of data but it’s an encoder only model, so it requires to finetune a decoder and this can lead to problems regarding generalization.
- Wav2Vec: ~1M samples
- Whisper: 860k samples
- So lots of data is important, but also to have different sources is important:
- They show how others have done some work on training with diverse sources, but on relatively low amounts of data (O(1k-10k) samples).
- So I guess the main takeaway of whisper is:
- Lots of data.
- Diverse sources.
Audio loading and preprocessing
Whisper takes as input a 30 seconds audio clip. The audio clip needs to be transformed from its waveform representation into a log mel-spectrogram. We transform a 1d array into a 2d array that represents amplitudes of the audio at different frequencies.
There are some parameters related to the way that the mel spectrogram is computed but that is beyond the scope of this blog.
import torch
import torch.nn.functional as F
# [Audio hyperparams]
def load_audio(file: str, sr: int = SAMPLE_RATE):
# [ffmpeg command setup]
out = run(cmd, capture_output=True, check=True).stdout
# [Convert to numpy array and normalize]
def log_mel_spectrogram(audio, n_mels, padding):
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
# Usage
audio = load_audio("jfk.flac")
audio = torch.from_numpy(audio)
mel = log_mel_spectrogram(audio, n_mels=80, padding=N_SAMPLES)
mel = mel[None, :, : N_FRAMES]
Whisper architecture
At a high level, the whisper architecture comprises two parts, an audio encoder and text decoder. We use the audio encoder once to get useful features for the audio clip that will be used in the text decoder. The text decoder is a transformer-based causal language model, in this case we cross attend the audio features after each self attention block.
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class MultiHeadAttention(nn.Module):
def __init__(self, n_head, n_state):
super().__init__()
self.n_state = n_state
self.n_head = n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
def forward(self, x, mask=None, xa=None):
scale = (self.n_state // self.n_head) ** -0.25
q = self.query(x)
if xa is not None:
k = self.key(xa)
v = self.value(xa)
else:
k = self.key(x)
v = self.value(x)
q = q.view(*q.shape[:2], self.n_head, -1).transpose(1, 2) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(*v.shape[:2], self.n_head, -1).transpose(1, 2)
qk = q @ k
if mask is not None:
qk = qk + mask[: x.shape[1], : x.shape[1]]
w = qk.softmax(-1)
out = (w @ v).transpose(1, 2).flatten(start_dim=2)
out = self.out(out)
return out
class ResidualAttention(nn.Module):
def __init__(self, n_head, n_state, cross_attention=False):
super().__init__()
self.n_state = n_state
self.n_head = n_head
self.cross_attention = cross_attention
self.attn = MultiHeadAttention(n_head, n_state)
self.attn_ln = nn.LayerNorm(n_state)
if cross_attention:
self.cross_attn = MultiHeadAttention(n_head, n_state)
self.cross_attn_ln = nn.LayerNorm(n_state)
self.mlp = nn.Sequential(
nn.Linear(n_state, n_state * 4), nn.GELU(), nn.Linear(n_state * 4, n_state)
)
self.mlp_ln = nn.LayerNorm(n_state)
def forward(self, x, mask=None, xa=None):
x = x + self.attn(self.attn_ln(x), mask=mask)
if self.cross_attention:
x = x + self.cross_attn(self.cross_attn_ln(x), xa=xa)
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
def __init__(self, n_mels, n_head, n_state, n_layer, n_ctx):
super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.blocks = nn.ModuleList(
[ResidualAttention(n_head, n_state) for _ in range(n_layer)]
)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.ln_post = nn.LayerNorm(n_state)
def forward(self, x):
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.transpose(1, 2)
x = x + self.positional_embedding
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
class TextDecoder(nn.Module):
def __init__(self, n_vocab, n_head, n_state, n_layer, n_ctx):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks = nn.ModuleList(
[
ResidualAttention(n_head, n_state, cross_attention=True)
for _ in range(n_layer)
]
)
self.ln = nn.LayerNorm(n_state)
mask = torch.empty(n_ctx, n_ctx).fill_(-torch.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x, xa):
x = self.token_embedding(x) + self.positional_embedding[: x.shape[-1]]
for block in self.blocks:
x = block(x, mask=self.mask, xa=xa)
x = self.ln(x)
logits = x @ self.token_embedding.weight.transpose(0, 1)
return logits
class Whisper(nn.Module):
def __init__(
self,
n_mels,
n_vocab,
n_audio_ctx,
n_audio_state,
n_audio_head,
n_audio_layer,
n_text_ctx,
n_text_state,
n_text_head,
n_text_layer,
):
super().__init__()
self.encoder = AudioEncoder(
n_mels, n_audio_head, n_audio_state, n_audio_layer, n_audio_ctx
)
self.decoder = TextDecoder(
n_vocab, n_text_head, n_text_state, n_text_layer, n_text_ctx
)
def forward(self, x, xa):
return self.decoder(x, self.encoder(xa))
Generation
The generation process involves several steps:
- Audio Feature Extraction
- Tokenization
- Logit Filters
- Autoregressive Generation
Audio features
We first calculate features from the audio signal:
audio_features = model.encoder(mel)
Tokenizer
The first token of each generation corresponds to the language token, which identifies the source language of the audio. The second token identifies the task, which could be either transcribe or translate:
from whisper.tokenizer import get_tokenizer
tokenizer = get_tokenizer(
multilingual=True,
num_languages=99,
language="en",
task="transcribe",
)
initial_tokens = tokenizer.sot_sequence
print(tokenizer.decode(initial_tokens))
<|startoftranscript|><|en|><|transcribe|>
**Logit filters **
The official code implements some logit filters, which post-process the logits that the audio decoder outputs. There are two main filters used:
- One to suppress a list of undesirable tokens that in general should never appear in any generation
- Blank spaces at the start of generation
def _get_suppress_tokens():
suppress_tokens = "-1"
if isinstance(suppress_tokens, str):
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
if -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = []
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens.extend(
[
tokenizer.transcribe,
tokenizer.translate,
tokenizer.sot,
tokenizer.sot_prev,
tokenizer.sot_lm,
]
)
if tokenizer.no_speech is not None:
suppress_tokens.append(tokenizer.no_speech)
return tuple(sorted(set(suppress_tokens)))
class SuppressBlank:
def __init__(self, tokenizer, sample_begin):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
def apply(self, logits, tokens):
if tokens.shape[1] == self.sample_begin:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -torch.inf
class SuppressTokens:
def __init__(self, tokens):
self.tokens = tokens
def apply(self, logits, tokens):
logits[:, self.tokens] = -torch.inf
sample_begin = 3
blank_filter = SuppressBlank(tokenizer, sample_begin)
supress_tokens_filter = SuppressTokens(_get_suppress_tokens())
logits_filters = [blank_filter, supress_tokens_filter]
Here are the tokens that we suppress:
tokenizer.decode(_get_suppress_tokens())
'"#()*+/:;<=>@[\\]^_`{|}~ - " ( [ �>> >>-- \' ♪ -- * : / <「」� # ♫♪ ] + = -( ) ♪♪)) @ { ~ \\ > ; >>>♫ -[ (( ("『』 | ^--- 「 ♬♪♪ _ ))) `}} ♪♪♪ )) --- ♩♬ << } (\'<|startoftranscript|><|translate|><|transcribe|><|startoflm|><|startofprev|><|nospeech|>'
Inference loop
Here’s a simplified version of the inference loop:
tokens = torch.tensor(initial_tokens).view(1, -1)
sample_len = checkpoint["dims"]["n_text_ctx"] // 2
for idx in range(sample_len):
logits = model.decoder(tokens, audio_features)
if idx == 0:
no_speech_prob = logits[:, 0].softmax(-1)[:, tokenizer.no_speech].tolist()
print("No speech prob:", no_speech_prob)
logits = logits[:, -1]
for filter in logits_filters:
filter.apply(logits, tokens)
next_token = logits.argmax(-1)
completed = (next_token == tokenizer.eot).item()
tokens = torch.cat([tokens, next_token[None]], dim=1)
if completed:
break
print(tokenizer.decode(tokens[0]))
Here is the output generated by the model
'<|startoftranscript|><|en|><|transcribe|><|notimestamps|> And so my fellow Americans ask not what your country can do for you ask what you can do for your country.<|endoftext|>'
Questions & Ideas for Further Exploration
As we conclude this post, several intriguing ideas and questions about the Whisper model and its architecture have emerged:
- How does the model handle audio segments containing multiple languages?
- Is word alignment influenced by the specified language?
- Are there performance differences between translation and transcription to English?
- To what extent are the text standardizations truly innocuous?
- What are the possibilities for long-form transcriptions?
- How can Whisper be adapted for real-time transcriptions?
- IDEA: Implementing native speaker diarization with Whisper
These thoughts open up exciting avenues for future research and applications of the Whisper model.