add more typing

This commit is contained in:
rany2 2023-01-05 00:56:01 +02:00
parent efe0cbedde
commit c4c3dc5a13
12 changed files with 129 additions and 117 deletions

View File

@ -9,7 +9,7 @@ import asyncio
import edge_tts
async def main():
async def main() -> None:
TEXT = "Hello World!"
VOICE = "en-GB-SoniaNeural"
OUTPUT_FILE = "test.mp3"

View File

@ -11,7 +11,7 @@ import edge_tts
from edge_tts import VoicesManager
async def main():
async def main() -> None:
voices = await VoicesManager.create()
voice = voices.find(Gender="Male", Language="es")
# Also supports Locales

View File

@ -1,3 +1,4 @@
find src examples -name '*.py' | xargs black
find src examples -name '*.py' | xargs isort
find src examples -name '*.py' | xargs pylint
find src examples -name '*.py' | xargs mypy

13
mypy.ini Normal file
View File

@ -0,0 +1,13 @@
[mypy]
warn_return_any = True
warn_unused_configs = True
#disallow_any_unimported = True
#disallow_any_expr = True
#disallow_any_decorated = True
#disallow_any_explicit = True
#disallow_any_generics = True
#disallow_subclassing_any = True
#disallow_untyped_calls = True
disallow_untyped_defs = True
disallow_incomplete_defs = True

View File

@ -27,4 +27,11 @@ where=src
[options.entry_points]
console_scripts =
edge-tts = edge_tts.__main__:main
edge-playback = edge_playback.__init__:main
edge-playback = edge_playback.__main__:main
[options.extras_require]
dev =
black
isort
mypy
pylint

View File

@ -1,63 +0,0 @@
#!/usr/bin/env python3
"""
Playback TTS with subtitles using edge-tts and mpv.
"""
import os
import subprocess
import sys
import tempfile
from shutil import which
def main():
depcheck_failed = False
if not which("mpv"):
print("mpv is not installed.", file=sys.stderr)
depcheck_failed = True
if not which("edge-tts"):
print("edge-tts is not installed.", file=sys.stderr)
depcheck_failed = True
if depcheck_failed:
print("Please install the missing dependencies.", file=sys.stderr)
sys.exit(1)
media = None
subtitle = None
try:
media = tempfile.NamedTemporaryFile(delete=False)
media.close()
subtitle = tempfile.NamedTemporaryFile(delete=False)
subtitle.close()
print(f"Media file: {media.name}")
print(f"Subtitle file: {subtitle.name}\n")
with subprocess.Popen(
[
"edge-tts",
f"--write-media={media.name}",
f"--write-subtitles={subtitle.name}",
]
+ sys.argv[1:]
) as process:
process.communicate()
with subprocess.Popen(
[
"mpv",
f"--sub-file={subtitle.name}",
media.name,
]
) as process:
process.communicate()
finally:
if media is not None:
os.unlink(media.name)
if subtitle is not None:
os.unlink(subtitle.name)
if __name__ == "__main__":
main()

View File

@ -1,10 +1,63 @@
#!/usr/bin/env python3
"""
This is the main file for the edge_playback package.
Playback TTS with subtitles using edge-tts and mpv.
"""
from edge_playback.__init__ import main
import os
import subprocess
import sys
import tempfile
from shutil import which
def main() -> None:
depcheck_failed = False
if not which("mpv"):
print("mpv is not installed.", file=sys.stderr)
depcheck_failed = True
if not which("edge-tts"):
print("edge-tts is not installed.", file=sys.stderr)
depcheck_failed = True
if depcheck_failed:
print("Please install the missing dependencies.", file=sys.stderr)
sys.exit(1)
media = None
subtitle = None
try:
media = tempfile.NamedTemporaryFile(delete=False)
media.close()
subtitle = tempfile.NamedTemporaryFile(delete=False)
subtitle.close()
print(f"Media file: {media.name}")
print(f"Subtitle file: {subtitle.name}\n")
with subprocess.Popen(
[
"edge-tts",
f"--write-media={media.name}",
f"--write-subtitles={subtitle.name}",
]
+ sys.argv[1:]
) as process:
process.communicate()
with subprocess.Popen(
[
"mpv",
f"--sub-file={subtitle.name}",
media.name,
]
) as process:
process.communicate()
finally:
if media is not None:
os.unlink(media.name)
if subtitle is not None:
os.unlink(subtitle.name)
if __name__ == "__main__":
main()

