Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion core/src/main/java/com/google/adk/SchemaUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ private SchemaUtils() {} // Private constructor for utility class
*/
@SuppressWarnings("unchecked") // For tool parameter type casting.
private static Boolean matchType(Object value, Schema schema, Boolean isInput) {
if (value == null) {
return schema.nullable().orElse(false);
}
// Based on types from https://cloud.google.com/vertex-ai/docs/reference/rest/v1/Schema
Type.Known type = schema.type().get().knownEnum();
switch (type) {
Expand Down Expand Up @@ -73,7 +76,6 @@ private static Boolean matchType(Object value, Schema schema, Boolean isInput) {
throw new IllegalArgumentException(
"Unsupported type: " + type + " is not a Open API data type.");
default:
// This category includes NULL, which is not supported.
break;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ public static Schema buildSchemaFromType(Type type, ObjectMapper objectMapper) {
*/
private static Schema buildSchemaRecursive(
JavaType javaType, SchemaGenerationContext context, ObjectMapper objectMapper) {
if (Optional.class.isAssignableFrom(javaType.getRawClass())) {
JavaType containedType = javaType.containedType(0);
if (containedType == null) {
return Schema.builder().type("OBJECT").build();
}
Schema innerSchema = buildSchemaRecursive(containedType, context, objectMapper);
return innerSchema.toBuilder().nullable(true).build();
}
if (context.isProcessing(javaType)) {
logger.warn("Type {} is recursive. Omitting from schema.", javaType.toCanonical());
return Schema.builder()
Expand Down
91 changes: 66 additions & 25 deletions core/src/main/java/com/google/adk/tools/FunctionTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -271,29 +271,48 @@ private Maybe<Map<String, Object>> call(Map<String, Object> args, ToolContext to
throws IllegalAccessException, InvocationTargetException {
Object[] arguments = buildArguments(args, toolContext, null);
Object result = func.invoke(instance, arguments);
if (result == null) {
if (result == null || isEmptyOptional(result)) {
return Maybe.empty();
} else if (result instanceof Maybe) {
return ((Maybe<?>) result)
.map(
data -> objectMapper.convertValue(data, new TypeReference<Map<String, Object>>() {}));
.filter(data -> !isEmptyOptional(data))
.map(this::convertToMapOrResult);
} else if (result instanceof Single) {
return ((Single<?>) result)
.map(data -> objectMapper.convertValue(data, new TypeReference<Map<String, Object>>() {}))
.toMaybe();
.toMaybe()
.filter(data -> !isEmptyOptional(data))
.map(this::convertToMapOrResult);
} else {
try {
return Maybe.just(
objectMapper.convertValue(result, new TypeReference<Map<String, Object>>() {}));
} catch (IllegalArgumentException e) {
// Conversion to map failed, in this case we follow
// https://google.github.io/adk-docs/tools-custom/function-tools/#return-type and return
// the { "result": $result }
return Maybe.just(ImmutableMap.of("result", result));
return Maybe.just(convertToMapOrResult(result));
}
}

private Map<String, Object> convertToMapOrResult(Object value) {
if (value == null || isEmptyOptional(value)) {
return ImmutableMap.of();
}
if (value instanceof Optional) {
value = ((Optional<?>) value).get();
}
try {
Map<String, Object> map =
objectMapper.convertValue(value, new TypeReference<Map<String, Object>>() {});
if (map == null) {
return ImmutableMap.of("result", value);
}
return map;
} catch (IllegalArgumentException e) {
// Conversion to map failed, in this case we follow
// https://google.github.io/adk-docs/tools-custom/function-tools/#return-type and return
// the { "result": $result }
return ImmutableMap.of("result", value);
}
}

private static boolean isEmptyOptional(Object value) {
return value instanceof Optional && !((Optional<?>) value).isPresent();
}

@SuppressWarnings("unchecked")
public Flowable<Map<String, Object>> callLive(
Map<String, Object> args, ToolContext toolContext, InvocationContext invocationContext)
Expand All @@ -308,6 +327,22 @@ public Flowable<Map<String, Object>> callLive(
}
}

@SuppressWarnings("unchecked") // For tool parameter type casting.
@Nullable
private Object resolveArgumentValue(
@Nullable Object argValue, Class<?> paramType, Type parameterizedType, String paramName) {
if (paramType.equals(List.class)) {
if (argValue instanceof List) {
Type type = ((ParameterizedType) parameterizedType).getActualTypeArguments()[0];
Class<?> typeArgClass = getTypeClass(type, paramName);
return createList((List<Object>) argValue, typeArgClass);
}
} else if (argValue instanceof Map) {
return objectMapper.convertValue(argValue, paramType);
}
return castValue(argValue, paramType);
}

@SuppressWarnings("unchecked") // For tool parameter type casting.
private Object[] buildArguments(
Map<String, Object> args,
Expand Down Expand Up @@ -336,9 +371,14 @@ private Object[] buildArguments(
continue;
}
Annotations.Schema schema = parameters[i].getAnnotation(Annotations.Schema.class);
Class<?> paramType = parameters[i].getType();
if (!args.containsKey(paramName)) {
if (schema != null && schema.optional()) {
arguments[i] = null;
if (paramType.equals(Optional.class)) {
arguments[i] = Optional.empty();
} else {
arguments[i] = null;
}
continue;
} else {
throw new IllegalArgumentException(
Expand All @@ -347,22 +387,23 @@ private Object[] buildArguments(
paramName));
}
}
Class<?> paramType = parameters[i].getType();
Object argValue = args.get(paramName);
if (paramType.equals(List.class)) {
if (argValue instanceof List) {
Type type =
if (paramType.equals(Optional.class)) {
if (argValue == null) {
arguments[i] = Optional.empty();
} else {
Type innerType =
((ParameterizedType) parameters[i].getParameterizedType())
.getActualTypeArguments()[0];
Class<?> typeArgClass = getTypeClass(type, paramName);
arguments[i] = createList((List<Object>) argValue, typeArgClass);
continue;
Class<?> innerClass = getTypeClass(innerType, paramName);
Object resolvedValue = resolveArgumentValue(argValue, innerClass, innerType, paramName);
arguments[i] = Optional.ofNullable(resolvedValue);
}
} else if (argValue instanceof Map) {
arguments[i] = objectMapper.convertValue(argValue, paramType);
continue;
} else {
arguments[i] =
resolveArgumentValue(
argValue, paramType, parameters[i].getParameterizedType(), paramName);
}
arguments[i] = castValue(argValue, paramType);
}
return arguments;
}
Expand Down
83 changes: 83 additions & 0 deletions core/src/test/java/com/google/adk/SchemaUtilsTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.adk;

import static org.junit.Assert.assertThrows;

import com.google.common.collect.ImmutableMap;
import com.google.genai.types.Schema;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Unit tests for {@link SchemaUtils}. */
@RunWith(JUnit4.class)
public final class SchemaUtilsTest {

@Test
public void validateMapOnSchema_nullableField_allowsNull() {
Schema schema =
Schema.builder()
.type("OBJECT")
.properties(
ImmutableMap.of(
"nullableField", Schema.builder().type("STRING").nullable(true).build()))
.build();

Map<String, Object> args = new HashMap<>();
args.put("nullableField", null);

// Should not throw exception
SchemaUtils.validateMapOnSchema(args, schema, /* isInput= */ true);
}

@Test
public void validateMapOnSchema_nonNullableField_throwsException() {
Schema schema =
Schema.builder()
.type("OBJECT")
.properties(
ImmutableMap.of(
"nonNullableField", Schema.builder().type("STRING").nullable(false).build()))
.build();

Map<String, Object> args = new HashMap<>();
args.put("nonNullableField", null);

assertThrows(
IllegalArgumentException.class,
() -> SchemaUtils.validateMapOnSchema(args, schema, /* isInput= */ true));
}

@Test
public void validateMapOnSchema_implicitNonNullableField_throwsException() {
Schema schema =
Schema.builder()
.type("OBJECT")
.properties(ImmutableMap.of("defaultField", Schema.builder().type("STRING").build()))
.build();

Map<String, Object> args = new HashMap<>();
args.put("defaultField", null);

assertThrows(
IllegalArgumentException.class,
() -> SchemaUtils.validateMapOnSchema(args, schema, /* isInput= */ true));
}
}
105 changes: 105 additions & 0 deletions core/src/test/java/com/google/adk/tools/FunctionCallingUtilsTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.adk.tools;

import static com.google.common.truth.Truth.assertThat;

import com.fasterxml.jackson.core.type.TypeReference;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.Schema;
import java.lang.reflect.Type;
import java.util.List;
import java.util.Optional;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Unit tests for {@link FunctionCallingUtils}. */
@RunWith(JUnit4.class)
public final class FunctionCallingUtilsTest {

public static class PojoWithFields {
public String field1;
public int field2;
}

public static class PojoWithOptionalFields {
public Optional<String> optionalField;
public Optional<PojoWithFields> optionalPojo;
public Optional<List<String>> optionalList;
}

@Test
public void buildSchemaFromType_optionalString_returnsNullableString() {
Type type = new TypeReference<Optional<String>>() {}.getType();

Schema schema = FunctionCallingUtils.buildSchemaFromType(type);

assertThat(schema).isEqualTo(Schema.builder().type("STRING").nullable(true).build());
}

@Test
public void buildSchemaFromType_optionalPojo_returnsNullablePojoWithProperties() {
Type type = new TypeReference<Optional<PojoWithFields>>() {}.getType();

Schema schema = FunctionCallingUtils.buildSchemaFromType(type);

assertThat(schema)
.isEqualTo(
Schema.builder()
.type("OBJECT")
.nullable(true)
.properties(
ImmutableMap.of(
"field1", Schema.builder().type("STRING").build(),
"field2", Schema.builder().type("INTEGER").build()))
.build());
}

@Test
public void buildSchemaFromType_pojoWithOptionalFields_generatesCorrectSchema() {
Type type = PojoWithOptionalFields.class;

Schema schema = FunctionCallingUtils.buildSchemaFromType(type);

Schema expectedSchema =
Schema.builder()
.type("OBJECT")
.properties(
ImmutableMap.of(
"optionalField",
Schema.builder().type("STRING").nullable(true).build(),
"optionalPojo",
Schema.builder()
.type("OBJECT")
.nullable(true)
.properties(
ImmutableMap.of(
"field1", Schema.builder().type("STRING").build(),
"field2", Schema.builder().type("INTEGER").build()))
.build(),
"optionalList",
Schema.builder()
.type("ARRAY")
.nullable(true)
.items(Schema.builder().type("STRING").build())
.build()))
.build();

assertThat(schema).isEqualTo(expectedSchema);
}
}
Loading