Merge pull request #167 from myshell-ai/openvoice-v2

add openvoice-v2
This commit is contained in:
Zengyi Qin 2024-04-17 14:39:47 -04:00 committed by GitHub
commit 782971e67a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 150 additions and 3 deletions

142
demo_part3.ipynb Normal file
View File

@ -0,0 +1,142 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multi-Accent and Multi-Lingual Voice Clone Demo with MeloTTS"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"from openvoice import se_extractor\n",
"from openvoice.api import ToneColorConverter"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialization\n",
"\n",
"In this example, we will use the checkpoints from OpenVoiceV2. OpenVoiceV2 is trained with more aggressive augmentations and thus demonstrate better robustness in some cases."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ckpt_converter = 'checkpoints_v2/converter'\n",
"device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
"output_dir = 'outputs_v2'\n",
"\n",
"tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)\n",
"tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')\n",
"\n",
"os.makedirs(output_dir, exist_ok=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Obtain Tone Color Embedding\n",
"We only extract the tone color embedding for the target speaker. The source tone color embeddings can be directly loaded from `checkpoints_v2/ses` folder."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"reference_speaker = 'resources/example_reference.mp3'\n",
"target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, vad=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Use MeloTTS as Base Speakers\n",
"\n",
"MeloTTS is a high-quality multi-lingual text-to-speech library by @MyShell.ai, supporting languages including English (American, British, Indian, Australian, Default), Spanish, French, Chinese, Japanese, Korean. In the following example, we will use the models in MeloTTS as the base speakers. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from melo.api import TTS\n",
"\n",
"texts = {\n",
" 'EN': \"Did you ever hear a folk tale about a giant turtle?\",\n",
" 'ES': \"El resplandor del sol acaricia las olas, pintando el cielo con una paleta deslumbrante.\",\n",
" 'FR': \"La lueur dorée du soleil caresse les vagues, peignant le ciel d'une palette éblouissante.\",\n",
" 'ZH': \"在这次vacation中我们计划去Paris欣赏埃菲尔铁塔和卢浮宫的美景。\",\n",
" 'JP': \"彼は毎朝ジョギングをして体を健康に保っています。\",\n",
" 'KR': \"안녕하세요! 오늘은 날씨가 정말 좋네요.\",\n",
"}\n",
"\n",
"\n",
"src_path = f'{output_dir}/tmp.wav'\n",
"\n",
"# Speed is adjustable\n",
"speed = 1.0\n",
"\n",
"for language, text in texts.items():\n",
" model = TTS(language=language, device=device)\n",
" speaker_ids = model.hps.data.spk2id\n",
" \n",
" for speaker_key in speaker_ids.keys():\n",
" speaker_id = speaker_ids[speaker_key]\n",
" speaker_key = speaker_key.lower().replace('_', '-')\n",
" \n",
" source_se = torch.load(f'checkpoints_v2/base_speakers/ses/{speaker_key}.pth', map_location=device)\n",
" model.tts_to_file(text, speaker_id, src_path, speed=speed)\n",
" save_path = f'{output_dir}/output_melotts_{speaker_key}.wav'\n",
"\n",
" # Run the tone color converter\n",
" encode_message = \"@MyShell\"\n",
" tone_color_converter.convert(\n",
" audio_src_path=src_path, \n",
" src_se=source_se, \n",
" tgt_se=target_se, \n",
" output_path=save_path,\n",
" message=encode_message)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "melo",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -107,6 +107,7 @@ class ToneColorConverter(OpenVoiceBaseClass):
self.watermark_model = wavmark.load_model().to(self.device)
else:
self.watermark_model = None
self.version = getattr(self.hps, '_version_', "v1")

View File

@ -420,6 +420,7 @@ class SynthesizerTrn(nn.Module):
upsample_kernel_sizes,
n_speakers=256,
gin_channels=256,
zero_g=False,
**kwargs
):
super().__init__()
@ -461,6 +462,7 @@ class SynthesizerTrn(nn.Module):
self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
self.emb_g = nn.Embedding(n_speakers, gin_channels)
self.zero_g = zero_g
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., sdp_ratio=0.2, max_len=None):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
@ -490,8 +492,8 @@ class SynthesizerTrn(nn.Module):
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
g_src = sid_src
g_tgt = sid_tgt
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
return o_hat, y_mask, (z, z_p, z_hat)

View File

@ -128,8 +128,10 @@ def hash_numpy_array(audio_path):
def get_se(audio_path, vc_model, target_dir='processed', vad=True):
device = vc_model.device
version = vc_model.version
print("OpenVoice version:", version)
audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{hash_numpy_array(audio_path)}"
audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{version}_{hash_numpy_array(audio_path)}"
se_path = os.path.join(target_dir, audio_name, 'se.pth')
if os.path.isfile(se_path):