Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,17 @@
_MEDIA_URL_CONTENT_TYPE_BY_MAJOR_MIME_TYPE = {
"image": "image_url",
"video": "video_url",
"audio": "audio_url",
}

# LiteLLM input_audio only accepts "mp3" and "wav" as format values.
# Maps audio MIME subtypes (including common aliases) to the canonical format.
_AUDIO_MIME_SUBTYPE_TO_FORMAT: dict[str, str] = {
"mpeg": "mp3",
"mp3": "mp3",
"wav": "wav",
"x-wav": "wav",
"wave": "wav",
"vnd.wave": "wav",
}

# Mapping of LiteLLM finish_reason strings to FinishReason enum values
Expand Down Expand Up @@ -1048,6 +1058,21 @@ async def _get_content(
"type": url_content_type,
url_content_type: {"url": data_uri},
})
elif mime_type.startswith("audio/"):
audio_subtype = mime_type.split("/", 1)[1]
audio_format = _AUDIO_MIME_SUBTYPE_TO_FORMAT.get(audio_subtype)
if audio_format is None:
raise ValueError(
f"Unsupported audio MIME type '{part.inline_data.mime_type}'."
" LiteLLM input_audio only supports mp3 and wav."
)
content_objects.append({
"type": "input_audio",
"input_audio": {
"data": base64_string,
"format": audio_format,
},
})
elif mime_type in _SUPPORTED_FILE_CONTENT_MIME_TYPES:
# OpenAI/Azure require file_id from uploaded file, not inline data
if provider in _FILE_ID_REQUIRED_PROVIDERS:
Expand Down
105 changes: 95 additions & 10 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2873,12 +2873,6 @@ async def test_get_content_file_uri_file_id_required_falls_back_to_text(
"video_url",
id="video",
),
pytest.param(
"https://example.com/audio.mp3",
"audio/mpeg",
"audio_url",
id="audio",
),
],
)
async def test_get_content_file_uri_media_url_file_id_required_uses_url_type(
Expand All @@ -2899,6 +2893,32 @@ async def test_get_content_file_uri_media_url_file_id_required_uses_url_type(
}]


@pytest.mark.asyncio
@pytest.mark.parametrize(
"provider,model",
[
("openai", "openai/gpt-4o-audio-preview"),
("azure", "azure/gpt-4o-audio-preview"),
],
)
async def test_get_content_file_uri_audio_http_url_file_id_required_falls_back_to_text(
provider, model
):
# audio_url is not a valid LiteLLM content type; HTTP audio URLs for
# file-id-required providers fall back to a text reference.
parts = [
types.Part(
file_data=types.FileData(
file_uri="https://example.com/audio.mp3",
mime_type="audio/mpeg",
display_name="audio.mp3",
)
)
]
content = await _get_content(parts, provider=provider, model=model)
assert content == [{"type": "text", "text": '[File reference: "audio.mp3"]'}]


@pytest.mark.asyncio
@pytest.mark.parametrize(
"provider,model",
Expand Down Expand Up @@ -3144,16 +3164,81 @@ async def test_get_content_file_uri_mime_type_inference(

@pytest.mark.asyncio
async def test_get_content_audio():
# Audio inline_data must produce an input_audio block (not audio_url).
# The data field is raw base64 (no data URI prefix) and format is the
# MIME subtype extracted from the MIME type.
parts = [
types.Part.from_bytes(data=b"test_audio_data", mime_type="audio/mpeg")
]
content = await _get_content(parts)
assert content[0]["type"] == "audio_url"
assert content[0]["type"] == "input_audio"
assert content[0]["input_audio"]["data"] == "dGVzdF9hdWRpb19kYXRh"
assert content[0]["input_audio"]["format"] == "mp3"
assert "url" not in content[0]["input_audio"]


@pytest.mark.asyncio
@pytest.mark.parametrize(
"mime_type,expected_format",
[
pytest.param("audio/mpeg", "mp3", id="mpeg_to_mp3"),
pytest.param("audio/mp3", "mp3", id="mp3_alias"),
pytest.param("audio/wav", "wav", id="wav"),
pytest.param("audio/x-wav", "wav", id="x-wav_alias"),
pytest.param("audio/wave", "wav", id="wave_alias"),
],
)
async def test_get_content_audio_formats(mime_type, expected_format):
# Only mp3 and wav are valid input_audio formats; verify MIME aliases map
# to the correct canonical format string.
parts = [types.Part.from_bytes(data=b"audio_bytes", mime_type=mime_type)]
content = await _get_content(parts)
assert content[0]["type"] == "input_audio"
assert content[0]["input_audio"]["format"] == expected_format
assert (
content[0]["audio_url"]["url"]
== "data:audio/mpeg;base64,dGVzdF9hdWRpb19kYXRh"
content[0]["input_audio"]["data"]
== base64.b64encode(b"audio_bytes").decode()
)
assert "format" not in content[0]["audio_url"]


@pytest.mark.asyncio
@pytest.mark.parametrize(
"mime_type",
["audio/mp4", "audio/ogg", "audio/webm", "audio/aac"],
)
async def test_get_content_audio_unsupported_format_raises(mime_type):
# Formats other than mp3/wav are not supported by LiteLLM input_audio and
# should raise a ValueError rather than silently producing a bad payload.
parts = [types.Part.from_bytes(data=b"audio_bytes", mime_type=mime_type)]
with pytest.raises(ValueError, match="Unsupported audio MIME type"):
await _get_content(parts)


@pytest.mark.asyncio
async def test_get_content_audio_raw_base64_not_data_uri():
# Ensure the data field is raw base64 with no "data:audio/...;base64," prefix.
raw_bytes = b"\x00\x01\x02\x03"
parts = [types.Part.from_bytes(data=raw_bytes, mime_type="audio/wav")]
content = await _get_content(parts)
audio_data = content[0]["input_audio"]["data"]
assert not audio_data.startswith("data:")
assert audio_data == base64.b64encode(raw_bytes).decode()


@pytest.mark.asyncio
async def test_get_content_audio_mixed_with_text():
# When audio is combined with text, both parts appear as separate content
# objects: text block followed by input_audio block.
parts = [
types.Part.from_text(text="What is said in this audio?"),
types.Part.from_bytes(data=b"test_audio_data", mime_type="audio/mpeg"),
]
content = await _get_content(parts)
assert len(content) == 2
assert content[0] == {"type": "text", "text": "What is said in this audio?"}
assert content[1]["type"] == "input_audio"
assert content[1]["input_audio"]["data"] == "dGVzdF9hdWRpb19kYXRh"
assert content[1]["input_audio"]["format"] == "mp3"


def test_to_litellm_role():
Expand Down
Loading