diff --git a/src/edge_playback/__init__.py b/src/edge_playback/__init__.py index 86df7c4..863e4d0 100644 --- a/src/edge_playback/__init__.py +++ b/src/edge_playback/__init__.py @@ -12,44 +12,51 @@ from shutil import which def main(): - """ - Main function. - """ - if which("mpv") and which("edge-tts"): + depcheck_failed = False + if not which("mpv"): + print("mpv is not installed.", file=sys.stderr) + depcheck_failed = True + if not which("edge-tts"): + print("edge-tts is not installed.", file=sys.stderr) + depcheck_failed = True + if depcheck_failed: + print("Please install the missing dependencies.", file=sys.stderr) + sys.exit(1) + + media = None + subtitle = None + try: media = tempfile.NamedTemporaryFile(delete=False) + media.close() + subtitle = tempfile.NamedTemporaryFile(delete=False) - try: - media.close() - subtitle.close() + subtitle.close() - print() - print(f"Media file: {media.name}") - print(f"Subtitle file: {subtitle.name}\n") - with subprocess.Popen( - [ - "edge-tts", - "--boundary-type=1", - f"--write-media={media.name}", - f"--write-subtitles={subtitle.name}", - ] - + sys.argv[1:] - ) as process: - process.communicate() + print(f"Media file: {media.name}") + print(f"Subtitle file: {subtitle.name}\n") + with subprocess.Popen( + [ + "edge-tts", + f"--write-media={media.name}", + f"--write-subtitles={subtitle.name}", + ] + + sys.argv[1:] + ) as process: + process.communicate() - with subprocess.Popen( - [ - "mpv", - "--keep-open=yes", - f"--sub-file={subtitle.name}", - media.name, - ] - ) as process: - process.communicate() - finally: + with subprocess.Popen( + [ + "mpv", + f"--sub-file={subtitle.name}", + media.name, + ] + ) as process: + process.communicate() + finally: + if media is not None: os.unlink(media.name) + if subtitle is not None: os.unlink(subtitle.name) - else: - print("This script requires mpv and edge-tts.") if __name__ == "__main__": diff --git a/src/edge_tts/communicate.py b/src/edge_tts/communicate.py index f81d063..2c14322 100644 --- a/src/edge_tts/communicate.py +++ b/src/edge_tts/communicate.py @@ -4,16 +4,20 @@ Communicate package. import json +import re import time import uuid +from typing import Dict, Generator, List, Optional from xml.sax.saxutils import escape import aiohttp +from edge_tts.exceptions import * + from .constants import WSS_URL -def get_headers_and_data(data): +def get_headers_and_data(data: str | bytes) -> tuple[Dict[str, str], bytes]: """ Returns the headers and data from the given data. @@ -25,6 +29,8 @@ def get_headers_and_data(data): """ if isinstance(data, str): data = data.encode("utf-8") + if not isinstance(data, bytes): + raise TypeError("data must be str or bytes") headers = {} for line in data.split(b"\r\n\r\n")[0].split(b"\r\n"): @@ -37,7 +43,7 @@ def get_headers_and_data(data): return headers, b"\r\n\r\n".join(data.split(b"\r\n\r\n")[1:]) -def remove_incompatible_characters(string): +def remove_incompatible_characters(string: str | bytes) -> str: """ The service does not support a couple character ranges. Most important being the vertical tab character which is @@ -52,31 +58,30 @@ def remove_incompatible_characters(string): """ if isinstance(string, bytes): string = string.decode("utf-8") + if not isinstance(string, str): + raise TypeError("string must be str or bytes") - string = list(string) + chars: List[str] = list(string) - for idx, char in enumerate(string): - code = ord(char) + for idx, char in enumerate(chars): + code: int = ord(char) if (0 <= code <= 8) or (11 <= code <= 12) or (14 <= code <= 31): - string[idx] = " " + chars[idx] = " " - return "".join(string) + return "".join(chars) -def connect_id(): +def connect_id() -> str: """ Returns a UUID without dashes. - Args: - None - Returns: str: A UUID without dashes. """ return str(uuid.uuid4()).replace("-", "") -def iter_bytes(my_bytes): +def iter_bytes(my_bytes: bytes) -> Generator[bytes, None, None]: """ Iterates over bytes object @@ -90,20 +95,22 @@ def iter_bytes(my_bytes): yield my_bytes[i : i + 1] -def split_text_by_byte_length(text, byte_length): +def split_text_by_byte_length(text: bytes, byte_length: int) -> List[bytes]: """ Splits a string into a list of strings of a given byte length while attempting to keep words together. Args: - text (byte): The string to be split. - byte_length (int): The byte length of each string in the list. + text (str or bytes): The string to be split. + byte_length (int): The maximum byte length of each string in the list. Returns: - list: A list of strings of the given byte length. + list: A list of bytes of the given byte length. """ if isinstance(text, str): text = text.encode("utf-8") + if not isinstance(text, bytes): + raise TypeError("text must be str or bytes") words = [] while len(text) > byte_length: @@ -125,17 +132,10 @@ def split_text_by_byte_length(text, byte_length): return words -def mkssml(text, voice, pitch, rate, volume): +def mkssml(text: str | bytes, voice: str, pitch: str, rate: str, volume: str) -> str: """ Creates a SSML string from the given parameters. - Args: - text (str): The text to be spoken. - voice (str): The voice to be used. - pitch (str): The pitch to be used. - rate (str): The rate to be used. - volume (str): The volume to be used. - Returns: str: The SSML string. """ @@ -154,9 +154,6 @@ def date_to_string(): """ Return Javascript-style date string. - Args: - None - Returns: str: Javascript-style date string. """ @@ -171,15 +168,10 @@ def date_to_string(): ) -def ssml_headers_plus_data(request_id, timestamp, ssml): +def ssml_headers_plus_data(request_id: str, timestamp: str, ssml: str) -> str: """ Returns the headers and data to be used in the request. - Args: - request_id (str): The request ID. - timestamp (str): The timestamp. - ssml (str): The SSML string. - Returns: str: The headers and data to be used in the request. """ @@ -198,73 +190,86 @@ class Communicate: Class for communicating with the service. """ - def __init__(self): - """ - Initializes the Communicate class. - """ - self.date = date_to_string() - - async def run( + def __init__( self, - messages, - boundary_type=0, - codec="audio-24khz-48kbitrate-mono-mp3", - voice="Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)", - pitch="+0Hz", - rate="+0%", - volume="+0%", - proxy=None, + text: str | List[str], + voice: str = "Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)", + *, + pitch: str = "+0Hz", + rate: str = "+0%", + volume: str = "+0%", + proxy: Optional[str] = None, ): """ - Runs the Communicate class. + Initializes the Communicate class. - Args: - messages (str or list): A list of SSML strings or a single text. - boundery_type (int): The type of boundary to use. 0 for none, 1 for word_boundary, 2 for sentence_boundary. - codec (str): The codec to use. - voice (str): The voice to use. - pitch (str): The pitch to use. - rate (str): The rate to use. - volume (str): The volume to use. - - Yields: - tuple: The subtitle offset, subtitle, and audio data. + Raises: + ValueError: If the voice is not valid. """ - - word_boundary = False - - if boundary_type > 0: - word_boundary = True - if boundary_type > 1: - raise ValueError( - "Invalid boundary type. SentenceBoundary is no longer supported." + self.text = text + self.boundary_type = 1 + self.codec = "audio-24khz-48kbitrate-mono-mp3" + self.voice = voice + # Possible values for voice are: + # - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural) + # - cy-GB-NiaNeural + # Always send the first variant as that is what Microsoft Edge does. + match = re.match(r"^([a-z]{2})-([A-Z]{2})-(.+Neural)$", voice) + if match is not None: + self.voice = ( + "Microsoft Server Speech Text to Speech Voice" + + f" ({match.group(1)}-{match.group(2)}, {match.group(3)})" ) - word_boundary = str(word_boundary).lower() + if ( + re.match( + r"^Microsoft Server Speech Text to Speech Voice \(.+,.+\)$", + self.voice, + ) + is None + ): + raise ValueError(f"Invalid voice '{voice}'.") - websocket_max_size = 2 ** 16 + if re.match(r"^[+-]\d+Hz$", pitch) is None: + raise ValueError(f"Invalid pitch '{pitch}'.") + self.pitch = pitch + + if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", rate) is None: + raise ValueError(f"Invalid rate '{rate}'.") + self.rate = rate + + if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", volume) is None: + raise ValueError(f"Invalid volume '{volume}'.") + self.volume = volume + + self.proxy = proxy + + async def stream(self): + """Streams audio and metadata from the service.""" + + websocket_max_size = 2**16 overhead_per_message = ( len( ssml_headers_plus_data( - connect_id(), self.date, mkssml("", voice, pitch, rate, volume) + connect_id(), + date_to_string(), + mkssml("", self.voice, self.pitch, self.rate, self.volume), ) ) - + 50 - ) # margin of error - messages = split_text_by_byte_length( - escape(remove_incompatible_characters(messages)), + + 50 # margin of error + ) + texts = split_text_by_byte_length( + escape(remove_incompatible_characters(self.text)), websocket_max_size - overhead_per_message, ) - # Variables for the loop - download = False async with aiohttp.ClientSession(trust_env=True) as session: async with session.ws_connect( f"{WSS_URL}&ConnectionId={connect_id()}", compress=15, autoclose=True, autoping=True, - proxy=proxy, + proxy=self.proxy, headers={ "Pragma": "no-cache", "Cache-Control": "no-cache", @@ -275,9 +280,19 @@ class Communicate: " (KHTML, like Gecko) Chrome/91.0.4472.77 Safari/537.36 Edg/91.0.864.41", }, ) as websocket: - for message in messages: + for text in texts: + # download indicates whether we should be expecting audio data, + # this is so what we avoid getting binary data from the websocket + # and falsely thinking it's audio data. + download = False + + # audio_was_received indicates whether we have received audio data + # from the websocket. This is so we can raise an exception if we + # don't receive any audio data. + audio_was_received = False + # Each message needs to have the proper date - self.date = date_to_string() + date = date_to_string() # Prepare the request to be sent to the service. # @@ -290,26 +305,26 @@ class Communicate: # # Also pay close attention to double { } in request (escape for f-string). request = ( - f"X-Timestamp:{self.date}\r\n" + f"X-Timestamp:{date}\r\n" "Content-Type:application/json; charset=utf-8\r\n" "Path:speech.config\r\n\r\n" '{"context":{"synthesis":{"audio":{"metadataoptions":{' - f'"sentenceBoundaryEnabled":false,' - f'"wordBoundaryEnabled":{word_boundary}}},"outputFormat":"{codec}"' + '"sentenceBoundaryEnabled":false,"wordBoundaryEnabled":true},' + f'"outputFormat":"{self.codec}"' "}}}}\r\n" ) - # Send the request to the service. await websocket.send_str(request) - # Send the message itself. + await websocket.send_str( ssml_headers_plus_data( connect_id(), - self.date, - mkssml(message, voice, pitch, rate, volume), + date, + mkssml( + text, self.voice, self.pitch, self.rate, self.volume + ), ) ) - # Begin listening for the response. async for received in websocket: if received.type == aiohttp.WSMsgType.TEXT: parameters, data = get_headers_and_data(received.data) @@ -329,35 +344,34 @@ class Communicate: and parameters["Path"] == "audio.metadata" ): metadata = json.loads(data) - metadata_type = metadata["Metadata"][0]["Type"] - metadata_offset = metadata["Metadata"][0]["Data"][ - "Offset" - ] - if metadata_type == "WordBoundary": - metadata_duration = metadata["Metadata"][0]["Data"][ - "Duration" + for i in range(len(metadata["Metadata"])): + metadata_type = metadata["Metadata"][i]["Type"] + metadata_offset = metadata["Metadata"][i]["Data"][ + "Offset" ] - metadata_text = metadata["Metadata"][0]["Data"][ - "text" - ]["Text"] - yield ( - [ - metadata_offset, - metadata_duration, - ], - metadata_text, - None, - ) - elif metadata_type == "SentenceBoundary": - raise NotImplementedError( - "SentenceBoundary is not supported due to being broken." - ) - elif metadata_type == "SessionEnd": - continue - else: - raise NotImplementedError( - f"Unknown metadata type: {metadata_type}" - ) + if metadata_type == "WordBoundary": + metadata_duration = metadata["Metadata"][i][ + "Data" + ]["Duration"] + metadata_text = metadata["Metadata"][i]["Data"][ + "text" + ]["Text"] + yield { + "type": metadata_type, + "offset": metadata_offset, + "duration": metadata_duration, + "text": metadata_text, + } + elif metadata_type == "SentenceBoundary": + raise UnknownResponse( + "SentenceBoundary is not supported due to being broken." + ) + elif metadata_type == "SessionEnd": + continue + else: + raise UnknownResponse( + f"Unknown metadata type: {metadata_type}" + ) elif ( "Path" in parameters and parameters["Path"] == "response" @@ -368,25 +382,60 @@ class Communicate: Content-Type:application/json; charset=utf-8 Path:response - {"context":{"serviceTag":"yyyyyyyyyyyyyyyyyyy"},"audio":{"type":"inline","streamId":"zzzzzzzzzzzzzzzzz"}} + {"context":{"serviceTag":"yyyyyyyyyyyyyyyyyyy"},"audio": + {"type":"inline","streamId":"zzzzzzzzzzzzzzzzz"}} """ pass else: - raise ValueError( + raise UnknownResponse( "The response from the service is not recognized.\n" + received.data ) elif received.type == aiohttp.WSMsgType.BINARY: if download: - yield ( - None, - None, - b"Path:audio\r\n".join( + yield { + "type": "audio", + "data": b"Path:audio\r\n".join( received.data.split(b"Path:audio\r\n")[1:] ), - ) + } + audio_was_received = True else: - raise ValueError( + raise UnexpectedResponse( "The service sent a binary message, but we are not expecting one." ) - await websocket.close() + + if not audio_was_received: + raise NoAudioReceived( + "No audio was received from the service. Please verify that your parameters are correct." + ) + + async def save( + self, audio_fname: str | bytes, metadata_fname: Optional[str | bytes] = None + ): + """ + Save the audio and metadata to the specified files. + """ + written_audio = False + try: + audio = open(audio_fname, "wb") + metadata = None + if metadata_fname is not None: + metadata = open(metadata_fname, "w") + + async for message in self.stream(): + if message["type"] == "audio": + audio.write(message["data"]) + written_audio = True + elif metadata is not None and message["type"] == "WordBoundary": + json.dump(message, metadata) + metadata.write("\n") + finally: + audio.close() + if metadata is not None: + metadata.close() + + if not written_audio: + raise NoAudioReceived( + "No audio was received from the service, so the file is empty." + ) diff --git a/src/edge_tts/exceptions.py b/src/edge_tts/exceptions.py new file mode 100644 index 0000000..c37c55a --- /dev/null +++ b/src/edge_tts/exceptions.py @@ -0,0 +1,13 @@ +class UnknownResponse(Exception): + """Raised when an unknown response is received from the server.""" + + +class UnexpectedResponse(Exception): + """Raised when an unexpected response is received from the server. + + This hasn't happened yet, but it's possible that the server will + change its response format in the future.""" + + +class NoAudioReceived(Exception): + """Raised when no audio is received from the server.""" diff --git a/src/edge_tts/submaker.py b/src/edge_tts/submaker.py index 6988518..5a432c3 100644 --- a/src/edge_tts/submaker.py +++ b/src/edge_tts/submaker.py @@ -28,9 +28,9 @@ def mktimestamp(time_unit): Returns: str: The timecode of the subtitle. """ - hour = math.floor(time_unit / 10 ** 7 / 3600) - minute = math.floor((time_unit / 10 ** 7 / 60) % 60) - seconds = (time_unit / 10 ** 7) % 60 + hour = math.floor(time_unit / 10**7 / 3600) + minute = math.floor((time_unit / 10**7 / 60) % 60) + seconds = (time_unit / 10**7) % 60 return f"{hour:02d}:{minute:02d}:{seconds:06.3f}" @@ -48,7 +48,7 @@ class SubMaker: subtitles should overlap. """ self.subs_and_offset = [] - self.overlapping = overlapping * (10 ** 7) + self.overlapping = overlapping * (10**7) def create_sub(self, timestamp, text): """ diff --git a/src/edge_tts/util.py b/src/edge_tts/util.py index 6a4a29f..7f55ed5 100644 --- a/src/edge_tts/util.py +++ b/src/edge_tts/util.py @@ -11,9 +11,6 @@ from edge_tts import Communicate, SubMaker, list_voices async def _list_voices(proxy): - """ - List available voices. - """ for idx, voice in enumerate(await list_voices(proxy=proxy)): if idx != 0: print() @@ -26,34 +23,36 @@ async def _list_voices(proxy): async def _tts(args): - tts = Communicate() - subs = SubMaker(args.overlapping) - if args.write_media: - media_file = open(args.write_media, "wb") # pylint: disable=consider-using-with - async for i in tts.run( + tts = await Communicate( args.text, - args.boundary_type, - args.codec, args.voice, - args.pitch, - args.rate, - args.volume, proxy=args.proxy, - ): - if i[2] is not None: - if not args.write_media: - sys.stdout.buffer.write(i[2]) - else: - media_file.write(i[2]) - elif i[0] is not None and i[1] is not None: - subs.create_sub(i[0], i[1]) - if args.write_media: - media_file.close() - if not args.write_subtitles: - sys.stderr.write(subs.generate_subs()) - else: - with open(args.write_subtitles, "w", encoding="utf-8") as file: - file.write(subs.generate_subs()) + rate=args.rate, + volume=args.volume, + ) + try: + media_file = None + if args.write_media: + media_file = open(args.write_media, "wb") + + subs = SubMaker(args.overlapping) + async for data in tts.stream(): + if data["type"] == "audio": + if not args.write_media: + sys.stdout.buffer.write(data["data"]) + else: + media_file.write(data["data"]) + elif data["type"] == "WordBoundary": + subs.create_sub([data["offset"], data["duration"]], data["text"]) + + if not args.write_subtitles: + sys.stderr.write(subs.generate_subs()) + else: + with open(args.write_subtitles, "w", encoding="utf-8") as file: + file.write(subs.generate_subs()) + finally: + if media_file is not None: + media_file.close() async def _main(): @@ -64,23 +63,13 @@ async def _main(): parser.add_argument( "-v", "--voice", - help="voice for TTS. " - "Default: Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)", - default="Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)", - ) - parser.add_argument( - "-c", - "--codec", - help="codec format. Default: audio-24khz-48kbitrate-mono-mp3. " - "Another choice is webm-24khz-16bit-mono-opus. " - "For more info check https://bit.ly/2T33h6S", - default="audio-24khz-48kbitrate-mono-mp3", + help="voice for TTS. " "Default: en-US-AriaNeural", + default="en-US-AriaNeural", ) group.add_argument( "-l", "--list-voices", - help="lists available voices. " - "Edge's list is incomplete so check https://bit.ly/2SFq1d3", + help="lists available voices", action="store_true", ) parser.add_argument( @@ -109,32 +98,19 @@ async def _main(): type=float, ) parser.add_argument( - "-b", - "--boundary-type", - help="set boundary type for subtitles. Default 0 for none. Set 1 for word_boundary.", - default=0, - type=int, - ) - parser.add_argument( - "--write-media", help="instead of stdout, send media output to provided file" + "--write-media", help="send media output to file instead of stdout" ) parser.add_argument( "--write-subtitles", - help="instead of stderr, send subtitle output to provided file (implies boundary-type is 1)", - ) - parser.add_argument( - "--proxy", - help="proxy", + help="send subtitle output to provided file instead of stderr", ) + parser.add_argument("--proxy", help="use a proxy for TTS and voice list.") args = parser.parse_args() if args.list_voices: await _list_voices(args.proxy) sys.exit(0) - if args.write_subtitles and args.boundary_type == 0: - args.boundary_type = 1 - if args.text is not None or args.file is not None: if args.file is not None: # we need to use sys.stdin.read() because some devices @@ -151,9 +127,6 @@ async def _main(): def main(): - """ - Main function. - """ asyncio.get_event_loop().run_until_complete(_main())