I’m trying to implement graves (GMM) attention based on Mozilla TTS repo. Here is a link with brief discussions about the implementation by the repo holder. Code below is my implementation to fit flowtron(https://github.com/NVIDIA/flowtron).
The author of flowtron make some changes to Tacotron. They feed all encoded mel frames once so the dimension of queries, keys, and values are
queries: T_mel * B * attn_hidden_dim
keys, values: T_text * B * text_embedding_dim
When I train it, it just doesn’t work well. Only the first frame alignment is close to maximum value and other frame attn. scores are really low. Also, the converging speed is slow (comparing to an additive attn). Can someone help me out to see which part of the code needs a fix (possibly the mu_t part)?
class GravesAttention(torch.nn.Module):
def __init__(self, n_mel_channels=80, n_speaker_dim=128,
n_text_channels=512, n_att_channels=256, K=4):
super(GravesAttention, self).__init__()
## K is number of gaussian component
self.K = K
self._mask_value = 1e-8
self.eps = 1e-5
self.J = None
self.N_a = nn.Sequential(
nn.Linear(n_mel_channels, n_mel_channels, bias=True),
nn.ReLU(),
nn.Linear(n_mel_channels, 3 * K, bias=True)
)
self.key = LinearNorm(n_text_channels + n_speaker_dim,
n_att_channels, bias=False, w_init_gain='tanh')
self.value = LinearNorm(n_text_channels + n_speaker_dim,
n_att_channels, bias=False,
w_init_gain='tanh')
self.init_layers()
def init_layers(self):
torch.nn.init.constant_(self.N_a[2].bias[(2 * self.K):(3 * self.K)], 1.) # bias mean
torch.nn.init.constant_(self.N_a[2].bias[self.K:(2 * self.K)], 10) # bias std
def init_states(self, inputs):
if self.J is None or inputs.shape[0] + 1 > self.J.shape[-1]:
self.J = torch.arange(0, inputs.shape[0] + 2.0).to(inputs.device) + 0.5
def forward(self, queries, keys, values, mask=None, attn=None):
self.init_states(keys) ##initialize self.J
if attn is None:
keys = self.key(keys).transpose(0, 1) # B x in_lens x n_attn_channels
values = self.value(values) if hasattr(self, 'value') else values
values = values.transpose(0, 1) # B x in_lens x n_attn_channels
gbk_t = self.N_a(queries).transpose(0, 1) # B x T x 3K
gbk_t = gbk_t.view(gbk_t.size(0), gbk_t.size(1), -1, self.K)
# each B x T x K
g_t = gbk_t[:, :, 0, :]
b_t = gbk_t[:, :, 1, :]
k_t = gbk_t[:, :, 2, :]
g_t = torch.nn.functional.dropout(g_t, p=0.5, training=self.training)
sig_t = torch.nn.functional.softplus(b_t) + self.eps
k_t = torch.nn.functional.softplus(k_t)
mu_t = torch.cumsum(k_t, dim=1) ## mu_t = mu_(t-1) + k_t, mu_0 = 0
g_t = torch.softmax(g_t, dim=-1) + self.eps
j = self.J[:values.size(1) + 1]
phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.sigmoid((mu_t.unsqueeze(-1) - j) /
sig_t.unsqueeze(-1))))
alpha_t = torch.sum(phi_t, 2) ## sum over attn heads
alpha_t = alpha_t[:, :, 1:] - alpha_t[:, :, :-1]
alpha_t[alpha_t == 0] = 1e-8
if mask is not None:
alpha_t.data.masked_fill_(mask.transpose(1, 2), self._mask_value)
else:
values = self.value(values)
values = values.transpose(0, 1)
print("with_dropout flows2 max, min in alpha_t {} {}".format(torch.max(alpha_t), torch.min(alpha_t)))
output = torch.bmm(alpha_t, values)
output = output.transpose(1, 2)
return output, alpha_t