Skip to content

Commit 2e3f863

Browse files
committed
Map @param annotation to androidx.benchmark
1 parent 097c69f commit 2e3f863

File tree

3 files changed

+212
-64
lines changed

3 files changed

+212
-64
lines changed

plugin/main/src/kotlinx/benchmark/gradle/AndroidMultiplatformTasks.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ private fun Project.createSetupAndroidProjectTask(target: AndroidBenchmarkTarget
4747
val unpackedDir = getUnpackAarDir(compilation)
4848
val newText = it.readText().replace(
4949
"<<BENCHMARK_CLASSES_JAR_PATH>>",
50-
unpackedDir.resolve("classes.jar").absolutePath
50+
unpackedDir.resolve("classes.jar").absolutePath.replace("\\", "/")
5151
)
5252
it.writeText(newText)
5353
}
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1+
@file:OptIn(RequiresKotlinCompilerEmbeddable::class)
2+
13
package kotlinx.benchmark.gradle
24

35
import com.squareup.kotlinpoet.*
6+
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
7+
import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.measureAnnotationFQN
8+
import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.paramAnnotationFQN
9+
import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.setupAnnotationFQN
10+
import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.teardownAnnotationFQN
11+
import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.warmupAnnotationFQN
12+
import kotlinx.benchmark.gradle.internal.generator.RequiresKotlinCompilerEmbeddable
413
import java.io.File
514
import java.util.*
615

