Skip to content

Commit 07e4a04

Browse files
committed
Add macro arguments
1 parent 63345b5 commit 07e4a04

File tree

4 files changed

+196
-24
lines changed

4 files changed

+196
-24
lines changed

example/http.nmfu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ out int content_size = 0;
1313
// "subroutine"-esque functionality. duplicate states are optimized out.
1414
//
1515
// Effectively just substitudes into the ast at the appropriate point.
16-
macro waitend {
16+
macro waitend() {
1717
wait "\r\n\r\n";
1818
}
1919

2020
// optional space as in the spec
21-
macro ows {
21+
macro ows() {
2222
optional {" ";}
2323
}
2424

example/test/macro.ok.nmfu

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
out int int1;
2+
out int int2;
3+
4+
hook hi;
5+
6+
macro test1(out out2, match words) {
7+
out2 = 0;
8+
foreach {
9+
words;
10+
} do {
11+
out2 = [out2 * 10 + ($last - '0')];
12+
}
13+
}
14+
15+
macro dummy() {
16+
"hi";
17+
}
18+
19+
macro test2(hook hooky, macro submacro) {
20+
submacro();
21+
"ooflio";
22+
hooky();
23+
}
24+
25+
macro foreach_wrapper(expr action, loop target) {
26+
foreach {
27+
/\w+/;
28+
" ";
29+
} do {
30+
int1 = action;
31+
}
32+
break target;
33+
}
34+
35+
parser {
36+
test1(int2, /\d+/);
37+
test2(hi, dummy);
38+
39+
loop outer {
40+
loop inner {
41+
foreach_wrapper([int1 + 5], outer);
42+
}
43+
}
44+
}

example/ttc_rdf.nmfu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ out str[32] date;
55

66
hook on_advisory;
77

8-
macro ows {
8+
macro ows() {
99
/\s*/;
1010
}
1111

nmfu.py

Lines changed: 149 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,17 @@ def Lark(*args, **kwargs):
6969
7070
hook_decl: "hook" IDENTIFIER ";"
7171
72-
macro_decl: "macro" IDENTIFIER "{" statement* "}"
72+
macro_decl: "macro" IDENTIFIER macro_args "{" statement* "}"
73+
74+
macro_args: "(" macro_arg ("," macro_arg)* ")"
75+
| "(" ")" -> macro_arg_empty
76+
77+
macro_arg: "macro" IDENTIFIER -> macro_macro_arg
78+
| "out" IDENTIFIER -> macro_out_arg
79+
| "match" IDENTIFIER -> macro_match_expr_arg
80+
| "expr" IDENTIFIER -> macro_int_expr_arg
81+
| "hook" IDENTIFIER -> macro_hook_arg
82+
| "loop" IDENTIFIER -> macro_breaktgt_arg
7383
7484
parser_decl: "parser" "{" statement+ "}"
7585
@@ -1006,8 +1016,8 @@ def __str__(self):
10061016
return self._get_message()
10071017

10081018
class IllegalASTStateError(NMFUError):
1009-
def __init__(self, msg, source):
1010-
super().__init__([source])
1019+
def __init__(self, msg, *source):
1020+
super().__init__([*source])
10111021
self.source = source
10121022
self.msg = msg
10131023

@@ -2825,14 +2835,52 @@ def convert(self, current_error_handlers):
28252835
sub_dfa.append_after(self.next.convert(current_error_handlers), current_error_handlers[ErrorReasons.NO_MATCH], chain_actions=self.after_actions)
28262836

28272837
return sub_dfa
2838+
2839+
class MacroArgumentKind(enum.Enum):
2840+
MACRO = 0
2841+
OUT = 1
2842+
MATCH = 2
2843+
INTEXPR = 3
2844+
HOOK = 4
2845+
LOOP = 5
2846+
2847+
EXPR = 10
2848+
2849+
class MacroArgument:
2850+
def __init__(self, name: str, kind: MacroArgumentKind):
2851+
self.name = name
2852+
self.kind = kind
2853+
2854+
def get_lookup_type(self):
2855+
if self.kind in (MacroArgumentKind.MATCH, MacroArgumentKind.INTEXPR):
2856+
return MacroArgumentKind.EXPR
2857+
else:
2858+
return self.kind
28282859

28292860
class Macro:
2830-
def __init__(self, name_token: lark.Token, parse_tree: lark.Tree):
2861+
def __init__(self, name_token: lark.Token, parse_tree: lark.Tree, arguments: List[MacroArgument]):
28312862
self.name = name_token.value
28322863
self.parse_tree = parse_tree
2864+
self.arguments = arguments
28332865
ProgramData.imbue(self, DebugTag.SOURCE_LINE, name_token.line)
28342866
ProgramData.imbue(self, DebugTag.NAME, "macro " + self.name)
28352867

2868+
def bind_arguments_for(self, input_trees: List[lark.Tree]):
2869+
bound_arguments = {}
2870+
for argspec, value in zip(self.arguments, input_trees):
2871+
allowed_types = {
2872+
MacroArgumentKind.MACRO: ("identifier_const",),
2873+
MacroArgumentKind.OUT: ("identifier_const",),
2874+
MacroArgumentKind.HOOK: ("identifier_const",),
2875+
MacroArgumentKind.LOOP: ("identifier_const",),
2876+
MacroArgumentKind.MATCH: ("regex", "end_expr", "concat_expr", "string_const", "string_case_const"),
2877+
MacroArgumentKind.INTEXPR: ("string_const", "sum_expr", "bool_const", "number_const", "char_const", "identifier_const")
2878+
}[argspec.kind]
2879+
if value.data not in allowed_types:
2880+
raise IllegalParseTree("Invalid argument type for argument " + argspec.name, value)
2881+
bound_arguments[(argspec.get_lookup_type(), argspec.name)] = value
2882+
return bound_arguments
2883+
28362884
# =========
28372885
# PARSE CTX
28382886
# =========
@@ -2850,6 +2898,8 @@ def __init__(self, parse_tree: lark.Tree):
28502898
self.exception_handlers = defaultdict(lambda: self.generic_fail_state) # normal ErrorReason -> State
28512899
self.break_handlers = {} # "string name" -> Action
28522900
self.innermost_break_handler = None # just an Action
2901+
2902+
self.bound_argument_stack: List[Dict[Tuple[MacroArgumentKind, str], lark.Tree]] = []
28532903

28542904
def parse(self):
28552905
# Parse state_object_spec
@@ -2860,7 +2910,7 @@ def parse(self):
28602910
self.state_object_spec[out_obj.name] = out_obj
28612911
# Parse macros
28622912
for macro in self._parse_tree.find_data("macro_decl"):
2863-
macro_obj = Macro(macro.children[0], macro.children[1:])
2913+
macro_obj = Macro(macro.children[0], macro.children[2:], self._parse_macro_arguments(macro.children[1]))
28642914
if macro_obj.name in self.macros:
28652915
raise DuplicateDefinitionError("macro", macro_obj, macro_obj.name)
28662916
self.macros[macro_obj.name] = macro_obj
@@ -2876,6 +2926,32 @@ def parse(self):
28762926
if isinstance(self.ast, ActionSourceNode):
28772927
self.start_actions, self.ast = self.ast.adopt_actions_from()
28782928

2929+
def _lookup_bound_argument(self, name, context: MacroArgumentKind):
2930+
try:
2931+
return self.bound_argument_stack[-1][(context, name.value)]
2932+
except (KeyError, IndexError):
2933+
raise UndefinedReferenceError("bound argument", name)
2934+
2935+
def _parse_macro_arguments(self, args: lark.Tree):
2936+
if args.data == "macro_arg_empty":
2937+
return []
2938+
defined = set()
2939+
parsed_args = []
2940+
for i in args.children:
2941+
name = i.children[0].value
2942+
if name in defined:
2943+
raise DuplicateDefinitionError("macro argument", i, name)
2944+
kind = {
2945+
"macro_macro_arg": MacroArgumentKind.MACRO,
2946+
"macro_out_arg": MacroArgumentKind.OUT,
2947+
"macro_match_expr_arg": MacroArgumentKind.MATCH,
2948+
"macro_int_expr_arg": MacroArgumentKind.INTEXPR,
2949+
"macro_hook_arg": MacroArgumentKind.HOOK,
2950+
"macro_breaktgt_arg": MacroArgumentKind.LOOP,
2951+
}[i.data]
2952+
parsed_args.append(MacroArgument(name, kind))
2953+
return parsed_args
2954+
28792955
def _convert_string(self, escaped_string: str):
28802956
"""
28812957
Parse the string `escaped_string`, which is the direct token from lark (still with quotes and escapes)
@@ -2944,7 +3020,17 @@ def _parse_math_expr(self, expr: lark.Tree):
29443020
return ProgramData.imbue(ProgramData.imbue(LiteralIntegerExpr(ord(expr.children[0].value[1])), DebugTag.SOURCE_LINE, expr.line), DebugTag.SOURCE_COLUMN, expr.column)
29453021
elif expr.data == "math_var":
29463022
try:
2947-
return ProgramData.imbue(ProgramData.imbue(OutIntegerExpr(self.state_object_spec[expr.children[0].value]), DebugTag.SOURCE_LINE, expr.line), DebugTag.SOURCE_COLUMN, expr.column)
3023+
out_spec = self.state_object_spec[expr.children[0].value]
3024+
except KeyError:
3025+
try:
3026+
target_name = self._lookup_bound_argument(expr.children[0], MacroArgumentKind.OUT).children[0]
3027+
out_spec = self.state_object_spec[target_name.value]
3028+
except UndefinedReferenceError:
3029+
raise UndefinedReferenceError("output", expr.children[0])
3030+
except KeyError:
3031+
raise UndefinedReferenceError("output", target_name)
3032+
try:
3033+
return ProgramData.imbue(ProgramData.imbue(OutIntegerExpr(out_spec), DebugTag.SOURCE_LINE, expr.line), DebugTag.SOURCE_COLUMN, expr.column)
29483034
except KeyError:
29493035
raise UndefinedReferenceError("output", expr.children[0])
29503036
elif expr.data == "builtin_math_var":
@@ -2980,6 +3066,11 @@ def _parse_integer_expr(self, expr: lark.Tree, into_storage: OutputStorage=None)
29803066
ProgramData.imbue(val, DebugTag.SOURCE_COLUMN, expr.children[0].column)
29813067
return val
29823068
elif expr.data == "identifier_const":
3069+
try:
3070+
expr = self._lookup_bound_argument(expr.children[0], MacroArgumentKind.EXPR)
3071+
return self._parse_integer_expr(expr, into_storage=into_storage)
3072+
except UndefinedReferenceError:
3073+
pass
29833074
if into_storage is None:
29843075
raise IllegalParseTree("Undefined enumeration value, no into_storage", expr)
29853076

@@ -3041,6 +3132,8 @@ def _parse_match_expr(self, expr: lark.Tree) -> Match:
30413132
return match
30423133
elif expr.data == "concat_expr":
30433134
return ConcatMatch(list(self._parse_match_expr(x) for x in expr.children))
3135+
elif expr.data == "identifier_const":
3136+
return self._parse_match_expr(self._lookup_bound_argument(expr.children[0], MacroArgumentKind.EXPR))
30443137
else:
30453138
raise IllegalParseTree("Invalid expression in match expression context", expr)
30463139

@@ -3052,23 +3145,35 @@ def _parse_assign_stmt(self, stmt: lark.Tree, is_append) -> Node:
30523145
# Resolve the target
30533146
targeted = stmt.children[0].value
30543147
if targeted not in self.state_object_spec:
3055-
raise UndefinedReferenceError("output", stmt.children[0])
3148+
try:
3149+
targeted = self._lookup_bound_argument(stmt.children[0], MacroArgumentKind.OUT).children[0].value
3150+
except UndefinedReferenceError:
3151+
raise UndefinedReferenceError("output", stmt.children[0])
30563152
targeted = self.state_object_spec[targeted]
30573153

30583154
if targeted.type == OutputStorageType.STR and is_append:
3155+
# Handle arguments
3156+
if stmt.children[1].data == "identifier_const":
3157+
sub_expr = self._lookup_bound_argument(stmt.children[1].children[0], MacroArgumentKind.EXPR)
3158+
else:
3159+
sub_expr = stmt.children[1]
30593160
# Check if this is a math expression (only valid append type other than match)
3060-
if stmt.children[1].data in ["math_num", "math_var", "builtin_math_var", "sum_expr", "mul_expr"]:
3161+
if sub_expr.data in ["math_num", "math_var", "builtin_math_var", "sum_expr", "mul_expr"]:
30613162
# Create an AppendCharTo action
3062-
return ActionNode(AppendCharTo(self.exception_handlers[ErrorReasons.OUT_OF_SPACE], self._parse_math_expr(stmt.children[1]), targeted))
3163+
return ActionNode(AppendCharTo(self.exception_handlers[ErrorReasons.OUT_OF_SPACE], self._parse_math_expr(sub_expr), targeted))
30633164
# Create an append expression
3064-
match_node = MatchNode(self._parse_match_expr(stmt.children[1]))
3165+
match_node = MatchNode(self._parse_match_expr(sub_expr))
30653166
match_node.match.attach(AppendTo(self.exception_handlers[ErrorReasons.OUT_OF_SPACE], targeted))
30663167
return match_node
30673168
elif targeted.type == OutputStorageType.STR:
3068-
if stmt.children[1].data != "string_const":
3069-
raise IllegalParseTree("String assignment only supports string constants, did you mean +=?", stmt.children[1])
3169+
# Handle arguments
3170+
if stmt.children[1].data == "identifier_const":
3171+
sub_expr = self._lookup_bound_argument(stmt.children[1].children[0], MacroArgumentKind.EXPR)
30703172
else:
3071-
return ActionNode(SetToStr(self._convert_string(stmt.children[1].children[0].value), targeted))
3173+
sub_expr = stmt.children[1]
3174+
if sub_expr.data != "string_const":
3175+
raise IllegalParseTree("String assignment only supports string constants, did you mean +=?", sub_expr)
3176+
return ActionNode(SetToStr(self._convert_string(sub_expr.children[0].value), targeted))
30723177
elif not is_append:
30733178
return ActionNode(SetTo(self._parse_integer_expr(stmt.children[1], targeted), targeted))
30743179

@@ -3092,6 +3197,16 @@ def _parse_case_clause(self, clause: lark.Tree):
30923197

30933198
return frozenset(result_set), target_dfa
30943199

3200+
def _parse_macro_call(self, lark_node_for_error: lark.Tree, name: str, arguments: List[lark.Tree]):
3201+
if len(arguments) != len(self.macros[name].arguments):
3202+
raise IllegalParseTree("Incorrect number of arguments", lark_node_for_error)
3203+
self.bound_argument_stack.append(
3204+
self.macros[name].bind_arguments_for(arguments)
3205+
)
3206+
node = ProgramData.imbue(self._parse_stmt_seq(self.macros[name].parse_tree), DebugTag.PARENT, self.macros[name])
3207+
del self.bound_argument_stack[-1]
3208+
return node
3209+
30953210
def _parse_stmt(self, stmt: lark.Tree) -> Node:
30963211
"""
30973212
Parse a statement into a node
@@ -3103,12 +3218,21 @@ def _parse_stmt(self, stmt: lark.Tree) -> Node:
31033218
elif stmt.data == "call_stmt":
31043219
name = stmt.children[0].value
31053220
try:
3106-
return ProgramData.imbue(self._parse_stmt_seq(self.macros[name].parse_tree), DebugTag.PARENT, self.macros[name])
3221+
return self._parse_macro_call(stmt, name, stmt.children[1:])
31073222
except KeyError:
3108-
if name in self.hooks:
3109-
return ActionNode(ProgramData.imbue(ProgramData.imbue(CallHook(name), DebugTag.SOURCE_LINE, stmt.line), DebugTag.SOURCE_COLUMN, stmt.column))
3110-
else:
3111-
raise UndefinedReferenceError("macro", stmt.children[0])
3223+
try:
3224+
name2 = self._lookup_bound_argument(stmt.children[0], MacroArgumentKind.MACRO).children[0].value
3225+
return self._parse_macro_call(stmt, name2, stmt.children[1:])
3226+
except UndefinedReferenceError:
3227+
if name not in self.hooks:
3228+
try:
3229+
name = self._lookup_bound_argument(stmt.children[0], MacroArgumentKind.HOOK).children[0].value
3230+
except UndefinedReferenceError:
3231+
pass
3232+
if name in self.hooks:
3233+
return ActionNode(ProgramData.imbue(ProgramData.imbue(CallHook(name), DebugTag.SOURCE_LINE, stmt.line), DebugTag.SOURCE_COLUMN, stmt.column))
3234+
else:
3235+
raise UndefinedReferenceError("macro or hook", stmt.children[0])
31123236
elif stmt.data == "finish_stmt":
31133237
act = FinishAction()
31143238
ProgramData.imbue(act, DebugTag.SOURCE_LINE, stmt.line)
@@ -3118,9 +3242,13 @@ def _parse_stmt(self, stmt: lark.Tree) -> Node:
31183242
if stmt.children:
31193243
target_name = stmt.children[0].value
31203244
if target_name not in self.break_handlers:
3121-
raise UndefinedReferenceError("loop", stmt.children[0])
3122-
else:
3123-
act = self.break_handlers[target_name]
3245+
try:
3246+
target_name = self._lookup_bound_argument(stmt.children[0], MacroArgumentKind.LOOP).children[0].value
3247+
if target_name not in self.break_handlers:
3248+
raise KeyError()
3249+
except (KeyError, UndefinedReferenceError):
3250+
raise UndefinedReferenceError("loop", stmt.children[0])
3251+
act = self.break_handlers[target_name]
31243252
else:
31253253
act = self.innermost_break_handler
31263254
if act is None:

0 commit comments

Comments
 (0)