View File

View File

@ -7,7 +7,7 @@ import json
import re
import time
import uuid
from typing import Dict, Generator, List, Optional
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
from xml.sax.saxutils import escape
import aiohttp
@ -96,7 +96,7 @@ def iter_bytes(my_bytes: bytes) -> Generator[bytes, None, None]:
yield my_bytes[i : i + 1]
def split_text_by_byte_length(text: bytes, byte_length: int) -> List[bytes]:
def split_text_by_byte_length(text: str | bytes, byte_length: int) -> List[bytes]:
"""
Splits a string into a list of strings of a given byte length
while attempting to keep words together.
@ -151,7 +151,7 @@ def mkssml(text: str | bytes, voice: str, pitch: str, rate: str, volume: str) ->
return ssml
def date_to_string():
def date_to_string() -> str:
"""
Return Javascript-style date string.
@ -193,7 +193,7 @@ class Communicate:
def __init__(
self,
text: str | List[str],
text: str,
voice: str = "Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)",
*,
pitch: str = "+0Hz",
@ -207,9 +207,9 @@ class Communicate:
Raises:
ValueError: If the voice is not valid.
"""
self.text = text
self.codec = "audio-24khz-48kbitrate-mono-mp3"
self.voice = voice
self.text: str = text
self.codec: str = "audio-24khz-48kbitrate-mono-mp3"
self.voice: str = voice
# Possible values for voice are:
# - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural)
# - cy-GB-NiaNeural
@ -232,19 +232,19 @@ class Communicate:
if re.match(r"^[+-]\d+Hz$", pitch) is None:
raise ValueError(f"Invalid pitch '{pitch}'.")
self.pitch = pitch
self.pitch: str = pitch
if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", rate) is None:
raise ValueError(f"Invalid rate '{rate}'.")
self.rate = rate
self.rate: str = rate
if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", volume) is None:
raise ValueError(f"Invalid volume '{volume}'.")
self.volume = volume
self.volume: str = volume
self.proxy = proxy
self.proxy: Optional[str] = proxy
async def stream(self):
async def stream(self) -> AsyncGenerator[Dict[str, Any], None]:
"""Streams audio and metadata from the service."""
websocket_max_size = 2**16
@ -403,7 +403,7 @@ class Communicate:
async def save(
self, audio_fname: str | bytes, metadata_fname: Optional[str | bytes] = None
):
) -> None:
"""
Save the audio and metadata to the specified files.
"""

View File

@ -3,13 +3,14 @@ list_voices package for edge_tts.
"""
import json
from typing import Any, Optional
import aiohttp
from .constants import VOICE_LIST
async def list_voices(*, proxy=None):
async def list_voices(*, proxy: Optional[str] = None) -> Any:
"""
List all available voices and their attributes.
@ -47,7 +48,7 @@ class VoicesManager:
"""
@classmethod
async def create(cls):
async def create(cls): # type: ignore
"""
Creates a VoicesManager object and populates it with all available voices.
"""
@ -59,12 +60,12 @@ class VoicesManager:
]
return self
def find(self, **kwargs):
def find(self, **kwargs: Any) -> list[dict[str, Any]]:
"""
Finds all matching voices based on the provided attributes.
"""
matching_voices = [
voice for voice in self.voices if kwargs.items() <= voice.items()
voice for voice in self.voices if kwargs.items() <= voice.items() # type: ignore
]
return matching_voices

View File

