Skip to content

Commit 3f42d44

Browse files
committed
Data type helpers
Signed-off-by: Ryan Nett <[email protected]>
1 parent c1ab496 commit 3f42d44

File tree

1 file changed

+54
-0
lines changed
  • tensorflow-core-kotlin/tensorflow-core-kotlin-api/src/main/kotlin/org/tensorflow/op

1 file changed

+54
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
17+
package org.tensorflow.op
18+
19+
import org.tensorflow.internal.types.registry.TensorTypeRegistry
20+
import org.tensorflow.proto.framework.DataType
21+
import org.tensorflow.types.family.TType
22+
import kotlin.reflect.KClass
23+
24+
/**
25+
* Converts a tensor type class to a [DataType] attribute.
26+
*
27+
* @return data type
28+
* @see Operands.toDataType
29+
*/
30+
public fun <T: TType> Class<T>.dataType(): DataType = Operands.toDataType(this)
31+
32+
/**
33+
* Converts a tensor type class to a [DataType] attribute.
34+
*
35+
* @return data type
36+
* @see Operands.toDataType
37+
*/
38+
public fun <T: TType> KClass<T>.dataType(): DataType = Operands.toDataType(this.java)
39+
40+
/**
41+
* Converts a tensor type class to a [DataType] attribute.
42+
*
43+
* @return data type
44+
* @see Operands.toDataType
45+
*/
46+
public inline fun <reified T: TType> dataType(): DataType = T::class.dataType()
47+
48+
/**
49+
* Converts a [DataType] attribute to a tensor type class.
50+
*
51+
* @return the tensor type class
52+
* @see TensorTypeRegistry.find
53+
*/
54+
public fun <T: TType> DataType.tType(): Class<T> = TensorTypeRegistry.find<T>(this).type()

0 commit comments

Comments
 (0)