Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion assemblyai/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.64.4"
__version__ = "0.64.7"
4 changes: 4 additions & 0 deletions assemblyai/streaming/v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
LLMGatewayResponseEvent,
NoiseSuppressionModel,
SpeakerRevisionEvent,
SpeakerRevisionItem,
SpeechModel,
SpeechStartedEvent,
StreamingClientOptions,
StreamingError,
StreamingErrorCodes,
StreamingEvents,
StreamingMode,
StreamingParameters,
StreamingPiiPolicy,
StreamingPiiSubstitution,
Expand All @@ -31,13 +33,15 @@
"LLMGatewayResponseEvent",
"NoiseSuppressionModel",
"SpeakerRevisionEvent",
"SpeakerRevisionItem",
"SpeechModel",
"SpeechStartedEvent",
"StreamingClient",
"StreamingClientOptions",
"StreamingError",
"StreamingErrorCodes",
"StreamingEvents",
"StreamingMode",
"StreamingParameters",
"StreamingPiiPolicy",
"StreamingPiiSubstitution",
Expand Down
36 changes: 29 additions & 7 deletions assemblyai/streaming/v3/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,31 @@ class LLMGatewayResponseEvent(BaseModel):
data: Any


class SpeakerRevisionEvent(BaseModel):
"""Server-side correction to a previously-emitted Turn's speaker labels.
class SpeakerRevisionItem(BaseModel):
"""A single Turn whose speaker labels were revised by reclustering.

Emitted after offline reclustering refines the live tentative labels.
Match by `turn_order` against the original Turn; replace its per-word
speaker assignments (and the turn-level `speaker_label`) with these.
Text and word timestamps are unchanged from the original Turn.
speaker assignments (and the turn-level `speaker_label`) with these. Text
and word timestamps are unchanged from the original Turn.
"""

type: Literal["SpeakerRevision"] = "SpeakerRevision"
turn_order: int
speaker_label: Optional[str] = None
words: List[Word] = []


class SpeakerRevisionEvent(BaseModel):
"""Server-side correction to previously-emitted Turns' speaker labels.