@ -6,10 +6,11 @@ information provided by the service easier.
"""
import math
from typing import List, Tuple
from xml.sax.saxutils import escape, unescape
def formatter(offset1, offset2, subdata):
def formatter(offset1: float, offset2: float, subdata: str) -> str:
"""
formatter returns the timecode and the text of the subtitle.
"""
@ -19,7 +20,7 @@ def formatter(offset1, offset2, subdata):
)
def mktimestamp(time_unit):
def mktimestamp(time_unit: float) -> str:
"""
mktimestamp returns the timecode of the subtitle.
@ -39,7 +40,7 @@ class SubMaker:
SubMaker class
"""
def __init__(self, overlapping=1):
def __init__(self, overlapping: int = 1) -> None:
"""
SubMaker constructor.
@ -47,10 +48,11 @@ class SubMaker:
overlapping (int): The amount of time in seconds that the
subtitles should overlap.
"""
self.subs_and_offset = []
self.overlapping = overlapping * (10**7)
self.offset: List[Tuple[float, float]] = []
self.subs: List[str] = []
self.overlapping: int = overlapping * (10**7)
def create_sub(self, timestamp, text):
def create_sub(self, timestamp: Tuple[float, float], text: str) -> None:
"""
create_sub creates a subtitle with the given timestamp and text
and adds it to the list of subtitles
@ -62,40 +64,37 @@ class SubMaker:
Returns:
None
"""
timestamp[1] += timestamp[0]
self.subs_and_offset.append(timestamp)
self.subs_and_offset.append(text)
self.offset.append((timestamp[0], timestamp[0] + timestamp[1]))
self.subs.append(text)
def generate_subs(self):
def generate_subs(self) -> str:
"""
generate_subs generates the complete subtitle file.
Returns:
str: The complete subtitle file.
"""
if len(self.subs_and_offset) >= 2:
if len(self.subs) == len(self.offset):
data = "WEBVTT\r\n\r\n"
for offset, subs in zip(
self.subs_and_offset[::2], self.subs_and_offset[1::2]
):
for offset, subs in zip(self.offset, self.subs):
subs = unescape(subs)
subs = [subs[i : i + 79] for i in range(0, len(subs), 79)]
split_subs: List[str] = [subs[i : i + 79] for i in range(0, len(subs), 79)]
for i in range(len(subs) - 1):
sub = subs[i]
for i in range(len(split_subs) - 1):
sub = split_subs[i]
split_at_word = True
if sub[-1] == " ":
subs[i] = sub[:-1]
split_subs[i] = sub[:-1]
split_at_word = False
if sub[0] == " ":
subs[i] = sub[1:]
split_subs[i] = sub[1:]
split_at_word = False
if split_at_word:
subs[i] += "-"
split_subs[i] += "-"
subs = "\r\n".join(subs)
subs = "\r\n".join(split_subs)
data += formatter(offset[0], offset[1] + self.overlapping, subs)
return data

View File

@ -5,12 +5,14 @@ Main package.
import argparse
import asyncio
from io import BufferedWriter
import sys
from typing import Any
from edge_tts import Communicate, SubMaker, list_voices
async def _print_voices(proxy):
async def _print_voices(*, proxy: str) -> None:
"""Print all available voices."""
for idx, voice in enumerate(await list_voices(proxy=proxy)):
if idx != 0:
@ -23,9 +25,9 @@ async def _print_voices(proxy):
print(f"{key}: {voice[key]}")
async def _run_tts(args):
async def _run_tts(args: Any) -> None:
"""Run TTS after parsing arguments from command line."""
tts = await Communicate(
tts = Communicate(
args.text,
args.voice,
proxy=args.proxy,
@ -35,18 +37,17 @@ async def _run_tts(args):
try:
media_file = None
if args.write_media:
# pylint: disable=consider-using-with
media_file = open(args.write_media, "wb")
subs = SubMaker(args.overlapping)
async for data in tts.stream():
if data["type"] == "audio":
if not args.write_media:
sys.stdout.buffer.write(data["data"])
else:
if isinstance(media_file, BufferedWriter):
media_file.write(data["data"])
else:
sys.stdout.buffer.write(data["data"])
elif data["type"] == "WordBoundary":
subs.create_sub([data["offset"], data["duration"]], data["text"])
subs.create_sub((data["offset"], data["duration"]), data["text"])
if not args.write_subtitles:
sys.stderr.write(subs.generate_subs())
@ -58,7 +59,7 @@ async def _run_tts(args):
media_file.close()
async def _async_main():
async def _async_main() -> None:
parser = argparse.ArgumentParser(description="Microsoft Edge TTS")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("-t", "--text", help="what TTS will say")
@ -111,7 +112,7 @@ async def _async_main():
args = parser.parse_args()
if args.list_voices:
await _print_voices(args.proxy)
await _print_voices(proxy=args.proxy)
sys.exit(0)
if args.text is not None or args.file is not None:
@ -129,7 +130,7 @@ async def _async_main():
await _run_tts(args)
def main():
def main() -> None:
"""Run the main function using asyncio."""
asyncio.get_event_loop().run_until_complete(_async_main())