Skip to content

Commit

Permalink
Merge pull request #81 from tinymanorg/exit-statement-checks
Browse files Browse the repository at this point in the history
Exit statement checks
  • Loading branch information
fergalwalsh authored Apr 21, 2023
2 parents 830337d + 8be8f57 commit de87e0a
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 11 deletions.
47 changes: 46 additions & 1 deletion tealish/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def get_current_scope(self) -> Scope:
def consume(cls, compiler: "TealishCompiler", parent: Optional[Node]) -> "Program":
node = Program("", parent=parent, compiler=compiler)
expect_struct_definition = True
exit_statement = None
while True:
if compiler.peek() is None:
break
Expand All @@ -280,6 +281,22 @@ def consume(cls, compiler: "TealishCompiler", parent: Optional[Node]) -> "Progra
)
if not isinstance(n, (TealVersion, Blank, Comment, Struct)):
expect_struct_definition = False

if exit_statement:
if not isinstance(n, (Func, Block, Comment, Blank)):
raise ParseError(
f"Unexpected statement at line {n.line_no}."
+ f" Only Block and Function definitions should occure after a {exit_statement}."
)
else:
if isinstance(n, (Func, Block)):
raise ParseError(
f"Unexpected {n} definition at line {n.line_no}. "
+ "Block and Function definitions must occur after an exit statement (e.g Exit, switch, jump)."
)
if is_exit_statement(n):
exit_statement = n

node.add_child(n)
return node

