diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py index dd34a94..af24c75 100644 --- a/assemblyai/__version__.py +++ b/assemblyai/__version__.py @@ -1 +1 @@ -__version__ = "0.64.4" +__version__ = "0.64.7" diff --git a/assemblyai/streaming/v3/__init__.py b/assemblyai/streaming/v3/__init__.py index abee882..2bcec35 100644 --- a/assemblyai/streaming/v3/__init__.py +++ b/assemblyai/streaming/v3/__init__.py @@ -7,12 +7,14 @@ LLMGatewayResponseEvent, NoiseSuppressionModel, SpeakerRevisionEvent, + SpeakerRevisionItem, SpeechModel, SpeechStartedEvent, StreamingClientOptions, StreamingError, StreamingErrorCodes, StreamingEvents, + StreamingMode, StreamingParameters, StreamingPiiPolicy, StreamingPiiSubstitution, @@ -31,6 +33,7 @@ "LLMGatewayResponseEvent", "NoiseSuppressionModel", "SpeakerRevisionEvent", + "SpeakerRevisionItem", "SpeechModel", "SpeechStartedEvent", "StreamingClient", @@ -38,6 +41,7 @@ "StreamingError", "StreamingErrorCodes", "StreamingEvents", + "StreamingMode", "StreamingParameters", "StreamingPiiPolicy", "StreamingPiiSubstitution", diff --git a/assemblyai/streaming/v3/models.py b/assemblyai/streaming/v3/models.py index a59d5a8..db76abe 100644 --- a/assemblyai/streaming/v3/models.py +++ b/assemblyai/streaming/v3/models.py @@ -74,21 +74,31 @@ class LLMGatewayResponseEvent(BaseModel): data: Any -class SpeakerRevisionEvent(BaseModel): - """Server-side correction to a previously-emitted Turn's speaker labels. +class SpeakerRevisionItem(BaseModel): + """A single Turn whose speaker labels were revised by reclustering. - Emitted after offline reclustering refines the live tentative labels. Match by `turn_order` against the original Turn; replace its per-word - speaker assignments (and the turn-level `speaker_label`) with these. - Text and word timestamps are unchanged from the original Turn. + speaker assignments (and the turn-level `speaker_label`) with these. Text + and word timestamps are unchanged from the original Turn. """ - type: Literal["SpeakerRevision"] = "SpeakerRevision" turn_order: int speaker_label: Optional[str] = None words: List[Word] = [] +class SpeakerRevisionEvent(BaseModel): + """Server-side correction to previously-emitted Turns' speaker labels. + + Emitted once per offline-recluster resolve. `revisions` carries one entry + per earlier Turn whose label actually changed (unchanged turns are + omitted). Apply each entry by matching its `turn_order`. + """ + + type: Literal["SpeakerRevision"] = "SpeakerRevision" + revisions: List[SpeakerRevisionItem] = [] + + EventMessage = Union[ BeginEvent, TerminationEvent, @@ -121,6 +131,7 @@ class StreamingSessionParameters(BaseModel): keyterms_prompt: Optional[List[str]] = None filter_profanity: Optional[bool] = None prompt: Optional[str] = None + agent_context: Optional[str] = None interruption_delay: Optional[int] = None turn_left_pad_ms: Optional[int] = None @@ -137,6 +148,7 @@ class SpeechModel(str, Enum): universal_streaming_multilingual = "universal-streaming-multilingual" universal_streaming_english = "universal-streaming-english" u3_rt_pro = "u3-rt-pro" + u3_rt_pro_beta_1 = "u3-rt-pro-beta-1" whisper_rt = "whisper-rt" u3_pro = "u3-pro" # Deprecated: Use u3_rt_pro instead @@ -159,6 +171,15 @@ def __str__(self): return self.value +class StreamingMode(str, Enum): + max_accuracy = "max_accuracy" + min_latency = "min_latency" + balanced = "balanced" + + def __str__(self): + return self.value + + class StreamingPiiSubstitution(str, Enum): hash = "hash" entity_name = "entity_name" @@ -223,7 +244,7 @@ def __str__(self): class StreamingParameters(StreamingSessionParameters): sample_rate: int encoding: Optional[Encoding] = None - speech_model: SpeechModel + speech_model: Optional[SpeechModel] = None language_detection: Optional[bool] = None domain: Optional[StreamingDomain] = None inactivity_timeout: Optional[int] = None @@ -244,6 +265,7 @@ class StreamingParameters(StreamingSessionParameters): redact_pii: Optional[bool] = None redact_pii_policies: Optional[List[StreamingPiiPolicy]] = None redact_pii_sub: Optional[StreamingPiiSubstitution] = None + mode: Optional[StreamingMode] = None class UpdateConfiguration(StreamingSessionParameters): diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py index bb0b551..4a6451d 100644 --- a/tests/unit/test_streaming.py +++ b/tests/unit/test_streaming.py @@ -6,7 +6,6 @@ from urllib.parse import urlencode import pytest -from pydantic import ValidationError from pytest_mock import MockFixture from websockets.exceptions import ConnectionClosed, InvalidStatus from websockets.frames import Close @@ -19,12 +18,14 @@ StreamingClient, StreamingClientOptions, StreamingEvents, + StreamingMode, StreamingParameters, StreamingPiiPolicy, StreamingPiiSubstitution, TurnEvent, Word, ) +from assemblyai.streaming.v3._base import _build_uri from assemblyai.streaming.v3.models import TerminateSession @@ -245,6 +246,37 @@ def mocked_websocket_connect( assert "noise_suppression_threshold" not in actual_url +def test_client_connect_with_mode(mocker: MockFixture): + # Given: client + mode parameter + actual_url = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_url + actual_url = url + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + _disable_rw_threads(mocker) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + params = StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.u3_rt_pro, + mode=StreamingMode.max_accuracy, + ) + + # When: connect + client.connect(params) + + # Then: the mode wire param is present with its underscore value + assert "mode=max_accuracy" in actual_url + + def test_noise_suppression_deprecated_alias_migrates_to_voice_focus( mocker: MockFixture, caplog: pytest.LogCaptureFixture ): @@ -568,6 +600,7 @@ def mocked_websocket_connect( speech_model=SpeechModel.u3_pro, min_end_of_turn_silence_when_confident=200, prompt="Transcribe this audio with beautiful punctuation and formatting.", + agent_context="What is your account number?", keyterms_prompt=["yes", "no", "okay"], ) @@ -579,6 +612,7 @@ def mocked_websocket_connect( assert "min_turn_silence=200" in actual_url assert "min_end_of_turn_silence_when_confident" not in actual_url assert "prompt=Transcribe" in actual_url + assert "agent_context=What" in actual_url assert "keyterms_prompt=" in actual_url # keyterms_prompt is JSON-encoded assert actual_additional_headers["Authorization"] == "test" @@ -913,28 +947,38 @@ def test_turn_event_with_word_speakers(): def test_speaker_revision_event_parses(): - # Given: a SpeakerRevision payload as emitted by the server (revision words - # use the same Word schema as Turn — start/end/confidence/text/word_is_final/speaker) + # Given: a SpeakerRevision payload as emitted by the server — one message + # per recluster resolve carrying a list of revised turns. Revision words + # use the same Word schema as Turn (start/end/confidence/text/word_is_final/speaker). data = { "type": "SpeakerRevision", - "turn_order": 3, - "speaker_label": "B", - "words": [ + "revisions": [ { - "start": 1000, - "end": 1200, - "confidence": 0.9, - "text": "hello", - "word_is_final": True, - "speaker": "B", + "turn_order": 3, + "speaker_label": "B", + "words": [ + { + "start": 1000, + "end": 1200, + "confidence": 0.9, + "text": "hello", + "word_is_final": True, + "speaker": "B", + }, + { + "start": 1210, + "end": 1400, + "confidence": 0.88, + "text": "world", + "word_is_final": True, + "speaker": "A", + }, + ], }, { - "start": 1210, - "end": 1400, - "confidence": 0.88, - "text": "world", - "word_is_final": True, - "speaker": "A", + "turn_order": 7, + "speaker_label": "A", + "words": [], }, ], } @@ -942,11 +986,13 @@ def test_speaker_revision_event_parses(): # When: parsed event = SpeakerRevisionEvent.parse_obj(data) - # Then: the revision carries the corrected per-word and turn-level speakers + # Then: each revision carries the corrected per-word and turn-level speakers assert event.type == "SpeakerRevision" - assert event.turn_order == 3 - assert event.speaker_label == "B" - assert [w.speaker for w in event.words] == ["B", "A"] + assert [r.turn_order for r in event.revisions] == [3, 7] + assert event.revisions[0].speaker_label == "B" + assert [w.speaker for w in event.revisions[0].words] == ["B", "A"] + assert event.revisions[1].speaker_label == "A" + assert event.revisions[1].words == [] def test_speaker_revision_event_dispatched_to_handler(mocker: MockFixture): @@ -954,16 +1000,20 @@ def test_speaker_revision_event_dispatched_to_handler(mocker: MockFixture): revision_json = json.dumps( { "type": "SpeakerRevision", - "turn_order": 5, - "speaker_label": "A", - "words": [ + "revisions": [ { - "start": 500, - "end": 700, - "confidence": 0.95, - "text": "yes", - "word_is_final": True, - "speaker": "A", + "turn_order": 5, + "speaker_label": "A", + "words": [ + { + "start": 500, + "end": 700, + "confidence": 0.95, + "text": "yes", + "word_is_final": True, + "speaker": "A", + }, + ], }, ], } @@ -989,15 +1039,23 @@ def test_speaker_revision_event_dispatched_to_handler(mocker: MockFixture): # Then: the handler is invoked with a parsed SpeakerRevisionEvent assert len(received) == 1 assert isinstance(received[0], SpeakerRevisionEvent) - assert received[0].turn_order == 5 - assert received[0].speaker_label == "A" - assert [w.speaker for w in received[0].words] == ["A"] + assert len(received[0].revisions) == 1 + assert received[0].revisions[0].turn_order == 5 + assert received[0].revisions[0].speaker_label == "A" + assert [w.speaker for w in received[0].revisions[0].words] == ["A"] + + +def test_speech_model_optional(): + """Test that omitting speech_model is valid and excluded from the wire URI.""" + # Given: streaming parameters with no speech_model + params = StreamingParameters(sample_rate=16000) + # When: the params are constructed and serialized to a connection URI + uri = _build_uri("wss://example.com/v3/ws", params) -def test_speech_model_required(): - """Test that omitting speech_model raises a validation error.""" - with pytest.raises(ValidationError): - StreamingParameters(sample_rate=16000) + # Then: speech_model defaults to None and is not sent to the server + assert params.speech_model is None + assert "speech_model" not in uri def test_speech_started_event(): diff --git a/tests/unit/test_streaming_async.py b/tests/unit/test_streaming_async.py index bf00701..dc1d002 100644 --- a/tests/unit/test_streaming_async.py +++ b/tests/unit/test_streaming_async.py @@ -794,6 +794,45 @@ async def test_set_params_enqueues_update_configuration(mocker: MockFixture): await client.disconnect() +async def test_set_params_with_agent_context(mocker: MockFixture): + # Given: a connected async streaming client + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + from assemblyai.streaming.v3.models import ( + StreamingSessionParameters, + ) + + # When: set_params is called with agent_context mid-stream + await client.set_params( + StreamingSessionParameters(agent_context="What is your account number?") + ) + + for _ in range(100): + update_frames = [ + s for s in fake_ws.sent if isinstance(s, str) and "UpdateConfiguration" in s + ] + if update_frames: + break + await asyncio.sleep(0.01) + + # Then: an UpdateConfiguration frame carrying agent_context is sent + update_frames = [ + s for s in fake_ws.sent if isinstance(s, str) and "UpdateConfiguration" in s + ] + assert len(update_frames) == 1 + payload = json.loads(update_frames[0]) + assert payload["type"] == "UpdateConfiguration" + assert payload["agent_context"] == "What is your account number?" + + await client.disconnect() + + async def test_force_endpoint_enqueues_force_endpoint_frame(mocker: MockFixture): fake_ws = _FakeAsyncWebSocket() _patch_connect(mocker, fake_ws)