Skip to content

Commit a3ee1ea

Browse files
added python driver
1 parent 3f04f1a commit a3ee1ea

File tree

6 files changed

+342
-13
lines changed

6 files changed

+342
-13
lines changed

.pylintrc

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[MASTER]
2+
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
3+
[MESSAGES CONTROL]
4+
disable=C0326,missing-docstring,invalid-name,trailing-whitespace
5+

lib/conversions/src/RLCToPython.cpp

+30-9
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,19 @@ class FunctionToPyFunction
250250
}
251251
};
252252

253+
static mlir::rlc::EntityType getActionTypeOfActionStatement(
254+
mlir::rlc::ActionStatement action)
255+
{
256+
auto* currentOp = action.getOperation()->getParentOp();
257+
assert(currentOp != nullptr);
258+
while (not mlir::dyn_cast<mlir::rlc::ActionFunction>(currentOp))
259+
{
260+
currentOp = currentOp->getParentOp();
261+
assert(currentOp != nullptr);
262+
}
263+
return mlir::cast<mlir::rlc::ActionFunction>(currentOp).getEntityType();
264+
}
265+
253266
static void emitActionContraints(
254267
mlir::rlc::ActionStatement action,
255268
mlir::Value emittedPythonFunction,
@@ -259,20 +272,28 @@ static void emitActionContraints(
259272
auto created = rewriter.create<mlir::rlc::python::PythonActionInfo>(
260273
action->getLoc(), emittedPythonFunction);
261274

275+
mlir::rlc::EntityType ActionType = getActionTypeOfActionStatement(action);
276+
262277
llvm::SmallVector<mlir::Location, 2> locs;
278+
llvm::SmallVector<mlir::Type, 2> types;
279+
280+
locs.push_back(action.getLoc());
281+
types.push_back(ActionType);
282+
263283
for (size_t i = 0; i < action.getResultTypes().size(); i++)
264284
locs.push_back(action.getLoc());
265285

286+
for (auto type : action.getResultTypes())
287+
types.push_back(type);
288+
266289
auto* block = rewriter.createBlock(
267-
&created.getBody(),
268-
created.getBody().begin(),
269-
action.getResultTypes(),
270-
locs);
290+
&created.getBody(), created.getBody().begin(), types, locs);
271291

272292
rewriter.setInsertionPoint(block, block->begin());
273293

274294
for (const auto& [pythonArg, rlcArg] : llvm::zip(
275-
block->getArguments(), action.getPrecondition().getArguments()))
295+
block->getArguments().drop_front(),
296+
action.getPrecondition().getArguments()))
276297
{
277298
const auto& argInfo = analysis.getBoundsOf(rlcArg);
278299
rewriter.create<mlir::rlc::python::PythonArgumentConstraint>(
@@ -292,13 +313,13 @@ static void emitActionContraints(
292313
action->getLoc(), emittedPythonFunction);
293314

294315
llvm::SmallVector<mlir::Location, 2> locs;
295-
for (size_t i = 0; i < action.getFunctionType().getNumResults(); i++)
316+
for (size_t i = 0; i < action.getFunctionType().getNumInputs(); i++)
296317
locs.push_back(action.getLoc());
297318

298319
auto* block = rewriter.createBlock(
299320
&created.getBody(),
300321
created.getBody().begin(),
301-
action.getFunctionType().getResults(),
322+
action.getFunctionType().getInputs(),
302323
locs);
303324

304325
rewriter.setInsertionPoint(block, block->begin());
@@ -345,7 +366,7 @@ class ActionDeclToTNothing
345366
rewriter,
346367
getTypeConverter(),
347368
op.getUnmangledName(),
348-
mlir::rlc::mangledName(op.getMangledName(), op.getFunctionType()),
369+
op.getMangledName(),
349370
op.getArgNames(),
350371
op.getFunctionType());
351372

@@ -414,7 +435,7 @@ void rlc::RLCToPython::runOnOperation()
414435
auto lib = builder.create<mlir::rlc::python::CTypesLoad>(
415436
getOperation().getLoc(),
416437
mlir::rlc::python::CDLLType::get(&getContext()),
417-
"./lib.so");
438+
"lib.so");
418439
mlir::ConversionTarget target(getContext());
419440

420441
mlir::TypeConverter ctypesConverter;

lib/python/src/Operations.cpp

+25-4
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,19 @@ mlir::LogicalResult mlir::rlc::python::PythonFun::emit(
7979

8080
OS << "\n";
8181

82+
OS << "signatures[" << getSymName() << "] = [";
83+
if (getFunctionType().getNumResults() != 0)
84+
writeTypeName(OS, getFunctionType().getResult(0));
85+
else
86+
OS << "None";
87+
OS << ", ";
88+
for (mlir::Type type : getFunctionType().getInputs())
89+
{
90+
writeTypeName(OS, type);
91+
OS << ", ";
92+
}
93+
OS << "]\n";
94+
8295
return mlir::success();
8396
}
8497

