Skip to content

Commit

Permalink
Merge pull request #33237 from vespa-engine/bratseth/type-inference-c…
Browse files Browse the repository at this point in the history
…leanup-take-2

Bratseth/type inference cleanup take 2
  • Loading branch information
bratseth authored Feb 3, 2025
2 parents 32def7a + b72e6f0 commit a6cd029
Show file tree
Hide file tree
Showing 73 changed files with 380 additions and 378 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,9 @@ private void exactMatchSettingsForField(SDField field) {

private static class MyProvider extends TypedTransformProvider {

private int maxTokenLength;
private final int maxTokenLength;

MyProvider(Schema schema, int maxTokenLength)
{
MyProvider(Schema schema, int maxTokenLength) {
super(ExactExpression.class, schema);
this.maxTokenLength = maxTokenLength;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ protected boolean shouldConvert(Expression expression) {
outputs.add(fieldName);
prevNames.add(fieldName);
}
if (expression.createdOutputType() != null) {
if (expression.isMutating()) {
prevNames.clear();
}
return false;
Expand All @@ -98,9 +98,8 @@ private static class MyAdapter implements FieldTypeAdapter {
@Override
public DataType getInputType(Expression exp, String fieldName) {
SDField field = schema.getDocumentField(fieldName);
if (field == null) {
throw new VerificationException(exp, "Input field '" + fieldName + "' not found.");
}
if (field == null)
throw new VerificationException(exp, "Input field '" + fieldName + "' not found");
return field.getDataType();
}

Expand All @@ -110,16 +109,14 @@ public void tryOutputType(Expression expression, String fieldName, DataType valu
DataType fieldType;
if (expression instanceof AttributeExpression) {
Attribute attribute = schema.getAttribute(fieldName);
if (attribute == null) {
throw new VerificationException(expression, "Attribute '" + fieldName + "' not found.");
}
if (attribute == null)
throw new VerificationException(expression, "Attribute '" + fieldName + "' not found");
fieldDesc = "attribute";
fieldType = attribute.getDataType();
} else if (expression instanceof IndexExpression) {
SDField field = schema.getConcreteField(fieldName);
if (field == null) {
throw new VerificationException(expression, "Index field '" + fieldName + "' not found.");
}
if (field == null)
throw new VerificationException(expression, "Index field '" + fieldName + "' not found");
fieldDesc = "index field";
fieldType = field.getDataType();
} else if (expression instanceof SummaryExpression) {
Expand All @@ -131,7 +128,7 @@ public void tryOutputType(Expression expression, String fieldName, DataType valu
fieldDesc = "document field";
fieldType = sdField.getDataType();
} else {
throw new VerificationException(expression, "Summary field '" + fieldName + "' not found.");
throw new VerificationException(expression, "Summary field '" + fieldName + "' not found");
}
} else {
fieldDesc = "summary field";
Expand All @@ -142,7 +139,7 @@ public void tryOutputType(Expression expression, String fieldName, DataType valu
}
if ( ! fieldType.isAssignableFrom(valueType))
throw new VerificationException(expression, "Can not assign " + valueType.getName() + " to " + fieldDesc +
" '" + fieldName + "' which is " + fieldType.getName() + ".");
" '" + fieldName + "' which is " + fieldType.getName());
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ protected boolean shouldConvert(Expression exp) {
}
if (exp instanceof InputExpression && ((InputExpression)exp).getFieldName().equals(field.getName())) {
mutatedBy = null;
} else if (exp.createdOutputType() != null) {
} else if (exp.isMutating()) {
mutatedBy = exp;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,14 @@ void testConnectivity() {
"which does not exist in the query."));
}

@Test
void testWeight() {
QueryTree parsed = parse("select * from sources * where " +
"weakAnd(field1 contains ({weight: 120}'term1'), " +
" field1 contains ({weight: 70}'term2'))");
assertEquals("WEAKAND(100) field1:term1!120 field1:term2!70", parsed.toString());
}

@Test
void testAnnotatedPhrase() {
QueryTree parsed =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public Class getValueClass() {
@Override
public FieldPath buildFieldPath(String remainFieldName)
{
if (remainFieldName.length() > 0 && remainFieldName.charAt(0) == '[') {
if (!remainFieldName.isEmpty() && remainFieldName.charAt(0) == '[') {
int endPos = remainFieldName.indexOf(']');
if (endPos == -1) {
throw new IllegalArgumentException("Array subscript must be closed with ]");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,9 @@ public PrimitiveDataType getPrimitiveType() {

@Override
public boolean isValueCompatible(FieldValue value) {
if (!(value instanceof CollectionFieldValue)) {
return false;
}
CollectionFieldValue<?> cfv = (CollectionFieldValue<?>) value;
return equals(cfv.getDataType());
if (!(value instanceof CollectionFieldValue<?> collectionValue)) return false;
if (collectionValue.getDataType().getClass() != this.getClass()) return false;
return collectionValue.getDataType().getNestedType().isAssignableTo(this.getNestedType());
}

@Override
Expand Down
6 changes: 5 additions & 1 deletion document/src/main/java/com/yahoo/document/MapDataType.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.document;

import com.yahoo.document.datatypes.CollectionFieldValue;
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.MapFieldValue;

Expand Down Expand Up @@ -41,7 +42,10 @@ public MapDataType clone() {

@Override
public boolean isValueCompatible(FieldValue value) {
return value.getDataType().equals(this);
if (!(value instanceof MapFieldValue<?,?> mapValue)) return false;
if (mapValue.getDataType().getClass() != this.getClass()) return false;
return mapValue.getDataType().getKeyType().isAssignableTo(this.getKeyType()) &&
mapValue.getDataType().getValueType().isAssignableTo(this.getValueType());
}

public DataType getKeyType() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,14 @@
*/
public abstract class FieldValueConverter {

@SuppressWarnings({ "unchecked" })
@SuppressWarnings({ "rawtypes", "unchecked" })
public final FieldValue convert(FieldValue value) {
if (value == null) {
return null;
}
if (shouldConvert(value)) {
return doConvert(value);
}
if (value instanceof Array) {
return convertArray((Array)value);
}
if (value instanceof MapFieldValue) {
return convertMap((MapFieldValue)value);
}
if (value instanceof WeightedSet) {
return convertWset((WeightedSet)value);
}
if (value instanceof StructuredFieldValue) {
return convertStructured((StructuredFieldValue)value);
}
if (value == null) return null;
if (shouldConvert(value)) return doConvert(value);
if (value instanceof Array arrayValue) return convertArray(arrayValue);
if (value instanceof MapFieldValue mapValue) return convertMap(mapValue);
if (value instanceof WeightedSet weightedSetValue) return convertWset(weightedSetValue);
if (value instanceof StructuredFieldValue structuredFieldValue) return convertStructured(structuredFieldValue);
return value;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,17 @@ public final ExpressionConverter branch() {
}

@Override
protected final boolean shouldConvert(Expression exp) {
if (transformClass.isInstance(exp)) {
protected final boolean shouldConvert(Expression expression) {
if (transformClass.isInstance(expression)) {
if (transformed) {
duplicate = true;
return true;
}
transformed = true;
return false;
}
if (exp.createdOutputType() != null) {
transformed = false;
return false;
}
if ( ! requiresTransform(exp)) {
return false;
}
if (transformed) {
return false;
}
if ( ! requiresTransform(expression)) return false;
if (transformed) return false;
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
*/
public final class BusyWaitExpression extends Expression {

@Override
public boolean isMutating() { return false; }

@Override
protected void doExecute(ExecutionContext context) {
FieldValue value = context.getCurrentValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public DataType setOutputType(DataType outputType, VerificationContext context)
throw new VerificationException(this, "Produces type " + value.getDataType().getName() + ", but type " +
outputType.getName() + " is required");
super.setOutputType(outputType, context);
return null;
return AnyDataType.instance;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ public EchoExpression(PrintStream out) {
this.out = out;
}

@Override
public boolean isMutating() { return false; }

public PrintStream getOutputStream() { return out; }

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ public ExactExpression(int maxTokenLength) {
this(OptionalInt.of(maxTokenLength));
}

@Override
public boolean isMutating() { return false; }

@Override
public DataType setInputType(DataType inputType, VerificationContext context) {
return super.setInputType(inputType, DataType.STRING, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ public abstract class Expression extends Selectable {
/** Returns whether this expression requires an input value. */
public boolean requiresInput() { return true; }

/**
* Returns whether this expression outputs a different value than what it gets as input.
* Annotating a string value does not count as modifying it.
*/
public boolean isMutating() { return true; }

/**
* Returns an expression where the children of this has been converted using the given converter.
* This default implementation returns this as it has no children.
Expand Down Expand Up @@ -152,50 +158,38 @@ public final void verify(DocumentType type) {
verify(new DocumentTypeAdapter(type));
}

public final Document verify(Document doc) {
return verify(new SimpleAdapterFactory(), doc);
public final void verify(Document doc) {
verify(new SimpleAdapterFactory(), doc);
}

public final Document verify(AdapterFactory factory, Document doc) {
return verify(factory.newDocumentAdapter(doc));
public final void verify(AdapterFactory factory, Document doc) {
verify(factory.newDocumentAdapter(doc));
}

public final Document verify(DocumentAdapter adapter) {
public final void verify(DocumentAdapter adapter) {
verify((FieldTypeAdapter)adapter);
return adapter.getFullOutput();
adapter.getFullOutput();
}

public final DocumentUpdate verify(DocumentUpdate upd) {
return verify(new SimpleAdapterFactory(), upd);
public final void verify(DocumentUpdate upd) {
verify(new SimpleAdapterFactory(), upd);
}

public final DocumentUpdate verify(AdapterFactory factory, DocumentUpdate upd) {
DocumentUpdate ret = null;
for (UpdateAdapter adapter : factory.newUpdateAdapterList(upd)) {
DocumentUpdate output = verify(adapter);
if (output == null) {
// ignore
} else if (ret != null) {
ret.addAll(output);
} else {
ret = output;
}
}
return ret;
public final void verify(AdapterFactory factory, DocumentUpdate upd) {
for (UpdateAdapter adapter : factory.newUpdateAdapterList(upd))
verify(adapter);
}

public final DocumentUpdate verify(UpdateAdapter adapter) {
public final void verify(UpdateAdapter adapter) {
verify((FieldTypeAdapter)adapter);
return adapter.getOutput();
}

public final DataType verify(FieldTypeAdapter adapter) {
return verify(new VerificationContext(adapter));
public final void verify(FieldTypeAdapter adapter) {
verify(new VerificationContext(adapter));
}

public final DataType verify(VerificationContext context) {
public final void verify(VerificationContext context) {
doVerify(context);
return context.getCurrentType();
}

protected void doVerify(VerificationContext context) {}
Expand Down Expand Up @@ -246,14 +240,6 @@ public final FieldValue execute(ExecutionContext context) {
if (input == null) return null;
}
doExecute(context);
DataType outputType = createdOutputType();
if (outputType != null) {
FieldValue output = context.getCurrentValue();
if (output != null && !outputType.isValueCompatible(output)) {
throw new IllegalStateException("Expression '" + this + "' expected " + outputType.getName() +
" output, got " + output.getDataType().getName());
}
}
return context.getCurrentValue();
}

Expand All @@ -276,17 +262,6 @@ public static Expression newInstance(ScriptParserContext context) throws ParseEx
return ScriptParser.parseExpression(context);
}

protected static boolean equals(Object lhs, Object rhs) {
if (lhs == null) {
return rhs == null;
} else {
if (rhs == null) {
return false;
}
return lhs.equals(rhs);
}
}

// Convenience For testing
public static Document execute(Expression expression, Document doc) {
expression.verify(doc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ public ForEachExpression(Expression expression) {
this.expression = Objects.requireNonNull(expression);
}

@Override
public boolean isMutating() { return expression.isMutating(); }

public Expression getInnerExpression() { return expression; }

@Override
Expand Down Expand Up @@ -167,12 +170,8 @@ protected void doExecute(ExecutionContext context) {
FieldValue input = context.getCurrentValue();
if (input instanceof Array || input instanceof WeightedSet) {
FieldValue next = new ExecutionConverter(context, expression).convert(input);
if (next == null) {
VerificationContext verificationContext = new VerificationContext(context.getFieldValue());
context.fillVariableTypes(verificationContext);
verificationContext.setCurrentType(input.getDataType()).verify(this);
next = verificationContext.getCurrentType().createFieldValue();
}
if (next == null)
next = getOutputType().createFieldValue();
context.setCurrentValue(next);
} else if (input instanceof Struct || input instanceof Map) {
context.setCurrentValue(new ExecutionConverter(context, expression).convert(input));
Expand Down Expand Up @@ -227,7 +226,7 @@ protected boolean shouldConvert(FieldValue value) {
/** Converts a map into an array by passing each entry through the expression. */
@Override
protected FieldValue convertMap(MapFieldValue<FieldValue, FieldValue> map) {
var values = new Array<>(new ArrayDataType(expression.createdOutputType()), map.size());
var values = new Array<>(new ArrayDataType(expression.getOutputType()), map.size());
for (var entry : map.entrySet())
values.add(doConvert(new MapEntryFieldValue(entry.getKey(), entry.getValue())));
return values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ public GuardExpression(Expression innerExpression) {
shouldExecute = shouldExecute(innerExpression);
}

@Override
public boolean isMutating() { return innerExpression.isMutating(); }

@Override
public boolean requiresInput() { return innerExpression.requiresInput(); }

Expand Down
Loading

0 comments on commit a6cd029

Please sign in to comment.