Skip to content

Commit ed79ef2

Browse files
committed
add training scripts
1 parent 1529cc3 commit ed79ef2

File tree

6 files changed

+393
-1
lines changed

6 files changed

+393
-1
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
*.npy
22
*.npz
3-
*.tfevents.*
3+
*.tfevents.*
4+
*.pt
5+
metadata*

configs/training_config.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
frames_per_sample: 40
2+
frames_per_slice: 8
3+
bits: 9
4+
conditioning_channels: 128
5+
embedding_dim: 256
6+
rnn_channels: 896
7+
fc_channels: 512
8+
batch_size: 32
9+
n_steps: 100000
10+
valid_every: 1000
11+
valid_ratio: 0.05
12+
save_every: 5000

data/vocoder_dataset.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Vocoder dataset."""
2+
3+
import json
4+
from random import randint
5+
from pathlib import Path
6+
7+
import numpy as np
8+
import torch
9+
from torch.utils.data import Dataset
10+
11+
from .utils import mulaw_encode
12+
13+
14+
class VocoderDataset(Dataset):
15+
"""Sample a segment of utterance for training vocoder."""
16+
17+
def __init__(
18+
self, data_dir, metadata_path, frames_per_sample, frames_per_slice, bits
19+
):
20+
21+
with open(metadata_path, "r") as f:
22+
metadata = json.load(f)
23+
24+
self.data_dir = Path(data_dir)
25+
self.sample_rate = metadata["sample_rate"]
26+
self.hop_len = metadata["hop_len"]
27+
self.n_mels = metadata["n_mels"]
28+
self.n_pad = (frames_per_sample - frames_per_slice) // 2
29+
self.frames_per_sample = frames_per_sample
30+
self.frames_per_slice = frames_per_slice
31+
self.bits = bits
32+
self.uttr_infos = [
33+
uttr_info
34+
for uttr_info in metadata["utterances"]
35+
if uttr_info["mel_len"] > frames_per_sample
36+
]
37+
38+
def __len__(self):
39+
return len(self.uttr_infos)
40+
41+
def __getitem__(self, index):
42+
uttr_info = self.uttr_infos[index]
43+
features = np.load(self.data_dir / uttr_info["feature_path"])
44+
wav = features["wav"]
45+
mel = features["mel"]
46+
47+
wav = np.pad(wav, (0, (len(mel) * self.hop_len - len(wav))), "constant")
48+
mel = np.pad(mel, ((self.n_pad,), (0,)), "constant")
49+
wav = np.pad(wav, (self.n_pad * self.hop_len,), "constant")
50+
wav = mulaw_encode(wav, 2 ** self.bits)
51+
52+
pos = randint(0, len(mel) - self.frames_per_sample)
53+
mel_seg = mel[pos : pos + self.frames_per_sample, :]
54+
55+
pos1 = pos + self.n_pad
56+
pos2 = pos1 + self.frames_per_slice
57+
wav_seg = wav[pos1 * self.hop_len : pos2 * self.hop_len + 1]
58+
59+
return torch.FloatTensor(mel_seg), torch.LongTensor(wav_seg)

models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .vocoder import *

models/vocoder.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import numpy as np
5+
from tqdm import tqdm
6+
7+
8+
class Vocoder(nn.Module):
9+
"""Universal vocoding"""
10+
11+
def __init__(
12+
self,
13+
sample_rate,
14+
mel_channels,
15+
conditioning_channels,
16+
embedding_dim,
17+
rnn_channels,
18+
fc_channels,
19+
bits,
20+
hop_length,
21+
):
22+
super().__init__()
23+
24+
self.init_params = {
25+
"sample_rate": sample_rate,
26+
"mel_channels": mel_channels,
27+
"conditioning_channels": conditioning_channels,
28+
"embedding_dim": embedding_dim,
29+
"rnn_channels": rnn_channels,
30+
"fc_channels": fc_channels,
31+
"bits": bits,
32+
"hop_length": hop_length,
33+
}
34+
35+
self.rnn_channels = rnn_channels
36+
self.quantization_channels = 2 ** bits
37+
self.hop_length = hop_length
38+
self.sample_rate = sample_rate
39+
40+
self.rnn1 = nn.GRU(
41+
mel_channels,
42+
conditioning_channels,
43+
num_layers=2,
44+
batch_first=True,
45+
bidirectional=True,
46+
)
47+
self.embedding = nn.Embedding(self.quantization_channels, embedding_dim)
48+
self.rnn2 = nn.GRU(
49+
embedding_dim + 2 * conditioning_channels, rnn_channels, batch_first=True
50+
)
51+
self.fc1 = nn.Linear(rnn_channels, fc_channels)
52+
self.fc2 = nn.Linear(fc_channels, self.quantization_channels)
53+
54+
def forward(self, x, mels):
55+
sample_frames = mels.size(1)
56+
audio_slice_frames = x.size(1) // self.hop_length
57+
pad = (sample_frames - audio_slice_frames) // 2
58+
59+
mels, _ = self.rnn1(mels)
60+
mels = mels[:, pad : pad + audio_slice_frames, :]
61+
62+
mels = F.interpolate(mels.transpose(1, 2), scale_factor=float(self.hop_length))
63+
mels = mels.transpose(1, 2)
64+
65+
x = self.embedding(x)
66+
67+
x, _ = self.rnn2(torch.cat((x, mels), dim=2))
68+
69+
x = F.relu(self.fc1(x))
70+
x = self.fc2(x)
71+
return x
72+
73+
@classmethod
74+
def load_checkpoint(cls, checkpoint_path):
75+
ckpt = torch.load(checkpoint_path, map_location="cpu")
76+
model = cls(**ckpt["init_params"])
77+
model.load_state_dict(ckpt["model"])
78+
return model
79+
80+
def save_checkpoint(self, checkpoint_path):
81+
torch.save(
82+
{"init_params": self.init_params, "model": self.state_dict(),},
83+
checkpoint_path,
84+
)
85+
86+
def generate(self, mel):
87+
"""Generate waveform from mel spectrogram using vocoder."""
88+
output = []
89+
cell = get_gru_cell(self.rnn2)
90+
91+
with torch.no_grad():
92+
mel, _ = self.rnn1(mel)
93+
94+
mel = F.interpolate(
95+
mel.transpose(1, 2), scale_factor=float(self.hop_length)
96+
)
97+
mel = mel.transpose(1, 2)
98+
99+
batch_size, _, _ = mel.size()
100+
101+
h = torch.zeros(batch_size, self.rnn_channels, device=mel.device)
102+
x = (
103+
torch.zeros(batch_size, device=mel.device)
104+
.fill_(self.quantization_channels // 2)
105+
.long()
106+
)
107+
108+
for m in tqdm(torch.unbind(mel, dim=1), leave=False):
109+
x = self.embedding(x)
110+
h = cell(torch.cat((x, m), dim=1), h)
111+
112+
x = F.relu(self.fc1(h))
113+
logits = self.fc2(x)
114+
115+
posterior = F.softmax(logits, dim=1)
116+
dist = torch.distributions.Categorical(posterior)
117+
118+
x = dist.sample()
119+
output.append(
120+
2 * x.float().item() / (self.quantization_channels - 1.0) - 1.0
121+
)
122+
123+
output = np.asarray(output, dtype=np.float64)
124+
output = mulaw_decode(output, self.quantization_channels)
125+
126+
return output
127+
128+
129+
def get_gru_cell(gru):
130+
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
131+
gru_cell.weight_hh.data = gru.weight_hh_l0.data
132+
gru_cell.weight_ih.data = gru.weight_ih_l0.data
133+
gru_cell.bias_hh.data = gru.bias_hh_l0.data
134+
gru_cell.bias_ih.data = gru.bias_ih_l0.data
135+
return gru_cell
136+
137+
138+
def mulaw_decode(x_mu: np.ndarray, n_channels: int) -> np.ndarray:
139+
"""Decode mu-law encoded signal."""
140+
mu = n_channels - 1
141+
x = np.sign(x_mu) / mu * ((1 + mu) ** np.abs(x_mu) - 1)
142+
return x

0 commit comments

Comments
 (0)