Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.firebase.ai.ksp

import com.google.devtools.ksp.KspExperimental
import com.google.devtools.ksp.isPublic
import com.google.devtools.ksp.processing.CodeGenerator
import com.google.devtools.ksp.processing.Dependencies
import com.google.devtools.ksp.processing.KSPLogger
Expand All @@ -26,6 +27,7 @@ import com.google.devtools.ksp.symbol.ClassKind
import com.google.devtools.ksp.symbol.KSAnnotated
import com.google.devtools.ksp.symbol.KSAnnotation
import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
import com.google.devtools.ksp.symbol.KSType
import com.google.devtools.ksp.symbol.KSVisitorVoid
import com.google.devtools.ksp.symbol.Modifier
Expand All @@ -40,27 +42,178 @@ import com.squareup.kotlinpoet.TypeSpec
import com.squareup.kotlinpoet.ksp.toClassName
import com.squareup.kotlinpoet.ksp.toTypeName
import com.squareup.kotlinpoet.ksp.writeTo
import java.util.Locale
import javax.annotation.processing.Generated

public class SchemaSymbolProcessor(
public class FirebaseSymbolProcessor(
private val codeGenerator: CodeGenerator,
private val logger: KSPLogger,
) : SymbolProcessor {
private val baseKdocRegex = Regex("^\\s*(.*?)((@\\w* .*)|\\z)", RegexOption.DOT_MATCHES_ALL)
private val propertyKdocRegex =
Regex("\\s*@property (\\w*) (.*?)(?=@\\w*|\\z)", RegexOption.DOT_MATCHES_ALL)

override fun process(resolver: Resolver): List<KSAnnotated> {
resolver
.getSymbolsWithAnnotation("com.google.firebase.ai.annotations.Generable")
.filterIsInstance<KSClassDeclaration>()
.map { it to SchemaSymbolProcessorVisitor() }
.forEach { (klass, visitor) -> visitor.visitClassDeclaration(klass, Unit) }

resolver
.getSymbolsWithAnnotation("com.google.firebase.ai.annotations.Tool")
.filterIsInstance<KSFunctionDeclaration>()
.map { it to FunctionSymbolProcessorVisitor(it, resolver) }
.forEach { (klass, visitor) -> visitor.visitFunctionDeclaration(klass, Unit) }

return emptyList()
}

private inner class FunctionSymbolProcessorVisitor(
private val func: KSFunctionDeclaration,
private val resolver: Resolver,
) : KSVisitorVoid() {
override fun visitFunctionDeclaration(function: KSFunctionDeclaration, data: Unit) {
var shouldError = false
val fullFunctionName = function.qualifiedName!!.asString()
if (!function.isPublic()) {
logger.warn("$fullFunctionName must be public.")
shouldError = true
}
val containingClass = function.parentDeclaration as? KSClassDeclaration
if (containingClass == null || !containingClass.isCompanionObject) {
logger.warn(
"$fullFunctionName must be within a companion object " +
containingClass!!.qualifiedName!!.asString()
)
shouldError = true
}
if (function.parameters.size != 1) {
logger.warn("$fullFunctionName must have exactly one parameter")
shouldError = true
}
val parameter = function.parameters.firstOrNull()?.type?.resolve()?.declaration
if (parameter != null) {
if (parameter.annotations.find { it.shortName.getShortName() == "Generable" } == null) {
logger.warn("$fullFunctionName parameter must be annotated @Generable")
shouldError = true
}
if (parameter.annotations.find { it.shortName.getShortName() == "Serializable" } == null) {
logger.warn("$fullFunctionName parameter must be annotated @Serializable")
shouldError = true
}
}
val output = function.returnType?.resolve()
if (
output != null &&
output.toClassName().canonicalName != "kotlinx.serialization.json.JsonObject"
) {
if (
output.declaration.annotations.find { it.shortName.getShortName() != "Generable" } == null
) {
logger.warn("$fullFunctionName output must be annotated @Generable")
shouldError = true
}
if (
output.declaration.annotations.find { it.shortName.getShortName() != "Serializable" } ==
null
) {
logger.warn("$fullFunctionName output must be annotated @Serializable")
shouldError = true
}
}
if (shouldError) {
logger.error("$fullFunctionName has one or more errors, please resolve them.")
}
val generatedFunctionFile = generateFileSpec(function)
generatedFunctionFile.writeTo(
codeGenerator,
Dependencies(true, function.containingFile!!),
)
}

private fun generateFileSpec(functionDeclaration: KSFunctionDeclaration): FileSpec {
val generatedClassName =
functionDeclaration.simpleName.asString().replaceFirstChar {
if (it.isLowerCase()) it.titlecase(Locale.ROOT) else it.toString()
} + "GeneratedFunctionDeclaration"
return FileSpec.builder(functionDeclaration.packageName.asString(), generatedClassName)
.addImport("com.google.firebase.ai.type", "AutoFunctionDeclaration")
.addType(
TypeSpec.classBuilder(generatedClassName)
.addAnnotation(Generated::class)
.addType(
TypeSpec.companionObjectBuilder()
.addProperty(
PropertySpec.builder(
"FUNCTION_DECLARATION",
ClassName("com.google.firebase.ai.type", "AutoFunctionDeclaration")
.parameterizedBy(
functionDeclaration.parameters.first().type.resolve().toClassName(),
functionDeclaration.returnType?.resolve()?.toClassName()
?: ClassName("kotlinx.serialization.json", "JsonObject")
),
KModifier.PUBLIC,
)
.mutable(false)
.initializer(
CodeBlock.builder()
.add(generateCodeBlockForFunctionDeclaration(functionDeclaration))
.build()
)
.build()
)
.build()
)
.build()
)
.build()
}

fun generateCodeBlockForFunctionDeclaration(
functionDeclaration: KSFunctionDeclaration
): CodeBlock {
val builder = CodeBlock.builder()
val hasTypedOutput =
!(functionDeclaration.returnType == null ||
functionDeclaration.returnType!!.resolve().toClassName().canonicalName ==
"kotlinx.serialization.json.JsonObject")
val kdocDescription = functionDeclaration.docString?.let { extractBaseKdoc(it) }
val annotationDescription =
getStringFromAnnotation(
functionDeclaration.annotations.find { it.shortName.getShortName() == "Tool" },
"description"
)
val description = annotationDescription ?: kdocDescription ?: ""
val inputSchemaName =
"${
functionDeclaration.parameters.first().type.resolve().toClassName().canonicalName
}GeneratedSchema.SCHEMA"
builder
.addStatement("AutoFunctionDeclaration.create(")
.indent()
.addStatement("functionName = %S,", functionDeclaration.simpleName.getShortName())
.addStatement("description = %S,", description)
.addStatement("inputSchema = $inputSchemaName,")
if (hasTypedOutput) {
val outputSchemaName =
"${
functionDeclaration.returnType!!.resolve().toClassName().canonicalName
}GeneratedSchema.SCHEMA"
builder.addStatement("outputSchema = $outputSchemaName,")
}
builder.addStatement(
"functionReference = " +
functionDeclaration.qualifiedName!!.getQualifier() +
"::${functionDeclaration.qualifiedName!!.getShortName()},"
)
builder.unindent().addStatement(")")
return builder.build()
}
}

