diff --git a/assemblyai/__init__.py b/assemblyai/__init__.py index 4662522..7d92a77 100644 --- a/assemblyai/__init__.py +++ b/assemblyai/__init__.py @@ -2,6 +2,7 @@ from .__version__ import __version__ from .client import Client from .lemur import Lemur +from .sync import SyncTranscriber from .transcriber import Transcriber, Transcript, TranscriptGroup from .types import ( AssemblyAIError, @@ -63,6 +64,11 @@ StatusResult, SummarizationModel, SummarizationType, + SyncSpeechModel, + SyncTranscriptError, + SyncTranscriptionConfig, + SyncTranscriptResponse, + SyncWord, Timestamp, TranscriptError, TranscriptionConfig, @@ -138,6 +144,12 @@ "StatusResult", "SummarizationModel", "SummarizationType", + "SyncSpeechModel", + "SyncTranscriber", + "SyncTranscriptError", + "SyncTranscriptionConfig", + "SyncTranscriptResponse", + "SyncWord", "Timestamp", "Transcriber", "TranscriptionConfig", diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py index af24c75..cae757e 100644 --- a/assemblyai/__version__.py +++ b/assemblyai/__version__.py @@ -1 +1 @@ -__version__ = "0.64.7" +__version__ = "0.64.8" diff --git a/assemblyai/sync.py b/assemblyai/sync.py new file mode 100644 index 0000000..c0a2d1e --- /dev/null +++ b/assemblyai/sync.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import concurrent.futures +import os +from typing import BinaryIO, Optional, Tuple, Union +from urllib.parse import urlparse + +from . import client as _client +from . import sync_api, types + +AudioInput = Union[str, bytes, bytearray, "os.PathLike[str]", BinaryIO] + +# Extensions that signal raw S16LE PCM rather than a WAV container. +_PCM_SUFFIXES = (".pcm", ".raw") + + +def _resolve_audio( + data: AudioInput, + config: types.SyncTranscriptionConfig, +) -> Tuple[bytes, str, str]: + """ + Reads the audio input into bytes and decides its multipart Content-Type. + + PCM is selected when the source has a `.pcm`/`.raw` extension or when + `sample_rate`/`channels` are set on the config (the fields the sync API + requires only for raw PCM) — and both must then be present. Everything + else is treated as a WAV container. URLs are rejected — the sync API has + no URL ingestion. + + Returns: `(audio_bytes, filename, content_type)`. + """ + suffix = "" + filename: Optional[str] = None + + if isinstance(data, (bytes, bytearray)): + audio = bytes(data) + elif isinstance(data, (str, os.PathLike)): + path = os.fspath(data) + if urlparse(path).scheme in ("http", "https"): + raise ValueError( + "SyncTranscriber does not accept URLs. Pass a local file path or " + "audio bytes, or use aai.Transcriber for URL/async transcription." + ) + with open(path, "rb") as f: + audio = f.read() + filename = os.path.basename(path) + suffix = os.path.splitext(path)[1].lower() + elif hasattr(data, "read"): + audio = data.read() + name = getattr(data, "name", None) + if name: + filename = os.path.basename(name) + suffix = os.path.splitext(name)[1].lower() + else: + raise TypeError(f"unsupported audio input type: {type(data).__name__}") + + wants_pcm = config.sample_rate is not None or config.channels is not None + is_pcm = suffix in _PCM_SUFFIXES or wants_pcm + if is_pcm and (config.sample_rate is None or config.channels is None): + raise ValueError( + "raw PCM audio requires both sample_rate and channels in " + "SyncTranscriptionConfig" + ) + + content_type = "audio/pcm" if is_pcm else "audio/wav" + if not filename: + filename = "audio.pcm" if is_pcm else "audio.wav" + + return audio, filename, content_type + + +def _config_to_json(config: types.SyncTranscriptionConfig) -> Optional[dict]: + """Serializes the config to the JSON `config` part, dropping the routing model.""" + data = config.dict(exclude_none=True) + data.pop("model", None) + return data or None + + +class _SyncTranscriberImpl: + def __init__( + self, + *, + client: _client.Client, + config: types.SyncTranscriptionConfig, + ) -> None: + self._client = client + self.config = config + + def transcribe( + self, + *, + data: AudioInput, + config: Optional[types.SyncTranscriptionConfig], + ) -> types.SyncTranscriptResponse: + config = config or self.config + audio, filename, content_type = _resolve_audio(data, config) + return sync_api.transcribe( + self._client.http_client, + base_url=self._client.settings.sync_base_url, + audio=audio, + filename=filename, + audio_content_type=content_type, + model=config.model, + config=_config_to_json(config), + timeout=self._client.settings.sync_http_timeout, + ) + + +class SyncTranscriber: + """ + Transcribes audio synchronously: audio in, transcript out, one request. + + Unlike `Transcriber` (which submits a job to the async API and polls for + completion), `SyncTranscriber` posts the audio to the sync API and returns + the finished `SyncTranscriptResponse` directly. There is no job id or + status to poll. Accepts a local file path, raw bytes, or a binary file + object — but not a URL. + + Example: + ```python + import assemblyai as aai + + aai.settings.api_key = "your-key" + + result = aai.SyncTranscriber().transcribe("./call.wav") + print(result.text) + ``` + """ + + def __init__( + self, + *, + client: Optional[_client.Client] = None, + config: Optional[types.SyncTranscriptionConfig] = None, + max_workers: Optional[int] = None, + ) -> None: + """ + Creates a `SyncTranscriber`. + + Args: + client: The HTTP client to use. Defaults to the shared default client. + config: Default transcription options. Per-call `config` overrides it. + max_workers: Thread pool size for `transcribe_async`. Defaults to + the CPU count minus one. + """ + self._client = client or _client.Client.get_default() + self._impl = _SyncTranscriberImpl( + client=self._client, + config=config or types.SyncTranscriptionConfig(), + ) + + if not max_workers: + cpu_count = os.cpu_count() + max_workers = max(1, cpu_count - 1) if cpu_count else 1 + + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers, + ) + + @property + def config(self) -> types.SyncTranscriptionConfig: + """The default configuration of the `SyncTranscriber`.""" + return self._impl.config + + @config.setter + def config(self, config: types.SyncTranscriptionConfig) -> None: + self._impl.config = config + + def transcribe( + self, + data: AudioInput, + config: Optional[types.SyncTranscriptionConfig] = None, + ) -> types.SyncTranscriptResponse: + """ + Transcribes audio and returns the finished transcript. + + Args: + data: A local file path, raw audio bytes, or a binary file object. + Raw PCM also requires `sample_rate` and `channels` on the config. + config: Options for this call. If `None`, the transcriber's default + configuration is used. + + Raises: `SyncTranscriptError` if the request fails. + """ + return self._impl.transcribe(data=data, config=config) + + def transcribe_async( + self, + data: AudioInput, + config: Optional[types.SyncTranscriptionConfig] = None, + ) -> "concurrent.futures.Future[types.SyncTranscriptResponse]": + """ + Transcribes audio on a worker thread. + + Returns a `concurrent.futures.Future` (not an asyncio coroutine); call + `.result()` to block for the transcript. Useful for fanning out a + handful of files concurrently. + """ + return self._executor.submit( + self._impl.transcribe, + data=data, + config=config, + ) diff --git a/assemblyai/sync_api.py b/assemblyai/sync_api.py new file mode 100644 index 0000000..f5d8c2e --- /dev/null +++ b/assemblyai/sync_api.py @@ -0,0 +1,91 @@ +import json +from typing import Optional + +import httpx + +from . import types + +ENDPOINT_TRANSCRIBE = "/transcribe" +MODEL_HEADER = "X-AAI-Model" + + +def _error_from_response(response: httpx.Response) -> types.SyncTranscriptError: + """ + Builds a `SyncTranscriptError` from a non-200 response. + + The service uses two error envelopes: `{"error_code", "message"}` for + audio/capacity/inference errors and `{"detail"}` for auth and rate-limit + errors. Parse by status code, not by assuming `error_code` is present. + """ + error_code: Optional[str] = None + message: Optional[str] = None + + try: + body = response.json() + if isinstance(body, dict): + error_code = body.get("error_code") + message = body.get("message") or body.get("detail") + except Exception: + message = response.text or None + + if not message: + message = f"sync transcription failed with status {response.status_code}" + + retry_after_header = response.headers.get("retry-after") + retry_after = ( + int(retry_after_header) + if retry_after_header and retry_after_header.isdigit() + else None + ) + + return types.SyncTranscriptError( + message, + status_code=response.status_code, + error_code=error_code, + retry_after=retry_after, + ) + + +def transcribe( + client: httpx.Client, + *, + base_url: str, + audio: bytes, + filename: str, + audio_content_type: str, + model: str, + config: Optional[dict], + timeout: float, +) -> types.SyncTranscriptResponse: + """ + Posts a single synchronous transcription request. + + Args: + client: the HTTP client (carries the `Authorization` header). + base_url: the sync API base URL, e.g. `https://sync.assemblyai.com`. + audio: raw audio bytes (WAV container or S16LE PCM). + filename: name for the audio multipart part. + audio_content_type: `audio/wav` or `audio/pcm`; selects the decoder. + model: sent as the `X-AAI-Model` routing header. + config: the JSON `config` part, or None to omit it. + timeout: per-request timeout in seconds. + + Returns: the parsed transcript response. + + Raises: `SyncTranscriptError` on any non-200 response. + """ + files = {"audio": (filename, audio, audio_content_type)} + if config: + files["config"] = (None, json.dumps(config), "application/json") + + response = client.post( + base_url.rstrip("/") + ENDPOINT_TRANSCRIBE, + files=files, + headers={MODEL_HEADER: model}, + timeout=timeout, + ) + + if response.status_code != httpx.codes.OK: + raise _error_from_response(response) + + return types.SyncTranscriptResponse.parse_obj(response.json()) diff --git a/assemblyai/types.py b/assemblyai/types.py index 6f29e36..4fa101b 100644 --- a/assemblyai/types.py +++ b/assemblyai/types.py @@ -84,6 +84,28 @@ class LemurError(AssemblyAIError): """ +class SyncTranscriptError(AssemblyAIError): + """ + Error raised when a synchronous transcription request fails. + + Carries the server's machine-readable `error_code` (e.g. `bad_audio`, + `audio_too_large`, `capacity_exceeded`, `inference_timeout`) when present, + and `retry_after` (seconds) for 429/503 responses that include a + `Retry-After` header. + """ + + def __init__( + self, + message: str, + status_code: Optional[int] = None, + error_code: Optional[str] = None, + retry_after: Optional[int] = None, + ): + super().__init__(message, status_code) + self.error_code = error_code + self.retry_after = retry_after + + class Sourcable: """ A base class for all sourcable objects @@ -106,6 +128,12 @@ class Settings(BaseSettings): base_url: str = "https://api.assemblyai.com" "The base URL for the AssemblyAI API" + sync_base_url: str = "https://sync.assemblyai.com" + "The base URL for the synchronous transcription API (used by `SyncTranscriber`)" + + sync_http_timeout: float = 60.0 + "The HTTP timeout for synchronous transcription requests. Kept above the server's 30s deadline so the client doesn't race it." + polling_interval: float = Field(default=3.0, gt=0.0) "The default polling interval for long-running requests (e.g. polling the `Transcript`'s status)" @@ -2951,3 +2979,103 @@ class LemurPurgeResponse(BaseModel): deleted: bool "The result of the LeMUR purge request" + + +# Caps mirror the sync service's `config` part so an oversized request fails +# locally with a clear message instead of a 400 round trip. +_SYNC_MAX_PROMPT_LEN = 4096 +_SYNC_MAX_WORD_BOOST_LEN = 2048 + + +class SyncSpeechModel(str, Enum): + """Speech models available on the synchronous transcription API.""" + + u3_sync_pro = "u3-sync-pro" + + +class SyncTranscriptionConfig(BaseModel): + """ + Options for a synchronous transcription request. + + Only `prompt` and `word_boost` shape the transcript; `sample_rate` and + `channels` are required only for raw PCM audio (WAV carries them in its + header). `model` selects the sync speech model and is sent as the + `X-AAI-Model` routing header, not in the request body. + """ + + model: str = SyncSpeechModel.u3_sync_pro.value + "The sync speech model to route to. Sent as the `X-AAI-Model` header." + + prompt: Optional[str] = Field(default=None, max_length=_SYNC_MAX_PROMPT_LEN) + "Custom transcription instruction prepended to the model's system prompt. Max 4096 characters." + + word_boost: Optional[List[str]] = None + "Keyterms biasing the decoder. Whitespace is stripped and empty terms dropped. Max 2048 characters total." + + sample_rate: Optional[int] = None + "Source sample rate in Hz. Required for raw PCM audio; ignored for WAV." + + channels: Optional[int] = None + "Channel count (1 mono, 2 stereo). Required for raw PCM audio; ignored for WAV." + + if pydantic_v2: + + @field_validator("word_boost") + @classmethod + def _normalize_word_boost(cls, v): + if not v: + return None + terms = [t.strip() for t in v if t and t.strip()] + total = sum(len(t) for t in terms) + if total > _SYNC_MAX_WORD_BOOST_LEN: + raise ValueError( + f"word_boost exceeds {_SYNC_MAX_WORD_BOOST_LEN} characters (got {total})" + ) + return terms or None + + else: + + @validator("word_boost") + def _normalize_word_boost(cls, v): + if not v: + return None + terms = [t.strip() for t in v if t and t.strip()] + total = sum(len(t) for t in terms) + if total > _SYNC_MAX_WORD_BOOST_LEN: + raise ValueError( + f"word_boost exceeds {_SYNC_MAX_WORD_BOOST_LEN} characters (got {total})" + ) + return terms or None + + +class SyncWord(BaseModel): + """A single word from a synchronous transcript, with timing and confidence.""" + + text: str + start_ms: int + "Word start time in milliseconds." + end_ms: int + "Word end time in milliseconds." + confidence: float + + +class SyncTranscriptResponse(BaseModel): + """The result of a synchronous transcription request.""" + + text: str + "The full transcript text." + + words: List[SyncWord] = Field(default_factory=list) + "Per-word timing and confidence." + + confidence: float + "Overall transcript confidence in the range 0-1." + + audio_duration_ms: int + "Total audio duration in milliseconds." + + inference_time_ms: float + "Model inference time in milliseconds. Excludes auth, decode, and queue wait." + + session_id: str + "Server-generated UUID for this request. Record it to correlate with support." diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py new file mode 100644 index 0000000..91d168f --- /dev/null +++ b/tests/unit/test_sync.py @@ -0,0 +1,185 @@ +import httpx +import pytest +from pytest_httpx import HTTPXMock + +import assemblyai as aai + +aai.settings.api_key = "test" + +TRANSCRIBE_URL = f"{aai.settings.sync_base_url}/transcribe" + +_OK_RESPONSE = { + "text": "hello world", + "words": [ + {"text": "hello", "start_ms": 0, "end_ms": 200, "confidence": 0.9}, + {"text": "world", "start_ms": 220, "end_ms": 400, "confidence": 0.95}, + ], + "confidence": 0.92, + "audio_duration_ms": 400, + "inference_time_ms": 12.5, + "session_id": "eb92c4ff-4bbb-429f-9b99-7279d7fe738f", +} + + +def _mock_ok(httpx_mock: HTTPXMock) -> None: + httpx_mock.add_response( + url=TRANSCRIBE_URL, + method="POST", + status_code=httpx.codes.OK, + json=_OK_RESPONSE, + ) + + +def test_transcribe_bytes_parses_response(httpx_mock: HTTPXMock): + # Given a mocked sync endpoint + _mock_ok(httpx_mock) + + # When transcribing raw audio bytes + result = aai.SyncTranscriber().transcribe(b"RIFFfake-wav-bytes") + + # Then the response is parsed into a SyncTranscriptResponse + assert isinstance(result, aai.SyncTranscriptResponse) + assert result.text == "hello world" + assert result.session_id == _OK_RESPONSE["session_id"] + assert result.words[0].start_ms == 0 + assert result.words[1].text == "world" + + +def test_transcribe_sends_model_header_and_wav_part(httpx_mock: HTTPXMock): + # Given a mocked sync endpoint + _mock_ok(httpx_mock) + + # When transcribing bytes with the default config + aai.SyncTranscriber().transcribe(b"RIFFfake-wav-bytes") + + # Then the request routes via X-AAI-Model and ships a WAV audio part + request = httpx_mock.get_requests()[0] + assert request.headers["X-AAI-Model"] == "u3-sync-pro" + body = request.read() + assert b'name="audio"' in body + assert b"Content-Type: audio/wav" in body + # And no config part is sent when the config is empty + assert b'name="config"' not in body + + +def test_transcribe_sends_prompt_and_word_boost(httpx_mock: HTTPXMock): + # Given a mocked sync endpoint + _mock_ok(httpx_mock) + + # When transcribing with a prompt and word_boost + config = aai.SyncTranscriptionConfig( + prompt="Transcribe verbatim.", + word_boost=["AssemblyAI", " Lemur ", ""], + ) + aai.SyncTranscriber().transcribe(b"RIFFfake-wav-bytes", config=config) + + # Then a config JSON part carries the prompt and normalized word_boost + body = httpx_mock.get_requests()[0].read() + assert b'name="config"' in body + assert b"Transcribe verbatim." in body + assert b'"AssemblyAI"' in body + assert b'"Lemur"' in body # whitespace stripped, empty term dropped + # And the routing model is never placed in the body + assert b'"model"' not in body + + +def test_transcribe_pcm_sends_pcm_part_and_rate(httpx_mock: HTTPXMock): + # Given a mocked sync endpoint + _mock_ok(httpx_mock) + + # When transcribing bytes with sample_rate + channels (raw PCM) + config = aai.SyncTranscriptionConfig(sample_rate=16000, channels=1) + aai.SyncTranscriber().transcribe(b"\x00\x01" * 100, config=config) + + # Then the audio part is PCM and the config carries rate + channels + body = httpx_mock.get_requests()[0].read() + assert b"Content-Type: audio/pcm" in body + assert b'"sample_rate"' in body + assert b'"channels"' in body + + +def test_transcribe_pcm_without_rate_raises(): + # Given a config with sample_rate but no channels (partial PCM intent) + config = aai.SyncTranscriptionConfig(sample_rate=16000) + + # When transcribing, Then it fails locally before any request + with pytest.raises(ValueError, match="sample_rate and channels"): + aai.SyncTranscriber().transcribe(b"\x00\x01" * 100, config=config) + + +def test_transcribe_rejects_url(): + # Given an http URL as input + transcriber = aai.SyncTranscriber() + + # When transcribing, Then it is rejected with a pointer to Transcriber + with pytest.raises(ValueError, match="does not accept URLs"): + transcriber.transcribe("https://example.com/audio.wav") + + +def test_transcribe_path_input(httpx_mock: HTTPXMock, tmp_path): + # Given a local WAV file + _mock_ok(httpx_mock) + audio_file = tmp_path / "call.wav" + audio_file.write_bytes(b"RIFFfake-wav-bytes") + + # When transcribing the path + result = aai.SyncTranscriber().transcribe(str(audio_file)) + + # Then it succeeds and ships the file under its own name + assert result.text == "hello world" + body = httpx_mock.get_requests()[0].read() + assert b'filename="call.wav"' in body + + +def test_word_boost_too_long_raises(): + # Given a word_boost exceeding the 2048-char cap + # When building the config, Then validation fails immediately + with pytest.raises(ValueError, match="word_boost exceeds"): + aai.SyncTranscriptionConfig(word_boost=["x" * 3000]) + + +def test_error_envelope_maps_to_sync_transcript_error(httpx_mock: HTTPXMock): + # Given the server rejects oversized audio + httpx_mock.add_response( + url=TRANSCRIBE_URL, + method="POST", + status_code=413, + json={"error_code": "audio_too_large", "message": "too long"}, + ) + + # When transcribing, Then a SyncTranscriptError carries code + status + with pytest.raises(aai.SyncTranscriptError) as exc_info: + aai.SyncTranscriber().transcribe(b"RIFFfake-wav-bytes") + + error = exc_info.value + assert error.status_code == 413 + assert error.error_code == "audio_too_large" + assert "too long" in str(error) + + +def test_rate_limit_surfaces_retry_after(httpx_mock: HTTPXMock): + # Given a rate-limit response with a Retry-After header + httpx_mock.add_response( + url=TRANSCRIBE_URL, + method="POST", + status_code=429, + json={"detail": "Too many requests"}, + headers={"Retry-After": "5"}, + ) + + # When transcribing, Then retry_after is parsed and error_code is absent + with pytest.raises(aai.SyncTranscriptError) as exc_info: + aai.SyncTranscriber().transcribe(b"RIFFfake-wav-bytes") + + error = exc_info.value + assert error.status_code == 429 + assert error.error_code is None + assert error.retry_after == 5 + + +def test_default_model_is_u3_sync_pro(): + # Given a default config + # When inspecting the model + # Then it is the sync U3-Pro identifier + assert aai.SyncTranscriptionConfig().model == "u3-sync-pro" + assert aai.SyncSpeechModel.u3_sync_pro.value == "u3-sync-pro"