1
+ @file:OptIn(RequiresKotlinCompilerEmbeddable ::class )
2
+
1
3
package kotlinx.benchmark.gradle
2
4
3
5
import com.squareup.kotlinpoet.*
6
+ import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
7
+ import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.measureAnnotationFQN
8
+ import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.paramAnnotationFQN
9
+ import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.setupAnnotationFQN
10
+ import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.teardownAnnotationFQN
11
+ import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.warmupAnnotationFQN
12
+ import kotlinx.benchmark.gradle.internal.generator.RequiresKotlinCompilerEmbeddable
4
13
import java.io.File
5
14
import java.util.*
6
15
@@ -10,7 +19,11 @@ internal fun generateBenchmarkSourceFiles(
10
19
) {
11
20
classDescriptors.forEach { descriptor ->
12
21
if (descriptor.visibility == Visibility .PUBLIC && ! descriptor.isAbstract) {
13
- generateDescriptorFile(descriptor, targetDir)
22
+ if (descriptor.getSpecificField(paramAnnotationFQN).isNotEmpty()) {
23
+ generateParameterizedDescriptorFile(descriptor, targetDir)
24
+ } else {
25
+ generateDescriptorFile(descriptor, targetDir)
26
+ }
14
27
}
15
28
}
16
29
}
@@ -27,6 +40,12 @@ private fun generateDescriptorFile(descriptor: ClassAnnotationsDescriptor, andro
27
40
.addImport(" androidx.benchmark" , " BenchmarkState" )
28
41
.addImport(" androidx.benchmark" , " ExperimentalBenchmarkStateApi" )
29
42
43
+ if (descriptor.hasSetupOrTeardownMethods()) {
44
+ fileSpecBuilder
45
+ .addImport(" org.junit" , " Before" )
46
+ .addImport(" org.junit" , " After" )
47
+ }
48
+
30
49
val typeSpecBuilder = TypeSpec .classBuilder(descriptorName)
31
50
.addAnnotation(
32
51
AnnotationSpec .builder(ClassName (" org.junit.runner" , " RunWith" ))
@@ -40,7 +59,122 @@ private fun generateDescriptorFile(descriptor: ClassAnnotationsDescriptor, andro
40
59
fileSpecBuilder.build().writeTo(androidTestDir)
41
60
}
42
61
43
- private fun addBenchmarkMethods (typeSpecBuilder : TypeSpec .Builder , descriptor : ClassAnnotationsDescriptor ) {
62
+ private fun generateParameterizedDescriptorFile (descriptor : ClassAnnotationsDescriptor , androidTestDir : File ) {
63
+ val descriptorName = " ${descriptor.name} _Descriptor"
64
+ val packageName = descriptor.packageName
65
+ val fileSpecBuilder = FileSpec .builder(packageName, descriptorName)
66
+ .addImport(" org.junit.runner" , " RunWith" )
67
+ .addImport(" org.junit.runners" , " Parameterized" )
68
+ .addImport(" androidx.benchmark" , " BenchmarkState" )
69
+ .addImport(" androidx.benchmark" , " ExperimentalBenchmarkStateApi" )
70
+ .addImport(" org.junit" , " Test" )
71
+
72
+ if (descriptor.hasSetupOrTeardownMethods()) {
73
+ fileSpecBuilder
74
+ .addImport(" org.junit" , " Before" )
75
+ .addImport(" org.junit" , " After" )
76
+ }
77
+
78
+ fileSpecBuilder.addAnnotation(
79
+ AnnotationSpec .builder(ClassName (" org.junit.runner" , " RunWith" ))
80
+ .addMember(" %T::class" , ClassName (" org.junit.runners" , " Parameterized" ))
81
+ .build()
82
+ )
83
+
84
+ // Generate constructor
85
+ val constructorSpec = FunSpec .constructorBuilder()
86
+ val paramFields = descriptor.getSpecificField(paramAnnotationFQN)
87
+ paramFields.forEach { param ->
88
+ constructorSpec.addParameter(param.name, getTypeName(param.type))
89
+ }
90
+
91
+ val typeSpecBuilder = TypeSpec .classBuilder(descriptorName)
92
+ .primaryConstructor(constructorSpec.build())
93
+ .addProperties(paramFields.map { param ->
94
+ PropertySpec .builder(param.name, getTypeName(param.type))
95
+ .initializer(param.name)
96
+ .addModifiers(KModifier .PRIVATE )
97
+ .build()
98
+ })
99
+
100
+ addBenchmarkMethods(typeSpecBuilder, descriptor, true )
101
+
102
+ // Generate companion object with parameters
103
+ val companionSpec = TypeSpec .companionObjectBuilder()
104
+ .addFunction(generateParametersFunction(paramFields))
105
+ .build()
106
+
107
+ typeSpecBuilder.addType(companionSpec)
108
+
109
+ fileSpecBuilder.addType(typeSpecBuilder.build())
110
+ fileSpecBuilder.build().writeTo(androidTestDir)
111
+ }
112
+
113
+ private fun generateParametersFunction (paramFields : List <FieldAnnotationsDescriptor >): FunSpec {
114
+ val dataFunctionBuilder = FunSpec .builder(" data" )
115
+ .addAnnotation(JvmStatic ::class )
116
+ .returns(
117
+ ClassName (" java.util" , " Collection" )
118
+ .parameterizedBy(
119
+ ClassName (" kotlin" , " Array" )
120
+ .parameterizedBy(ANY )
121
+ )
122
+ )
123
+
124
+ val paramNameAndIndex = paramFields.mapIndexed { index, param ->
125
+ " ${param.name} ={${index} }"
126
+ }.joinToString(" , " )
127
+
128
+ val paramAnnotationValue = " {index}: $paramNameAndIndex "
129
+
130
+ dataFunctionBuilder.addAnnotation(
131
+ AnnotationSpec .builder(ClassName (" org.junit.runners" , " Parameterized.Parameters" ))
132
+ .addMember(" name = \" %L\" " , paramAnnotationValue)
133
+ .build()
134
+ )
135
+
136
+ val paramValueLists = paramFields.map { param ->
137
+ val values = param.annotations
138
+ .find { it.name == paramAnnotationFQN }
139
+ ?.parameters?.get(" value" ) as List <* >
140
+
141
+ values.map { value ->
142
+ if (param.type == " java.lang.String" ) {
143
+ " \"\"\" $value \"\"\" "
144
+ } else {
145
+ value.toString()
146
+ }
147
+ }
148
+ }
149
+
150
+ val cartesianProduct = cartesianProduct(paramValueLists as List <List <Any >>)
151
+
152
+ val returnStatement = StringBuilder (" return listOf(\n " )
153
+ cartesianProduct.forEachIndexed { index, combination ->
154
+ val arrayContent = combination.joinToString(" , " )
155
+ returnStatement.append(" arrayOf($arrayContent )" )
156
+ if (index != cartesianProduct.size - 1 ) {
157
+ returnStatement.append(" ,\n " )
158
+ }
159
+ }
160
+ returnStatement.append(" \n )" )
161
+ dataFunctionBuilder.addStatement(returnStatement.toString())
162
+
163
+ return dataFunctionBuilder.build()
164
+ }
165
+
166
+ private fun cartesianProduct (lists : List <List <Any >>): List <List <Any >> {
167
+ if (lists.isEmpty()) return emptyList()
168
+ return lists.fold(listOf (listOf<Any >())) { acc, list ->
169
+ acc.flatMap { prefix -> list.map { value -> prefix + value } }
170
+ }
171
+ }
172
+
173
+ private fun addBenchmarkMethods (
174
+ typeSpecBuilder : TypeSpec .Builder ,
175
+ descriptor : ClassAnnotationsDescriptor ,
176
+ isParameterized : Boolean = false
177
+ ) {
44
178
val className = " ${descriptor.packageName} .${descriptor.name} "
45
179
val propertyName = descriptor.name.decapitalize(Locale .getDefault())
46
180
@@ -55,70 +189,106 @@ private fun addBenchmarkMethods(typeSpecBuilder: TypeSpec.Builder, descriptor: C
55
189
descriptor.methods
56
190
.filter { it.visibility == Visibility .PUBLIC && it.parameters.isEmpty() }
57
191
.filterNot { method ->
58
- method.annotations.any { annotation -> annotation.name == " kotlinx.benchmark.Param " }
192
+ method.annotations.any { annotation -> annotation.name == paramAnnotationFQN }
59
193
}
60
194
.forEach { method ->
61
195
when {
62
- method.annotations.any { it.name == " kotlinx.benchmark.Setup " || it.name == " kotlinx.benchmark.TearDown " } -> {
196
+ method.annotations.any { it.name == setupAnnotationFQN || it.name == teardownAnnotationFQN } -> {
63
197
generateNonMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder)
64
198
}
199
+
200
+ isParameterized && descriptor.getSpecificField(paramAnnotationFQN).isNotEmpty() -> {
201
+ generateParameterizedMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder)
202
+ }
203
+
65
204
else -> {
66
205
generateMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder)
67
206
}
68
207
}
69
208
}
70
209
}
71
210
72
- private fun generateMeasurableMethod (
211
+ private fun generateCommonMeasurableMethod (
73
212
descriptor : ClassAnnotationsDescriptor ,
74
213
method : MethodAnnotationsDescriptor ,
75
214
propertyName : String ,
76
- typeSpecBuilder : TypeSpec .Builder
215
+ typeSpecBuilder : TypeSpec .Builder ,
216
+ isParameterized : Boolean
77
217
) {
78
218
val measurementIterations = descriptor.annotations
79
- .find { it.name == " kotlinx.benchmark.Measurement " }
219
+ .find { it.name == measureAnnotationFQN }
80
220
?.parameters?.get(" iterations" ) as ? Int ? : 5
81
221
val warmupIterations = descriptor.annotations
82
- .find { it.name == " kotlinx.benchmark.Warmup " }
222
+ .find { it.name == warmupAnnotationFQN }
83
223
?.parameters?.get(" iterations" ) as ? Int ? : 5
84
224
225
+ val methodName = " ${descriptor.packageName} .${descriptor.name} .${method.name} "
226
+
85
227
val methodSpecBuilder = FunSpec .builder(" benchmark_${descriptor.name} _${method.name} " )
86
228
.addAnnotation(ClassName (" org.junit" , " Test" ))
87
229
.addAnnotation(
88
230
AnnotationSpec .builder(ClassName (" kotlin" , " OptIn" ))
89
231
.addMember(" %T::class" , ClassName (" androidx.benchmark" , " ExperimentalBenchmarkStateApi" ))
90
232
.build()
91
233
)
92
- // TODO: Add warmupCount and repeatCount parameters
234
+
235
+ if (isParameterized) {
236
+ descriptor.getSpecificField(paramAnnotationFQN).forEach { field ->
237
+ methodSpecBuilder.addStatement(" $propertyName .${field.name} = ${field.name} " )
238
+ }
239
+ }
240
+
241
+ methodSpecBuilder
93
242
.addStatement(
94
243
" val state = %T(warmupCount = $warmupIterations , repeatCount = $measurementIterations )" ,
95
244
ClassName (" androidx.benchmark" , " BenchmarkState" )
96
245
)
246
+ .addStatement(" println(\" Android: $methodName \" )" )
97
247
.beginControlFlow(" while (state.keepRunning())" )
98
248
.addStatement(" $propertyName .${method.name} ()" )
99
249
.endControlFlow()
100
250
.addStatement(" val measurementResult = state.getMeasurementTimeNs()" )
101
251
.beginControlFlow(" measurementResult.forEachIndexed { index, time ->" )
102
252
.addStatement(" println(\" Iteration \$ {index + 1}: \$ time ns\" )" )
103
253
.endControlFlow()
254
+
104
255
typeSpecBuilder.addFunction(methodSpecBuilder.build())
105
256
}
106
257
258
+ private fun generateParameterizedMeasurableMethod (
259
+ descriptor : ClassAnnotationsDescriptor ,
260
+ method : MethodAnnotationsDescriptor ,
261
+ propertyName : String ,
262
+ typeSpecBuilder : TypeSpec .Builder
263
+ ) {
264
+ generateCommonMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder, isParameterized = true )
265
+ }
266
+
267
+ private fun generateMeasurableMethod (
268
+ descriptor : ClassAnnotationsDescriptor ,
269
+ method : MethodAnnotationsDescriptor ,
270
+ propertyName : String ,
271
+ typeSpecBuilder : TypeSpec .Builder
272
+ ) {
273
+ generateCommonMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder, isParameterized = false )
274
+ }
275
+
276
+
107
277
private fun generateNonMeasurableMethod (
108
278
descriptor : ClassAnnotationsDescriptor ,
109
279
method : MethodAnnotationsDescriptor ,
110
280
propertyName : String ,
111
281
typeSpecBuilder : TypeSpec .Builder
112
282
) {
113
283
when (method.annotations.first().name) {
114
- " kotlinx.benchmark.Setup " -> {
284
+ setupAnnotationFQN -> {
115
285
val methodSpecBuilder = FunSpec .builder(" benchmark_${descriptor.name} _setUp" )
116
286
.addAnnotation(ClassName (" org.junit" , " Before" ))
117
287
.addStatement(" $propertyName .${method.name} ()" )
118
288
typeSpecBuilder.addFunction(methodSpecBuilder.build())
119
289
}
120
290
121
- " kotlinx.benchmark.TearDown " -> {
291
+ teardownAnnotationFQN -> {
122
292
val methodSpecBuilder = FunSpec .builder(" benchmark_${descriptor.name} _tearDown" )
123
293
.addAnnotation(ClassName (" org.junit" , " After" ))
124
294
.addStatement(" $propertyName .${method.name} ()" )
@@ -127,49 +297,16 @@ private fun generateNonMeasurableMethod(
127
297
}
128
298
}
129
299
130
- private fun updateAndroidDependencies (buildGradleFile : File , dependencies : List <Pair <String , String ?>>) {
131
- if (buildGradleFile.exists()) {
132
- val buildGradleContent = buildGradleFile.readText()
133
-
134
- if (buildGradleContent.contains(" android {" )) {
135
- val androidBlockStart = buildGradleContent.indexOf(" android {" )
136
- val androidBlockEnd = buildGradleContent.lastIndexOf(" }" ) + 1
137
- val androidBlockContent = buildGradleContent.substring(androidBlockStart, androidBlockEnd)
138
-
139
- val newDependencies = dependencies.filterNot { (dependency, version) ->
140
- val dependencyString = version?.let { """ $dependency :$version """ } ? : dependency
141
- androidBlockContent.contains(dependencyString)
142
- }
143
- if (newDependencies.isNotEmpty()) {
144
- val updatedAndroidBlockContent = if (androidBlockContent.contains(" dependencies {" )) {
145
- val dependenciesBlockStart = androidBlockContent.indexOf(" dependencies {" )
146
- val dependenciesBlockEnd = androidBlockContent.indexOf(" }" , dependenciesBlockStart) + 1
147
- val dependenciesBlockContent =
148
- androidBlockContent.substring(dependenciesBlockStart, dependenciesBlockEnd)
149
-
150
- val newDependenciesString = newDependencies.joinToString(" \n " ) { (dependency, version) ->
151
- version?.let { """ androidTestImplementation("$dependency :$version ")""" }
152
- ? : """ androidTestImplementation(files("$dependency "))"""
153
- }
154
- androidBlockContent.replace(
155
- dependenciesBlockContent,
156
- dependenciesBlockContent.replace(
157
- " dependencies {" ,
158
- " dependencies {\n $newDependenciesString "
159
- )
160
- )
161
- } else {
162
- val newDependenciesString = newDependencies.joinToString(" \n " ) { (dependency, version) ->
163
- version?.let { """ androidTestImplementation("$dependency :$version ")""" }
164
- ? : """ androidTestImplementation(files("$dependency "))"""
165
- }
166
- androidBlockContent.replace(" {" , " {\n dependencies {\n $newDependenciesString \n }\n " )
167
- }
168
-
169
- val updatedBuildGradleContent =
170
- buildGradleContent.replace(androidBlockContent, updatedAndroidBlockContent)
171
- buildGradleFile.writeText(updatedBuildGradleContent)
172
- }
173
- }
300
+ private fun getTypeName (type : String ): TypeName {
301
+ return when (type) {
302
+ " int" -> Int ::class .asTypeName()
303
+ " long" -> Long ::class .asTypeName()
304
+ " boolean" -> Boolean ::class .asTypeName()
305
+ " float" -> Float ::class .asTypeName()
306
+ " double" -> Double ::class .asTypeName()
307
+ " char" -> Char ::class .asTypeName()
308
+ " byte" -> Byte ::class .asTypeName()
309
+ " short" -> Short ::class .asTypeName()
310
+ else -> ClassName .bestGuess(type)
174
311
}
175
- }
312
+ }
0 commit comments