@@ -10,7 +19,11 @@ internal fun generateBenchmarkSourceFiles(
1019
) {
1120
classDescriptors.forEach { descriptor ->
1221
if (descriptor.visibility == Visibility.PUBLIC && !descriptor.isAbstract) {
13-
generateDescriptorFile(descriptor, targetDir)
22+
if (descriptor.getSpecificField(paramAnnotationFQN).isNotEmpty()) {
23+
generateParameterizedDescriptorFile(descriptor, targetDir)
24+
} else {
25+
generateDescriptorFile(descriptor, targetDir)
26+
}
1427
}
1528
}
1629
}
@@ -27,6 +40,12 @@ private fun generateDescriptorFile(descriptor: ClassAnnotationsDescriptor, andro
2740
.addImport("androidx.benchmark", "BenchmarkState")
2841
.addImport("androidx.benchmark", "ExperimentalBenchmarkStateApi")
2942

43+
if (descriptor.hasSetupOrTeardownMethods()) {
44+
fileSpecBuilder
45+
.addImport("org.junit", "Before")
46+
.addImport("org.junit", "After")
47+
}
48+
3049
val typeSpecBuilder = TypeSpec.classBuilder(descriptorName)
3150
.addAnnotation(
3251
AnnotationSpec.builder(ClassName("org.junit.runner", "RunWith"))
@@ -40,7 +59,122 @@ private fun generateDescriptorFile(descriptor: ClassAnnotationsDescriptor, andro
4059
fileSpecBuilder.build().writeTo(androidTestDir)
4160
}
4261

43-
private fun addBenchmarkMethods(typeSpecBuilder: TypeSpec.Builder, descriptor: ClassAnnotationsDescriptor) {
62+
private fun generateParameterizedDescriptorFile(descriptor: ClassAnnotationsDescriptor, androidTestDir: File) {
63+
val descriptorName = "${descriptor.name}_Descriptor"
64+
val packageName = descriptor.packageName
65+
val fileSpecBuilder = FileSpec.builder(packageName, descriptorName)
66+
.addImport("org.junit.runner", "RunWith")
67+
.addImport("org.junit.runners", "Parameterized")
68+
.addImport("androidx.benchmark", "BenchmarkState")
69+
.addImport("androidx.benchmark", "ExperimentalBenchmarkStateApi")
70+
.addImport("org.junit", "Test")
71+
72+
if (descriptor.hasSetupOrTeardownMethods()) {
73+
fileSpecBuilder
74+
.addImport("org.junit", "Before")
75+
.addImport("org.junit", "After")
76+
}
77+
78+
fileSpecBuilder.addAnnotation(
79+
AnnotationSpec.builder(ClassName("org.junit.runner", "RunWith"))
80+
.addMember("%T::class", ClassName("org.junit.runners", "Parameterized"))
81+
.build()
82+
)
83+
84+
// Generate constructor
85+
val constructorSpec = FunSpec.constructorBuilder()
86+
val paramFields = descriptor.getSpecificField(paramAnnotationFQN)
87+
paramFields.forEach { param ->
88+
constructorSpec.addParameter(param.name, getTypeName(param.type))
89+
}
90+
91+
val typeSpecBuilder = TypeSpec.classBuilder(descriptorName)
92+
.primaryConstructor(constructorSpec.build())
93+
.addProperties(paramFields.map { param ->
94+
PropertySpec.builder(param.name, getTypeName(param.type))
95+
.initializer(param.name)
96+
.addModifiers(KModifier.PRIVATE)
97+
.build()
98+
})
99+
100+
addBenchmarkMethods(typeSpecBuilder, descriptor, true)
101+
102+
// Generate companion object with parameters
103+
val companionSpec = TypeSpec.companionObjectBuilder()
104+
.addFunction(generateParametersFunction(paramFields))
105+
.build()
106+
107+
typeSpecBuilder.addType(companionSpec)
108+
109+
fileSpecBuilder.addType(typeSpecBuilder.build())
110+
fileSpecBuilder.build().writeTo(androidTestDir)
111+
}
112+
113+
private fun generateParametersFunction(paramFields: List<FieldAnnotationsDescriptor>): FunSpec {
114+
val dataFunctionBuilder = FunSpec.builder("data")
115+
.addAnnotation(JvmStatic::class)
116+
.returns(
117+
ClassName("java.util", "Collection")
118+
.parameterizedBy(
119+
ClassName("kotlin", "Array")
120+
.parameterizedBy(ANY)
121+
)
122+
)
123+
124+
val paramNameAndIndex = paramFields.mapIndexed { index, param ->
125+
"${param.name}={${index}}"
126+
}.joinToString(", ")
127+
128+
val paramAnnotationValue = "{index}: $paramNameAndIndex"
129+
130+
dataFunctionBuilder.addAnnotation(
131+
AnnotationSpec.builder(ClassName("org.junit.runners", "Parameterized.Parameters"))
132+
.addMember("name = \"%L\"", paramAnnotationValue)
133+
.build()
134+
)
135+
136+
val paramValueLists = paramFields.map { param ->
137+
val values = param.annotations
138+
.find { it.name == paramAnnotationFQN }
139+
?.parameters?.get("value") as List<*>
140+
141+
values.map { value ->
142+
if (param.type == "java.lang.String") {
143+
"\"\"\"$value\"\"\""
144+
} else {
145+
value.toString()
146+
}
147+
}
148+
}
149+
150+
val cartesianProduct = cartesianProduct(paramValueLists as List<List<Any>>)
151+
152+
val returnStatement = StringBuilder("return listOf(\n")
153+
cartesianProduct.forEachIndexed { index, combination ->
154+
val arrayContent = combination.joinToString(", ")
155+
returnStatement.append(" arrayOf($arrayContent)")
156+
if (index != cartesianProduct.size - 1) {
157+
returnStatement.append(",\n")
158+
}
159+
}
160+
returnStatement.append("\n)")
161+
dataFunctionBuilder.addStatement(returnStatement.toString())
162+
163+
return dataFunctionBuilder.build()
164+
}
165+
166+
private fun cartesianProduct(lists: List<List<Any>>): List<List<Any>> {
167+
if (lists.isEmpty()) return emptyList()
168+
return lists.fold(listOf(listOf<Any>())) { acc, list ->
169+
acc.flatMap { prefix -> list.map { value -> prefix + value } }
170+
}
171+
}
172+
173+
private fun addBenchmarkMethods(
174+
typeSpecBuilder: TypeSpec.Builder,
175+
descriptor: ClassAnnotationsDescriptor,
176+
isParameterized: Boolean = false
177+
) {
44178
val className = "${descriptor.packageName}.${descriptor.name}"
45179
val propertyName = descriptor.name.decapitalize(Locale.getDefault())
46180

@@ -55,70 +189,106 @@ private fun addBenchmarkMethods(typeSpecBuilder: TypeSpec.Builder, descriptor: C
55189
descriptor.methods
56190
.filter { it.visibility == Visibility.PUBLIC && it.parameters.isEmpty() }
57191
.filterNot { method ->
58-
method.annotations.any { annotation -> annotation.name == "kotlinx.benchmark.Param" }
192+
method.annotations.any { annotation -> annotation.name == paramAnnotationFQN }
59193
}
60194
.forEach { method ->
61195
when {
62-
method.annotations.any { it.name == "kotlinx.benchmark.Setup" || it.name == "kotlinx.benchmark.TearDown" } -> {
196+
method.annotations.any { it.name == setupAnnotationFQN || it.name == teardownAnnotationFQN } -> {
63197
generateNonMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder)
64198
}
199+
200+
isParameterized && descriptor.getSpecificField(paramAnnotationFQN).isNotEmpty() -> {
201+
generateParameterizedMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder)
202+
}
203+
65204
else -> {
66205
generateMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder)
67206
}
68207
}
69208
}
70209
}
71210

72-
private fun generateMeasurableMethod(
211+
private fun generateCommonMeasurableMethod(
73212
descriptor: ClassAnnotationsDescriptor,
74213
method: MethodAnnotationsDescriptor,
75214
propertyName: String,
76-
typeSpecBuilder: TypeSpec.Builder
215+
typeSpecBuilder: TypeSpec.Builder,
216+
isParameterized: Boolean
77217
) {
78218
val measurementIterations = descriptor.annotations
79-
.find { it.name == "kotlinx.benchmark.Measurement" }
219+
.find { it.name == measureAnnotationFQN }
80220
?.parameters?.get("iterations") as? Int ?: 5
81221
val warmupIterations = descriptor.annotations
82-
.find { it.name == "kotlinx.benchmark.Warmup" }
222+
.find { it.name == warmupAnnotationFQN }
83223
?.parameters?.get("iterations") as? Int ?: 5
84224

225+
val methodName = "${descriptor.packageName}.${descriptor.name}.${method.name}"
226+
85227
val methodSpecBuilder = FunSpec.builder("benchmark_${descriptor.name}_${method.name}")
86228
.addAnnotation(ClassName("org.junit", "Test"))
87229
.addAnnotation(
88230
AnnotationSpec.builder(ClassName("kotlin", "OptIn"))
89231
.addMember("%T::class", ClassName("androidx.benchmark", "ExperimentalBenchmarkStateApi"))
90232
.build()
91233
)
92-
// TODO: Add warmupCount and repeatCount parameters
234+
235+
if (isParameterized) {
236+
descriptor.getSpecificField(paramAnnotationFQN).forEach { field ->
237+
methodSpecBuilder.addStatement("$propertyName.${field.name} = ${field.name}")
238+
}
239+
}
240+
241+
methodSpecBuilder
93242
.addStatement(
94243
"val state = %T(warmupCount = $warmupIterations, repeatCount = $measurementIterations)",
95244
ClassName("androidx.benchmark", "BenchmarkState")
96245
)
246+
.addStatement("println(\"Android: $methodName\")")
97247
.beginControlFlow("while (state.keepRunning())")
98248
.addStatement("$propertyName.${method.name}()")
99249
.endControlFlow()
100250
.addStatement("val measurementResult = state.getMeasurementTimeNs()")
101251
.beginControlFlow("measurementResult.forEachIndexed { index, time ->")
102252
.addStatement("println(\"Iteration \${index + 1}: \$time ns\")")
103253
.endControlFlow()
254+
104255
typeSpecBuilder.addFunction(methodSpecBuilder.build())
105256
}
106257

258+
private fun generateParameterizedMeasurableMethod(
259+
descriptor: ClassAnnotationsDescriptor,
260+
method: MethodAnnotationsDescriptor,
261+
propertyName: String,
262+
typeSpecBuilder: TypeSpec.Builder
263+
) {
264+
generateCommonMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder, isParameterized = true)
265+
}
266+
267+
private fun generateMeasurableMethod(
268+
descriptor: ClassAnnotationsDescriptor,
269+
method: MethodAnnotationsDescriptor,
270+
propertyName: String,
271+
typeSpecBuilder: TypeSpec.Builder
272+
) {
273+
generateCommonMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder, isParameterized = false)
274+
}
275+
276+
107277
private fun generateNonMeasurableMethod(
108278
descriptor: ClassAnnotationsDescriptor,
109279
method: MethodAnnotationsDescriptor,
110280
propertyName: String,
111281
typeSpecBuilder: TypeSpec.Builder
112282
) {
113283
when (method.annotations.first().name) {
114-
"kotlinx.benchmark.Setup" -> {
284+
setupAnnotationFQN -> {
115285
val methodSpecBuilder = FunSpec.builder("benchmark_${descriptor.name}_setUp")
116286
.addAnnotation(ClassName("org.junit", "Before"))
117287
.addStatement("$propertyName.${method.name}()")
118288
typeSpecBuilder.addFunction(methodSpecBuilder.build())
119289
}
120290

121-
"kotlinx.benchmark.TearDown" -> {
291+
teardownAnnotationFQN -> {
122292
val methodSpecBuilder = FunSpec.builder("benchmark_${descriptor.name}_tearDown")
123293
.addAnnotation(ClassName("org.junit", "After"))
124294
.addStatement("$propertyName.${method.name}()")
@@ -127,49 +297,16 @@ private fun generateNonMeasurableMethod(
127297
}
128298
}
129299

130-
private fun updateAndroidDependencies(buildGradleFile: File, dependencies: List<Pair<String, String?>>) {
131-
if (buildGradleFile.exists()) {
132-
val buildGradleContent = buildGradleFile.readText()
133-
134-
if (buildGradleContent.contains("android {")) {
135-
val androidBlockStart = buildGradleContent.indexOf("android {")
136-
val androidBlockEnd = buildGradleContent.lastIndexOf("}") + 1
137-
val androidBlockContent = buildGradleContent.substring(androidBlockStart, androidBlockEnd)
138-
139-
val newDependencies = dependencies.filterNot { (dependency, version) ->
140-
val dependencyString = version?.let { """$dependency:$version""" } ?: dependency
141-
androidBlockContent.contains(dependencyString)
142-
}
143-
if (newDependencies.isNotEmpty()) {
144-
val updatedAndroidBlockContent = if (androidBlockContent.contains("dependencies {")) {
145-
val dependenciesBlockStart = androidBlockContent.indexOf("dependencies {")
146-
val dependenciesBlockEnd = androidBlockContent.indexOf("}", dependenciesBlockStart) + 1
147-
val dependenciesBlockContent =
148-
androidBlockContent.substring(dependenciesBlockStart, dependenciesBlockEnd)
149-
150-
val newDependenciesString = newDependencies.joinToString("\n ") { (dependency, version) ->
151-
version?.let { """androidTestImplementation("$dependency:$version")""" }
152-
?: """androidTestImplementation(files("$dependency"))"""
153-
}
154-
androidBlockContent.replace(
155-
dependenciesBlockContent,
156-
dependenciesBlockContent.replace(
157-
"dependencies {",
158-
"dependencies {\n $newDependenciesString"
159-
)
160-
)
161-
} else {
162-
val newDependenciesString = newDependencies.joinToString("\n ") { (dependency, version) ->
163-
version?.let { """androidTestImplementation("$dependency:$version")""" }
164-
?: """androidTestImplementation(files("$dependency"))"""
165-
}
166-
androidBlockContent.replace("{", "{\n dependencies {\n $newDependenciesString\n }\n")
167-
}
168-
169-
val updatedBuildGradleContent =
170-
buildGradleContent.replace(androidBlockContent, updatedAndroidBlockContent)
171-
buildGradleFile.writeText(updatedBuildGradleContent)
172-
}
173-
}
300+
private fun getTypeName(type: String): TypeName {
301+
return when (type) {
302+
"int" -> Int::class.asTypeName()
303+
"long" -> Long::class.asTypeName()
304+
"boolean" -> Boolean::class.asTypeName()
305+
"float" -> Float::class.asTypeName()
306+
"double" -> Double::class.asTypeName()
307+
"char" -> Char::class.asTypeName()
308+
"byte" -> Byte::class.asTypeName()
309+
"short" -> Short::class.asTypeName()
310+
else -> ClassName.bestGuess(type)
174311
}
175-
}
312+
}

0 commit comments

Comments
 (0)