Skip to content

Commit f0078d0

Browse files
python wrapper
1 parent 13b9668 commit f0078d0

File tree

13 files changed

+705
-147
lines changed

13 files changed

+705
-147
lines changed

lib/conversions/src/RLCToPython.cpp

Lines changed: 189 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,85 @@
11
#include "rlc/conversions/RLCToPython.hpp"
22

3+
#include "rlc/dialect/ActionArgumentAnalysis.hpp"
34
#include "rlc/dialect/Operations.hpp"
45
#include "rlc/dialect/Types.hpp"
56
#include "rlc/python/Operations.hpp"
67
#include "rlc/python/Types.hpp"
78

8-
static void registerConversions(mlir::TypeConverter& converter)
9+
static void registerBuiltinConversions(
10+
mlir::TypeConverter& converter, mlir::TypeConverter& ctypesConverter)
911
{
1012
converter.addConversion([](mlir::rlc::IntegerType t) -> mlir::Type {
11-
return mlir::rlc::python::PythonIntType::get(t.getContext());
13+
return mlir::rlc::python::IntType::get(t.getContext());
1214
});
1315

1416
converter.addConversion([](mlir::rlc::BoolType t) -> mlir::Type {
15-
return mlir::rlc::python::PythonBoolType::get(t.getContext());
17+
return mlir::rlc::python::BoolType::get(t.getContext());
1618
});
1719

1820
converter.addConversion([](mlir::rlc::FloatType t) -> mlir::Type {
19-
return mlir::rlc::python::PythonFloatType::get(t.getContext());
21+
return mlir::rlc::python::FloatType::get(t.getContext());
2022
});
2123

2224
converter.addConversion([](mlir::rlc::VoidType t) -> mlir::Type {
23-
return mlir::rlc::python::PythonNoneType::get(t.getContext());
25+
return mlir::rlc::python::NoneType::get(t.getContext());
26+
});
27+
28+
converter.addConversion([&](mlir::rlc::ArrayType t) -> mlir::Type {
29+
auto converted = converter.convertType(t.getUnderlying());
30+
assert(converted);
31+
return mlir::rlc::python::CArrayType::get(
32+
t.getContext(), converted, t.getSize());
33+
});
34+
35+
converter.addConversion([&](mlir::rlc::EntityType t) -> mlir::Type {
36+
llvm::SmallVector<mlir::Type, 3> types;
37+
for (auto sub : t.getBody())
38+
{
39+
auto converted = ctypesConverter.convertType(sub);
40+
assert(converted);
41+
types.push_back(converted);
42+
}
43+
return mlir::rlc::python::CTypeStructType::get(
44+
t.getContext(), t.getName(), types);
45+
});
46+
47+
converter.addConversion([&](mlir::FunctionType t) -> mlir::Type {
48+
llvm::SmallVector<mlir::Type, 3> resTypes;
49+
for (auto sub : t.getResults())
50+
{
51+
auto converted = converter.convertType(sub);
52+
assert(converted);
53+
resTypes.push_back(converted);
54+
}
55+
56+
llvm::SmallVector<mlir::Type, 3> inputTypes;
57+
for (auto sub : t.getInputs())
58+
{
59+
auto converted = converter.convertType(sub);
60+
assert(converted);
61+
inputTypes.push_back(converted);
62+
}
63+
return mlir::FunctionType::get(t.getContext(), inputTypes, resTypes);
64+
});
65+
}
66+
67+
static void registerCTypesConversions(mlir::TypeConverter& converter)
68+
{
69+
converter.addConversion([](mlir::rlc::IntegerType t) -> mlir::Type {
70+
return mlir::rlc::python::CTypesIntType::get(t.getContext());
71+
});
72+
73+
converter.addConversion([](mlir::rlc::BoolType t) -> mlir::Type {
74+
return mlir::rlc::python::CTypesBoolType::get(t.getContext());
75+
});
76+
77+
converter.addConversion([](mlir::rlc::FloatType t) -> mlir::Type {
78+
return mlir::rlc::python::CTypesFloatType::get(t.getContext());
79+
});
80+
81+
converter.addConversion([](mlir::rlc::VoidType t) -> mlir::Type {
82+
return mlir::rlc::python::NoneType::get(t.getContext());
2483
});
2584

2685
converter.addConversion([&](mlir::rlc::ArrayType t) -> mlir::Type {
@@ -88,7 +147,7 @@ class EntityDeclarationToClassDecl
88147
}
89148
};
90149

91-
static void emitFunctionWrapper(
150+
static mlir::rlc::python::PythonFun emitFunctionWrapper(
92151
mlir::Location loc,
93152
mlir::rlc::python::CTypesLoad* library,
94153
mlir::ConversionPatternRewriter& rewriter,
@@ -99,7 +158,7 @@ static void emitFunctionWrapper(
99158
mlir::FunctionType fType)
100159
{
101160
if (fName.startswith("_"))
102-
return;
161+
return nullptr;
103162

104163
auto funType = converter->convertType(fType).cast<mlir::FunctionType>();
105164

@@ -116,17 +175,45 @@ static void emitFunctionWrapper(
116175
auto res = rewriter.create<mlir::rlc::python::PythonAccess>(
117176
loc, funType, *library, f.getSymName());
118177

119-
auto resType =
120-
funType.getNumResults() == 0
121-
? mlir::rlc::python::PythonNoneType::get(fType.getContext())
122-
: funType.getResult(0);
178+
auto resType = funType.getNumResults() == 0
179+
? mlir::rlc::python::NoneType::get(fType.getContext())
180+
: mlir::rlc::pythonBuiltinToCTypes(funType.getResult(0));
123181

124182
rewriter.create<mlir::rlc::python::AssignResultType>(loc, res, resType);
183+
llvm::SmallVector<mlir::Value> values;
184+
185+
for (auto value : block->getArguments())
186+
{
187+
if (mlir::rlc::isBuiltinType(value.getType()))
188+
{
189+
auto res = rewriter.create<mlir::rlc::python::PythonCast>(
190+
value.getLoc(),
191+
mlir::rlc::pythonBuiltinToCTypes(value.getType()),
192+
value);
193+
values.push_back(res);
194+
}
195+
else
196+
{
197+
values.push_back(value);
198+
}
199+
}
125200

126201
auto result = rewriter.create<mlir::rlc::python::PythonCall>(
127-
loc, mlir::TypeRange({ resType }), res, block->getArguments());
202+
loc, mlir::TypeRange({ resType }), res, values);
203+
204+
mlir::Value toReturn = result.getResult(0);
128205

129-
rewriter.create<mlir::rlc::python::PythonReturn>(loc, result.getResults());
206+
if (resType.isa<mlir::rlc::python::CTypesFloatType>())
207+
{
208+
toReturn = rewriter.create<mlir::rlc::python::PythonAccess>(
209+
result.getLoc(),
210+
mlir::rlc::pythonCTypesToBuiltin(resType),
211+
toReturn,
212+
"value");
213+
}
214+
215+
rewriter.create<mlir::rlc::python::PythonReturn>(loc, toReturn);
216+
return f;
130217
}
131218

132219
class FunctionToPyFunction
@@ -163,6 +250,70 @@ class FunctionToPyFunction
163250
}
164251
};
165252

253+
static void emitActionContraints(
254+
mlir::rlc::ActionStatement action,
255+
mlir::Value emittedPythonFunction,
256+
mlir::ConversionPatternRewriter& rewriter)
257+
{
258+
mlir::rlc::ActionArgumentAnalysis analysis(action);
259+
auto created = rewriter.create<mlir::rlc::python::PythonActionInfo>(
260+
action->getLoc(), emittedPythonFunction);
261+
262+
llvm::SmallVector<mlir::Location, 2> locs;
263+
for (size_t i = 0; i < action.getResultTypes().size(); i++)
264+
locs.push_back(action.getLoc());
265+
266+
auto* block = rewriter.createBlock(
267+
&created.getBody(),
268+
created.getBody().begin(),
269+
action.getResultTypes(),
270+
locs);
271+
272+
rewriter.setInsertionPoint(block, block->begin());
273+
274+
for (const auto& [pythonArg, rlcArg] : llvm::zip(
275+
block->getArguments(), action.getPrecondition().getArguments()))
276+
{
277+
const auto& argInfo = analysis.getBoundsOf(rlcArg);
278+
rewriter.create<mlir::rlc::python::PythonArgumentConstraint>(
279+
action.getLoc(), pythonArg, argInfo.getMin(), argInfo.getMax());
280+
}
281+
282+
rewriter.setInsertionPointAfter(created);
283+
}
284+
285+
static void emitActionContraints(
286+
mlir::rlc::ActionFunction action,
287+
mlir::Value emittedPythonFunction,
288+
mlir::ConversionPatternRewriter& rewriter)
289+
{
290+
mlir::rlc::ActionArgumentAnalysis analysis(action);
291+
auto created = rewriter.create<mlir::rlc::python::PythonActionInfo>(
292+
action->getLoc(), emittedPythonFunction);
293+
294+
llvm::SmallVector<mlir::Location, 2> locs;
295+
for (size_t i = 0; i < action.getFunctionType().getNumResults(); i++)
296+
locs.push_back(action.getLoc());
297+
298+
auto* block = rewriter.createBlock(
299+
&created.getBody(),
300+
created.getBody().begin(),
301+
action.getFunctionType().getResults(),
302+
locs);
303+
304+
rewriter.setInsertionPoint(block, block->begin());
305+
306+
for (const auto& [pythonArg, rlcArg] : llvm::zip(
307+
block->getArguments(), action.getBody().front().getArguments()))
308+
{
309+
const auto& argInfo = analysis.getBoundsOf(rlcArg);
310+
rewriter.create<mlir::rlc::python::PythonArgumentConstraint>(
311+
action.getLoc(), pythonArg, argInfo.getMin(), argInfo.getMax());
312+
}
313+
314+
rewriter.setInsertionPointAfter(created);
315+
}
316+
166317
class ActionDeclToTNothing
167318
: public mlir::OpConversionPattern<mlir::rlc::ActionFunction>
168319
{
@@ -188,7 +339,7 @@ class ActionDeclToTNothing
188339
OpAdaptor adaptor,
189340
mlir::ConversionPatternRewriter& rewriter) const final
190341
{
191-
emitFunctionWrapper(
342+
auto f = emitFunctionWrapper(
192343
op.getLoc(),
193344
library,
194345
rewriter,
@@ -197,7 +348,17 @@ class ActionDeclToTNothing
197348
mlir::rlc::mangledName(op.getSymName(), op.getFunctionType()),
198349
op.getArgNames(),
199350
op.getFunctionType());
351+
352+
if (f == nullptr)
353+
{
354+
rewriter.eraseOp(op);
355+
return mlir::success();
356+
}
357+
200358
rewriter.setInsertionPointAfter(op);
359+
360+
emitActionContraints(op, f, rewriter);
361+
201362
for (const auto& [type, action] :
202363
llvm::zip(op.getActions(), builder->actionStatementsOfAction(op)))
203364
{
@@ -209,7 +370,7 @@ class ActionDeclToTNothing
209370
arrayAttr.push_back(attr.cast<mlir::StringAttr>());
210371

211372
auto castedType = type.getType().cast<mlir::FunctionType>();
212-
emitFunctionWrapper(
373+
auto f = emitFunctionWrapper(
213374
casted.getLoc(),
214375
library,
215376
rewriter,
@@ -218,14 +379,18 @@ class ActionDeclToTNothing
218379
mlir::rlc::mangledName(casted.getName(), castedType),
219380
rewriter.getStrArrayAttr(arrayAttr),
220381
castedType);
221-
rewriter.setInsertionPointAfter(op);
382+
rewriter.setInsertionPointAfter(f);
383+
if (f == nullptr)
384+
continue;
385+
386+
emitActionContraints(casted, f, rewriter);
222387

223388
auto validityType = mlir::FunctionType::get(
224389
getContext(),
225390
castedType.getInputs(),
226391
mlir::rlc::BoolType::get(getContext()));
227392
auto name = ("can_" + casted.getName()).str();
228-
emitFunctionWrapper(
393+
auto preconditionCheckFunction = emitFunctionWrapper(
229394
casted.getLoc(),
230395
library,
231396
rewriter,
@@ -234,7 +399,7 @@ class ActionDeclToTNothing
234399
mlir::rlc::mangledName(name, validityType),
235400
rewriter.getStrArrayAttr(arrayAttr),
236401
validityType);
237-
rewriter.setInsertionPointAfter(op);
402+
rewriter.setInsertionPointAfter(preconditionCheckFunction);
238403
}
239404
rewriter.eraseOp(op);
240405
return mlir::success();
@@ -251,14 +416,18 @@ void rlc::RLCToPython::runOnOperation()
251416
mlir::rlc::python::CDLLType::get(&getContext()),
252417
"./lib.so");
253418
mlir::ConversionTarget target(getContext());
419+
420+
mlir::TypeConverter ctypesConverter;
421+
registerCTypesConversions(ctypesConverter);
422+
254423
mlir::TypeConverter converter;
255-
registerConversions(converter);
424+
registerBuiltinConversions(converter, ctypesConverter);
256425

257426
target.addLegalDialect<mlir::rlc::python::RLCPython>();
258427
target.addIllegalDialect<mlir::rlc::RLCDialect>();
259428

260429
mlir::RewritePatternSet patterns(&getContext());
261-
patterns.add<EntityDeclarationToClassDecl>(converter, &getContext());
430+
patterns.add<EntityDeclarationToClassDecl>(ctypesConverter, &getContext());
262431
patterns.add<ActionDeclToTNothing>(
263432
&lib, &rlcBuilder, converter, &getContext());
264433
patterns.add<FunctionToPyFunction>(&lib, converter, &getContext());

lib/dialect/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
rlcAddLibrary(dialect SHARED src/Dialect.cpp src/Types.cpp src/Operations.cpp src/Conversion.cpp src/EmitMain.cpp src/TypeCheck.cpp src/Interfaces.cpp src/SymbolTable.cpp)
1+
rlcAddLibrary(dialect SHARED src/Dialect.cpp src/Types.cpp src/Operations.cpp src/Conversion.cpp src/EmitMain.cpp src/TypeCheck.cpp src/Interfaces.cpp src/SymbolTable.cpp src/ActionArgumentAnalysis.cpp)
22
target_link_libraries(dialect PUBLIC MLIRSupport MLIRDialect MLIRLLVMDialect MLIRLLVMIRTransforms MLIRControlFlowDialect)
33

44
set(tblgen ${LLVM_BINARY_DIR}/bin/mlir-tblgen)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#pragma once
2+
3+
#include "mlir/IR/Operation.h"
4+
#include "rlc/dialect/Operations.hpp"
5+
6+
namespace mlir::rlc
7+
{
8+
9+
class IntegerArgumentConstraints
10+
{
11+
public:
12+
using IntType = int64_t;
13+
14+
void maybeNewMax(IntType newMax) { max = std::min(max, newMax); }
15+
void maybeNewMin(IntType newMin) { min = std::max(min, newMin); }
16+
17+
[[nodiscard]] IntType getMin() const { return min; }
18+
19+
[[nodiscard]] IntType getMax() const { return max; }
20+
21+
private:
22+
IntType min = std::numeric_limits<IntType>::min();
23+
IntType max = std::numeric_limits<IntType>::max();
24+
};
25+
26+
class ActionArgumentAnalysis
27+
{
28+
public:
29+
explicit ActionArgumentAnalysis(mlir::Operation* op);
30+
31+
const IntegerArgumentConstraints& getBoundsOf(mlir::Value arg)
32+
{
33+
return contraints[arg];
34+
}
35+
36+
private:
37+
void handle(mlir::rlc::ActionStatement statement);
38+
void handle(mlir::rlc::ActionFunction statement);
39+
void handleArgument(mlir::Value argument, mlir::Operation* contraint);
40+
void handleBinaryOp(
41+
mlir::Operation* op,
42+
mlir::Value argument,
43+
IntegerArgumentConstraints::IntType constant);
44+
void handleBinaryOp(
45+
mlir::Operation* op,
46+
IntegerArgumentConstraints::IntType constant,
47+
mlir::Value argument);
48+
mlir::DenseMap<mlir::Value, IntegerArgumentConstraints> contraints;
49+
};
50+
} // namespace mlir::rlc

0 commit comments

Comments
 (0)