Skip to content

Commit 2c47453

Browse files
committed
turn the model into TorchScript
1 parent f610d33 commit 2c47453

File tree

6 files changed

+110
-162
lines changed

6 files changed

+110
-162
lines changed

mel2wav.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import torch
99
import soundfile as sf
1010

11-
from models import Vocoder
12-
1311

1412
def parse_args():
1513
"""Parse command-line arguments."""
@@ -25,14 +23,14 @@ def main(ckpt_path, npy_path, output_path):
2523

2624
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2725

28-
model = Vocoder.load_checkpoint(ckpt_path)
26+
model = torch.jit.load(ckpt_path)
2927
model.to(device)
30-
model.eval()
3128

3229
mel = np.load(npy_path)
3330
mel = torch.FloatTensor(mel).to(device).transpose(0, 1).unsqueeze(0)
3431

35-
wav = model.generate(mel)
32+
with torch.no_grad():
33+
wav = model.generate(mel).squeeze().detach().cpu().numpy()
3634

3735
npy_path_name = Path(npy_path).name
3836
wav_path = npy_path_name + ".wav" if output_path is None else output_path

models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .vocoder import *
1+
from .universal_vocoder import *

models/universal_vocoder.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""Universal vocoder"""
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
6+
7+
class UniversalVocoder(nn.Module):
8+
"""Universal vocoding"""
9+
10+
def __init__(
11+
self,
12+
sample_rate,
13+
frames_per_sample,
14+
frames_per_slice,
15+
mel_dim,
16+
mel_rnn_dim,
17+
emb_dim,
18+
wav_rnn_dim,
19+
affine_dim,
20+
bits,
21+
hop_length,
22+
):
23+
super().__init__()
24+
25+
self.sample_rate = sample_rate
26+
self.frames_per_slice = frames_per_slice
27+
self.pad = (frames_per_sample - frames_per_slice) // 2
28+
self.wav_rnn_dim = wav_rnn_dim
29+
self.quant_dim = 2 ** bits
30+
self.hop_len = hop_length
31+
32+
self.mel_rnn = nn.GRU(
33+
mel_dim, mel_rnn_dim, num_layers=2, batch_first=True, bidirectional=True
34+
)
35+
self.embedding = nn.Embedding(self.quant_dim, emb_dim)
36+
self.wav_rnn = nn.GRU(emb_dim + 2 * mel_rnn_dim, wav_rnn_dim, batch_first=True)
37+
self.affine = nn.Sequential(
38+
nn.Linear(wav_rnn_dim, affine_dim),
39+
nn.ReLU(),
40+
nn.Linear(affine_dim, self.quant_dim),
41+
)
42+
43+
def forward(self, wavs, mels):
44+
"""Generate waveform from mel spectrogram with teacher-forcing."""
45+
mel_embs, _ = self.mel_rnn(mels)
46+
mel_embs = mel_embs.transpose(1, 2)
47+
mel_embs = mel_embs[:, :, self.pad : self.pad + self.frames_per_slice]
48+
49+
conditions = F.interpolate(mel_embs, scale_factor=float(self.hop_len))
50+
conditions = conditions.transpose(1, 2)
51+
52+
wav_embs = self.embedding(wavs)
53+
wav_outs, _ = self.wav_rnn(torch.cat((wav_embs, conditions), dim=2))
54+
55+
return self.affine(wav_outs)
56+
57+
@torch.jit.export
58+
def generate(self, mels):
59+
"""Generate waveform from mel spectrogram."""
60+
mel_embs, _ = self.mel_rnn(mels)
61+
mel_embs = mel_embs.transpose(1, 2)
62+
63+
conditions = F.interpolate(mel_embs, scale_factor=float(self.hop_len))
64+
conditions = conditions.transpose(1, 2)
65+
66+
hid = torch.zeros(mels.size(0), 1, self.wav_rnn_dim, device=mels.device)
67+
wav = torch.full(
68+
(mels.size(0),), self.quant_dim // 2, dtype=torch.long, device=mels.device,
69+
)
70+
wavs = torch.empty(
71+
mels.size(0), mels.size(1) * self.hop_len, device=mels.device
72+
)
73+
74+
for i, condition in enumerate(torch.unbind(conditions, dim=1)):
75+
wav_emb = self.embedding(wav)
76+
_, hid = self.wav_rnn(
77+
torch.cat((wav_emb, condition), dim=1).unsqueeze(1), hid
78+
)
79+
logit = self.affine(hid.squeeze(1))
80+
posterior = F.softmax(logit, dim=1)
81+
wav = torch.multinomial(posterior, 1).squeeze(1)
82+
wavs[:, i] = 2 * wav.item() / (self.quant_dim - 1.0) - 1.0
83+
84+
mu = self.quant_dim - 1
85+
wavs = torch.sign(wavs) / mu * ((1 + mu) ** torch.abs(wavs) - 1)
86+
87+
return wavs

models/vocoder.py

Lines changed: 0 additions & 142 deletions
This file was deleted.

reconstruct.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from jsonargparse import ArgumentParser, ActionConfigFile
99

1010
from data import load_wav, log_mel_spectrogram
11-
from models import Vocoder
1211

1312

1413
def parse_args():
@@ -46,21 +45,22 @@ def main(
4645

4746
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4847

49-
model = Vocoder.load_checkpoint(ckpt_path)
48+
model = torch.jit.load(ckpt_path)
5049
model.to(device)
51-
model.eval()
5250

5351
wav = load_wav(audio_path, sample_rate)
5452
mel = log_mel_spectrogram(
5553
wav, preemph, sample_rate, n_mels, n_fft, hop_len, win_len, f_min
5654
).T
5755

5856
mel = torch.FloatTensor(mel).to(device).transpose(0, 1).unsqueeze(0)
59-
wav = model.generate(mel)
57+
58+
with torch.no_grad():
59+
wav = model.generate(mel).squeeze().detach().cpu().numpy()
6060

6161
npy_path_name = Path(audio_path).name
6262
wav_path = npy_path_name + ".rec.wav" if output_path is None else output_path
63-
sf.write(wav_path, wav, model.sample_rate)
63+
sf.write(wav_path, wav, sample_rate)
6464

6565

6666
if __name__ == "__main__":

train.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
"""Train reconstruction model."""
2+
"""Train universal vocoder."""
33

44
from datetime import datetime
55
from pathlib import Path
@@ -13,7 +13,8 @@
1313
from jsonargparse import ArgumentParser, ActionConfigFile
1414

1515
from data import VocoderDataset
16-
from models import Vocoder
16+
17+
from models import UniversalVocoder
1718

1819

1920
def parse_args():
@@ -90,17 +91,20 @@ def main(
9091
pin_memory=True,
9192
)
9293

93-
model = Vocoder(
94+
model = UniversalVocoder(
9495
sample_rate=dataset.sample_rate,
95-
mel_channels=dataset.n_mels,
96-
conditioning_channels=conditioning_channels,
97-
embedding_dim=embedding_dim,
98-
rnn_channels=rnn_channels,
99-
fc_channels=fc_channels,
96+
frames_per_sample=frames_per_sample,
97+
frames_per_slice=frames_per_slice,
98+
mel_dim=dataset.n_mels,
99+
mel_rnn_dim=conditioning_channels,
100+
emb_dim=embedding_dim,
101+
wav_rnn_dim=rnn_channels,
102+
affine_dim=fc_channels,
100103
bits=bits,
101104
hop_length=dataset.hop_len,
102105
)
103106
model.to(device)
107+
model = torch.jit.script(model)
104108

105109
optimizer = Adam(model.parameters())
106110

@@ -169,7 +173,8 @@ def main(
169173
save_dir_path = Path(save_dir)
170174
save_dir_path.mkdir(parents=True, exist_ok=True)
171175
checkpoint_path = save_dir_path / f"vocoder-ckpt-{step+1}.pt"
172-
model.save_checkpoint(checkpoint_path)
176+
torch.jit.save(model.cpu(), str(checkpoint_path))
177+
model.to(device)
173178

174179

175180
if __name__ == "__main__":

0 commit comments

Comments
 (0)