From 2f6a34fffc003dfdb29664cfd78d8fa239810f9f Mon Sep 17 00:00:00 2001 From: Benjamin Callonnec Date: Thu, 23 Apr 2026 07:59:50 +0200 Subject: [PATCH] fix litellm audio management --- src/google/adk/models/lite_llm.py | 27 ++++++- tests/unittests/models/test_litellm.py | 105 ++++++++++++++++++++++--- 2 files changed, 121 insertions(+), 11 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 7d13696c96..7e90b4912e 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -103,7 +103,17 @@ _MEDIA_URL_CONTENT_TYPE_BY_MAJOR_MIME_TYPE = { "image": "image_url", "video": "video_url", - "audio": "audio_url", +} + +# LiteLLM input_audio only accepts "mp3" and "wav" as format values. +# Maps audio MIME subtypes (including common aliases) to the canonical format. +_AUDIO_MIME_SUBTYPE_TO_FORMAT: dict[str, str] = { + "mpeg": "mp3", + "mp3": "mp3", + "wav": "wav", + "x-wav": "wav", + "wave": "wav", + "vnd.wave": "wav", } # Mapping of LiteLLM finish_reason strings to FinishReason enum values @@ -1048,6 +1058,21 @@ async def _get_content( "type": url_content_type, url_content_type: {"url": data_uri}, }) + elif mime_type.startswith("audio/"): + audio_subtype = mime_type.split("/", 1)[1] + audio_format = _AUDIO_MIME_SUBTYPE_TO_FORMAT.get(audio_subtype) + if audio_format is None: + raise ValueError( + f"Unsupported audio MIME type '{part.inline_data.mime_type}'." + " LiteLLM input_audio only supports mp3 and wav." + ) + content_objects.append({ + "type": "input_audio", + "input_audio": { + "data": base64_string, + "format": audio_format, + }, + }) elif mime_type in _SUPPORTED_FILE_CONTENT_MIME_TYPES: # OpenAI/Azure require file_id from uploaded file, not inline data if provider in _FILE_ID_REQUIRED_PROVIDERS: diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index ace08ad997..092a1baa89 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -2873,12 +2873,6 @@ async def test_get_content_file_uri_file_id_required_falls_back_to_text( "video_url", id="video", ), - pytest.param( - "https://example.com/audio.mp3", - "audio/mpeg", - "audio_url", - id="audio", - ), ], ) async def test_get_content_file_uri_media_url_file_id_required_uses_url_type( @@ -2899,6 +2893,32 @@ async def test_get_content_file_uri_media_url_file_id_required_uses_url_type( }] +@pytest.mark.asyncio +@pytest.mark.parametrize( + "provider,model", + [ + ("openai", "openai/gpt-4o-audio-preview"), + ("azure", "azure/gpt-4o-audio-preview"), + ], +) +async def test_get_content_file_uri_audio_http_url_file_id_required_falls_back_to_text( + provider, model +): + # audio_url is not a valid LiteLLM content type; HTTP audio URLs for + # file-id-required providers fall back to a text reference. + parts = [ + types.Part( + file_data=types.FileData( + file_uri="https://example.com/audio.mp3", + mime_type="audio/mpeg", + display_name="audio.mp3", + ) + ) + ] + content = await _get_content(parts, provider=provider, model=model) + assert content == [{"type": "text", "text": '[File reference: "audio.mp3"]'}] + + @pytest.mark.asyncio @pytest.mark.parametrize( "provider,model", @@ -3144,16 +3164,81 @@ async def test_get_content_file_uri_mime_type_inference( @pytest.mark.asyncio async def test_get_content_audio(): + # Audio inline_data must produce an input_audio block (not audio_url). + # The data field is raw base64 (no data URI prefix) and format is the + # MIME subtype extracted from the MIME type. parts = [ types.Part.from_bytes(data=b"test_audio_data", mime_type="audio/mpeg") ] content = await _get_content(parts) - assert content[0]["type"] == "audio_url" + assert content[0]["type"] == "input_audio" + assert content[0]["input_audio"]["data"] == "dGVzdF9hdWRpb19kYXRh" + assert content[0]["input_audio"]["format"] == "mp3" + assert "url" not in content[0]["input_audio"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "mime_type,expected_format", + [ + pytest.param("audio/mpeg", "mp3", id="mpeg_to_mp3"), + pytest.param("audio/mp3", "mp3", id="mp3_alias"), + pytest.param("audio/wav", "wav", id="wav"), + pytest.param("audio/x-wav", "wav", id="x-wav_alias"), + pytest.param("audio/wave", "wav", id="wave_alias"), + ], +) +async def test_get_content_audio_formats(mime_type, expected_format): + # Only mp3 and wav are valid input_audio formats; verify MIME aliases map + # to the correct canonical format string. + parts = [types.Part.from_bytes(data=b"audio_bytes", mime_type=mime_type)] + content = await _get_content(parts) + assert content[0]["type"] == "input_audio" + assert content[0]["input_audio"]["format"] == expected_format assert ( - content[0]["audio_url"]["url"] - == "data:audio/mpeg;base64,dGVzdF9hdWRpb19kYXRh" + content[0]["input_audio"]["data"] + == base64.b64encode(b"audio_bytes").decode() ) - assert "format" not in content[0]["audio_url"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "mime_type", + ["audio/mp4", "audio/ogg", "audio/webm", "audio/aac"], +) +async def test_get_content_audio_unsupported_format_raises(mime_type): + # Formats other than mp3/wav are not supported by LiteLLM input_audio and + # should raise a ValueError rather than silently producing a bad payload. + parts = [types.Part.from_bytes(data=b"audio_bytes", mime_type=mime_type)] + with pytest.raises(ValueError, match="Unsupported audio MIME type"): + await _get_content(parts) + + +@pytest.mark.asyncio +async def test_get_content_audio_raw_base64_not_data_uri(): + # Ensure the data field is raw base64 with no "data:audio/...;base64," prefix. + raw_bytes = b"\x00\x01\x02\x03" + parts = [types.Part.from_bytes(data=raw_bytes, mime_type="audio/wav")] + content = await _get_content(parts) + audio_data = content[0]["input_audio"]["data"] + assert not audio_data.startswith("data:") + assert audio_data == base64.b64encode(raw_bytes).decode() + + +@pytest.mark.asyncio +async def test_get_content_audio_mixed_with_text(): + # When audio is combined with text, both parts appear as separate content + # objects: text block followed by input_audio block. + parts = [ + types.Part.from_text(text="What is said in this audio?"), + types.Part.from_bytes(data=b"test_audio_data", mime_type="audio/mpeg"), + ] + content = await _get_content(parts) + assert len(content) == 2 + assert content[0] == {"type": "text", "text": "What is said in this audio?"} + assert content[1]["type"] == "input_audio" + assert content[1]["input_audio"]["data"] == "dGVzdF9hdWRpb19kYXRh" + assert content[1]["input_audio"]["format"] == "mp3" def test_to_litellm_role():