private inner class SchemaSymbolProcessorVisitor() : KSVisitorVoid() {
private val numberTypes = setOf("kotlin.Int", "kotlin.Long", "kotlin.Double", "kotlin.Float")
private val baseKdocRegex = Regex("^\\s*(.*?)((@\\w* .*)|\\z)", RegexOption.DOT_MATCHES_ALL)
private val propertyKdocRegex =
Regex("\\s*@property (\\w*) (.*?)(?=@\\w*|\\z)", RegexOption.DOT_MATCHES_ALL)

override fun visitClassDeclaration(classDeclaration: KSClassDeclaration, data: Unit) {
val isDataClass = classDeclaration.modifiers.contains(Modifier.DATA)
Expand Down Expand Up @@ -240,7 +393,8 @@ public class SchemaSymbolProcessor(
}
if ((format != null || pattern != null) && className.canonicalName != "kotlin.String") {
logger.warn(
"${parentType?.toClassName()?.simpleName?.let { "$it." }}$name is not a String type, format and pattern are not a valid parameter to specify in @Guide"
"${parentType?.toClassName()?.simpleName?.let { "$it." }}$name is not a String type, " +
"format and pattern are not a valid parameter to specify in @Guide"
)
}
if (minimum != null) {
Expand All @@ -264,73 +418,72 @@ public class SchemaSymbolProcessor(
builder.addStatement("nullable = %L)", className.isNullable).unindent()
return builder.build()
}
}

private fun getDescriptionFromAnnotations(
guideAnnotation: KSAnnotation?,
guideClassAnnotation: KSAnnotation?,
description: String?,
baseKdoc: String?,
): String? {
val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, "description")

val guideClassDescription = getStringFromAnnotation(guideClassAnnotation, "description")
private fun getDescriptionFromAnnotations(
guideAnnotation: KSAnnotation?,
guideClassAnnotation: KSAnnotation?,
description: String?,
baseKdoc: String?,
): String? {
val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, "description")

return guidePropertyDescription ?: guideClassDescription ?: description ?: baseKdoc
}
val guideClassDescription = getStringFromAnnotation(guideClassAnnotation, "description")

private fun getDoubleFromAnnotation(
guideAnnotation: KSAnnotation?,
doubleName: String,
): Double? {
val guidePropertyDoubleValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true }
?.value as? Double
if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == -1.0) {
return null
}
return guidePropertyDoubleValue
return guidePropertyDescription ?: guideClassDescription ?: description ?: baseKdoc
}
private fun getDoubleFromAnnotation(
guideAnnotation: KSAnnotation?,
doubleName: String,
): Double? {
val guidePropertyDoubleValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true }
?.value as? Double
if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == -1.0) {
return null
}
return guidePropertyDoubleValue
}

