@@ -46,7 +46,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
46
46
private static final Set <Class <?>> ALLOWED_START_TYPES = new HashSet <Class <?>>(
47
47
Arrays .<Class <?>> asList (AggregationExpression .class , String .class , Field .class , Document .class ));
48
48
49
- private final String from ;
49
+ private final Object from ;
50
50
private final List <Object > startWith ;
51
51
private final Field connectFrom ;
52
52
private final Field connectTo ;
@@ -55,7 +55,7 @@ public class GraphLookupOperation implements InheritsFieldsAggregationOperation
55
55
private final @ Nullable Field depthField ;
56
56
private final @ Nullable CriteriaDefinition restrictSearchWithMatch ;
57
57
58
- private GraphLookupOperation (String from , List <Object > startWith , Field connectFrom , Field connectTo , Field as ,
58
+ private GraphLookupOperation (Object from , List <Object > startWith , Field connectFrom , Field connectTo , Field as ,
59
59
@ Nullable Long maxDepth , @ Nullable Field depthField , @ Nullable CriteriaDefinition restrictSearchWithMatch ) {
60
60
61
61
this .from = from ;
@@ -82,7 +82,7 @@ public Document toDocument(AggregationOperationContext context) {
82
82
83
83
Document graphLookup = new Document ();
84
84
85
- graphLookup .put ("from" , from );
85
+ graphLookup .put ("from" , getCollectionName ( context ) );
86
86
87
87
List <Object > mappedStartWith = new ArrayList <>(startWith .size ());
88
88
@@ -99,7 +99,7 @@ public Document toDocument(AggregationOperationContext context) {
99
99
100
100
graphLookup .put ("startWith" , mappedStartWith .size () == 1 ? mappedStartWith .iterator ().next () : mappedStartWith );
101
101
102
- graphLookup .put ("connectFromField" , connectFrom . getTarget ( ));
102
+ graphLookup .put ("connectFromField" , getForeignFieldName ( context ));
103
103
graphLookup .put ("connectToField" , connectTo .getTarget ());
104
104
graphLookup .put ("as" , as .getName ());
105
105
@@ -118,6 +118,16 @@ public Document toDocument(AggregationOperationContext context) {
118
118
return new Document (getOperator (), graphLookup );
119
119
}
120
120
121
+ String getCollectionName (AggregationOperationContext context ) {
122
+ return from instanceof Class <?> type ? context .getCollection (type ) : from .toString ();
123
+ }
124
+
125
+ String getForeignFieldName (AggregationOperationContext context ) {
126
+
127
+ return from instanceof Class <?> type ? context .getMappedFieldName (type , connectFrom .getTarget ())
128
+ : connectFrom .getTarget ();
129
+ }
130
+
121
131
@ Override
122
132
public String getOperator () {
123
133
return "$graphLookup" ;
@@ -128,7 +138,7 @@ public ExposedFields getFields() {
128
138
129
139
List <ExposedField > fields = new ArrayList <>(2 );
130
140
fields .add (new ExposedField (as , true ));
131
- if (depthField != null ) {
141
+ if (depthField != null ) {
132
142
fields .add (new ExposedField (depthField , true ));
133
143
}
134
144
return ExposedFields .from (fields .toArray (new ExposedField [0 ]));
@@ -146,6 +156,17 @@ public interface FromBuilder {
146
156
* @return never {@literal null}.
147
157
*/
148
158
StartWithBuilder from (String collectionName );
159
+
160
+ /**
161
+ * Use the given type to determine name of the foreign collection and map
162
+ * {@link ConnectFromBuilder#connectFrom(String)} against it to consider eventually present
163
+ * {@link org.springframework.data.mongodb.core.mapping.Field} annotations.
164
+ *
165
+ * @param type must not be {@literal null}.
166
+ * @return never {@literal null}.
167
+ * @since 4.2
168
+ */
169
+ StartWithBuilder from (Class <?> type );
149
170
}
150
171
151
172
/**
@@ -218,7 +239,7 @@ public interface ConnectToBuilder {
218
239
static final class GraphLookupOperationFromBuilder
219
240
implements FromBuilder , StartWithBuilder , ConnectFromBuilder , ConnectToBuilder {
220
241
221
- private @ Nullable String from ;
242
+ private @ Nullable Object from ;
222
243
private @ Nullable List <? extends Object > startWith ;
223
244
private @ Nullable String connectFrom ;
224
245
@@ -231,6 +252,14 @@ public StartWithBuilder from(String collectionName) {
231
252
return this ;
232
253
}
233
254
255
+ @ Override
256
+ public StartWithBuilder from (Class <?> type ) {
257
+
258
+ Assert .notNull (type , "Type must not be null" );
259
+ this .from = type ;
260
+ return this ;
261
+ }
262
+
234
263
@ Override
235
264
public ConnectFromBuilder startWith (String ... fieldReferences ) {
236
265
@@ -321,15 +350,15 @@ public GraphLookupOperationBuilder connectTo(String fieldName) {
321
350
*/
322
351
public static final class GraphLookupOperationBuilder {
323
352
324
- private final String from ;
353
+ private final Object from ;
325
354
private final List <Object > startWith ;
326
355
private final Field connectFrom ;
327
356
private final Field connectTo ;
328
357
private @ Nullable Long maxDepth ;
329
358
private @ Nullable Field depthField ;
330
359
private @ Nullable CriteriaDefinition restrictSearchWithMatch ;
331
360
332
- protected GraphLookupOperationBuilder (String from , List <? extends Object > startWith , String connectFrom ,
361
+ protected GraphLookupOperationBuilder (Object from , List <? extends Object > startWith , String connectFrom ,
333
362
String connectTo ) {
334
363
335
364
this .from = from ;
0 commit comments