|
36 | 36 | from flask import Flask, jsonify, request
|
37 | 37 | from flask.logging import default_handler
|
38 | 38 | from flask_cors import CORS
|
39 |
| -from werkzeug.exceptions import abort |
40 |
| -from werkzeug.wrappers import Response |
| 39 | +from nest.lib.hl_api_exceptions import NESTError |
41 | 40 |
|
42 | 41 | # This ensures that the logging information shows up in the console running the server,
|
43 | 42 | # even when Flask's event loop is running.
|
@@ -189,41 +188,36 @@ def index():
|
189 | 188 |
|
190 | 189 |
|
191 | 190 | def do_exec(args, kwargs):
|
192 |
| - try: |
193 |
| - source_code = kwargs.get("source", "") |
194 |
| - source_cleaned = clean_code(source_code) |
195 |
| - |
196 |
| - locals_ = dict() |
197 |
| - response = dict() |
198 |
| - if RESTRICTION_DISABLED: |
199 |
| - with Capturing() as stdout: |
200 |
| - globals_ = globals().copy() |
201 |
| - globals_.update(get_modules_from_env()) |
202 |
| - exec(source_cleaned, globals_, locals_) |
203 |
| - if len(stdout) > 0: |
204 |
| - response["stdout"] = "\n".join(stdout) |
205 |
| - else: |
206 |
| - code = RestrictedPython.compile_restricted(source_cleaned, "<inline>", "exec") # noqa |
207 |
| - globals_ = get_restricted_globals() |
| 191 | + source_code = kwargs.get("source", "") |
| 192 | + source_cleaned = clean_code(source_code) |
| 193 | + |
| 194 | + locals_ = dict() |
| 195 | + response = dict() |
| 196 | + if RESTRICTION_DISABLED: |
| 197 | + with Capturing() as stdout: |
| 198 | + globals_ = globals().copy() |
208 | 199 | globals_.update(get_modules_from_env())
|
209 |
| - exec(code, globals_, locals_) |
210 |
| - if "_print" in locals_: |
211 |
| - response["stdout"] = "".join(locals_["_print"].txt) |
212 |
| - |
213 |
| - if "return" in kwargs: |
214 |
| - if isinstance(kwargs["return"], list): |
215 |
| - data = dict() |
216 |
| - for variable in kwargs["return"]: |
217 |
| - data[variable] = locals_.get(variable, None) |
218 |
| - else: |
219 |
| - data = locals_.get(kwargs["return"], None) |
220 |
| - response["data"] = nest.serialize_data(data) |
221 |
| - return response |
| 200 | + get_or_error(exec)(source_cleaned, globals_, locals_) |
| 201 | + if len(stdout) > 0: |
| 202 | + response["stdout"] = "\n".join(stdout) |
| 203 | + else: |
| 204 | + code = RestrictedPython.compile_restricted(source_cleaned, "<inline>", "exec") # noqa |
| 205 | + globals_ = get_restricted_globals() |
| 206 | + globals_.update(get_modules_from_env()) |
| 207 | + get_or_error(exec)(code, globals_, locals_) |
| 208 | + if "_print" in locals_: |
| 209 | + response["stdout"] = "".join(locals_["_print"].txt) |
| 210 | + |
| 211 | + if "return" in kwargs: |
| 212 | + if isinstance(kwargs["return"], list): |
| 213 | + data = dict() |
| 214 | + for variable in kwargs["return"]: |
| 215 | + data[variable] = locals_.get(variable, None) |
| 216 | + else: |
| 217 | + data = locals_.get(kwargs["return"], None) |
222 | 218 |
|
223 |
| - except Exception as e: |
224 |
| - for line in traceback.format_exception(*sys.exc_info()): |
225 |
| - print(line, flush=True) |
226 |
| - flask.abort(EXCEPTION_ERROR_STATUS, str(e)) |
| 219 | + response["data"] = get_or_error(nest.serialize_data)(data) |
| 220 | + return response |
227 | 221 |
|
228 | 222 |
|
229 | 223 | def log(call_name, msg):
|
@@ -336,10 +330,43 @@ def __exit__(self, *args):
|
336 | 330 | sys.stdout = self._stdout
|
337 | 331 |
|
338 | 332 |
|
| 333 | +class ErrorHandler(Exception): |
| 334 | + status_code = 400 |
| 335 | + lineno = -1 |
| 336 | + |
| 337 | + def __init__(self, message: str, lineno: int = None, status_code: int = None, payload=None): |
| 338 | + super().__init__() |
| 339 | + self.message = message |
| 340 | + if status_code is not None: |
| 341 | + self.status_code = status_code |
| 342 | + if lineno is not None: |
| 343 | + self.lineno = lineno |
| 344 | + self.payload = payload |
| 345 | + |
| 346 | + def to_dict(self): |
| 347 | + rv = dict(self.payload or ()) |
| 348 | + rv["message"] = self.message |
| 349 | + if self.lineno != -1: |
| 350 | + rv["lineNumber"] = self.lineno |
| 351 | + return rv |
| 352 | + |
| 353 | + |
| 354 | +# https://flask.palletsprojects.com/en/2.3.x/errorhandling/ |
| 355 | +@app.errorhandler(ErrorHandler) |
| 356 | +def error_handler(e): |
| 357 | + return jsonify(e.to_dict()), e.status_code |
| 358 | + |
| 359 | + |
| 360 | +# It comments lines starting with 'import' or 'from' otherwise the line number of error would be wrong. |
339 | 361 | def clean_code(source):
|
340 | 362 | codes = source.split("\n")
|
341 |
| - code_cleaned = filter(lambda code: not (code.startswith("import") or code.startswith("from")), codes) # noqa |
342 |
| - return "\n".join(code_cleaned) |
| 363 | + codes_cleaned = [] # noqa |
| 364 | + for code in codes: |
| 365 | + if code.startswith("import") or code.startswith("from"): |
| 366 | + codes_cleaned.append("#" + code) |
| 367 | + else: |
| 368 | + codes_cleaned.append(code) |
| 369 | + return "\n".join(codes_cleaned) |
343 | 370 |
|
344 | 371 |
|
345 | 372 | def get_arguments(request):
|
@@ -368,6 +395,16 @@ def get_arguments(request):
|
368 | 395 | return list(args), kwargs
|
369 | 396 |
|
370 | 397 |
|
| 398 | +def get_lineno(err, tb_idx): |
| 399 | + lineno = -1 |
| 400 | + if hasattr(err, "lineno") and err.lineno is not None: |
| 401 | + lineno = err.lineno |
| 402 | + else: |
| 403 | + tb = sys.exc_info()[2] |
| 404 | + lineno = traceback.extract_tb(tb)[tb_idx][1] |
| 405 | + return lineno |
| 406 | + |
| 407 | + |
371 | 408 | def get_modules_from_env():
|
372 | 409 | """Get modules from environment variable NEST_SERVER_MODULES.
|
373 | 410 |
|
@@ -397,13 +434,34 @@ def get_modules_from_env():
|
397 | 434 | def get_or_error(func):
|
398 | 435 | """Wrapper to get data and status."""
|
399 | 436 |
|
400 |
| - def func_wrapper(call, args, kwargs): |
| 437 | + def func_wrapper(call, *args, **kwargs): |
401 | 438 | try:
|
402 |
| - return func(call, args, kwargs) |
403 |
| - except Exception as e: |
404 |
| - for line in traceback.format_exception(*sys.exc_info()): |
405 |
| - print(line, flush=True) |
406 |
| - flask.abort(EXCEPTION_ERROR_STATUS, str(e)) |
| 439 | + return func(call, *args, **kwargs) |
| 440 | + |
| 441 | + except NESTError as err: |
| 442 | + error_class = err.errorname + " (NESTError)" |
| 443 | + detail = err.errormessage |
| 444 | + lineno = get_lineno(err, 1) |
| 445 | + |
| 446 | + except (KeyError, SyntaxError, TypeError, ValueError) as err: |
| 447 | + error_class = err.__class__.__name__ |
| 448 | + detail = err.args[0] |
| 449 | + lineno = get_lineno(err, 1) |
| 450 | + |
| 451 | + except Exception as err: |
| 452 | + error_class = err.__class__.__name__ |
| 453 | + detail = err.args[0] |
| 454 | + lineno = get_lineno(err, -1) |
| 455 | + |
| 456 | + for line in traceback.format_exception(*sys.exc_info()): |
| 457 | + print(line, flush=True) |
| 458 | + |
| 459 | + if lineno == -1: |
| 460 | + message = "%s: %s" % (error_class, detail) |
| 461 | + else: |
| 462 | + message = "%s at line %d: %s" % (error_class, lineno, detail) |
| 463 | + |
| 464 | + raise ErrorHandler(message, lineno) |
407 | 465 |
|
408 | 466 | return func_wrapper
|
409 | 467 |
|
|
0 commit comments