From bc1d992f4fc209527c6dfb8b90a46f565df82d24 Mon Sep 17 00:00:00 2001 From: wl-zhao Date: Sat, 23 Dec 2023 20:38:52 +0800 Subject: [PATCH] add chinese model --- api.py | 11 +- demo_part1.ipynb | 141 +++++++++++++++++--- demo_part2.ipynb | 2 +- text/cleaners.py | 1 + text/mandarin.py | 326 +++++++++++++++++++++++++++++++++++++++++++++++ utils.py | 58 +++++++++ 6 files changed, 516 insertions(+), 23 deletions(-) create mode 100644 text/mandarin.py diff --git a/api.py b/api.py index 1c07c07..b769470 100644 --- a/api.py +++ b/api.py @@ -41,7 +41,8 @@ class OpenVoiceBaseClass(object): class BaseSpeakerTTS(OpenVoiceBaseClass): language_marks = { - "english": "[EN]", + "english": "EN", + "chinese": "ZH", } @staticmethod @@ -62,8 +63,8 @@ class BaseSpeakerTTS(OpenVoiceBaseClass): return audio_segments @staticmethod - def split_sentences_into_pieces(text): - texts = utils.split_sentences_latin(text) + def split_sentences_into_pieces(text, language_str): + texts = utils.split_sentence(text, language_str=language_str) print(" > Text splitted to sentences.") print('\n'.join(texts)) print(" > ===========================") @@ -73,12 +74,12 @@ class BaseSpeakerTTS(OpenVoiceBaseClass): mark = self.language_marks.get(language.lower(), None) assert mark is not None, f"language {language} is not supported" - texts = self.split_sentences_into_pieces(text) + texts = self.split_sentences_into_pieces(text, mark) audio_list = [] for t in texts: t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t) - t = mark + t + mark + t = f'[{mark}]{t}[{mark}]' stn_tst = self.get_text(t, self.hps, False) device = self.device speaker_id = self.hps.speakers[speaker] diff --git a/demo_part1.ipynb b/demo_part1.ipynb index 2b84ac6..58c2e32 100644 --- a/demo_part1.ipynb +++ b/demo_part1.ipynb @@ -10,10 +10,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "b7f043ee", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/zwl/anaconda3/envs/openvoice/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import os\n", "import torch\n", @@ -31,12 +40,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "aacad912", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded checkpoint 'checkpoints/base_speakers/EN/checkpoint.pth'\n", + "missing/unexpected keys: [] []\n", + "Loaded checkpoint 'checkpoints/converter/checkpoint.pth'\n", + "missing/unexpected keys: [] []\n" + ] + } + ], "source": [ - "ckpt_base = 'checkpoints/base_speaker'\n", + "ckpt_base = 'checkpoints/base_speakers/EN'\n", "ckpt_converter = 'checkpoints/converter'\n", "device = 'cuda:0'\n", "output_dir = 'outputs'\n", @@ -64,19 +84,18 @@ "metadata": {}, "source": [ "The `source_se` is the tone color embedding of the base speaker. \n", - "It is an average for multiple sentences with multiple emotions\n", - "of the base speaker. We directly provide the result here but\n", + "It is an average of multiple sentences generated by the base speaker. We directly provide the result here but\n", "the readers feel free to extract `source_se` by themselves." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "63ff6273", "metadata": {}, "outputs": [], "source": [ - "source_se = torch.load(f'{ckpt_base}/source_se.pth').to(device)" + "source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device)" ] }, { @@ -89,7 +108,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "55105eae", "metadata": {}, "outputs": [], @@ -108,17 +127,38 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "73dc1259", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " > Text splitted to sentences.\n", + "This audio is generated by open voice.\n", + " > ===========================\n", + "ðɪs ˈɑdiˌoʊ ɪz ˈdʒɛnəɹˌeɪtɪd baɪ ˈoʊpən vɔɪs.\n", + " length:45\n", + " length:45\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/zwl/anaconda3/envs/openvoice/lib/python3.9/site-packages/wavmark/models/my_model.py:25: UserWarning: istft will require a complex-valued input tensor in a future PyTorch release. Matching the output from stft with return_complex=True. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/native/SpectralOps.cpp:978.)\n", + " return torch.istft(signal_wmd_fft, n_fft=self.n_fft, hop_length=self.hop_length, window=window,\n" + ] + } + ], "source": [ - "save_path = f'{output_dir}/output_friendly.wav'\n", + "save_path = f'{output_dir}/output_en_default.wav'\n", "\n", "# Run the base speaker tts\n", "text = \"This audio is generated by open voice.\"\n", "src_path = f'{output_dir}/tmp.wav'\n", - "base_speaker_tts.tts(text, src_path, speaker='friendly', language='English', speed=1.0)\n", + "base_speaker_tts.tts(text, src_path, speaker='default', language='English', speed=1.0)\n", "\n", "# Run the tone color converter\n", "encode_message = \"@MyShell\"\n", @@ -135,16 +175,30 @@ "id": "6e3ea28a", "metadata": {}, "source": [ - "**Try with different styles and speed.** The style can be controlled by the `speaker` parameter in the `base_speaker_tts.tts` method. Available choices: friendly, cheerful, excited, sad, angry, terrified, shouting, whispering. The speed can be controlled by the `speed` parameter. Let's try whispering with speed 0.9." + "**Try with different styles and speed.** The style can be controlled by the `speaker` parameter in the `base_speaker_tts.tts` method. Available choices: friendly, cheerful, excited, sad, angry, terrified, shouting, whispering. Note that the tone color embedding need to be updated. The speed can be controlled by the `speed` parameter. Let's try whispering with speed 0.9." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "fd022d38", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " > Text splitted to sentences.\n", + "This audio is generated by open voice with a half-performance model.\n", + " > ===========================\n", + "ðɪs ˈɑdiˌoʊ ɪz ˈdʒɛnəɹˌeɪtɪd baɪ ˈoʊpən vɔɪs wɪθ ə half-peɹfoɹmance* ˈmɑdəɫ.\n", + " length:76\n", + " length:75\n" + ] + } + ], "source": [ + "source_se = torch.load(f'{ckpt_base}/en_style_se.pth').to(device)\n", "save_path = f'{output_dir}/output_whispering.wav'\n", "\n", "# Run the base speaker tts\n", @@ -162,6 +216,59 @@ " message=encode_message)" ] }, + { + "cell_type": "markdown", + "id": "5fcfc70b", + "metadata": {}, + "source": [ + "**Try with different languages.** OpenVoice can achieve multi-lingual voice cloning by simply replace the base speaker. We provide an example with a Chinese base speaker here and we encourage the readers to try `demo_part2.ipynb` for a detaied demo." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "a71d1387", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded checkpoint 'checkpoints/base_speakers/ZH/checkpoint.pth'\n", + "missing/unexpected keys: [] []\n", + " > Text splitted to sentences.\n", + "今天天气真好, 我们一起出去吃饭吧.\n", + " > ===========================\n", + "tʃ⁼in→tʰjɛn→tʰjɛn→tʃʰi↓ ts`⁼ən→ xɑʊ↓↑, wo↓↑mən i↓tʃʰi↓↑ ts`ʰu→tʃʰɥ↓ ts`ʰɹ`→fan↓ p⁼a.\n", + " length:85\n", + " length:85\n" + ] + } + ], + "source": [ + "\n", + "ckpt_base = 'checkpoints/base_speakers/ZH'\n", + "base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json', device=device)\n", + "base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth')\n", + "\n", + "source_se = torch.load(f'{ckpt_base}/zh_default_se.pth').to(device)\n", + "save_path = f'{output_dir}/output_chinese.wav'\n", + "\n", + "# Run the base speaker tts\n", + "text = \"今天天气真好,我们一起出去吃饭吧。\"\n", + "src_path = f'{output_dir}/tmp.wav'\n", + "base_speaker_tts.tts(text, src_path, speaker='default', language='Chinese', speed=1.0)\n", + "\n", + "# Run the tone color converter\n", + "encode_message = \"@MyShell\"\n", + "tone_color_converter.convert(\n", + " audio_src_path=src_path, \n", + " src_se=source_se, \n", + " tgt_se=target_se, \n", + " output_path=save_path,\n", + " message=encode_message)" + ] + }, { "cell_type": "markdown", "id": "8e513094", diff --git a/demo_part2.ipynb b/demo_part2.ipynb index 2a628b4..fec3f2f 100644 --- a/demo_part2.ipynb +++ b/demo_part2.ipynb @@ -51,7 +51,7 @@ "id": "3db80fcf", "metadata": {}, "source": [ - "In this demo, we will use OpenAI TTS as the base speaker to produce multi-lingual speech audio. The users can flexibly change the base speaker according to their own needs. Please create a file named `.env` and place OpenAI key as `OPENAI_API_KEY=xxx`." + "In this demo, we will use OpenAI TTS as the base speaker to produce multi-lingual speech audio. The users can flexibly change the base speaker according to their own needs. Please create a file named `.env` and place OpenAI key as `OPENAI_API_KEY=xxx`. We have also provided a Chinese base speaker model (see `demo_part1.ipynb`)." ] }, { diff --git a/text/cleaners.py b/text/cleaners.py index d219c41..619ad47 100644 --- a/text/cleaners.py +++ b/text/cleaners.py @@ -1,5 +1,6 @@ import re from text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2 +from text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2 def cjke_cleaners2(text): text = re.sub(r'\[ZH\](.*?)\[ZH\]', diff --git a/text/mandarin.py b/text/mandarin.py new file mode 100644 index 0000000..162e1b9 --- /dev/null +++ b/text/mandarin.py @@ -0,0 +1,326 @@ +import os +import sys +import re +from pypinyin import lazy_pinyin, BOPOMOFO +import jieba +import cn2an +import logging + + +# List of (Latin alphabet, bopomofo) pairs: +_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ + ('a', 'ㄟˉ'), + ('b', 'ㄅㄧˋ'), + ('c', 'ㄙㄧˉ'), + ('d', 'ㄉㄧˋ'), + ('e', 'ㄧˋ'), + ('f', 'ㄝˊㄈㄨˋ'), + ('g', 'ㄐㄧˋ'), + ('h', 'ㄝˇㄑㄩˋ'), + ('i', 'ㄞˋ'), + ('j', 'ㄐㄟˋ'), + ('k', 'ㄎㄟˋ'), + ('l', 'ㄝˊㄛˋ'), + ('m', 'ㄝˊㄇㄨˋ'), + ('n', 'ㄣˉ'), + ('o', 'ㄡˉ'), + ('p', 'ㄆㄧˉ'), + ('q', 'ㄎㄧㄡˉ'), + ('r', 'ㄚˋ'), + ('s', 'ㄝˊㄙˋ'), + ('t', 'ㄊㄧˋ'), + ('u', 'ㄧㄡˉ'), + ('v', 'ㄨㄧˉ'), + ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), + ('x', 'ㄝˉㄎㄨˋㄙˋ'), + ('y', 'ㄨㄞˋ'), + ('z', 'ㄗㄟˋ') +]] + +# List of (bopomofo, romaji) pairs: +_bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('ㄅㄛ', 'p⁼wo'), + ('ㄆㄛ', 'pʰwo'), + ('ㄇㄛ', 'mwo'), + ('ㄈㄛ', 'fwo'), + ('ㄅ', 'p⁼'), + ('ㄆ', 'pʰ'), + ('ㄇ', 'm'), + ('ㄈ', 'f'), + ('ㄉ', 't⁼'), + ('ㄊ', 'tʰ'), + ('ㄋ', 'n'), + ('ㄌ', 'l'), + ('ㄍ', 'k⁼'), + ('ㄎ', 'kʰ'), + ('ㄏ', 'h'), + ('ㄐ', 'ʧ⁼'), + ('ㄑ', 'ʧʰ'), + ('ㄒ', 'ʃ'), + ('ㄓ', 'ʦ`⁼'), + ('ㄔ', 'ʦ`ʰ'), + ('ㄕ', 's`'), + ('ㄖ', 'ɹ`'), + ('ㄗ', 'ʦ⁼'), + ('ㄘ', 'ʦʰ'), + ('ㄙ', 's'), + ('ㄚ', 'a'), + ('ㄛ', 'o'), + ('ㄜ', 'ə'), + ('ㄝ', 'e'), + ('ㄞ', 'ai'), + ('ㄟ', 'ei'), + ('ㄠ', 'au'), + ('ㄡ', 'ou'), + ('ㄧㄢ', 'yeNN'), + ('ㄢ', 'aNN'), + ('ㄧㄣ', 'iNN'), + ('ㄣ', 'əNN'), + ('ㄤ', 'aNg'), + ('ㄧㄥ', 'iNg'), + ('ㄨㄥ', 'uNg'), + ('ㄩㄥ', 'yuNg'), + ('ㄥ', 'əNg'), + ('ㄦ', 'əɻ'), + ('ㄧ', 'i'), + ('ㄨ', 'u'), + ('ㄩ', 'ɥ'), + ('ˉ', '→'), + ('ˊ', '↑'), + ('ˇ', '↓↑'), + ('ˋ', '↓'), + ('˙', ''), + (',', ','), + ('。', '.'), + ('!', '!'), + ('?', '?'), + ('—', '-') +]] + +# List of (romaji, ipa) pairs: +_romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ + ('ʃy', 'ʃ'), + ('ʧʰy', 'ʧʰ'), + ('ʧ⁼y', 'ʧ⁼'), + ('NN', 'n'), + ('Ng', 'ŋ'), + ('y', 'j'), + ('h', 'x') +]] + +# List of (bopomofo, ipa) pairs: +_bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('ㄅㄛ', 'p⁼wo'), + ('ㄆㄛ', 'pʰwo'), + ('ㄇㄛ', 'mwo'), + ('ㄈㄛ', 'fwo'), + ('ㄅ', 'p⁼'), + ('ㄆ', 'pʰ'), + ('ㄇ', 'm'), + ('ㄈ', 'f'), + ('ㄉ', 't⁼'), + ('ㄊ', 'tʰ'), + ('ㄋ', 'n'), + ('ㄌ', 'l'), + ('ㄍ', 'k⁼'), + ('ㄎ', 'kʰ'), + ('ㄏ', 'x'), + ('ㄐ', 'tʃ⁼'), + ('ㄑ', 'tʃʰ'), + ('ㄒ', 'ʃ'), + ('ㄓ', 'ts`⁼'), + ('ㄔ', 'ts`ʰ'), + ('ㄕ', 's`'), + ('ㄖ', 'ɹ`'), + ('ㄗ', 'ts⁼'), + ('ㄘ', 'tsʰ'), + ('ㄙ', 's'), + ('ㄚ', 'a'), + ('ㄛ', 'o'), + ('ㄜ', 'ə'), + ('ㄝ', 'ɛ'), + ('ㄞ', 'aɪ'), + ('ㄟ', 'eɪ'), + ('ㄠ', 'ɑʊ'), + ('ㄡ', 'oʊ'), + ('ㄧㄢ', 'jɛn'), + ('ㄩㄢ', 'ɥæn'), + ('ㄢ', 'an'), + ('ㄧㄣ', 'in'), + ('ㄩㄣ', 'ɥn'), + ('ㄣ', 'ən'), + ('ㄤ', 'ɑŋ'), + ('ㄧㄥ', 'iŋ'), + ('ㄨㄥ', 'ʊŋ'), + ('ㄩㄥ', 'jʊŋ'), + ('ㄥ', 'əŋ'), + ('ㄦ', 'əɻ'), + ('ㄧ', 'i'), + ('ㄨ', 'u'), + ('ㄩ', 'ɥ'), + ('ˉ', '→'), + ('ˊ', '↑'), + ('ˇ', '↓↑'), + ('ˋ', '↓'), + ('˙', ''), + (',', ','), + ('。', '.'), + ('!', '!'), + ('?', '?'), + ('—', '-') +]] + +# List of (bopomofo, ipa2) pairs: +_bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('ㄅㄛ', 'pwo'), + ('ㄆㄛ', 'pʰwo'), + ('ㄇㄛ', 'mwo'), + ('ㄈㄛ', 'fwo'), + ('ㄅ', 'p'), + ('ㄆ', 'pʰ'), + ('ㄇ', 'm'), + ('ㄈ', 'f'), + ('ㄉ', 't'), + ('ㄊ', 'tʰ'), + ('ㄋ', 'n'), + ('ㄌ', 'l'), + ('ㄍ', 'k'), + ('ㄎ', 'kʰ'), + ('ㄏ', 'h'), + ('ㄐ', 'tɕ'), + ('ㄑ', 'tɕʰ'), + ('ㄒ', 'ɕ'), + ('ㄓ', 'tʂ'), + ('ㄔ', 'tʂʰ'), + ('ㄕ', 'ʂ'), + ('ㄖ', 'ɻ'), + ('ㄗ', 'ts'), + ('ㄘ', 'tsʰ'), + ('ㄙ', 's'), + ('ㄚ', 'a'), + ('ㄛ', 'o'), + ('ㄜ', 'ɤ'), + ('ㄝ', 'ɛ'), + ('ㄞ', 'aɪ'), + ('ㄟ', 'eɪ'), + ('ㄠ', 'ɑʊ'), + ('ㄡ', 'oʊ'), + ('ㄧㄢ', 'jɛn'), + ('ㄩㄢ', 'yæn'), + ('ㄢ', 'an'), + ('ㄧㄣ', 'in'), + ('ㄩㄣ', 'yn'), + ('ㄣ', 'ən'), + ('ㄤ', 'ɑŋ'), + ('ㄧㄥ', 'iŋ'), + ('ㄨㄥ', 'ʊŋ'), + ('ㄩㄥ', 'jʊŋ'), + ('ㄥ', 'ɤŋ'), + ('ㄦ', 'əɻ'), + ('ㄧ', 'i'), + ('ㄨ', 'u'), + ('ㄩ', 'y'), + ('ˉ', '˥'), + ('ˊ', '˧˥'), + ('ˇ', '˨˩˦'), + ('ˋ', '˥˩'), + ('˙', ''), + (',', ','), + ('。', '.'), + ('!', '!'), + ('?', '?'), + ('—', '-') +]] + + +def number_to_chinese(text): + numbers = re.findall(r'\d+(?:\.?\d+)?', text) + for number in numbers: + text = text.replace(number, cn2an.an2cn(number), 1) + return text + + +def chinese_to_bopomofo(text): + text = text.replace('、', ',').replace(';', ',').replace(':', ',') + words = jieba.lcut(text, cut_all=False) + text = '' + for word in words: + bopomofos = lazy_pinyin(word, BOPOMOFO) + if not re.search('[\u4e00-\u9fff]', word): + text += word + continue + for i in range(len(bopomofos)): + bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i]) + if text != '': + text += ' ' + text += ''.join(bopomofos) + return text + + +def latin_to_bopomofo(text): + for regex, replacement in _latin_to_bopomofo: + text = re.sub(regex, replacement, text) + return text + + +def bopomofo_to_romaji(text): + for regex, replacement in _bopomofo_to_romaji: + text = re.sub(regex, replacement, text) + return text + + +def bopomofo_to_ipa(text): + for regex, replacement in _bopomofo_to_ipa: + text = re.sub(regex, replacement, text) + return text + + +def bopomofo_to_ipa2(text): + for regex, replacement in _bopomofo_to_ipa2: + text = re.sub(regex, replacement, text) + return text + + +def chinese_to_romaji(text): + text = number_to_chinese(text) + text = chinese_to_bopomofo(text) + text = latin_to_bopomofo(text) + text = bopomofo_to_romaji(text) + text = re.sub('i([aoe])', r'y\1', text) + text = re.sub('u([aoəe])', r'w\1', text) + text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', + r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') + text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) + return text + + +def chinese_to_lazy_ipa(text): + text = chinese_to_romaji(text) + for regex, replacement in _romaji_to_ipa: + text = re.sub(regex, replacement, text) + return text + + +def chinese_to_ipa(text): + text = number_to_chinese(text) + text = chinese_to_bopomofo(text) + text = latin_to_bopomofo(text) + text = bopomofo_to_ipa(text) + text = re.sub('i([aoe])', r'j\1', text) + text = re.sub('u([aoəe])', r'w\1', text) + text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', + r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') + text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) + return text + + +def chinese_to_ipa2(text): + text = number_to_chinese(text) + text = chinese_to_bopomofo(text) + text = latin_to_bopomofo(text) + text = bopomofo_to_ipa2(text) + text = re.sub(r'i([aoe])', r'j\1', text) + text = re.sub(r'u([aoəe])', r'w\1', text) + text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text) + text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text) + return text diff --git a/utils.py b/utils.py index f3a29a1..747a3b7 100644 --- a/utils.py +++ b/utils.py @@ -75,6 +75,13 @@ def bits_to_string(bits_array): return output_string +def split_sentence(text, min_len=10, language_str='[EN]'): + if language_str in ['EN']: + sentences = split_sentences_latin(text, min_len=min_len) + else: + sentences = split_sentences_zh(text, min_len=min_len) + return sentences + def split_sentences_latin(text, min_len=10): """Split Long sentences into list of short ones @@ -133,4 +140,55 @@ def merge_short_sentences_latin(sens): sens_out.pop(-1) except: pass + return sens_out + +def split_sentences_zh(text, min_len=10): + text = re.sub('[。!?;]', '.', text) + text = re.sub('[,]', ',', text) + # 将文本中的换行符、空格和制表符替换为空格 + text = re.sub('[\n\t ]+', ' ', text) + # 在标点符号后添加一个空格 + text = re.sub('([,.!?;])', r'\1 $#!', text) + # 分隔句子并去除前后空格 + # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)] + sentences = [s.strip() for s in text.split('$#!')] + if len(sentences[-1]) == 0: del sentences[-1] + + new_sentences = [] + new_sent = [] + count_len = 0 + for ind, sent in enumerate(sentences): + new_sent.append(sent) + count_len += len(sent) + if count_len > min_len or ind == len(sentences) - 1: + count_len = 0 + new_sentences.append(' '.join(new_sent)) + new_sent = [] + return merge_short_sentences_zh(new_sentences) + + +def merge_short_sentences_zh(sens): + # return sens + """Avoid short sentences by merging them with the following sentence. + + Args: + List[str]: list of input sentences. + + Returns: + List[str]: list of output sentences. + """ + sens_out = [] + for s in sens: + # If the previous sentense is too short, merge them with + # the current sentence. + if len(sens_out) > 0 and len(sens_out[-1]) <= 2: + sens_out[-1] = sens_out[-1] + " " + s + else: + sens_out.append(s) + try: + if len(sens_out[-1]) <= 2: + sens_out[-2] = sens_out[-2] + " " + sens_out[-1] + sens_out.pop(-1) + except: + pass return sens_out \ No newline at end of file