Slight performance improvements and bug fixes

* also add new functionality for edge-playback to keep temp files
* and bump version to 6.0.9
This commit is contained in:
rany2 2023-01-09 17:47:04 +02:00
parent d4da421ef6
commit bd9cc2bd2d
3 changed files with 62 additions and 55 deletions

View File

@ -1,6 +1,6 @@
[metadata]
name = edge-tts
version = 6.0.8
version = 6.0.9
author = rany
author_email = ranygh@riseup.net
description = Microsoft Edge's TTS

View File

@ -4,6 +4,7 @@
Playback TTS with subtitles using edge-tts and mpv.
"""
import os
import subprocess
import sys
import tempfile
@ -22,9 +23,10 @@ def _main() -> None:
print("Please install the missing dependencies.", file=sys.stderr)
sys.exit(1)
keep = os.environ.get("EDGE_PLAYBACK_KEEP_TEMP") is not None
with tempfile.NamedTemporaryFile(
suffix=".mp3", delete=False
) as media, tempfile.NamedTemporaryFile(suffix=".vtt", delete=False) as subtitle:
suffix=".mp3", delete=not keep
) as media, tempfile.NamedTemporaryFile(suffix=".vtt", delete=not keep) as subtitle:
media.close()
subtitle.close()
@ -49,6 +51,9 @@ def _main() -> None:
) as process:
process.communicate()
if keep:
print(f"\nKeeping temporary files: {media.name} and {subtitle.name}")
if __name__ == "__main__":
_main()

View File

@ -34,7 +34,7 @@ from edge_tts.exceptions import (
from .constants import WSS_URL
def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[str, str], bytes]:
def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[bytes, bytes], bytes]:
"""
Returns the headers and data from the given data.
@ -50,14 +50,11 @@ def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[str, str], bytes
raise TypeError("data must be str or bytes")
headers = {}
for line in data.split(b"\r\n\r\n")[0].split(b"\r\n"):
line_split = line.split(b":")
key, value = line_split[0], b":".join(line_split[1:])
if value.startswith(b" "):
value = value[1:]
headers[key.decode("utf-8")] = value.decode("utf-8")
for line in data[: data.find(b"\r\n\r\n")].split(b"\r\n"):
key, value = line.split(b":", 1)
headers[key] = value
return headers, b"\r\n\r\n".join(data.split(b"\r\n\r\n")[1:])
return headers, data[data.find(b"\r\n\r\n") + 4 :]
def remove_incompatible_characters(string: Union[str, bytes]) -> str:
@ -98,55 +95,59 @@ def connect_id() -> str:
return str(uuid.uuid4()).replace("-", "")
def iter_bytes(my_bytes: bytes) -> Generator[bytes, None, None]:
"""
Iterates over bytes object
Args:
my_bytes: Bytes object to iterate over
Yields:
the individual bytes
"""
for i in range(len(my_bytes)):
yield my_bytes[i : i + 1]
def split_text_by_byte_length(text: Union[str, bytes], byte_length: int) -> List[bytes]:
def split_text_by_byte_length(
text: Union[str, bytes], byte_length: int
) -> Generator[bytes, None, None]:
"""
Splits a string into a list of strings of a given byte length
while attempting to keep words together.
while attempting to keep words together. This function assumes
text will be inside of an XML tag.
Args:
text (str or bytes): The string to be split.
byte_length (int): The maximum byte length of each string in the list.
Returns:
list: A list of bytes of the given byte length.
Yield:
bytes: The next string in the list.
"""
if isinstance(text, str):
text = text.encode("utf-8")
if not isinstance(text, bytes):
raise TypeError("text must be str or bytes")
words = []
if byte_length <= 0:
raise ValueError("byte_length must be greater than 0")
while len(text) > byte_length:
# Find the last space in the string
last_space = text.rfind(b" ", 0, byte_length)
if last_space == -1:
# No space found, just split at the byte length
words.append(text[:byte_length])
text = text[byte_length:]
else:
# Split at the last space
words.append(text[:last_space])
text = text[last_space:]
words.append(text)
split_at = text.rfind(b" ", 0, byte_length)
# Remove empty strings from the list
words = [word for word in words if word]
# Return the list
return words
# If no space found, split_at is byte_length
split_at = split_at if split_at != -1 else byte_length
# Verify all & are terminated with a ;
while b"&" in text[:split_at]:
ampersand_index = text.rindex(b"&", 0, split_at)
if text.find(b";", ampersand_index, split_at) != -1:
break
split_at = ampersand_index - 1
if split_at < 0:
raise ValueError("Maximum byte length is too small or invalid text")
if split_at == 0:
break
# Append the string to the list
new_text = text[:split_at].strip()
if new_text:
yield new_text
if split_at == 0:
split_at = 1
text = text[split_at:]
new_text = text.strip()
if new_text:
yield new_text
def mkssml(text: Union[str, bytes], voice: str, rate: str, volume: str) -> str:
@ -352,15 +353,14 @@ class Communicate:
async for received in websocket:
if received.type == aiohttp.WSMsgType.TEXT:
parameters, data = get_headers_and_data(received.data)
if parameters.get("Path") == "turn.start":
path = parameters.get(b"Path")
if path == b"turn.start":
download_audio = True
elif parameters.get("Path") == "turn.end":
elif path == b"turn.end":
download_audio = False
break # End of audio data
elif parameters.get("Path") == "audio.metadata":
meta = json.loads(data)
for i in range(len(meta["Metadata"])):
meta_obj = meta["Metadata"][i]
elif path == b"audio.metadata":
for meta_obj in json.loads(data)["Metadata"]:
meta_type = meta_obj["Type"]
if meta_type == "WordBoundary":
yield {
@ -375,7 +375,7 @@ class Communicate:
raise UnknownResponse(
f"Unknown metadata type: {meta_type}"
)
elif parameters.get("Path") == "response":
elif path == b"response":
pass
else:
raise UnknownResponse(
@ -390,13 +390,15 @@ class Communicate:
yield {
"type": "audio",
"data": b"Path:audio\r\n".join(
received.data.split(b"Path:audio\r\n")[1:]
),
"data": received.data[
received.data.find(b"Path:audio\r\n") + 12 :
],
}
audio_was_received = True
elif received.type == aiohttp.WSMsgType.ERROR:
raise WebSocketError(received.data)
raise WebSocketError(
received.data if received.data else "Unknown error"
)
if not audio_was_received:
raise NoAudioReceived(