private fun getIntFromAnnotation(guideAnnotation: KSAnnotation?, intName: String): Int? {
val guidePropertyIntValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(intName) == true }
?.value as? Int
if (guidePropertyIntValue == null || guidePropertyIntValue == -1) {
return null
}
return guidePropertyIntValue
private fun getIntFromAnnotation(guideAnnotation: KSAnnotation?, intName: String): Int? {
val guidePropertyIntValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(intName) == true }
?.value as? Int
if (guidePropertyIntValue == null || guidePropertyIntValue == -1) {
return null
}
return guidePropertyIntValue
}

private fun getStringFromAnnotation(
guideAnnotation: KSAnnotation?,
stringName: String,
): String? {
val guidePropertyStringValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true }
?.value as? String
if (guidePropertyStringValue.isNullOrEmpty()) {
return null
}
return guidePropertyStringValue
private fun getStringFromAnnotation(
guideAnnotation: KSAnnotation?,
stringName: String,
): String? {
val guidePropertyStringValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true }
?.value as? String
if (guidePropertyStringValue.isNullOrEmpty()) {
return null
}
return guidePropertyStringValue
}

private fun extractBaseKdoc(kdoc: String): String? {
return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1)?.value?.trim().let {
if (it.isNullOrEmpty()) null else it
}
private fun extractBaseKdoc(kdoc: String): String? {
return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1)?.value?.trim().let {
if (it.isNullOrEmpty()) null else it
}
}

private fun extractPropertyKdocs(kdoc: String): Map<String, String> {
return propertyKdocRegex
.findAll(kdoc)
.map { it.groups[1]!!.value to it.groups[2]!!.value.replace("\n", "").trim() }
.toMap()
}
private fun extractPropertyKdocs(kdoc: String): Map<String, String> {
return propertyKdocRegex
.findAll(kdoc)
.map { it.groups[1]!!.value to it.groups[2]!!.value.replace("\n", "").trim() }
.toMap()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import com.google.devtools.ksp.processing.SymbolProcessor
import com.google.devtools.ksp.processing.SymbolProcessorEnvironment
import com.google.devtools.ksp.processing.SymbolProcessorProvider

public class SchemaSymbolProcessorProvider : SymbolProcessorProvider {
public class FirebaseSymbolProcessorProvider : SymbolProcessorProvider {
override fun create(environment: SymbolProcessorEnvironment): SymbolProcessor {
return SchemaSymbolProcessor(environment.codeGenerator, environment.logger)
return FirebaseSymbolProcessor(environment.codeGenerator, environment.logger)
}
}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
com.google.firebase.ai.ksp.SchemaSymbolProcessorProvider
com.google.firebase.ai.ksp.FirebaseSymbolProcessorProvider
5 changes: 5 additions & 0 deletions firebase-ai/api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ package com.google.firebase.ai.annotations {
property public abstract String pattern;
}

@kotlin.annotation.Retention(kotlin.annotation.AnnotationRetention.SOURCE) @kotlin.annotation.Target(allowedTargets=kotlin.annotation.AnnotationTarget.FUNCTION) public @interface Tool {
method public abstract String description() default "";
property public abstract String description;
}

}

package com.google.firebase.ai.java {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

package com.google.firebase.ai.annotations

import com.google.firebase.ai.type.JsonSchema

/**
* This annotation is used with the firebase-ai-ksp-processor plugin to generate JsonSchema that
* This annotation is used with the firebase-ai-ksp-processor plugin to generate [JsonSchema] that
* match an existing kotlin class structure. For more info see:
* https://github.com/firebase/firebase-android-sdk/blob/main/firebase-ai-ksp-processor/README.md
*/
Expand Down
Loading