Skip to content

Commit c13c352

Browse files
committed
[Refactor](UDF) Refactor the java udf code to reduce the unless code
1 parent e042e9f commit c13c352

File tree

5 files changed

+92
-113
lines changed

5 files changed

+92
-113
lines changed

fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/JavaUdfDataType.java

+56-41
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
import java.math.BigInteger;
3030
import java.net.InetAddress;
3131
import java.util.ArrayList;
32-
import java.util.HashSet;
32+
import java.util.HashMap;
33+
import java.util.Map;
3334
import java.util.Set;
3435

3536
// Data types that are supported as return or argument types in Java UDFs.
@@ -63,32 +64,36 @@ public class JavaUdfDataType {
6364
public static final JavaUdfDataType MAP_TYPE = new JavaUdfDataType("MAP_TYPE", TPrimitiveType.MAP, 0);
6465
public static final JavaUdfDataType STRUCT_TYPE = new JavaUdfDataType("STRUCT_TYPE", TPrimitiveType.STRUCT, 0);
6566

66-
private static Set<JavaUdfDataType> JavaUdfDataTypeSet = new HashSet<>();
67+
private static final Map<TPrimitiveType, JavaUdfDataType> javaUdfDataTypeMap = new HashMap<>();
68+
69+
public static void addJavaUdfDataType(JavaUdfDataType dataType) {
70+
javaUdfDataTypeMap.put(dataType.getPrimitiveType(), dataType);
71+
}
6772

6873
static {
69-
JavaUdfDataTypeSet.add(INVALID_TYPE);
70-
JavaUdfDataTypeSet.add(BOOLEAN);
71-
JavaUdfDataTypeSet.add(TINYINT);
72-
JavaUdfDataTypeSet.add(SMALLINT);
73-
JavaUdfDataTypeSet.add(INT);
74-
JavaUdfDataTypeSet.add(BIGINT);
75-
JavaUdfDataTypeSet.add(FLOAT);
76-
JavaUdfDataTypeSet.add(DOUBLE);
77-
JavaUdfDataTypeSet.add(STRING);
78-
JavaUdfDataTypeSet.add(DATE);
79-
JavaUdfDataTypeSet.add(DATETIME);
80-
JavaUdfDataTypeSet.add(LARGEINT);
81-
JavaUdfDataTypeSet.add(DECIMALV2);
82-
JavaUdfDataTypeSet.add(DATEV2);
83-
JavaUdfDataTypeSet.add(DATETIMEV2);
84-
JavaUdfDataTypeSet.add(DECIMAL32);
85-
JavaUdfDataTypeSet.add(DECIMAL64);
86-
JavaUdfDataTypeSet.add(DECIMAL128);
87-
JavaUdfDataTypeSet.add(ARRAY_TYPE);
88-
JavaUdfDataTypeSet.add(MAP_TYPE);
89-
JavaUdfDataTypeSet.add(STRUCT_TYPE);
90-
JavaUdfDataTypeSet.add(IPV4);
91-
JavaUdfDataTypeSet.add(IPV6);
74+
addJavaUdfDataType(INVALID_TYPE);
75+
addJavaUdfDataType(BOOLEAN);
76+
addJavaUdfDataType(TINYINT);
77+
addJavaUdfDataType(SMALLINT);
78+
addJavaUdfDataType(INT);
79+
addJavaUdfDataType(BIGINT);
80+
addJavaUdfDataType(FLOAT);
81+
addJavaUdfDataType(DOUBLE);
82+
addJavaUdfDataType(STRING);
83+
addJavaUdfDataType(DATE);
84+
addJavaUdfDataType(DATETIME);
85+
addJavaUdfDataType(LARGEINT);
86+
addJavaUdfDataType(DECIMALV2);
87+
addJavaUdfDataType(DATEV2);
88+
addJavaUdfDataType(DATETIMEV2);
89+
addJavaUdfDataType(DECIMAL32);
90+
addJavaUdfDataType(DECIMAL64);
91+
addJavaUdfDataType(DECIMAL128);
92+
addJavaUdfDataType(ARRAY_TYPE);
93+
addJavaUdfDataType(MAP_TYPE);
94+
addJavaUdfDataType(STRUCT_TYPE);
95+
addJavaUdfDataType(IPV4);
96+
addJavaUdfDataType(IPV6);
9297
}
9398

9499
private final String description;
@@ -117,17 +122,33 @@ public JavaUdfDataType(JavaUdfDataType other) {
117122

118123
@Override
119124
public String toString() {
120-
return description;
121-
}
125+
StringBuilder res = new StringBuilder();
126+
res.append(description);
127+
// TODO: the item/key/value type should be dispose in child class
128+
if (getItemType() != null) {
129+
res.append(" item: ").append(getItemType().toString()).append(" sql: ")
130+
.append(getItemType().toSql());
131+
}
132+
if (getKeyType() != null) {
133+
res.append(" key: ").append(getKeyType().toString()).append(" sql: ")
134+
.append(getKeyType().toSql());
135+
}
136+
if (getValueType() != null) {
137+
res.append(" value: ").append(getValueType().toString()).append(" sql: ")
138+
.append(getValueType().toSql());
139+
}
122140

123-
public TPrimitiveType getPrimitiveType() {
124-
return thriftType;
141+
return res.toString();
125142
}
126143

127144
public int getLen() {
128145
return len;
129146
}
130147

148+
public TPrimitiveType getPrimitiveType() {
149+
return thriftType;
150+
}
151+
131152
public static Set<JavaUdfDataType> getCandidateTypes(Class<?> c) {
132153
if (c == boolean.class || c == Boolean.class) {
133154
return Sets.newHashSet(JavaUdfDataType.BOOLEAN);
@@ -169,19 +190,14 @@ public static Set<JavaUdfDataType> getCandidateTypes(Class<?> c) {
169190
}
170191

171192
public static boolean isSupported(Type t) {
172-
for (JavaUdfDataType javaType : JavaUdfDataTypeSet) {
173-
if (javaType == JavaUdfDataType.INVALID_TYPE) {
174-
continue;
175-
}
176-
if (javaType.getPrimitiveType() == t.getPrimitiveType().toThrift()) {
177-
return true;
178-
}
179-
}
180-
if (t.getPrimitiveType().toThrift() == TPrimitiveType.VARCHAR
181-
|| t.getPrimitiveType().toThrift() == TPrimitiveType.CHAR) {
193+
TPrimitiveType thriftType = t.getPrimitiveType().toThrift();
194+
// varchar and char are supported in java udf, type is String
195+
if (thriftType == TPrimitiveType.VARCHAR
196+
|| thriftType == TPrimitiveType.CHAR) {
182197
return true;
183198
}
184-
return false;
199+
return !thriftType.equals(TPrimitiveType.INVALID_TYPE)
200+
&& javaUdfDataTypeMap.containsKey(thriftType);
185201
}
186202

187203
public int getPrecision() {
@@ -209,7 +225,6 @@ public void setItemType(Type type) throws InternalException {
209225
this.itemType = type;
210226
} else {
211227
if (!this.itemType.matchesType(type)) {
212-
LOG.info("set error");
213228
throw new InternalException("udf type not matches origin type :" + this.itemType.toSql()
214229
+ " set type :" + type.toSql());
215230
}

fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/UdfClassCache.java

+2
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,6 @@ public class UdfClassCache {
3838
public JavaUdfDataType retType;
3939
// the class type of the arguments in evaluate() method
4040
public Class[] argClass;
41+
// The return type class of evaluate() method
42+
public Class retClass;
4143
}

fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java

+26-24
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.doris.common.exception.UdfRuntimeException;
2424
import org.apache.doris.common.jni.utils.JavaUdfDataType;
2525
import org.apache.doris.common.jni.vec.ColumnValueConverter;
26+
import org.apache.doris.common.jni.vec.VectorTable;
2627
import org.apache.doris.thrift.TFunction;
2728
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
2829
import org.apache.doris.thrift.TPrimitiveType;
@@ -39,6 +40,7 @@
3940
import java.time.LocalDateTime;
4041
import java.util.ArrayList;
4142
import java.util.HashMap;
43+
import java.util.Map;
4244
import java.util.Map.Entry;
4345

4446
public abstract class BaseExecutor {
@@ -69,7 +71,9 @@ public abstract class BaseExecutor {
6971
protected JavaUdfDataType retType;
7072
protected Class[] argClass;
7173
protected MethodAccess methodAccess;
74+
protected VectorTable outputTable = null;
7275
protected TFunction fn;
76+
protected Class retClass;
7377

7478
/**
7579
* Create a UdfExecutor, using parameters from a serialized thrift object. Used
@@ -102,32 +106,8 @@ public String debugString() {
102106
StringBuilder res = new StringBuilder();
103107
for (JavaUdfDataType type : argTypes) {
104108
res.append(type.toString());
105-
if (type.getItemType() != null) {
106-
res.append(" item: ").append(type.getItemType().toString()).append(" sql: ")
107-
.append(type.getItemType().toSql());
108-
}
109-
if (type.getKeyType() != null) {
110-
res.append(" key: ").append(type.getKeyType().toString()).append(" sql: ")
111-
.append(type.getKeyType().toSql());
112-
}
113-
if (type.getValueType() != null) {
114-
res.append(" key: ").append(type.getValueType().toString()).append(" sql: ")
115-
.append(type.getValueType().toSql());
116-
}
117109
}
118110
res.append(" return type: ").append(retType.toString());
119-
if (retType.getItemType() != null) {
120-
res.append(" item: ").append(retType.getItemType().toString()).append(" sql: ")
121-
.append(retType.getItemType().toSql());
122-
}
123-
if (retType.getKeyType() != null) {
124-
res.append(" key: ").append(retType.getKeyType().toString()).append(" sql: ")
125-
.append(retType.getKeyType().toSql());
126-
}
127-
if (retType.getValueType() != null) {
128-
res.append(" key: ").append(retType.getValueType().toString()).append(" sql: ")
129-
.append(retType.getValueType().toSql());
130-
}
131111
res.append(" methodAccess: ").append(methodAccess.toString());
132112
res.append(" fn.toString(): ").append(fn.toString());
133113
return res.toString();
@@ -150,6 +130,10 @@ public void close() {
150130
}
151131
}
152132
}
133+
// Close the output table if it exists.
134+
if (outputTable != null) {
135+
outputTable.close();
136+
}
153137
// We are now un-usable (because the class loader has been
154138
// closed), so null out method_ and classLoader_.
155139
classLoader = null;
@@ -330,4 +314,22 @@ protected ColumnValueConverter getOutputConverter(JavaUdfDataType returnType, Cl
330314
}
331315
return null;
332316
}
317+
318+
// Add unified converter methods
319+
protected Map<Integer, ColumnValueConverter> getInputConverters(int numColumns, boolean isUdaf) {
320+
Map<Integer, ColumnValueConverter> converters = new HashMap<>();
321+
for (int j = 0; j < numColumns; ++j) {
322+
// For UDAF, we need to offset by 1 since first arg is state
323+
int argIndex = isUdaf ? j + 1 : j;
324+
ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[argIndex]);
325+
if (converter != null) {
326+
converters.put(j, converter);
327+
}
328+
}
329+
return converters;
330+
}
331+
332+
protected ColumnValueConverter getOutputConverter() {
333+
return getOutputConverter(retType, retClass);
334+
}
333335
}

fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java

+1-22
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,7 @@ public class UdafExecutor extends BaseExecutor {
5353

5454
private HashMap<String, Method> allMethods;
5555
private HashMap<Long, Object> stateObjMap;
56-
private Class retClass;
5756
private int addIndex;
58-
private VectorTable outputTable = null;
59-
6057
/**
6158
* Constructor to create an object.
6259
*/
@@ -69,35 +66,17 @@ public UdafExecutor(byte[] thriftParams) throws Exception {
6966
*/
7067
@Override
7168
public void close() {
72-
if (outputTable != null) {
73-
outputTable.close();
74-
}
7569
super.close();
7670
allMethods = null;
7771
stateObjMap = null;
7872
}
7973

80-
private Map<Integer, ColumnValueConverter> getInputConverters(int numColumns) {
81-
Map<Integer, ColumnValueConverter> converters = new HashMap<>();
82-
for (int j = 0; j < numColumns; ++j) {
83-
ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[j + 1]);
84-
if (converter != null) {
85-
converters.put(j, converter);
86-
}
87-
}
88-
return converters;
89-
}
90-
91-
private ColumnValueConverter getOutputConverter() {
92-
return getOutputConverter(retType, retClass);
93-
}
94-
9574
public void addBatch(boolean isSinglePlace, int rowStart, int rowEnd, long placeAddr, int offset,
9675
Map<String, String> inputParams) throws UdfRuntimeException {
9776
try {
9877
VectorTable inputTable = VectorTable.createReadableTable(inputParams);
9978
Object[][] inputs = inputTable.getMaterializedData(rowStart, rowEnd,
100-
getInputConverters(inputTable.getNumColumns()));
79+
getInputConverters(inputTable.getNumColumns(), true));
10180
if (isSinglePlace) {
10281
addBatchSingle(rowStart, rowEnd, placeAddr, inputs);
10382
} else {

fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java

+7-26
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ public class UdfExecutor extends BaseExecutor {
5353

5454
private int evaluateIndex;
5555

56-
private VectorTable outputTable = null;
57-
5856
private boolean isStaticLoad = false;
5957

6058
/**
@@ -70,33 +68,16 @@ public UdfExecutor(byte[] thriftParams) throws Exception {
7068
*/
7169
@Override
7270
public void close() {
73-
// inputTable is released by c++, only release outputTable
74-
if (outputTable != null) {
75-
outputTable.close();
76-
}
7771
// We are now un-usable (because the class loader has been
7872
// closed), so null out method_ and classLoader_.
7973
method = null;
8074
if (!isStaticLoad) {
8175
super.close();
76+
} else if (outputTable != null) {
77+
outputTable.close();
8278
}
8379
}
8480

85-
private Map<Integer, ColumnValueConverter> getInputConverters(int numColumns) {
86-
Map<Integer, ColumnValueConverter> converters = new HashMap<>();
87-
for (int j = 0; j < numColumns; ++j) {
88-
ColumnValueConverter converter = getInputConverter(argTypes[j].getPrimitiveType(), argClass[j]);
89-
if (converter != null) {
90-
converters.put(j, converter);
91-
}
92-
}
93-
return converters;
94-
}
95-
96-
private ColumnValueConverter getOutputConverter() {
97-
return getOutputConverter(retType, method.getReturnType());
98-
}
99-
10081
public long evaluate(Map<String, String> inputParams, Map<String, String> outputParams) throws UdfRuntimeException {
10182
try {
10283
VectorTable inputTable = VectorTable.createReadableTable(inputParams);
@@ -112,7 +93,7 @@ public long evaluate(Map<String, String> inputParams, Map<String, String> output
11293
Object[] result = outputTable.getColumnType(0).isPrimitive()
11394
? outputTable.getColumn(0).newObjectContainerArray(numRows)
11495
: (Object[]) Array.newInstance(method.getReturnType(), numRows);
115-
Object[][] inputs = inputTable.getMaterializedData(getInputConverters(numColumns));
96+
Object[][] inputs = inputTable.getMaterializedData(getInputConverters(numColumns, false));
11697
Object[] parameters = new Object[numColumns];
11798
for (int i = 0; i < numRows; ++i) {
11899
for (int j = 0; j < numColumns; ++j) {
@@ -216,16 +197,15 @@ private void checkAndCacheUdfClass(String className, UdfClassCache cache, Type f
216197
} else {
217198
cache.retType = returnType.second;
218199
}
219-
Type keyType = cache.retType.getKeyType();
220-
Type valueType = cache.retType.getValueType();
221200
Pair<Boolean, JavaUdfDataType[]> inputType = UdfUtils.setArgTypes(parameterTypes, cache.argClass, false);
222201
if (!inputType.first) {
223202
continue;
224203
} else {
225204
cache.argTypes = inputType.second;
226205
}
227-
cache.retType.setKeyType(keyType);
228-
cache.retType.setValueType(valueType);
206+
if (cache.method != null) {
207+
cache.retClass = cache.method.getReturnType();
208+
}
229209
return;
230210
}
231211
StringBuilder sb = new StringBuilder();
@@ -269,6 +249,7 @@ protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type fun
269249
evaluateIndex = cache.evaluateIndex;
270250
retType = cache.retType;
271251
argTypes = cache.argTypes;
252+
retClass = cache.retClass;
272253
} catch (MalformedURLException e) {
273254
throw new UdfRuntimeException("Unable to load jar.", e);
274255
} catch (SecurityException e) {

0 commit comments

Comments
 (0)