Expand Down Expand Up @@ -623,11 +640,34 @@ def __init__(
def consume(cls, compiler: "TealishCompiler", parent: Optional[Node]) -> "Block":
line = compiler.consume_line()
block = Block(line, parent, compiler=compiler)
exit_statement = None
while True:
if compiler.peek() == "end":
compiler.consume_line()
if exit_statement is None:
raise ParseError(
f"Unexpected end of block at line {compiler.line_no}."
+ " Blocks must end with an exit statement (e.g. exit, switch, jump)"
)
break
block.add_child(Statement.consume(compiler, block))

n = Statement.consume(compiler, block)
if exit_statement:
if not isinstance(n, (Func, Block, Comment, Blank)):
raise ParseError(
f"Unexpected statement at line {n.line_no}."
+ f" Only Block and Function definitions should occure after a {exit_statement}."
)
else:
if isinstance(n, (Func, Block)):
raise ParseError(
f"Unexpected {n} definition at line {n.line_no}. "
+ "Block and Function definitions must occur after an exit statement (e.g. exit, switch, jump)."
)
if is_exit_statement(n):
exit_statement = n

block.add_child(n)
return block

def process(self) -> None:
Expand Down Expand Up @@ -1760,3 +1800,8 @@ def split_return_args(s):

def indent(s: str) -> str:
return textwrap.indent(s, " ")


def is_exit_statement(node):
if isinstance(node, (Exit, Switch, Jump)):
return True
3 changes: 3 additions & 0 deletions tests/everything.teal
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ inner_stuff:
fail:
// oops()
callsub __func__oops
// exit(1)
pushint 1
return

// Function with no args or return value
// func oops():
Expand Down
1 change: 1 addition & 0 deletions tests/everything.tl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ end

block fail:
oops()
exit(1)
end

# Function with no args or return value
Expand Down
93 changes: 83 additions & 10 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def compile_min(p):
return min_teal


def compile_function_min(p):
program = ["exit(1)"] + p
teal = compile_lines(program)
min_teal = strip_comments(teal)
return min_teal[2:]


def compile_expression_min(p, **kwargs):
teal = compile_expression(p, **kwargs)
min_teal = strip_comments(teal)
Expand Down Expand Up @@ -317,7 +324,7 @@ def test_fail_wrong_type(self):

class TestFunctionReturn(unittest.TestCase):
def test_pass(self):
compile_lines(
compile_function_min(
[
"func f():",
"return",
Expand All @@ -327,7 +334,7 @@ def test_pass(self):

def test_fail_no_return(self):
with self.assertRaises(ParseError) as e:
compile_lines(
compile_function_min(
[
"func f():",
"assert(1)",
Expand All @@ -339,7 +346,7 @@ def test_fail_no_return(self):
@expectedFailure
def test_fail_wrong_sig_1_return(self):
with self.assertRaises(CompileError) as e:
compile_lines(
compile_function_min(
[
"func f():",
"return 1",
Expand All @@ -353,7 +360,7 @@ def test_fail_wrong_sig_1_return(self):
@expectedFailure
def test_fail_wrong_sig_2_returns(self):
with self.assertRaises(CompileError) as e:
compile_lines(
compile_function_min(
[
"func f() int:",
"return 1, 2",
Expand All @@ -365,7 +372,7 @@ def test_fail_wrong_sig_2_returns(self):
)

def test_pass_return_literal(self):
compile_lines(
compile_function_min(
[
"func f() int:",
"return 1",
Expand All @@ -374,7 +381,7 @@ def test_pass_return_literal(self):
)

def test_pass_return_two_literals(self):
compile_lines(
compile_function_min(
[
"func f() int, int:",
"return 1, 2",
Expand All @@ -383,7 +390,7 @@ def test_pass_return_two_literals(self):
)

def test_pass_return_math_expression(self):
compile_lines(
compile_function_min(
[
"func f() int:",
"return 1 + 2",
Expand All @@ -392,7 +399,7 @@ def test_pass_return_math_expression(self):
)

def test_pass_return_two_math_expressions(self):
compile_lines(
compile_function_min(
[
"func f() int, int:",
"return 1 + 2, 3 + 1",
Expand All @@ -401,7 +408,7 @@ def test_pass_return_two_math_expressions(self):
)

def test_pass_return_bytes_with_comma(self):
teal = compile_min(
teal = compile_function_min(
[
"func f() bytes:",
'return "1,2,3"',
Expand All @@ -411,7 +418,7 @@ def test_pass_return_bytes_with_comma(self):
self.assertListEqual(teal[1:], ['pushbytes "1,2,3"', "retsub"])

def test_pass_return_two_func_calls(self):
teal = compile_min(
teal = compile_function_min(
[
"func f() int, int:",
"return sqrt(25), exp(5, 2)",
Expand Down Expand Up @@ -1263,6 +1270,72 @@ def test_pass_method_const(self):
self.assertListEqual(teal, ['method "name(uint64,uint64)"', "log"])


class TestExits(unittest.TestCase):
def test_pass_nested_blocks(self):
compile_min(
[
"jump a",
"block a:",
" jump b",
" block b:",
" exit(1)",
" end",
"end",
]
)

def test_pass_comments_after_jump(self):
compile_min(
[
"jump a",
"# a comment",
"block a:",
" jump b",
" # a comment",
" block b:",
" exit(1)",
" end",
"end",
]
)

def test_pass_func_in_block(self):
compile_min(
[
"jump a",
"block a:",
" f()",
" exit(1)",
" func f():",
" return",
" end",
"end",
]
)

def test_fail_block_before_exit(self):
with self.assertRaises(ParseError) as e:
compile_min(['log("abc")', "block a:", " exit(1)", "end"])
self.assertIn("Unexpected Block definition", str(e.exception))

def test_fail_func_before_exit(self):
with self.assertRaises(ParseError) as e:
compile_min(['log("abc")', "func a():", " return", "end"])
self.assertIn("Unexpected Func definition", str(e.exception))

def test_fail_block_without_exit(self):
with self.assertRaises(ParseError) as e:
compile_min(["jump a", "block a:", " assert(1)", "end"])
self.assertIn("Unexpected end of block", str(e.exception))

def test_fail_block_without_final_exit(self):
with self.assertRaises(ParseError) as e:
compile_min(
["jump a", "block a:", " if 1:", " exit(1)", " end", "end"]
)
self.assertIn("Unexpected end of block", str(e.exception))


class TestEverythingProgram(unittest.TestCase):
maxDiff = None

Expand Down

0 comments on commit de87e0a

Please sign in to comment.