1
1
#include " rlc/conversions/RLCToPython.hpp"
2
2
3
+ #include " rlc/dialect/ActionArgumentAnalysis.hpp"
3
4
#include " rlc/dialect/Operations.hpp"
4
5
#include " rlc/dialect/Types.hpp"
5
6
#include " rlc/python/Operations.hpp"
6
7
#include " rlc/python/Types.hpp"
7
8
8
- static void registerConversions (mlir::TypeConverter& converter)
9
+ static void registerBuiltinConversions (
10
+ mlir::TypeConverter& converter, mlir::TypeConverter& ctypesConverter)
9
11
{
10
12
converter.addConversion ([](mlir::rlc::IntegerType t) -> mlir::Type {
11
- return mlir::rlc::python::PythonIntType ::get (t.getContext ());
13
+ return mlir::rlc::python::IntType ::get (t.getContext ());
12
14
});
13
15
14
16
converter.addConversion ([](mlir::rlc::BoolType t) -> mlir::Type {
15
- return mlir::rlc::python::PythonBoolType ::get (t.getContext ());
17
+ return mlir::rlc::python::BoolType ::get (t.getContext ());
16
18
});
17
19
18
20
converter.addConversion ([](mlir::rlc::FloatType t) -> mlir::Type {
19
- return mlir::rlc::python::PythonFloatType ::get (t.getContext ());
21
+ return mlir::rlc::python::FloatType ::get (t.getContext ());
20
22
});
21
23
22
24
converter.addConversion ([](mlir::rlc::VoidType t) -> mlir::Type {
23
- return mlir::rlc::python::PythonNoneType::get (t.getContext ());
25
+ return mlir::rlc::python::NoneType::get (t.getContext ());
26
+ });
27
+
28
+ converter.addConversion ([&](mlir::rlc::ArrayType t) -> mlir::Type {
29
+ auto converted = converter.convertType (t.getUnderlying ());
30
+ assert (converted);
31
+ return mlir::rlc::python::CArrayType::get (
32
+ t.getContext (), converted, t.getSize ());
33
+ });
34
+
35
+ converter.addConversion ([&](mlir::rlc::EntityType t) -> mlir::Type {
36
+ llvm::SmallVector<mlir::Type, 3 > types;
37
+ for (auto sub : t.getBody ())
38
+ {
39
+ auto converted = ctypesConverter.convertType (sub);
40
+ assert (converted);
41
+ types.push_back (converted);
42
+ }
43
+ return mlir::rlc::python::CTypeStructType::get (
44
+ t.getContext (), t.getName (), types);
45
+ });
46
+
47
+ converter.addConversion ([&](mlir::FunctionType t) -> mlir::Type {
48
+ llvm::SmallVector<mlir::Type, 3 > resTypes;
49
+ for (auto sub : t.getResults ())
50
+ {
51
+ auto converted = converter.convertType (sub);
52
+ assert (converted);
53
+ resTypes.push_back (converted);
54
+ }
55
+
56
+ llvm::SmallVector<mlir::Type, 3 > inputTypes;
57
+ for (auto sub : t.getInputs ())
58
+ {
59
+ auto converted = converter.convertType (sub);
60
+ assert (converted);
61
+ inputTypes.push_back (converted);
62
+ }
63
+ return mlir::FunctionType::get (t.getContext (), inputTypes, resTypes);
64
+ });
65
+ }
66
+
67
+ static void registerCTypesConversions (mlir::TypeConverter& converter)
68
+ {
69
+ converter.addConversion ([](mlir::rlc::IntegerType t) -> mlir::Type {
70
+ return mlir::rlc::python::CTypesIntType::get (t.getContext ());
71
+ });
72
+
73
+ converter.addConversion ([](mlir::rlc::BoolType t) -> mlir::Type {
74
+ return mlir::rlc::python::CTypesBoolType::get (t.getContext ());
75
+ });
76
+
77
+ converter.addConversion ([](mlir::rlc::FloatType t) -> mlir::Type {
78
+ return mlir::rlc::python::CTypesFloatType::get (t.getContext ());
79
+ });
80
+
81
+ converter.addConversion ([](mlir::rlc::VoidType t) -> mlir::Type {
82
+ return mlir::rlc::python::NoneType::get (t.getContext ());
24
83
});
25
84
26
85
converter.addConversion ([&](mlir::rlc::ArrayType t) -> mlir::Type {
@@ -88,7 +147,7 @@ class EntityDeclarationToClassDecl
88
147
}
89
148
};
90
149
91
- static void emitFunctionWrapper (
150
+ static mlir::rlc::python::PythonFun emitFunctionWrapper (
92
151
mlir::Location loc,
93
152
mlir::rlc::python::CTypesLoad* library,
94
153
mlir::ConversionPatternRewriter& rewriter,
@@ -99,7 +158,7 @@ static void emitFunctionWrapper(
99
158
mlir::FunctionType fType )
100
159
{
101
160
if (fName .startswith (" _" ))
102
- return ;
161
+ return nullptr ;
103
162
104
163
auto funType = converter->convertType (fType ).cast <mlir::FunctionType>();
105
164
@@ -116,17 +175,45 @@ static void emitFunctionWrapper(
116
175
auto res = rewriter.create <mlir::rlc::python::PythonAccess>(
117
176
loc, funType, *library, f.getSymName ());
118
177
119
- auto resType =
120
- funType.getNumResults () == 0
121
- ? mlir::rlc::python::PythonNoneType::get (fType .getContext ())
122
- : funType.getResult (0 );
178
+ auto resType = funType.getNumResults () == 0
179
+ ? mlir::rlc::python::NoneType::get (fType .getContext ())
180
+ : mlir::rlc::pythonBuiltinToCTypes (funType.getResult (0 ));
123
181
124
182
rewriter.create <mlir::rlc::python::AssignResultType>(loc, res, resType);
183
+ llvm::SmallVector<mlir::Value> values;
184
+
185
+ for (auto value : block->getArguments ())
186
+ {
187
+ if (mlir::rlc::isBuiltinType (value.getType ()))
188
+ {
189
+ auto res = rewriter.create <mlir::rlc::python::PythonCast>(
190
+ value.getLoc (),
191
+ mlir::rlc::pythonBuiltinToCTypes (value.getType ()),
192
+ value);
193
+ values.push_back (res);
194
+ }
195
+ else
196
+ {
197
+ values.push_back (value);
198
+ }
199
+ }
125
200
126
201
auto result = rewriter.create <mlir::rlc::python::PythonCall>(
127
- loc, mlir::TypeRange ({ resType }), res, block->getArguments ());
202
+ loc, mlir::TypeRange ({ resType }), res, values);
203
+
204
+ mlir::Value toReturn = result.getResult (0 );
128
205
129
- rewriter.create <mlir::rlc::python::PythonReturn>(loc, result.getResults ());
206
+ if (resType.isa <mlir::rlc::python::CTypesFloatType>())
207
+ {
208
+ toReturn = rewriter.create <mlir::rlc::python::PythonAccess>(
209
+ result.getLoc (),
210
+ mlir::rlc::pythonCTypesToBuiltin (resType),
211
+ toReturn,
212
+ " value" );
213
+ }
214
+
215
+ rewriter.create <mlir::rlc::python::PythonReturn>(loc, toReturn);
216
+ return f;
130
217
}
131
218
132
219
class FunctionToPyFunction
@@ -163,6 +250,70 @@ class FunctionToPyFunction
163
250
}
164
251
};
165
252
253
+ static void emitActionContraints (
254
+ mlir::rlc::ActionStatement action,
255
+ mlir::Value emittedPythonFunction,
256
+ mlir::ConversionPatternRewriter& rewriter)
257
+ {
258
+ mlir::rlc::ActionArgumentAnalysis analysis (action);
259
+ auto created = rewriter.create <mlir::rlc::python::PythonActionInfo>(
260
+ action->getLoc (), emittedPythonFunction);
261
+
262
+ llvm::SmallVector<mlir::Location, 2 > locs;
263
+ for (size_t i = 0 ; i < action.getResultTypes ().size (); i++)
264
+ locs.push_back (action.getLoc ());
265
+
266
+ auto * block = rewriter.createBlock (
267
+ &created.getBody (),
268
+ created.getBody ().begin (),
269
+ action.getResultTypes (),
270
+ locs);
271
+
272
+ rewriter.setInsertionPoint (block, block->begin ());
273
+
274
+ for (const auto & [pythonArg, rlcArg] : llvm::zip (
275
+ block->getArguments (), action.getPrecondition ().getArguments ()))
276
+ {
277
+ const auto & argInfo = analysis.getBoundsOf (rlcArg);
278
+ rewriter.create <mlir::rlc::python::PythonArgumentConstraint>(
279
+ action.getLoc (), pythonArg, argInfo.getMin (), argInfo.getMax ());
280
+ }
281
+
282
+ rewriter.setInsertionPointAfter (created);
283
+ }
284
+
285
+ static void emitActionContraints (
286
+ mlir::rlc::ActionFunction action,
287
+ mlir::Value emittedPythonFunction,
288
+ mlir::ConversionPatternRewriter& rewriter)
289
+ {
290
+ mlir::rlc::ActionArgumentAnalysis analysis (action);
291
+ auto created = rewriter.create <mlir::rlc::python::PythonActionInfo>(
292
+ action->getLoc (), emittedPythonFunction);
293
+
294
+ llvm::SmallVector<mlir::Location, 2 > locs;
295
+ for (size_t i = 0 ; i < action.getFunctionType ().getNumResults (); i++)
296
+ locs.push_back (action.getLoc ());
297
+
298
+ auto * block = rewriter.createBlock (
299
+ &created.getBody (),
300
+ created.getBody ().begin (),
301
+ action.getFunctionType ().getResults (),
302
+ locs);
303
+
304
+ rewriter.setInsertionPoint (block, block->begin ());
305
+
306
+ for (const auto & [pythonArg, rlcArg] : llvm::zip (
307
+ block->getArguments (), action.getBody ().front ().getArguments ()))
308
+ {
309
+ const auto & argInfo = analysis.getBoundsOf (rlcArg);
310
+ rewriter.create <mlir::rlc::python::PythonArgumentConstraint>(
311
+ action.getLoc (), pythonArg, argInfo.getMin (), argInfo.getMax ());
312
+ }
313
+
314
+ rewriter.setInsertionPointAfter (created);
315
+ }
316
+
166
317
class ActionDeclToTNothing
167
318
: public mlir::OpConversionPattern<mlir::rlc::ActionFunction>
168
319
{
@@ -188,7 +339,7 @@ class ActionDeclToTNothing
188
339
OpAdaptor adaptor,
189
340
mlir::ConversionPatternRewriter& rewriter) const final
190
341
{
191
- emitFunctionWrapper (
342
+ auto f = emitFunctionWrapper (
192
343
op.getLoc (),
193
344
library,
194
345
rewriter,
@@ -197,7 +348,17 @@ class ActionDeclToTNothing
197
348
mlir::rlc::mangledName (op.getSymName (), op.getFunctionType ()),
198
349
op.getArgNames (),
199
350
op.getFunctionType ());
351
+
352
+ if (f == nullptr )
353
+ {
354
+ rewriter.eraseOp (op);
355
+ return mlir::success ();
356
+ }
357
+
200
358
rewriter.setInsertionPointAfter (op);
359
+
360
+ emitActionContraints (op, f, rewriter);
361
+
201
362
for (const auto & [type, action] :
202
363
llvm::zip (op.getActions (), builder->actionStatementsOfAction (op)))
203
364
{
@@ -209,7 +370,7 @@ class ActionDeclToTNothing
209
370
arrayAttr.push_back (attr.cast <mlir::StringAttr>());
210
371
211
372
auto castedType = type.getType ().cast <mlir::FunctionType>();
212
- emitFunctionWrapper (
373
+ auto f = emitFunctionWrapper (
213
374
casted.getLoc (),
214
375
library,
215
376
rewriter,
@@ -218,14 +379,18 @@ class ActionDeclToTNothing
218
379
mlir::rlc::mangledName (casted.getName (), castedType),
219
380
rewriter.getStrArrayAttr (arrayAttr),
220
381
castedType);
221
- rewriter.setInsertionPointAfter (op);
382
+ rewriter.setInsertionPointAfter (f);
383
+ if (f == nullptr )
384
+ continue ;
385
+
386
+ emitActionContraints (casted, f, rewriter);
222
387
223
388
auto validityType = mlir::FunctionType::get (
224
389
getContext (),
225
390
castedType.getInputs (),
226
391
mlir::rlc::BoolType::get (getContext ()));
227
392
auto name = (" can_" + casted.getName ()).str ();
228
- emitFunctionWrapper (
393
+ auto preconditionCheckFunction = emitFunctionWrapper (
229
394
casted.getLoc (),
230
395
library,
231
396
rewriter,
@@ -234,7 +399,7 @@ class ActionDeclToTNothing
234
399
mlir::rlc::mangledName (name, validityType),
235
400
rewriter.getStrArrayAttr (arrayAttr),
236
401
validityType);
237
- rewriter.setInsertionPointAfter (op );
402
+ rewriter.setInsertionPointAfter (preconditionCheckFunction );
238
403
}
239
404
rewriter.eraseOp (op);
240
405
return mlir::success ();
@@ -251,14 +416,18 @@ void rlc::RLCToPython::runOnOperation()
251
416
mlir::rlc::python::CDLLType::get (&getContext ()),
252
417
" ./lib.so" );
253
418
mlir::ConversionTarget target (getContext ());
419
+
420
+ mlir::TypeConverter ctypesConverter;
421
+ registerCTypesConversions (ctypesConverter);
422
+
254
423
mlir::TypeConverter converter;
255
- registerConversions (converter);
424
+ registerBuiltinConversions (converter, ctypesConverter );
256
425
257
426
target.addLegalDialect <mlir::rlc::python::RLCPython>();
258
427
target.addIllegalDialect <mlir::rlc::RLCDialect>();
259
428
260
429
mlir::RewritePatternSet patterns (&getContext ());
261
- patterns.add <EntityDeclarationToClassDecl>(converter , &getContext ());
430
+ patterns.add <EntityDeclarationToClassDecl>(ctypesConverter , &getContext ());
262
431
patterns.add <ActionDeclToTNothing>(
263
432
&lib, &rlcBuilder, converter, &getContext ());
264
433
patterns.add <FunctionToPyFunction>(&lib, converter, &getContext ());
0 commit comments