Skip to content

Commit c201d67

Browse files
authored
Merge pull request #3343 from babsey/error-lineno
Add line number to error message
2 parents f495b3e + 91eea27 commit c201d67

File tree

1 file changed

+101
-43
lines changed

1 file changed

+101
-43
lines changed

pynest/nest/server/hl_api_server.py

Lines changed: 101 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@
3636
from flask import Flask, jsonify, request
3737
from flask.logging import default_handler
3838
from flask_cors import CORS
39-
from werkzeug.exceptions import abort
40-
from werkzeug.wrappers import Response
39+
from nest.lib.hl_api_exceptions import NESTError
4140

4241
# This ensures that the logging information shows up in the console running the server,
4342
# even when Flask's event loop is running.
@@ -189,41 +188,36 @@ def index():
189188

190189

191190
def do_exec(args, kwargs):
192-
try:
193-
source_code = kwargs.get("source", "")
194-
source_cleaned = clean_code(source_code)
195-
196-
locals_ = dict()
197-
response = dict()
198-
if RESTRICTION_DISABLED:
199-
with Capturing() as stdout:
200-
globals_ = globals().copy()
201-
globals_.update(get_modules_from_env())
202-
exec(source_cleaned, globals_, locals_)
203-
if len(stdout) > 0:
204-
response["stdout"] = "\n".join(stdout)
205-
else:
206-
code = RestrictedPython.compile_restricted(source_cleaned, "<inline>", "exec") # noqa
207-
globals_ = get_restricted_globals()
191+
source_code = kwargs.get("source", "")
192+
source_cleaned = clean_code(source_code)
193+
194+
locals_ = dict()
195+
response = dict()
196+
if RESTRICTION_DISABLED:
197+
with Capturing() as stdout:
198+
globals_ = globals().copy()
208199
globals_.update(get_modules_from_env())
209-
exec(code, globals_, locals_)
210-
if "_print" in locals_:
211-
response["stdout"] = "".join(locals_["_print"].txt)
212-
213-
if "return" in kwargs:
214-
if isinstance(kwargs["return"], list):
215-
data = dict()
216-
for variable in kwargs["return"]:
217-
data[variable] = locals_.get(variable, None)
218-
else:
219-
data = locals_.get(kwargs["return"], None)
220-
response["data"] = nest.serialize_data(data)
221-
return response
200+
get_or_error(exec)(source_cleaned, globals_, locals_)
201+
if len(stdout) > 0:
202+
response["stdout"] = "\n".join(stdout)
203+
else:
204+
code = RestrictedPython.compile_restricted(source_cleaned, "<inline>", "exec") # noqa
205+
globals_ = get_restricted_globals()
206+
globals_.update(get_modules_from_env())
207+
get_or_error(exec)(code, globals_, locals_)
208+
if "_print" in locals_:
209+
response["stdout"] = "".join(locals_["_print"].txt)
210+
211+
if "return" in kwargs:
212+
if isinstance(kwargs["return"], list):
213+
data = dict()
214+
for variable in kwargs["return"]:
215+
data[variable] = locals_.get(variable, None)
216+
else:
217+
data = locals_.get(kwargs["return"], None)
222218

223-
except Exception as e:
224-
for line in traceback.format_exception(*sys.exc_info()):
225-
print(line, flush=True)
226-
flask.abort(EXCEPTION_ERROR_STATUS, str(e))
219+
response["data"] = get_or_error(nest.serialize_data)(data)
220+
return response
227221

228222

229223
def log(call_name, msg):
@@ -336,10 +330,43 @@ def __exit__(self, *args):
336330
sys.stdout = self._stdout
337331

338332

333+
class ErrorHandler(Exception):
334+
status_code = 400
335+
lineno = -1
336+
337+
def __init__(self, message: str, lineno: int = None, status_code: int = None, payload=None):
338+
super().__init__()
339+
self.message = message
340+
if status_code is not None:
341+
self.status_code = status_code
342+
if lineno is not None:
343+
self.lineno = lineno
344+
self.payload = payload
345+
346+
def to_dict(self):
347+
rv = dict(self.payload or ())
348+
rv["message"] = self.message
349+
if self.lineno != -1:
350+
rv["lineNumber"] = self.lineno
351+
return rv
352+
353+
354+
# https://flask.palletsprojects.com/en/2.3.x/errorhandling/
355+
@app.errorhandler(ErrorHandler)
356+
def error_handler(e):
357+
return jsonify(e.to_dict()), e.status_code
358+
359+
360+
# It comments lines starting with 'import' or 'from' otherwise the line number of error would be wrong.
339361
def clean_code(source):
340362
codes = source.split("\n")
341-
code_cleaned = filter(lambda code: not (code.startswith("import") or code.startswith("from")), codes) # noqa
342-
return "\n".join(code_cleaned)
363+
codes_cleaned = [] # noqa
364+
for code in codes:
365+
if code.startswith("import") or code.startswith("from"):
366+
codes_cleaned.append("#" + code)
367+
else:
368+
codes_cleaned.append(code)
369+
return "\n".join(codes_cleaned)
343370

344371

345372
def get_arguments(request):
@@ -368,6 +395,16 @@ def get_arguments(request):
368395
return list(args), kwargs
369396

370397

398+
def get_lineno(err, tb_idx):
399+
lineno = -1
400+
if hasattr(err, "lineno") and err.lineno is not None:
401+
lineno = err.lineno
402+
else:
403+
tb = sys.exc_info()[2]
404+
lineno = traceback.extract_tb(tb)[tb_idx][1]
405+
return lineno
406+
407+
371408
def get_modules_from_env():
372409
"""Get modules from environment variable NEST_SERVER_MODULES.
373410
@@ -397,13 +434,34 @@ def get_modules_from_env():
397434
def get_or_error(func):
398435
"""Wrapper to get data and status."""
399436

400-
def func_wrapper(call, args, kwargs):
437+
def func_wrapper(call, *args, **kwargs):
401438
try:
402-
return func(call, args, kwargs)
403-
except Exception as e:
404-
for line in traceback.format_exception(*sys.exc_info()):
405-
print(line, flush=True)
406-
flask.abort(EXCEPTION_ERROR_STATUS, str(e))
439+
return func(call, *args, **kwargs)
440+
441+
except NESTError as err:
442+
error_class = err.errorname + " (NESTError)"
443+
detail = err.errormessage
444+
lineno = get_lineno(err, 1)
445+
446+
except (KeyError, SyntaxError, TypeError, ValueError) as err:
447+
error_class = err.__class__.__name__
448+
detail = err.args[0]
449+
lineno = get_lineno(err, 1)
450+
451+
except Exception as err:
452+
error_class = err.__class__.__name__
453+
detail = err.args[0]
454+
lineno = get_lineno(err, -1)
455+
456+
for line in traceback.format_exception(*sys.exc_info()):
457+
print(line, flush=True)
458+
459+
if lineno == -1:
460+
message = "%s: %s" % (error_class, detail)
461+
else:
462+
message = "%s at line %d: %s" % (error_class, lineno, detail)
463+
464+
raise ErrorHandler(message, lineno)
407465

408466
return func_wrapper
409467

0 commit comments

Comments
 (0)