mirror of
https://github.com/rany2/edge-tts
synced 2024-11-22 01:45:02 +00:00
Slight performance improvements and bug fixes
* also add new functionality for edge-playback to keep temp files * and bump version to 6.0.9
This commit is contained in:
parent
d4da421ef6
commit
bd9cc2bd2d
@ -1,6 +1,6 @@
|
||||
[metadata]
|
||||
name = edge-tts
|
||||
version = 6.0.8
|
||||
version = 6.0.9
|
||||
author = rany
|
||||
author_email = ranygh@riseup.net
|
||||
description = Microsoft Edge's TTS
|
||||
|
@ -4,6 +4,7 @@
|
||||
Playback TTS with subtitles using edge-tts and mpv.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@ -22,9 +23,10 @@ def _main() -> None:
|
||||
print("Please install the missing dependencies.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
keep = os.environ.get("EDGE_PLAYBACK_KEEP_TEMP") is not None
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".mp3", delete=False
|
||||
) as media, tempfile.NamedTemporaryFile(suffix=".vtt", delete=False) as subtitle:
|
||||
suffix=".mp3", delete=not keep
|
||||
) as media, tempfile.NamedTemporaryFile(suffix=".vtt", delete=not keep) as subtitle:
|
||||
media.close()
|
||||
subtitle.close()
|
||||
|
||||
@ -49,6 +51,9 @@ def _main() -> None:
|
||||
) as process:
|
||||
process.communicate()
|
||||
|
||||
if keep:
|
||||
print(f"\nKeeping temporary files: {media.name} and {subtitle.name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_main()
|
||||
|
@ -34,7 +34,7 @@ from edge_tts.exceptions import (
|
||||
from .constants import WSS_URL
|
||||
|
||||
|
||||
def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[str, str], bytes]:
|
||||
def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[bytes, bytes], bytes]:
|
||||
"""
|
||||
Returns the headers and data from the given data.
|
||||
|
||||
@ -50,14 +50,11 @@ def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[str, str], bytes
|
||||
raise TypeError("data must be str or bytes")
|
||||
|
||||
headers = {}
|
||||
for line in data.split(b"\r\n\r\n")[0].split(b"\r\n"):
|
||||
line_split = line.split(b":")
|
||||
key, value = line_split[0], b":".join(line_split[1:])
|
||||
if value.startswith(b" "):
|
||||
value = value[1:]
|
||||
headers[key.decode("utf-8")] = value.decode("utf-8")
|
||||
for line in data[: data.find(b"\r\n\r\n")].split(b"\r\n"):
|
||||
key, value = line.split(b":", 1)
|
||||
headers[key] = value
|
||||
|
||||
return headers, b"\r\n\r\n".join(data.split(b"\r\n\r\n")[1:])
|
||||
return headers, data[data.find(b"\r\n\r\n") + 4 :]
|
||||
|
||||
|
||||
def remove_incompatible_characters(string: Union[str, bytes]) -> str:
|
||||
@ -98,55 +95,59 @@ def connect_id() -> str:
|
||||
return str(uuid.uuid4()).replace("-", "")
|
||||
|
||||
|
||||
def iter_bytes(my_bytes: bytes) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Iterates over bytes object
|
||||
|
||||
Args:
|
||||
my_bytes: Bytes object to iterate over
|
||||
|
||||
Yields:
|
||||
the individual bytes
|
||||
"""
|
||||
for i in range(len(my_bytes)):
|
||||
yield my_bytes[i : i + 1]
|
||||
|
||||
|
||||
def split_text_by_byte_length(text: Union[str, bytes], byte_length: int) -> List[bytes]:
|
||||
def split_text_by_byte_length(
|
||||
text: Union[str, bytes], byte_length: int
|
||||
) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Splits a string into a list of strings of a given byte length
|
||||
while attempting to keep words together.
|
||||
while attempting to keep words together. This function assumes
|
||||
text will be inside of an XML tag.
|
||||
|
||||
Args:
|
||||
text (str or bytes): The string to be split.
|
||||
byte_length (int): The maximum byte length of each string in the list.
|
||||
|
||||
Returns:
|
||||
list: A list of bytes of the given byte length.
|
||||
Yield:
|
||||
bytes: The next string in the list.
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
text = text.encode("utf-8")
|
||||
if not isinstance(text, bytes):
|
||||
raise TypeError("text must be str or bytes")
|
||||
|
||||
words = []
|
||||
if byte_length <= 0:
|
||||
raise ValueError("byte_length must be greater than 0")
|
||||
|
||||
while len(text) > byte_length:
|
||||
# Find the last space in the string
|
||||
last_space = text.rfind(b" ", 0, byte_length)
|
||||
if last_space == -1:
|
||||
# No space found, just split at the byte length
|
||||
words.append(text[:byte_length])
|
||||
text = text[byte_length:]
|
||||
else:
|
||||
# Split at the last space
|
||||
words.append(text[:last_space])
|
||||
text = text[last_space:]
|
||||
words.append(text)
|
||||
split_at = text.rfind(b" ", 0, byte_length)
|
||||
|
||||
# Remove empty strings from the list
|
||||
words = [word for word in words if word]
|
||||
# Return the list
|
||||
return words
|
||||
# If no space found, split_at is byte_length
|
||||
split_at = split_at if split_at != -1 else byte_length
|
||||
|
||||
# Verify all & are terminated with a ;
|
||||
while b"&" in text[:split_at]:
|
||||
ampersand_index = text.rindex(b"&", 0, split_at)
|
||||
if text.find(b";", ampersand_index, split_at) != -1:
|
||||
break
|
||||
|
||||
split_at = ampersand_index - 1
|
||||
if split_at < 0:
|
||||
raise ValueError("Maximum byte length is too small or invalid text")
|
||||
if split_at == 0:
|
||||
break
|
||||
|
||||
# Append the string to the list
|
||||
new_text = text[:split_at].strip()
|
||||
if new_text:
|
||||
yield new_text
|
||||
if split_at == 0:
|
||||
split_at = 1
|
||||
text = text[split_at:]
|
||||
|
||||
new_text = text.strip()
|
||||
if new_text:
|
||||
yield new_text
|
||||
|
||||
|
||||
def mkssml(text: Union[str, bytes], voice: str, rate: str, volume: str) -> str:
|
||||
@ -352,15 +353,14 @@ class Communicate:
|
||||
async for received in websocket:
|
||||
if received.type == aiohttp.WSMsgType.TEXT:
|
||||
parameters, data = get_headers_and_data(received.data)
|
||||
if parameters.get("Path") == "turn.start":
|
||||
path = parameters.get(b"Path")
|
||||
if path == b"turn.start":
|
||||
download_audio = True
|
||||
elif parameters.get("Path") == "turn.end":
|
||||
elif path == b"turn.end":
|
||||
download_audio = False
|
||||
break # End of audio data
|
||||
elif parameters.get("Path") == "audio.metadata":
|
||||
meta = json.loads(data)
|
||||
for i in range(len(meta["Metadata"])):
|
||||
meta_obj = meta["Metadata"][i]
|
||||
elif path == b"audio.metadata":
|
||||
for meta_obj in json.loads(data)["Metadata"]:
|
||||
meta_type = meta_obj["Type"]
|
||||
if meta_type == "WordBoundary":
|
||||
yield {
|
||||
@ -375,7 +375,7 @@ class Communicate:
|
||||
raise UnknownResponse(
|
||||
f"Unknown metadata type: {meta_type}"
|
||||
)
|
||||
elif parameters.get("Path") == "response":
|
||||
elif path == b"response":
|
||||
pass
|
||||
else:
|
||||
raise UnknownResponse(
|
||||
@ -390,13 +390,15 @@ class Communicate:
|
||||
|
||||
yield {
|
||||
"type": "audio",
|
||||
"data": b"Path:audio\r\n".join(
|
||||
received.data.split(b"Path:audio\r\n")[1:]
|
||||
),
|
||||
"data": received.data[
|
||||
received.data.find(b"Path:audio\r\n") + 12 :
|
||||
],
|
||||
}
|
||||
audio_was_received = True
|
||||
elif received.type == aiohttp.WSMsgType.ERROR:
|
||||
raise WebSocketError(received.data)
|
||||
raise WebSocketError(
|
||||
received.data if received.data else "Unknown error"
|
||||
)
|
||||
|
||||
if not audio_was_received:
|
||||
raise NoAudioReceived(
|
||||
|
Loading…
Reference in New Issue
Block a user