@@ -166,13 +179,16 @@ mlir::LogicalResult mlir::rlc::python::CTypesLoad::emit(
166179
{
167180
OS << "from ctypes import *\n";
168181
OS << "from typing import overload\n";
182+
OS << "from pathlib import Path\n";
169183
OS << "import builtins\n";
170184
OS << "from collections import defaultdict\n\n";
171-
OS << "lib = CDLL(\"" << getLibName() << "\")\n";
185+
OS << "lib = CDLL(Path(__file__).resolve().parent / \"" << getLibName()
186+
<< "\")\n";
172187
context.registerValue(getResult(), "lib");
173188

174189
OS << "actions = defaultdict(list)\n";
175190
OS << "args_info = {}\n";
191+
OS << "signatures = {}\n";
176192

177193
return mlir::success();
178194
}
@@ -207,15 +223,20 @@ mlir::LogicalResult mlir::rlc::python::PythonActionInfo::emit(
207223
for (auto& arg : getBody().getArguments())
208224
{
209225
assert(std::distance(arg.getUses().begin(), arg.getUses().end()) <= 1);
210-
for (auto& use : arg.getUses())
226+
if (not arg.getUses().empty())
211227
{
228+
auto& Use = *arg.getUses().begin();
212229
auto argConstraint =
213230
mlir::cast<mlir::rlc::python::PythonArgumentConstraint>(
214-
use.getOwner());
231+
Use.getOwner());
215232
OS << "(" << argConstraint.getMin() << ", " << argConstraint.getMax()
216233
<< ")";
217-
OS << ", ";
218234
}
235+
else
236+
{
237+
OS << "None";
238+
}
239+
OS << ", ";
219240
}
220241
OS << "]\n\n";
221242

python/loader/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .simulation import Simulation, Action, Argument, State, compile

python/loader/simulation.py

+232
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
from importlib import import_module, machinery, util
2+
import inspect
3+
from collections import defaultdict
4+
from tempfile import TemporaryDirectory
5+
from subprocess import run
6+
7+
8+
def import_file(name, file_path):
9+
loader = machinery.SourceFileLoader(name, file_path)
10+
spec = util.spec_from_loader(name, loader)
11+
mod = util.module_from_spec(spec)
12+
loader.exec_module(mod)
13+
return mod
14+
15+
16+
class Argument:
17+
def __init__(self, action, arg_index):
18+
self.action = action
19+
self.index = arg_index
20+
21+
@property
22+
def name(self):
23+
return inspect.getfullargspec(self.action.action)[0][self.index]
24+
25+
def get_min_max(self):
26+
arg_info = self.action.module.module.args_info
27+
if self.action.action not in arg_info:
28+
return None
29+
30+
return arg_info[self.action.action][self.index]
31+
32+
@property
33+
def max(self):
34+
maybe_min_max = self.get_min_max()
35+
if maybe_min_max is None:
36+
return None
37+
(_, max) = maybe_min_max
38+
return max
39+
40+
@property
41+
def min(self):
42+
maybe_min_max = self.get_min_max()
43+
if maybe_min_max is None:
44+
return None
45+
(min, _) = maybe_min_max
46+
return min
47+
48+
def dump(self):
49+
print("\targ: {}".format(self.name))
50+
if self.min is not None:
51+
print("\t\tmin: {}".format(self.min))
52+
if self.min is not None:
53+
print("\t\tmax: {}".format(self.max))
54+
55+
@property
56+
def type(self):
57+
return self.action.arg_types[self.index]
58+
59+
def parse(self, to_convert):
60+
if isinstance(to_convert, self.type):
61+
return to_convert
62+
63+
if not isinstance(to_convert, str):
64+
print(
65+
"Unable to convert argument {} to type {}",
66+
to_conver,
67+
type(to_convert).__name__(),
68+
)
69+
return None
70+
71+
if self.type == int:
72+
return int(to_convert)
73+
74+
if self.type == bool:
75+
return bool(to_convert)
76+
77+
if self.type == float:
78+
return float(to_convert)
79+
80+
print("Conversion to non primitive types is not implemented yet")
81+
assert false
82+
83+
84+
class Action:
85+
def __init__(self, action, name: str, module):
86+
self.action = action
87+
self.name = name
88+
self.module = module
89+
90+
arg_info = self.module.module.args_info[self.action]
91+
self.args = [Argument(self, i) for i in range(len(arg_info))]
92+
93+
@property
94+
def return_type(self):
95+
return self.module.module.signatures[self.action][0]
96+
97+
@property
98+
def arg_types(self):
99+
return self.module.module.signatures[self.action][1:]
100+
101+
@property
102+
def signature(self):
103+
return self.module.module.signatures[self.action]
104+
105+
def get_simulation_init(self):
106+
return self.module.action_to_simulation_init[self]
107+
108+
def is_simulation_init(self) -> bool:
109+
return self.get_simulation_init() == self
110+
111+
def dump(self):
112+
if self.return_type is not None:
113+
print(
114+
"{}({}) -> {}".format(
115+
self.name,
116+
", ".join(type.__name__ for type in self.arg_types),
117+
self.return_type.__name__,
118+
)
119+
)
120+
else:
121+
print(
122+
"{}({})".format(
123+
self.name, ", ".join(type.__name__ for type in self.arg_types)
124+
)
125+
)
126+
for arg in self.args:
127+
arg.dump()
128+
129+
def invoke(self, *args):
130+
casted_args = [
131+
formal_arg.parse(string_arg)
132+
for (formal_arg, string_arg) in zip(self.args, args)
133+
]
134+
if None in casted_args:
135+
return
136+
137+
return self.action(*casted_args)
138+
139+
class State:
140+
def __init__(self, simulation, state):
141+
self.simulation = simulation
142+
self.state = state
143+
144+
def execute(self, arguments):
145+
return self.simulation.execute([arguments[0], self.state, *arguments[1:]])
146+
147+
def dump(self):
148+
print(self.state)
149+
150+
class Simulation:
151+
def __init__(self, wrapper: str):
152+
self.wrapper_path = wrapper
153+
self.module = import_file("sim", wrapper)
154+
155+
self.actions = []
156+
for action_name in self.action_names:
157+
for overload in self.module.actions[action_name]:
158+
self.actions.append(Action(overload, action_name, self))
159+
160+
self.action_to_simulation_init = {}
161+
self.entity_type_to_simulation_init = {}
162+
self.simulation_inits = []
163+
self.names_to_overloads = defaultdict(list)
164+
165+
for action in self.actions:
166+
if action.return_type is None:
167+
continue
168+
169+
self.simulation_inits.append(action)
170+
self.entity_type_to_simulation_init[action.return_type] = action
171+
self.action_to_simulation_init[action] = action
172+
self.names_to_overloads[action.name].append(action)
173+
174+
for action in self.actions:
175+
if not action.return_type is None:
176+
continue
177+
178+
entity_type = action.arg_types[0]
179+
self.action_to_simulation_init[
180+
action
181+
] = self.entity_type_to_simulation_init[entity_type]
182+
self.names_to_overloads[action.name].append(action)
183+
184+
@property
185+
def action_names(self) -> [str]:
186+
return [name for name in self.module.actions.keys()]
187+
188+
def get_overloads_of_action(self, action_name: str) -> [Action]:
189+
return self.names_to_overloads[action_name]
190+
191+
def dump(self):
192+
for name in self.action_names:
193+
for overload in self.get_overloads_of_action(name):
194+
overload.dump()
195+
196+
def execute(self, arguments, include_simulations_init = False):
197+
assert len(arguments) != 0
198+
action_name = arguments[0]
199+
args = arguments[1:]
200+
201+
overloads = self.get_overloads_of_action(action_name)
202+
if len(overloads) == 0:
203+
print("No known action named {}".format(action_name))
204+
return
205+
206+
overload = self.resolve_overload(action_name, len(args), include_simulations_init)
207+
208+
if overload is None:
209+
print(
210+
"No known action named {} with {} arguments".format( action_name, len(args)
211+
))
212+
return
213+
214+
return overload.invoke(*args)
215+
216+
def start(self, args) -> State:
217+
return State(self, self.execute(args, True))
218+
219+
def resolve_overload(self, overload_name, args_count, include_simulations_init=False):
220+
for overload in self.get_overloads_of_action(overload_name):
221+
if not include_simulations_init and overload.is_simulation_init():
222+
continue
223+
224+
if len(overload.args) == args_count:
225+
return overload
226+
return None
227+
228+
def compile(source, rlc_compiler="rlc"):
229+
with TemporaryDirectory() as tmp_dir:
230+
assert(run([rlc_compiler, source, "--python", "-o", "{}/wrapper.py".format(tmp_dir)]).returncode == 0)
231+
assert(run([rlc_compiler, source, "--shared", "-o", "{}/lib.so".format(tmp_dir)]).returncode == 0)
232+
return Simulation(tmp_dir + "/wrapper.py")

0 commit comments

Comments
 (0)