Emitted once per offline-recluster resolve. `revisions` carries one entry
per earlier Turn whose label actually changed (unchanged turns are
omitted). Apply each entry by matching its `turn_order`.
"""

type: Literal["SpeakerRevision"] = "SpeakerRevision"
revisions: List[SpeakerRevisionItem] = []


EventMessage = Union[
BeginEvent,
TerminationEvent,
Expand Down Expand Up @@ -121,6 +131,7 @@ class StreamingSessionParameters(BaseModel):
keyterms_prompt: Optional[List[str]] = None
filter_profanity: Optional[bool] = None
prompt: Optional[str] = None
agent_context: Optional[str] = None
interruption_delay: Optional[int] = None
turn_left_pad_ms: Optional[int] = None

Expand All @@ -137,6 +148,7 @@ class SpeechModel(str, Enum):
universal_streaming_multilingual = "universal-streaming-multilingual"
universal_streaming_english = "universal-streaming-english"
u3_rt_pro = "u3-rt-pro"
u3_rt_pro_beta_1 = "u3-rt-pro-beta-1"
whisper_rt = "whisper-rt"
u3_pro = "u3-pro" # Deprecated: Use u3_rt_pro instead

Expand All @@ -159,6 +171,15 @@ def __str__(self):
return self.value


class StreamingMode(str, Enum):
max_accuracy = "max_accuracy"
min_latency = "min_latency"
balanced = "balanced"

def __str__(self):
return self.value


class StreamingPiiSubstitution(str, Enum):
hash = "hash"
entity_name = "entity_name"
Expand Down Expand Up @@ -223,7 +244,7 @@ def __str__(self):
class StreamingParameters(StreamingSessionParameters):
sample_rate: int
encoding: Optional[Encoding] = None
speech_model: SpeechModel
speech_model: Optional[SpeechModel] = None
language_detection: Optional[bool] = None
domain: Optional[StreamingDomain] = None
inactivity_timeout: Optional[int] = None
Expand All @@ -244,6 +265,7 @@ class StreamingParameters(StreamingSessionParameters):
redact_pii: Optional[bool] = None
redact_pii_policies: Optional[List[StreamingPiiPolicy]] = None
redact_pii_sub: Optional[StreamingPiiSubstitution] = None
mode: Optional[StreamingMode] = None


class UpdateConfiguration(StreamingSessionParameters):
Expand Down
134 changes: 96 additions & 38 deletions tests/unit/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from urllib.parse import urlencode

import pytest
from pydantic import ValidationError
from pytest_mock import MockFixture
from websockets.exceptions import ConnectionClosed, InvalidStatus
from websockets.frames import Close
Expand All @@ -19,12 +18,14 @@
StreamingClient,
StreamingClientOptions,
StreamingEvents,
StreamingMode,
StreamingParameters,
StreamingPiiPolicy,
StreamingPiiSubstitution,
TurnEvent,
Word,
)
from assemblyai.streaming.v3._base import _build_uri
from assemblyai.streaming.v3.models import TerminateSession


Expand Down Expand Up @@ -245,6 +246,37 @@ def mocked_websocket_connect(
assert "noise_suppression_threshold" not in actual_url


def test_client_connect_with_mode(mocker: MockFixture):
# Given: client + mode parameter
actual_url = None

def mocked_websocket_connect(
url: str, additional_headers: dict, open_timeout: float
):
nonlocal actual_url
actual_url = url

mocker.patch(
"assemblyai.streaming.v3.client.websocket_connect",
new=mocked_websocket_connect,
)
_disable_rw_threads(mocker)
client = StreamingClient(
StreamingClientOptions(api_key="test", api_host="api.example.com")
)
params = StreamingParameters(
sample_rate=16000,
speech_model=SpeechModel.u3_rt_pro,
mode=StreamingMode.max_accuracy,
)

# When: connect
client.connect(params)

# Then: the mode wire param is present with its underscore value
assert "mode=max_accuracy" in actual_url


def test_noise_suppression_deprecated_alias_migrates_to_voice_focus(
mocker: MockFixture, caplog: pytest.LogCaptureFixture
):
Expand Down Expand Up @@ -568,6 +600,7 @@ def mocked_websocket_connect(
speech_model=SpeechModel.u3_pro,
min_end_of_turn_silence_when_confident=200,
prompt="Transcribe this audio with beautiful punctuation and formatting.",
agent_context="What is your account number?",
keyterms_prompt=["yes", "no", "okay"],
)

Expand All @@ -579,6 +612,7 @@ def mocked_websocket_connect(
assert "min_turn_silence=200" in actual_url
assert "min_end_of_turn_silence_when_confident" not in actual_url
assert "prompt=Transcribe" in actual_url
assert "agent_context=What" in actual_url
assert "keyterms_prompt=" in actual_url # keyterms_prompt is JSON-encoded

assert actual_additional_headers["Authorization"] == "test"
Expand Down Expand Up @@ -913,57 +947,73 @@ def test_turn_event_with_word_speakers():


def test_speaker_revision_event_parses():
# Given: a SpeakerRevision payload as emitted by the server (revision words
# use the same Word schema as Turn — start/end/confidence/text/word_is_final/speaker)
# Given: a SpeakerRevision payload as emitted by the server — one message
# per recluster resolve carrying a list of revised turns. Revision words
# use the same Word schema as Turn (start/end/confidence/text/word_is_final/speaker).
data = {
"type": "SpeakerRevision",
"turn_order": 3,
"speaker_label": "B",
"words": [
"revisions": [
{
"start": 1000,
"end": 1200,
"confidence": 0.9,
"text": "hello",
"word_is_final": True,
"speaker": "B",
"turn_order": 3,
"speaker_label": "B",
"words": [
{
"start": 1000,
"end": 1200,
"confidence": 0.9,
"text": "hello",
"word_is_final": True,
"speaker": "B",
},
{
"start": 1210,
"end": 1400,
"confidence": 0.88,
"text": "world",
"word_is_final": True,
"speaker": "A",
},
],
},
{
"start": 1210,
"end": 1400,
"confidence": 0.88,
"text": "world",
"word_is_final": True,
"speaker": "A",
"turn_order": 7,
"speaker_label": "A",
"words": [],
},
],
}

# When: parsed
event = SpeakerRevisionEvent.parse_obj(data)

# Then: the revision carries the corrected per-word and turn-level speakers
# Then: each revision carries the corrected per-word and turn-level speakers
assert event.type == "SpeakerRevision"
assert event.turn_order == 3
assert event.speaker_label == "B"
assert [w.speaker for w in event.words] == ["B", "A"]
assert [r.turn_order for r in event.revisions] == [3, 7]
assert event.revisions[0].speaker_label == "B"
assert [w.speaker for w in event.revisions[0].words] == ["B", "A"]
assert event.revisions[1].speaker_label == "A"
assert event.revisions[1].words == []


def test_speaker_revision_event_dispatched_to_handler(mocker: MockFixture):
# Given: a SpeakerRevision frame on the wire and a handler registered
revision_json = json.dumps(
{
"type": "SpeakerRevision",
"turn_order": 5,
"speaker_label": "A",
"words": [
"revisions": [
{
"start": 500,
"end": 700,
"confidence": 0.95,
"text": "yes",
"word_is_final": True,
"speaker": "A",
"turn_order": 5,
"speaker_label": "A",
"words": [
{
"start": 500,
"end": 700,
"confidence": 0.95,
"text": "yes",
"word_is_final": True,
"speaker": "A",
},
],
},
],
}
Expand All @@ -989,15 +1039,23 @@ def test_speaker_revision_event_dispatched_to_handler(mocker: MockFixture):
# Then: the handler is invoked with a parsed SpeakerRevisionEvent
assert len(received) == 1
assert isinstance(received[0], SpeakerRevisionEvent)
assert received[0].turn_order == 5
assert received[0].speaker_label == "A"
assert [w.speaker for w in received[0].words] == ["A"]
assert len(received[0].revisions) == 1
assert received[0].revisions[0].turn_order == 5
assert received[0].revisions[0].speaker_label == "A"
assert [w.speaker for w in received[0].revisions[0].words] == ["A"]


def test_speech_model_optional():
"""Test that omitting speech_model is valid and excluded from the wire URI."""
# Given: streaming parameters with no speech_model
params = StreamingParameters(sample_rate=16000)

# When: the params are constructed and serialized to a connection URI
uri = _build_uri("wss://example.com/v3/ws", params)

def test_speech_model_required():
"""Test that omitting speech_model raises a validation error."""
with pytest.raises(ValidationError):
StreamingParameters(sample_rate=16000)
# Then: speech_model defaults to None and is not sent to the server
assert params.speech_model is None
assert "speech_model" not in uri


def test_speech_started_event():
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/test_streaming_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,45 @@ async def test_set_params_enqueues_update_configuration(mocker: MockFixture):
await client.disconnect()


async def test_set_params_with_agent_context(mocker: MockFixture):
# Given: a connected async streaming client
fake_ws = _FakeAsyncWebSocket()
_patch_connect(mocker, fake_ws)

client = AsyncStreamingClient(
StreamingClientOptions(api_key="test", api_host="api.example.com")
)
await client.connect(_default_params())

from assemblyai.streaming.v3.models import (
StreamingSessionParameters,
)

# When: set_params is called with agent_context mid-stream
await client.set_params(
StreamingSessionParameters(agent_context="What is your account number?")
)

for _ in range(100):
update_frames = [
s for s in fake_ws.sent if isinstance(s, str) and "UpdateConfiguration" in s
]
if update_frames:
break
await asyncio.sleep(0.01)

# Then: an UpdateConfiguration frame carrying agent_context is sent
update_frames = [
s for s in fake_ws.sent if isinstance(s, str) and "UpdateConfiguration" in s
]
assert len(update_frames) == 1
payload = json.loads(update_frames[0])
assert payload["type"] == "UpdateConfiguration"
assert payload["agent_context"] == "What is your account number?"

await client.disconnect()


async def test_force_endpoint_enqueues_force_endpoint_frame(mocker: MockFixture):
fake_ws = _FakeAsyncWebSocket()
_patch_connect(mocker, fake_ws)
Expand Down
Loading