Skip to content

fix: fix exception handling in C++ native code #244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: @jakubgonera/style-transfer-cpp
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions android/src/main/cpp/ETInstallerModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ void ETInstallerModule::injectJSIBindings() {
jbyteArray byteData =
(jbyteArray)env->CallStaticObjectMethod(cls, method, jUrl);

if (env->IsSameObject(byteData, NULL)) {
throw std::runtime_error("Error fetching data from a url");
}

int size = env->GetArrayLength(byteData);
jbyte *bytes = env->GetByteArrayElements(byteData, JNI_FALSE);
std::byte *dataBytePtr = reinterpret_cast<std::byte *>(bytes);
Expand Down
20 changes: 12 additions & 8 deletions android/src/main/java/com/swmansion/rnexecutorch/ETInstaller.kt
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,20 @@ class ETInstaller(
@JvmStatic
@DoNotStrip
@Throws(Exception::class)
fun fetchByteDataFromUrl(source: String): ByteArray {
val url = URL(source)
val connection = url.openConnection()
connection.connect()
fun fetchByteDataFromUrl(source: String): ByteArray? {
try {
val url = URL(source)
val connection = url.openConnection()
connection.connect()

val inputStream: InputStream = connection.getInputStream()
val data = inputStream.readBytes()
inputStream.close()
val inputStream: InputStream = connection.getInputStream()
val data = inputStream.readBytes()
inputStream.close()

return data
return data
} catch (exception: Throwable) {
return null
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion common/rnexecutorch/data_processing/ImageProcessing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ cv::Mat readImage(const std::string &imageURI) {
// local file
auto url = ada::parse(imageURI);
image = cv::imread(std::string{url->get_pathname()}, cv::IMREAD_COLOR);
} else {
} else if (imageURI.starts_with("http")) {
// remote file
std::vector<std::byte> imageData = fetchUrlFunc(imageURI);
image = cv::imdecode(
cv::Mat(1, imageData.size(), CV_8UC1, (void *)imageData.data()),
cv::IMREAD_COLOR);
} else {
throw std::runtime_error("Read image error: unknown protocol");
}

if (image.empty()) {
Expand Down
52 changes: 31 additions & 21 deletions common/rnexecutorch/host_objects/ModelHostObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,48 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
}

JSI_HOST_FUNCTION(forward) {

auto promise = promiseVendor.createPromise(
[this, count, args, &runtime](std::shared_ptr<Promise> promise) {
std::thread([this, promise = std::move(promise), count, args,
&runtime]() {
constexpr std::size_t forwardArgCount =
jsiconversion::getArgumentCount(&Model::forward);
if (forwardArgCount != count) {
char errorMessage[100];
std::snprintf(
errorMessage, sizeof(errorMessage),
"Argument count mismatch, was expecting: %zu but got: %zu",
forwardArgCount, count);
constexpr std::size_t forwardArgCount =
jsiconversion::getArgumentCount(&Model::forward);
if (forwardArgCount != count) {
char errorMessage[100];
std::snprintf(
errorMessage, sizeof(errorMessage),
"Argument count mismatch, was expecting: %zu but got: %zu",
forwardArgCount, count);

promise->reject(errorMessage);
return;
}
promise->reject(errorMessage);
return;
}

// Do the asynchronous work
std::thread([this, promise = std::move(promise), args, &runtime]() {
try {
auto argsConverted = jsiconversion::createArgsTupleFromJsi(
&Model::forward, args, runtime);
promise->resolve([this, argsConverted = std::move(argsConverted)](
jsi::Runtime &runtime) {
auto result = std::apply(
std::bind_front(&Model::forward, model), argsConverted);
auto resultValue =
jsiconversion::getJsiValue(std::move(result), runtime);
return resultValue;
auto result = std::apply(std::bind_front(&Model::forward, model),
argsConverted);

promise->resolve([result =
std::move(result)](jsi::Runtime &runtime) {
return jsiconversion::getJsiValue(std::move(result), runtime);
});
} catch (const std::runtime_error &e) {
// This catch should be merged with the next one
// (std::runtime_error inherits from std::exception) HOWEVER react
// native has broken RTTI which breaks proper exception type
// checking. Remove when the following change is present in our
// version:
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
promise->reject(e.what());
return;
} catch (const std::exception &e) {
promise->reject(e.what());
return;
} catch (...) {
promise->reject("Unknown error");
return;
}
}).detach();
});
Expand Down
6 changes: 3 additions & 3 deletions common/rnexecutorch/jsi/JsiPromise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ jsi::Value PromiseVendor::createPromise(

auto rejectWrapper = [reject, &runtime, callInvoker](
const std::string &errorMessage) -> void {
auto error = jsi::JSError(runtime, errorMessage);
auto errorShared = std::make_shared<jsi::JSError>(error);
callInvoker->invokeAsync([reject, &runtime, errorShared]() -> void {
callInvoker->invokeAsync([reject, &runtime, errorMessage]() -> void {
auto error = jsi::JSError(runtime, errorMessage);
auto errorShared = std::make_shared<jsi::JSError>(error);
reject->call(runtime, errorShared->value());
});
};
Expand Down
22 changes: 14 additions & 8 deletions ios/RnExecutorch/ETInstaller.mm
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#import <React/RCTCallInvoker.h>
#import <ReactCommon/RCTTurboModule.h>
#include <rnexecutorch/RnExecutorchInstaller.h>
#include <stdexcept>

using namespace facebook::react;

Expand All @@ -26,14 +27,19 @@ @implementation ETInstaller
assert(jsiRuntime != nullptr);

auto fetchUrl = [](std::string url) {
NSString *nsUrlStr =
[NSString stringWithCString:url.c_str()
encoding:[NSString defaultCStringEncoding]];
NSURL *nsUrl = [NSURL URLWithString:nsUrlStr];
NSData *data = [NSData dataWithContentsOfURL:nsUrl];
const std::byte *bytePtr = reinterpret_cast<const std::byte *>(data.bytes);
int bufferLength = [data length];
return std::vector<std::byte>(bytePtr, bytePtr + bufferLength);
@try {
NSString *nsUrlStr =
[NSString stringWithCString:url.c_str()
encoding:[NSString defaultCStringEncoding]];
NSURL *nsUrl = [NSURL URLWithString:nsUrlStr];
NSData *data = [NSData dataWithContentsOfURL:nsUrl];
const std::byte *bytePtr =
reinterpret_cast<const std::byte *>(data.bytes);
int bufferLength = [data length];
return std::vector<std::byte>(bytePtr, bytePtr + bufferLength);
} @catch (NSException *exception) {
throw std::runtime_error("Error fetching data from a url");
}
};
rnexecutorch::RnExecutorchInstaller::injectJSIBindings(
jsiRuntime, jsCallInvoker, fetchUrl);
Expand Down