From e9184c9846d97f65907667aa2a6bbac1f65fed64 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 1 May 2026 09:15:52 -0700 Subject: [PATCH] feat: Add support for refusal content using "[[REFUSAL]]:" prefix This is part of a larger chain of commits for adding chat completion API support to the Apigee model. PiperOrigin-RevId: 908765169 --- .../models/chat/ChatCompletionsCommon.java | 45 ++++++ .../models/chat/ChatCompletionsRequest.java | 30 ++-- .../models/chat/ChatCompletionsResponse.java | 2 +- .../chat/ChatCompletionsRequestTest.java | 151 ++++++++++++++++++ .../chat/ChatCompletionsResponseTest.java | 6 +- 5 files changed, 218 insertions(+), 16 deletions(-) diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java index e26546313..1ed997824 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java @@ -25,6 +25,7 @@ import com.google.genai.types.Part; import java.util.Base64; import java.util.Map; +import java.util.Objects; import org.jspecify.annotations.Nullable; /** Shared models for Chat Completions Request and Response. */ @@ -45,6 +46,50 @@ private ChatCompletionsCommon() {} public static final String METADATA_KEY_SYSTEM_FINGERPRINT = "system_fingerprint"; public static final String METADATA_KEY_SERVICE_TIER = "service_tier"; + /** + * Prefix used to mark refusal content in a text Part, since there is no dedicated field for + * refusal content in the Gemini API. + */ + static final String REFUSAL_PREFIX = "[[REFUSAL]]: "; + + /** + * Result of splitting a text part into its non-refusal content and refusal content. Either + * component may be {@code null} when absent. + */ + record RefusalSplit(@Nullable String content, @Nullable String refusal) {} + + /** + * Splits a text Part value into a content portion and a refusal portion based on the {@link + * #REFUSAL_PREFIX} sentinel: + * + * + * + * @param text the raw text from a {@link Part#text()}. + * @return a {@link RefusalSplit} with the content and refusal portions. + */ + static RefusalSplit parseRefusalPrefix(String text) { + Objects.requireNonNull(text, "text cannot be null"); + if (text.startsWith(REFUSAL_PREFIX)) { + return new RefusalSplit(null, text.substring(REFUSAL_PREFIX.length())); + } + String separator = "\n" + REFUSAL_PREFIX; + int index = text.indexOf(separator); + if (index >= 0) { + String before = text.substring(0, index); + String after = text.substring(index + separator.length()); + return new RefusalSplit(before.isEmpty() ? null : before, after); + } + return new RefusalSplit(text, null); + } + /** * See * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_message_tool_call%20%3E%20(schema) diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java index ea49bbe2f..523c04a5a 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java @@ -350,6 +350,7 @@ private static List processContent(Content content) { List contentParts = new ArrayList<>(); List toolCalls = new ArrayList<>(); List toolResponses = new ArrayList<>(); + List refusals = new ArrayList<>(); content .parts() @@ -357,7 +358,18 @@ private static List processContent(Content content) { parts -> { for (Part part : parts) { if (part.text().isPresent()) { - contentParts.add(processTextPart(part)); + // Text Parts may carry refusal content prefixed with REFUSAL_PREFIX. + ChatCompletionsCommon.RefusalSplit split = + ChatCompletionsCommon.parseRefusalPrefix(part.text().get()); + if (split.content() != null) { + ContentPart textPart = new ContentPart(); + textPart.type = "text"; + textPart.text = split.content(); + contentParts.add(textPart); + } + if (split.refusal() != null) { + refusals.add(split.refusal()); + } } else if (part.inlineData().isPresent()) { contentParts.add(processInlineDataPart(part)); } else if (part.fileData().isPresent()) { @@ -381,6 +393,9 @@ private static List processContent(Content content) { if (!toolCalls.isEmpty()) { msg.toolCalls = ImmutableList.copyOf(toolCalls); } + if (!refusals.isEmpty()) { + msg.refusal = String.join("\n", refusals); + } if (!contentParts.isEmpty()) { if (contentParts.size() == 1 && Objects.equals(contentParts.get(0).type, "text")) { msg.content = new MessageContent(contentParts.get(0).text); @@ -394,19 +409,6 @@ private static List processContent(Content content) { } } - /** - * Processes a text part and returns a mapped ContentPart. - * - * @param part The input part containing simple text. - * @return The mapped text part. - */ - private static ContentPart processTextPart(Part part) { - ContentPart textPart = new ContentPart(); - textPart.type = "text"; - textPart.text = part.text().get(); - return textPart; - } - /** * Processes an inline data part and returns a mapped ContentPart. * diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java index a718f9a43..61e7e8358 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java @@ -180,7 +180,7 @@ private ImmutableList mapMessageToParts(Message message) { parts.add(Part.fromText(message.content)); } if (message.refusal != null) { - parts.add(Part.fromText(message.refusal)); + parts.add(Part.fromText(ChatCompletionsCommon.REFUSAL_PREFIX + message.refusal)); } if (message.toolCalls != null) { parts.addAll(mapToolCallsToParts(message.toolCalls)); diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java index aaddc690d..1f41189a2 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java @@ -245,6 +245,157 @@ public void testFromLlmRequest_basic() throws Exception { assertThat(request.messages.get(0).content.getValue()).isEqualTo("Hello"); } + @Test + public void testFromLlmRequest_withRefusal() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.fromText("Regular text response"), + Part.fromText( + ChatCompletionsCommon.REFUSAL_PREFIX + "I cannot do that."))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message message = request.messages.get(0); + assertThat(message.role).isEqualTo("assistant"); + assertThat(message.refusal).isEqualTo("I cannot do that."); + assertThat(message.content.getValue()).isEqualTo("Regular text response"); + } + + @Test + public void testFromLlmRequest_withRefusalEmbeddedAfterNewline() throws Exception { + // A single Part containing both content and refusal, separated by "\n[[REFUSAL]]: ". + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.fromText( + "Partial text answer\n" + + ChatCompletionsCommon.REFUSAL_PREFIX + + "System error or refusal"))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message message = request.messages.get(0); + assertThat(message.role).isEqualTo("assistant"); + assertThat(message.content.getValue()).isEqualTo("Partial text answer"); + assertThat(message.refusal).isEqualTo("System error or refusal"); + } + + @Test + public void testFromLlmRequest_withMultipleRefusalsJoinedWithNewline() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.fromText(ChatCompletionsCommon.REFUSAL_PREFIX + "First"), + Part.fromText(ChatCompletionsCommon.REFUSAL_PREFIX + "Second"))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message message = request.messages.get(0); + assertThat(message.role).isEqualTo("assistant"); + assertThat(message.refusal).isEqualTo("First\nSecond"); + assertThat(message.content).isNull(); + } + + @Test + public void testFromLlmRequest_withRefusalOnlyHasNullContent() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.fromText( + ChatCompletionsCommon.REFUSAL_PREFIX + "Only a refusal"))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message message = request.messages.get(0); + assertThat(message.role).isEqualTo("assistant"); + assertThat(message.refusal).isEqualTo("Only a refusal"); + assertThat(message.content).isNull(); + } + + @Test + public void testFromLlmRequest_withRefusalPrefixAfterEmptyContentLine() throws Exception { + // Edge case: text begins with "\n[[REFUSAL]]: ..." -- empty content before the prefix. + // Expectation: no content part, refusal populated. + String text = "\n" + ChatCompletionsCommon.REFUSAL_PREFIX + "Refusal only"; + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts(ImmutableList.of(Part.fromText(text))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message message = request.messages.get(0); + assertThat(message.refusal).isEqualTo("Refusal only"); + assertThat(message.content).isNull(); + } + + @Test + public void testFromLlmRequest_withRefusalPrefixMidLineIsNotSplit() throws Exception { + // The prefix is intentionally NOT recognized mid-line without a preceding newline. + String inlineText = "foo " + ChatCompletionsCommon.REFUSAL_PREFIX + "bar"; + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts(ImmutableList.of(Part.fromText(inlineText))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message message = request.messages.get(0); + assertThat(message.refusal).isNull(); + assertThat(message.content.getValue()).isEqualTo(inlineText); + } + @Test public void testFromLlmRequest_withSystemInstruction() throws Exception { LlmRequest llmRequest = diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java index ad1839019..367545207 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java @@ -504,6 +504,7 @@ public void testToLlmResponse_withRefusal() throws Exception { "index": 0, "message": { "role": "assistant", + "content": "Partial text answer", "refusal": "System error or refusal" }, "finish_reason": "stop" @@ -521,8 +522,11 @@ public void testToLlmResponse_withRefusal() throws Exception { // Content assertThat(response.content().get().role()).hasValue("model"); + assertThat(response.content().get().parts().get()).hasSize(2); assertThat(response.content().get().parts().get().get(0).text()) - .hasValue("System error or refusal"); + .hasValue("Partial text answer"); + assertThat(response.content().get().parts().get().get(1).text()) + .hasValue("[[REFUSAL]]: System error or refusal"); // Custom Metadata List metadata = response.customMetadata().get();