|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | +import numpy as np |
| 5 | +from tqdm import tqdm |
| 6 | + |
| 7 | + |
| 8 | +class Vocoder(nn.Module): |
| 9 | + """Universal vocoding""" |
| 10 | + |
| 11 | + def __init__( |
| 12 | + self, |
| 13 | + sample_rate, |
| 14 | + mel_channels, |
| 15 | + conditioning_channels, |
| 16 | + embedding_dim, |
| 17 | + rnn_channels, |
| 18 | + fc_channels, |
| 19 | + bits, |
| 20 | + hop_length, |
| 21 | + ): |
| 22 | + super().__init__() |
| 23 | + |
| 24 | + self.init_params = { |
| 25 | + "sample_rate": sample_rate, |
| 26 | + "mel_channels": mel_channels, |
| 27 | + "conditioning_channels": conditioning_channels, |
| 28 | + "embedding_dim": embedding_dim, |
| 29 | + "rnn_channels": rnn_channels, |
| 30 | + "fc_channels": fc_channels, |
| 31 | + "bits": bits, |
| 32 | + "hop_length": hop_length, |
| 33 | + } |
| 34 | + |
| 35 | + self.rnn_channels = rnn_channels |
| 36 | + self.quantization_channels = 2 ** bits |
| 37 | + self.hop_length = hop_length |
| 38 | + self.sample_rate = sample_rate |
| 39 | + |
| 40 | + self.rnn1 = nn.GRU( |
| 41 | + mel_channels, |
| 42 | + conditioning_channels, |
| 43 | + num_layers=2, |
| 44 | + batch_first=True, |
| 45 | + bidirectional=True, |
| 46 | + ) |
| 47 | + self.embedding = nn.Embedding(self.quantization_channels, embedding_dim) |
| 48 | + self.rnn2 = nn.GRU( |
| 49 | + embedding_dim + 2 * conditioning_channels, rnn_channels, batch_first=True |
| 50 | + ) |
| 51 | + self.fc1 = nn.Linear(rnn_channels, fc_channels) |
| 52 | + self.fc2 = nn.Linear(fc_channels, self.quantization_channels) |
| 53 | + |
| 54 | + def forward(self, x, mels): |
| 55 | + sample_frames = mels.size(1) |
| 56 | + audio_slice_frames = x.size(1) // self.hop_length |
| 57 | + pad = (sample_frames - audio_slice_frames) // 2 |
| 58 | + |
| 59 | + mels, _ = self.rnn1(mels) |
| 60 | + mels = mels[:, pad : pad + audio_slice_frames, :] |
| 61 | + |
| 62 | + mels = F.interpolate(mels.transpose(1, 2), scale_factor=float(self.hop_length)) |
| 63 | + mels = mels.transpose(1, 2) |
| 64 | + |
| 65 | + x = self.embedding(x) |
| 66 | + |
| 67 | + x, _ = self.rnn2(torch.cat((x, mels), dim=2)) |
| 68 | + |
| 69 | + x = F.relu(self.fc1(x)) |
| 70 | + x = self.fc2(x) |
| 71 | + return x |
| 72 | + |
| 73 | + @classmethod |
| 74 | + def load_checkpoint(cls, checkpoint_path): |
| 75 | + ckpt = torch.load(checkpoint_path, map_location="cpu") |
| 76 | + model = cls(**ckpt["init_params"]) |
| 77 | + model.load_state_dict(ckpt["model"]) |
| 78 | + return model |
| 79 | + |
| 80 | + def save_checkpoint(self, checkpoint_path): |
| 81 | + torch.save( |
| 82 | + {"init_params": self.init_params, "model": self.state_dict(),}, |
| 83 | + checkpoint_path, |
| 84 | + ) |
| 85 | + |
| 86 | + def generate(self, mel): |
| 87 | + """Generate waveform from mel spectrogram using vocoder.""" |
| 88 | + output = [] |
| 89 | + cell = get_gru_cell(self.rnn2) |
| 90 | + |
| 91 | + with torch.no_grad(): |
| 92 | + mel, _ = self.rnn1(mel) |
| 93 | + |
| 94 | + mel = F.interpolate( |
| 95 | + mel.transpose(1, 2), scale_factor=float(self.hop_length) |
| 96 | + ) |
| 97 | + mel = mel.transpose(1, 2) |
| 98 | + |
| 99 | + batch_size, _, _ = mel.size() |
| 100 | + |
| 101 | + h = torch.zeros(batch_size, self.rnn_channels, device=mel.device) |
| 102 | + x = ( |
| 103 | + torch.zeros(batch_size, device=mel.device) |
| 104 | + .fill_(self.quantization_channels // 2) |
| 105 | + .long() |
| 106 | + ) |
| 107 | + |
| 108 | + for m in tqdm(torch.unbind(mel, dim=1), leave=False): |
| 109 | + x = self.embedding(x) |
| 110 | + h = cell(torch.cat((x, m), dim=1), h) |
| 111 | + |
| 112 | + x = F.relu(self.fc1(h)) |
| 113 | + logits = self.fc2(x) |
| 114 | + |
| 115 | + posterior = F.softmax(logits, dim=1) |
| 116 | + dist = torch.distributions.Categorical(posterior) |
| 117 | + |
| 118 | + x = dist.sample() |
| 119 | + output.append( |
| 120 | + 2 * x.float().item() / (self.quantization_channels - 1.0) - 1.0 |
| 121 | + ) |
| 122 | + |
| 123 | + output = np.asarray(output, dtype=np.float64) |
| 124 | + output = mulaw_decode(output, self.quantization_channels) |
| 125 | + |
| 126 | + return output |
| 127 | + |
| 128 | + |
| 129 | +def get_gru_cell(gru): |
| 130 | + gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) |
| 131 | + gru_cell.weight_hh.data = gru.weight_hh_l0.data |
| 132 | + gru_cell.weight_ih.data = gru.weight_ih_l0.data |
| 133 | + gru_cell.bias_hh.data = gru.bias_hh_l0.data |
| 134 | + gru_cell.bias_ih.data = gru.bias_ih_l0.data |
| 135 | + return gru_cell |
| 136 | + |
| 137 | + |
| 138 | +def mulaw_decode(x_mu: np.ndarray, n_channels: int) -> np.ndarray: |
| 139 | + """Decode mu-law encoded signal.""" |
| 140 | + mu = n_channels - 1 |
| 141 | + x = np.sign(x_mu) / mu * ((1 + mu) ** np.abs(x_mu) - 1) |
| 142 | + return x |
0 commit comments