diff --git a/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessor.kt b/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessor.kt similarity index 54% rename from firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessor.kt rename to firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessor.kt index 71211448ebc..16dbe30b202 100644 --- a/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessor.kt +++ b/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessor.kt @@ -17,6 +17,7 @@ package com.google.firebase.ai.ksp import com.google.devtools.ksp.KspExperimental +import com.google.devtools.ksp.isPublic import com.google.devtools.ksp.processing.CodeGenerator import com.google.devtools.ksp.processing.Dependencies import com.google.devtools.ksp.processing.KSPLogger @@ -26,6 +27,7 @@ import com.google.devtools.ksp.symbol.ClassKind import com.google.devtools.ksp.symbol.KSAnnotated import com.google.devtools.ksp.symbol.KSAnnotation import com.google.devtools.ksp.symbol.KSClassDeclaration +import com.google.devtools.ksp.symbol.KSFunctionDeclaration import com.google.devtools.ksp.symbol.KSType import com.google.devtools.ksp.symbol.KSVisitorVoid import com.google.devtools.ksp.symbol.Modifier @@ -40,12 +42,17 @@ import com.squareup.kotlinpoet.TypeSpec import com.squareup.kotlinpoet.ksp.toClassName import com.squareup.kotlinpoet.ksp.toTypeName import com.squareup.kotlinpoet.ksp.writeTo +import java.util.Locale import javax.annotation.processing.Generated -public class SchemaSymbolProcessor( +public class FirebaseSymbolProcessor( private val codeGenerator: CodeGenerator, private val logger: KSPLogger, ) : SymbolProcessor { + private val baseKdocRegex = Regex("^\\s*(.*?)((@\\w* .*)|\\z)", RegexOption.DOT_MATCHES_ALL) + private val propertyKdocRegex = + Regex("\\s*@property (\\w*) (.*?)(?=@\\w*|\\z)", RegexOption.DOT_MATCHES_ALL) + override fun process(resolver: Resolver): List { resolver .getSymbolsWithAnnotation("com.google.firebase.ai.annotations.Generable") @@ -53,14 +60,160 @@ public class SchemaSymbolProcessor( .map { it to SchemaSymbolProcessorVisitor() } .forEach { (klass, visitor) -> visitor.visitClassDeclaration(klass, Unit) } + resolver + .getSymbolsWithAnnotation("com.google.firebase.ai.annotations.Tool") + .filterIsInstance() + .map { it to FunctionSymbolProcessorVisitor(it, resolver) } + .forEach { (klass, visitor) -> visitor.visitFunctionDeclaration(klass, Unit) } + return emptyList() } + private inner class FunctionSymbolProcessorVisitor( + private val func: KSFunctionDeclaration, + private val resolver: Resolver, + ) : KSVisitorVoid() { + override fun visitFunctionDeclaration(function: KSFunctionDeclaration, data: Unit) { + var shouldError = false + val fullFunctionName = function.qualifiedName!!.asString() + if (!function.isPublic()) { + logger.warn("$fullFunctionName must be public.") + shouldError = true + } + val containingClass = function.parentDeclaration as? KSClassDeclaration + if (containingClass == null || !containingClass.isCompanionObject) { + logger.warn( + "$fullFunctionName must be within a companion object " + + containingClass!!.qualifiedName!!.asString() + ) + shouldError = true + } + if (function.parameters.size != 1) { + logger.warn("$fullFunctionName must have exactly one parameter") + shouldError = true + } + val parameter = function.parameters.firstOrNull()?.type?.resolve()?.declaration + if (parameter != null) { + if (parameter.annotations.find { it.shortName.getShortName() == "Generable" } == null) { + logger.warn("$fullFunctionName parameter must be annotated @Generable") + shouldError = true + } + if (parameter.annotations.find { it.shortName.getShortName() == "Serializable" } == null) { + logger.warn("$fullFunctionName parameter must be annotated @Serializable") + shouldError = true + } + } + val output = function.returnType?.resolve() + if ( + output != null && + output.toClassName().canonicalName != "kotlinx.serialization.json.JsonObject" + ) { + if ( + output.declaration.annotations.find { it.shortName.getShortName() != "Generable" } == null + ) { + logger.warn("$fullFunctionName output must be annotated @Generable") + shouldError = true + } + if ( + output.declaration.annotations.find { it.shortName.getShortName() != "Serializable" } == + null + ) { + logger.warn("$fullFunctionName output must be annotated @Serializable") + shouldError = true + } + } + if (shouldError) { + logger.error("$fullFunctionName has one or more errors, please resolve them.") + } + val generatedFunctionFile = generateFileSpec(function) + generatedFunctionFile.writeTo( + codeGenerator, + Dependencies(true, function.containingFile!!), + ) + } + + private fun generateFileSpec(functionDeclaration: KSFunctionDeclaration): FileSpec { + val generatedClassName = + functionDeclaration.simpleName.asString().replaceFirstChar { + if (it.isLowerCase()) it.titlecase(Locale.ROOT) else it.toString() + } + "GeneratedFunctionDeclaration" + return FileSpec.builder(functionDeclaration.packageName.asString(), generatedClassName) + .addImport("com.google.firebase.ai.type", "AutoFunctionDeclaration") + .addType( + TypeSpec.classBuilder(generatedClassName) + .addAnnotation(Generated::class) + .addType( + TypeSpec.companionObjectBuilder() + .addProperty( + PropertySpec.builder( + "FUNCTION_DECLARATION", + ClassName("com.google.firebase.ai.type", "AutoFunctionDeclaration") + .parameterizedBy( + functionDeclaration.parameters.first().type.resolve().toClassName(), + functionDeclaration.returnType?.resolve()?.toClassName() + ?: ClassName("kotlinx.serialization.json", "JsonObject") + ), + KModifier.PUBLIC, + ) + .mutable(false) + .initializer( + CodeBlock.builder() + .add(generateCodeBlockForFunctionDeclaration(functionDeclaration)) + .build() + ) + .build() + ) + .build() + ) + .build() + ) + .build() + } + + fun generateCodeBlockForFunctionDeclaration( + functionDeclaration: KSFunctionDeclaration + ): CodeBlock { + val builder = CodeBlock.builder() + val hasTypedOutput = + !(functionDeclaration.returnType == null || + functionDeclaration.returnType!!.resolve().toClassName().canonicalName == + "kotlinx.serialization.json.JsonObject") + val kdocDescription = functionDeclaration.docString?.let { extractBaseKdoc(it) } + val annotationDescription = + getStringFromAnnotation( + functionDeclaration.annotations.find { it.shortName.getShortName() == "Tool" }, + "description" + ) + val description = annotationDescription ?: kdocDescription ?: "" + val inputSchemaName = + "${ + functionDeclaration.parameters.first().type.resolve().toClassName().canonicalName + }GeneratedSchema.SCHEMA" + builder + .addStatement("AutoFunctionDeclaration.create(") + .indent() + .addStatement("functionName = %S,", functionDeclaration.simpleName.getShortName()) + .addStatement("description = %S,", description) + .addStatement("inputSchema = $inputSchemaName,") + if (hasTypedOutput) { + val outputSchemaName = + "${ + functionDeclaration.returnType!!.resolve().toClassName().canonicalName + }GeneratedSchema.SCHEMA" + builder.addStatement("outputSchema = $outputSchemaName,") + } + builder.addStatement( + "functionReference = " + + functionDeclaration.qualifiedName!!.getQualifier() + + "::${functionDeclaration.qualifiedName!!.getShortName()}," + ) + builder.unindent().addStatement(")") + return builder.build() + } + } + private inner class SchemaSymbolProcessorVisitor() : KSVisitorVoid() { private val numberTypes = setOf("kotlin.Int", "kotlin.Long", "kotlin.Double", "kotlin.Float") - private val baseKdocRegex = Regex("^\\s*(.*?)((@\\w* .*)|\\z)", RegexOption.DOT_MATCHES_ALL) - private val propertyKdocRegex = - Regex("\\s*@property (\\w*) (.*?)(?=@\\w*|\\z)", RegexOption.DOT_MATCHES_ALL) override fun visitClassDeclaration(classDeclaration: KSClassDeclaration, data: Unit) { val isDataClass = classDeclaration.modifiers.contains(Modifier.DATA) @@ -240,7 +393,8 @@ public class SchemaSymbolProcessor( } if ((format != null || pattern != null) && className.canonicalName != "kotlin.String") { logger.warn( - "${parentType?.toClassName()?.simpleName?.let { "$it." }}$name is not a String type, format and pattern are not a valid parameter to specify in @Guide" + "${parentType?.toClassName()?.simpleName?.let { "$it." }}$name is not a String type, " + + "format and pattern are not a valid parameter to specify in @Guide" ) } if (minimum != null) { @@ -264,73 +418,72 @@ public class SchemaSymbolProcessor( builder.addStatement("nullable = %L)", className.isNullable).unindent() return builder.build() } + } - private fun getDescriptionFromAnnotations( - guideAnnotation: KSAnnotation?, - guideClassAnnotation: KSAnnotation?, - description: String?, - baseKdoc: String?, - ): String? { - val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, "description") - - val guideClassDescription = getStringFromAnnotation(guideClassAnnotation, "description") + private fun getDescriptionFromAnnotations( + guideAnnotation: KSAnnotation?, + guideClassAnnotation: KSAnnotation?, + description: String?, + baseKdoc: String?, + ): String? { + val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, "description") - return guidePropertyDescription ?: guideClassDescription ?: description ?: baseKdoc - } + val guideClassDescription = getStringFromAnnotation(guideClassAnnotation, "description") - private fun getDoubleFromAnnotation( - guideAnnotation: KSAnnotation?, - doubleName: String, - ): Double? { - val guidePropertyDoubleValue = - guideAnnotation - ?.arguments - ?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true } - ?.value as? Double - if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == -1.0) { - return null - } - return guidePropertyDoubleValue + return guidePropertyDescription ?: guideClassDescription ?: description ?: baseKdoc + } + private fun getDoubleFromAnnotation( + guideAnnotation: KSAnnotation?, + doubleName: String, + ): Double? { + val guidePropertyDoubleValue = + guideAnnotation + ?.arguments + ?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true } + ?.value as? Double + if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == -1.0) { + return null } + return guidePropertyDoubleValue + } - private fun getIntFromAnnotation(guideAnnotation: KSAnnotation?, intName: String): Int? { - val guidePropertyIntValue = - guideAnnotation - ?.arguments - ?.firstOrNull { it.name?.getShortName()?.equals(intName) == true } - ?.value as? Int - if (guidePropertyIntValue == null || guidePropertyIntValue == -1) { - return null - } - return guidePropertyIntValue + private fun getIntFromAnnotation(guideAnnotation: KSAnnotation?, intName: String): Int? { + val guidePropertyIntValue = + guideAnnotation + ?.arguments + ?.firstOrNull { it.name?.getShortName()?.equals(intName) == true } + ?.value as? Int + if (guidePropertyIntValue == null || guidePropertyIntValue == -1) { + return null } + return guidePropertyIntValue + } - private fun getStringFromAnnotation( - guideAnnotation: KSAnnotation?, - stringName: String, - ): String? { - val guidePropertyStringValue = - guideAnnotation - ?.arguments - ?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true } - ?.value as? String - if (guidePropertyStringValue.isNullOrEmpty()) { - return null - } - return guidePropertyStringValue + private fun getStringFromAnnotation( + guideAnnotation: KSAnnotation?, + stringName: String, + ): String? { + val guidePropertyStringValue = + guideAnnotation + ?.arguments + ?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true } + ?.value as? String + if (guidePropertyStringValue.isNullOrEmpty()) { + return null } + return guidePropertyStringValue + } - private fun extractBaseKdoc(kdoc: String): String? { - return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1)?.value?.trim().let { - if (it.isNullOrEmpty()) null else it - } + private fun extractBaseKdoc(kdoc: String): String? { + return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1)?.value?.trim().let { + if (it.isNullOrEmpty()) null else it } + } - private fun extractPropertyKdocs(kdoc: String): Map { - return propertyKdocRegex - .findAll(kdoc) - .map { it.groups[1]!!.value to it.groups[2]!!.value.replace("\n", "").trim() } - .toMap() - } + private fun extractPropertyKdocs(kdoc: String): Map { + return propertyKdocRegex + .findAll(kdoc) + .map { it.groups[1]!!.value to it.groups[2]!!.value.replace("\n", "").trim() } + .toMap() } } diff --git a/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessorProvider.kt b/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessorProvider.kt similarity index 85% rename from firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessorProvider.kt rename to firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessorProvider.kt index 2c8015bc8a9..771706d9866 100644 --- a/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessorProvider.kt +++ b/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessorProvider.kt @@ -20,8 +20,8 @@ import com.google.devtools.ksp.processing.SymbolProcessor import com.google.devtools.ksp.processing.SymbolProcessorEnvironment import com.google.devtools.ksp.processing.SymbolProcessorProvider -public class SchemaSymbolProcessorProvider : SymbolProcessorProvider { +public class FirebaseSymbolProcessorProvider : SymbolProcessorProvider { override fun create(environment: SymbolProcessorEnvironment): SymbolProcessor { - return SchemaSymbolProcessor(environment.codeGenerator, environment.logger) + return FirebaseSymbolProcessor(environment.codeGenerator, environment.logger) } } diff --git a/firebase-ai-ksp-processor/src/main/resources/META-INF/services/com.google.devtools.ksp.processing.SymbolProcessorProvider b/firebase-ai-ksp-processor/src/main/resources/META-INF/services/com.google.devtools.ksp.processing.SymbolProcessorProvider index 83d92f28c7e..b5a8cffc5a6 100644 --- a/firebase-ai-ksp-processor/src/main/resources/META-INF/services/com.google.devtools.ksp.processing.SymbolProcessorProvider +++ b/firebase-ai-ksp-processor/src/main/resources/META-INF/services/com.google.devtools.ksp.processing.SymbolProcessorProvider @@ -1 +1 @@ -com.google.firebase.ai.ksp.SchemaSymbolProcessorProvider \ No newline at end of file +com.google.firebase.ai.ksp.FirebaseSymbolProcessorProvider \ No newline at end of file diff --git a/firebase-ai/api.txt b/firebase-ai/api.txt index a1daa7e8d38..bdfe6ac2289 100644 --- a/firebase-ai/api.txt +++ b/firebase-ai/api.txt @@ -120,6 +120,11 @@ package com.google.firebase.ai.annotations { property public abstract String pattern; } + @kotlin.annotation.Retention(kotlin.annotation.AnnotationRetention.SOURCE) @kotlin.annotation.Target(allowedTargets=kotlin.annotation.AnnotationTarget.FUNCTION) public @interface Tool { + method public abstract String description() default ""; + property public abstract String description; + } + } package com.google.firebase.ai.java { diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Generable.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Generable.kt index fe217272bbd..7a721ed0921 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Generable.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Generable.kt @@ -16,8 +16,10 @@ package com.google.firebase.ai.annotations +import com.google.firebase.ai.type.JsonSchema + /** - * This annotation is used with the firebase-ai-ksp-processor plugin to generate JsonSchema that + * This annotation is used with the firebase-ai-ksp-processor plugin to generate [JsonSchema] that * match an existing kotlin class structure. For more info see: * https://github.com/firebase/firebase-android-sdk/blob/main/firebase-ai-ksp-processor/README.md */ diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Tool.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Tool.kt new file mode 100644 index 00000000000..a0a81d2ff57 --- /dev/null +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Tool.kt @@ -0,0 +1,30 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai.annotations + +import com.google.firebase.ai.type.AutoFunctionDeclaration + +/** + * This annotation is used with the firebase-ai-ksp-processor plugin to generate + * [AutoFunctionDeclaration]s that match an existing kotlin function. For more info see: + * https://github.com/firebase/firebase-android-sdk/blob/main/firebase-ai-ksp-processor/README.md + * + * @property description a description of the function + */ +@Target(AnnotationTarget.FUNCTION) +@Retention(AnnotationRetention.SOURCE) +public annotation class Tool(public val description: String = "")