diff --git a/android/CMakeLists.txt b/android/CMakeLists.txt new file mode 100644 index 00000000..2feb0039 --- /dev/null +++ b/android/CMakeLists.txt @@ -0,0 +1,17 @@ +cmake_minimum_required(VERSION 3.13) +project(RnExecutorch) + +set (CMAKE_VERBOSE_MAKEFILE ON) +set (CMAKE_CXX_STANDARD 20) + +include("${REACT_NATIVE_DIR}/ReactAndroid/cmake-utils/folly-flags.cmake") +add_compile_options(${folly_FLAGS}) + +string(APPEND CMAKE_CXX_FLAGS " -DRCT_NEW_ARCH_ENABLED") + +set(ANDROID_CPP_DIR "${CMAKE_SOURCE_DIR}/src/main/cpp") +set(COMMON_CPP_DIR "${CMAKE_SOURCE_DIR}/../common") +set(ET_LIB_DIR "${CMAKE_SOURCE_DIR}/../third-party/android/libs") +set(ET_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/../third-party/include") + +add_subdirectory("${ANDROID_CPP_DIR}") \ No newline at end of file diff --git a/android/build.gradle b/android/build.gradle index c670398b..1dc3f045 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -1,3 +1,5 @@ +import org.apache.tools.ant.taskdefs.condition.Os + buildscript { ext { agp_version = '8.4.2' @@ -21,19 +23,20 @@ buildscript { def reactNativeArchitectures() { def value = rootProject.getProperties().get("reactNativeArchitectures") - return value ? value.split(",") : ["armeabi-v7a", "x86", "x86_64", "arm64-v8a"] -} - -def isNewArchitectureEnabled() { - return rootProject.hasProperty("newArchEnabled") && rootProject.getProperty("newArchEnabled") == "true" + // react-native-executorch supports only these architectures. This is due to + // Executorch not supporting anything else. + def defaultArchitectures = ["x86_64", "arm64-v8a"] + if(!value) { + return defaultArchitectures + } + def architectures = value.split(",") + return architectures.findAll { it in defaultArchitectures } } apply plugin: "com.android.library" apply plugin: "kotlin-android" +apply plugin: "com.facebook.react" -if (isNewArchitectureEnabled()) { - apply plugin: "com.facebook.react" -} def getExtOrDefault(name) { return rootProject.ext.has(name) ? rootProject.ext.get(name) : project.properties["RnExecutorch_" + name] @@ -52,6 +55,38 @@ def supportsNamespace() { return (major == 7 && minor >= 3) || major >= 8 } +def safeAppExtGet(prop, fallback) { + def appProject = rootProject.allprojects.find { it.plugins.hasPlugin('com.android.application') } + appProject?.ext?.has(prop) ? appProject.ext.get(prop) : fallback +} + +def toPlatformFileString(String path) { + if (Os.isFamily(Os.FAMILY_WINDOWS)) { + path = path.replace(File.separatorChar, '/' as char) + } + return path +} + +def resolveReactNativeDirectory() { + def reactNativeLocation = safeAppExtGet("REACT_NATIVE_NODE_MODULES_DIR", null) + + if (reactNativeLocation !== null) { + return file(reactNativeLocation) + } + + // Fallback to node resolver for custom directory structures like monorepos. + def reactNativePackage = file(["node", "--print", "require.resolve('react-native/package.json')"].execute(null, rootDir).text.trim()) + if(reactNativePackage.exists()) { + return reactNativePackage.parentFile + } + + throw new GradleException( + "[RnExecutorch] Unable to resolve react-native location in node_modules. You should project extension property (in `app/build.gradle`) `REACT_NATIVE_NODE_MODULES_DIR` with path to react-native." + ) +} + +def reactNativeRootDir = resolveReactNativeDirectory() + android { if (supportsNamespace()) { namespace "com.swmansion.rnexecutorch" @@ -63,12 +98,34 @@ android { } } + buildFeatures { + prefab true + prefabPublishing true + buildConfig true + } + compileSdkVersion getExtOrIntegerDefault("compileSdkVersion") defaultConfig { minSdkVersion getExtOrIntegerDefault("minSdkVersion") targetSdkVersion getExtOrIntegerDefault("targetSdkVersion") - buildConfigField("boolean", "IS_NEW_ARCHITECTURE_ENABLED", isNewArchitectureEnabled().toString()) + externalNativeBuild { + cmake { + cppFlags "-O2 -frtti -fexceptions -Wall -fstack-protector-all" + abiFilters (*reactNativeArchitectures()) + arguments "-DANDROID_STL=c++_shared", + "-DREACT_NATIVE_DIR=${toPlatformFileString(reactNativeRootDir.path)}" + "-DBUILD_DIR=${project.buildDir}" + "-DANDROID_TOOLCHAIN=clang" + } + } + } + + + externalNativeBuild { + cmake { + path "CMakeLists.txt" + } } buildTypes { @@ -101,9 +158,11 @@ dependencies { //noinspection GradleDynamicVersion implementation 'com.github.wendykierp:JTransforms:3.1' implementation "com.facebook.react:react-android:+" + implementation "com.facebook.react:react-native:+" + implementation 'com.facebook.fbjni:fbjni:0.6.0' implementation 'org.opencv:opencv:4.10.0' implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version" - implementation(files("libs/executorch.aar")) + implementation(files("../third-party/android/libs/executorch.aar")) implementation 'org.opencv:opencv:4.10.0' implementation("com.squareup.okhttp3:okhttp:4.9.2") } diff --git a/android/libs/executorch.aar b/android/libs/executorch.aar deleted file mode 100644 index f5e6830d..00000000 Binary files a/android/libs/executorch.aar and /dev/null differ diff --git a/android/src/main/cpp/CMakeLists.txt b/android/src/main/cpp/CMakeLists.txt new file mode 100644 index 00000000..83ce1213 --- /dev/null +++ b/android/src/main/cpp/CMakeLists.txt @@ -0,0 +1,44 @@ +cmake_minimum_required(VERSION 3.13) + +file(GLOB_RECURSE ANDROID_CPP_SOURCES CONFIGURE_DEPENDS "${ANDROID_CPP_DIR}/*.cpp") +file(GLOB_RECURSE COMMON_CPP_SOURCES CONFIGURE_DEPENDS "${COMMON_CPP_DIR}/*.cpp" "${COMMON_CPP_DIR}/*.c") + +add_library(react-native-executorch SHARED ${ANDROID_CPP_SOURCES} ${COMMON_CPP_SOURCES}) + +find_package(ReactAndroid REQUIRED CONFIG) +find_package(fbjni REQUIRED CONFIG) + +target_include_directories( + react-native-executorch + PUBLIC + "${COMMON_CPP_DIR}" + "${ANDROID_CPP_DIR}" + "${ET_INCLUDE_DIR}" + "${REACT_NATIVE_DIR}/ReactCommon" + "${REACT_NATIVE_DIR}/ReactAndroid/src/main/jni/react/turbomodule" + "${REACT_NATIVE_DIR}/ReactCommon/callinvoker" + "${BUILD_DIR}/generated/source/codegen/jni/react/renderer/components/RnExecutorchSpec" +) + +set(LINK_LIBRARIES + ReactAndroid::jsi + fbjni::fbjni + android + log +) + +set(RN_VERSION_LINK_LIBRARIES + ReactAndroid::reactnative +) + +add_library(executorch SHARED IMPORTED) + +set_target_properties(executorch PROPERTIES + IMPORTED_LOCATION "${ET_LIB_DIR}/${ANDROID_ABI}/libexecutorch.so") + +target_link_libraries( + react-native-executorch + ${LINK_LIBRARIES} + ${RN_VERSION_LINK_LIBRARIES} + executorch +) \ No newline at end of file diff --git a/android/src/main/cpp/ETInstallerModule.cpp b/android/src/main/cpp/ETInstallerModule.cpp new file mode 100644 index 00000000..cec899bd --- /dev/null +++ b/android/src/main/cpp/ETInstallerModule.cpp @@ -0,0 +1,35 @@ +#include "ETInstallerModule.h" +#include "RnExecutorchInstaller.h" + +namespace rnexecutorch { + +using namespace facebook::jni; + +ETInstallerModule::ETInstallerModule( + jni::alias_ref &jThis, + jsi::Runtime *jsiRuntime, + const std::shared_ptr &jsCallInvoker) + : javaPart_(make_global(jThis)), jsiRuntime_(jsiRuntime), + jsCallInvoker_(jsCallInvoker) {} + +jni::local_ref ETInstallerModule::initHybrid( + jni::alias_ref jThis, jlong jsContext, + jni::alias_ref + jsCallInvokerHolder) { + auto jsCallInvoker = jsCallInvokerHolder->cthis()->getCallInvoker(); + auto rnRuntime = reinterpret_cast(jsContext); + return makeCxxInstance(jThis, rnRuntime, jsCallInvoker); +} + +void ETInstallerModule::registerNatives() { + registerHybrid({ + makeNativeMethod("initHybrid", ETInstallerModule::initHybrid), + makeNativeMethod("injectJSIBindings", + ETInstallerModule::injectJSIBindings), + }); +} + +void ETInstallerModule::injectJSIBindings() { + RnExecutorchInstaller::injectJSIBindings(jsiRuntime_, jsCallInvoker_); +} +} // namespace rnexecutorch \ No newline at end of file diff --git a/android/src/main/cpp/ETInstallerModule.h b/android/src/main/cpp/ETInstallerModule.h new file mode 100644 index 00000000..78a95f4f --- /dev/null +++ b/android/src/main/cpp/ETInstallerModule.h @@ -0,0 +1,43 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace rnexecutorch { + +using namespace facebook; +using namespace react; + +class ETInstallerModule : public jni::HybridClass { +public: + static auto constexpr kJavaDescriptor = + "Lcom/swmansion/rnexecutorch/ETInstaller;"; + + static jni::local_ref + initHybrid(jni::alias_ref jThis, jlong jsContext, + jni::alias_ref + jsCallInvokerHolder); + + static void registerNatives(); + + void injectJSIBindings(); + +private: + friend HybridBase; + + jni::global_ref javaPart_; + jsi::Runtime *jsiRuntime_; + std::shared_ptr jsCallInvoker_; + + explicit ETInstallerModule( + jni::alias_ref &jThis, + jsi::Runtime *jsiRuntime, + const std::shared_ptr &jsCallInvoker); +}; + +} // namespace rnexecutorch \ No newline at end of file diff --git a/android/src/main/cpp/OnLoad.cpp b/android/src/main/cpp/OnLoad.cpp new file mode 100644 index 00000000..dedb3871 --- /dev/null +++ b/android/src/main/cpp/OnLoad.cpp @@ -0,0 +1,10 @@ +#include + +#include + +using namespace rnexecutorch; + +JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *) { + return facebook::jni::initialize( + vm, [] { ETInstallerModule::registerNatives(); }); +} \ No newline at end of file diff --git a/android/src/main/java/com/swmansion/rnexecutorch/ETInstaller.kt b/android/src/main/java/com/swmansion/rnexecutorch/ETInstaller.kt new file mode 100644 index 00000000..51d4e9fc --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/ETInstaller.kt @@ -0,0 +1,44 @@ +package com.swmansion.rnexecutorch + +import com.facebook.jni.HybridData +import com.facebook.react.bridge.ReactApplicationContext +import com.facebook.react.bridge.ReactMethod +import com.facebook.react.common.annotations.FrameworkAPI +import com.facebook.react.module.annotations.ReactModule +import com.facebook.react.turbomodule.core.CallInvokerHolderImpl + +@OptIn(FrameworkAPI::class) +@ReactModule(name = ETInstaller.NAME) +class ETInstaller( + reactContext: ReactApplicationContext, +) : NativeETInstallerSpec(reactContext) { + companion object { + const val NAME = NativeETInstallerSpec.NAME + } + + private val mHybridData: HybridData + + external fun initHybrid( + jsContext: Long, + callInvoker: CallInvokerHolderImpl, + ): HybridData + + private external fun injectJSIBindings() + + init { + try { + System.loadLibrary("executorch") + System.loadLibrary("react-native-executorch") + val jsCallInvokerHolder = reactContext.jsCallInvokerHolder as CallInvokerHolderImpl + mHybridData = initHybrid(reactContext.javaScriptContextHolder!!.get(), jsCallInvokerHolder) + } catch (exception: UnsatisfiedLinkError) { + throw RuntimeException("Could not load native module Install", exception) + } + } + + @ReactMethod(isBlockingSynchronousMethod = true) + override fun install(): Boolean { + injectJSIBindings() + return true + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt index 861f8ac6..8e0092b3 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt @@ -32,6 +32,8 @@ class RnExecutorchPackage : TurboReactPackage() { VerticalOCR(reactContext) } else if (name == ImageSegmentation.NAME) { ImageSegmentation(reactContext) + } else if (name == ETInstaller.NAME) { + ETInstaller(reactContext) } else if (name == Tokenizer.NAME) { Tokenizer(reactContext) } else if (name == TextEmbeddings.NAME) { @@ -49,6 +51,7 @@ class RnExecutorchPackage : TurboReactPackage() { LLM.NAME, false, // canOverrideExistingModule false, // needsEagerInit + true, // hasConstants false, // isCxxModule true, ) @@ -58,6 +61,7 @@ class RnExecutorchPackage : TurboReactPackage() { ETModule.NAME, false, // canOverrideExistingModule false, // needsEagerInit + true, // hasConstants false, // isCxxModule true, ) @@ -68,6 +72,7 @@ class RnExecutorchPackage : TurboReactPackage() { StyleTransfer.NAME, false, // canOverrideExistingModule false, // needsEagerInit + true, // hasConstants false, // isCxxModule true, ) @@ -78,6 +83,7 @@ class RnExecutorchPackage : TurboReactPackage() { Classification.NAME, false, // canOverrideExistingModule false, // needsEagerInit + true, // hasConstants false, // isCxxModule true, ) @@ -88,6 +94,7 @@ class RnExecutorchPackage : TurboReactPackage() { ObjectDetection.NAME, false, // canOverrideExistingModule false, // needsEagerInit + true, // hasConstants false, // isCxxModule true, ) @@ -98,6 +105,7 @@ class RnExecutorchPackage : TurboReactPackage() { SpeechToText.NAME, false, // canOverrideExistingModule false, // needsEagerInit + true, // hasConstants false, // isCxxModule true, ) @@ -108,6 +116,7 @@ class RnExecutorchPackage : TurboReactPackage() { OCR.NAME, false, // canOverrideExistingModule false, // needsEagerInit + true, // hasConstants false, // isCxxModule true, ) @@ -118,6 +127,7 @@ class RnExecutorchPackage : TurboReactPackage() { VerticalOCR.NAME, false, // canOverrideExistingModule false, // needsEagerInit + true, // hasConstants false, // isCxxModule true, ) @@ -128,6 +138,7 @@ class RnExecutorchPackage : TurboReactPackage() { ImageSegmentation.NAME, false, // canOverrideExistingModule false, // needsEagerInit + true, // hasConstants false, // isCxxModule true, ) @@ -138,6 +149,18 @@ class RnExecutorchPackage : TurboReactPackage() { Tokenizer.NAME, false, // canOverrideExistingModule false, // needsEagerInit + true, // hasConstants + false, // isCxxModule + true, + ) + + moduleInfos[ETInstaller.NAME] = + ReactModuleInfo( + ETInstaller.NAME, + ETInstaller.NAME, + false, // canOverrideExistingModule + false, // needsEagerInit + true, // hasConstants false, // isCxxModule true, ) diff --git a/common/Log.h b/common/Log.h new file mode 100644 index 00000000..d72fae37 --- /dev/null +++ b/common/Log.h @@ -0,0 +1,71 @@ +#pragma once + +#include +#include + +#ifdef __ANDROID__ +#include +#endif +#ifdef __APPLE__ +#include +#endif + +namespace rnexecutorch { + +enum class LOG_LEVEL { INFO, ERROR, DEBUG }; + +#ifdef __ANDROID__ +android_LogPriority androidLogLevel(LOG_LEVEL logLevel) { + switch (logLevel) { + case LOG_LEVEL::INFO: + default: + return ANDROID_LOG_INFO; + case LOG_LEVEL::ERROR: + return ANDROID_LOG_ERROR; + case LOG_LEVEL::DEBUG: + return ANDROID_LOG_DEBUG; + } +} +#endif + +// const char* instead of const std::string& as va_start doesn't take references +void log(LOG_LEVEL logLevel, const char *fmt, ...) { + va_list args; + va_start(args, fmt); + + // Maximum length of a log message. + static constexpr size_t kMaxLogMessageLength = 1024; + char buf[kMaxLogMessageLength]; + size_t len = vsnprintf(buf, kMaxLogMessageLength, fmt, args); + if (len >= kMaxLogMessageLength - 1) { + for (std::size_t i = 0; i < 3; ++i) + buf[kMaxLogMessageLength - 2 - i] = '.'; + len = kMaxLogMessageLength - 3; + } + buf[kMaxLogMessageLength - 1] = 0; + +#ifdef __ANDROID__ + + __android_log_print(androidLogLevel(logLevel), "RnExecutorch", "%s", buf); + +#endif // ifdef __ANDROID__ +#ifdef __APPLE__ + + switch (logLevel) { + case LOG_LEVEL::INFO: + default: + os_log_info(OS_LOG_DEFAULT, "%s", buf); + break; + case LOG_LEVEL::ERROR: + os_log_error(OS_LOG_DEFAULT, "%s", buf); + break; + case LOG_LEVEL::DEBUG: + os_log_debug(OS_LOG_DEFAULT, "%s", buf); + break; + } + +#endif // ifdef __APPLE__ + va_end(args); +} + +} // namespace rnexecutorch diff --git a/common/RnExecutorchInstaller.h b/common/RnExecutorchInstaller.h new file mode 100644 index 00000000..f2118fd1 --- /dev/null +++ b/common/RnExecutorchInstaller.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include +#include + +#include + +#include "jsi/JsiPromise.h" + +namespace rnexecutorch { + +using namespace facebook; + +class RnExecutorchInstaller { +public: + static void + injectJSIBindings(jsi::Runtime *jsiRuntime, + const std::shared_ptr &jsCallInvoker) { + // Install JSI methods here + } + +private: +}; + +} // namespace rnexecutorch \ No newline at end of file diff --git a/common/jsi/JsiPromise.cpp b/common/jsi/JsiPromise.cpp new file mode 100644 index 00000000..2043ac92 --- /dev/null +++ b/common/jsi/JsiPromise.cpp @@ -0,0 +1,60 @@ +#include "JsiPromise.h" + +namespace rnexecutorch { + +using namespace facebook; + +jsi::Value PromiseVendor::createPromise( + const std::function)> &function) { + if (runtime_ == nullptr) { + throw std::runtime_error("Runtime was null!"); + } + + auto &runtime = *runtime_; + auto callInvoker = callInvoker_; + + // get Promise constructor + auto promiseCtor = runtime.global().getPropertyAsFunction(runtime, "Promise"); + + // create a "run" function (first Promise arg) + auto runPromise = jsi::Function::createFromHostFunction( + runtime, jsi::PropNameID::forUtf8(runtime, "runPromise"), 2, + [callInvoker, + function](jsi::Runtime &runtime, const jsi::Value &thisValue, + const jsi::Value *arguments, size_t count) -> jsi::Value { + auto resolveLocal = arguments[0].asObject(runtime).asFunction(runtime); + auto resolve = std::make_shared(std::move(resolveLocal)); + auto rejectLocal = arguments[1].asObject(runtime).asFunction(runtime); + auto reject = std::make_shared(std::move(rejectLocal)); + + auto resolveWrapper = + [resolve, &runtime, callInvoker]( + const std::function &resolver) + -> void { + callInvoker->invokeAsync([resolve, &runtime, resolver]() -> void { + auto valueShared = std::make_shared(resolver(runtime)); + + resolve->call(runtime, *valueShared); + }); + }; + + auto rejectWrapper = [reject, &runtime, callInvoker]( + const std::string &errorMessage) -> void { + auto error = jsi::JSError(runtime, errorMessage); + auto errorShared = std::make_shared(error); + callInvoker->invokeAsync([reject, &runtime, errorShared]() -> void { + reject->call(runtime, errorShared->value()); + }); + }; + + auto promise = std::make_shared(resolveWrapper, rejectWrapper); + function(promise); + + return jsi::Value::undefined(); + }); + + // return new Promise((resolve, reject) => ...) + return promiseCtor.callAsConstructor(runtime, runPromise); +} + +} // namespace rnexecutorch \ No newline at end of file diff --git a/common/jsi/JsiPromise.h b/common/jsi/JsiPromise.h new file mode 100644 index 00000000..cc04c174 --- /dev/null +++ b/common/jsi/JsiPromise.h @@ -0,0 +1,47 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace rnexecutorch { + +using namespace facebook; + +class Promise { +public: + Promise(std::function)> + resolve, + std::function reject) + : resolve_(std::move(resolve)), reject_(std::move(reject)) {} + + void resolve(const std::function &resolver) { + resolve_(std::forward>( + resolver)); + } + + void reject(const std::string &errorMessage) { reject_(errorMessage); } + +private: + std::function)> resolve_; + std::function reject_; +}; + +class PromiseVendor { +public: + PromiseVendor(jsi::Runtime *runtime, + const std::shared_ptr &callInvoker) + : runtime_(runtime), callInvoker_(callInvoker) {} + + jsi::Value + createPromise(const std::function)> &function); + +private: + jsi::Runtime *runtime_; + std::shared_ptr callInvoker_; +}; + +} // namespace rnexecutorch \ No newline at end of file diff --git a/examples/computer-vision/ios/computervision.xcodeproj/project.pbxproj b/examples/computer-vision/ios/computervision.xcodeproj/project.pbxproj index ace28495..ce46b822 100644 --- a/examples/computer-vision/ios/computervision.xcodeproj/project.pbxproj +++ b/examples/computer-vision/ios/computervision.xcodeproj/project.pbxproj @@ -12,13 +12,15 @@ 13B07FC11A68108700A75B9A /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 13B07FB71A68108700A75B9A /* main.m */; }; 3D44DDE8855509EE8F14BD35 /* PrivacyInfo.xcprivacy in Resources */ = {isa = PBXBuildFile; fileRef = 9F1C1E848042D20F63F5A766 /* PrivacyInfo.xcprivacy */; }; 3E461D99554A48A4959DE609 /* SplashScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = AA286B85B6C04FC6940260E9 /* SplashScreen.storyboard */; }; - 410BB317783F79CD19490DEA /* libPods-computervision.a in Frameworks */ = {isa = PBXBuildFile; fileRef = B4BA30D50A7961C34FDC0CE3 /* libPods-computervision.a */; }; + 8ECE9C761F0C334C5905F3DC /* libPods-computervision.a in Frameworks */ = {isa = PBXBuildFile; fileRef = A883A86E8ACD521325B193BA /* libPods-computervision.a */; }; B18059E884C0ABDD17F3DC3D /* ExpoModulesProvider.swift in Sources */ = {isa = PBXBuildFile; fileRef = FAC715A2D49A985799AEE119 /* ExpoModulesProvider.swift */; }; BB2F792D24A3F905000567C9 /* Expo.plist in Resources */ = {isa = PBXBuildFile; fileRef = BB2F792C24A3F905000567C9 /* Expo.plist */; }; EE159840D482449C972B155B /* noop-file.swift in Sources */ = {isa = PBXBuildFile; fileRef = 9A2BFCDC01274C44ADBAA6A1 /* noop-file.swift */; }; /* End PBXBuildFile section */ /* Begin PBXFileReference section */ + 00D3ACA31040E8F7A3BA0935 /* Pods-computervision.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-computervision.debug.xcconfig"; path = "Target Support Files/Pods-computervision/Pods-computervision.debug.xcconfig"; sourceTree = ""; }; + 04B98E7657007205BDD2A29A /* Pods-computervision.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-computervision.release.xcconfig"; path = "Target Support Files/Pods-computervision/Pods-computervision.release.xcconfig"; sourceTree = ""; }; 13B07F961A680F5B00A75B9A /* computervision.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = computervision.app; sourceTree = BUILT_PRODUCTS_DIR; }; 13B07FAF1A68108700A75B9A /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = AppDelegate.h; path = computervision/AppDelegate.h; sourceTree = ""; }; 13B07FB01A68108700A75B9A /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; name = AppDelegate.mm; path = computervision/AppDelegate.mm; sourceTree = ""; }; @@ -26,11 +28,9 @@ 13B07FB61A68108700A75B9A /* Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; name = Info.plist; path = computervision/Info.plist; sourceTree = ""; }; 13B07FB71A68108700A75B9A /* main.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = main.m; path = computervision/main.m; sourceTree = ""; }; 9A2BFCDC01274C44ADBAA6A1 /* noop-file.swift */ = {isa = PBXFileReference; explicitFileType = undefined; fileEncoding = 4; includeInIndex = 0; lastKnownFileType = sourcecode.swift; name = "noop-file.swift"; path = "computervision/noop-file.swift"; sourceTree = ""; }; - 9A9D9D5D3E179ACB16F7C961 /* Pods-computervision.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-computervision.release.xcconfig"; path = "Target Support Files/Pods-computervision/Pods-computervision.release.xcconfig"; sourceTree = ""; }; 9F1C1E848042D20F63F5A766 /* PrivacyInfo.xcprivacy */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xml; name = PrivacyInfo.xcprivacy; path = computervision/PrivacyInfo.xcprivacy; sourceTree = ""; }; + A883A86E8ACD521325B193BA /* libPods-computervision.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-computervision.a"; sourceTree = BUILT_PRODUCTS_DIR; }; AA286B85B6C04FC6940260E9 /* SplashScreen.storyboard */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.storyboard; name = SplashScreen.storyboard; path = computervision/SplashScreen.storyboard; sourceTree = ""; }; - B34AACF0D072BF11B622DA21 /* Pods-computervision.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-computervision.debug.xcconfig"; path = "Target Support Files/Pods-computervision/Pods-computervision.debug.xcconfig"; sourceTree = ""; }; - B4BA30D50A7961C34FDC0CE3 /* libPods-computervision.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-computervision.a"; sourceTree = BUILT_PRODUCTS_DIR; }; BB2F792C24A3F905000567C9 /* Expo.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = Expo.plist; sourceTree = ""; }; E5089F61F1384BB681122A7F /* computervision-Bridging-Header.h */ = {isa = PBXFileReference; explicitFileType = undefined; fileEncoding = 4; includeInIndex = 0; lastKnownFileType = sourcecode.c.h; name = "computervision-Bridging-Header.h"; path = "computervision/computervision-Bridging-Header.h"; sourceTree = ""; }; ED297162215061F000B7C4FE /* JavaScriptCore.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = JavaScriptCore.framework; path = System/Library/Frameworks/JavaScriptCore.framework; sourceTree = SDKROOT; }; @@ -42,7 +42,7 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - 410BB317783F79CD19490DEA /* libPods-computervision.a in Frameworks */, + 8ECE9C761F0C334C5905F3DC /* libPods-computervision.a in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -70,7 +70,7 @@ isa = PBXGroup; children = ( ED297162215061F000B7C4FE /* JavaScriptCore.framework */, - B4BA30D50A7961C34FDC0CE3 /* libPods-computervision.a */, + A883A86E8ACD521325B193BA /* libPods-computervision.a */, ); name = Frameworks; sourceTree = ""; @@ -125,8 +125,8 @@ D65327D7A22EEC0BE12398D9 /* Pods */ = { isa = PBXGroup; children = ( - B34AACF0D072BF11B622DA21 /* Pods-computervision.debug.xcconfig */, - 9A9D9D5D3E179ACB16F7C961 /* Pods-computervision.release.xcconfig */, + 00D3ACA31040E8F7A3BA0935 /* Pods-computervision.debug.xcconfig */, + 04B98E7657007205BDD2A29A /* Pods-computervision.release.xcconfig */, ); path = Pods; sourceTree = ""; @@ -146,14 +146,14 @@ isa = PBXNativeTarget; buildConfigurationList = 13B07F931A680F5B00A75B9A /* Build configuration list for PBXNativeTarget "computervision" */; buildPhases = ( - 4C3CB1F7F4FC3393B10158A0 /* [CP] Check Pods Manifest.lock */, + A44E0F386BE1861A4BB0D152 /* [CP] Check Pods Manifest.lock */, 40729C76425943C737B83F07 /* [Expo] Configure project */, 13B07F871A680F5B00A75B9A /* Sources */, 13B07F8C1A680F5B00A75B9A /* Frameworks */, 13B07F8E1A680F5B00A75B9A /* Resources */, 00DD1BFF1BD5951E006B06BC /* Bundle React Native code and images */, - 00EA61A0AC7C7A004831148A /* [CP] Embed Pods Frameworks */, - 024B2F8B6E2DEF04825F05DE /* [CP] Copy Pods Resources */, + 6B22C61BD8A720CC7014EB81 /* [CP] Embed Pods Frameworks */, + D6A2310306AE2EE1E906EC1D /* [CP] Copy Pods Resources */, ); buildRules = ( ); @@ -225,7 +225,26 @@ shellPath = /bin/sh; shellScript = "if [[ -f \"$PODS_ROOT/../.xcode.env\" ]]; then\n source \"$PODS_ROOT/../.xcode.env\"\nfi\nif [[ -f \"$PODS_ROOT/../.xcode.env.local\" ]]; then\n source \"$PODS_ROOT/../.xcode.env.local\"\nfi\n\n# The project root by default is one level up from the ios directory\nexport PROJECT_ROOT=\"$PROJECT_DIR\"/..\n\nif [[ \"$CONFIGURATION\" = *Debug* ]]; then\n export SKIP_BUNDLING=1\nfi\nif [[ -z \"$ENTRY_FILE\" ]]; then\n # Set the entry JS file using the bundler's entry resolution.\n export ENTRY_FILE=\"$(\"$NODE_BINARY\" -e \"require('expo/scripts/resolveAppEntry')\" \"$PROJECT_ROOT\" ios absolute | tail -n 1)\"\nfi\n\nif [[ -z \"$CLI_PATH\" ]]; then\n # Use Expo CLI\n export CLI_PATH=\"$(\"$NODE_BINARY\" --print \"require.resolve('@expo/cli', { paths: [require.resolve('expo/package.json')] })\")\"\nfi\nif [[ -z \"$BUNDLE_COMMAND\" ]]; then\n # Default Expo CLI command for bundling\n export BUNDLE_COMMAND=\"export:embed\"\nfi\n\n# Source .xcode.env.updates if it exists to allow\n# SKIP_BUNDLING to be unset if needed\nif [[ -f \"$PODS_ROOT/../.xcode.env.updates\" ]]; then\n source \"$PODS_ROOT/../.xcode.env.updates\"\nfi\n# Source local changes to allow overrides\n# if needed\nif [[ -f \"$PODS_ROOT/../.xcode.env.local\" ]]; then\n source \"$PODS_ROOT/../.xcode.env.local\"\nfi\n\n`\"$NODE_BINARY\" --print \"require('path').dirname(require.resolve('react-native/package.json')) + '/scripts/react-native-xcode.sh'\"`\n\n"; }; - 00EA61A0AC7C7A004831148A /* [CP] Embed Pods Frameworks */ = { + 40729C76425943C737B83F07 /* [Expo] Configure project */ = { + isa = PBXShellScriptBuildPhase; + alwaysOutOfDate = 1; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + ); + name = "[Expo] Configure project"; + outputFileListPaths = ( + ); + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "# This script configures Expo modules and generates the modules provider file.\nbash -l -c \"./Pods/Target\\ Support\\ Files/Pods-computervision/expo-configure-project.sh\"\n"; + }; + 6B22C61BD8A720CC7014EB81 /* [CP] Embed Pods Frameworks */ = { isa = PBXShellScriptBuildPhase; buildActionMask = 2147483647; files = ( @@ -245,7 +264,29 @@ shellScript = "\"${PODS_ROOT}/Target Support Files/Pods-computervision/Pods-computervision-frameworks.sh\"\n"; showEnvVarsInLog = 0; }; - 024B2F8B6E2DEF04825F05DE /* [CP] Copy Pods Resources */ = { + A44E0F386BE1861A4BB0D152 /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + "${PODS_PODFILE_DIR_PATH}/Podfile.lock", + "${PODS_ROOT}/Manifest.lock", + ); + name = "[CP] Check Pods Manifest.lock"; + outputFileListPaths = ( + ); + outputPaths = ( + "$(DERIVED_FILE_DIR)/Pods-computervision-checkManifestLockResult.txt", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; + showEnvVarsInLog = 0; + }; + D6A2310306AE2EE1E906EC1D /* [CP] Copy Pods Resources */ = { isa = PBXShellScriptBuildPhase; buildActionMask = 2147483647; files = ( @@ -279,47 +320,6 @@ shellScript = "\"${PODS_ROOT}/Target Support Files/Pods-computervision/Pods-computervision-resources.sh\"\n"; showEnvVarsInLog = 0; }; - 40729C76425943C737B83F07 /* [Expo] Configure project */ = { - isa = PBXShellScriptBuildPhase; - alwaysOutOfDate = 1; - buildActionMask = 2147483647; - files = ( - ); - inputFileListPaths = ( - ); - inputPaths = ( - ); - name = "[Expo] Configure project"; - outputFileListPaths = ( - ); - outputPaths = ( - ); - runOnlyForDeploymentPostprocessing = 0; - shellPath = /bin/sh; - shellScript = "# This script configures Expo modules and generates the modules provider file.\nbash -l -c \"./Pods/Target\\ Support\\ Files/Pods-computervision/expo-configure-project.sh\"\n"; - }; - 4C3CB1F7F4FC3393B10158A0 /* [CP] Check Pods Manifest.lock */ = { - isa = PBXShellScriptBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - inputFileListPaths = ( - ); - inputPaths = ( - "${PODS_PODFILE_DIR_PATH}/Podfile.lock", - "${PODS_ROOT}/Manifest.lock", - ); - name = "[CP] Check Pods Manifest.lock"; - outputFileListPaths = ( - ); - outputPaths = ( - "$(DERIVED_FILE_DIR)/Pods-computervision-checkManifestLockResult.txt", - ); - runOnlyForDeploymentPostprocessing = 0; - shellPath = /bin/sh; - shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; - showEnvVarsInLog = 0; - }; /* End PBXShellScriptBuildPhase section */ /* Begin PBXSourcesBuildPhase section */ @@ -339,7 +339,7 @@ /* Begin XCBuildConfiguration section */ 13B07F941A680F5B00A75B9A /* Debug */ = { isa = XCBuildConfiguration; - baseConfigurationReference = B34AACF0D072BF11B622DA21 /* Pods-computervision.debug.xcconfig */; + baseConfigurationReference = 00D3ACA31040E8F7A3BA0935 /* Pods-computervision.debug.xcconfig */; buildSettings = { ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; CLANG_ENABLE_MODULES = YES; @@ -376,7 +376,7 @@ }; 13B07F951A680F5B00A75B9A /* Release */ = { isa = XCBuildConfiguration; - baseConfigurationReference = 9A9D9D5D3E179ACB16F7C961 /* Pods-computervision.release.xcconfig */; + baseConfigurationReference = 04B98E7657007205BDD2A29A /* Pods-computervision.release.xcconfig */; buildSettings = { ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; CLANG_ENABLE_MODULES = YES; diff --git a/ios/RnExecutorch.xcodeproj/project.pbxproj b/ios/RnExecutorch.xcodeproj/project.pbxproj index 68e367a8..f95bb5ca 100644 --- a/ios/RnExecutorch.xcodeproj/project.pbxproj +++ b/ios/RnExecutorch.xcodeproj/project.pbxproj @@ -8,6 +8,7 @@ /* Begin PBXBuildFile section */ 55D6EA8C2D0987D2009BA408 /* ExecutorchLib.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 55D6EA8B2D0987D2009BA408 /* ExecutorchLib.xcframework */; }; + 8C53B8782D96BFCD0097900E /* JsiPromise.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 8C53B8702D96BFCD0097900E /* JsiPromise.cpp */; }; /* End PBXBuildFile section */ /* Begin PBXCopyFilesBuildPhase section */ @@ -25,6 +26,10 @@ /* Begin PBXFileReference section */ 550986892CEF541900FECBB8 /* libRnExecutorch.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = libRnExecutorch.a; sourceTree = BUILT_PRODUCTS_DIR; }; 55D6EA8B2D0987D2009BA408 /* ExecutorchLib.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; path = ExecutorchLib.xcframework; sourceTree = ""; }; + 8C53B86F2D96BFCD0097900E /* JsiPromise.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = JsiPromise.h; sourceTree = ""; }; + 8C53B8702D96BFCD0097900E /* JsiPromise.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = JsiPromise.cpp; sourceTree = ""; }; + 8C53B8752D96BFCD0097900E /* RnExecutorchInstaller.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = RnExecutorchInstaller.h; sourceTree = ""; }; + 8CB19AA72DA3D1A200EB6786 /* Log.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = Log.h; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFileSystemSynchronizedGroupBuildPhaseMembershipExceptionSet section */ @@ -71,6 +76,7 @@ 550986802CEF541900FECBB8 = { isa = PBXGroup; children = ( + 8C53B8762D96BFCD0097900E /* common */, 5509868B2CEF541900FECBB8 /* RnExecutorch */, 55D6EA8A2D0987D2009BA408 /* Frameworks */, 5509868A2CEF541900FECBB8 /* Products */, @@ -93,6 +99,26 @@ name = Frameworks; sourceTree = ""; }; + 8C53B8712D96BFCD0097900E /* jsi */ = { + isa = PBXGroup; + children = ( + 8C53B86F2D96BFCD0097900E /* JsiPromise.h */, + 8C53B8702D96BFCD0097900E /* JsiPromise.cpp */, + ); + path = jsi; + sourceTree = ""; + }; + 8C53B8762D96BFCD0097900E /* common */ = { + isa = PBXGroup; + children = ( + 8CB19AA72DA3D1A200EB6786 /* Log.h */, + 8C53B8712D96BFCD0097900E /* jsi */, + 8C53B8752D96BFCD0097900E /* RnExecutorchInstaller.h */, + ); + name = common; + path = ../common; + sourceTree = SOURCE_ROOT; + }; /* End PBXGroup section */ /* Begin PBXNativeTarget section */ @@ -155,6 +181,7 @@ isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; files = ( + 8C53B8782D96BFCD0097900E /* JsiPromise.cpp in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; diff --git a/ios/RnExecutorch/ETInstaller.h b/ios/RnExecutorch/ETInstaller.h new file mode 100644 index 00000000..d01236e9 --- /dev/null +++ b/ios/RnExecutorch/ETInstaller.h @@ -0,0 +1,8 @@ +#import +#import +#import + +@interface ETInstaller + : RCTEventEmitter + +@end diff --git a/ios/RnExecutorch/ETInstaller.mm b/ios/RnExecutorch/ETInstaller.mm new file mode 100644 index 00000000..b7a079c3 --- /dev/null +++ b/ios/RnExecutorch/ETInstaller.mm @@ -0,0 +1,40 @@ +#import "ETInstaller.h" + +#import + +#import +#import +#include + +using namespace facebook::react; + +@interface RCTBridge (JSIRuntime) +- (void *)runtime; +@end + +@implementation ETInstaller + +@synthesize callInvoker = _callInvoker; + +RCT_EXPORT_MODULE(ETInstaller); + +RCT_EXPORT_BLOCKING_SYNCHRONOUS_METHOD(install) { + auto jsiRuntime = + reinterpret_cast(self.bridge.runtime); + auto jsCallInvoker = _callInvoker.callInvoker; + + assert(jsiRuntime != nullptr); + + rnexecutorch::RnExecutorchInstaller::injectJSIBindings(jsiRuntime, + jsCallInvoker); + + NSLog(@"Successfully installed JSI bindings for react-native-executorch!"); + return @true; +} + +- (std::shared_ptr)getTurboModule: + (const facebook::react::ObjCTurboModule::InitParams &)params { + return std::make_shared(params); +} + +@end diff --git a/ios/libs/libbackend_coreml_ios.a b/ios/libs/libbackend_coreml_ios.a new file mode 100644 index 00000000..933be4c4 Binary files /dev/null and b/ios/libs/libbackend_coreml_ios.a differ diff --git a/ios/libs/libbackend_coreml_simulator.a b/ios/libs/libbackend_coreml_simulator.a new file mode 100644 index 00000000..8bf7ba5a Binary files /dev/null and b/ios/libs/libbackend_coreml_simulator.a differ diff --git a/ios/libs/libbackend_mps_ios.a b/ios/libs/libbackend_mps_ios.a new file mode 100644 index 00000000..71e56e53 Binary files /dev/null and b/ios/libs/libbackend_mps_ios.a differ diff --git a/ios/libs/libbackend_mps_simulator.a b/ios/libs/libbackend_mps_simulator.a new file mode 100644 index 00000000..0b1ea204 Binary files /dev/null and b/ios/libs/libbackend_mps_simulator.a differ diff --git a/ios/libs/libbackend_xnnpack_ios.a b/ios/libs/libbackend_xnnpack_ios.a new file mode 100644 index 00000000..203be474 Binary files /dev/null and b/ios/libs/libbackend_xnnpack_ios.a differ diff --git a/ios/libs/libbackend_xnnpack_simulator.a b/ios/libs/libbackend_xnnpack_simulator.a new file mode 100644 index 00000000..0ab42898 Binary files /dev/null and b/ios/libs/libbackend_xnnpack_simulator.a differ diff --git a/ios/libs/libexecutorch_ios.a b/ios/libs/libexecutorch_ios.a new file mode 100644 index 00000000..757a0da5 Binary files /dev/null and b/ios/libs/libexecutorch_ios.a differ diff --git a/ios/libs/libexecutorch_simulator.a b/ios/libs/libexecutorch_simulator.a new file mode 100644 index 00000000..5c72e3dd Binary files /dev/null and b/ios/libs/libexecutorch_simulator.a differ diff --git a/ios/libs/libkernels_custom_ios.a b/ios/libs/libkernels_custom_ios.a new file mode 100644 index 00000000..19dc7d28 Binary files /dev/null and b/ios/libs/libkernels_custom_ios.a differ diff --git a/ios/libs/libkernels_custom_simulator.a b/ios/libs/libkernels_custom_simulator.a new file mode 100644 index 00000000..3f1a8676 Binary files /dev/null and b/ios/libs/libkernels_custom_simulator.a differ diff --git a/ios/libs/libkernels_optimized_ios.a b/ios/libs/libkernels_optimized_ios.a new file mode 100644 index 00000000..1bed14fa Binary files /dev/null and b/ios/libs/libkernels_optimized_ios.a differ diff --git a/ios/libs/libkernels_optimized_simulator.a b/ios/libs/libkernels_optimized_simulator.a new file mode 100644 index 00000000..a62462cc Binary files /dev/null and b/ios/libs/libkernels_optimized_simulator.a differ diff --git a/ios/libs/libkernels_portable_ios.a b/ios/libs/libkernels_portable_ios.a new file mode 100644 index 00000000..d7f46dcd Binary files /dev/null and b/ios/libs/libkernels_portable_ios.a differ diff --git a/ios/libs/libkernels_portable_simulator.a b/ios/libs/libkernels_portable_simulator.a new file mode 100644 index 00000000..2347a872 Binary files /dev/null and b/ios/libs/libkernels_portable_simulator.a differ diff --git a/ios/libs/libkernels_quantized_ios.a b/ios/libs/libkernels_quantized_ios.a new file mode 100644 index 00000000..d86749aa Binary files /dev/null and b/ios/libs/libkernels_quantized_ios.a differ diff --git a/ios/libs/libkernels_quantized_simulator.a b/ios/libs/libkernels_quantized_simulator.a new file mode 100644 index 00000000..2fd5b69b Binary files /dev/null and b/ios/libs/libkernels_quantized_simulator.a differ diff --git a/package.json b/package.json index 5383254f..815ab504 100644 --- a/package.json +++ b/package.json @@ -13,7 +13,10 @@ "android", "ios", "cpp", + "common", "*.podspec", + "third-party/include", + "third-party/android/libs", "!ios/build", "!android/build", "!android/gradle", diff --git a/react-native-executorch.podspec b/react-native-executorch.podspec index 3b960600..fe3a7094 100644 --- a/react-native-executorch.podspec +++ b/react-native-executorch.podspec @@ -13,10 +13,50 @@ Pod::Spec.new do |s| s.platforms = { :ios => min_ios_version_supported } s.source = { :git => "https://github.com/NorbertKlockiewicz/react-native-executorch.git", :tag => "#{s.version}" } - s.ios.vendored_frameworks = "ios/ExecutorchLib.xcframework" - s.source_files = "ios/**/*.{h,m,mm}" + s.user_target_xcconfig = { + "HEADER_SEARCH_PATHS" => "$(PODS_TARGET_SRCROOT)/third-party/include", + "OTHER_LDFLAGS[sdk=iphoneos*][arch=*]" => [ + '$(inherited)', + '-framework "CoreML"', + '-framework "Accelerate"', + '-framework "Metal"', + '-framework "MetalPerformanceShaders"', + '-framework "MetalPerformanceShadersGraph"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libbackend_coreml_ios.a"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libbackend_mps_ios.a"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libbackend_xnnpack_ios.a"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libexecutorch_ios.a"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libkernels_custom_ios.a"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libkernels_optimized_ios.a"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libkernels_quantized_ios.a"' + ].join(' '), + + "OTHER_LDFLAGS[sdk=iphonesimulator*][arch=*]" => [ + '$(inherited)', + '-framework "CoreML"', + '-framework "Accelerate"', + '-framework "Metal"', + '-framework "MetalPerformanceShaders"', + '-framework "MetalPerformanceShadersGraph"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libbackend_coreml_simulator.a"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libbackend_mps_simulator.a"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libbackend_xnnpack_simulator.a"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libexecutorch_simulator.a"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libkernels_custom_simulator.a"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libkernels_optimized_simulator.a"', + '-force_load "$(PODS_ROOT)/../../node_modules/react-native-executorch/ios/libs/libkernels_quantized_simulator.a"' + ].join(' ') + } + + s.pod_target_xcconfig = { + "HEADER_SEARCH_PATHS" => "$(PODS_TARGET_SRCROOT)/third-party/include" + } + s.ios.vendored_frameworks = "ios/ExecutorchLib.xcframework" + s.source_files = "ios/**/*.{h,m,mm}", "common/**/*.{hpp,cpp,c,h}" + s.dependency "opencv-rne", "~> 0.1.0" + s.dependency "sqlite3" install_modules_dependencies(s) end \ No newline at end of file diff --git a/src/index.tsx b/src/index.tsx index ebd75b7b..b69a7741 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -1,5 +1,27 @@ import { SpeechToTextLanguage } from './types/stt'; +import { ETInstallerNativeModule } from './native/RnExecutorchModules'; + +// In the future install pick a symbol to check for to avoid installing multiple times +/* +// eslint-disable no-var +declare global { + var exampleGlobalFunction: () => void; +} +// eslint-disable no-var + +if (global.exampleGlobalFunction == null) { + if (!ETInstallerNativeModule) { + throw new Error( + `Failed to install react-native-executorch: The native module could not be found.` + ); + } + + ETInstallerNativeModule.install(); +} +*/ +ETInstallerNativeModule.install(); + // hooks export * from './hooks/computer_vision/useClassification'; export * from './hooks/computer_vision/useObjectDetection'; diff --git a/src/native/NativeETInstaller.ts b/src/native/NativeETInstaller.ts new file mode 100644 index 00000000..a163524d --- /dev/null +++ b/src/native/NativeETInstaller.ts @@ -0,0 +1,8 @@ +import type { TurboModule } from 'react-native'; +import { TurboModuleRegistry } from 'react-native'; + +export interface Spec extends TurboModule { + install(): boolean; +} + +export default TurboModuleRegistry.get('ETInstaller'); diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index 90b0cf51..895c0a7d 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -10,6 +10,7 @@ import { Spec as TextEmbeddingsInterface } from './NativeTextEmbeddings'; import { Spec as LLMInterface } from './NativeLLM'; import { Spec as ClassificationInterface } from './NativeClassification'; import { Spec as TokenizerInterface } from './NativeTokenizer'; +import { Spec as ETInstallerInterface } from './NativeETInstaller'; const LINKING_ERROR = `The package 'react-native-executorch' doesn't seem to be linked. Make sure: \n\n` + @@ -56,6 +57,8 @@ const TokenizerNativeModule: TokenizerInterface = returnSpecOrThrowLinkingError( ); const TextEmbeddingsNativeModule: TextEmbeddingsInterface = returnSpecOrThrowLinkingError(require('./NativeTextEmbeddings').default); +const ETInstallerNativeModule: ETInstallerInterface = + returnSpecOrThrowLinkingError(require('./NativeETInstaller').default); export { LLMNativeModule, @@ -69,4 +72,5 @@ export { VerticalOCRNativeModule, TextEmbeddingsNativeModule, TokenizerNativeModule, + ETInstallerNativeModule }; diff --git a/third-party/android/libs/arm64-v8a/libexecutorch.so b/third-party/android/libs/arm64-v8a/libexecutorch.so new file mode 100644 index 00000000..8e937625 Binary files /dev/null and b/third-party/android/libs/arm64-v8a/libexecutorch.so differ diff --git a/third-party/android/libs/x86_64/libexecutorch.so b/third-party/android/libs/x86_64/libexecutorch.so new file mode 100644 index 00000000..fd474f30 Binary files /dev/null and b/third-party/android/libs/x86_64/libexecutorch.so differ diff --git a/third-party/include/executorch/ExecuTorch.h b/third-party/include/executorch/ExecuTorch.h new file mode 100644 index 00000000..e1643971 --- /dev/null +++ b/third-party/include/executorch/ExecuTorch.h @@ -0,0 +1,9 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import "ExecuTorchLog.h" diff --git a/third-party/include/executorch/ExecuTorchLog.h b/third-party/include/executorch/ExecuTorchLog.h new file mode 100644 index 00000000..a71591c7 --- /dev/null +++ b/third-party/include/executorch/ExecuTorchLog.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Defines log levels with specific character codes representing each level. + */ +typedef NS_ENUM(NSInteger, ExecuTorchLogLevel) { + ExecuTorchLogLevelDebug = 'D', + ExecuTorchLogLevelInfo = 'I', + ExecuTorchLogLevelError = 'E', + ExecuTorchLogLevelFatal = 'F', + ExecuTorchLogLevelUnknown = '?' +} NS_SWIFT_NAME(LogLevel); + +/** + * A protocol defining the requirements for a log sink to receive log messages. + */ +NS_SWIFT_NAME(LogSink) +@protocol ExecuTorchLogSink + +/** + * Logs a message with the specified additional info. + * + * @param level The log level of the message. + * @param timestamp The timestamp of the log message since ExecuTorch PAL start. + * @param filename The name of the file generating the log message. + * @param line The line number in the file where the log message was generated. + * @param message The log message text. + */ +- (void)logWithLevel:(ExecuTorchLogLevel)level + timestamp:(NSTimeInterval)timestamp + filename:(NSString *)filename + line:(NSUInteger)line + message:(NSString *)message + NS_SWIFT_NAME(log(level:timestamp:filename:line:message:)); + +@end + +/** + * A singleton class for managing log sinks and dispatching log messages. + */ +NS_SWIFT_NAME(Log) +@interface ExecuTorchLog : NSObject + +/// The shared singleton log instance. +@property(class, readonly) ExecuTorchLog *sharedLog; + +/** + * Adds a log sink to receive log messages. + * + * @param sink The log sink to add. + */ +- (void)addSink:(id)sink NS_SWIFT_NAME(add(sink:)); + +/** + * Removes a previously added log sink. + * + * @param sink The log sink to remove. + */ +- (void)removeSink:(id)sink NS_SWIFT_NAME(remove(sink:)); + ++ (instancetype)new NS_UNAVAILABLE; +- (instancetype)init NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/third-party/include/executorch/extension/module/module.h b/third-party/include/executorch/extension/module/module.h new file mode 100644 index 00000000..5fb2723c --- /dev/null +++ b/third-party/include/executorch/extension/module/module.h @@ -0,0 +1,447 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace executorch { +namespace extension { + +/** + * A facade class for loading programs and executing methods within them. + */ +class Module { +public: + /** + * Enum to define loading behavior. + */ + enum class LoadMode { + /// Load the whole file as a buffer. + File, + /// Use mmap to load pages into memory. + Mmap, + /// Use memory locking and handle errors. + MmapUseMlock, + /// Use memory locking and ignore errors. + MmapUseMlockIgnoreErrors, + }; + + /** + * Constructs an instance by loading a program from a file with specified + * memory locking behavior. + * + * @param[in] file_path The path to the ExecuTorch program file to load. + * @param[in] load_mode The loading mode to use. + * @param[in] event_tracer A EventTracer used for tracking and logging events. + */ + explicit Module(const std::string &file_path, + const LoadMode load_mode = LoadMode::MmapUseMlock, + std::unique_ptr event_tracer = nullptr); + + /** + * Constructs an instance with the provided data loader and memory allocator. + * + * @param[in] data_loader A DataLoader used for loading program data. + * @param[in] memory_allocator A MemoryAllocator used for memory management. + * @param[in] temp_allocator A MemoryAllocator to use when allocating + * temporary data during kernel or delegate execution. + * @param[in] event_tracer A EventTracer used for tracking and logging events. + */ + explicit Module( + std::unique_ptr data_loader, + std::unique_ptr memory_allocator = nullptr, + std::unique_ptr temp_allocator = nullptr, + std::unique_ptr event_tracer = nullptr); + + /** + * Constructs an instance using an existing shared program. + * + * @param[in] program The shared program to use. It's required the data loader + * the program uses is valid for the lifetime of the program. + * @param[in] memory_allocator A MemoryAllocator used for memory management. + * @param[in] temp_allocator A MemoryAllocator to use when allocating + * temporary data. + * @param[in] event_tracer A EventTracer used for tracking and logging events. + */ + explicit Module( + std::shared_ptr program, + std::unique_ptr memory_allocator = nullptr, + std::unique_ptr temp_allocator = nullptr, + std::unique_ptr event_tracer = nullptr); + + Module(const Module &) = delete; + Module &operator=(const Module &) = delete; + Module(Module &&) = delete; + Module &operator=(Module &&) = delete; + + /** + * Loads the program if needed. + * + * @param[in] verification The type of verification to do before returning + * success. + * + * @returns An Error to indicate success or failure of the loading process. + */ + ET_NODISCARD + runtime::Error load(const runtime::Program::Verification verification = + runtime::Program::Verification::Minimal); + + /** + * Checks if the program is loaded. + * + * @returns true if the program is loaded, false otherwise. + */ + inline bool is_loaded() const { return program_ != nullptr; } + + /** + * Get the program. The data loader used by the program is guaranteed to be + * valid for the lifetime of the program. + * + * @returns Shared pointer to the program or nullptr if it's not yet loaded. + */ + inline std::shared_ptr program() const { return program_; } + + /** + * Get a list of method names available in the loaded program. + * Loads the program and method if needed. + * + * @returns A set of strings containing the names of the methods, or an error + * if the program or method failed to load. + */ + runtime::Result> method_names(); + + /** + * Load a specific method from the program and set up memory management if + * needed. The loaded method is cached to reuse the next time it's executed. + * + * @param[in] method_name The name of the method to load. + * @param[in] event_tracer Per-method event tracer to profile/trace methods + * individually. When not given, the event tracer passed to the Module + * constructor is used. Otherwise, this per-method event tracer takes + * precedence. + * + * @returns An Error to indicate success or failure. + */ + ET_NODISCARD + runtime::Error + load_method(const std::string &method_name, + torch::executor::EventTracer *event_tracer = nullptr); + + /** + * Load the 'forward' method from the program and set up memory management if + * needed. The loaded method is cached to reuse the next time it's executed. + * + * @param[in] event_tracer An event tracer used for tracking and logging + * events. + * + * @returns An Error to indicate success or failure. + */ + ET_NODISCARD inline runtime::Error + load_forward(torch::executor::EventTracer *event_tracer = nullptr) { + return load_method("forward", event_tracer); + } + + /** + * Checks if a specific method is loaded. + * + * @param[in] method_name The name of the method to check. + * + * @returns true if the method specified by method_name is loaded, false + * otherwise. + */ + inline bool is_method_loaded(const std::string &method_name) const { + return methods_.count(method_name); + } + + /** + * Get a method metadata struct by method name. + * Loads the program and method if needed. + * + * @param[in] method_name The name of the method to get the metadata for. + * + * @returns A method metadata, or an error if the program or method failed to + * load. + */ + runtime::Result + method_meta(const std::string &method_name); + + /** + * Execute a specific method with the given input values and retrieve the + * output values. Loads the program and method before executing if needed. + * + * @param[in] method_name The name of the method to execute. + * @param[in] input_values A vector of input values to be passed to the + * method. + * + * @returns A Result object containing either a vector of output values + * from the method or an error to indicate failure. + */ + ET_NODISCARD + runtime::Result> + execute(const std::string &method_name, + const std::vector &input_values); + + /** + * Execute a specific method with a single input value. + * Loads the program and method before executing if needed. + * + * @param[in] method_name The name of the method to execute. + * @param[in] input_value A value to be passed to the method. + * + * @returns A Result object containing either a vector of output values + * from the method or an error to indicate failure. + */ + ET_NODISCARD inline runtime::Result> + execute(const std::string &method_name, const runtime::EValue &input_value) { + return execute(method_name, std::vector{input_value}); + } + + /** + * Execute a specific method without any input values. + * Loads the program and method before executing if needed. + * + * @param[in] method_name The name of the method to execute. + * + * @returns A Result object containing either a vector of output values + * from the method or an error to indicate failure. + */ + ET_NODISCARD inline runtime::Result> + execute(const std::string &method_name) { + return execute(method_name, std::vector{}); + } + + /** + * Retrieve the output value of a specific method with the given input values. + * Loads the program and method before execution if needed. + * + * @param[in] method_name The name of the method to execute. + * @param[in] input_values A vector of input values to be passed to the + * method. + * + * @returns A Result object containing either the first output value from the + * method or an error to indicate failure. + */ + ET_NODISCARD inline runtime::Result + get(const std::string &method_name, + const std::vector &input_values) { + auto result = ET_UNWRAP(execute(method_name, input_values)); + if (result.empty()) { + return runtime::Error::InvalidArgument; + } + return result[0]; + } + + /** + * Retrieve the output value of a specific method with a single input value. + * Loads the program and method before execution if needed. + * + * @param[in] method_name The name of the method to execute. + * @param[in] input_value A value to be passed to the method. + * + * @returns A Result object containing either the first output value from the + * method or an error to indicate failure. + */ + ET_NODISCARD inline runtime::Result + get(const std::string &method_name, const runtime::EValue &input_value) { + return get(method_name, std::vector{input_value}); + } + + /** + * Retrieve the output value of a specific method without any input values. + * Loads the program and method before execution if needed. + * + * @param[in] method_name The name of the method to execute. + * + * @returns A Result object containing either the first output value from the + * method or an error to indicate failure. + */ + ET_NODISCARD inline runtime::Result + get(const std::string &method_name) { + return get(method_name, std::vector{}); + } + + /** + * Execute the 'forward' method with the given input values and retrieve the + * output values. Loads the program and method before executing if needed. + * + * @param[in] input_values A vector of input values for the 'forward' method. + * + * @returns A Result object containing either a vector of output values + * from the 'forward' method or an error to indicate failure. + */ + ET_NODISCARD inline runtime::Result> + forward(const std::vector &input_values) { + return execute("forward", input_values); + } + + /** + * Execute the 'forward' method with a single value. + * Loads the program and method before executing if needed. + * + * @param[in] input_value A value for the 'forward' method. + * + * @returns A Result object containing either a vector of output values + * from the 'forward' method or an error to indicate failure. + */ + ET_NODISCARD inline runtime::Result> + forward(const runtime::EValue &input_value) { + return forward(std::vector{input_value}); + } + + /** + * Execute the 'forward' method without any input values. + * Loads the program and method before executing if needed. + * + * @returns A Result object containing either a vector of output values + * from the 'forward' method or an error to indicate failure. + */ + ET_NODISCARD inline runtime::Result> forward() { + return forward(std::vector{}); + } + + /** + * Sets a single input value for a specific method. + * + * @param[in] method_name The name of the method. + * @param[in] input_value The EValue to set as the method input. + * @param[in] input_index Zero-based index of the input to set. + * + * @returns An Error to indicate success or failure. + */ + ET_NODISCARD + runtime::Error set_input(const std::string &method_name, + const runtime::EValue &input_value, + size_t input_index); + + /** + * Sets a single input value for the "forward" method. + * + * @param[in] input_value The EValue to set as the method input. + * @param[in] input_index Zero-based index of the input to set. + * + * @returns An Error to indicate success or failure. + */ + ET_NODISCARD + inline runtime::Error set_input(const runtime::EValue &input_value, + size_t input_index) { + return set_input("forward", input_value, input_index); + } + + /** + * Sets all input values for a specific method. + * + * @param[in] method_name The name of the method. + * @param[in] input_values A vector of EValues to set as the method inputs. + * + * @returns An Error to indicate success or failure. + */ + ET_NODISCARD + runtime::Error set_inputs(const std::string &method_name, + const std::vector &input_values); + + /** + * Sets all input values for the "forward" method. + * + * @param[in] input_values A vector of EValues to set as the method inputs. + * + * @returns An Error to indicate success or failure. + */ + ET_NODISCARD + inline runtime::Error + set_inputs(const std::vector &input_values) { + return set_inputs("forward", input_values); + } + + /** + * Sets the output tensor for a specific method. + * + * @param[in] method_name The name of the method. + * @param[in] output_value The EValue containing the Tensor to set as the + * method output. + * @param[in] output_index Zero-based index of the output to set. + * + * @returns An Error to indicate success or failure. + * + * @note Only Tensor outputs are currently supported for setting. + */ + ET_NODISCARD + runtime::Error set_output(const std::string &method_name, + runtime::EValue output_value, + size_t output_index = 0); + + /** + * Sets the output tensor for the "forward" method. + * + * @param[in] output_value The EValue containing the Tensor to set as the + * method output. + * @param[in] output_index Zero-based index of the output to set. + * + * @returns An Error to indicate success or failure. + * + * @note Only Tensor outputs are currently supported for setting. + */ + ET_NODISCARD + inline runtime::Error set_output(runtime::EValue output_value, + size_t output_index = 0) { + return set_output("forward", std::move(output_value), output_index); + } + + /** + * Retrieves the EventTracer instance being used by the Module. + * EventTracer is used for tracking and logging events during the execution + * of methods. + * + * @returns A pointer to the EventTracer instance. Returns nullptr if no + * EventTracer is set. + */ + inline runtime::EventTracer *event_tracer() const { + return event_tracer_.get(); + } + +private: + struct MethodHolder { + std::vector> planned_buffers; + std::vector> planned_spans; + std::unique_ptr planned_memory; + std::unique_ptr memory_manager; + std::unique_ptr method; + std::vector inputs; + }; + +private: + std::string file_path_; + LoadMode load_mode_{LoadMode::MmapUseMlock}; + std::shared_ptr program_; + std::unique_ptr data_loader_; + std::unique_ptr memory_allocator_; + std::unique_ptr temp_allocator_; + std::unique_ptr event_tracer_; + +protected: + std::unordered_map methods_; + + friend class ExecuTorchJni; +}; + +} // namespace extension +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::extension::Module; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/extension/tensor/tensor.h b/third-party/include/executorch/extension/tensor/tensor.h new file mode 100644 index 00000000..80a41018 --- /dev/null +++ b/third-party/include/executorch/extension/tensor/tensor.h @@ -0,0 +1,13 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// Umbrella header for the Tensor extension. +#include +#include diff --git a/third-party/include/executorch/extension/tensor/tensor_accessor.h b/third-party/include/executorch/extension/tensor/tensor_accessor.h new file mode 100644 index 00000000..d5b59b06 --- /dev/null +++ b/third-party/include/executorch/extension/tensor/tensor_accessor.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace extension { +namespace internal { + +/** + * Base class template storing the underlying data with size and stride helpers. + * Inherited by TensorAccessor<> which requires specialization on rank. + */ +template class TensorAccessorBase { +public: + /// Returns the size of the underlying tensor at the given dimension. + executorch::aten::SizesType size(ssize_t i) const { + ET_CHECK_MSG(i < dim_ && i >= 0, "Dimension outside of [0, %zd], got %zd", + dim_ - 1, i); + return sizes_[i]; + } + + /// Returns the stride of the underlying tensor at the given dimension. + executorch::aten::StridesType stride(ssize_t i) const { + ET_CHECK_MSG(i < dim_ && i >= 0, "Dimension outside of [0, %zd], got %zd", + dim_ - 1, i); + return strides_[i]; + } + +protected: + TensorAccessorBase(T *data, const executorch::aten::SizesType *sizes, + const executorch::aten::StridesType *strides, ssize_t dim) + : data_(data), sizes_(sizes), strides_(strides), dim_(dim) {} + + T *data_; + const executorch::aten::SizesType *sizes_; + const executorch::aten::StridesType *strides_; + ssize_t dim_; +}; + +} // namespace internal + +/** + * TensorAccessor template with data type and rank as template parameters. No + * public constructors, can only be created using make_tensor_accessor from a + * given executorch::aten::Tensor. Use operator[] to index and obtain a lower + * rank accessor or the underlying scalar value. + */ +template +class TensorAccessor : public internal::TensorAccessorBase { +public: + /** + * Index into the the outer most dimension. + * + * @param i Index. + * @return If N > 1, a TensorAccessor with N-1 dimensions. If N == 1, a + * reference to the underlying scalar. Refer to the TensorAccessor + * specialization. + */ + TensorAccessor operator[](ssize_t i) { + return TensorAccessor(this->data_ + this->strides_[0] * i, + this->sizes_ + 1, this->strides_ + 1, + N - 1); + } + + /** + * Index into the the outer most dimension. + * + * @param i Index. + * @return If N > 1, a constant TensorAccessor with N-1 dimensions. If N == 1, + * a constant reference to the underlying scalar. Refer to the + * TensorAccessor specialization. + */ + const TensorAccessor operator[](ssize_t i) const { + return TensorAccessor(this->data_ + this->strides_[0] * i, + this->sizes_ + 1, this->strides_ + 1, + N - 1); + } + +private: + TensorAccessor(T *data, const executorch::aten::SizesType *sizes, + const executorch::aten::StridesType *strides, ssize_t dim) + : internal::TensorAccessorBase(data, sizes, strides, dim) {} + + template friend class TensorAccessor; + + template + friend executorch::runtime::Result> + make_tensor_accessor(const executorch::aten::Tensor &t); +}; + +/** + * TensorAccessor specialization for N == 1, where operator[] returns a + * reference to the underlying scalar. + */ +template +class TensorAccessor : public internal::TensorAccessorBase { +public: + /** + * Index into the the outer most dimension. + * + * @param i Index. + * @return Reference to the underlying scalar. + */ + T &operator[](ssize_t i) { return this->data_[this->strides_[0] * i]; } + + /** + * Index into the the outer most dimension. + * + * @param i Index. + * @return Constant reference to the underlying scalar. + */ + const T &operator[](ssize_t i) const { + return this->data_[this->strides_[0] * i]; + } + +private: + TensorAccessor(T *data, const executorch::aten::SizesType *sizes, + const executorch::aten::StridesType *strides, ssize_t dim) + : internal::TensorAccessorBase(data, sizes, strides, dim) {} + + template friend class TensorAccessor; + + template + friend executorch::runtime::Result> + make_tensor_accessor(const executorch::aten::Tensor &t); +}; + +/** + * Creates a TensorAccessor from the given tensor. The number of dimension + * N and the data type T's size must match those of the input tensor. For + * Executorch tensors, non-trivial dimension order is not supported. + * + * @param tensor Origin tensor. The TensorImpl inside must outlive the returned + * TensorAccessor. + * @return TensorAccessor of the input tensor. + * @retval Error::InvalidArgument Mismatch on data type or number of dimensions. + * @retval Error::NotSupported Input tensor has non-trivial dimension onrder. + */ +template +executorch::runtime::Result> +make_tensor_accessor(const executorch::aten::Tensor &tensor) { + static_assert(N > 0, "TensorAccessor is used for indexing tensors, for " + "scalar use *_data_ptr()"); + + if (N != tensor.dim()) { + ET_LOG(Error, "Expecting %zd dimensions but tensor has %zd.", + static_cast(N), static_cast(tensor.dim())); + return executorch::runtime::Error::InvalidArgument; + } + + if (sizeof(T) != tensor.element_size()) { + ET_LOG(Error, + "Size of data type template argument (%zd) not equal to tensor " + "element size (%zd)", + static_cast(sizeof(T)), + static_cast(tensor.element_size())); + return executorch::runtime::Error::InvalidArgument; + } + +#ifndef USE_ATEN_LIB + auto dim_order = tensor.dim_order(); + for (ssize_t i = 0; i < dim_order.size(); i++) { + if (dim_order[i] != i) { + ET_LOG(Error, "Non-trival dim_order not supported."); + return executorch::runtime::Error::NotSupported; + } + } +#endif + + T *ptr = nullptr; + if constexpr (std::is_const_v) { + ptr = tensor.const_data_ptr(); + } else { + ptr = tensor.mutable_data_ptr(); + } + return TensorAccessor(ptr, tensor.sizes().data(), + tensor.strides().data(), N); +} + +} // namespace extension +} // namespace executorch diff --git a/third-party/include/executorch/extension/tensor/tensor_ptr.h b/third-party/include/executorch/extension/tensor/tensor_ptr.h new file mode 100644 index 00000000..f6a8009b --- /dev/null +++ b/third-party/include/executorch/extension/tensor/tensor_ptr.h @@ -0,0 +1,347 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace executorch { +namespace extension { + +/** + * A smart pointer type for managing the lifecycle of a Tensor. + */ +using TensorPtr = std::shared_ptr; + +/** + * Creates a TensorPtr that manages a Tensor with the specified properties. + * + * @param sizes A vector specifying the size of each dimension. + * @param data A pointer to the data buffer. + * @param dim_order A vector specifying the order of dimensions. + * @param strides A vector specifying the strides of the tensor. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies the mutability of the tensor's shape. + * @param deleter A custom deleter function for managing the lifetime of the + * data buffer. If provided, this deleter will be called when the managed Tensor + * object is destroyed. + * @return A TensorPtr that manages the newly created Tensor. + */ +TensorPtr +make_tensor_ptr(std::vector sizes, void *data, + std::vector dim_order, + std::vector strides, + const executorch::aten::ScalarType type = + executorch::aten::ScalarType::Float, + const executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND, + std::function deleter = nullptr); + +/** + * Creates a TensorPtr that manages a Tensor with the specified properties. + * + * @param sizes A vector specifying the size of each dimension. + * @param data A pointer to the data buffer. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies the mutability of the tensor's shape. + * @param deleter A custom deleter function for managing the lifetime of the + * data buffer. If provided, this deleter will be called when the managed Tensor + * object is destroyed. + * @return A TensorPtr that manages the newly created Tensor. + */ +inline TensorPtr +make_tensor_ptr(std::vector sizes, void *data, + const executorch::aten::ScalarType type = + executorch::aten::ScalarType::Float, + const executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND, + std::function deleter = nullptr) { + return make_tensor_ptr(std::move(sizes), data, {}, {}, type, dynamism, + std::move(deleter)); +} + +/** + * Creates a TensorPtr that manages a Tensor with the specified properties. + * + * This template overload is specialized for cases where the tensor data is + * provided as a vector. The scalar type is automatically deduced from the + * vector's data type. If the specified `type` differs from the deduced type of + * the vector's elements, and casting is allowed, the data will be cast to the + * specified `type`. This allows for flexible creation of tensors with data + * vectors of one type and a different scalar type. + * + * @tparam T The C++ type of the tensor elements, deduced from the vector. + * @param sizes A vector specifying the size of each dimension. + * @param data A vector containing the tensor's data. + * @param dim_order A vector specifying the order of dimensions. + * @param strides A vector specifying the strides of each dimension. + * @param type The scalar type of the tensor elements. If it differs from the + * deduced type, the data will be cast to this type if allowed. + * @param dynamism Specifies the mutability of the tensor's shape. + * @return A TensorPtr that manages the newly created TensorImpl. + */ +template ::value> +inline TensorPtr +make_tensor_ptr(std::vector sizes, + std::vector data, + std::vector dim_order = {}, + std::vector strides = {}, + executorch::aten::ScalarType type = deduced_type, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + if (type != deduced_type) { + ET_CHECK_MSG(runtime::canCast(deduced_type, type), + "Cannot cast deduced type to specified type."); + std::vector casted_data(data.size() * runtime::elementSize(type)); + ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "make_tensor_ptr", CTYPE, [&] { + std::transform(data.begin(), data.end(), + reinterpret_cast(casted_data.data()), + [](const T &val) { return static_cast(val); }); + }); + const auto raw_data_ptr = casted_data.data(); + auto data_ptr = + std::make_shared>(std::move(casted_data)); + return make_tensor_ptr(std::move(sizes), raw_data_ptr, std::move(dim_order), + std::move(strides), type, dynamism, + [data_ptr = std::move(data_ptr)](void *) {}); + } + const auto raw_data_ptr = data.data(); + auto data_ptr = std::make_shared>(std::move(data)); + return make_tensor_ptr(std::move(sizes), raw_data_ptr, std::move(dim_order), + std::move(strides), type, dynamism, + [data_ptr = std::move(data_ptr)](void *) {}); +} + +/** + * Creates a TensorPtr that manages a Tensor with the specified properties. + * + * This template overload is specialized for cases where the tensor data is + * provided as a vector. The scalar type is automatically deduced from the + * vector's data type. If the specified `type` differs from the deduced type of + * the vector's elements, and casting is allowed, the data will be cast to the + * specified `type`. This allows for flexible creation of tensors with data + * vectors of one type and a different scalar type. + * + * @tparam T The C++ type of the tensor elements, deduced from the vector. + * @param data A vector containing the tensor's data. + * @param type The scalar type of the tensor elements. If it differs from the + * deduced type, the data will be cast to this type if allowed. + * @param dynamism Specifies the mutability of the tensor's shape. + * @return A TensorPtr that manages the newly created TensorImpl. + */ +template ::value> +inline TensorPtr +make_tensor_ptr(std::vector data, + executorch::aten::ScalarType type = deduced_type, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + std::vector sizes{ + executorch::aten::SizesType(data.size())}; + return make_tensor_ptr(std::move(sizes), std::move(data), {0}, {1}, type, + dynamism); +} + +/** + * Creates a TensorPtr that manages a Tensor with the specified properties. + * + * This template overload is specialized for cases where the tensor data is + * provided as an initializer list. The scalar type is automatically deduced + * from the initializer list's data type. If the specified `type` differs from + * the deduced type of the initializer list's elements, and casting is allowed, + * the data will be cast to the specified `type`. This allows for flexible + * creation of tensors with data vectors of one type and a different scalar + * type. + * + * @tparam T The C++ type of the tensor elements, deduced from the initializer + * list. + * @param sizes A vector specifying the size of each dimension. + * @param list An initializer list containing the tensor's data. + * @param dim_order A vector specifying the order of dimensions. + * @param strides A vector specifying the strides of each dimension. + * @param type The scalar type of the tensor elements. If it differs from the + * deduced type, the data will be cast to this type if allowed. + * @param dynamism Specifies the mutability of the tensor's shape. + * @return A TensorPtr that manages the newly created TensorImpl. + */ +template ::value> +inline TensorPtr +make_tensor_ptr(std::vector sizes, + std::initializer_list list, + std::vector dim_order = {}, + std::vector strides = {}, + executorch::aten::ScalarType type = deduced_type, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return make_tensor_ptr(std::move(sizes), std::vector(std::move(list)), + std::move(dim_order), std::move(strides), type, + dynamism); +} + +/** + * Creates a TensorPtr that manages a Tensor with the specified properties. + * + * This template overload allows creating a Tensor from an initializer list + * of data. The scalar type is automatically deduced from the type of the + * initializer list's elements. If the specified `type` differs from + * the deduced type of the initializer list's elements, and casting is allowed, + * the data will be cast to the specified `type`. This allows for flexible + * creation of tensors with data vectors of one type and a different scalar + * type. + * + * @tparam T The C++ type of the tensor elements, deduced from the initializer + * list. + * @param list An initializer list containing the tensor's data. + * @param type The scalar type of the tensor elements. If it differs from the + * deduced type, the data will be cast to this type if allowed. + * @param dynamism Specifies the mutability of the tensor's shape. + * @return A TensorPtr that manages the newly created TensorImpl. + */ +template ::value> +inline TensorPtr +make_tensor_ptr(std::initializer_list list, + executorch::aten::ScalarType type = deduced_type, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + std::vector sizes{ + executorch::aten::SizesType(list.size())}; + return make_tensor_ptr(std::move(sizes), std::move(list), {0}, {1}, type, + dynamism); +} + +/** + * Creates a TensorPtr that manages a Tensor with a single scalar value. + * + * @tparam T The C++ type of the scalar value. + * @param value The scalar value to be used for the Tensor. + * @return A TensorPtr that manages the newly created TensorImpl. + */ +template inline TensorPtr make_tensor_ptr(T value) { + return make_tensor_ptr({}, std::vector{value}); +} + +/** + * Creates a TensorPtr that manages a Tensor with the specified properties. + * + * This overload accepts a raw memory buffer stored in a std::vector + * and a scalar type to interpret the data. The vector is managed, and the + * memory's lifetime is tied to the TensorImpl. + * + * @param sizes A vector specifying the size of each dimension. + * @param data A vector containing the raw memory for the tensor's data. + * @param dim_order A vector specifying the order of dimensions. + * @param strides A vector specifying the strides of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies the mutability of the tensor's shape. + * @return A TensorPtr managing the newly created Tensor. + */ +TensorPtr make_tensor_ptr( + std::vector sizes, std::vector data, + std::vector dim_order, + std::vector strides, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND); + +/** + * Creates a TensorPtr that manages a Tensor with the specified properties. + * + * This overload accepts a raw memory buffer stored in a std::vector + * and a scalar type to interpret the data. The vector is managed, and the + * memory's lifetime is tied to the TensorImpl. + * + * @param sizes A vector specifying the size of each dimension. + * @param data A vector containing the raw memory for the tensor's data. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies the mutability of the tensor's shape. + * @return A TensorPtr managing the newly created Tensor. + */ +inline TensorPtr make_tensor_ptr( + std::vector sizes, std::vector data, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return make_tensor_ptr(std::move(sizes), std::move(data), {}, {}, type, + dynamism); +} + +/** + * Creates a TensorPtr to manage a new Tensor with the same properties + * as the given Tensor, sharing the same data without owning it. + * + * @param tensor The Tensor whose properties are used to create a new TensorPtr. + * @return A new TensorPtr managing a Tensor with the same properties as the + * original. + */ +inline TensorPtr make_tensor_ptr(const executorch::aten::Tensor &tensor) { + return make_tensor_ptr( + std::vector(tensor.sizes().begin(), + tensor.sizes().end()), + tensor.mutable_data_ptr(), +#ifndef USE_ATEN_LIB + std::vector(tensor.dim_order().begin(), + tensor.dim_order().end()), + std::vector(tensor.strides().begin(), + tensor.strides().end()), + tensor.scalar_type(), tensor.shape_dynamism() +#else // USE_ATEN_LIB + {}, + std::vector(tensor.strides().begin(), + tensor.strides().end()), + tensor.scalar_type() +#endif // USE_ATEN_LIB + ); +} + +/** + * Creates a TensorPtr that manages a new Tensor with the same properties + * as the given Tensor, but with a copy of the data owned by the returned + * TensorPtr, or nullptr if the original data is null. + * + * @param tensor The Tensor to clone. + * @return A new TensorPtr that manages a Tensor with the same properties as the + * original but with copied data. + */ +TensorPtr clone_tensor_ptr(const executorch::aten::Tensor &tensor); + +/** + * Creates a new TensorPtr by cloning the given TensorPtr, copying the + * underlying data. + * + * @param tensor The TensorPtr to clone. + * @return A new TensorPtr that manages a Tensor with the same properties as the + * original but with copied data. + */ +inline TensorPtr clone_tensor_ptr(const TensorPtr &tensor) { + return clone_tensor_ptr(*tensor); +} + +/** + * Resizes the Tensor managed by the provided TensorPtr to the new sizes. + * + * @param tensor A TensorPtr managing the Tensor to resize. + * @param sizes A vector representing the new sizes for each dimension. + * @return Error::Ok on success, or an appropriate error code on failure. + */ +ET_NODISCARD +runtime::Error +resize_tensor_ptr(TensorPtr &tensor, + const std::vector &sizes); + +} // namespace extension +} // namespace executorch diff --git a/third-party/include/executorch/extension/tensor/tensor_ptr_maker.h b/third-party/include/executorch/extension/tensor/tensor_ptr_maker.h new file mode 100644 index 00000000..f6a70dd4 --- /dev/null +++ b/third-party/include/executorch/extension/tensor/tensor_ptr_maker.h @@ -0,0 +1,655 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace extension { + +/** + * A helper class for creating TensorPtr instances from raw data and tensor + * properties. Note that the TensorPtr created by this class does not own the + * data, so the data must outlive the TensorPtr. + * + * TensorPtrMaker provides a fluent interface for specifying various tensor + * properties, such as type, sizes, data pointer, dimension order, strides, and + * shape dynamism. The final tensor is created by invoking make_tensor_ptr() or + * by converting TensorPtrMaker to TensorPtr. + */ +class TensorPtrMaker final { +public: + // This class may have non-copyable members in the future. + TensorPtrMaker(const TensorPtrMaker &) = delete; + TensorPtrMaker &operator=(const TensorPtrMaker &) = delete; + // But it is movable. + TensorPtrMaker(TensorPtrMaker &&) = default; + TensorPtrMaker &operator=(TensorPtrMaker &&) = default; + + /** + * Sets the scalar type of the tensor elements. + * + * @param type The scalar type (e.g., float, int, bool). + * @return Rvalue to this TensorPtrMaker for method chaining. + */ + TensorPtrMaker &&type(executorch::aten::ScalarType type) { + type_ = type; + return std::move(*this); + } + + /** + * Sets the order of dimensions in memory. + * + * @param dim_order A vector specifying the dimension order. + * @return Rvalue to this TensorPtrMaker for method chaining. + */ + TensorPtrMaker && + dim_order(std::vector dim_order) { + dim_order_ = std::move(dim_order); + return std::move(*this); + } + + /** + * Sets the strides for each dimension of the tensor. + * + * @param strides A vector specifying the stride for each dimension. + * @return Rvalue to this TensorPtrMaker for method chaining. + */ + TensorPtrMaker &&strides(std::vector strides) { + strides_ = std::move(strides); + return std::move(*this); + } + + /** + * Sets the shape dynamism of the tensor. + * + * @param dynamism Specifies whether the tensor's shape is static, dynamic, or + * bounded. + * @return Rvalue to this TensorPtrMaker for method chaining. + */ + TensorPtrMaker &&dynamism(executorch::aten::TensorShapeDynamism dynamism) { + dynamism_ = dynamism; + return std::move(*this); + } + + /** + * Sets a custom deleter function to manage the lifetime of the data buffer. + * + * @param deleter A function that will be called to delete the data buffer + * when the Tensor object managed by the TensorPtr is destroyed. Explicitly + * consuming an rvalue to avoid unnecessary copies when the deleter is a + * lambda that has captured some state. + * @return Rvalue to this TensorPtrMaker for method chaining. + */ + TensorPtrMaker &&deleter(std::function &&deleter) { + deleter_ = std::move(deleter); + return std::move(*this); + } + + /** + * Creates and returns a TensorPtr instance using the properties set in this + * TensorPtrMaker. + * + * @return A TensorPtr instance that manages the newly created Tensor. + */ + TensorPtr make_tensor_ptr() && { + return ::executorch::extension::make_tensor_ptr( + std::move(sizes_), data_, std::move(dim_order_), std::move(strides_), + type_, dynamism_, std::move(deleter_)); + } + + /** + * Implicit conversion operator to create a TensorPtr. + * + * @return A TensorPtr instance that manages the newly created Tensor. + */ + operator TensorPtr() && { return std::move(*this).make_tensor_ptr(); } + +private: + TensorPtrMaker(void *data, std::vector sizes, + executorch::aten::ScalarType type) + : sizes_(std::move(sizes)), data_(data), type_(type) {} + +private: + // The following properties are required to create a Tensor. + friend TensorPtrMaker for_blob(void *data, + std::vector sizes, + executorch::aten::ScalarType type); + +private: + std::vector sizes_; + std::vector strides_; + std::vector dim_order_; + std::function deleter_ = nullptr; + void *data_ = nullptr; + executorch::aten::ScalarType type_ = executorch::aten::ScalarType::Float; + executorch::aten::TensorShapeDynamism dynamism_ = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND; +}; + +/** + * Creates a TensorPtrMaker instance for building a TensorPtr from a raw data + * pointer and tensor sizes. + * + * The TensorPtrMaker returned by this function allows for further customization + * of the tensor's properties, such as data type, dimension order, strides, and + * shape dynamism, before finalizing the TensorPtr creation. + * + * @param data A pointer to the raw data to be used by the tensor. It must + * outlive the TensorPtr created by this function. + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @return A TensorPtrMaker instance for creating a TensorPtr. + */ +inline TensorPtrMaker for_blob( + void *data, std::vector sizes, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float) { + return TensorPtrMaker(data, std::move(sizes), type); +} + +/** + * Creates a TensorPtr from a raw data pointer and tensor sizes, with an + * optional dynamism setting. + * + * This function provides a convenient way to create a tensor from existing + * data, with the option to specify whether the tensor's shape is static or + * dynamic. + * + * @param data A pointer to the raw data used by the tensor. The data must + * outlive the TensorPtr created by this function. + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr from_blob( + void *data, std::vector sizes, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return for_blob(data, std::move(sizes), type) + .dynamism(dynamism) + .make_tensor_ptr(); +} + +/** + * Creates a TensorPtr from a raw data pointer, tensor sizes, and strides, with + * an optional dynamism setting. + * + * This function allows for the creation of a tensor from existing data, with + * the option to specify custom strides for each dimension and whether the + * tensor’s shape is static, dynamic, or bounded. + * + * @param data A pointer to the raw data used by the tensor. The data must + * outlive the TensorPtr created by this function. + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static, dynamic, or + * bounded. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr from_blob( + void *data, std::vector sizes, + std::vector strides, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return for_blob(data, std::move(sizes), type) + .strides(std::move(strides)) + .dynamism(dynamism) + .make_tensor_ptr(); +} + +/** + * Creates a TensorPtr from a raw data pointer and tensor sizes, with an + * optional dynamism setting. + * + * This function is a convenient way to create a tensor from existing data, with + * the option to specify whether the tensor's shape is static, dynamic, or + * bounded. + * + * @param data A pointer to the raw data to be used by the tensor. It must + * outlive the TensorPtr created by this function. + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param deleter A function to delete the data when it's no longer needed. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance that manages the newly created Tensor. + */ +inline TensorPtr +from_blob(void *data, std::vector sizes, + executorch::aten::ScalarType type, + std::function &&deleter, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return for_blob(data, std::move(sizes), type) + .deleter(std::move(deleter)) + .dynamism(dynamism) + .make_tensor_ptr(); +} + +/** + * Creates a TensorPtr from a raw data pointer, tensor sizes, and strides, with + * an optional dynamism setting. + * + * This function allows for the creation of a tensor from existing data, with + * the option to specify custom strides for each dimension and whether the + * tensor's shape is static, dynamic, or bounded. + * + * @param data A pointer to the raw data to be used by the tensor. It must + * outlive the TensorPtr created by this function. + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param type The scalar type of the tensor elements. + * @param deleter A function to delete the data when it's no longer needed. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance that manages the newly created Tensor. + */ +inline TensorPtr +from_blob(void *data, std::vector sizes, + std::vector strides, + executorch::aten::ScalarType type, + std::function &&deleter, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return for_blob(data, std::move(sizes), type) + .strides(std::move(strides)) + .deleter(std::move(deleter)) + .dynamism(dynamism) + .make_tensor_ptr(); +} + +/** + * Creates a TensorPtr with the specified sizes, strides, and properties. + * + * This function allocates memory for the tensor elements but does not + * initialize them with any specific values. The tensor is created with the + * specified strides. + * + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +TensorPtr empty_strided( + std::vector sizes, + std::vector strides, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND); + +/** + * Creates an empty TensorPtr with the same size and properties as the given + * tensor. + * + * This function allocates memory for the tensor elements but does not + * initialize them with any specific values. + * + * @param other A reference to another tensor, whose size and properties are + * used. + * @param type The scalar type of the tensor elements. If not provided, the + * scalar type of the other tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr empty_like( + const TensorPtr &other, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Undefined, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + if (type == executorch::aten::ScalarType::Undefined) { + type = other->scalar_type(); + } + return empty_strided({other->sizes().begin(), other->sizes().end()}, + {other->strides().begin(), other->strides().end()}, type, + dynamism); +} + +/** + * Creates an empty TensorPtr with the specified sizes and properties. + * + * This function allocates memory for the tensor elements but does not + * initialize them with any specific values. + * + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr +empty(std::vector sizes, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return empty_strided(std::move(sizes), {}, type, dynamism); +} + +/** + * Creates a TensorPtr filled with the specified value. + * + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param fill_value The value to fill the tensor with. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +TensorPtr full_strided( + std::vector sizes, + std::vector strides, + executorch::aten::Scalar fill_value, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND); + +/** + * Creates a TensorPtr filled with the specified value, with the same size and + * properties as another tensor. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param fill_value The value to fill the tensor with. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the other tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr full_like( + const TensorPtr &other, executorch::aten::Scalar fill_value, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Undefined, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + if (type == executorch::aten::ScalarType::Undefined) { + type = other->scalar_type(); + } + return full_strided({other->sizes().begin(), other->sizes().end()}, + {other->strides().begin(), other->strides().end()}, + fill_value, type, dynamism); +} + +/** + * Creates a TensorPtr filled with the specified value. + * + * @param sizes A vector specifying the size of each dimension. + * @param fill_value The value used to fill the tensor. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr +full(std::vector sizes, + executorch::aten::Scalar fill_value, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full_strided(std::move(sizes), {}, fill_value, type, dynamism); +} + +/** + * Creates a TensorPtr holding a scalar value. + * + * @param value The scalar value for the tensor. + * @param type The scalar type of the tensor elements. + * @return A TensorPtr instance managing the newly created scalar Tensor. + */ +inline TensorPtr scalar_tensor( + executorch::aten::Scalar value, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float) { + return full({}, value, type); +} + +/** + * Creates a TensorPtr filled with ones, with the same size and properties as + * another tensor. + * + * @param other A reference to another tensor, whose size and properties are + * used. + * @param type The scalar type of the tensor elements. If not provided, the + * scalar type of the other tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr ones_like( + const TensorPtr &other, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Undefined, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full_like(other, 1, type, dynamism); +} + +/** + * Creates a TensorPtr filled with ones. + * + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr +ones(std::vector sizes, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full(std::move(sizes), 1, type, dynamism); +} + +/** + * Creates a TensorPtr filled with zeros, with the same size and properties as + * another tensor. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the `other` tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr zeros_like( + const TensorPtr &other, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Undefined, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full_like(other, 0, type, dynamism); +} + +/** + * Creates a TensorPtr filled with zeros. + * + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr +zeros(std::vector sizes, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full(std::move(sizes), 0, type, dynamism); +} + +/** + * Creates a TensorPtr filled with random values between 0 and 1. + * + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + **/ +TensorPtr rand_strided( + std::vector sizes, + std::vector strides, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND); + +/** + * Creates a TensorPtr filled with random values between 0 and 1. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the other tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr rand_like( + const TensorPtr &other, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Undefined, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + if (type == executorch::aten::ScalarType::Undefined) { + type = other->scalar_type(); + } + return rand_strided({other->sizes().begin(), other->sizes().end()}, + {other->strides().begin(), other->strides().end()}, type, + dynamism); +} + +/** + * Creates a TensorPtr filled with random values between 0 and 1. + * + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr +rand(std::vector sizes, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return rand_strided(std::move(sizes), {}, type, dynamism); +} + +/** + * Creates a TensorPtr filled with random values between 0 and 1, with specified + * strides. + * + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +TensorPtr randn_strided( + std::vector sizes, + std::vector strides, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND); + +/** + * Creates a TensorPtr filled with random values from a normal distribution. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the other tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr randn_like( + const TensorPtr &other, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Undefined, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + if (type == executorch::aten::ScalarType::Undefined) { + type = other->scalar_type(); + } + return randn_strided({other->sizes().begin(), other->sizes().end()}, + {other->strides().begin(), other->strides().end()}, type, + dynamism); +} + +/** + * Creates a TensorPtr filled with random values sampled from a normal + * distribution. + * + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr +randn(std::vector sizes, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return randn_strided(std::move(sizes), {}, type, dynamism); +} + +/** + * Creates a TensorPtr filled with random integer values in the given range. + * + * @param low The lower bound (inclusive) of the random values. + * @param high The upper bound (exclusive) of the random values. + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +TensorPtr randint_strided( + int64_t low, int64_t high, std::vector sizes, + std::vector strides, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Int, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND); + +/** + * Creates a TensorPtr filled with random integer values in the given range. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param low The lower bound (inclusive) of the random values. + * @param high The upper bound (exclusive) of the random values. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the other tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr randint_like( + const TensorPtr &other, int64_t low, int64_t high, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Undefined, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + if (type == executorch::aten::ScalarType::Undefined) { + type = other->scalar_type(); + } + return randint_strided( + low, high, {other->sizes().begin(), other->sizes().end()}, + {other->strides().begin(), other->strides().end()}, type, dynamism); +} + +/** + * Creates a TensorPtr filled with random integer values within the specified + * range. + * + * @param low The inclusive lower bound of the random values. + * @param high The exclusive upper bound of the random values. + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr +randint(int64_t low, int64_t high, + std::vector sizes, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Int, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return randint_strided(low, high, std::move(sizes), {}, type, dynamism); +} + +} // namespace extension +} // namespace executorch diff --git a/third-party/include/executorch/runtime/backend/backend_execution_context.h b/third-party/include/executorch/runtime/backend/backend_execution_context.h new file mode 100644 index 00000000..42413073 --- /dev/null +++ b/third-party/include/executorch/runtime/backend/backend_execution_context.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace runtime { + +/** + * BackendExecutionContext will be used to inject run time context. + */ +class BackendExecutionContext final { +public: + BackendExecutionContext(EventTracer *event_tracer = nullptr, + MemoryAllocator *temp_allocator = nullptr, + const char *method_name = nullptr) + : event_tracer_(event_tracer), temp_allocator_(temp_allocator), + method_name_(method_name) {} + + /** + * Returns a pointer to an instance of EventTracer to do profiling/debugging + * logging inside the delegate backend. Users will need access to this pointer + * to use any of the event tracer APIs. + */ + EventTracer *event_tracer() { return event_tracer_; } + + /** + * Returns a pointer to the address allocated by temp allocator. This + * allocator will be reset after every delegate call during execution. + */ + void *allocate(size_t size, + size_t alignment = MemoryAllocator::kDefaultAlignment) { + // TODO(chenlai): depends on the need, we may expose more functionality for + // memory allocation. + return temp_allocator_->allocate(size, alignment); + } + + /** + * Returns the temp allocator. This allocator will be reset every instruction. + */ + MemoryAllocator *get_temp_allocator() { return temp_allocator_; } + + /** + * Get the name of the executing method from the ExecuTorch runtime. + */ + const char *get_method_name() const { return method_name_; } + +private: + EventTracer *event_tracer_ = nullptr; + MemoryAllocator *temp_allocator_ = nullptr; + const char *method_name_ = nullptr; +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::BackendExecutionContext; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/backend/backend_init_context.h b/third-party/include/executorch/runtime/backend/backend_init_context.h new file mode 100644 index 00000000..291cc720 --- /dev/null +++ b/third-party/include/executorch/runtime/backend/backend_init_context.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include + +namespace executorch { +namespace runtime { + +/** + * BackendInitContext will be used to inject runtime info for to initialize + * delegate. + */ +class BackendInitContext final { +public: + explicit BackendInitContext(MemoryAllocator *runtime_allocator, + const char *method_name = nullptr) + : runtime_allocator_(runtime_allocator), method_name_(method_name) {} + + /** Get the runtime allocator passed from Method. It's the same runtime + * executor used by the standard executor runtime and the life span is the + * same as the model. + */ + MemoryAllocator *get_runtime_allocator() { return runtime_allocator_; } + + /** Get the loaded method name from ExecuTorch runtime. Usually it's + * "forward", however, if there are multiple methods in the .pte file, it can + * be different. One example is that we may have prefill and decode methods in + * the same .pte file. In this case, when client loads "prefill" method, the + * `get_method_name` function will return "prefill", when client loads + * "decode" method, the `get_method_name` function will return "decode". + */ + const char *get_method_name() const { return method_name_; } + +private: + MemoryAllocator *runtime_allocator_ = nullptr; + const char *method_name_ = nullptr; +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::BackendInitContext; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/backend/interface.h b/third-party/include/executorch/runtime/backend/interface.h new file mode 100644 index 00000000..42f8698e --- /dev/null +++ b/third-party/include/executorch/runtime/backend/interface.h @@ -0,0 +1,155 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace runtime { + +struct SizedBuffer { + void *buffer; + size_t nbytes; // number of bytes of buffer +}; + +struct CompileSpec { + const char *key; // spec key + SizedBuffer value; // spec value +}; + +/** + * An opaque handle managed by a backend. Typically points to a backend-private + * class/struct. + */ +using DelegateHandle = void; + +class BackendInterface { +public: + virtual ~BackendInterface() = 0; + + /** + * Returns true if the backend is available to process delegation calls. + */ + ET_NODISCARD virtual bool is_available() const = 0; + + /** + * Responsible to further process (compile/transform/optimize) the compiled + * unit that was produced, ahead-of-time, as well as perform any backend + * initialization to ready it for execution. This method is called every time + * the ExecuTorch program is initialized. Consequently, this is the place to + * perform any backend initialization as well as transformations, + * optimizations, and even compilation that depend on the target device. As + * such, it is strongly encouraged to push as much processing as possible to + * the ahead-of-time processing. + * + * @param[in] processed An opaque (to ExecuTorch) backend-specific compiled + * unit from the preprocessor. Can contain anything the backend needs to + * execute the equivalent semantics of the passed-in Module and its + * method. Often passed unmodified to `execute()` as a `DelegateHandle`, + * unless it needs further processing at init time to be fully executable. + * If the data is not needed after init(), calling processed->Free() can + * reclaim its memory. + * @param[in] compile_specs The exact same compiler specification that + * was used ahead-of-time to produce `processed`. + * + * @returns On success, an opaque handle representing the the method + * implemented by the delegate. This handle is passed to `execute()` and + * `destroy()`, and the memory it points to is owned by the backend. + * Typically points to a backend-private class/struct. + * @returns On error, returns an error code other than Error::Ok. If the + * compiled unit (the preprocessed result from ahead of time) is not + * compatible with the current backend runtime, return the error code + * Error::DelegateInvalidCompatibility. Other backend delegate + * specific error codes can be found in error.h. + */ + ET_NODISCARD virtual Result + init(BackendInitContext &context, FreeableBuffer *processed, + ArrayRef compile_specs) const = 0; + + /** + * Responsible for executing the given method’s handle, as it was produced + * by compile. + * + * @param[in] handle An opaque handle returned by `init()`. Usually a backend + * executable unit. This executable unit should be ready to execute the + * delegate blobs. + * @param[in] args The method’s inputs and outputs. + * @retval Error::Ok if successful. + */ + ET_NODISCARD virtual Error execute(BackendExecutionContext &context, + DelegateHandle *handle, + EValue **args) const = 0; + + /** + * Responsible for destroying a handle, if it's required for some backend. + * It may be needed for some backends. For example, resources associated with + * this handle needs to be released. This method is called when the execution + * plan is destroyed (i.e., the program is out of its lifespan). + * + * @param[in] handle The handle to be destroyed. An opaque handle returned by + * `init()`. + */ + virtual void destroy(ET_UNUSED DelegateHandle *handle) const {} +}; + +/** + * Returns the corresponding object pointer for a given string name. + * The mapping is populated using register_backend method. + * + * @param[in] name Name of the user-defined backend delegate. + * @retval Pointer to the appropriate object that implements BackendInterface. + * Nullptr if it can't find anything with the given name. + */ +BackendInterface *get_backend_class(const char *name); + +/** + * A named instance of a backend. + */ +struct Backend { + /// The name of the backend. Must match the string used in the PTE file. + const char *name; + /// The instance of the backend to use when loading and executing programs. + BackendInterface *backend; +}; + +/** + * Registers the Backend object (i.e. string name and BackendInterface pair) so + * that it could be called via the name during the runtime. + * + * @param[in] backend Backend object + * @retval Error code representing whether registration was successful. + */ +ET_NODISCARD Error register_backend(const Backend &backend); + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::Backend; +using ::executorch::runtime::CompileSpec; +using ::executorch::runtime::DelegateHandle; +using ::executorch::runtime::get_backend_class; +using ::executorch::runtime::register_backend; +using ::executorch::runtime::SizedBuffer; +using PyTorchBackendInterface = ::executorch::runtime::BackendInterface; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/array_ref.h b/third-party/include/executorch/runtime/core/array_ref.h new file mode 100644 index 00000000..e4468251 --- /dev/null +++ b/third-party/include/executorch/runtime/core/array_ref.h @@ -0,0 +1,223 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +// removed llvm-specific functionality +// removed some implicit const -> non-const conversions that rely on +// complicated std::enable_if meta-programming +// removed a bunch of slice variants for simplicity... +// remove constructors for std::array +// remove constructors and operators for std::vector +// removed some prevention of accidental assignments from temporary that +// required std::enable_if meta-programming +// removed reverse iterator + +#pragma once + +#include + +#include + +namespace executorch { +namespace runtime { + +/** + * Represents a constant reference to an array (0 or more elements + * consecutively in memory), i.e. a start pointer and a length. It allows + * various APIs to take consecutive elements easily and conveniently. + * + * This class does not own the underlying data, it is expected to be used in + * situations where the data resides in some other buffer, whose lifetime + * extends past that of the ArrayRef. For this reason, it is not in general + * safe to store an ArrayRef. + * + * Span and ArrayRef are extrememly similar with the difference being ArrayRef + * views a list of constant elements and Span views a list of mutable elements. + * Clients should decide between the two based on if the list elements for their + * use case should be mutable. + * + * This is intended to be trivially copyable, so it should be passed by + * value. + */ +template class ArrayRef final { +public: + using iterator = const T *; + using const_iterator = const T *; + using size_type = size_t; + using value_type = T; + +private: + /// The start of the array, in an external buffer. + const T *Data; + + /// The number of elements. + size_type Length; + +public: + /// @name Constructors + /// @{ + + /// Construct an empty ArrayRef. + /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {} + + /// Construct a ArrayRef from a single element. Implicitly convert element + /// type. It is aligned with PyTorch's c10::ArrayRef. + /* implicit */ constexpr ArrayRef(const T &OneElt) + : Data(&OneElt), Length(1) {} + + /// Construct a ArrayRef from a pointer and length. + ArrayRef(const T *data, size_t length) : Data(data), Length(length) { + ET_DCHECK(Data != nullptr || Length == 0); + } + + /// Construct a ArrayRef from a range. + ArrayRef(const T *begin, const T *end) : Data(begin), Length(end - begin) {} + + /// Construct a ArrayRef from a C array. + template + /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {} + + /// @} + /// @name Simple Operations + /// @{ + + constexpr iterator begin() const { return Data; } + constexpr iterator end() const { return Data + Length; } + + // These are actually the same as iterator, since ArrayRef only + // gives you const iterators. + constexpr const_iterator cbegin() const { return Data; } + constexpr const_iterator cend() const { return Data + Length; } + + /// empty - Check if the array is empty. + constexpr bool empty() const { return Length == 0; } + + constexpr const T *data() const { return Data; } + + /// size - Get the array size. + constexpr size_t size() const { return Length; } + + /// front - Get the first element. + const T &front() const { + // ArrayRef: attempted to access front() of empty list + ET_CHECK(!empty()); + return Data[0]; + } + + /// back - Get the last element. + const T &back() const { + // ArrayRef: attempted to access back() of empty list + ET_CHECK(!empty()); + return Data[Length - 1]; + } + + /// equals - Check for element-wise equality. + bool equals(ArrayRef RHS) const { + if (Length != RHS.Length) { + return false; + } + for (size_t i = 0; i < this->Length; i++) { + if (Data[i] != RHS.Data[i]) { + return false; + } + } + return true; + } + + /// slice(n, m) - Take M elements of the array starting at element N + ArrayRef slice(size_t N, size_t M) const { + // cant slice longer then the array + ET_CHECK(N + M <= size()); + return ArrayRef(data() + N, M); + } + + /// slice(n) - Chop off the first N elements of the array. + constexpr ArrayRef slice(size_t N) const { return slice(N, size() - N); } + + /// @} + /// @name Operator Overloads + /// @{ + constexpr const T &operator[](size_t Index) const { return Data[Index]; } + + /// Vector compatibility + const T &at(size_t Index) const { + // invalid index + ET_CHECK(Index < Length); + return Data[Index]; + } + + /// @} +}; + +/// @name ArrayRef Convenience constructors +/// @{ + +/// Construct an ArrayRef from a single element. +template ArrayRef makeArrayRef(const T &OneElt) { + return OneElt; +} + +/// Construct an ArrayRef from a pointer and length. +template ArrayRef makeArrayRef(const T *data, size_t length) { + return ArrayRef(data, length); +} + +/// Construct an ArrayRef from a range. +template ArrayRef makeArrayRef(const T *begin, const T *end) { + return ArrayRef(begin, end); +} + +/// Construct an ArrayRef from an ArrayRef (no-op) (const) +template ArrayRef makeArrayRef(const ArrayRef &Vec) { + return Vec; +} + +/// Construct an ArrayRef from an ArrayRef (no-op) +template ArrayRef &makeArrayRef(ArrayRef &Vec) { + return Vec; +} + +/// Construct an ArrayRef from a C array. +template ArrayRef makeArrayRef(const T (&Arr)[N]) { + return ArrayRef(Arr); +} + +// WARNING: Template instantiation will NOT be willing to do an implicit +// conversions to get you to an ArrayRef, which is why we need so +// many overloads. + +template bool operator==(ArrayRef a1, ArrayRef a2) { + return a1.equals(a2); +} + +template bool operator!=(ArrayRef a1, ArrayRef a2) { + return !a1.equals(a2); +} + +using IntArrayRef = ArrayRef; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::ArrayRef; +using ::executorch::runtime::IntArrayRef; +using ::executorch::runtime::makeArrayRef; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/data_loader.h b/third-party/include/executorch/runtime/core/data_loader.h new file mode 100644 index 00000000..25244fb9 --- /dev/null +++ b/third-party/include/executorch/runtime/core/data_loader.h @@ -0,0 +1,132 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace executorch { +namespace runtime { + +/** + * Loads from a data source. + * + * See //executorch/extension/data_loader for common implementations. + */ +class DataLoader { +public: + /** + * Describes the content of the segment. + */ + struct SegmentInfo { + /** + * Represents the purpose of the segment. + */ + enum class Type { + /** + * Data for the actual program. + */ + Program, + /** + * Holds constant tensor data. + */ + Constant, + /** + * Data used for initializing a backend. + */ + Backend, + /** + * Data used for initializing mutable tensors. + */ + Mutable, + }; + + /// Type of the segment. + Type segment_type; + + /// Index of the segment within the segment list. Undefined for program + /// segments. + size_t segment_index; + + /// An optional, null-terminated string describing the segment. For + /// `Backend` segments, this is the backend ID. Null for other segment + /// types. + const char *descriptor; + + SegmentInfo() = default; + + explicit SegmentInfo(Type segment_type, size_t segment_index = 0, + const char *descriptor = nullptr) + : segment_type(segment_type), segment_index(segment_index), + descriptor(descriptor) {} + }; + + virtual ~DataLoader() = default; + + /** + * Loads data from the underlying data source. + * + * NOTE: This must be thread-safe. If this call modifies common state, the + * implementation must do its own locking. + * + * @param offset The byte offset in the data source to start loading from. + * @param size The number of bytes to load. + * @param segment_info Information about the segment being loaded. + * + * @returns a `FreeableBuffer` that owns the loaded data. + */ + ET_NODISCARD virtual Result + load(size_t offset, size_t size, const SegmentInfo &segment_info) const = 0; + + /** + * Loads data from the underlying data source into the provided buffer. + * + * NOTE: This must be thread-safe. If this call modifies common state, the + * implementation must do its own locking. + * + * @param offset The byte offset in the data source to start loading from. + * @param size The number of bytes to load. + * @param segment_info Information about the segment being loaded. + * @param buffer The buffer to load data into. Must point to at least `size` + * bytes of memory. + * + * @returns an Error indicating if the load was successful. + */ + ET_NODISCARD virtual Error load_into(size_t offset, size_t size, + const SegmentInfo &segment_info, + void *buffer) const { + // Using a stub implementation here instead of pure virtual to expand the + // data_loader interface in a backwards compatible way. + (void)buffer; + (void)offset; + (void)size; + (void)segment_info; + ET_LOG(Error, "load_into() not implemented for this data loader."); + return Error::NotImplemented; + } + + /** + * Returns the length of the underlying data source, typically the file size. + */ + ET_NODISCARD virtual Result size() const = 0; +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::DataLoader; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/error.h b/third-party/include/executorch/runtime/core/error.h new file mode 100644 index 00000000..62ea8f16 --- /dev/null +++ b/third-party/include/executorch/runtime/core/error.h @@ -0,0 +1,207 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * ExecuTorch Error declarations. + */ + +#pragma once + +#include + +#include + +namespace executorch { +namespace runtime { + +// Alias error code integral type to minimal platform width (32-bits for now). +typedef uint32_t error_code_t; + +/** + * ExecuTorch Error type. + */ +enum class Error : error_code_t { + /* + * System errors. + */ + + /// Status indicating a successful operation. + Ok = 0x00, + + /// An internal error occurred. + Internal = 0x01, + + /// Status indicating the executor is in an invalid state for a target + /// operation + InvalidState = 0x2, + + /// Status indicating there are no more steps of execution to run + EndOfMethod = 0x03, + + /* + * Logical errors. + */ + + /// Operation is not supported in the current context. + NotSupported = 0x10, + + /// Operation is not yet implemented. + NotImplemented = 0x11, + + /// User provided an invalid argument. + InvalidArgument = 0x12, + + /// Object is an invalid type for the operation. + InvalidType = 0x13, + + /// Operator(s) missing in the operator registry. + OperatorMissing = 0x14, + + /* + * Resource errors. + */ + + /// Requested resource could not be found. + NotFound = 0x20, + + /// Could not allocate the requested memory. + MemoryAllocationFailed = 0x21, + + /// Could not access a resource. + AccessFailed = 0x22, + + /// Error caused by the contents of a program. + InvalidProgram = 0x23, + + /* + * Delegate errors. + */ + + /// Init stage: Backend receives an incompatible delegate version. + DelegateInvalidCompatibility = 0x30, + /// Init stage: Backend fails to allocate memory. + DelegateMemoryAllocationFailed = 0x31, + /// Execute stage: The handle is invalid. + DelegateInvalidHandle = 0x32, + +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::Error; +using ::executorch::runtime::error_code_t; +} // namespace executor +} // namespace torch + +/** + * If cond__ is false, log the specified message and return the specified Error + * from the current function, which must be of return type + * executorch::runtime::Error. + * + * @param[in] cond__ The condition to be checked, asserted as true. + * @param[in] error__ Error enum value to return without the `Error::` prefix, + * like `InvalidArgument`. + * @param[in] message__ Format string for the log error message. + * @param[in] ... Optional additional arguments for the format string. + */ +#define ET_CHECK_OR_RETURN_ERROR(cond__, error__, message__, ...) \ + { \ + if (!(cond__)) { \ + ET_LOG(Error, message__, ##__VA_ARGS__); \ + return ::executorch::runtime::Error::error__; \ + } \ + } + +/** + * If error__ is not Error::Ok, optionally log a message and return the error + * from the current function, which must be of return type + * executorch::runtime::Error. + * + * @param[in] error__ Error enum value asserted to be Error::Ok. + * @param[in] ... Optional format string for the log error message and its + * arguments. + */ +#define ET_CHECK_OK_OR_RETURN_ERROR(error__, ...) \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR(error__, ##__VA_ARGS__) + +// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. +#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, \ + 4, 3, 2, 1) \ + (__VA_ARGS__) + +/** + * Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. + * This macro selects the correct version of + * ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR based on the number of arguments passed. + * It uses a trick with the preprocessor to count the number of arguments and + * then selects the appropriate macro. + * + * The macro expansion uses __VA_ARGS__ to accept any number of arguments and + * then appends them to ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_, followed by the + * count of arguments. The count is determined by the macro + * ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT which takes the arguments and + * passes them along with a sequence of numbers (2, 1). The preprocessor then + * matches this sequence to the correct number of arguments provided. + * + * If two arguments are passed, ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 is + * selected, suitable for cases where an error code and a custom message are + * provided. If only one argument is passed, + * ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1 is selected, which is used for cases + * with just an error code. + * + * Usage: + * ET_CHECK_OK_OR_RETURN_ERROR(error_code); // Calls v1 + * ET_CHECK_OK_OR_RETURN_ERROR(error_code, "Error message", ...); // Calls v2 + */ +#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT(_1, _2, _3, _4, _5, _6, \ + _7, _8, _9, _10, N, ...) \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_##N + +// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. +#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1(error__) \ + do { \ + const auto et_error__ = (error__); \ + if (et_error__ != ::executorch::runtime::Error::Ok) { \ + return et_error__; \ + } \ + } while (0) + +// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. +#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2(error__, message__, ...) \ + do { \ + const auto et_error__ = (error__); \ + if (et_error__ != ::executorch::runtime::Error::Ok) { \ + ET_LOG(Error, message__, ##__VA_ARGS__); \ + return et_error__; \ + } \ + } while (0) + +// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. +#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_3 \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_4 \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_5 \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_6 \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_7 \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_8 \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_9 \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_10 \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 diff --git a/third-party/include/executorch/runtime/core/evalue.h b/third-party/include/executorch/runtime/core/evalue.h new file mode 100644 index 00000000..a4c307ff --- /dev/null +++ b/third-party/include/executorch/runtime/core/evalue.h @@ -0,0 +1,521 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include + +namespace executorch { +namespace runtime { + +struct EValue; + +namespace internal { + +// Tensor gets proper reference treatment because its expensive to copy in aten +// mode, all other types are just copied. +template struct evalue_to_const_ref_overload_return { + using type = T; +}; + +template <> +struct evalue_to_const_ref_overload_return { + using type = const executorch::aten::Tensor &; +}; + +template struct evalue_to_ref_overload_return { + using type = T; +}; + +template <> struct evalue_to_ref_overload_return { + using type = executorch::aten::Tensor &; +}; + +} // namespace internal + +/* + * Helper class used to correlate EValues in the executor table, with the + * unwrapped list of the proper type. Because values in the runtime's values + * table can change during execution, we cannot statically allocate list of + * objects at deserialization. Imagine the serialized list says index 0 in the + * value table is element 2 in the list, but during execution the value in + * element 2 changes (in the case of tensor this means the TensorImpl* stored in + * the tensor changes). To solve this instead they must be created dynamically + * whenever they are used. + */ +template class BoxedEvalueList { +public: + BoxedEvalueList() = default; + /* + * Wrapped_vals is a list of pointers into the values table of the runtime + * whose destinations correlate with the elements of the list, unwrapped_vals + * is a container of the same size whose serves as memory to construct the + * unwrapped vals. + */ + BoxedEvalueList(EValue **wrapped_vals, T *unwrapped_vals, int size) + : wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {} + /* + * Constructs and returns the list of T specified by the EValue pointers + */ + executorch::aten::ArrayRef get() const; + +private: + // Source of truth for the list + executorch::aten::ArrayRef wrapped_vals_; + // Same size as wrapped_vals + mutable T *unwrapped_vals_; +}; + +template <> +executorch::aten::ArrayRef> +BoxedEvalueList>::get() + const; + +// Aggregate typing system similar to IValue only slimmed down with less +// functionality, no dependencies on atomic, and fewer supported types to better +// suit embedded systems (ie no intrusive ptr) +struct EValue { + union Payload { + // When in ATen mode at::Tensor is not trivially copyable, this nested union + // lets us handle tensor as a special case while leaving the rest of the + // fields in a simple state instead of requiring a switch on tag everywhere. + union TriviallyCopyablePayload { + TriviallyCopyablePayload() : as_int(0) {} + // Scalar supported through these 3 types + int64_t as_int; + double as_double; + bool as_bool; + // TODO(jakeszwe): convert back to pointers to optimize size of this + // struct + executorch::aten::ArrayRef as_string; + executorch::aten::ArrayRef as_double_list; + executorch::aten::ArrayRef as_bool_list; + BoxedEvalueList as_int_list; + BoxedEvalueList as_tensor_list; + BoxedEvalueList> + as_list_optional_tensor; + } copyable_union; + + // Since a Tensor just holds a TensorImpl*, there's no value to use Tensor* + // here. + executorch::aten::Tensor as_tensor; + + Payload() {} + ~Payload() {} + }; + + // Data storage and type tag + Payload payload; + Tag tag; + + // Basic ctors and assignments + EValue(const EValue &rhs) : EValue(rhs.payload, rhs.tag) {} + + EValue(EValue &&rhs) noexcept : tag(rhs.tag) { moveFrom(std::move(rhs)); } + + EValue &operator=(EValue &&rhs) & noexcept { + if (&rhs == this) { + return *this; + } + + destroy(); + moveFrom(std::move(rhs)); + return *this; + } + + EValue &operator=(EValue const &rhs) & { + // Define copy assignment through copy ctor and move assignment + *this = EValue(rhs); + return *this; + } + + ~EValue() { destroy(); } + + /****** None Type ******/ + EValue() : tag(Tag::None) { payload.copyable_union.as_int = 0; } + + bool isNone() const { return tag == Tag::None; } + + /****** Int Type ******/ + /*implicit*/ EValue(int64_t i) : tag(Tag::Int) { + payload.copyable_union.as_int = i; + } + + bool isInt() const { return tag == Tag::Int; } + + int64_t toInt() const { + ET_CHECK_MSG(isInt(), "EValue is not an int."); + return payload.copyable_union.as_int; + } + + /****** Double Type ******/ + /*implicit*/ EValue(double d) : tag(Tag::Double) { + payload.copyable_union.as_double = d; + } + + bool isDouble() const { return tag == Tag::Double; } + + double toDouble() const { + ET_CHECK_MSG(isDouble(), "EValue is not a Double."); + return payload.copyable_union.as_double; + } + + /****** Bool Type ******/ + /*implicit*/ EValue(bool b) : tag(Tag::Bool) { + payload.copyable_union.as_bool = b; + } + + bool isBool() const { return tag == Tag::Bool; } + + bool toBool() const { + ET_CHECK_MSG(isBool(), "EValue is not a Bool."); + return payload.copyable_union.as_bool; + } + + /****** Scalar Type ******/ + /// Construct an EValue using the implicit value of a Scalar. + /*implicit*/ EValue(executorch::aten::Scalar s) { + if (s.isIntegral(false)) { + tag = Tag::Int; + payload.copyable_union.as_int = s.to(); + } else if (s.isFloatingPoint()) { + tag = Tag::Double; + payload.copyable_union.as_double = s.to(); + } else if (s.isBoolean()) { + tag = Tag::Bool; + payload.copyable_union.as_bool = s.to(); + } else { + ET_CHECK_MSG(false, "Scalar passed to EValue is not initialized."); + } + } + + bool isScalar() const { + return tag == Tag::Int || tag == Tag::Double || tag == Tag::Bool; + } + + executorch::aten::Scalar toScalar() const { + // Convert from implicit value to Scalar using implicit constructors. + + if (isDouble()) { + return toDouble(); + } else if (isInt()) { + return toInt(); + } else if (isBool()) { + return toBool(); + } else { + ET_CHECK_MSG(false, "EValue is not a Scalar."); + } + } + + /****** Tensor Type ******/ + /*implicit*/ EValue(executorch::aten::Tensor t) : tag(Tag::Tensor) { + // When built in aten mode, at::Tensor has a non trivial constructor + // destructor, so regular assignment to a union field is UB. Instead we must + // go through placement new (which causes a refcount bump). + new (&payload.as_tensor) executorch::aten::Tensor(t); + } + + // Template constructor that allows construction from types that can be + // dereferenced to produce a type that EValue can be implicitly constructed + // from. + template (std::declval())), // declval to + // simulate + // forwarding + EValue>::value>::type> + /*implicit*/ EValue(T &&value) { + ET_CHECK_MSG(value != nullptr, "Pointer is null."); + // Note that this ctor does not initialize this->tag directly; it is set by + // moving in the new value. + moveFrom(*std::forward(value)); + } + + // Delete constructor for raw pointers to ensure they cannot be used. + template explicit EValue(T *value) = delete; + + bool isTensor() const { return tag == Tag::Tensor; } + + executorch::aten::Tensor toTensor() && { + ET_CHECK_MSG(isTensor(), "EValue is not a Tensor."); + auto res = std::move(payload.as_tensor); + clearToNone(); + return res; + } + + executorch::aten::Tensor &toTensor() & { + ET_CHECK_MSG(isTensor(), "EValue is not a Tensor."); + return payload.as_tensor; + } + + const executorch::aten::Tensor &toTensor() const & { + ET_CHECK_MSG(isTensor(), "EValue is not a Tensor."); + return payload.as_tensor; + } + + /****** String Type ******/ + /*implicit*/ EValue(const char *s, size_t size) : tag(Tag::String) { + payload.copyable_union.as_string = + executorch::aten::ArrayRef(s, size); + } + + bool isString() const { return tag == Tag::String; } + + executorch::aten::string_view toString() const { + ET_CHECK_MSG(isString(), "EValue is not a String."); + return executorch::aten::string_view( + payload.copyable_union.as_string.data(), + payload.copyable_union.as_string.size()); + } + + /****** Int List Type ******/ + /*implicit*/ EValue(BoxedEvalueList i) : tag(Tag::ListInt) { + payload.copyable_union.as_int_list = i; + } + + bool isIntList() const { return tag == Tag::ListInt; } + + executorch::aten::ArrayRef toIntList() const { + ET_CHECK_MSG(isIntList(), "EValue is not an Int List."); + return payload.copyable_union.as_int_list.get(); + } + + /****** Bool List Type ******/ + /*implicit*/ EValue(executorch::aten::ArrayRef b) : tag(Tag::ListBool) { + payload.copyable_union.as_bool_list = b; + } + + bool isBoolList() const { return tag == Tag::ListBool; } + + executorch::aten::ArrayRef toBoolList() const { + ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List."); + return payload.copyable_union.as_bool_list; + } + + /****** Double List Type ******/ + /*implicit*/ EValue(executorch::aten::ArrayRef d) + : tag(Tag::ListDouble) { + payload.copyable_union.as_double_list = d; + } + + bool isDoubleList() const { return tag == Tag::ListDouble; } + + executorch::aten::ArrayRef toDoubleList() const { + ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List."); + return payload.copyable_union.as_double_list; + } + + /****** Tensor List Type ******/ + /*implicit*/ EValue(BoxedEvalueList t) + : tag(Tag::ListTensor) { + payload.copyable_union.as_tensor_list = t; + } + + bool isTensorList() const { return tag == Tag::ListTensor; } + + executorch::aten::ArrayRef toTensorList() const { + ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List."); + return payload.copyable_union.as_tensor_list.get(); + } + + /****** List Optional Tensor Type ******/ + /*implicit*/ EValue( + BoxedEvalueList> t) + : tag(Tag::ListOptionalTensor) { + payload.copyable_union.as_list_optional_tensor = t; + } + + bool isListOptionalTensor() const { return tag == Tag::ListOptionalTensor; } + + executorch::aten::ArrayRef< + executorch::aten::optional> + toListOptionalTensor() const { + return payload.copyable_union.as_list_optional_tensor.get(); + } + + /****** ScalarType Type ******/ + executorch::aten::ScalarType toScalarType() const { + ET_CHECK_MSG(isInt(), "EValue is not a ScalarType."); + return static_cast( + payload.copyable_union.as_int); + } + + /****** MemoryFormat Type ******/ + executorch::aten::MemoryFormat toMemoryFormat() const { + ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat."); + return static_cast( + payload.copyable_union.as_int); + } + + /****** Layout Type ******/ + executorch::aten::Layout toLayout() const { + ET_CHECK_MSG(isInt(), "EValue is not a Layout."); + return static_cast(payload.copyable_union.as_int); + } + + /****** Device Type ******/ + executorch::aten::Device toDevice() const { + ET_CHECK_MSG(isInt(), "EValue is not a Device."); + return executorch::aten::Device(static_cast( + payload.copyable_union.as_int), + -1); + } + + template T to() &&; + template + typename internal::evalue_to_const_ref_overload_return::type to() const &; + template + typename internal::evalue_to_ref_overload_return::type to() &; + + /** + * Converts the EValue to an optional object that can represent both T and + * an uninitialized state. + */ + template + inline executorch::aten::optional toOptional() const { + if (this->isNone()) { + return executorch::aten::nullopt; + } + return this->to(); + } + +private: + // Pre cond: the payload value has had its destructor called + void clearToNone() noexcept { + payload.copyable_union.as_int = 0; + tag = Tag::None; + } + + // Shared move logic + void moveFrom(EValue &&rhs) noexcept { + if (rhs.isTensor()) { + new (&payload.as_tensor) + executorch::aten::Tensor(std::move(rhs.payload.as_tensor)); + rhs.payload.as_tensor.~Tensor(); + } else { + payload.copyable_union = rhs.payload.copyable_union; + } + tag = rhs.tag; + rhs.clearToNone(); + } + + // Destructs stored tensor if there is one + void destroy() { + // Necessary for ATen tensor to refcount decrement the intrusive_ptr to + // tensorimpl that got a refcount increment when we placed it in the evalue, + // no-op if executorch tensor #ifdef could have a + // minor performance bump for a code maintainability hit + if (isTensor()) { + payload.as_tensor.~Tensor(); + } else if (isTensorList()) { + for (auto &tensor : toTensorList()) { + tensor.~Tensor(); + } + } else if (isListOptionalTensor()) { + for (auto &optional_tensor : toListOptionalTensor()) { + optional_tensor.~optional(); + } + } + } + + EValue(const Payload &p, Tag t) : tag(t) { + if (isTensor()) { + new (&payload.as_tensor) executorch::aten::Tensor(p.as_tensor); + } else { + payload.copyable_union = p.copyable_union; + } + } +}; + +#define EVALUE_DEFINE_TO(T, method_name) \ + template <> inline T EValue::to() && { \ + return static_cast(std::move(*this).method_name()); \ + } \ + template <> \ + inline ::executorch::runtime::internal::evalue_to_const_ref_overload_return< \ + T>::type \ + EValue::to() const & { \ + typedef ::executorch::runtime::internal:: \ + evalue_to_const_ref_overload_return::type return_type; \ + return static_cast(this->method_name()); \ + } \ + template <> \ + inline ::executorch::runtime::internal::evalue_to_ref_overload_return< \ + T>::type \ + EValue::to() & { \ + typedef ::executorch::runtime::internal::evalue_to_ref_overload_return< \ + T>::type return_type; \ + return static_cast(this->method_name()); \ + } + +EVALUE_DEFINE_TO(executorch::aten::Scalar, toScalar) +EVALUE_DEFINE_TO(int64_t, toInt) +EVALUE_DEFINE_TO(bool, toBool) +EVALUE_DEFINE_TO(double, toDouble) +EVALUE_DEFINE_TO(executorch::aten::string_view, toString) +EVALUE_DEFINE_TO(executorch::aten::ScalarType, toScalarType) +EVALUE_DEFINE_TO(executorch::aten::MemoryFormat, toMemoryFormat) +EVALUE_DEFINE_TO(executorch::aten::Layout, toLayout) +EVALUE_DEFINE_TO(executorch::aten::Device, toDevice) +// Tensor and Optional Tensor +EVALUE_DEFINE_TO(executorch::aten::optional, + toOptional) +EVALUE_DEFINE_TO(executorch::aten::Tensor, toTensor) + +// IntList and Optional IntList +EVALUE_DEFINE_TO(executorch::aten::ArrayRef, toIntList) +EVALUE_DEFINE_TO( + executorch::aten::optional>, + toOptional>) + +// DoubleList and Optional DoubleList +EVALUE_DEFINE_TO(executorch::aten::ArrayRef, toDoubleList) +EVALUE_DEFINE_TO(executorch::aten::optional>, + toOptional>) + +// BoolList and Optional BoolList +EVALUE_DEFINE_TO(executorch::aten::ArrayRef, toBoolList) +EVALUE_DEFINE_TO(executorch::aten::optional>, + toOptional>) + +// TensorList and Optional TensorList +EVALUE_DEFINE_TO(executorch::aten::ArrayRef, + toTensorList) +EVALUE_DEFINE_TO( + executorch::aten::optional< + executorch::aten::ArrayRef>, + toOptional>) + +// List of Optional Tensor +EVALUE_DEFINE_TO(executorch::aten::ArrayRef< + executorch::aten::optional>, + toListOptionalTensor) +#undef EVALUE_DEFINE_TO + +template +executorch::aten::ArrayRef BoxedEvalueList::get() const { + for (typename executorch::aten::ArrayRef::size_type i = 0; + i < wrapped_vals_.size(); i++) { + ET_CHECK(wrapped_vals_[i] != nullptr); + unwrapped_vals_[i] = wrapped_vals_[i]->template to(); + } + return executorch::aten::ArrayRef{unwrapped_vals_, wrapped_vals_.size()}; +} + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::BoxedEvalueList; +using ::executorch::runtime::EValue; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/event_tracer.h b/third-party/include/executorch/runtime/core/event_tracer.h new file mode 100644 index 00000000..ffb60066 --- /dev/null +++ b/third-party/include/executorch/runtime/core/event_tracer.h @@ -0,0 +1,500 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#pragma once + +namespace executorch { +namespace runtime { + +/// Represents an allocator id returned by track_allocator. +typedef uint32_t AllocatorID; +/// Represents the chain id that will be passed in by the user during +/// event logging. +typedef int32_t ChainID; +/// Represents the debug handle that is generally associated with each +/// op executed in the runtime. +typedef uint32_t DebugHandle; + +/// Default id's for chain id and debug handle. +constexpr ChainID kUnsetChainId = -1; +constexpr DebugHandle kUnsetDebugHandle = 0; +// Default bundled input index to indicate that it hasn't been set yet. +constexpr int kUnsetBundledInputIndex = -1; + +/// Different types of delegate debug identifiers that are supported currently. +enum class DelegateDebugIdType { + /// Default value, indicates that it's not a delegate event. + kNone, + /// Indicates a delegate event logged using an integer delegate debug + /// identifier. + kInt, + /// Indicates a delegate event logged using a string delegate debug + /// identifier i.e. the delegate debug id is a pointer to a string table + /// managed by the class implementing EventTracer functionality. + kStr +}; + +/// Indicates the type of the EValue that was logged. These values could be +/// serialized and should not be changed. +enum class LoggedEValueType { + /// Intermediate output from an operator. + kIntermediateOutput = 0, + /// Output at the program level. This is essentially the output + /// of the model. + kProgramOutput = 1, +}; + +/// Indicates the level of event tracer debug logging. Verbosity of the logging +/// increases as we go down the enum list. +enum class EventTracerDebugLogLevel { + /// No logging. + kNoLogging, + /// When set to this only the program level outputs will be logged. + kProgramOutputs, + /// When set to this all intermediate outputs and program level outputs + /// will be logged. + kIntermediateOutputs, +}; + +/** + * Indicates the level of profiling that should be enabled. Profiling + * events will be logged in increasing order of verbosity as we go down the + * enum list. Thus it is important to keep the enum values in the right order. + */ +enum class EventTracerProfilingLevel { + /// No operator profiling. + kProfileMethodOnly, + /// All profiling events enabled. + kProfileAllEvents, +}; + +/** + * This is the struct which should be returned when a profiling event is + * started. This is used to uniquely identify that profiling event and will be + * required to be passed into the end_profiling call to signal that the event + * identified by this struct has completed. + **/ +struct EventTracerEntry { + /// An event id to uniquely identify this event that was generated during a + /// call to start the tracking of an event. + int64_t event_id; + /// The chain to which this event belongs to. + ChainID chain_id; + /// The debug handle corresponding to this event. + DebugHandle debug_handle; + /// The time at which this event was started to be tracked. + et_timestamp_t start_time; + /// When delegate_event_id_type != DelegateDebugIdType::kNone it indicates + /// that event_id represents a delegate event. If delegate_event_id_type is: + /// 1) kInt then event_id contains an integer delegate debug id. + /// 2) kStr then event_id contains a string table index into a string table + /// maintained by the class implementing EventTracer functionality that will + /// give us the string identifier of this delegate event. For more details + /// refer to the DelegateMappingBuilder library present in + /// executorch/exir/backend/utils.py. + DelegateDebugIdType delegate_event_id_type; +}; +/** + * EventTracer is a class that users can inherit and implement to + * log/serialize/stream etc. the profiling and debugging events that are + * generated at runtime for a model. An example of this is the ETDump + * implementation in the devtools codebase that serializes these events to a + * flatbuffer. + */ +class EventTracer { +public: + /** + * Start a new event block (can consist of profiling and/or debugging events.) + * identified by this name. A block is conceptually a set of events that we + * want to group together. e.g. all the events that occur during the call to + * execute() (i.e. model inference) could be categorized as a block. + * + * @param[in] name A human readable identifier for the event block. Users + * calling this interface do not need to keep the memory pointed to by this + * pointer around. The string must be copied over into internal memory during + * this call. + */ + virtual void create_event_block(const char *name) = 0; + + /** + * Start the profiling of the event identified by name and debug_handle. + * The user can pass in a chain_id and debug_handle to this call, or leave + * them empty (default values) which would then result in the chain_id and + * debug handle stored within (set by set_chain_debug_handle) this class to be + * used. + * @param[in] name Human readable name for the profiling event. Users calling + * this interface do not need to keep the memory pointed to by this pointer + * around. The string must be copied over into internal memory during this + * call. + * @param[in] chain_id The id of the chain to which this event belongs to. If + * kUnsetChainId is passed in the chain_id and kUnsetDebugHandle for + * debug_handle then the values stored in the class internally for these + * properties will be used. + * @param[in] debug_handle Debug handle generated ahead-of-time during model + * compilation. + * + * @return Returns an instance of EventTracerEntry which should be passed back + * into the end_profiling() call. + */ + virtual EventTracerEntry + start_profiling(const char *name, ChainID chain_id = kUnsetChainId, + DebugHandle debug_handle = kUnsetDebugHandle) = 0; + + /** + * Start the profiling of a delegate event. Similar to start_profiling it will + * return an instance of EventTracerEntry that contains the details of this + * event. + * + * @param[in] name Human readable name for the delegate event. This name has + * to be the same name that was passed in during the Debug delegate mapping + * generation in the export/ahead-of-time process. If indices and not names + * are used by this delegate to identify ops executed in the backend then + * nullptr can be passed in. Users calling this interface do not need to keep + * the memory pointed to by this pointer around. The string must be copied + * over into internal memory during this call. + * @param[in] delegate_debug_index The id of the delegate event. If string + * based names are used by this delegate to identify ops executed in the + * backend then kUnsetDebugHandle should be passed in here. + */ + virtual EventTracerEntry + start_profiling_delegate(const char *name, + DebugHandle delegate_debug_index) = 0; + + /** + * Signal the end of the delegate profiling event contained in + * event_tracer_entry. Users also have the option to log some some free-from + * string based metadata along with this. + * + * @param[in] event_tracer_entry The EventTracerEntry returned by a call to + * start_profiling_delegate(). + * @param[in] metadata Optional data relevant to the execution that the user + * wants to log along with this event. Pointer to metadata doesn't need to be + * valid after the call to this function. The contents and format of the data + * are transparent to the event tracer. It will just pipe along the data and + * make it available for the user again in the post-processing stage. + * @param[in] metadata_len Length of the metadata buffer. + */ + virtual void end_profiling_delegate(EventTracerEntry event_tracer_entry, + const void *metadata = nullptr, + size_t metadata_len = 0) = 0; + + /** + * Some delegates get access to the profiling details only after the complete + * graph has been executed. This interface is to support such use cases. It + * can be called in a loop etc. to log any number of profiling events that are + * part of this delegate. + * + * @param[in] name Human readable name for the delegate event. This name has + * to be the same name that was passed in during the Debug delegate mapping + * generation in the export/ahead-of-time process. If indices and not names + * are used by this delegate to identify ops executed in the backend then + * nullptr can be passed in. Users calling this interface do not need to keep + * the memory pointed to by this pointer around. The string must be copied + * over into internal memory during this call. + * @param[in] delegate_debug_index The id of the delegate event. If string + * based names are used by this delegate to identify ops executed in the + * backend then kUnsetDebugHandle should be passed in here. + * @param[in] start_time The timestamp when the delegate event started. + * @param[in] end_time The timestamp when the delegate event finished. + * @param[in] metadata Optional data relevant to the execution that the user + * wants to log along with this event. Pointer to metadata doesn't need to be + * valid after the call to this function. The contents and format of the data + * are transparent to the event tracer. It will just pipe along the data and + * make it available for the user again in the post-processing stage. + * @param[in] metadata_len Length of the metadata buffer. + */ + virtual void log_profiling_delegate(const char *name, + DebugHandle delegate_debug_index, + et_timestamp_t start_time, + et_timestamp_t end_time, + const void *metadata = nullptr, + size_t metadata_len = 0) = 0; + + /** + * End the profiling of the event identified by prof_entry + * + * @param[in] prof_entry Value returned by a call to start_profiling + */ + virtual void end_profiling(EventTracerEntry prof_entry) = 0; + + /** + * Track this allocation done via a MemoryAllocator which had profiling + * enabled on it. + * + * @param[in] id Allocator id generated by a call to track_allocator. + * @param[in] size The size of the allocation done, in bytes. + */ + virtual void track_allocation(AllocatorID id, size_t size) = 0; + + /** + * Generate an allocator id for this memory allocator that will be used in the + * future to identify all the allocations done by this allocator. + * + * @param[in] name Human readable name for the allocator. Users calling + * this interface do not need to keep the memory pointed to by this pointer + * around. The string should be copied over into internal memory during this + * call. + * + * @return Identifier to uniquely identify this allocator. + */ + virtual AllocatorID track_allocator(const char *name) = 0; + + /** + * Log an evalue during the execution of the model. This is useful for + * debugging purposes. Model outputs are a special case of this and will + * be logged with the output bool enabled. + * + * Users of this should refer to the chain_id and debug_handle to get the + * context for these evalues and their corresponding op. + * + * @param[in] evalue The value to be logged. + * @param[in] evalue_type Indicates what type of output this is logging e.g. + * an intermediate output, program output etc. + */ + virtual void log_evalue(const EValue &evalue, + LoggedEValueType evalue_type) = 0; + + /** + * Log an intermediate tensor output from a delegate. + * + * @param[in] name Human readable name for the delegate event. This name has + * to be the same name that was passed in during the Debug delegate mapping + * generation in the export/ahead-of-time process. If indices and not names + * are used by this delegate to identify ops executed in the backend then + * nullptr can be passed in. Users calling this interface do not need to keep + * the memory pointed to by this pointer around. The string must be copied + * over into internal memory during this call. + * @param[in] delegate_debug_index The id of the delegate event. If string + * based names are used by this delegate to identify ops executed in the + * backend then kUnsetDebugHandle should be passed in here. + * @param[in] output The tensor type output to be logged. + */ + virtual void + log_intermediate_output_delegate(const char *name, + DebugHandle delegate_debug_index, + const executorch::aten::Tensor &output) = 0; + + /** + * Log an intermediate tensor array output from a delegate. + * + * @param[in] name Human readable name for the delegate event. This name has + * to be the same name that was passed in during the Debug delegate mapping + * generation in the export/ahead-of-time process. If indices and not names + * are used by this delegate to identify ops executed in the backend then + * nullptr can be passed in. Users calling this interface do not need to keep + * the memory pointed to by this pointer around. The string must be copied + * over into internal memory during this call. + * @param[in] delegate_debug_index The id of the delegate event. If string + * based names are used by this delegate to identify ops executed in the + * backend then kUnsetDebugHandle should be passed in here. + * @param[in] output The tensor array type output to be logged. + */ + virtual void log_intermediate_output_delegate( + const char *name, DebugHandle delegate_debug_index, + const ArrayRef output) = 0; + + /** + * Log an intermediate int output from a delegate. + * + * @param[in] name Human readable name for the delegate event. This name has + * to be the same name that was passed in during the Debug delegate mapping + * generation in the export/ahead-of-time process. If indices and not names + * are used by this delegate to identify ops executed in the backend then + * nullptr can be passed in. Users calling this interface do not need to keep + * the memory pointed to by this pointer around. The string must be copied + * over into internal memory during this call. + * @param[in] delegate_debug_index The id of the delegate event. If string + * based names are used by this delegate to identify ops executed in the + * backend then kUnsetDebugHandle should be passed in here. + * @param[in] output The int type output to be logged. + */ + virtual void + log_intermediate_output_delegate(const char *name, + DebugHandle delegate_debug_index, + const int &output) = 0; + + /** + * Log an intermediate bool output from a delegate. + * + * @param[in] name Human readable name for the delegate event. This name has + * to be the same name that was passed in during the Debug delegate mapping + * generation in the export/ahead-of-time process. If indices and not names + * are used by this delegate to identify ops executed in the backend then + * nullptr can be passed in. Users calling this interface do not need to keep + * the memory pointed to by this pointer around. The string must be copied + * over into internal memory during this call. + * @param[in] delegate_debug_index The id of the delegate event. If string + * based names are used by this delegate to identify ops executed in the + * backend then kUnsetDebugHandle should be passed in here. + * @param[in] output The bool type output to be logged. + */ + virtual void + log_intermediate_output_delegate(const char *name, + DebugHandle delegate_debug_index, + const bool &output) = 0; + + /** + * Log an intermediate double output from a delegate. + * + * @param[in] name Human readable name for the delegate event. This name has + * to be the same name that was passed in during the Debug delegate mapping + * generation in the export/ahead-of-time process. If indices and not names + * are used by this delegate to identify ops executed in the backend then + * nullptr can be passed in. Users calling this interface do not need to keep + * the memory pointed to by this pointer around. The string must be copied + * over into internal memory during this call. + * @param[in] delegate_debug_index The id of the delegate event. If string + * based names are used by this delegate to identify ops executed in the + * backend then kUnsetDebugHandle should be passed in here. + * @param[in] output The double type output to be logged. + */ + virtual void + log_intermediate_output_delegate(const char *name, + DebugHandle delegate_debug_index, + const double &output) = 0; + + /** + * Helper function to set the chain id ands debug handle. Users have two + * options, the first is that they can directly pass in the chain id and debug + * handle to start_profiling or they can explicitly set them through this + * helper before calling start_profiling. + * + * The reason this helper exists is to + * solve a specific problem. We want to do profiling logging inside the + * codegen layer which calls the kernels. The problem though is that the + * codegen layer doesn't have access to these ids when calling + * start_profiling. + * + * Users should ideally use these within a RAII scope interface to make sure + * that these values are unset after the end_profiling call. If non-default + * values are passed into the start_profiling call they will always be given + * precedence over the values set by this interface. + * + * So what we do is call this helper in method.cpp before + * we hit the codegen layer and in the codegen layer we do a start_profiling + * call without passing in a chain_id or debug_handle. This ensures that the + * values set via this helper are the ones associated with that call. + * + * @param[in] chain_id Chain id of the current instruction being exectuted. + * @param[in] debug_handle Debug handle of the current instruction being + * executed. In this context debug handle and instruction id are the same + * thing. + */ + void set_chain_debug_handle(ChainID chain_id, DebugHandle debug_handle) { + chain_id_ = chain_id; + debug_handle_ = debug_handle; + } + + /** + * When running a program wrapped in a bundled program, log the bundled input + * index of the current bundled input being tested out on this method. + * If users want to unset the index back to the default value, they can call + * this method with kUnsetBundledInputIndex. + * + * @param[in] bundled_input_index Index of the current input being tested + */ + void set_bundled_input_index(int bundled_input_index) { + bundled_input_index_ = bundled_input_index; + } + + /** + * Return the current bundled input index. + */ + int bundled_input_index() { return bundled_input_index_; } + + /** + * Set the level of event tracer debug logging that is desired. + * + */ + void set_event_tracer_debug_level(EventTracerDebugLogLevel log_level) { + event_tracer_debug_level_ = log_level; + } + + /** + * Return the current level of event tracer debug logging. + */ + EventTracerDebugLogLevel event_tracer_debug_level() { + return event_tracer_debug_level_; + } + + /** + * Set the level of event tracer profiling that is desired. + */ + void + set_event_tracer_profiling_level(EventTracerProfilingLevel profiling_level) { + event_tracer_profiling_level_ = profiling_level; + } + + /** + * Return the current level of event tracer profiling. + */ + EventTracerProfilingLevel event_tracer_profiling_level() { + return event_tracer_profiling_level_; + } + + /** + * Return the current status of intermediate outputs logging mode. + */ + bool intermediate_outputs_logging_status() { + return log_intermediate_tensors_; + } + + /** + * Get the current chain id. + * + * @return Current chain id. + */ + ChainID current_chain_id() { return chain_id_; } + + /** + * Get the current debug handle. + * + * @return Current debug handle. + */ + DebugHandle current_debug_handle() { return debug_handle_; } + + virtual ~EventTracer() {} + +protected: + ChainID chain_id_ = kUnsetChainId; + DebugHandle debug_handle_ = kUnsetDebugHandle; + bool event_tracer_enable_debugging_ = false; + bool log_intermediate_tensors_ = false; + int bundled_input_index_ = kUnsetBundledInputIndex; + EventTracerDebugLogLevel event_tracer_debug_level_ = + EventTracerDebugLogLevel::kNoLogging; + EventTracerProfilingLevel event_tracer_profiling_level_ = + EventTracerProfilingLevel::kProfileAllEvents; +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::AllocatorID; +using ::executorch::runtime::ChainID; +using ::executorch::runtime::DebugHandle; +using ::executorch::runtime::DelegateDebugIdType; +using ::executorch::runtime::EventTracer; +using ::executorch::runtime::EventTracerDebugLogLevel; +using ::executorch::runtime::EventTracerEntry; +using ::executorch::runtime::kUnsetBundledInputIndex; +using ::executorch::runtime::kUnsetChainId; +using ::executorch::runtime::kUnsetDebugHandle; +using ::executorch::runtime::LoggedEValueType; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/event_tracer_hooks.h b/third-party/include/executorch/runtime/core/event_tracer_hooks.h new file mode 100644 index 00000000..1e46013c --- /dev/null +++ b/third-party/include/executorch/runtime/core/event_tracer_hooks.h @@ -0,0 +1,323 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +/** + * @file + * + * This file contains the hooks that are inserted across various parts of the + * core runtime code to call into the EventTracer class for logging of profiling + * and debugging events. Any calls made to the EventTracer from the runtime must + * be made via these hooks. + * Users shouldn't directly add these hooks in their code and it's meant only + * for usage in ExecuTorch internal code. + * + * The benefit of defining these hooks is that we can easily control whether or + * not we want to compile in the EventTracer code based on the status of the + * ET_EVENT_TRACER_ENABLED flag. + * + * TODO(dbort): Make this a private header of runtime/executor. It only contains + * runtime-internal functions and should not be part of the public set of + * headers. + */ + +namespace executorch { +namespace runtime { +namespace internal { + +/** + * This class enables scope based profiling where needed using RAII for + * operators only. If operator profiling is disabled then this class is a no-op. + */ +class EventTracerProfileOpScope final { +public: + EventTracerProfileOpScope(EventTracer *event_tracer, const char *name) { +#ifdef ET_EVENT_TRACER_ENABLED + event_tracer_ = event_tracer; + if (event_tracer_ == nullptr) { + return; + } + if (event_tracer_->event_tracer_profiling_level() > + executorch::runtime::EventTracerProfilingLevel::kProfileMethodOnly) { + event_entry_ = event_tracer->start_profiling(name); + } +#else //! ET_EVENT_TRACER_ENABLED + (void)event_tracer; + (void)name; +#endif + } + + ~EventTracerProfileOpScope() { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer_ == nullptr) { + return; + } + if (event_tracer_->event_tracer_profiling_level() > + executorch::runtime::EventTracerProfilingLevel::kProfileMethodOnly) { + event_tracer_->end_profiling(event_entry_); + } +#endif + } + +private: +#ifdef ET_EVENT_TRACER_ENABLED + EventTracer *event_tracer_; + EventTracerEntry event_entry_; +#endif +}; + +using EventTracerProfileScope = EventTracerProfileOpScope; + +/** + * This class enables scope based profiling where needed using RAII. + * Profiling will be started when the object is created and will end + * when the object goes out of scope. This is specifically intended to + * be used for profiling methods in the runtime. + */ +class EventTracerProfileMethodScope final { +public: + EventTracerProfileMethodScope(EventTracer *event_tracer, const char *name) { +#ifdef ET_EVENT_TRACER_ENABLED + event_tracer_ = event_tracer; + if (event_tracer_ == nullptr) { + return; + } + event_entry_ = event_tracer->start_profiling(name); +#else //! ET_EVENT_TRACER_ENABLED + (void)event_tracer; + (void)name; +#endif + } + + ~EventTracerProfileMethodScope() { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer_ == nullptr) { + return; + } + event_tracer_->end_profiling(event_entry_); +#endif + } + +private: +#ifdef ET_EVENT_TRACER_ENABLED + EventTracer *event_tracer_; + EventTracerEntry event_entry_; +#endif +}; + +/** + * This class helps us set and then clear out the chain id and debug handle + * values stored in the event tracer class using RAII. This is typically called + * in the executor loop before entering the codegen layer to configure the chain + * id and debug handle of the current instruction being executed. + * After we return from the kernel execution we can then reset the chain id and + * debug handle to defaults when this object goes out of scope. + */ +class EventTracerProfileInstructionScope final { +public: + EventTracerProfileInstructionScope(EventTracer *event_tracer, + ChainID chain_idx, + DebugHandle debug_handle) { +#ifdef ET_EVENT_TRACER_ENABLED + event_tracer_ = event_tracer; + if (event_tracer_ == nullptr) { + return; + } + event_tracer_->set_chain_debug_handle(chain_idx, debug_handle); +#else //! ET_EVENT_TRACER_ENABLED + (void)event_tracer; + (void)chain_idx; + (void)debug_handle; +#endif + } + + ~EventTracerProfileInstructionScope() { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer_ == nullptr) { + return; + } + event_tracer_->set_chain_debug_handle(kUnsetChainId, kUnsetDebugHandle); +#endif + } + +private: +#ifdef ET_EVENT_TRACER_ENABLED + EventTracer *event_tracer_; +#endif +}; + +inline bool event_tracer_enabled() { +#ifdef ET_EVENT_TRACER_ENABLED + return true; +#else //! ET_EVENT_TRACER_ENABLED + return false; +#endif +} +/** + * Create a new event block with the specified name. Any events logged + * after this will be associated with this new event block. + */ +inline void event_tracer_create_event_block(EventTracer *event_tracer, + char const *name) { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer) { + event_tracer->create_event_block(name); + } +#else //! ET_EVENT_TRACER_ENABLED + (void)event_tracer; + (void)name; +#endif +} + +/** + * Explicitly mark the beginning of a new profiling event. This returns + * an instance of an EventTracerEntry object that the user needs to keep + * around and pass into the corresponding event_tracer_end_profiling_event + * call. + */ +inline EventTracerEntry +event_tracer_begin_profiling_event(EventTracer *event_tracer, + char const *name) { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer) { + return event_tracer->start_profiling(name); + } +#else //! ET_EVENT_TRACER_ENABLED + (void)event_tracer; + (void)name; +#endif + // There is no active tracer; this value will be ignored. + return EventTracerEntry(); +} + +/** + * Mark the end of a profiling event passing in the entry token + * returned by a previous call to ET_EVENT_TRACER_BEGIN_PROFILING_EVENT. + */ +inline void event_tracer_end_profiling_event(EventTracer *event_tracer, + EventTracerEntry event) { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer) { + event_tracer->end_profiling(event); + } +#else //! ET_EVENT_TRACER_ENABLED + (void)event_tracer; + (void)event; +#endif +} + +/** + * Start the tracking of the allocator represented by this name and returns + * an AllocatorID that will be used to track all subsequent allocations done by + * this allocator. + */ +inline AllocatorID event_tracer_track_allocator(EventTracer *event_tracer, + const char *name) { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer) { + return event_tracer->track_allocator(name); + } +#else //! ET_EVENT_TRACER_ENABLED + (void)event_tracer; + (void)name; +#endif + // There is no active tracer; this value will be ignored. + return 0; +} + +/// Log the allocation event done via the allocator represented by id. +inline void event_tracer_track_allocation(EventTracer *event_tracer, + AllocatorID id, size_t size) { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer) { + event_tracer->track_allocation(id, size); + } +#else //! ET_EVENT_TRACER_ENABLED + (void)event_tracer; + (void)id; + (void)size; +#endif +} + +/// Log an intermediate value. +inline void event_tracer_log_evalue(EventTracer *event_tracer, EValue &evalue) { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer) { + if (event_tracer->event_tracer_debug_level() >= + EventTracerDebugLogLevel::kIntermediateOutputs) { + event_tracer->log_evalue(evalue, LoggedEValueType::kIntermediateOutput); + } + } +#else //! ET_EVENT_TRACER_ENABLED + (void)event_tracer; + (void)evalue; +#endif +} + +/// Log a program output. +inline void event_tracer_log_evalue_output(EventTracer *event_tracer, + const EValue &evalue) { +#ifdef ET_EVENT_TRACER_ENABLED + /* + * If debugging via event tracer is enabled but intermediate output logging is + * disabled then we want to only log the outputs. + */ + if (event_tracer) { + if (event_tracer->event_tracer_debug_level() >= + EventTracerDebugLogLevel::kProgramOutputs) { + event_tracer->log_evalue(evalue, LoggedEValueType::kProgramOutput); + } + } +#else //! ET_EVENT_TRACER_ENABLED + (void)event_tracer; + (void)evalue; +#endif +} + +// Set the bundled input index of the current bundled input being used by the +// method. +inline void event_tracer_set_bundled_input_index(EventTracer *event_tracer, + int bundled_input_index) { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer) { + event_tracer->set_bundled_input_index(bundled_input_index); + } +#else //! ET_EVENT_TRACER_ENABLED + (void)event_tracer; + (void)bundled_input_index; +#endif +} + +} // namespace internal +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +namespace internal { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::internal::event_tracer_begin_profiling_event; +using ::executorch::runtime::internal::event_tracer_create_event_block; +using ::executorch::runtime::internal::event_tracer_end_profiling_event; +using ::executorch::runtime::internal::event_tracer_log_evalue; +using ::executorch::runtime::internal::event_tracer_log_evalue_output; +using ::executorch::runtime::internal::event_tracer_set_bundled_input_index; +using ::executorch::runtime::internal::event_tracer_track_allocation; +using ::executorch::runtime::internal::event_tracer_track_allocator; +using ::executorch::runtime::internal::EventTracerProfileInstructionScope; +using ::executorch::runtime::internal::EventTracerProfileMethodScope; +using ::executorch::runtime::internal::EventTracerProfileOpScope; +using ::executorch::runtime::internal::EventTracerProfileScope; + +} // namespace internal +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/event_tracer_hooks_delegate.h b/third-party/include/executorch/runtime/core/event_tracer_hooks_delegate.h new file mode 100644 index 00000000..17024630 --- /dev/null +++ b/third-party/include/executorch/runtime/core/event_tracer_hooks_delegate.h @@ -0,0 +1,197 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +/** + * @file + * + * This file contains the hooks that can be used by runtime delegate backend + * authors to log profiling and debugging events from backend code. In order to + * use these hooks delegate authors would have needed to generate a delegate + * debug identifier mapping using the DelegateMappingBuilder library present in + * executorch/exir/backend/utils.py. The delegate debug identifiers generated by + * that library are the ones that need to be passed to these hooks to log + * events. Using any other identifiers will cause post-processing of the events + * data to not properly link back to the nodes in the original lowered graph. + * + * The benefit of defining these hooks is that we can easily control whether or + * not we want to compile in the EventTracer code based on the status of the + * ET_EVENT_TRACER_ENABLED flag. + */ + +namespace executorch { +namespace runtime { + +/** + * Start the profiling of a delegate event. Similar to start_profiling it will + * return an instance of EventTracerEntry that contains the details of this + * event. Can be left in production code as these hooks compile conditionally. + * + * @param[in] event_tracer The event tracer instance that is doing the logging. + * @param[in] name Human readable name for the delegate event. This name has + * to be the same name that was passed in during the Debug delegate mapping + * generation in the export/ahead-of-time process. If indices and not names + * are used by this delegate to identify ops executed in the backend then + * nullptr can be passed in. Users calling this interface do not need to keep + * the memory pointed to by this pointer around. The string must be copied over + * into internal memory during this call. + * @param[in] delegate_debug_id The id of the delegate event. If string + * based names are used by this delegate to identify ops executed in the + * backend then kUnsetDebugHandle should be passed in here. + */ +inline EventTracerEntry +event_tracer_start_profiling_delegate(EventTracer *event_tracer, + const char *name, + DebugHandle delegate_debug_id) { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer) { + return event_tracer->start_profiling_delegate(name, delegate_debug_id); + } +#else //! ET_EVENT_TRACER_ENABLED + (void)name; + (void)delegate_debug_id; +#endif + // There is no active tracer; this value will be ignored. + return EventTracerEntry(); +} + +/** + * Signal the end of the delegate profiling event contained in + * event_tracer_entry. Users also have the option to log some some free-from + * string based metadata along with this. Can be left in production code as + * these hooks compile conditionally. + * + * @param[in] event_tracer The event tracer instance that is doing the logging. + * @param[in] event_tracer_entry The EventTracerEntry returned by a call to + * start_profiling_delegate(). + * @param[in] metadata Optional data relevant to the execution that the user + * wants to log along with this event. Pointer to metadata doesn't need to be + * valid after the call to this function. The contents and format of the data + * are transparent to the event tracer. It will just pipe along the data and + * make it available for the user again in the post-processing stage. + * @param[in] metadata_len Length of the metadata buffer. + */ +inline void event_tracer_end_profiling_delegate( + EventTracer *event_tracer, EventTracerEntry event_tracer_entry, + const void *metadata = nullptr, size_t metadata_len = 0) { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer) { + event_tracer->end_profiling_delegate(event_tracer_entry, metadata, + metadata_len); + } +#else //! ET_EVENT_TRACER_ENABLED + (void)event_tracer_entry; + (void)metadata; + (void)metadata_len; +#endif +} + +/** + * Some delegates get access to the profiling details only after the complete + * graph has been executed. This interface is to support such use cases. It + * can be called in a loop etc. to log any number of profiling events that are + * part of this delegate. Can be left in production code as these hooks + * compile conditionally. + * + * @param[in] event_tracer The event tracer instance that is doing the logging. + * @param[in] name Human readable name for the delegate event. This name has + * to be the same name that was passed in during the Debug delegate mapping + * generation in the export/ahead-of-time process. If indices and not names + * are used by this delegate to identify ops executed in the backend then + * nullptr can be passed in. Users calling this interface do not need to keep + * the memory pointed to by this pointer around. The string must + * be copied over into internal memory during this call. + * @param[in] delegate_debug_id The id of the delegate event. If string + * based names are used by this delegate to identify ops executed in the + * backend then -1 should be passed in here. + * @param[in] start_time The timestamp when the delegate event started. + * @param[in] end_time The timestamp when the delegate event finished. + * @param[in] metadata Optional data relevant to the execution that the user + * wants to log along with this event. Pointer to metadata doesn't need to be + * valid after the call to this function. The contents and format of the data + * are transparent to the event tracer. It will just pipe along the data and + * make it available for the user again in the post-processing stage. + * @param[in] metadata_len Length of the metadata buffer. + */ +inline void event_tracer_log_profiling_delegate( + EventTracer *event_tracer, const char *name, DebugHandle delegate_debug_id, + et_timestamp_t start_time, et_timestamp_t end_time, + const void *metadata = nullptr, size_t metadata_len = 0) { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer) { + event_tracer->log_profiling_delegate(name, delegate_debug_id, start_time, + end_time, metadata, metadata_len); + } +#else //! ET_EVENT_TRACER_ENABLED + (void)name; + (void)delegate_debug_id; + (void)start_time; + (void)end_time; + (void)metadata; + (void)metadata_len; +#endif +} + +/** + * This templated interfaces can be called in a loop etc. to log any number of + * debug events that are part of this delegate. Supported values types are int, + * bool, double, tensor and array of tensors. Can be left in production code as + * these hooks compile conditionally. + * + * @param[in] event_tracer The event tracer instance that is doing the logging. + * @param[in] name Human readable name for the delegate event. This name has + * to be the same name that was passed in during the Debug delegate mapping + * generation in the export/ahead-of-time process. If indices and not names + * are used by this delegate to identify ops executed in the backend then + * nullptr can be passed in. Users calling this interface do not need to keep + * the memory pointed to by this pointer around. The string must + * be copied over into internal memory during this call. + * @param[in] delegate_debug_id The id of the delegate event. If string + * based names are used by this delegate to identify ops executed in the + * backend then -1 should be passed in here. + * @param[in] output The output to be logged. + */ +template +inline void event_tracer_log_output_delegate(EventTracer *event_tracer, + const char *name, + DebugHandle delegate_debug_id, + const T &output) { +#ifdef ET_EVENT_TRACER_ENABLED + if (event_tracer) { + static_assert( + std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same>::value, + "Unsupported type for intermediate output"); + event_tracer->log_intermediate_output_delegate(name, delegate_debug_id, + output); + } +#else //! ET_EVENT_TRACER_ENABLED + (void)name; + (void)delegate_debug_id; + (void)output; +#endif +} + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::event_tracer_end_profiling_delegate; +using ::executorch::runtime::event_tracer_log_output_delegate; +using ::executorch::runtime::event_tracer_log_profiling_delegate; +using ::executorch::runtime::event_tracer_start_profiling_delegate; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/exec_aten/exec_aten.h b/third-party/include/executorch/runtime/core/exec_aten/exec_aten.h new file mode 100644 index 00000000..b63e9f3c --- /dev/null +++ b/third-party/include/executorch/runtime/core/exec_aten/exec_aten.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include // @manual +#include +#ifdef USE_ATEN_LIB +#include // @manual +#include +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include +#else // use executor +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include // @manual + +#endif + +namespace executorch { +namespace aten { + +using TensorShapeDynamism = executorch::runtime::TensorShapeDynamism; + +#ifdef USE_ATEN_LIB + +using Tensor = at::Tensor; +using TensorList = at::TensorList; +using TensorImpl = at::TensorImpl; +using string_view = c10::string_view; +template using ArrayRef = c10::ArrayRef; +template using optional = std::optional; +using nullopt_t = std::nullopt_t; +using std::nullopt; +using ScalarType = at::ScalarType; +using Scalar = c10::Scalar; +using MemoryFormat = c10::MemoryFormat; +using SizesType = int64_t; +using DimOrderType = uint8_t; +using StridesType = int64_t; +using Device = c10::Device; +using DeviceType = c10::DeviceType; +using Layout = c10::Layout; + +// Custom types that map to ScalarType +using Half = c10::Half; +template using complex = c10::complex; +using qint8 = c10::qint8; +using quint8 = c10::quint8; +using qint32 = c10::qint32; +using BFloat16 = c10::BFloat16; +using quint4x2 = c10::quint4x2; +using quint2x4 = c10::quint2x4; +using IntArrayRef = at::IntArrayRef; + +template using OptionalArrayRef = c10::OptionalArrayRef; +using OptionalIntArrayRef = OptionalArrayRef; + +inline ssize_t compute_numel(const SizesType *sizes, ssize_t dim) { + return static_cast( + c10::multiply_integers(c10::ArrayRef(sizes, dim))); +} + +#else // Use executor types + +using Tensor = torch::executor::Tensor; +using TensorImpl = torch::executor::TensorImpl; +using string_view = torch::executor::string_view; +template using ArrayRef = torch::executor::ArrayRef; +template using optional = torch::executor::optional; +using nullopt_t = torch::executor::nullopt_t; +// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) +static constexpr nullopt_t nullopt{0}; +using ScalarType = torch::executor::ScalarType; +using TensorList = ArrayRef; +using Scalar = torch::executor::Scalar; +using MemoryFormat = torch::executor::MemoryFormat; +using SizesType = torch::executor::Tensor::SizesType; +using DimOrderType = torch::executor::Tensor::DimOrderType; +using StridesType = torch::executor::Tensor::StridesType; +using Device = torch::executor::Device; +using DeviceType = torch::executor::DeviceType; +using Layout = torch::executor::Layout; + +// Custom types that map to ScalarType +using Half = torch::executor::Half; +template using complex = torch::executor::complex; +using qint8 = torch::executor::qint8; +using quint8 = torch::executor::quint8; +using qint32 = torch::executor::qint32; +using BFloat16 = torch::executor::BFloat16; +using quint4x2 = torch::executor::quint4x2; +using quint2x4 = torch::executor::quint2x4; + +using IntArrayRef = torch::executor::IntArrayRef; + +template +using OptionalArrayRef = + torch::executor::optional>; +using OptionalIntArrayRef = OptionalArrayRef; + +using torch::executor::compute_numel; + +#endif // Use ExecuTorch types + +} // namespace aten +} // namespace executorch + +// DEPRECATED: The exec_aten:: namespace is deprecated. Use executorch::aten:: +// instead. +namespace exec_aten = executorch::aten; + +namespace torch { +namespace executor { +using TensorList = exec_aten::TensorList; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/exec_aten/util/dim_order_util.h b/third-party/include/executorch/runtime/core/exec_aten/util/dim_order_util.h new file mode 100644 index 00000000..c1ea45b8 --- /dev/null +++ b/third-party/include/executorch/runtime/core/exec_aten/util/dim_order_util.h @@ -0,0 +1,259 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace executorch { +namespace runtime { + +namespace { +template +bool validate_dim_order(const DimOrderType *dim_order, const size_t dims) { + for (int32_t i = 0; i < dims; ++i) { + if (dim_order[i] >= dims) { + return false; + } + } + return true; +} +} // namespace + +/** + * Check if a given dim_order array is equivalent to the contiguous dim order of + * {0, 1, 2, 3, ...} + * + * @param[in] dim_order pointer to dim_order array + * @param[in] dims length of the dim_order array + */ +template +inline bool is_contiguous_dim_order(const DimOrderType *dim_order, + const size_t dims) { + for (int i = 0; i < dims; ++i) { + if (dim_order[i] != i) { + return false; + } + } + return true; +} + +/** + * Check if a given dim_order array is equivalent to a channels last dim order. + * Channels last dim order is only valid for 4-dim and 5-dim tensors. + * + * @param[in] dim_order pointer to dim_order array + * @param[in] dims length of the dim_order array + */ +template +bool is_channels_last_dim_order(const DimOrderType *dim_order, + const size_t dims) { + if (dims != 4 && dims != 5) { + return false; + } + // 4-dim tensor is interpreted as NCHW, 5-dim tensor is interpreted as NCHWD + size_t channels_dim = 1; + // Last value in the dim order should be the channels dim + if (dim_order[dims - 1] != channels_dim) { + return false; + } + + if (dim_order[0] != 0) { + return false; + } + int d = 1; + while (d < dims - 1) { + if (dim_order[d] != d + 1) { + return false; + } + d++; + } + return true; +} + +/* + * This utility translated sizes to strides by using dimension order + * information. Dimension order specifies how the dimensions are laid out in the + * memory. For example for Size = [2, 3, 4, 5] dim_names = [N, C, H, W] + * dim_order = [0, 2, 3, 1] + * strides = [60, 1, 15, 3] + * param[in]: sizes, pointer to sizes array + * param[in]: dim_order, pointer to dimension order array + * param[in]: dims, number of dims. Sizes and dim_order must be sizes to dims + * param[out]: strides, pointer to strides array that is filled in + * + * NB: Reason for not using ArrayRef is the dependency on kernel_types.h + * This header cannot be included, because of circular dep it causes. + * kernel_types depends on executorch_kernel_types in lean mode, which compiles + * TensorImpl.cpp. executorch_kernel_types needs to depend on dim_order_utils + * in order to utilize dim_order_to_stride in its resize impl. If + * dim_order_utils depends on kernel_type, we have circular deps. This is also + * the reason for templatizing this function. Better ideas welcome! + * TODO(T148342910) + * + * Note that this function does not check that the provided dim order is valid. + * This function should only be used when the validity of the dim order has been + * checked beforehand. A safer version of this function is provided below as + * dim_order_to_stride which will check that the dim order is valid. + */ +template +inline void dim_order_to_stride_nocheck(const SizesType *sizes, + const DimOrderType *dim_order, + const size_t dims, + StridesType *strides) { + // For 0 dim tensors, just return ok. + if (dims == 0) { + return; + } + // Fastest moving dim has stride of 1. + // For example: + // Size = [2, 3, 4, 5] dim_names = [N, C, H, W] + // dim_order = [0, 2, 3, 1] + // strides = [60, 1, 15, 3] + strides[dim_order[dims - 1]] = 1; + for (int32_t i = dims - 2; i >= 0; --i) { + if (sizes[dim_order[i + 1]] == 0) { + strides[dim_order[i]] = strides[dim_order[i + 1]]; + } else { + strides[dim_order[i]] = + strides[dim_order[i + 1]] * sizes[dim_order[i + 1]]; + } + } +} + +template +ET_NODISCARD inline Error +dim_order_to_stride(const SizesType *sizes, const DimOrderType *dim_order, + const size_t dims, StridesType *strides) { + // For 0 dim tensors, just return ok. + if (dims == 0) { + return Error::Ok; + } + ET_CHECK_OR_RETURN_ERROR(validate_dim_order(dim_order, dims), InvalidArgument, + "Invalid dim order. One of the value is larger than " + "the number of dims %zu", + dims); + + dim_order_to_stride_nocheck(sizes, dim_order, dims, strides); + return Error::Ok; +} + +namespace internal { + +template struct StrideDimOrder { + StridesType stride; + DimOrderType dim_order; + + StrideDimOrder(StridesType stride, DimOrderType dim_order) + : stride(stride), dim_order(dim_order) {} + StrideDimOrder() = default; + bool operator>(const StrideDimOrder &other) const { + // descending order + return stride < other.stride; + } +}; + +template struct Sorter { +public: + void quick_sort(ValueType arr[], int32_t low, int32_t high) { + if (low < high) { + ValueType pivot = arr[high]; + int32_t pos = partition(arr, low, high, pivot); + + quick_sort(arr, low, pos - 1); + quick_sort(arr, pos + 1, high); + } + } + +private: + void swap(ValueType arr[], int32_t pos1, int32_t pos2) noexcept { + ValueType temp = arr[pos1]; + arr[pos1] = arr[pos2]; + arr[pos2] = temp; + } + + int32_t partition(ValueType arr[], int32_t low, int32_t high, + ValueType pivot) { + int32_t i = low; + int32_t j = low; + while (i <= high) { + if (arr[i] > pivot) { + i++; + } else { + swap(arr, i++, j++); + } + } + return j - 1; + } +}; + +} // namespace internal + +/* + * This utility translated strides to dimension order + * information. Dimension order specifies how the dimensions are laid out in the + * memory. For example for tensor with sizes [3, 5, 2] and strides [5, 1, 15], + * dim order should be [2, 0, 1], which is obtained by sorting strides in + * descending order. param[in]: sizes, pointer to sizes array param[in]: + * dim_order, pointer to dimension order array param[in]: dims, number of dims. + * Sizes and dim_order must be sizes to dims param[out]: strides, pointer to + * strides array that is filled in + * + * NB: Reason for not using ArrayRef is the dependency on kernel_types.h + * This header cannot be included, because of circular dep it causes. + * kernel_types depends on executorch_kernel_types in lean mode, which compiles + * TensorImpl.cpp. executorch_kernel_types needs to depend on dim_order_utils + * in order to utilize dim_order_to_stride in its resize impl. If + * dim_order_utils depends on kernel_type, we have circular deps. This is also + * the reason for templatizing this function. Better ideas welcome! + * TODO(T148342910) + */ +template +ET_NODISCARD inline Error stride_to_dim_order(const StridesType *strides, + const size_t dims, + DimOrderType *dim_order) { + const size_t kMaxNumOfDimensions = 16; + ET_CHECK_OR_RETURN_ERROR(dim_order != nullptr, MemoryAllocationFailed, + "Need memory to get dim_order."); + ET_CHECK_OR_RETURN_ERROR(dims <= kMaxNumOfDimensions, NotSupported, + "dims %zu exceeds maximum allowed %zu", dims, + kMaxNumOfDimensions); + internal::StrideDimOrder + array[kMaxNumOfDimensions]; + for (DimOrderType i = 0; i < dims; i++) { + array[i].dim_order = i; + array[i].stride = strides[i]; + } + + internal::Sorter> sorter; + + sorter.quick_sort(array, 0, dims - 1); + + for (auto i = 0; i < dims; i++) { + dim_order[i] = array[i].dim_order; + } + return Error::Ok; +} +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::dim_order_to_stride; +using ::executorch::runtime::dim_order_to_stride_nocheck; +using ::executorch::runtime::is_channels_last_dim_order; +using ::executorch::runtime::is_contiguous_dim_order; +using ::executorch::runtime::stride_to_dim_order; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h b/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h new file mode 100644 index 00000000..da20c8c9 --- /dev/null +++ b/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h @@ -0,0 +1,1273 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * + * Forked from + * https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h + * + * See file comment in ../ScalarType.h. + * + * This file contains all of the non-critical parts of the original ScalarType.h + * that are not required for the core ExecuTorch runtime, but may be helpful for + * code that uses ScalarType. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#ifdef USE_ATEN_LIB +// Note that a lot of the macros/functions defined in this ScalarTypeUtil.h file +// are also defined in c10/core/ScalarType.h, which is included via +// kernel_types.h when building in ATen mode. They tend to use different names +// and a different namespace, but if there are conflicts they should be resolved +// here. +#define ET_FORALL_SCALAR_TYPES AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS +#include +namespace executorch { +namespace aten { +using ScalarType = at::ScalarType; +} // namespace aten +} // namespace executorch +#else // !USE_ATEN_LIB +#include +#include +namespace executorch { +namespace aten { +using ScalarType = torch::executor::ScalarType; +using string_view = torch::executor::string_view; +} // namespace aten +} // namespace executorch +#endif // USE_ATEN_LIB +// DEPRECATED: The exec_aten:: namespace is deprecated. Use executorch::aten:: +// instead. +namespace exec_aten = ::executorch::aten; + +namespace executorch { +namespace runtime { + +#if !defined(USE_ATEN_LIB) +// Util to figure out if the scalar type if one of the +// supported floating point types. +// In aten mode, aten lib already has these utils as part of +// its vec_base.h +template +struct is_floating_point + : std::integral_constant::value || + std::is_same_v || + std::is_same_v> { +}; + +// Util to figure out if the scalar type is one of the +// reduced precision floating point types. +template +struct is_reduced_floating_point + : std::integral_constant || + std::is_same_v> { +}; + +template +constexpr bool is_reduced_floating_point_v = + is_reduced_floating_point::value; +#endif + +/// Maps ScalarTypes to C++ types. +template <::executorch::aten::ScalarType N> struct ScalarTypeToCppType; + +#define SPECIALIZE_ScalarTypeToCppType(cpp_type, scalar_type) \ + template <> \ + struct ScalarTypeToCppType<::executorch::aten::ScalarType::scalar_type> { \ + using type = cpp_type; \ + }; + +ET_FORALL_SCALAR_TYPES(SPECIALIZE_ScalarTypeToCppType) + +#undef SPECIALIZE_ScalarTypeToCppType + +/// Maps C++ types to ScalarTypes. +template struct CppTypeToScalarType; + +#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \ + template <> \ + struct CppTypeToScalarType \ + : std::integral_constant<::executorch::aten::ScalarType, \ + ::executorch::aten::ScalarType::scalar_type> { \ + }; + +ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) + +#undef SPECIALIZE_CppTypeToScalarType + +// +// Macros that iterate across different subsets of ScalarTypes. +// +// See ET_FORALL_SCALAR_TYPES in ScalarType.h to iterate across all ScalarType +// names and types. +// +// For all of these macros, the final `_` parameter is the name of another macro +// that takes two parameters: the name of a C type, and the name of the +// corresponding ScalarType enumerator. +// +// Note that these macros should use fully-qualified namespaces (starting with +// `::`) to ensure that they can be called safely in any arbitrary namespace. +// + +// In this context, "INT" means integer C types, which is why the quantized +// integer types are not included. +#define ET_FORALL_INT_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int32_t, Int) \ + _(int64_t, Long) + +// Here `ANOTHER_INPUT` should be another variable to be forwarded to a given +// function. +#define ET_FORALL_INT_TYPES_WITH(ANOTHER_INPUT, _) \ + _(ANOTHER_INPUT, uint8_t, Byte) \ + _(ANOTHER_INPUT, int8_t, Char) \ + _(ANOTHER_INPUT, int16_t, Short) \ + _(ANOTHER_INPUT, int32_t, Int) \ + _(ANOTHER_INPUT, int64_t, Long) + +#define ET_FORALL_INT_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, uint8_t, Byte) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int8_t, Char) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int16_t, Short) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int32_t, Int) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int64_t, Long) + +#define ET_FORALL_INT_TYPES_AND(SCALARTYPE, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int32_t, Int) \ + _(int64_t, Long) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::executorch::aten::ScalarType::SCALARTYPE>::type, \ + SCALARTYPE) + +// In this context, "FLOAT" means float C types, which is why BFloat16 is not +// included. +#define ET_FORALL_FLOAT_TYPES(_) \ + _(float, Float) \ + _(double, Double) + +#define ET_FORALL_FLOAT_TYPES_AND(SCALARTYPE, _) \ + _(float, Float) \ + _(double, Double) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::executorch::aten::ScalarType::SCALARTYPE>::type, \ + SCALARTYPE) + +#define ET_FORALL_FLOAT_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(float, Float) \ + _(double, Double) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::executorch::aten::ScalarType::SCALARTYPE1>::type, \ + SCALARTYPE1) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::executorch::aten::ScalarType::SCALARTYPE2>::type, \ + SCALARTYPE2) + +#define ET_FORALL_FLOATH_TYPES(_) ET_FORALL_FLOAT_TYPES_AND(Half, _) + +#define ET_FORALL_FLOATHBF16_TYPES(_) \ + ET_FORALL_FLOAT_TYPES_AND2(Half, BFloat16, _) + +// Here `ANOTHER_INPUT` should be another variable to be forwarded to a given +// function. Not to be confused with another scalar type as in +// `ET_FORALL_FLOAT_TYPES_AND`. +#define ET_FORALL_FLOAT_TYPES_WITH(ANOTHER_INPUT, _) \ + _(ANOTHER_INPUT, float, Float) \ + _(ANOTHER_INPUT, double, Double) + +#define ET_FORALL_FLOAT_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) + +#define ET_FORALL_FLOATHBF16_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::executorch::aten::Half, Half) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::executorch::aten::BFloat16, BFloat16) + +// In this context, "REAL" means integer/float C types, which is why BFloat16 +// and Half are not included. +#define ET_FORALL_REAL_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int32_t, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) + +// Here `ANOTHER_INPUT` should be another variable to be forwarded to a given +// function. Not to be confused with another scalar type as in +// `ET_FORALL_REAL_TYPES_AND`. +#define ET_FORALL_REAL_TYPES_WITH(ANOTHER_INPUT, _) \ + _(ANOTHER_INPUT, uint8_t, Byte) \ + _(ANOTHER_INPUT, int8_t, Char) \ + _(ANOTHER_INPUT, int16_t, Short) \ + _(ANOTHER_INPUT, int32_t, Int) \ + _(ANOTHER_INPUT, int64_t, Long) \ + _(ANOTHER_INPUT, float, Float) \ + _(ANOTHER_INPUT, double, Double) + +#define ET_FORALL_REAL_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, uint8_t, Byte) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int8_t, Char) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int16_t, Short) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int32_t, Int) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int64_t, Long) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) + +#define ET_FORALL_REALHBF16_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, uint8_t, Byte) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int8_t, Char) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int16_t, Short) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int32_t, Int) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int64_t, Long) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::executorch::aten::Half, Half) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::executorch::aten::BFloat16, BFloat16) + +// For macros that take `SCALARTYPEn` parameters, those parameters should be +// an unquoted/unqualified enumerator name like `Int` or `Float`. +#define ET_FORALL_REAL_TYPES_AND(SCALARTYPE, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int32_t, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::executorch::aten::ScalarType::SCALARTYPE>::type, \ + SCALARTYPE) + +#define ET_FORALL_REAL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int32_t, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::executorch::aten::ScalarType::SCALARTYPE1>::type, \ + SCALARTYPE1) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::executorch::aten::ScalarType::SCALARTYPE2>::type, \ + SCALARTYPE2) + +#define ET_FORALL_REALH_TYPES(_) ET_FORALL_REAL_TYPES_AND(Half, _) + +#define ET_FORALL_REALHBF16_TYPES(_) \ + ET_FORALL_REAL_TYPES_AND2(Half, BFloat16, _) + +#define ET_FORALL_REALHBBF16_TYPES(_) \ + ET_FORALL_REAL_TYPES_AND3(Bool, Half, BFloat16, _) + +#define ET_FORALL_REAL_TYPES_AND_WITH(SCALARTYPE, ANOTHER_INPUT, _) \ + _(ANOTHER_INPUT, uint8_t, Byte) \ + _(ANOTHER_INPUT, int8_t, Char) \ + _(ANOTHER_INPUT, int16_t, Short) \ + _(ANOTHER_INPUT, int32_t, Int) \ + _(ANOTHER_INPUT, int64_t, Long) \ + _(ANOTHER_INPUT, float, Float) \ + _(ANOTHER_INPUT, double, Double) \ + _(ANOTHER_INPUT, \ + ::executorch::runtime::ScalarTypeToCppType< \ + ::executorch::aten::ScalarType::SCALARTYPE>::type, \ + SCALARTYPE) + +#define ET_FORALL_REAL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int32_t, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::executorch::aten::ScalarType::SCALARTYPE1>::type, \ + SCALARTYPE1) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::executorch::aten::ScalarType::SCALARTYPE2>::type, \ + SCALARTYPE2) + +#define ET_FORALL_REAL_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int32_t, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::executorch::aten::ScalarType::SCALARTYPE1>::type, \ + SCALARTYPE1) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::executorch::aten::ScalarType::SCALARTYPE2>::type, \ + SCALARTYPE2) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::executorch::aten::ScalarType::SCALARTYPE3>::type, \ + SCALARTYPE3) + +#define ET_FORALL_QINT_TYPES(_) \ + _(::torch::executor::qint8, QInt8) \ + _(::torch::executor::quint8, QUInt8) \ + _(::torch::executor::qint32, QInt32) \ + _(::torch::executor::quint4x2, QUInt4x2) \ + _(::torch::executor::quint2x4, QUInt2x4) + +// In this context, "COMPLEX" means complex types based on primitive C types, +// which is why ComplexHalf is not included. +#define ET_FORALL_COMPLEX_TYPES(_) \ + _(::torch::executor::complex, ComplexFloat) \ + _(::torch::executor::complex, ComplexDouble) + +// +// Utility functions to retrieve metadata for a given ScalarType +// + +/** + * Returns true if the parameter is one of the values covered by + * ET_FORALL_SCALAR_TYPES. + */ +inline bool isValid(::executorch::aten::ScalarType type) { + return static_cast(type) >= 0 && + type < ::executorch::aten::ScalarType::NumOptions && + type != ::executorch::aten::ScalarType::Undefined; +} + +/** + * Returns the name of a ScalarType as a C string. + * + * @param[in] t The type to get the name of. + * @return The name of the type, or "UNKNOWN_SCALAR" if the type is not known. + */ +inline const char *toString(::executorch::aten::ScalarType t) { +#define DEFINE_CASE(_, name) \ + case ::executorch::aten::ScalarType::name: \ + return #name; + + switch (t) { + ET_FORALL_SCALAR_TYPES(DEFINE_CASE) + case ::executorch::aten::ScalarType::Undefined: + return "Undefined"; + default: + return "UNKNOWN_SCALAR"; + } +#undef DEFINE_CASE +} + +/** + * Returns the size in bytes of the C type associated with the ScalarType. + * + * Calls ET_CHECK_MSG() if the type is unknown or is ScalarType::Undefined. + * + * @param[in] t The type to get the underlying C type size of. + * @return The size of the associated C type in bytes. + */ +inline size_t elementSize(::executorch::aten::ScalarType t) { +#define CASE_ELEMENTSIZE_CASE(ctype, name) \ + case ::executorch::aten::ScalarType::name: \ + return sizeof(ctype); + + switch (t) { + ET_FORALL_SCALAR_TYPES(CASE_ELEMENTSIZE_CASE) + default: + ET_CHECK_MSG(false, "Unknown ScalarType %" PRId8, static_cast(t)); + } +#undef CASE_ELEMENTSIZE_CASE +} + +inline constexpr bool isIntegralType(::executorch::aten::ScalarType t, + bool includeBool) { + return (includeBool && t == ::executorch::aten::ScalarType::Bool) || + (t == ::executorch::aten::ScalarType::Byte || + t == ::executorch::aten::ScalarType::Char || + t == ::executorch::aten::ScalarType::Int || + t == ::executorch::aten::ScalarType::Long || + t == ::executorch::aten::ScalarType::Short); +} + +template +struct is_integral_type + : public std::integral_constant< + bool, isIntegralType(CppTypeToScalarType::value, includeBool)> {}; + +inline constexpr bool isFloatingType(::executorch::aten::ScalarType t) { + return (t == ::executorch::aten::ScalarType::Double || + t == ::executorch::aten::ScalarType::Float || + t == ::executorch::aten::ScalarType::Half || + t == ::executorch::aten::ScalarType::BFloat16); +} + +inline bool isRealType(::executorch::aten::ScalarType t) { + return (t == ::executorch::aten::ScalarType::Byte || + t == ::executorch::aten::ScalarType::Char || + t == ::executorch::aten::ScalarType::Short || + t == ::executorch::aten::ScalarType::Int || + t == ::executorch::aten::ScalarType::Long || + t == ::executorch::aten::ScalarType::Float || + t == ::executorch::aten::ScalarType::Double); +} + +inline bool isRealHType(::executorch::aten::ScalarType t) { + return (t == ::executorch::aten::ScalarType::Byte || + t == ::executorch::aten::ScalarType::Char || + t == ::executorch::aten::ScalarType::Short || + t == ::executorch::aten::ScalarType::Int || + t == ::executorch::aten::ScalarType::Long || + t == ::executorch::aten::ScalarType::Float || + t == ::executorch::aten::ScalarType::Double || + t == ::executorch::aten::ScalarType::Half); +} + +inline bool isRealHBType(::executorch::aten::ScalarType t) { + return (isRealHType(t) || t == ::executorch::aten::ScalarType::Bool); +} + +inline bool isRealHBF16Type(::executorch::aten::ScalarType t) { + return (isRealHType(t) || t == ::executorch::aten::ScalarType::BFloat16); +} + +inline bool isRealHBBF16Type(::executorch::aten::ScalarType t) { + return (isRealHBType(t) || t == ::executorch::aten::ScalarType::BFloat16); +} + +inline constexpr bool isComplexType(::executorch::aten::ScalarType t) { + return (t == ::executorch::aten::ScalarType::ComplexHalf || + t == ::executorch::aten::ScalarType::ComplexFloat || + t == ::executorch::aten::ScalarType::ComplexDouble); +} + +template +struct is_complex_type + : std::integral_constant::value)> {}; + +constexpr bool isBitsType(::executorch::aten::ScalarType t) { + return t == ::executorch::aten::ScalarType::Bits1x8 || + t == ::executorch::aten::ScalarType::Bits2x4 || + t == ::executorch::aten::ScalarType::Bits4x2 || + t == ::executorch::aten::ScalarType::Bits8 || + t == ::executorch::aten::ScalarType::Bits16; +} + +template +struct is_bits_type + : std::integral_constant::value)> { +}; + +constexpr bool isQIntType(::executorch::aten::ScalarType t) { + // Don't forget to extend this when adding new QInt types + return t == ::executorch::aten::ScalarType::QInt8 || + t == ::executorch::aten::ScalarType::QUInt8 || + t == ::executorch::aten::ScalarType::QInt32 || + t == ::executorch::aten::ScalarType::QUInt4x2 || + t == ::executorch::aten::ScalarType::QUInt2x4; +} + +template +struct is_qint_type + : std::integral_constant::value)> { +}; + +constexpr bool isFloat8Type(::executorch::aten::ScalarType t) { + // Don't forget to extend this when adding new QInt types + return t == ::executorch::aten::ScalarType::Float8_e5m2 || + t == ::executorch::aten::ScalarType::Float8_e4m3fn || + t == ::executorch::aten::ScalarType::Float8_e5m2fnuz || + t == ::executorch::aten::ScalarType::Float8_e4m3fnuz; +} + +template +struct is_float8_type + : std::integral_constant::value)> {}; + +constexpr bool isBarebonesUnsignedType(::executorch::aten::ScalarType t) { + // Don't forget to extend this when adding new QInt types + return t == ::executorch::aten::ScalarType::UInt16 || + t == ::executorch::aten::ScalarType::UInt32 || + t == ::executorch::aten::ScalarType::UInt64; +} + +template +struct is_barebones_unsigned_type + : std::integral_constant::value)> {}; + +inline ::executorch::aten::ScalarType +toQIntType(::executorch::aten::ScalarType t) { + switch (t) { + case ::executorch::aten::ScalarType::Byte: + return ::executorch::aten::ScalarType::QUInt8; + case ::executorch::aten::ScalarType::Char: + return ::executorch::aten::ScalarType::QInt8; + case ::executorch::aten::ScalarType::Int: + return ::executorch::aten::ScalarType::QInt32; + default: + return t; + } +} + +inline ::executorch::aten::ScalarType +toUnderlying(::executorch::aten::ScalarType t) { + switch (t) { + case ::executorch::aten::ScalarType::QUInt8: + return ::executorch::aten::ScalarType::Byte; + case ::executorch::aten::ScalarType::QInt8: + return ::executorch::aten::ScalarType::Char; + case ::executorch::aten::ScalarType::QInt32: + return ::executorch::aten::ScalarType::Int; + case ::executorch::aten::ScalarType::QUInt4x2: + return ::executorch::aten::ScalarType::Byte; + case ::executorch::aten::ScalarType::QUInt2x4: + return ::executorch::aten::ScalarType::Byte; + default: + return t; + } +} + +inline bool isSignedType(::executorch::aten::ScalarType t) { + ET_CHECK_MSG(!::executorch::runtime::isQIntType(t), + "isSignedType not supported for quantized types like %" PRId8, + static_cast(t)); +#define CASE_SIGNED(ctype, name) \ + case ::executorch::aten::ScalarType::name: \ + return std::numeric_limits::is_signed; + + switch (t) { + case ::executorch::aten::ScalarType::ComplexHalf: + case ::executorch::aten::ScalarType::ComplexFloat: + case ::executorch::aten::ScalarType::ComplexDouble: + return true; + ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED) + default: + ET_CHECK_MSG(false, "Unknown ScalarType %" PRId8, static_cast(t)); + } +#undef CASE_SIGNED +} + +inline bool isUnderlying(::executorch::aten::ScalarType type, + ::executorch::aten::ScalarType qtype) { + return type == ::executorch::runtime::toUnderlying(qtype); +} + +inline ::executorch::aten::ScalarType +toRealValueType(::executorch::aten::ScalarType t) { + switch (t) { + case ::executorch::aten::ScalarType::ComplexHalf: + return ::executorch::aten::ScalarType::Half; + case ::executorch::aten::ScalarType::ComplexFloat: + return ::executorch::aten::ScalarType::Float; + case ::executorch::aten::ScalarType::ComplexDouble: + return ::executorch::aten::ScalarType::Double; + default: + return t; + } +} + +inline ::executorch::aten::ScalarType +toComplexType(::executorch::aten::ScalarType t) { + switch (t) { + case ::executorch::aten::ScalarType::BFloat16: + // BFloat16 has range equivalent to Float, + // so we map it to ComplexFloat. + return ::executorch::aten::ScalarType::ComplexFloat; + case ::executorch::aten::ScalarType::Half: + return ::executorch::aten::ScalarType::ComplexHalf; + case ::executorch::aten::ScalarType::Float: + return ::executorch::aten::ScalarType::ComplexFloat; + case ::executorch::aten::ScalarType::Double: + return ::executorch::aten::ScalarType::ComplexDouble; + case ::executorch::aten::ScalarType::ComplexHalf: + return ::executorch::aten::ScalarType::ComplexHalf; + case ::executorch::aten::ScalarType::ComplexFloat: + return ::executorch::aten::ScalarType::ComplexFloat; + case ::executorch::aten::ScalarType::ComplexDouble: + return ::executorch::aten::ScalarType::ComplexDouble; + default: + ET_CHECK_MSG(false, "Unknown Complex ScalarType for %" PRId8, + static_cast(t)); + } +} + +/** + * Encodes type casting rules that are consistent with ATen behaviour. + */ +inline constexpr bool canCast(const ::executorch::aten::ScalarType from, + const ::executorch::aten::ScalarType to) { + // Disallow complex -> non-complex + return !(::executorch::runtime::isComplexType(from) && + !::executorch::runtime::isComplexType(to)) && + // Disallow float -> integral + !(::executorch::runtime::isFloatingType(from) && + ::executorch::runtime::isIntegralType(to, /*includeBool=*/false)) && + // Treat bool as a special category. Disallow non-bool -> bool + !(from != ::executorch::aten::ScalarType::Bool && + to == ::executorch::aten::ScalarType::Bool); +} + +template +struct can_cast + : std::integral_constant::value, + CppTypeToScalarType::value)> {}; + +/** + * When casting from floating point to integral type, if the floating value is + * outside the integral type range, then an error is thrown if sanitization is + * enabled. To circumvent this, we cast the floating point to int64_t first. + */ +template ::value && + std::is_integral::value), + int> = 0> +To convert(From val) { + return static_cast(static_cast(val)); +} + +template ::value && + std::is_integral::value), + int> = 0> +To convert(From val) { + return static_cast(val); +} + +namespace internal { +// This is generated according to NumPy's promote_types +inline constexpr auto u1 = ::executorch::aten::ScalarType::Byte; +inline constexpr auto i1 = ::executorch::aten::ScalarType::Char; +inline constexpr auto i2 = ::executorch::aten::ScalarType::Short; +inline constexpr auto i4 = ::executorch::aten::ScalarType::Int; +inline constexpr auto i8 = ::executorch::aten::ScalarType::Long; +inline constexpr auto f2 = ::executorch::aten::ScalarType::Half; +inline constexpr auto f4 = ::executorch::aten::ScalarType::Float; +inline constexpr auto f8 = ::executorch::aten::ScalarType::Double; +inline constexpr auto c2 = ::executorch::aten::ScalarType::ComplexHalf; +inline constexpr auto c4 = ::executorch::aten::ScalarType::ComplexFloat; +inline constexpr auto c8 = ::executorch::aten::ScalarType::ComplexDouble; +inline constexpr auto b1 = ::executorch::aten::ScalarType::Bool; +inline constexpr auto bf = ::executorch::aten::ScalarType::BFloat16; + +using U1 = + typename ScalarTypeToCppType<::executorch::aten::ScalarType::Byte>::type; +using I1 = + typename ScalarTypeToCppType<::executorch::aten::ScalarType::Char>::type; +using I2 = + typename ScalarTypeToCppType<::executorch::aten::ScalarType::Short>::type; +using I4 = + typename ScalarTypeToCppType<::executorch::aten::ScalarType::Int>::type; +using I8 = + typename ScalarTypeToCppType<::executorch::aten::ScalarType::Long>::type; +using F2 = + typename ScalarTypeToCppType<::executorch::aten::ScalarType::Half>::type; +using F4 = + typename ScalarTypeToCppType<::executorch::aten::ScalarType::Float>::type; +using F8 = + typename ScalarTypeToCppType<::executorch::aten::ScalarType::Double>::type; +using C2 = typename ScalarTypeToCppType< + ::executorch::aten::ScalarType::ComplexHalf>::type; +using C4 = typename ScalarTypeToCppType< + ::executorch::aten::ScalarType::ComplexFloat>::type; +using C8 = typename ScalarTypeToCppType< + ::executorch::aten::ScalarType::ComplexDouble>::type; +using B1 = + typename ScalarTypeToCppType<::executorch::aten::ScalarType::Bool>::type; +using BF = typename ScalarTypeToCppType< + ::executorch::aten::ScalarType::BFloat16>::type; + +inline constexpr std::array<::executorch::aten::ScalarType, 13> index2dtype = { + {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf}}; + +constexpr std::array< + int64_t, static_cast(::executorch::aten::ScalarType::NumOptions)> +calculate_dtype2index() { + std::array(::executorch::aten::ScalarType::NumOptions)> + inverse = {}; + for (int64_t i = 0; + i < static_cast(::executorch::aten::ScalarType::NumOptions); + i++) { + inverse[i] = -1; + } + for (int64_t i = 0; i < static_cast(index2dtype.size()); i++) { + inverse[static_cast(index2dtype[i])] = i; + } + return inverse; +} + +inline constexpr auto dtype2index = calculate_dtype2index(); +inline constexpr int NUM_PROMOTE_TYPES = 13; +// Should match _promoteTypesLookup in c10/core/ScalarType.cpp so that +// we match PyTorch core type promotion semantics. +inline constexpr ::executorch::aten::ScalarType + promoteTypesLookup[NUM_PROMOTE_TYPES][NUM_PROMOTE_TYPES] = { + /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 bf*/ + /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf}, + /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, bf}, + /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, bf}, + /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, bf}, + /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, bf}, + /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, f4}, + /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, f4}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, f8}, + /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, c4}, + /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, c4}, + /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, + /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf}, + /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, bf}, +}; + +} // namespace internal + +/** + * Implements type promotion rules that are consistent with ATen behaviour, + * which in turn is consistent with NumPy's promote_types. + * If half_to_float is set to true, then half and bfloat16 will be promoted to + * float instead + */ +inline constexpr ::executorch::aten::ScalarType +promoteTypes(::executorch::aten::ScalarType a, ::executorch::aten::ScalarType b, + bool half_to_float = false) { + // For QInt types, only allow exact match + if (::executorch::runtime::isQIntType(a) && a == b) { + return a; + } + if (::executorch::runtime::isQIntType(a) || + ::executorch::runtime::isQIntType(b)) { + ET_CHECK_MSG(false, "promoteTypes not valid for quantized dtypes"); + } + + // For Bits types, only allow exact match + if (::executorch::runtime::isBitsType(a) && a == b) { + return a; + } + if (::executorch::runtime::isBitsType(a) || + ::executorch::runtime::isBitsType(b)) { + ET_CHECK_MSG(false, "promoteTypes not valid for bits dtypes"); + } + + // For Float8 types, only allow exact match + if (::executorch::runtime::isFloat8Type(a) && a == b) { + return a; + } + if (::executorch::runtime::isFloat8Type(a) || + ::executorch::runtime::isFloat8Type(b)) { + ET_CHECK_MSG(false, "promoteTypes not valid for float8 dtypes"); + } + + // For barebones uint types, only allow exact match + if (::executorch::runtime::isBarebonesUnsignedType(a) && a == b) { + return a; + } + if (::executorch::runtime::isBarebonesUnsignedType(a) || + ::executorch::runtime::isBarebonesUnsignedType(b)) { + ET_CHECK_MSG(false, "promoteTypes not valid for barebone unsigned dtypes"); + } + + auto ix_a = ::executorch::runtime::internal::dtype2index[(int)a]; + ET_CHECK(ix_a != -1); + auto ix_b = ::executorch::runtime::internal::dtype2index[(int)b]; + ET_CHECK(ix_b != -1); + ::executorch::aten::ScalarType promoted_type = + ::executorch::runtime::internal::promoteTypesLookup[ix_a][ix_b]; + + if (half_to_float && + (promoted_type == ::executorch::aten::ScalarType::Half || + promoted_type == ::executorch::aten::ScalarType::BFloat16)) { + promoted_type = ::executorch::aten::ScalarType::Float; + } + + return promoted_type; +} + +template +struct promote_types { +private: + static_assert(std::is_same_v || + (!is_qint_type::value && !is_qint_type::value), + "promote_types not valid for quantized dtypes"); + static_assert(std::is_same_v || + (!is_bits_type::value && !is_bits_type::value), + "promote_types not valid for bits dtypes"); + static_assert(std::is_same_v || + (!is_float8_type::value && !is_float8_type::value), + "promote_types not valid for float8 dtypes"); + static_assert(std::is_same_v || + (!is_barebones_unsigned_type::value && + !is_barebones_unsigned_type::value), + "promote_types not valid for barebones unsigned dtypes"); + + using promoted_type_not_respecting_half_to_float = + typename ScalarTypeToCppType::value, + CppTypeToScalarType::value)>::type; + +public: + using type = std::conditional_t< + half_to_float && + (std::is_same_v::type> || + std::is_same_v::type>), + typename ScalarTypeToCppType<::executorch::aten::ScalarType::Float>::type, + promoted_type_not_respecting_half_to_float>; +}; + +// +// Helper macros for switch case macros (see below) +// +// These macros are not meant to be used directly. They provide an easy way to +// generate a switch statement that can handle subsets of ScalarTypes supported +// by ExecuTorch. +// + +#ifdef ET_INTERNAL_CHECK_SELECTIVE_BUILD +#define ET_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \ + case enum_type: { \ + ET_INTERNAL_CHECK_SELECTIVE_BUILD(enum_type); \ + using CTYPE_ALIAS = \ + ::executorch::runtime::ScalarTypeToCppType::type; \ + return __VA_ARGS__(); \ + } +#else +#define ET_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \ + case enum_type: { \ + using CTYPE_ALIAS = \ + ::executorch::runtime::ScalarTypeToCppType::type; \ + return __VA_ARGS__(); \ + } +#endif + +#define ET_INTERNAL_SWITCH(TYPE, CONTEXT, NAME, ...) \ + [&] { \ + const auto &_st = TYPE; \ + constexpr const char *et_switch_name = NAME; \ + (void)et_switch_name; /* Suppress unused var */ \ + switch (_st) { \ + __VA_ARGS__ \ + default: \ + ET_CHECK_MSG(false, "Unhandled dtype %s for %s", \ + ::executorch::runtime::toString(_st), et_switch_name); \ + } \ + }() + +#define ET_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Byte, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Char, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Short, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Int, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Long, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Half, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Float, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Double, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::ComplexHalf, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::ComplexFloat, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::ComplexDouble, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::QInt8, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::QUInt8, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::QInt32, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::BFloat16, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::QUInt4x2, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::QUInt2x4, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Bits1x8, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Bits2x4, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Bits4x2, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Bits8, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Bits16, CTYPE_ALIAS, \ + __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Byte, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Char, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Short, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Int, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Long, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Float, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Double, CTYPE_ALIAS, \ + __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND(ADDITIONAL, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::ADDITIONAL, \ + CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND2(ADDITIONAL1, ADDITIONAL2, \ + CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::ADDITIONAL1, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::ADDITIONAL2, \ + CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND3(ADDITIONAL1, ADDITIONAL2, \ + ADDITIONAL3, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND2(ADDITIONAL1, ADDITIONAL2, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::ADDITIONAL3, \ + CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Byte, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Char, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Short, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Int, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Long, CTYPE_ALIAS, \ + __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_INT_TYPES_AND(ADDITIONAL, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::ADDITIONAL, \ + CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Double, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Float, CTYPE_ALIAS, \ + __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND(ADDITIONAL, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::ADDITIONAL, \ + CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND2(ADDITIONAL1, ADDITIONAL2, \ + CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND(ADDITIONAL1, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::ADDITIONAL2, \ + CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_QINT_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::QInt8, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::QUInt8, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::QInt32, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::QUInt4x2, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::QUInt2x4, \ + CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::ComplexFloat, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::ComplexDouble, \ + CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Long, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Double, CTYPE_ALIAS, \ + __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_REAL_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Long, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Double, CTYPE_ALIAS, \ + __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_INTB_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Long, CTYPE_ALIAS, \ + __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_FLOATB_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::Double, CTYPE_ALIAS, \ + __VA_ARGS__) + +// +// Switch case macros +// +// These macros provide an easy way to generate switch statements that apply a +// common lambda function to subsets of ScalarTypes supported by ExecuTorch. +// The lambda function can type specialize to the ctype associated with the +// ScalarType being handled through an alias passed as the CTYPE_ALIAS argument. +// +// Arguments: +// - ADDITIONAL: Additional ScalarType case to add +// - TYPE: The ScalarType to handle through the switch statement +// - CONTEXT: The KernelRuntimeContext instance used for error handling, etc. +// - NAME: A name for this operation which will be used in error messages +// - CTYPE_ALIAS: A typedef for the ctype associated with the ScalarType. +// - [&](){...}: A lambda function to be applied to each ScalarType case +// +// An example usage is: +// +// ET_SWITCH_REAL_TYPES(input.scalar_type(), "example", CTYPE, [&]() { +// output.mutable_data_ptr[0] = input.const_data_ptr[0]; +// }); +// +// Note that these can be nested as well: +// +// ET_SWITCH_REAL_TYPES(input.scalar_type(), "example", CTYPE_IN, [&]() { +// ET_SWITCH_REAL_TYPES(output.scalar_type(), "example", CTYPE_OUT, [&]() { +// output.mutable_data_ptr[0] = +// input.const_data_ptr[0]; +// }); +// }); +// +// These macros are adapted from Dispatch.h in the ATen library. The primary +// difference is that the CTYPE_ALIAS argument is exposed to users, which is +// used to alias the ctype associated with the ScalarType that is being handled. +// + +#define ET_SWITCH_ALL_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_REAL_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_REAL_TYPES_AND(ADDITIONAL, TYPE, CONTEXT, NAME, CTYPE_ALIAS, \ + ...) \ + ET_INTERNAL_SWITCH(TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND( \ + ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_REAL_TYPES_AND2(ADDITIONAL1, ADDITIONAL2, TYPE, CONTEXT, \ + NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH(TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND2( \ + ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_REAL_TYPES_AND3(ADDITIONAL1, ADDITIONAL2, ADDITIONAL3, TYPE, \ + CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND3( \ + ADDITIONAL1, ADDITIONAL2, ADDITIONAL3, CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_REALH_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_SWITCH_REAL_TYPES_AND(Half, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) + +#define ET_SWITCH_REALHBF16_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_SWITCH_REAL_TYPES_AND2(Half, BFloat16, TYPE, CONTEXT, NAME, CTYPE_ALIAS, \ + __VA_ARGS__) + +#define ET_SWITCH_REALB_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_SWITCH_REAL_TYPES_AND(Bool, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) + +#define ET_SWITCH_REALHB_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_SWITCH_REAL_TYPES_AND2(Half, Bool, TYPE, CONTEXT, NAME, CTYPE_ALIAS, \ + __VA_ARGS__) + +#define ET_SWITCH_REALHBBF16_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_SWITCH_REAL_TYPES_AND3(Half, Bool, BFloat16, TYPE, CONTEXT, NAME, \ + CTYPE_ALIAS, __VA_ARGS__) + +#define ET_SWITCH_INT_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_INT_TYPES_AND(ADDITIONAL, TYPE, CONTEXT, NAME, CTYPE_ALIAS, \ + ...) \ + ET_INTERNAL_SWITCH(TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_INT_TYPES_AND( \ + ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_FLOAT_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_FLOAT_TYPES_AND(ADDITIONAL, TYPE, CONTEXT, NAME, \ + CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH(TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND( \ + ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_FLOAT_TYPES_AND2(ADDITIONAL1, ADDITIONAL2, TYPE, CONTEXT, \ + NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH(TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND2( \ + ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_FLOATH_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_SWITCH_FLOAT_TYPES_AND(Half, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) + +#define ET_SWITCH_FLOATHBF16_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_SWITCH_FLOAT_TYPES_AND2(Half, BFloat16, TYPE, CONTEXT, NAME, CTYPE_ALIAS, \ + __VA_ARGS__) + +#define ET_SWITCH_QINT_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_QINT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_COMPLEX_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_SCALAR_OBJ_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_SCALAR_OBJ_REAL_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_SCALAR_OBJ_INTB_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_INTB_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_SCALAR_OBJ_FLOATB_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, \ + ...) \ + ET_INTERNAL_SWITCH(TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_FLOATB_TYPES( \ + CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_TWO_TYPES(T1, T2, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::T1, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::T2, \ + CTYPE_ALIAS, __VA_ARGS__)) + +#define ET_SWITCH_THREE_TYPES(T1, T2, T3, TYPE, CONTEXT, NAME, CTYPE_ALIAS, \ + ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, CONTEXT, NAME, \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::T1, CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::T2, \ + CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE(::executorch::aten::ScalarType::T3, \ + CTYPE_ALIAS, __VA_ARGS__)) + +} // namespace runtime +} // namespace executorch + +namespace executorch { +namespace aten { +#ifdef USE_ATEN_LIB +using ::at::elementSize; +#else // USE_ATEN_LIB +using ::executorch::runtime::elementSize; +#endif // USE_ATEN_LIB +} // namespace aten +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::can_cast; +using ::executorch::runtime::canCast; +using ::executorch::runtime::convert; +using ::executorch::runtime::CppTypeToScalarType; +using ::executorch::runtime::elementSize; +using ::executorch::runtime::is_barebones_unsigned_type; +using ::executorch::runtime::is_bits_type; +using ::executorch::runtime::is_complex_type; +using ::executorch::runtime::is_float8_type; +using ::executorch::runtime::is_integral_type; +using ::executorch::runtime::is_qint_type; +using ::executorch::runtime::isBitsType; +using ::executorch::runtime::isComplexType; +using ::executorch::runtime::isFloatingType; +using ::executorch::runtime::isIntegralType; +using ::executorch::runtime::isQIntType; +using ::executorch::runtime::isRealHBType; +using ::executorch::runtime::isRealHType; +using ::executorch::runtime::isRealType; +using ::executorch::runtime::isValid; +using ::executorch::runtime::promote_types; +using ::executorch::runtime::promoteTypes; +using ::executorch::runtime::ScalarTypeToCppType; +using ::executorch::runtime::toString; +#if !defined(USE_ATEN_LIB) +using ::executorch::runtime::is_floating_point; +using ::executorch::runtime::is_reduced_floating_point; +#endif +namespace internal { +using ::executorch::runtime::internal::B1; +using ::executorch::runtime::internal::C2; +using ::executorch::runtime::internal::C4; +using ::executorch::runtime::internal::C8; +using ::executorch::runtime::internal::F2; +using ::executorch::runtime::internal::F4; +using ::executorch::runtime::internal::F8; +using ::executorch::runtime::internal::I1; +using ::executorch::runtime::internal::I2; +using ::executorch::runtime::internal::I4; +using ::executorch::runtime::internal::I8; +using ::executorch::runtime::internal::U1; +} // namespace internal +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/exec_aten/util/tensor_util.h b/third-party/include/executorch/runtime/core/exec_aten/util/tensor_util.h new file mode 100644 index 00000000..935f9f5c --- /dev/null +++ b/third-party/include/executorch/runtime/core/exec_aten/util/tensor_util.h @@ -0,0 +1,1230 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include // std::array +#include // PRId64 +#include +#include // size_t +#include + +#include +#include +#include +#include +#include +#include +#include + +/// All assertion messages should begin with this prefix. +#define ET_TENSOR_CHECK_PREFIX__ "Tensors do not match" +#define ET_MIN2(a, b) (std::min(a, b)) +#define ET_MIN3(a, b, c) (std::min(a, std::min(b, c))) + +#define ET_NORMALIZE_IX(IX, UPPER_BOUND) IX < 0 ? IX + UPPER_BOUND : IX + +#define ET_CHECK_VALID_IX(IX, UPPER_BOUND) \ + ET_CHECK_MSG(IX >= -static_cast(UPPER_BOUND) && \ + IX < static_cast(UPPER_BOUND), \ + "index %" PRId64 " must be within range [-%zd, %zd)", IX, \ + UPPER_BOUND, UPPER_BOUND) + +#define ET_CHECK_VALID_DIM(DIM, UPPER_BOUND) \ + ET_CHECK_MSG(DIM >= -static_cast(UPPER_BOUND) && \ + DIM < static_cast(UPPER_BOUND), \ + "dim %" PRId64 " must be within range [-%zd, %zd)", DIM, \ + UPPER_BOUND, UPPER_BOUND) + +#define ET_CHECK_NON_ZERO_DIM_SIZE(DIM, T) \ + const size_t udim = ET_NORMALIZE_IX(DIM, T.dim()); \ + ET_CHECK_MSG(T.size(udim) != 0, "Expected dim %zd to have non-zero size.", \ + udim); + +/** + * Asserts that all tensors have the same shape. + * This also handles a edge case where there is only one element in all the + * tensors being compared but the number of dimensions >= 0. In the for loop + * iterating over the dimensions we make sure that we pick the smallest + * dimension of all the tensors as the upper bound for the for loop. + */ +#define ET_CHECK_SAME_SHAPE2(a__, b__) \ + ({ \ + const size_t a_numel__ = (a__).numel(); \ + const size_t b_numel__ = (b__).numel(); \ + const size_t a_dim__ = (a__).dim(); \ + const size_t b_dim__ = (b__).dim(); \ + ET_CHECK_MSG( \ + a_numel__ == b_numel__ && \ + ((a_numel__ == 1 && b_numel__ == 1) || (a_dim__ == b_dim__)), \ + ET_TENSOR_CHECK_PREFIX__ ": numel={%zu, %zu}, dim={%zu, %zu}", \ + a_numel__, b_numel__, a_dim__, b_dim__); \ + for (size_t dim__ = 0; dim__ < ET_MIN2(a_dim__, b_dim__); ++dim__) { \ + size_t a_size__ = (a__).size(dim__); \ + size_t b_size__ = (b__).size(dim__); \ + ET_CHECK_MSG(a_size__ == b_size__, \ + ET_TENSOR_CHECK_PREFIX__ " at size(%zu): {%zu, %zu}", \ + dim__, a_size__, b_size__); \ + } \ + }) + +#define ET_CHECK_SAME_SHAPE3(a__, b__, c__) \ + ({ \ + const size_t a_numel__ = (a__).numel(); \ + const size_t b_numel__ = (b__).numel(); \ + const size_t c_numel__ = (c__).numel(); \ + const size_t a_dim__ = (a__).dim(); \ + const size_t b_dim__ = (b__).dim(); \ + const size_t c_dim__ = (c__).dim(); \ + ET_CHECK_MSG(a_numel__ == b_numel__ && b_numel__ == c_numel__ && \ + ((a_numel__ == 1 && b_numel__ == 1 && c_numel__ == 1) || \ + a_dim__ == b_dim__ && b_dim__ == c_dim__), \ + ET_TENSOR_CHECK_PREFIX__ \ + ": numel={%zu, %zu, %zu}, dim={%zu, %zu, %zu}", \ + a_numel__, b_numel__, c_numel__, a_dim__, b_dim__, c_dim__); \ + for (size_t dim__ = 0; dim__ < ET_MIN3(a_dim__, b_dim__, c_dim__); \ + ++dim__) { \ + size_t a_size__ = (a__).size(dim__); \ + size_t b_size__ = (b__).size(dim__); \ + size_t c_size__ = (c__).size(dim__); \ + ET_CHECK_MSG(a_size__ == b_size__ && b_size__ == c_size__, \ + ET_TENSOR_CHECK_PREFIX__ " at size(%zu): {%zu, %zu, %zu}", \ + dim__, a_size__, b_size__, c_size__); \ + } \ + }) + +/// Asserts that all tensors have the same dtype. +#define ET_CHECK_SAME_DTYPE2(a__, b__) \ + ({ \ + const ::executorch::aten::ScalarType a_type__ = (a__).scalar_type(); \ + const ::executorch::aten::ScalarType b_type__ = (b__).scalar_type(); \ + ET_CHECK_MSG(a_type__ == b_type__, \ + ET_TENSOR_CHECK_PREFIX__ ": dtype={%" PRId8 ", %" PRId8 "}", \ + static_cast(a_type__), \ + static_cast(b_type__)); \ + }) + +#define ET_CHECK_SAME_DTYPE3(a__, b__, c__) \ + ({ \ + const ::executorch::aten::ScalarType a_type__ = (a__).scalar_type(); \ + const ::executorch::aten::ScalarType b_type__ = (b__).scalar_type(); \ + const ::executorch::aten::ScalarType c_type__ = (c__).scalar_type(); \ + ET_CHECK_MSG(a_type__ == b_type__ && b_type__ == c_type__, \ + ET_TENSOR_CHECK_PREFIX__ ": dtype={%" PRId8 ", %" PRId8 \ + ", %" PRId8 "}", \ + static_cast(a_type__), static_cast(b_type__), \ + static_cast(c_type__)); \ + }) + +/** + * Asserts that all tensors have the same shape and dtype. + * + * This macro should produce less code/data than calling the SHAPE and DTYPE + * macros independently, because it only calls ET_CHECK_MSG once. + */ +#define ET_CHECK_SAME_SHAPE_AND_DTYPE2(a__, b__) \ + ({ \ + const size_t a_numel__ = (a__).numel(); \ + const size_t b_numel__ = (b__).numel(); \ + const size_t a_dim__ = (a__).dim(); \ + const size_t b_dim__ = (b__).dim(); \ + const ::executorch::aten::ScalarType a_type__ = (a__).scalar_type(); \ + const ::executorch::aten::ScalarType b_type__ = (b__).scalar_type(); \ + \ + ET_CHECK_MSG( \ + a_numel__ == b_numel__ && \ + ((a_numel__ == 1 && b_numel__ == 1) || a_dim__ == b_dim__) && \ + a_type__ == b_type__, \ + ET_TENSOR_CHECK_PREFIX__ \ + ": numel={%zu, %zu}, dim={%zu, %zu}, dtype={%" PRId8 ", %" PRId8 "}", \ + a_numel__, b_numel__, a_dim__, b_dim__, static_cast(a_type__), \ + static_cast(b_type__)); \ + for (size_t dim__ = 0; dim__ < ET_MIN2(a_dim__, b_dim__); ++dim__) { \ + size_t a_size__ = (a__).size(dim__); \ + size_t b_size__ = (b__).size(dim__); \ + ET_CHECK_MSG(a_size__ == b_size__, \ + ET_TENSOR_CHECK_PREFIX__ " at size(%zu): {%zu, %zu}", \ + dim__, a_size__, b_size__); \ + } \ + }) + +#define ET_CHECK_SAME_SHAPE_AND_DTYPE3(a__, b__, c__) \ + ({ \ + const size_t a_numel__ = (a__).numel(); \ + const size_t b_numel__ = (b__).numel(); \ + const size_t c_numel__ = (c__).numel(); \ + const size_t a_dim__ = (a__).dim(); \ + const size_t b_dim__ = (b__).dim(); \ + const size_t c_dim__ = (c__).dim(); \ + const ::executorch::aten::ScalarType a_type__ = (a__).scalar_type(); \ + const ::executorch::aten::ScalarType b_type__ = (b__).scalar_type(); \ + const ::executorch::aten::ScalarType c_type__ = (c__).scalar_type(); \ + \ + ET_CHECK_MSG(a_numel__ == b_numel__ && b_numel__ == c_numel__ && \ + ((a_numel__ == 1 && b_numel__ == 1 && c_numel__ == 1) || \ + (a_dim__ == b_dim__ && b_dim__ == c_dim__)) && \ + a_type__ == b_type__ && b_type__ == c_type__, \ + ET_TENSOR_CHECK_PREFIX__ \ + ": numel={%zu, %zu, %zu}, dim={%zu, %zu, %zu}, " \ + "dtype={%" PRId8 ", %" PRId8 ", %" PRId8 "}", \ + a_numel__, b_numel__, c_numel__, a_dim__, b_dim__, c_dim__, \ + static_cast(a_type__), static_cast(b_type__), \ + static_cast(c_type__)); \ + for (size_t dim__ = 0; dim__ < ET_MIN3(a_dim__, b_dim__, c_dim__); \ + ++dim__) { \ + size_t a_size__ = (a__).size(dim__); \ + size_t b_size__ = (b__).size(dim__); \ + size_t c_size__ = (c__).size(dim__); \ + ET_CHECK_MSG(a_size__ == b_size__ && b_size__ == c_size__, \ + ET_TENSOR_CHECK_PREFIX__ " at size(%zu): {%zu, %zu, %zu}", \ + dim__, a_size__, b_size__, c_size__); \ + } \ + }) + +/** + * Assert that the input tensor is contiguous tensor. + */ +#define ET_CHECK_CONTIGUOUS(a__) \ + ({ \ + const ::executorch::aten::ArrayRef \ + strides = a__.strides(); \ + const ::executorch::aten::ArrayRef sizes = \ + a__.sizes(); \ + ET_CHECK_MSG( \ + strides[strides.size() - 1] == 1, \ + "The stride of the last dimension shall be 1 for contiguous tensor, " \ + "not %d", \ + strides[strides.size() - 1]); \ + for (size_t i = strides.size() - 1; i > 0; i--) { \ + ET_CHECK_MSG(strides[i - 1] == strides[i] * sizes[i], \ + "The stride of the %zu-th dimension shall equal to " \ + "strides[%zu] * sizes[%zu], now is %d and %d", \ + i - 1, i, i, strides[i - 1], strides[i] * sizes[i]); \ + } \ + }) + +/** + * Assert the input two tensors share same strides. + * Noted that this function does not make any check or promise on the contiguity + * of any input tensors. + */ +#define ET_CHECK_SAME_STRIDES2(a__, b__) \ + ({ \ + ET_CHECK_MSG( \ + a__.dim() == b__.dim(), \ + "Two tensors shall have same number of strides, but not %zu and %zu.", \ + a__.dim(), b__.dim()); \ + const ::executorch::aten::ArrayRef \ + a_strides = a__.strides(); \ + const ::executorch::aten::ArrayRef \ + b_strides = b__.strides(); \ + for (size_t i = 0; i < a__.dim(); i++) { \ + ET_CHECK_MSG(a_strides[i] == b_strides[i], \ + "a.strides()[%zu] shall equal to b.strides()[%zu], " \ + "but now is %d and %d.", \ + i, i, (int32_t)a_strides[i], (int32_t)b_strides[i]); \ + } \ + }) + +/** + * Assert the input three tensors share same strides. + * Noted that this function does not make any check or promise on the contiguity + * of any input tensors. + */ +#define ET_CHECK_SAME_STRIDES3(a__, b__, c__) \ + ({ \ + ET_CHECK_MSG(a__.dim() == b__.dim() && b__.dim() == c__.dim(), \ + "Three tensors shall have same number of strides, " \ + "but not %zu, %zu and %zu.", \ + a__.dim(), b__.dim(), c__.dim()); \ + const ::executorch::aten::ArrayRef \ + a_strides = a__.strides(); \ + const ::executorch::aten::ArrayRef \ + b_strides = b__.strides(); \ + const ::executorch::aten::ArrayRef \ + c_strides = c__.strides(); \ + for (size_t i = 0; i < a__.dim(); i++) { \ + ET_CHECK_MSG(a_strides[i] == b_strides[i] && \ + b_strides[i] == c_strides[i], \ + "a_strides[%zu], b_strides[%zu] and c_strides[%zu] " \ + "shall share same value, but now is %d, %d and %d", \ + i, i, i, (int32_t)a_strides[i], (int32_t)b_strides[i], \ + (int32_t)c_strides[i]); \ + } \ + }) + +#define ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(t__) \ + ({ \ + ET_CHECK_MSG(is_contiguous_dim_order(t__.dim_order().data(), \ + t__.dim_order().size()) || \ + is_channels_last_dim_order(t__.dim_order().data(), \ + t__.dim_order().size()), \ + "Tensor must have default or channels last dim order"); \ + }) + +/** + * A convenience macro to be used in utility functions that check whether input + * tensor(s) are valid, which are expected to return a boolean. Checks whether + * `cond` is true; if not, log the failed check and return false. + * + * @param[in] cond the condition to check + */ +#define ET_LOG_AND_RETURN_IF_FALSE(cond) \ + do { \ + if (!(cond)) { \ + ET_LOG(Error, "Check failed (%s): ", #cond); \ + return false; \ + } \ + } while (false) + +/** + * A convenience macro to be used in utility functions that check whether input + * tensor(s) are valid, which are expected to return a boolean. Checks whether + * `cond` is true; if not, log the failed check with `message` and return false. + * + * @param[in] cond the condition to check + * @param[in] message an additional message to log with `cond` + */ +#define ET_LOG_MSG_AND_RETURN_IF_FALSE(cond, message, ...) \ + do { \ + if (!(cond)) { \ + ET_LOG(Error, "Check failed (%s): " message, #cond, ##__VA_ARGS__); \ + return false; \ + } \ + } while (false) + +/** + * If `cond` is false, log `cond` and return from the kernel with a failure + * state set. + * + * @param[in] context the runtime context + * @param[in] cond the condition to check + * @param[in] error torch::executor::Error enum value (e.g `InvalidArgument`) + * @param[in] retval return value of the kernel to allow for early exit + */ +#define ET_KERNEL_CHECK(context, cond, error, retval) \ + do { \ + if (!(cond)) { \ + ET_LOG(Error, "Check failed (%s): ", #cond); \ + context.fail(torch::executor::Error::error); \ + return retval; \ + } \ + } while (false) + +/** + * If `cond` is false, log `message` and return from the kernel with a failure + * state set. + * + * @param[in] context the runtime context + * @param[in] cond the condition to check + * @param[in] error torch::executor::Error enum value (e.g `InvalidArgument`) + * @param[in] retval return value of the kernel to allow for early exit + */ +#define ET_KERNEL_CHECK_MSG(context, cond, error, retval, message, ...) \ + do { \ + if (!(cond)) { \ + ET_LOG(Error, "Check failed (%s): " message, #cond, ##__VA_ARGS__); \ + context.fail(torch::executor::Error::error); \ + return retval; \ + } \ + } while (false) + +/** + * Convenience macro to extract a scalar tensor into a value + */ +#define ET_EXTRACT_SCALAR_TENSOR(scalar_tensor, out_val) \ + ET_CHECK_MSG(extract_scalar_tensor(scalar_tensor, &out_val), #scalar_tensor \ + " could not be extracted: wrong type or out of range"); + +namespace executorch { +namespace runtime { + +// +// Utility functions for checking tensor attributes +// +// + +/* + * Returns true if the given dimension value is between -upper_bound and + * upper_bound - 1, inclusive. + */ +inline bool dim_is_valid(int64_t dim, int64_t upper_bound) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + dim >= -upper_bound && dim < upper_bound, + "Dimension %" PRId64 + " is out of range. Dimension should be between %" PRId64 " and %" PRId64 + ", inclusive.", + dim, -upper_bound, upper_bound - 1); + + return true; +} + +/* + * Returns the tensor's number of dimensions, except when the tensor is zero + * dimensional. In this case, it returns 1. This is used to properly handle + * the zero dimensional tensors in some kernels, that treat them as 1D tensors + * with a single element. + */ +inline ssize_t nonzero_dim(const executorch::aten::Tensor &tensor) { + return tensor.dim() == 0 ? 1 : tensor.dim(); +} + +/* + * Returns the size along a dimension dim, except when the tensor is zero + * dimensional. In this case, it returns 1. This is used to properly handle + * the zero dimensional tensors in some kernels, that treat them as 1D tensors + * with a single element. + */ +inline ssize_t nonempty_size(const executorch::aten::Tensor &tensor, + ssize_t dim) { + return tensor.dim() == 0 ? 1 : tensor.size(dim); +} + +inline bool tensor_can_cast_to(executorch::aten::Tensor a, + executorch::aten::ScalarType dtype) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + torch::executor::canCast(a.scalar_type(), dtype), + "Tensor of dtype %s cannot cast to dtype %s", + torch::executor::toString(a.scalar_type()), + torch::executor::toString(dtype)); + + return true; +} + +inline bool tensor_is_bool_type(executorch::aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + t.scalar_type() == executorch::aten::ScalarType::Bool, + "Expected to find bool type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensor_is_type(executorch::aten::Tensor t, + executorch::aten::ScalarType dtype) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + t.scalar_type() == dtype, + "Expected to find %s type, but tensor has type %s", + torch::executor::toString(dtype), + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensor_is_integral_type(executorch::aten::Tensor t, + bool includeBool = false) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + torch::executor::isIntegralType(t.scalar_type(), includeBool), + "Expected to find a integral type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensor_is_floating_type(executorch::aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + torch::executor::isFloatingType(t.scalar_type()), + "Expected to find a floating type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensor_is_real_type(executorch::aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + torch::executor::isRealType(t.scalar_type()), + "Expected to find a real type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensor_is_realh_type(executorch::aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + torch::executor::isRealHType(t.scalar_type()), + "Expected to find a real type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensor_is_realhbf16_type(executorch::aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + executorch::runtime::isRealHBF16Type(t.scalar_type()), + "Expected to find a real type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensor_is_realhb_type(executorch::aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + torch::executor::isRealHBType(t.scalar_type()), + "Expected to find a real type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensor_is_realhbbf16_type(executorch::aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + executorch::runtime::isRealHBBF16Type(t.scalar_type()), + "Expected to find a real type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensor_is_complex_type(executorch::aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + torch::executor::isComplexType(t.scalar_type()), + "Expected to find a complex type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensor_is_bits_type(executorch::aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + torch::executor::isBitsType(t.scalar_type()), + "Expected to find a bits type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensors_have_same_dtype(executorch::aten::Tensor a, + executorch::aten::Tensor b) { + ET_LOG_MSG_AND_RETURN_IF_FALSE(a.scalar_type() == b.scalar_type(), + ET_TENSOR_CHECK_PREFIX__ ": dtype={%s, %s}", + torch::executor::toString(a.scalar_type()), + torch::executor::toString(b.scalar_type())); + return true; +} + +inline bool tensors_have_same_dtype(executorch::aten::Tensor a, + executorch::aten::Tensor b, + executorch::aten::Tensor c) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + a.scalar_type() == b.scalar_type() && b.scalar_type() == c.scalar_type(), + ET_TENSOR_CHECK_PREFIX__ ": dtype={%s, %s, %s}", + torch::executor::toString(a.scalar_type()), + torch::executor::toString(b.scalar_type()), + torch::executor::toString(c.scalar_type())); + return true; +} + +inline bool tensor_is_rank(executorch::aten::Tensor t, size_t rank) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + t.dim() == rank, "Expected tensor.dim() to be %zu, but got %zu", + static_cast(rank), static_cast(t.dim())); + + return true; +} + +inline bool tensor_has_rank_greater_or_equal_to(executorch::aten::Tensor t, + size_t rank) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + t.dim() >= rank, "Expected tensor.dim() to be >= %zu, but got %zu", + static_cast(rank), static_cast(t.dim())); + + return true; +} + +inline bool tensor_has_rank_smaller_or_equal_to(executorch::aten::Tensor t, + size_t rank) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + t.dim() <= rank, "Expected tensor.dim() to be <= %zu, but got %zu", + static_cast(rank), static_cast(t.dim())); + + return true; +} + +inline bool tensor_has_dim(executorch::aten::Tensor t, int64_t d) { + if (t.dim() == 0) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + d == 0 || d == -1, "dim must be 0 or -1 for 0-dim tensor, got %" PRId64, + d); + } else { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + d > 0 ? d < t.dim() : t.dim() + d >= 0, + "%zu-dim tensor does not have dim at index %zu", + static_cast(t.dim()), static_cast(d)); + } + return true; +} + +inline bool tensor_has_non_empty_dim(executorch::aten::Tensor t, int64_t d) { + const size_t udim = ET_NORMALIZE_IX(d, t.dim()); + ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(t, d)); + ET_LOG_AND_RETURN_IF_FALSE(t.size(udim) != 0); + return true; +} + +inline bool tensor_dim_has_index(executorch::aten::Tensor t, int64_t d, + int64_t ix) { + // Indexing ops don't support zero-dim tensors + ET_CHECK(t.dim() != 0); + if (d < 0) { + d += t.dim(); + } + // Dimension must have been already checked by tensor_has_dim + ET_CHECK(d >= 0 && d < t.dim()); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + ix >= -t.size(d) && ix < t.size(d), + "index %" PRId64 " out of range [-%zu,%zu) at dimension %" PRId64 ")", ix, + static_cast(t.size(d)), static_cast(t.size(d)), d); + return true; +} + +inline bool tensors_have_same_size_at_dims(executorch::aten::Tensor a, + size_t dim_a, + executorch::aten::Tensor b, + size_t dim_b) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + dim_a < a.dim(), "Cannot retrieve dim %zu from tensor with dim %zu", + static_cast(dim_a), static_cast(a.dim())); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + dim_b < b.dim(), "Cannot retrieve dim %zu from tensor with dim %zu", + static_cast(dim_b), static_cast(b.dim())); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + a.size(dim_a) == b.size(dim_b), + ET_TENSOR_CHECK_PREFIX__ + ": a.size(%zu) = %zu does not match b.size(%zu) = %zu", + static_cast(dim_a), static_cast(a.size(dim_a)), + static_cast(dim_b), static_cast(b.size(dim_b))); + + return true; +} + +inline bool tensors_have_same_shape(executorch::aten::Tensor a, + executorch::aten::Tensor b) { + if (a.numel() == 1 && b.numel() == 1) { + // PyTorch operators treat all scalar tensors as the same shape even if + // they have different dims. + return true; + } + if (!(a.sizes() == b.sizes() && a.numel() == b.numel())) { + ET_LOG(Error, + ET_TENSOR_CHECK_PREFIX__ ": numel=(%zu, %zu), dim=(%zu, %zu)", + static_cast(a.numel()), static_cast(b.numel()), + static_cast(a.dim()), static_cast(b.dim())); + for (size_t d = 0; d < ET_MIN2(a.dim(), b.dim()); ++d) { + ET_LOG(Error, " size(%zu): (%zu, %zu)", static_cast(d), + static_cast(a.size(d)), static_cast(b.size(d))); + } + + return false; + } + + return true; +} + +inline bool tensors_have_same_shape(executorch::aten::Tensor a, + executorch::aten::Tensor b, + executorch::aten::Tensor c) { + if (a.numel() == 1 && b.numel() == 1 && c.numel() == 1) { + // PyTorch operators treat all scalar tensors as the same shape even if + // they have different dims. + return true; + } + bool cond1 = (a.sizes() == b.sizes()) && (a.numel() == b.numel()); + bool cond2 = (b.sizes() == c.sizes()) && (b.numel() == c.numel()); + + if (!(cond1 && cond2)) { + ET_LOG(Error, + ET_TENSOR_CHECK_PREFIX__ + ": numel=(%zu, %zu, %zu), dim=(%zu, %zu, %zu)", + static_cast(a.numel()), static_cast(b.numel()), + static_cast(c.numel()), static_cast(a.dim()), + static_cast(b.dim()), static_cast(c.dim())); + for (size_t d = 0; d < ET_MIN3(a.dim(), b.dim(), c.dim()); ++d) { + ET_LOG(Error, " size(%zu): (%zu, %zu, %zu)", static_cast(d), + static_cast(a.size(d)), static_cast(b.size(d)), + static_cast(c.size(d))); + } + + return false; + } + + return true; +} + +inline bool tensors_have_same_shape_and_dtype(executorch::aten::Tensor a, + executorch::aten::Tensor b) { + return tensors_have_same_shape(a, b) && tensors_have_same_dtype(a, b); +} + +inline bool tensors_have_same_shape_and_dtype(executorch::aten::Tensor a, + executorch::aten::Tensor b, + executorch::aten::Tensor c) { + return tensors_have_same_shape(a, b, c) && tensors_have_same_dtype(a, b, c); +} + +inline bool tensor_has_expected_size( + executorch::aten::Tensor a, + executorch::aten::ArrayRef expected_sizes) { + if (!(a.sizes() == expected_sizes)) { + ET_LOG(Error, ET_TENSOR_CHECK_PREFIX__ ": dim=(%zu, %zu)", + static_cast(a.dim()), + static_cast(expected_sizes.size())); + size_t a_dim = static_cast(a.dim()); + size_t expected_dim = static_cast(expected_sizes.size()); + for (size_t d = 0; d < ET_MIN2(a_dim, expected_dim); ++d) { + ET_LOG(Error, " size(%zu): (%zu, %zu)", static_cast(d), + static_cast(a.size(d)), + static_cast(expected_sizes[d])); + } + + return false; + } + return true; +} + +inline bool tensors_have_same_strides(executorch::aten::Tensor a, + executorch::aten::Tensor b) { + if (a.strides() != b.strides()) { + ET_LOG(Error, ET_TENSOR_CHECK_PREFIX__ ": dim=(%zu, %zu)", + static_cast(a.dim()), static_cast(b.dim())); + for (size_t d = 0; d < ET_MIN2(a.dim(), b.dim()); ++d) { + ET_LOG(Error, " stride(%zu): (%zu, %zu)", static_cast(d), + static_cast(a.strides()[d]), + static_cast(b.strides()[d])); + } + + return false; + } + return true; +} + +inline bool tensors_have_same_strides(executorch::aten::Tensor a, + executorch::aten::Tensor b, + executorch::aten::Tensor c) { + if (!(a.strides() == b.strides() && b.strides() == c.strides())) { + ET_LOG(Error, ET_TENSOR_CHECK_PREFIX__ ": dim=(%zu, %zu, %zu)", + static_cast(a.dim()), static_cast(b.dim()), + static_cast(c.dim())); + for (size_t d = 0; d < ET_MIN3(a.dim(), b.dim(), c.dim()); ++d) { + ET_LOG(Error, " stride(%zu): (%zu, %zu, %zu)", static_cast(d), + static_cast(a.strides()[d]), + static_cast(b.strides()[d]), + static_cast(c.strides()[d])); + } + + return false; + } + return true; +} + +inline bool tensor_is_contiguous(executorch::aten::Tensor t) { + const auto strides = t.strides(); + const auto sizes = t.sizes(); + // If tensor is 0-dim (i.e. a scalar tensor) it is contiguous + if (strides.size() == 0) { + return true; + } + ET_LOG_MSG_AND_RETURN_IF_FALSE( + strides[strides.size() - 1] == 1, + "Tensor is not contiguous; the stride of the last dimension must be 1, " + "but got %zu", + static_cast(strides[strides.size() - 1])); + for (int i = strides.size() - 1; i > 0; --i) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + strides[i - 1] == strides[i] * sizes[i], + "Tensor is not contiguous; the stride of dim %zu should be equal to " + "strides[%zu] * sizes[%zu] = %zu, but found %zu", + static_cast(i - 1), static_cast(i), + static_cast(i), static_cast(strides[i] * sizes[i]), + static_cast(strides[i - 1])); + } + return true; +} + +inline bool tensors_have_same_rank(executorch::aten::Tensor a, + executorch::aten::Tensor b) { + ET_LOG_MSG_AND_RETURN_IF_FALSE(a.dim() == b.dim(), + ET_TENSOR_CHECK_PREFIX__ ": rank={%zd, %zd}", + ssize_t(a.dim()), ssize_t(b.dim())); + return true; +} + +inline bool tensor_is_scalar(executorch::aten::Tensor t) { + return t.dim() == 0 && t.numel() == 1; +} + +/** + * The expected output size may not be the existing size of any inputs and + * outputs if the operator supports both broadcast and dynamic shape. + * Therefore such operators needs extra space to store the calculated expected + * output size. such dynamic allocation is troublesome in executorch so we can + * just hard code a static value of a relatively small value because users + * don't create high dimensional tensors. + */ +constexpr size_t kTensorDimensionLimit = 16; + +/// Returns the product of dim[0:dim), not including dim. +inline size_t getLeadingDims(const executorch::aten::Tensor &tensor, + int64_t dim) { + ET_CHECK_MSG(dim >= 0 && dim <= tensor.dim(), + "Ending dimension %" PRId64 + " should be in the range [0, tensor.dim() %zd].", + dim, ssize_t(tensor.dim())); + size_t dims = 1; + for (size_t i = 0; i < dim; ++i) { + dims *= static_cast(tensor.size(i)); + } + return dims; +} + +/// Returns the product of dim[dim+1:]. +inline size_t getTrailingDims(const executorch::aten::Tensor &tensor, + int64_t dim) { + ET_CHECK_MSG(dim >= -1 && dim < tensor.dim(), + "Starting dimension %" PRId64 + " should be in the range [-1, tensor.dim() -1 %zd).", + dim, ssize_t(tensor.dim())); + size_t dims = 1; + for (size_t i = dim + 1; i < tensor.dim(); ++i) { + dims *= static_cast(tensor.size(i)); + } + return dims; +} + +/** + * Given a N-dimensional tensor coordinate, return a linear index that can be + * used to access the corresponding element in the tensor's data buffer. + * + * @param[in] tensor The tensor that will be indexed + * @param[in] coordinate A n-dimensional array representing the coordinate to + * index. It is assumed that the array has kTensorDimensionLimit elements. + * @param[out] index The linear index to element at the specified coordinate + * in the tensor. + */ +inline size_t coordinateToIndex(const executorch::aten::Tensor &tensor, + const size_t *const coordinate) { + size_t index = 0; + for (int d = 0; d < tensor.dim(); ++d) { + index += coordinate[d] * getTrailingDims(tensor, d); + } + return index; +} + +/** + * Produce a memoized array for use with repeated calls to + * coordinateToIndexWithTrailingDimsMemo, which will be faster than + * repeated calls to coordinateToIndex. + */ +inline void +memoizeTrailingDims(const executorch::aten::Tensor &tensor, + size_t trailing_dims_memo[kTensorDimensionLimit]) { + const auto tensorDim = tensor.dim(); + size_t dims = 1; + for (int ii = tensorDim - 1; ii >= 0; --ii) { + trailing_dims_memo[ii] = dims; + dims *= static_cast(tensor.size(ii)); + } +} + +/** + * Like coordinateToIndex, but faster for repeated calls with the same + * tensor. trailing_dims_memo must be produced by a call to + * memoizeTrailingDims. + */ +inline size_t coordinateToIndexWithTrailingDimsMemo( + const executorch::aten::Tensor &tensor, const size_t *const coordinate, + const size_t trailing_dims_memo[kTensorDimensionLimit]) { + size_t index = 0; + for (int d = 0; d < tensor.dim(); ++d) { + index += coordinate[d] * trailing_dims_memo[d]; + } + return index; +} + +/** + * Given the linear index return the N-dimensional tensor coordinate. This is + * the inverse operation of coordinateToIndex. + * + * @param[in] tensor The tensor that will be indexed + * @param[in] index The linear index to element at the specified coordinate in + * the tensor. + * @param[out] coordinate A n-dimensional array representing the coordinate to + * index. It is assumed that the array has kTensorDimensionLimit elements. + * @returns void + */ +inline void indexToCoordinate(const executorch::aten::Tensor &tensor, + size_t index, size_t *coordinate) { + ET_CHECK(index < tensor.numel()); + for (auto i = 0; i < tensor.dim(); ++i) { + auto dim = tensor.dim() - 1 - i; + size_t dim_size = tensor.size(dim); + coordinate[dim] = index % dim_size; + index /= dim_size; + } +} + +/** + * Extracts an integer value from a scalar Tensor. + * + * @param[in] tensor The source of the value to extract. + * @param[out] out_val The extracted value, on success. + * @returns `true` if a value was extracted, and sets `*out_val` to that + * value. `false` if a value could not be extracted: either it was not an + * integer Scalar Tensor, or the value of that Scalar Tensor could not be + * represented by INT_T. + */ +template ::value && + !std::is_same::value, + bool>::type = true> +bool extract_scalar_tensor(executorch::aten::Tensor tensor, INT_T *out_val) { + if (tensor.numel() != 1) { + return false; + } +#define CASE_INT_DTYPE(TENSOR_CTYPE, TENSOR_DTYPE) \ + case executorch::aten::ScalarType::TENSOR_DTYPE: { \ + const TENSOR_CTYPE val = tensor.const_data_ptr()[0]; \ + if (val < std::numeric_limits::lowest() || \ + val > std::numeric_limits::max()) { \ + return false; \ + } \ + *out_val = static_cast(val); \ + return true; \ + } + + switch (tensor.scalar_type()) { + ET_FORALL_INT_TYPES(CASE_INT_DTYPE); + default: + return false; + } +#undef CASE_INT_DTYPE +} + +/** + * Extracts a floating point value from a scalar Tensor. + * + * @param[in] tensor The source of the value to extract. + * @param[out] out_val The extracted value, on success. + * @returns `true` if a value was extracted, and sets `*out_val` to that + * value. `false` if a value could not be extracted: either it was not a + * floating point Scalar Tensor, or the value of that Scalar Tensor could not + * be represented by FLOAT_T. + */ +template ::value, + bool>::type = true> +bool extract_scalar_tensor(executorch::aten::Tensor tensor, FLOAT_T *out_val) { + if (tensor.numel() != 1) { + return false; + } +#define CASE_REAL_DTYPE(TENSOR_CTYPE, TENSOR_DTYPE) \ + case executorch::aten::ScalarType::TENSOR_DTYPE: { \ + /* ET_FORALL_REAL_TYPES guarantees TENSOR_CTYPE is a real type. */ \ + double val = \ + static_cast(tensor.const_data_ptr()[0]); \ + if (std::isfinite(val) && (val < std::numeric_limits::lowest() || \ + val > std::numeric_limits::max())) { \ + return false; \ + } \ + *out_val = static_cast(val); \ + return true; \ + } + + switch (tensor.scalar_type()) { + ET_FORALL_REAL_TYPES(CASE_REAL_DTYPE); + default: + return false; + } +#undef CASE_REAL_DTYPE +} + +/** + * Extracts a boolean value from a Scalar. + * + * @param[in] scalar The source of the value to extract. + * @param[out] out_val The extracted value, on success. + * @returns `true` if a value was extracted, and sets `*out_val` to that + * value. `false` if a value could not be extracted, i.e. not a boolean + */ +template ::value, + bool>::type = true> +bool extract_scalar_tensor(executorch::aten::Tensor tensor, BOOL_T *out_val) { + if (tensor.scalar_type() != executorch::aten::ScalarType::Bool) { + return false; + } + if (tensor.numel() != 1) { + return false; + } + + bool val = tensor.const_data_ptr()[0]; + + *out_val = static_cast(val); + + return true; +} + +/// These APIs should not be used outside of Executor.cpp. +namespace internal { +/** + * Share t_src's data_ptr with t_dst. + */ +ET_NODISCARD Error share_tensor_data(const executorch::aten::Tensor &t_dst, + const executorch::aten::Tensor &t_src); + +/** + * Copy t_src's data_ptr to t_dst. + */ +ET_NODISCARD Error copy_tensor_data(const executorch::aten::Tensor &t_dst, + const executorch::aten::Tensor &t_src); + +/** + * Set the data_ptr of t to buffer. + */ +ET_NODISCARD Error set_tensor_data(const executorch::aten::Tensor &t, + void *buffer, size_t buffer_size); + +/** + * Reset tensor's data_ptr, clear all the storage for at::Tensor. + */ +void reset_data_ptr(const executorch::aten::Tensor &tensor); + +/** + * Resize tensor impl + */ +ET_NODISCARD Error resize_tensor_impl( + executorch::aten::TensorImpl *impl, + executorch::aten::ArrayRef new_sizes); + +} // namespace internal + +/** + * Resize a tensor to new_sizes, rank must stay the same. Currently does not + * expand the tensor if new size exceeds the current capacity. Currently + * fails an ET_CHECK if the tensor cannot be resized. + * + * WARNING: Placeholder API until discussion around runtime context is + * settled, will likely move to be a class method on a TensorResizer object + * passed in through runtimeContext. + */ +ET_NODISCARD inline Error resize_tensor( + executorch::aten::Tensor t, + executorch::aten::ArrayRef new_sizes) { + return internal::resize_tensor_impl(t.unsafeGetTensorImpl(), new_sizes); +} + +/** + * Resize a tensor to new_sizes, rank must stay the same. Currently does not + * expand the tensor if new size exceeds the current capacity. Currently + * fails an ET_CHECK if the tensor cannot be resized. + * + * WARNING: Placeholder API until discussion around runtime context is + * settled, will likely move to be a class method on a TensorResizer object + * passed in through runtimeContext. + */ +template ::value, + int>::type = 0> +ET_NODISCARD inline Error +resize_tensor(executorch::aten::Tensor t, + executorch::aten::ArrayRef new_sizes) { + // Need to cast the input array to an array of Tensor::SizesType + std::array + new_sizes_casted{}; + size_t new_sizes_ndim = new_sizes.size(); + for (size_t i = 0; i < new_sizes_ndim; ++i) { + new_sizes_casted[i] = + static_cast(new_sizes[i]); + } + + return internal::resize_tensor_impl( + t.unsafeGetTensorImpl(), {new_sizes_casted.data(), new_sizes_ndim}); +} + +/// DEPRECATED: Use `resize_tensor()` instead, which can fail non-fatally. +ET_DEPRECATED inline void +resize(executorch::aten::Tensor t, + executorch::aten::ArrayRef new_sizes) { + Error err = resize_tensor(t, new_sizes); + ET_CHECK_MSG(err == Error::Ok, + "Could not resize Tensor; see logs for details"); +} +/** + * Get dim_order of a Tensor and write it to out_dim_order. + * @param tensor The tensor where we want to get dim order from. + * @param out_dim_order Pointing to an array of DimOrderType where we write + * dim order into it. + * @param out_dim_order_size Size of the DimOrderType array. + */ +ET_NODISCARD Error get_dim_order(const executorch::aten::Tensor &tensor, + executorch::aten::DimOrderType *out_dim_order, + size_t out_dim_order_size); + +/** + * Checks whether a tensor has a valid dim order. If the dim order could not + * be determined, then this function returns false by default. + */ +bool tensor_has_valid_dim_order(executorch::aten::Tensor t); + +/** + * Checks whether a tensor has either the default of channels last dim order. + * If the dim order could not be determined, then this function returns false + * by default. + */ +bool tensor_is_default_or_channels_last_dim_order(executorch::aten::Tensor t); + +/** + * Checks whether a tensor has the default dimension order. + * Logs an error message if the tensor does not meet the expected criteria. + * + * @param t The tensor to check the dimension order of. + * @return True if the tensor has the default dimension order, false otherwise. + */ +bool tensor_is_default_dim_order(executorch::aten::Tensor t); + +/** + * Checks whether a tensor has the channels last dimension order. + * Logs an error message if the tensor does not meet the expected criteria. + * + * @param t The tensor to check the dimension order of. + * @return True if the tensor has the channels last dimension order, false + * otherwise. + */ +bool tensor_is_channels_last_dim_order(executorch::aten::Tensor t); + +/** + * Asserts that four tensors have the same dim_order + * + * Note that this macro only tests dim order, but not others like actual data, + * sizes, etc. + * + */ +bool tensors_have_same_dim_order( + const executorch::aten::ArrayRef tensor_list); + +/** + * Asserts that two tensors have the same dim_order + * + * Note that this macro only tests dim order, but not others like actual data, + * sizes, etc. + */ + +inline bool tensors_have_same_dim_order(const executorch::aten::Tensor &a, + const executorch::aten::Tensor &b) { + executorch::aten::Tensor tensor_list[2] = {a, b}; + return tensors_have_same_dim_order(tensor_list); +} + +/** + * Asserts that three tensors have the same dim_order + * + * Note that this macro only tests dim order, but not others like actual data, + * sizes, etc. + * + */ + +inline bool tensors_have_same_dim_order(const executorch::aten::Tensor &a, + const executorch::aten::Tensor &b, + const executorch::aten::Tensor &c) { + executorch::aten::Tensor tensor_list[3] = {a, b, c}; + return tensors_have_same_dim_order(tensor_list); +} + +/** + * Asserts that four tensors have the same dim_order + * + * Note that this macro only tests dim order, but not others like actual data, + * sizes, etc. + * + */ + +inline bool tensors_have_same_dim_order(const executorch::aten::Tensor &a, + const executorch::aten::Tensor &b, + const executorch::aten::Tensor &c, + const executorch::aten::Tensor &d) { + executorch::aten::Tensor tensor_list[4] = {a, b, c, d}; + return tensors_have_same_dim_order(tensor_list); +} + +/** + * Given an n-dimensional coordinate array and an array of tensor strides, + * calculates the linear index that can be used to retrieve the value at the + * given coordinates. + * @param coordinate Pointer to the array of coordinates. + * @param strides Pointer to the array of strides. + * @param ndim Number of dimensions in the tensor. + */ +inline size_t +calculate_linear_index(const executorch::aten::SizesType *coordinate, + const executorch::aten::StridesType *strides, + const size_t ndim) { + size_t index = 0; + for (size_t i = 0; i < ndim; i++) { + index += coordinate[i] * strides[i]; + } + return index; +} + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::calculate_linear_index; +using ::executorch::runtime::coordinateToIndex; +using ::executorch::runtime::dim_is_valid; +using ::executorch::runtime::extract_scalar_tensor; +using ::executorch::runtime::get_dim_order; +using ::executorch::runtime::getLeadingDims; +using ::executorch::runtime::getTrailingDims; +using ::executorch::runtime::indexToCoordinate; +using ::executorch::runtime::kTensorDimensionLimit; +using ::executorch::runtime::nonempty_size; +using ::executorch::runtime::nonzero_dim; +using ::executorch::runtime::resize; +using ::executorch::runtime::resize_tensor; +using ::executorch::runtime::tensor_can_cast_to; +using ::executorch::runtime::tensor_dim_has_index; +using ::executorch::runtime::tensor_has_dim; +using ::executorch::runtime::tensor_has_expected_size; +using ::executorch::runtime::tensor_has_non_empty_dim; +using ::executorch::runtime::tensor_has_rank_greater_or_equal_to; +using ::executorch::runtime::tensor_has_rank_smaller_or_equal_to; +using ::executorch::runtime::tensor_has_valid_dim_order; +using ::executorch::runtime::tensor_is_bits_type; +using ::executorch::runtime::tensor_is_bool_type; +using ::executorch::runtime::tensor_is_complex_type; +using ::executorch::runtime::tensor_is_contiguous; +using ::executorch::runtime::tensor_is_default_dim_order; +using ::executorch::runtime::tensor_is_default_or_channels_last_dim_order; +using ::executorch::runtime::tensor_is_floating_type; +using ::executorch::runtime::tensor_is_integral_type; +using ::executorch::runtime::tensor_is_rank; +using ::executorch::runtime::tensor_is_real_type; +using ::executorch::runtime::tensor_is_realh_type; +using ::executorch::runtime::tensor_is_realhb_type; +using ::executorch::runtime::tensor_is_scalar; +using ::executorch::runtime::tensors_have_same_dim_order; +using ::executorch::runtime::tensors_have_same_dtype; +using ::executorch::runtime::tensors_have_same_rank; +using ::executorch::runtime::tensors_have_same_shape; +using ::executorch::runtime::tensors_have_same_shape_and_dtype; +using ::executorch::runtime::tensors_have_same_size_at_dims; +using ::executorch::runtime::tensors_have_same_strides; +namespace internal { +using ::executorch::runtime::internal::copy_tensor_data; +using ::executorch::runtime::internal::reset_data_ptr; +using ::executorch::runtime::internal::resize_tensor_impl; +using ::executorch::runtime::internal::set_tensor_data; +using ::executorch::runtime::internal::share_tensor_data; +} // namespace internal +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/freeable_buffer.h b/third-party/include/executorch/runtime/core/freeable_buffer.h new file mode 100644 index 00000000..09c5efca --- /dev/null +++ b/third-party/include/executorch/runtime/core/freeable_buffer.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace runtime { + +/** + * A read-only buffer than can be freed. + */ +class FreeableBuffer final { +public: + // Callback signature for the function that does the freeing. + using FreeFn = void (*)(void *context, void *data, size_t size); + + /** + * Creates an empty FreeableBuffer with size zero and a null data pointer. + */ + FreeableBuffer() + : free_fn_(nullptr), free_fn_context_(nullptr), data_(nullptr), size_(0) { + } + + /** + * Creates a FreeableBuffer with an optional free function. + * + * @param[in] data The data of the segment. + * @param[in] size The size of the segment data, in bytes. + * @param[in] free_fn Optional function to free the data. Guaranteed to be + * called exactly once before the FreeableBuffer is destroyed. May be + * nullptr. NOTE: This function must be thread-safe. If it modifies common + * state, the function must do its own locking. + * @param[in] free_fn_context Opaque pointer to pass as the `context` + * parameter of `free_fn`. May be nullptr. + */ + FreeableBuffer(const void *data, size_t size, FreeFn free_fn, + void *free_fn_context = nullptr) + : free_fn_(free_fn), free_fn_context_(free_fn_context), data_(data), + size_(size) {} + + /** + * Move ctor. Takes the ownership of the data previously owned by `rhs`, + * leaving `rhs` pointing to nullptr. + */ + FreeableBuffer(FreeableBuffer &&rhs) noexcept + : free_fn_(rhs.free_fn_), free_fn_context_(rhs.free_fn_context_), + data_(rhs.data_), size_(rhs.size_) { + rhs.free_fn_ = nullptr; + rhs.free_fn_context_ = nullptr; + rhs.data_ = nullptr; + rhs.size_ = 0; + } + + ~FreeableBuffer() { Free(); } + + /** + * Frees the data if not already free. Safe to call multiple times. + */ + void Free() { + if (data_ != nullptr) { + if (free_fn_ != nullptr) { + free_fn_(free_fn_context_, const_cast(data_), size_); + } + data_ = nullptr; + size_ = 0; + } + } + + /** + * Size of the data in bytes. Returns 0 if the data has been freed. + */ + size_t size() const { return size_; } + + /** + * Pointer to the data. Returns nullptr if the data has been freed. + */ + const void *data() const { return data_; } + +private: + // Delete other rule-of-five methods. + FreeableBuffer(const FreeableBuffer &rhs) = delete; + FreeableBuffer &operator=(FreeableBuffer &&rhs) noexcept = delete; + FreeableBuffer &operator=(const FreeableBuffer &rhs) = delete; + + FreeFn free_fn_; + void *free_fn_context_; + const void *data_; + size_t size_; +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::FreeableBuffer; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/hierarchical_allocator.h b/third-party/include/executorch/runtime/core/hierarchical_allocator.h new file mode 100644 index 00000000..d894c080 --- /dev/null +++ b/third-party/include/executorch/runtime/core/hierarchical_allocator.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace runtime { + +/** + * A group of buffers that can be used to represent a device's memory hierarchy. + */ +class HierarchicalAllocator final { +public: + /** + * Constructs a new hierarchical allocator with the given array of buffers. + * + * - Memory IDs are based on the index into `buffers`: `buffers[N]` will have + * a memory ID of `N`. + * - `buffers.size()` must be >= `MethodMeta::num_non_const_buffers()`. + * - `buffers[N].size()` must be >= `MethodMeta::non_const_buffer_size(N)`. + */ + explicit HierarchicalAllocator(Span> buffers) + : buffers_(buffers) {} + + /** + * DEPRECATED: Use spans instead. + */ + ET_DEPRECATED HierarchicalAllocator(uint32_t n_allocators, + MemoryAllocator *allocators) + : buffers_(to_spans(n_allocators, allocators)) {} + + /** + * Returns the address at the byte offset `offset_bytes` from the given + * buffer's base address, which points to at least `size_bytes` of memory. + * + * @param[in] memory_id The ID of the buffer in the hierarchy. + * @param[in] offset_bytes The offset in bytes into the specified buffer. + * @param[in] size_bytes The amount of memory that should be available at + * the offset. + * + * @returns On success, the address of the requested byte offset into the + * specified buffer. On failure, a non-Ok Error. + */ + ET_NODISCARD Result get_offset_address(uint32_t memory_id, + size_t offset_bytes, + size_t size_bytes) { + ET_CHECK_OR_RETURN_ERROR(memory_id < buffers_.size(), InvalidArgument, + "id %" PRIu32 " >= %zu", memory_id, + buffers_.size()); + Span buffer = buffers_[memory_id]; + ET_CHECK_OR_RETURN_ERROR( + offset_bytes + size_bytes <= buffer.size(), MemoryAllocationFailed, + "offset_bytes (%zu) + size_bytes (%zu) >= allocator size (%zu) " + "for memory_id %" PRIu32, + offset_bytes, size_bytes, buffer.size(), memory_id); + return buffer.data() + offset_bytes; + } + +private: + // TODO(T162089316): Remove the span array and to_spans once all users move to + // spans. This array is necessary to hold the pointers and sizes that were + // originally provided as MemoryAllocator instances. + static constexpr size_t kSpanArraySize = 16; + // NOTE: span_array_ must be declared before buffers_ so that it isn't + // re-initialized to zeros after initializing buffers_. + Span span_array_[kSpanArraySize]; + Span> to_spans(uint32_t n_allocators, + MemoryAllocator *allocators) { + ET_CHECK_MSG(n_allocators <= kSpanArraySize, + "n_allocators %" PRIu32 " > %zu", n_allocators, + kSpanArraySize); + for (uint32_t i = 0; i < n_allocators; ++i) { + span_array_[i] = + Span(allocators[i].base_address(), allocators[i].size()); + } + return {span_array_, n_allocators}; + } + + /// The underlying buffers. + Span> buffers_; +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::HierarchicalAllocator; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/memory_allocator.h b/third-party/include/executorch/runtime/core/memory_allocator.h new file mode 100644 index 00000000..80bdd150 --- /dev/null +++ b/third-party/include/executorch/runtime/core/memory_allocator.h @@ -0,0 +1,362 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace runtime { + +/** + * A class that does simple allocation based on a size and returns the pointer + * to the memory address. It bookmarks a buffer with certain size. The + * allocation is simply checking space and growing the cur_ pointer with each + * allocation request. + * + * Simple example: + * + * // User allocates a 100 byte long memory in the heap. + * uint8_t* memory_pool = malloc(100 * sizeof(uint8_t)); + * MemoryAllocator allocator(100, memory_pool) + * // Pass allocator object in the Executor + * + * Underneath the hood, ExecuTorch will call + * allocator.allocate() to keep iterating cur_ pointer + */ +class MemoryAllocator { +public: + /** + * Default alignment of memory returned by this class. Ensures that pointer + * fields of structs will be aligned. Larger types like `long double` may not + * be, however, depending on the toolchain and architecture. + */ + static constexpr size_t kDefaultAlignment = alignof(void *); + + /** + * Constructs a new memory allocator of a given `size`, starting at the + * provided `base_address`. + * + * @param[in] size The size in bytes of the buffer at `base_address`. + * @param[in] base_address The buffer to allocate from. Does not take + * ownership of this buffer, so it must be valid for the lifetime of of + * the MemoryAllocator. + */ + MemoryAllocator(uint32_t size, uint8_t *base_address) + : begin_(base_address), end_(base_address + size), cur_(base_address), + size_(size) {} + + /** + * Allocates `size` bytes of memory. + * + * @param[in] size Number of bytes to allocate. + * @param[in] alignment Minimum alignment for the returned pointer. Must be a + * power of 2. + * + * @returns Aligned pointer to the allocated memory on success. + * @retval nullptr Not enough memory, or `alignment` was not a power of 2. + */ + virtual void *allocate(size_t size, size_t alignment = kDefaultAlignment) { + if (!isPowerOf2(alignment)) { + ET_LOG(Error, "Alignment %zu is not a power of 2", alignment); + return nullptr; + } + + // The allocation will occupy [start, end), where the start is the next + // position that's a multiple of alignment. + uint8_t *start = alignPointer(cur_, alignment); + uint8_t *end = start + size; + + // If the end of this allocation exceeds the end of this allocator, print + // error messages and return nullptr + if (end > end_) { + ET_LOG(Error, + "Memory allocation failed: %zuB requested (adjusted for " + "alignment), %zuB available", + static_cast(end - cur_), static_cast(end_ - cur_)); + return nullptr; + } + + // Otherwise, record how many bytes were used, advance cur_ to the new end, + // and then return start. Note that the number of bytes used is (end - cur_) + // instead of (end - start) because start > cur_ if there is a misalignment + EXECUTORCH_TRACK_ALLOCATION(prof_id_, end - cur_); + cur_ = end; + return static_cast(start); + } + + /** + * Allocates a buffer large enough for an instance of type T. Note that the + * memory will not be initialized. + * + * Example: + * @code + * auto p = memory_allocator->allocateInstance(); + * @endcode + * + * @param[in] alignment Minimum alignment for the returned pointer. Must be a + * power of 2. Defaults to the natural alignment of T. + * + * @returns Aligned pointer to the allocated memory on success. + * @retval nullptr Not enough memory, or `alignment` was not a power of 2. + */ + template T *allocateInstance(size_t alignment = alignof(T)) { + return static_cast(this->allocate(sizeof(T), alignment)); + } + + /** + * Allocates `size` number of chunks of type T, where each chunk is of size + * equal to sizeof(T) bytes. + * + * @param[in] size Number of memory chunks to allocate. + * @param[in] alignment Minimum alignment for the returned pointer. Must be a + * power of 2. Defaults to the natural alignment of T. + * + * @returns Aligned pointer to the allocated memory on success. + * @retval nullptr Not enough memory, or `alignment` was not a power of 2. + */ + template + T *allocateList(size_t size, size_t alignment = alignof(T)) { + // Some users of this method allocate lists of pointers, causing the next + // line to expand to `sizeof(type *)`, which triggers a clang-tidy warning. + // NOLINTNEXTLINE(bugprone-sizeof-expression) + return static_cast(this->allocate(size * sizeof(T), alignment)); + } + + // Returns the allocator memory's base address. + virtual uint8_t *base_address() const { return begin_; } + + // Returns the total size of the allocator's memory buffer. + virtual uint32_t size() const { return size_; } + + // Resets the current pointer to the base address. It does nothing to + // the contents. + virtual void reset() { cur_ = begin_; } + + void enable_profiling(ET_UNUSED const char *name) { + prof_id_ = EXECUTORCH_TRACK_ALLOCATOR(name); + } + + virtual ~MemoryAllocator() {} + +protected: + /** + * Returns the profiler ID for this allocator. + */ + int32_t prof_id() const { return prof_id_; } + + /** + * Returns true if the value is an integer power of 2. + */ + static bool isPowerOf2(size_t value) { + return value > 0 && (value & ~(value - 1)) == value; + } + + /** + * Returns the next alignment for a given pointer. + */ + static uint8_t *alignPointer(void *ptr, size_t alignment) { + intptr_t addr = reinterpret_cast(ptr); + if ((addr & (alignment - 1)) == 0) { + // Already aligned. + return reinterpret_cast(ptr); + } + addr = (addr | (alignment - 1)) + 1; + return reinterpret_cast(addr); + } + +private: + uint8_t *const begin_; + uint8_t *const end_; + uint8_t *cur_; + uint32_t const size_; + int32_t prof_id_ = -1; +}; + +#if ET_HAVE_GNU_STATEMENT_EXPRESSIONS +/** + * Tries allocating from the specified MemoryAllocator*. + * + * - On success, returns a pointer to the allocated buffer. + * - On failure, executes the provided code block, which must return or panic. + * + * Example: + * @code + * char* buf = ET_TRY_ALLOCATE_OR( + * memory_allocator, bufsize, { + * *out_err = Error::MemoryAllocationFailed; + * return nullopt; + * }); + * @endcode + */ +#define ET_TRY_ALLOCATE_OR(memory_allocator__, nbytes__, ...) \ + ({ \ + void *et_try_allocate_result = memory_allocator__->allocate(nbytes__); \ + if (et_try_allocate_result == nullptr && nbytes__ > 0) { \ + __VA_ARGS__ \ + /* The args must return. */ \ + ET_UNREACHABLE(); \ + } \ + et_try_allocate_result; \ + }) + +/** + * Tries allocating an instance of type__ from the specified MemoryAllocator*. + * + * - On success, returns a pointer to the allocated buffer. Note that the memory + * will not be initialized. + * - On failure, executes the provided code block, which must return or panic. + * + * Example: + * @code + * char* buf = ET_TRY_ALLOCATE_INSTANCE_OR( + * memory_allocator, + * MyType, + * { *out_err = Error::MemoryAllocationFailed; return nullopt; }); + * @endcode + */ +#define ET_TRY_ALLOCATE_INSTANCE_OR(memory_allocator__, type__, ...) \ + ({ \ + type__ *et_try_allocate_result = \ + memory_allocator__->allocateInstance(); \ + if (et_try_allocate_result == nullptr) { \ + __VA_ARGS__ \ + /* The args must return. */ \ + ET_UNREACHABLE(); \ + } \ + et_try_allocate_result; \ + }) + +/** + * Tries allocating multiple elements of a given type from the specified + * MemoryAllocator*. + * + * - On success, returns a pointer to the allocated buffer. + * - On failure, executes the provided code block, which must return or panic. + * + * Example: + * @code + * Tensor* tensor_list = ET_TRY_ALLOCATE_LIST_OR( + * memory_allocator, Tensor, num_tensors, { + * *out_err = Error::MemoryAllocationFailed; + * return nullopt; + * }); + * @endcode + */ +#define ET_TRY_ALLOCATE_LIST_OR(memory_allocator__, type__, nelem__, ...) \ + ({ \ + type__ *et_try_allocate_result = \ + memory_allocator__->allocateList(nelem__); \ + if (et_try_allocate_result == nullptr && nelem__ > 0) { \ + __VA_ARGS__ \ + /* The args must return. */ \ + ET_UNREACHABLE(); \ + } \ + et_try_allocate_result; \ + }) +#else // !ET_HAVE_GNU_STATEMENT_EXPRESSIONS +/** + * The recommended alternative for statement expression-incompatible compilers + * is to directly allocate the memory. + * e.g. memory_allocator__->allocate(nbytes__); + */ +#define ET_TRY_ALLOCATE_OR(memory_allocator__, nbytes__, ...) \ + static_assert(false, "ET_TRY_ALLOCATE_OR uses statement expressions and \ + thus is not available for use with this compiler."); + +/** + * The recommended alternative for statement expression-incompatible compilers + * is to directly allocate the memory. + * e.g. memory_allocator__->allocateInstance(); + */ +#define ET_TRY_ALLOCATE_INSTANCE_OR(memory_allocator__, type__, ...) \ + static_assert(false, "ET_TRY_ALLOCATE_INSTANCE_OR uses statement \ + expressions and thus is not available for use with this compiler."); + +/** + * The recommended alternative for statement expression-incompatible compilers + * is to directly use allocate the memory. + * e.g. memory_allocator__->allocateList(nelem__); + */ +#define ET_TRY_ALLOCATE_LIST_OR(memory_allocator__, type__, nelem__, ...) \ + static_assert(false, "ET_TRY_ALLOCATE_LIST_OR uses statement \ + expressions and thus is not available for use with this compiler."); +#endif // !ET_HAVE_GNU_STATEMENT_EXPRESSIONS + +/** + * Tries allocating from the specified MemoryAllocator*. + * + * - On success, returns a pointer to the allocated buffer. + * - On failure, returns `Error::MemoryAllocationFailed` from the calling + * function, which must be declared to return `executorch::runtime::Error`. + * + * Example: + * @code + * char* buf = ET_ALLOCATE_OR_RETURN_ERROR(memory_allocator, bufsize); + * @endcode + */ +#define ET_ALLOCATE_OR_RETURN_ERROR(memory_allocator__, nbytes__) \ + ET_TRY_ALLOCATE_OR(memory_allocator__, nbytes__, { \ + return ::executorch::runtime::Error::MemoryAllocationFailed; \ + }) + +/** + * Tries allocating an instance of type__ from the specified MemoryAllocator*. + * + * - On success, returns a pointer to the allocated buffer. Note that the memory + * will not be initialized. + * - On failure, returns `Error::MemoryAllocationFailed` from the calling + * function, which must be declared to return `executorch::runtime::Error`. + * + * Example: + * @code + * char* buf = ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(memory_allocator, MyType); + * @endcode + */ +#define ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(memory_allocator__, type__) \ + ET_TRY_ALLOCATE_INSTANCE_OR(memory_allocator__, type__, { \ + return ::executorch::runtime::Error::MemoryAllocationFailed; \ + }) + +/** + * Tries allocating multiple elements of a given type from the specified + * MemoryAllocator*. + * + * - On success, returns a pointer to the allocated buffer. + * - On failure, returns `Error::MemoryAllocationFailed` from the calling + * function, which must be declared to return `executorch::runtime::Error`. + * + * Example: + * @code + * Tensor* tensor_list = ET_ALLOCATE_LIST_OR_RETURN_ERROR( + * memory_allocator, Tensor, num_tensors); + * @endcode + */ +#define ET_ALLOCATE_LIST_OR_RETURN_ERROR(memory_allocator__, type__, nelem__) \ + ET_TRY_ALLOCATE_LIST_OR(memory_allocator__, type__, nelem__, { \ + return ::executorch::runtime::Error::MemoryAllocationFailed; \ + }) + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::MemoryAllocator; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/portable_type/bfloat16.h b/third-party/include/executorch/runtime/core/portable_type/bfloat16.h new file mode 100644 index 00000000..2ae42a8e --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/bfloat16.h @@ -0,0 +1,336 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace executorch { +namespace runtime { +namespace etensor { + +namespace internal { +inline float f32_from_bits(uint16_t src) { + float res = 0; + uint32_t tmp = src; + tmp <<= 16; + std::memcpy(&res, &tmp, sizeof(tmp)); + return res; +} + +inline uint16_t round_to_nearest_even(float src) { + if (std::isnan(src)) { + return UINT16_C(0x7FC0); + } + uint32_t U32 = 0; + std::memcpy(&U32, &src, sizeof(U32)); + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16); +} +} // namespace internal + +/** + * The "brain floating-point" type, compatible with c10/util/BFloat16.h from + * pytorch core. + * + * This representation uses 1 bit for the sign, 8 bits for the exponent and 7 + * bits for the mantissa. + */ +struct alignas(2) BFloat16 { + uint16_t x; + + BFloat16() = default; + struct from_bits_t {}; + static constexpr from_bits_t from_bits() { return from_bits_t(); } + + constexpr BFloat16(unsigned short bits, from_bits_t) : x(bits) {} + /* implicit */ BFloat16(float value) + : x(internal::round_to_nearest_even(value)) {} + operator float() const { return internal::f32_from_bits(x); } +}; + +inline std::ostream &operator<<(std::ostream &out, const BFloat16 &value) { + out << (float)value; + return out; +} + +/// Arithmetic + +inline BFloat16 operator+(const BFloat16 &a, const BFloat16 &b) { + return static_cast(a) + static_cast(b); +} + +inline BFloat16 operator-(const BFloat16 &a, const BFloat16 &b) { + return static_cast(a) - static_cast(b); +} + +inline BFloat16 operator*(const BFloat16 &a, const BFloat16 &b) { + return static_cast(a) * static_cast(b); +} + +inline BFloat16 operator/(const BFloat16 &a, const BFloat16 &b) { + return static_cast(a) / static_cast(b); +} + +inline BFloat16 operator-(const BFloat16 &a) { return -static_cast(a); } + +inline BFloat16 &operator+=(BFloat16 &a, const BFloat16 &b) { + a = a + b; + return a; +} + +inline BFloat16 &operator-=(BFloat16 &a, const BFloat16 &b) { + a = a - b; + return a; +} + +inline BFloat16 &operator*=(BFloat16 &a, const BFloat16 &b) { + a = a * b; + return a; +} + +inline BFloat16 &operator/=(BFloat16 &a, const BFloat16 &b) { + a = a / b; + return a; +} + +inline BFloat16 &operator|(BFloat16 &a, const BFloat16 &b) { + a.x = a.x | b.x; + return a; +} + +inline BFloat16 &operator^(BFloat16 &a, const BFloat16 &b) { + a.x = a.x ^ b.x; + return a; +} + +inline BFloat16 &operator&(BFloat16 &a, const BFloat16 &b) { + a.x = a.x & b.x; + return a; +} + +/// Arithmetic with floats + +inline float operator+(BFloat16 a, float b) { + return static_cast(a) + b; +} +inline float operator-(BFloat16 a, float b) { + return static_cast(a) - b; +} +inline float operator*(BFloat16 a, float b) { + return static_cast(a) * b; +} +inline float operator/(BFloat16 a, float b) { + return static_cast(a) / b; +} + +inline float operator+(float a, BFloat16 b) { + return a + static_cast(b); +} +inline float operator-(float a, BFloat16 b) { + return a - static_cast(b); +} +inline float operator*(float a, BFloat16 b) { + return a * static_cast(b); +} +inline float operator/(float a, BFloat16 b) { + return a / static_cast(b); +} + +inline float &operator+=(float &a, const BFloat16 &b) { + return a += static_cast(b); +} +inline float &operator-=(float &a, const BFloat16 &b) { + return a -= static_cast(b); +} +inline float &operator*=(float &a, const BFloat16 &b) { + return a *= static_cast(b); +} +inline float &operator/=(float &a, const BFloat16 &b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline double operator+(BFloat16 a, double b) { + return static_cast(a) + b; +} +inline double operator-(BFloat16 a, double b) { + return static_cast(a) - b; +} +inline double operator*(BFloat16 a, double b) { + return static_cast(a) * b; +} +inline double operator/(BFloat16 a, double b) { + return static_cast(a) / b; +} + +inline double operator+(double a, BFloat16 b) { + return a + static_cast(b); +} +inline double operator-(double a, BFloat16 b) { + return a - static_cast(b); +} +inline double operator*(double a, BFloat16 b) { + return a * static_cast(b); +} +inline double operator/(double a, BFloat16 b) { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline BFloat16 operator+(BFloat16 a, int b) { + return a + static_cast(b); +} +inline BFloat16 operator-(BFloat16 a, int b) { + return a - static_cast(b); +} +inline BFloat16 operator*(BFloat16 a, int b) { + return a * static_cast(b); +} +inline BFloat16 operator/(BFloat16 a, int b) { + return a / static_cast(b); +} + +inline BFloat16 operator+(int a, BFloat16 b) { + return static_cast(a) + b; +} +inline BFloat16 operator-(int a, BFloat16 b) { + return static_cast(a) - b; +} +inline BFloat16 operator*(int a, BFloat16 b) { + return static_cast(a) * b; +} +inline BFloat16 operator/(int a, BFloat16 b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline BFloat16 operator+(BFloat16 a, int64_t b) { + return a + static_cast(b); +} +inline BFloat16 operator-(BFloat16 a, int64_t b) { + return a - static_cast(b); +} +inline BFloat16 operator*(BFloat16 a, int64_t b) { + return a * static_cast(b); +} +inline BFloat16 operator/(BFloat16 a, int64_t b) { + return a / static_cast(b); +} + +inline BFloat16 operator+(int64_t a, BFloat16 b) { + return static_cast(a) + b; +} +inline BFloat16 operator-(int64_t a, BFloat16 b) { + return static_cast(a) - b; +} +inline BFloat16 operator*(int64_t a, BFloat16 b) { + return static_cast(a) * b; +} +inline BFloat16 operator/(int64_t a, BFloat16 b) { + return static_cast(a) / b; +} + +// Overloading < and > operators, because std::max and std::min use them. + +inline bool operator>(BFloat16 &lhs, BFloat16 &rhs) { + return float(lhs) > float(rhs); +} + +inline bool operator<(BFloat16 &lhs, BFloat16 &rhs) { + return float(lhs) < float(rhs); +} + +} // namespace etensor +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::etensor::BFloat16; +} // namespace executor +} // namespace torch + +namespace std { + +template <> class numeric_limits { +public: + static constexpr bool is_signed = true; + static constexpr bool is_specialized = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 8; + static constexpr int digits10 = 2; + static constexpr int max_digits10 = 4; + static constexpr int radix = 2; + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr torch::executor::BFloat16 min() { + return torch::executor::BFloat16(0x0080, + torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 lowest() { + return torch::executor::BFloat16(0xFF7F, + torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 max() { + return torch::executor::BFloat16(0x7F7F, + torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 epsilon() { + return torch::executor::BFloat16(0x3C00, + torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 round_error() { + return torch::executor::BFloat16(0x3F00, + torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 infinity() { + return torch::executor::BFloat16(0x7F80, + torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 quiet_NaN() { + return torch::executor::BFloat16(0x7FC0, + torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 signaling_NaN() { + return torch::executor::BFloat16(0x7F80, + torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 denorm_min() { + return torch::executor::BFloat16(0x0001, + torch::executor::BFloat16::from_bits()); + } +}; + +} // namespace std diff --git a/third-party/include/executorch/runtime/core/portable_type/bfloat16_math.h b/third-party/include/executorch/runtime/core/portable_type/bfloat16_math.h new file mode 100644 index 00000000..3f6fd400 --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/bfloat16_math.h @@ -0,0 +1,257 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace std { + +template +struct is_reduced_floating_point + : std::integral_constant< + bool, std::is_same::value || + std::is_same::value> {}; + +template ::value, int>::type = 0> +inline T acos(T a) { + return std::acos(float(a)); +} +template ::value, int>::type = 0> +inline T asin(T a) { + return std::asin(float(a)); +} +template ::value, int>::type = 0> +inline T atan(T a) { + return std::atan(float(a)); +} +template ::value, int>::type = 0> +inline T atanh(T a) { + return std::atanh(float(a)); +} +template ::value, int>::type = 0> +inline T erf(T a) { + return std::erf(float(a)); +} +template ::value, int>::type = 0> +inline T erfc(T a) { + return std::erfc(float(a)); +} +template ::value, int>::type = 0> +inline T exp(T a) { + return std::exp(float(a)); +} +template ::value, int>::type = 0> +inline T expm1(T a) { + return std::expm1(float(a)); +} +template ::value, int>::type = 0> +inline bool isfinite(T a) { + return std::isfinite(float(a)); +} +template ::value, int>::type = 0> +inline T log(T a) { + return std::log(float(a)); +} +template ::value, int>::type = 0> +inline T log10(T a) { + return std::log10(float(a)); +} +template ::value, int>::type = 0> +inline T log1p(T a) { + return std::log1p(float(a)); +} +template ::value, int>::type = 0> +inline T log2(T a) { + return std::log2(float(a)); +} +template ::value, int>::type = 0> +inline T ceil(T a) { + return std::ceil(float(a)); +} +template ::value, int>::type = 0> +inline T cos(T a) { + return std::cos(float(a)); +} +template ::value, int>::type = 0> +inline T floor(T a) { + return std::floor(float(a)); +} +template ::value, int>::type = 0> +inline T nearbyint(T a) { + return std::nearbyint(float(a)); +} +template ::value, int>::type = 0> +inline T sin(T a) { + return std::sin(float(a)); +} +template ::value, int>::type = 0> +inline T tan(T a) { + return std::tan(float(a)); +} +template ::value, int>::type = 0> +inline T sinh(T a) { + return std::sinh(float(a)); +} +template ::value, int>::type = 0> +inline T cosh(T a) { + return std::cosh(float(a)); +} +template ::value, int>::type = 0> +inline T tanh(T a) { + return std::tanh(float(a)); +} +template ::value, int>::type = 0> +inline T trunc(T a) { + return std::trunc(float(a)); +} +template ::value, int>::type = 0> +inline T lgamma(T a) { + return std::lgamma(float(a)); +} +template ::value, int>::type = 0> +inline T sqrt(T a) { + return std::sqrt(float(a)); +} +template ::value, int>::type = 0> +inline T rsqrt(T a) { + return 1.0 / std::sqrt(float(a)); +} +template ::value, int>::type = 0> +inline T abs(T a) { + return std::abs(float(a)); +} +#if defined(_MSC_VER) && defined(__CUDACC__) +template ::value, int>::type = 0> +inline T pow(T a, double b) { + return std::pow(float(a), float(b)); +} +#else +template ::value, int>::type = 0> +inline T pow(T a, double b) { + return std::pow(float(a), b); +} +#endif +template ::value, int>::type = 0> +inline T pow(T a, T b) { + return std::pow(float(a), float(b)); +} +template ::value, int>::type = 0> +inline T fmod(T a, T b) { + return std::fmod(float(a), float(b)); +} + +/* + The following function is inspired from the implementation in `musl` + Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT + ---------------------------------------------------------------------- + Copyright © 2005-2020 Rich Felker, et al. + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ---------------------------------------------------------------------- + */ +template ::value, int>::type = 0> +inline T nextafter(T from, T to) { + // Reference: + // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c + using int_repr_t = uint16_t; + constexpr uint8_t bits = 16; + union { + T f; + int_repr_t i; + } ufrom = {from}, uto = {to}; + + // get a mask to get the sign bit i.e. MSB + int_repr_t sign_mask = int_repr_t{1} << (bits - 1); + + // short-circuit: if either is NaN, return NaN + if (from != from || to != to) { + return from + to; + } + + // short-circuit: if they are exactly the same. + if (ufrom.i == uto.i) { + return from; + } + + // mask the sign-bit to zero i.e. positive + // equivalent to abs(x) + int_repr_t abs_from = ufrom.i & ~sign_mask; + int_repr_t abs_to = uto.i & ~sign_mask; + if (abs_from == 0) { + // if both are zero but with different sign, + // preserve the sign of `to`. + if (abs_to == 0) { + return to; + } + // smallest subnormal with sign of `to`. + ufrom.i = (uto.i & sign_mask) | int_repr_t{1}; + return ufrom.f; + } + + // if abs(from) > abs(to) or sign(from) != sign(to) + if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) { + ufrom.i--; + } else { + ufrom.i++; + } + + return ufrom.f; +} + +} // namespace std diff --git a/third-party/include/executorch/runtime/core/portable_type/bits_types.h b/third-party/include/executorch/runtime/core/portable_type/bits_types.h new file mode 100644 index 00000000..cddffc48 --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/bits_types.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include + +namespace executorch { +namespace runtime { +namespace etensor { + +/** + * bits1x8 is an uninterpreted dtype of a tensor with 1 bit (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits1x8 { + using underlying = uint8_t; + uint8_t val_; + bits1x8() = default; + explicit bits1x8(uint8_t val) : val_(val) {} +}; + +/** + * bits2x4 is an uninterpreted dtype of a tensor with 2 bits (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits2x4 { + using underlying = uint8_t; + uint8_t val_; + bits2x4() = default; + explicit bits2x4(uint8_t val) : val_(val) {} +}; + +/** + * bits4x2 is an uninterpreted dtype of a tensor with 4 bits (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits4x2 { + using underlying = uint8_t; + uint8_t val_; + bits4x2() = default; + explicit bits4x2(uint8_t val) : val_(val) {} +}; + +/** + * bits8 is an uninterpreted dtype of a tensor with 8 bits, without any + * semantics defined. + */ +struct alignas(1) bits8 { + uint8_t val_; + bits8() = default; + explicit bits8(uint8_t val) : val_(val) {} +}; + +/** + * bits16 is an uninterpreted dtype of a tensor with 16 bits, without any + * semantics defined. + */ +struct alignas(2) bits16 { + uint16_t val_; + bits16() = default; + explicit bits16(uint16_t val) : val_(val) {} +}; + +} // namespace etensor +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::etensor::bits16; +using ::executorch::runtime::etensor::bits1x8; +using ::executorch::runtime::etensor::bits2x4; +using ::executorch::runtime::etensor::bits4x2; +using ::executorch::runtime::etensor::bits8; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/portable_type/complex.h b/third-party/include/executorch/runtime/core/portable_type/complex.h new file mode 100644 index 00000000..56ed55d4 --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/complex.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace runtime { +namespace etensor { + +/** + * An implementation of complex numbers, compatible with c10/util/complex.h from + * pytorch core. + */ +template struct alignas(sizeof(T) * 2) complex { + T real_ = T(0); + T imag_ = T(0); +}; + +/** + * Specialization for Half, which is not a primitive C numeric type. + */ +template <> struct alignas(4) complex { + Half real_; + Half imag_; +}; + +} // namespace etensor +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::etensor::complex; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/portable_type/device.h b/third-party/include/executorch/runtime/core/portable_type/device.h new file mode 100644 index 00000000..7f4f447b --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/device.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace runtime { +namespace etensor { + +/// Denotes the specific genre of compute device. +/// Subset of https://github.com/pytorch/pytorch/blob/main/c10/core/Device.h +enum class DeviceType : int8_t { + CPU = 0, +}; + +/// An index representing a specific device; For cpu it should always be -1 or 0 +using DeviceIndex = int8_t; + +/** + * An abstraction for the compute device on which a tensor is located. + * ExecuTorch doesn't allow dynamic dispatching based on device, so this type is + * just a skeleton to allow certain kernels that expect device as an + * argument to still be run. + * + * In ExecuTorch this is always expected to be CPU. + */ +struct Device final { + using Type = DeviceType; + + /// Constructs a new `Device` from a `DeviceType` and an optional device + /// index. + /* implicit */ Device(DeviceType type, DeviceIndex index = -1) + : type_(type), index_(index) {} + + /// Returns the type of device this is. Only CPU is supported. + DeviceType type() const noexcept { return type_; } + + /// Returns true if the device is of CPU type. + bool is_cpu() const noexcept { return type_ == DeviceType::CPU; } + + /// Returns the device index. Always 0 if specified or -1 if not provided. + DeviceIndex index() const noexcept { + ET_CHECK(index_ == 0 || index_ == -1); + return index_; + } + +private: + DeviceType type_; + DeviceIndex index_ = -1; +}; + +} // namespace etensor +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::etensor::Device; +using ::executorch::runtime::etensor::DeviceType; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/portable_type/half.h b/third-party/include/executorch/runtime/core/portable_type/half.h new file mode 100644 index 00000000..fd48fabe --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/half.h @@ -0,0 +1,652 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#if defined(__GNUC__) || defined(__clang__) +#if defined(__aarch64__) +#ifndef __ARM_V8_ONLY__ +#define NATIVE_FP16 1 +#endif // __ARM_V8_ONLY__ +#endif // __aarch64__ +#endif // GNUC or clang + +#if defined(__GNUC__) || defined(__clang__) +#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \ + defined(_M_IX86) +#if defined(__AVX2__) +#define X86_F16 1 +#include // import conversion ops from f16cintrin.h +#endif // __AVX2__ +#endif // __x86_64__ || _M_X64 || __i386 || _M_IX86 +#endif // __GNUC__ || __clang__ + +namespace executorch { +namespace runtime { +namespace etensor { + +/** + * A half-precision floating point type, compatible with c10/util/Half.h from + * pytorch core. + */ +struct alignas(2) Half { + union { +#ifdef NATIVE_FP16 + _Float16 y; +#endif + uint16_t x; + }; + + struct from_bits_t {}; + static constexpr from_bits_t from_bits() { return from_bits_t(); } + + Half() = default; + + constexpr Half(uint16_t bits, from_bits_t) : x(bits) {} + /* implicit */ inline Half(float value); + inline operator float() const; +}; + +namespace internal { + +inline float fp32_from_bits(uint32_t w) { + static_assert(sizeof(float) == sizeof(uint32_t)); + union { + uint32_t as_bits; + float as_value; + } fp32 = {w}; + return fp32.as_value; +} + +inline uint32_t fp32_to_bits(float f) { + static_assert(sizeof(float) == sizeof(uint32_t)); + union { + float as_value; + uint32_t as_bits; + } fp32 = {f}; + return fp32.as_bits; +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 + * of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the + * half-precision number normalized. If the initial number is normalized, some + * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case + * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note + * that if we shift denormalized nonsign by renorm_shift, the unit bit of + * mantissa will shift into exponent, turning the biased exponent into 1, and + * making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff half-precision number has exponent of 15, the addition overflows + * it into bit 31, and the subsequent shift turns the high 9 bits + * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number + * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) + * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0xF + * for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x70, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + return sign | + ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding + * mode and no operations on denormals) floating-point operations and bitcasts + * between integer and floating-point variables. + */ +inline float fp16_ieee_to_fp32_value(uint16_t h) { +#ifdef X86_F16 + return _cvtsh_ss(h); +#else + + /* + * Extend the half-precision floating-point number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits + * of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become + * mantissa and exponent of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias + * between single-precision and half-precision formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after + * conversion to the single-precision number. Therefore, if the biased + * exponent of the half-precision input was 0x1F (max possible value), the + * biased exponent of the single-precision output must be 0xFF (max possible + * value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset + * below) rather than by 0x70 suggested by the difference in the exponent bias + * (see above). + * - Then we multiply the single-precision result of exponent adjustment by + * 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the + * necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and + * NaN would retain their value on at least partially IEEE754-compliant + * implementations. + * + * Note that the above operations do not handle denormal inputs (where biased + * exponent == 0). However, they also do not operate on denormal inputs, and + * do not produce denormal results. + */ + constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23; + // const float exp_scale = 0x1.0p-112f; + constexpr uint32_t scale_bits = (uint32_t)15 << 23; + float exp_scale_val = 0; + std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); + const float exp_scale = exp_scale_val; + const float normalized_value = + fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + /* + * Convert denormalized half-precision inputs into single-precision results + * (always normalized). Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has + * on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the + * same mantissa and thehalf-precision input and with an exponent which would + * scale the corresponding mantissa bits to 2**(-24). A normalized + * single-precision floating-point number is represented as: FP32 = (1 + + * mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased + * exponent is 126, a unit change in the mantissa of the input denormalized + * half-precision number causes a change of the constructed single-precision + * number by 2**(-24), i.e. the same amount. + * + * The last step is to adjust the bias of the constructed single-precision + * number. When the input half-precision number is zero, the constructed + * single-precision number has the value of FP32 = 1 * 2**(126 - 127) = + * 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed + * single-precision number to get the numerical equivalent of the input + * half-precision number. + */ + constexpr uint32_t magic_mask = UINT32_C(126) << 23; + constexpr float magic_bias = 0.5f; + const float denormalized_value = + fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or + * as a denormalized number, depending on the input exponent. The variable + * two_w contains input exponent in bits 27-31, therefore if its smaller than + * 2**27, the input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign + * of the input number. + */ + constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = + sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) + : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); + +#endif // not X86_F16 +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 16-bit floating-point number in IEEE half-precision format, in bit + * representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding + * mode and no operations on denormals) floating-point operations and bitcasts + * between integer and floating-point variables. + */ +inline uint16_t fp16_ieee_from_fp32_value(float f) { +#ifdef X86_F16 + return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT); +#else + + // const float scale_to_inf = 0x1.0p+112f; + // const float scale_to_zero = 0x1.0p-110f; + constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23; + constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23; + float scale_to_inf_val = 0, scale_to_zero_val = 0; + std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); + std::memcpy(&scale_to_zero_val, &scale_to_zero_bits, + sizeof(scale_to_zero_val)); + const float scale_to_inf = scale_to_inf_val; + const float scale_to_zero = scale_to_zero_val; + +#if defined(_MSC_VER) && _MSC_VER == 1916 + float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero; +#else + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; +#endif + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return static_cast( + (sign >> 16) | + (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)); +#endif // not X86_F16 +} + +} // namespace internal + +/// Constructors +#ifdef NATIVE_FP16 +inline Half::Half(float value) : y(value) {} +#else +inline Half::Half(float value) + : x(internal::fp16_ieee_from_fp32_value(value)) {} +#endif + +/// Implicit conversions +#ifdef NATIVE_FP16 +inline Half::operator float() const { return (float)y; } +#else +inline Half::operator float() const { + return internal::fp16_ieee_to_fp32_value(x); +} +#endif + +/// Arithmetic + +#ifdef NATIVE_FP16 + +#define return_half(r) \ + do { \ + Half ret; \ + ret.y = r; \ + return ret; \ + } while (0) + +inline Half operator+(const Half &a, const Half &b) { return_half(a.y + b.y); } + +inline Half operator-(const Half &a, const Half &b) { + return_half(a.y - b.y); + return static_cast(a) - static_cast(b); +} + +inline Half operator*(const Half &a, const Half &b) { return_half(a.y * b.y); } + +inline Half operator/(const Half &a, const Half &b) { return_half(a.y / b.y); } + +inline Half operator-(const Half &a) { return_half(-a.y); } + +inline Half &operator+=(Half &a, const Half &b) { + a.y += b.y; + return a; +} + +inline Half &operator-=(Half &a, const Half &b) { + a.y -= b.y; + return a; +} + +inline Half &operator*=(Half &a, const Half &b) { + a.y *= b.y; + return a; +} + +inline Half &operator/=(Half &a, const Half &b) { + a.y /= b.y; + return a; +} + +#else + +inline Half operator+(const Half &a, const Half &b) { + return static_cast(a) + static_cast(b); +} + +inline Half operator-(const Half &a, const Half &b) { + return static_cast(a) - static_cast(b); +} + +inline Half operator*(const Half &a, const Half &b) { + return static_cast(a) * static_cast(b); +} + +inline Half operator/(const Half &a, const Half &b) { + return static_cast(a) / static_cast(b); +} + +inline Half operator-(const Half &a) { return -static_cast(a); } + +inline Half &operator+=(Half &a, const Half &b) { + a = a + b; + return a; +} + +inline Half &operator-=(Half &a, const Half &b) { + a = a - b; + return a; +} + +inline Half &operator*=(Half &a, const Half &b) { + a = a * b; + return a; +} + +inline Half &operator/=(Half &a, const Half &b) { + a = a / b; + return a; +} + +#endif + +/// Arithmetic with floats + +inline float operator+(Half a, float b) { return static_cast(a) + b; } +inline float operator-(Half a, float b) { return static_cast(a) - b; } +inline float operator*(Half a, float b) { return static_cast(a) * b; } +inline float operator/(Half a, float b) { return static_cast(a) / b; } + +inline float operator+(float a, Half b) { return a + static_cast(b); } +inline float operator-(float a, Half b) { return a - static_cast(b); } +inline float operator*(float a, Half b) { return a * static_cast(b); } +inline float operator/(float a, Half b) { return a / static_cast(b); } + +inline float &operator+=(float &a, const Half &b) { + return a += static_cast(b); +} +inline float &operator-=(float &a, const Half &b) { + return a -= static_cast(b); +} +inline float &operator*=(float &a, const Half &b) { + return a *= static_cast(b); +} +inline float &operator/=(float &a, const Half &b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline double operator+(Half a, double b) { return static_cast(a) + b; } +inline double operator-(Half a, double b) { return static_cast(a) - b; } +inline double operator*(Half a, double b) { return static_cast(a) * b; } +inline double operator/(Half a, double b) { return static_cast(a) / b; } + +inline double operator+(double a, Half b) { return a + static_cast(b); } +inline double operator-(double a, Half b) { return a - static_cast(b); } +inline double operator*(double a, Half b) { return a * static_cast(b); } +inline double operator/(double a, Half b) { return a / static_cast(b); } + +/// Arithmetic with ints + +#ifdef NATIVE_FP16 + +inline Half operator+(Half a, int32_t b) { return_half(a.y + b); } +inline Half operator-(Half a, int32_t b) { return_half(a.y - b); } +inline Half operator*(Half a, int32_t b) { return_half(a.y * b); } +inline Half operator/(Half a, int32_t b) { return_half(a.y / b); } + +inline Half operator+(int32_t a, Half b) { return_half(a + b.y); } +inline Half operator-(int32_t a, Half b) { return_half(a - b.y); } +inline Half operator*(int32_t a, Half b) { return_half(a * b.y); } +inline Half operator/(int32_t a, Half b) { return_half(a / b.y); } + +#else + +inline Half operator+(Half a, int32_t b) { return a + static_cast(b); } +inline Half operator-(Half a, int32_t b) { return a - static_cast(b); } +inline Half operator*(Half a, int32_t b) { return a * static_cast(b); } +inline Half operator/(Half a, int32_t b) { return a / static_cast(b); } + +inline Half operator+(int32_t a, Half b) { return static_cast(a) + b; } +inline Half operator-(int32_t a, Half b) { return static_cast(a) - b; } +inline Half operator*(int32_t a, Half b) { return static_cast(a) * b; } +inline Half operator/(int32_t a, Half b) { return static_cast(a) / b; } + +#endif + +//// Arithmetic with int64_t + +#ifdef NATIVE_FP16 + +inline Half operator+(Half a, int64_t b) { return_half(a.y + b); } +inline Half operator-(Half a, int64_t b) { return_half(a.y - b); } +inline Half operator*(Half a, int64_t b) { return_half(a.y * b); } +inline Half operator/(Half a, int64_t b) { return_half(a.y / b); } + +inline Half operator+(int64_t a, Half b) { return_half(a + b.y); } +inline Half operator-(int64_t a, Half b) { return_half(a - b.y); } +inline Half operator*(int64_t a, Half b) { return_half(a * b.y); } +inline Half operator/(int64_t a, Half b) { return_half(a / b.y); } + +#else + +inline Half operator+(Half a, int64_t b) { return a + static_cast(b); } +inline Half operator-(Half a, int64_t b) { return a - static_cast(b); } +inline Half operator*(Half a, int64_t b) { return a * static_cast(b); } +inline Half operator/(Half a, int64_t b) { return a / static_cast(b); } + +inline Half operator+(int64_t a, Half b) { return static_cast(a) + b; } +inline Half operator-(int64_t a, Half b) { return static_cast(a) - b; } +inline Half operator*(int64_t a, Half b) { return static_cast(a) * b; } +inline Half operator/(int64_t a, Half b) { return static_cast(a) / b; } + +#endif + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion Half to float. + +static inline std::ostream & +operator<<(std::ostream &out, const executorch::runtime::etensor::Half &value) { + out << (float)value; + return out; +} + +} // namespace etensor +} // namespace runtime +} // namespace executorch +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::etensor::Half; +} // namespace executor +} // namespace torch + +namespace std { + +template <> class numeric_limits { +public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = true; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 11; + static constexpr int digits10 = 3; + static constexpr int max_digits10 = 5; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + static constexpr executorch::runtime::etensor::Half min() { + return executorch::runtime::etensor::Half( + 0x0400, executorch::runtime::etensor::Half::from_bits()); + } + static constexpr executorch::runtime::etensor::Half lowest() { + return executorch::runtime::etensor::Half( + 0xFBFF, executorch::runtime::etensor::Half::from_bits()); + } + static constexpr executorch::runtime::etensor::Half max() { + return executorch::runtime::etensor::Half( + 0x7BFF, executorch::runtime::etensor::Half::from_bits()); + } + static constexpr executorch::runtime::etensor::Half epsilon() { + return executorch::runtime::etensor::Half( + 0x1400, executorch::runtime::etensor::Half::from_bits()); + } + static constexpr executorch::runtime::etensor::Half round_error() { + return executorch::runtime::etensor::Half( + 0x3800, executorch::runtime::etensor::Half::from_bits()); + } + static constexpr executorch::runtime::etensor::Half infinity() { + return executorch::runtime::etensor::Half( + 0x7C00, executorch::runtime::etensor::Half::from_bits()); + } + static constexpr executorch::runtime::etensor::Half quiet_NaN() { + return executorch::runtime::etensor::Half( + 0x7E00, executorch::runtime::etensor::Half::from_bits()); + } + static constexpr executorch::runtime::etensor::Half signaling_NaN() { + return executorch::runtime::etensor::Half( + 0x7D00, executorch::runtime::etensor::Half::from_bits()); + } + static constexpr executorch::runtime::etensor::Half denorm_min() { + return executorch::runtime::etensor::Half( + 0x0001, executorch::runtime::etensor::Half::from_bits()); + } +}; + +} // namespace std diff --git a/third-party/include/executorch/runtime/core/portable_type/optional.h b/third-party/include/executorch/runtime/core/portable_type/optional.h new file mode 100644 index 00000000..f9df4681 --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/optional.h @@ -0,0 +1,180 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include // std::forward and other template magic checks + +namespace executorch { +namespace runtime { +namespace etensor { + +/// Used to indicate an optional type with uninitialized state. +struct nullopt_t final { + constexpr explicit nullopt_t(int32_t) {} +}; + +/// A constant of type nullopt_t that is used to indicate an optional type with +/// uninitialized state. +constexpr nullopt_t nullopt{0}; + +/// Leaner optional class, subset of c10, std, and boost optional APIs. +template class optional final { +public: + /// The type wrapped by the optional class. + using value_type = T; + + /// Constructs an optional object that does not contain a value. + /* implicit */ optional() noexcept : storage_(trivial_init), init_(false) {} + + /// Constructs an optional object that does not contain a value. + /* implicit */ optional(nullopt_t) noexcept + : storage_(trivial_init), init_(false) {} + + /// Constructs an optional object that matches the state of v. + /* implicit */ optional(const optional &v) + : storage_(trivial_init), init_(v.init_) { + if (init_) { + new (&storage_.value_) T(v.storage_.value_); + } + } + + /// Constructs an optional object that contains the specified value. + /* implicit */ optional(const T &v) : storage_(v), init_(true) {} + + /// Constructs an optional object from v. + /* implicit */ optional(optional &&v) noexcept( + std::is_nothrow_move_constructible::value) + : storage_(trivial_init), init_(v.init_) { + if (init_) { + new (&storage_.value_) T(std::forward(v.storage_.value_)); + } + } + + /// Constructs an optional object that contains the specified value. + /* implicit */ optional(T &&v) : storage_(std::forward(v)), init_(true) {} + + optional &operator=(const optional &rhs) { + if (init_ && !rhs.init_) { + clear(); + } else if (!init_ && rhs.init_) { + init_ = true; + new (&storage_.value_) T(rhs.storage_.value_); + } else if (init_ && rhs.init_) { + storage_.value_ = rhs.storage_.value_; + } + return *this; + } + + optional &operator=(optional &&rhs) noexcept( + std::is_nothrow_move_assignable::value && + std::is_nothrow_move_constructible::value) { + if (init_ && !rhs.init_) { + clear(); + } else if (!init_ && rhs.init_) { + init_ = true; + new (&storage_.value_) T(std::forward(rhs.storage_.value_)); + } else if (init_ && rhs.init_) { + storage_.value_ = std::forward(rhs.storage_.value_); + } + return *this; + } + + /// Destroys the stored value if there is one + ~optional() { + if (init_) { + storage_.value_.~T(); + } + } + + optional &operator=(nullopt_t) noexcept { + clear(); + return *this; + } + + /// Returns true if the object contains a value, false otherwise + explicit operator bool() const noexcept { return init_; } + + /// Returns true if the object contains a value, false otherwise + bool has_value() const noexcept { return init_; } + + /// Returns a constant reference to the contained value. Calls ET_CHECK if + /// the object does not contain a value. + T const &value() const & { + ET_CHECK(init_); + return contained_val(); + } + + /// Returns a mutable reference to the contained value. Calls ET_CHECK if the + /// object does not contain a value. + T &value() & { + ET_CHECK(init_); + return contained_val(); + } + + /// Returns an rvalue of the contained value. Calls ET_CHECK if the object + /// does not contain a value. + T &&value() && { + ET_CHECK(init_); + return std::forward(contained_val()); + } + +private: + // Used to invoke the dummy ctor of storage_t in the initializer lists of + // optional_base as default ctor is implicitly deleted because T is nontrivial + struct trivial_init_t { + } trivial_init{}; + + /** + * A wrapper type that lets us avoid constructing a T when there is no value. + * If there is a value present, the optional class must destroy it. + */ + union storage_t { + /// A small, trivially-constructable alternative to T. + unsigned char dummy_; + /// The constructed value itself, if optional::has_value_ is true. + T value_; + + /* implicit */ storage_t(trivial_init_t) { dummy_ = 0; } + + template + storage_t(Args &&...args) : value_(std::forward(args)...) {} + + ~storage_t() {} + }; + + const T &contained_val() const & { return storage_.value_; } + T &&contained_val() && { return std::move(storage_.value_); } + T &contained_val() & { return storage_.value_; } + + void clear() noexcept { + if (init_) { + storage_.value_.~T(); + } + init_ = false; + } + + storage_t storage_; + bool init_; +}; + +} // namespace etensor +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::etensor::nullopt; +using ::executorch::runtime::etensor::nullopt_t; +using ::executorch::runtime::etensor::optional; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/portable_type/qint_types.h b/third-party/include/executorch/runtime/core/portable_type/qint_types.h new file mode 100644 index 00000000..183675e1 --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/qint_types.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace runtime { +namespace etensor { + +/** + * qint8 is for signed 8 bit quantized Tensors + */ +struct alignas(1) qint8 { + using underlying = int8_t; + int8_t val_; + qint8() = default; + explicit qint8(int8_t val) : val_(val) {} +}; + +/** + * quint8 is for unsigned 8 bit quantized Tensors + */ +struct alignas(1) quint8 { + using underlying = uint8_t; + uint8_t val_; + quint8() = default; + explicit quint8(uint8_t val) : val_(val) {} +}; + +/** + * qint32 is for signed 32 bit quantized Tensors + */ +struct alignas(4) qint32 { + using underlying = int32_t; + int32_t val_; + qint32() = default; + explicit qint32(int32_t val) : val_(val) {} +}; + +/** + * quint4x2 is for un-signed 4 bit quantized Tensors that are packed to byte + * boundary. + */ +struct alignas(1) quint4x2 { + using underlying = uint8_t; + uint8_t val_; + quint4x2() = default; + explicit quint4x2(uint8_t val) : val_(val) {} +}; + +/** + * quint2x4 is for un-signed 2 bit quantized Tensors that are packed to byte + * boundary. + */ +struct alignas(1) quint2x4 { + using underlying = uint8_t; + uint8_t val_; + quint2x4() = default; + explicit quint2x4(uint8_t val) : val_(val) {} +}; + +} // namespace etensor +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::etensor::qint32; +using ::executorch::runtime::etensor::qint8; +using ::executorch::runtime::etensor::quint2x4; +using ::executorch::runtime::etensor::quint4x2; +using ::executorch::runtime::etensor::quint8; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/portable_type/scalar.h b/third-party/include/executorch/runtime/core/portable_type/scalar.h new file mode 100644 index 00000000..f503990e --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/scalar.h @@ -0,0 +1,110 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace executorch { +namespace runtime { +namespace etensor { + +/** + * Represents a scalar value. + * + * The API is a source-compatible subset of c10::Scalar, and the + * semantics/behavior should also match the c10 version. + */ +class Scalar { +public: + Scalar() : Scalar(int64_t(0)) {} + + template ::value, + bool>::type = true> + /*implicit*/ Scalar(T val) : tag(Tag::Int) { + v.as_int = static_cast(val); + } + /*implicit*/ Scalar(bool val) : tag(Tag::Bool) { v.as_bool = val; } + /*implicit*/ Scalar(double val) : tag(Tag::Double) { v.as_double = val; } + /*implicit*/ Scalar(BFloat16 val) : Scalar((double)(float)val) {} + /*implicit*/ Scalar(Half val) : Scalar((double)(float)val) {} + + /// Returns the concrete scalar value stored within. + template T to() const; + + /// Returns true if the scalar is integral, false otherwise. + bool isIntegral(bool includeBool) const { + return Tag::Int == tag || (includeBool && isBoolean()); + } + + /// Returns true if the scalar is a floating point, false otherwise. + bool isFloatingPoint() const { return tag == Tag::Double; } + + /// Returns true if the scalar is a boolean, false otherwise. + bool isBoolean() const { return tag == Tag::Bool; } + +private: + int64_t toInt() const { + if (isIntegral(/*includeBool=*/false)) { + return v.as_int; + } else if (isBoolean()) { + return static_cast(v.as_bool); + } else { + ET_CHECK_MSG(false, "Scalar is not an int nor a Boolean."); + } + } + + double toFloatingPoint() const { + ET_CHECK_MSG(isFloatingPoint(), "Scalar is not a Double."); + return v.as_double; + } + + double toDouble() const { + ET_CHECK_MSG(isFloatingPoint(), "Scalar is not a Double."); + return v.as_double; + } + + bool toBool() const { + ET_CHECK_MSG(isBoolean(), "Scalar is not a Boolean."); + return v.as_bool; + } + + Tag tag; + union v_t { + double as_double; + int64_t as_int; + bool as_bool; + v_t() {} // default constructor + } v; +}; + +#define ET_DEFINE_SCALAR_TO_METHOD(T, name) \ + template <> inline T Scalar::to() const { return to##name(); } + +ET_DEFINE_SCALAR_TO_METHOD(double, Double) +ET_DEFINE_SCALAR_TO_METHOD(int64_t, Int) +ET_DEFINE_SCALAR_TO_METHOD(bool, Bool) +#undef ET_DEFINE_SCALAR_TO_METHOD + +} // namespace etensor +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::etensor::Scalar; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/portable_type/scalar_type.h b/third-party/include/executorch/runtime/core/portable_type/scalar_type.h new file mode 100644 index 00000000..b7d20872 --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/scalar_type.h @@ -0,0 +1,154 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * + * Forked from + * https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h + * + * Everything but the ScalarType definition is in util/ScalarTypeUtil.h + * + * Note that these files do not need to be strictly identical to the pytorch + * core file, as far as names go. The only critical piece is that the types and + * indices of the main ScalarType enum line up, so that serialization is + * compatible between the two. + * + * Modifications for ExecuTorch: + * - Namespace torch::executor instead of c10 + * - Macro prefix ET_ instead of AT_ + * - Use ET_CHECK_MSG() instead of TORCH_CHECK() + * - Don't define standalone constants like `kByte`, `kInt` to keep the + * namespace clean + * - Remove operator<< to avoid a dependency on ostream and stdlib + * - Make `static inline` functions `inline` to avoid creating multiple + * copies of them. See + * https://gist.github.com/htfy96/50308afc11678d2e3766a36aa60d5f75#conclusion. + * - Remove deprecated definitions + * - Minor cleanup for internal consistency + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace runtime { +namespace etensor { + +// Placing a bunch of unused dtypes here as our macros don't make it easy +// to skip scalar types defined in aten that we dont have. +namespace unused_dtype { +struct alignas(1) Float8_e5m2 { + uint8_t x; + using underlying = uint8_t; + Float8_e5m2() = default; + explicit Float8_e5m2(uint8_t val) : x(val) {} +}; +struct alignas(1) Float8_e4m3fn { + uint8_t x; + using underlying = uint8_t; + Float8_e4m3fn() = default; + explicit Float8_e4m3fn(uint8_t val) : x(val) {} +}; +struct alignas(1) Float8_e5m2fnuz { + uint8_t x; + using underlying = uint8_t; + Float8_e5m2fnuz() = default; + explicit Float8_e5m2fnuz(uint8_t val) : x(val) {} +}; +struct alignas(1) Float8_e4m3fnuz { + uint8_t x; + using underlying = uint8_t; + Float8_e4m3fnuz() = default; + explicit Float8_e4m3fnuz(uint8_t val) : x(val) {} +}; + +} // namespace unused_dtype + +/** + * Calls the provided macro on every ScalarType, providing the C type and the + * ScalarType name to each call. + * + * The indices and C types must be consistent with + * AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS in the core pytorch file + * c10/core/ScalarType.h. This ensures that ExecuTorch serialization is + * compatible with ATen serialization. + * + * @param _ A macro that takes two parameters: the name of a C type, and the + * name of the corresponding ScalarType enumerator. + */ +#define ET_FORALL_SCALAR_TYPES(_) \ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int32_t, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(::executorch::runtime::etensor::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(::executorch::runtime::etensor::complex<::torch::executor::Half>, \ + ComplexHalf) /* 8 */ \ + _(::executorch::runtime::etensor::complex, ComplexFloat) /* 9 */ \ + _(::executorch::runtime::etensor::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(::executorch::runtime::etensor::qint8, QInt8) /* 12 */ \ + _(::executorch::runtime::etensor::quint8, QUInt8) /* 13 */ \ + _(::executorch::runtime::etensor::qint32, QInt32) /* 14 */ \ + _(::executorch::runtime::etensor::BFloat16, BFloat16) /* 15 */ \ + _(::executorch::runtime::etensor::quint4x2, QUInt4x2) /* 16 */ \ + _(::executorch::runtime::etensor::quint2x4, QUInt2x4) /* 17 */ \ + _(::executorch::runtime::etensor::bits1x8, Bits1x8) /* 18 */ \ + _(::executorch::runtime::etensor::bits2x4, Bits2x4) /* 19 */ \ + _(::executorch::runtime::etensor::bits4x2, Bits4x2) /* 20 */ \ + _(::executorch::runtime::etensor::bits8, Bits8) /* 21 */ \ + _(::executorch::runtime::etensor::bits16, Bits16) /* 22 */ \ + _(::executorch::runtime::etensor::unused_dtype::Float8_e5m2, \ + Float8_e5m2) /* 23 */ \ + _(::executorch::runtime::etensor::unused_dtype::Float8_e4m3fn, \ + Float8_e4m3fn) /* 24 */ \ + _(::executorch::runtime::etensor::unused_dtype::Float8_e5m2fnuz, \ + Float8_e5m2fnuz) /* 25 */ \ + _(::executorch::runtime::etensor::unused_dtype::Float8_e4m3fnuz, \ + Float8_e4m3fnuz) /* 26 */ \ + _(uint16_t, UInt16) /* 27 */ \ + _(uint32_t, UInt32) /* 28 */ \ + _(uint64_t, UInt64) /* 29 */ + +/** + * Data types (dtypes) that can be used as element types in ETensors. + */ +enum class ScalarType : int8_t { +/// Define an enumerator for each ScalarType. +#define DEFINE_ENUM(unused, name) name, + ET_FORALL_SCALAR_TYPES(DEFINE_ENUM) +#undef DEFINE_ENUM + + /// An explicitly undefined ScalarType. Does not map to any C type. + Undefined, + /// The number of ScalarType enumerators. + NumOptions, +}; + +} // namespace etensor +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::etensor::ScalarType; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/portable_type/string_view.h b/third-party/include/executorch/runtime/core/portable_type/string_view.h new file mode 100644 index 00000000..8e28fa02 --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/string_view.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace runtime { +namespace etensor { + +using std::string_view; + +} // namespace etensor +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::etensor::string_view; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/portable_type/tensor.h b/third-party/include/executorch/runtime/core/portable_type/tensor.h new file mode 100644 index 00000000..7e9efaa4 --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/tensor.h @@ -0,0 +1,142 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace executorch { +namespace runtime { +namespace etensor { + +/** + * A minimal Tensor type whose API is a source compatible subset of at::Tensor. + * + * NOTE: Instances of this class do not own the TensorImpl given to it, + * which means that the caller must guarantee that the TensorImpl lives longer + * than any Tensor instances that point to it. + * + * See the documention on TensorImpl for details about the return/parameter + * types used here and how they relate to at::Tensor. + */ +class Tensor { +public: + /// The type used for elements of `sizes()`. + using SizesType = TensorImpl::SizesType; + /// The type used for elements of `dim_order()`. + using DimOrderType = TensorImpl::DimOrderType; + /// The type used for elements of `strides()`. + using StridesType = TensorImpl::StridesType; + + Tensor() = delete; + explicit constexpr Tensor(TensorImpl *impl) : impl_(impl) {} + + /** + * Returns a pointer to the underlying TensorImpl. + * + * NOTE: Clients should be wary of operating on the TensorImpl + * directly instead of the Tensor. It is easy to break things. + */ + TensorImpl *unsafeGetTensorImpl() const { + // TODO(T154114015): See if we can make this api private with friends. + return impl_; + } + + /** + * Returns the size of the tensor in bytes. + * + * NOTE: Only the alive space is returned not the total capacity of the + * underlying data blob. + */ + size_t nbytes() const { return impl_->nbytes(); } + + /** + * Returns the size of the tensor at the given dimension. + * + * NOTE: that size() intentionally does not return SizeType even though it + * returns an element of an array of SizeType. This is to help make calls of + * this method more compatible with at::Tensor, and more consistent with the + * rest of the methods on this class and in ETensor. + */ + ssize_t size(ssize_t dim) const { return impl_->size(dim); } + + /// Returns the tensor's number of dimensions. + ssize_t dim() const { return impl_->dim(); } + + /// Returns the number of elements in the tensor. + ssize_t numel() const { return impl_->numel(); } + + /// Returns the type of the elements in the tensor (int32, float, bool, etc). + ScalarType scalar_type() const { return impl_->scalar_type(); } + + inline ScalarType dtype() const { return scalar_type(); } + + /// Returns the size in bytes of one element of the tensor. + ssize_t element_size() const { return impl_->element_size(); } + + /// Returns the sizes of the tensor at each dimension. + const ArrayRef sizes() const { return impl_->sizes(); } + + /// Returns the order the dimensions are laid out in memory. + const ArrayRef dim_order() const { return impl_->dim_order(); } + + /// Returns the strides of the tensor at each dimension. + const ArrayRef strides() const { return impl_->strides(); } + + /// Returns the mutability of the shape of the tensor. + TensorShapeDynamism shape_dynamism() const { return impl_->shape_dynamism(); } + + /// Returns a pointer of type T to the constant underlying data blob. + template inline const T *const_data_ptr() const { + return impl_->data(); + } + + /// Returns a pointer to the constant underlying data blob. + inline const void *const_data_ptr() const { return impl_->data(); } + + /// Returns a pointer of type T to the mutable underlying data blob. + template inline T *mutable_data_ptr() const { + return impl_->mutable_data(); + } + + /// Returns a pointer to the mutable underlying data blob. + inline void *mutable_data_ptr() const { return impl_->mutable_data(); } + + /// DEPRECATED: Use const_data_ptr or mutable_data_ptr instead. + template ET_DEPRECATED inline T *data_ptr() const { + return impl_->mutable_data(); + } + + /// DEPRECATED: Use const_data_ptr or mutable_data_ptr instead. + ET_DEPRECATED inline void *data_ptr() const { return impl_->mutable_data(); } + + /** + * DEPRECATED: Changes the data_ptr the tensor aliases. Does not free the + * previously pointed to data, does not assume ownership semantics of the new + * ptr. This api does not exist in at::Tensor so kernel developers should + * avoid it. + */ + ET_DEPRECATED void set_data(void *ptr) const { impl_->set_data(ptr); } + +private: + TensorImpl *impl_ = nullptr; +}; + +} // namespace etensor +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::etensor::Tensor; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/portable_type/tensor_impl.h b/third-party/include/executorch/runtime/core/portable_type/tensor_impl.h new file mode 100644 index 00000000..7357fffa --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/tensor_impl.h @@ -0,0 +1,261 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +// Forward declaration of a helper that provides access to internal resizing +// methods of TensorImpl. Real definition is in +// executorch/runtime/core/exec_aten/tensor_util.h. +namespace executorch { +namespace runtime { +namespace internal { +class TensorResizerFriend; +} // namespace internal +} // namespace runtime +} // namespace executorch + +namespace executorch { +namespace runtime { +namespace etensor { + +/** + * Manages the storage behind an ETensor (torch::executor::Tensor). + * + * Note that instances of this class do not own the arrays given to it + * (sizes/strides/data), which means that the caller must guarantee that they + * live longer than a given instance of this class. + * + * Note on types: + * + * Code that uses ETensor should also be able to build against at::Tensor. So, + * although the overlapping APIs don't need to be exactly the same, their types + * should be semantically similar. + * + * Many of the methods in at::Tensor use int64_t for parameter and return types. + * This can be a waste when building for 32-bit environments. So, TensorImpl and + * ETensor use ssize_t instead: like int64_t it is signed, but it will match the + * native word size of the target architecture. This will avoid unnecessarily + * expensive uses of 64-bit integers on 32-bit machines. + * + * But, since the types are not identical, code that uses ETensor needs to be + * generic about the local types it uses when working with these methods. In + * most cases, `auto` will do the trick. In the worst case, code can be guarded + * with `#ifdef USE_ATEN_LIB`. + */ +class TensorImpl { +public: + /** + * The type used for elements of `sizes()`. + * + * This must match the size/signedness of the type used for `Tensor.sizes` in + * //executorch/schema/program.fbs. + * + * Note that at::TensorImpl uses `int64_t` for this type. ExecuTorch uses + * `int32_t` to save memory, since no single size value will ever be larger + * than 2 billion. + */ + using SizesType = int32_t; + + /** + * The type used for elements of `dim_order()`. + * + * This must match the size/signedness of the type used for `Tensor.dim_order` + * in //executorch/schema/program.fbs. + */ + using DimOrderType = uint8_t; + + /** + * The type used for elements of `strides()`. + * + * This must match the size/signedness of the type used for `Tensor.strides` + * in //executorch/schema/program.fbs. + * + * Note that at::TensorImpl uses `int64_t` for this type. ExecuTorch uses + * `int32_t` to save memory, since no single stride value will ever be larger + * than 2 billion. + */ + using StridesType = int32_t; + + TensorImpl() = delete; + + /** + * @param type The type of the data (int, float, bool). + * @param dim Number of dimensions, and the length of the `sizes` array. + * @param sizes Sizes of the tensor at each dimension. Must contain `dim` + * entries. + * @param data Pointer to the data, whose size is determined by `type`, + * `dim`, and `sizes`. The tensor will not own this memory. + * @param dim_order Order in which dimensions are laid out in memory. + * @param strides Strides of the tensor at each dimension. Must contain `dim` + * entries. + * @param dynamism The mutability of the shape of the tensor. + */ + TensorImpl(ScalarType type, ssize_t dim, SizesType *sizes, + void *data = nullptr, DimOrderType *dim_order = nullptr, + StridesType *strides = nullptr, + TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC); + + /** + * Returns the size of the tensor in bytes. + * + * NOTE: This returns the size of the data used by the tensor's current shape, + * not the capacity of the underlying buffer. + */ + size_t nbytes() const; + + /** + * Returns the size of the tensor at the given dimension. + * + * NOTE: size() intentionally does not return SizeType even though it + * returns an element of an array of SizeType. This is to help make calls of + * this method more compatible with at::Tensor, and more consistent with the + * rest of the methods on this class and in ETensor. + */ + ssize_t size(ssize_t dim) const { + ET_CHECK_MSG(dim < dim_ && dim >= 0, + "Dimension out of range (expected to be in range of [0, %zd], " + "but got %zd", + dim_ - 1, dim); + return sizes_[dim]; + } + + /// Returns the tensor's number of dimensions. + ssize_t dim() const { return dim_; } + + /// Returns the number of elements in the tensor. + ssize_t numel() const { return numel_; } + + /// Returns the type of the elements in the tensor (int32, float, bool, etc). + ScalarType scalar_type() const { return type_; } + + inline ScalarType dtype() const { return scalar_type(); } + + /// Returns the size in bytes of one element of the tensor. + ssize_t element_size() const; + + /// Returns the sizes of the tensor at each dimension. + const ArrayRef sizes() const { + return ArrayRef{sizes_, static_cast(dim_)}; + } + + /// Returns the order the dimensions are laid out in memory. + const ArrayRef dim_order() const { + return ArrayRef{dim_order_, static_cast(dim_)}; + } + + /// Returns the strides of the tensor at each dimension. + const ArrayRef strides() const { + return ArrayRef{strides_, static_cast(dim_)}; + } + + /// Returns the mutability of the shape of the tensor. + TensorShapeDynamism shape_dynamism() const { return shape_dynamism_; } + + /// Returns a pointer of type T to the constant underlying data blob. + template inline const T *data() const { + return static_cast(data()); + } + + /// Returns a pointer to the constant underlying data blob. + const void *data() const { return data_; } + + /// Returns a pointer of type T to the mutable underlying data blob. + template inline T *mutable_data() const { + return static_cast(mutable_data()); + } + + /// Returns a pointer to the mutable underlying data blob. + void *mutable_data() const { return data_; } + + /// Sets the underlying data blob to the passed in pointer. + void set_data(void *ptr) { data_ = ptr; } + + /* + * DEPRECATED: Use torch::executor::resize_tensor() or + * torch::executor::resize_tensor_impl(). + */ + ET_DEPRECATED + void set_sizes_contiguous(ArrayRef new_sizes) { + Error err = internal_resize_contiguous(new_sizes); + ET_CHECK_MSG(err == Error::Ok, + "Could not resize Tensor; see logs for details"); + } + +private: + // For access to internal_resize_contiguous(). + friend class ::executorch::runtime::internal::TensorResizerFriend; + + /** + * Set the sizes and strides of a tensor assuming contiguous strides. + * Requires that `new_sizes.size() == this.dim()`. + * + * Callers must use torch::executor::resize_tensor() or + * torch::executor::resize_tensor_impl() instead, defined in TensorUtil.h. + * + * Same semantics as at::TensorImpl::set_sizes_contiguous(), but returns an + * error instead of panicking on failure. This is not part of the at::Tensor + * API, and can only be used in lean mode. + */ + ET_NODISCARD Error internal_resize_contiguous(ArrayRef new_sizes); + +private: + // Keep fields arranged to avoid unnecessary alignment holes. + + /// List of sizes of each dimension in the tensor. + SizesType *sizes_; + + /// List of the order that dimensions are laid out in memory. + DimOrderType *dim_order_; + + // TODO(T148356881): Get rid of strides from ETensor + StridesType *strides_; + + /// Pointer to underlying data blob. NOTE: Can be null. + void *data_; + + /// Tensor's number of dimensions. + const ssize_t dim_; + + /// Number of elements in the tensor. + ssize_t numel_; + + /// Maximum number of elements in the bounded tensor. Used when resizing up + /// and down. + size_t numel_bound_; + + /// Scalar type (int, float, bool, etc) of the tensor data. + const ScalarType type_; + + /// Specifies the mutability of the shape of the tensor. + const TensorShapeDynamism shape_dynamism_; +}; + +/** + * Compute the number of elements based on the sizes of a tensor. + */ +ssize_t compute_numel( + const ::executorch::runtime::etensor::TensorImpl::SizesType *sizes, + ssize_t dim); + +} // namespace etensor +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::etensor::compute_numel; +using ::executorch::runtime::etensor::TensorImpl; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/portable_type/tensor_options.h b/third-party/include/executorch/runtime/core/portable_type/tensor_options.h new file mode 100644 index 00000000..8b8f9848 --- /dev/null +++ b/third-party/include/executorch/runtime/core/portable_type/tensor_options.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace runtime { +namespace etensor { + +/** + * Tensor data memory formats supported by ExecuTorch. This concept only exists + * for compatibility with ATen; use dim_order to describe non-contiguous + * layouts. + */ +enum class MemoryFormat : int8_t { + /** + * Row-major contiguous data. + */ + Contiguous = 0, + /** + * Output tensor format should remain the same as the input tensor format. + * E.g. if the input tensor is in channels_last format, operator output + * should be in channels_last format. + */ + Preserve = 1, +}; + +/** + * Tensor data memory layout. This concept only exists for compatibility + * with ATen. + */ +enum class Layout : int8_t { + /** + * The tensor occupies memory densely and indexing is managed through strides. + * Contrasted with a sparse tensor layout where the memory structure of the + * data blob will be more complicated and indexing requires larger structures. + * + * This is the only layout supported by ExecuTorch. + */ + Strided = 0, +}; +} // namespace etensor +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::etensor::Layout; +using ::executorch::runtime::etensor::MemoryFormat; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/result.h b/third-party/include/executorch/runtime/core/result.h new file mode 100644 index 00000000..00cc7bb8 --- /dev/null +++ b/third-party/include/executorch/runtime/core/result.h @@ -0,0 +1,254 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * Result type to be used in conjunction with ExecuTorch Error type. + */ + +#pragma once + +#include +#include + +#include "executorch/runtime/core/error.h" +#include "executorch/runtime/platform/assert.h" + +namespace executorch { +namespace runtime { + +/** + * Result type wrapping either a value of type T or an error. + * + * Example use case: + * @code + * Result getOp(int opcode) { + * if (isValidOpCode(opcode)) { + * return opFns[opcode]; + * } + * return Error::NotFound; + * } + * + * Error useOp(int opcode) { + * Result op = getOp(opcode); + * if (!op.ok()) { + * return op.error(); + * } + * print(op->toString()); + * execute(*op); + * return Error::Ok; + * } + * @endcode + */ +template class Result final { +public: + /// `value_type` member for generic programming. + typedef T value_type; + + /** + * Creates a Result object from an Error. + * + * To preserve the invariant that `(result.error() == Error::Ok) == + * result.ok()`, an `error` parameter value of `Error:Ok` will be converted to + * a non-Ok value. + */ + /* implicit */ Result(Error error) + : error_(error == Error::Ok ? Error::Internal : error), hasValue_(false) { + } + + /// Value copy constructor. + /* implicit */ Result(const T &val) : value_(val), hasValue_(true) {} + + /// Value move constructor. + /* implicit */ Result(T &&val) : value_(std::move(val)), hasValue_(true) {} + + /// Result move constructor. + /* implicit */ Result(Result &&rhs) noexcept : hasValue_(rhs.hasValue_) { + if (hasValue_) { + // Use the value type's move constructor. + new (&value_) T(std::move(rhs.value_)); + } else { + error_ = rhs.error_; + } + } + + ~Result() { + if (hasValue_) { + // Manual value destruction. + // Result "owns" the memory, so `delete` would segfault. + value_.~T(); + } + } + + /** + * Returns true if this Result has a value. + * + * If true, it is guaranteed that `error()` will return `Error::Ok`. + * If false, it is guaranteed that `error()` will not return `Error::Ok`. + */ + ET_NODISCARD bool ok() const { return hasValue_; } + + /** + * Returns the error code of this Result. + * + * If this returns `Error::Ok`, it is guaranteed that `ok()` will return true. + * If this does not return `Error:Ok`, it is guaranteed that `ok()` will + * return false. + */ + ET_NODISCARD Error error() const { + if (hasValue_) { + return Error::Ok; + } else { + return error_; + } + } + + /** + * Returns a reference to the Result's value; longhand for operator*(). + * + * Only legal to call if `ok()` returns true. + */ + T &get() { + CheckOk(); + return value_; + } + + /** + * Returns a reference to the Result's value; longhand for operator*(). + * + * Only legal to call if `ok()` returns true. + */ + const T &get() const { + CheckOk(); + return value_; + } + + /* + * Returns a reference to the Result's value; shorthand for get(). + * + * Only legal to call if `ok()` returns true. + */ + const T &operator*() const &; + T &operator*() &; + + /* + * Returns a pointer to the Result's value. + * + * Only legal to call if `ok()` returns true. + */ + const T *operator->() const; + T *operator->(); + +private: + /** + * Delete default constructor since all Results should contain a value or + * error. + */ + Result() = delete; + /// Delete copy constructor since T may not be copyable. + Result(const Result &) = delete; + /// Delete copy assignment since T may not be copyable. + Result &operator=(const Result &) = delete; + /// Delete move assignment since it's not a supported pattern to reuse Result. + Result &operator=(Result &&rhs) = delete; + + // Panics if ok() would return false; + void CheckOk() const { ET_CHECK(hasValue_); } + + union { + T value_; // Used if hasValue_ is true. + Error error_; // Used if hasValue_ is false. + }; + + /// True if the Result contains a value. + const bool hasValue_; +}; + +template const T &Result::operator*() const & { + CheckOk(); + return value_; +} + +template T &Result::operator*() & { + CheckOk(); + return value_; +} + +template const T *Result::operator->() const { + CheckOk(); + return &value_; +} + +template T *Result::operator->() { + CheckOk(); + return &value_; +} + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::Result; +} // namespace executor +} // namespace torch + +/** + * Unwrap a Result to obtain its value. If the Result contains an error, + * propogate the error via trivial function return. + * + * Note: A function using ET_UNWRAP should itself return a Result or Error. + * + * @param[in] result__ Expression yielding the result to unwrap. + * @param[in] ... Optional format string for the log error message and its + * arguments. + */ +#define ET_UNWRAP(result__, ...) ET_INTERNAL_UNWRAP(result__, ##__VA_ARGS__) + +// Internal only: Use ET_UNWRAP() instead. +#define ET_INTERNAL_UNWRAP(...) \ + ET_INTERNAL_UNWRAP_SELECT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) \ + (__VA_ARGS__) + +// Internal only: Use ET_UNWRAP() instead. +#define ET_INTERNAL_UNWRAP_SELECT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, \ + ...) \ + ET_INTERNAL_UNWRAP_##N + +// Internal only: Use ET_UNWRAP() instead. +#define ET_INTERNAL_UNWRAP_1(result__) \ + ({ \ + auto et_result__ = (result__); \ + if (!et_result__.ok()) { \ + return et_result__.error(); \ + } \ + std::move(*et_result__); \ + }) + +// Internal only: Use ET_UNWRAP() instead. +#define ET_INTERNAL_UNWRAP_2(result__, message__, ...) \ + ({ \ + auto et_result__ = (result__); \ + if (!et_result__.ok()) { \ + ET_LOG(Error, message__, ##__VA_ARGS__); \ + return et_result__.error(); \ + } \ + std::move(*et_result__); \ + }) + +// Internal only: Use ET_UNWRAP() instead. +#define ET_INTERNAL_UNWRAP_3 ET_INTERNAL_UNWRAP_2 +#define ET_INTERNAL_UNWRAP_4 ET_INTERNAL_UNWRAP_2 +#define ET_INTERNAL_UNWRAP_5 ET_INTERNAL_UNWRAP_2 +#define ET_INTERNAL_UNWRAP_6 ET_INTERNAL_UNWRAP_2 +#define ET_INTERNAL_UNWRAP_7 ET_INTERNAL_UNWRAP_2 +#define ET_INTERNAL_UNWRAP_8 ET_INTERNAL_UNWRAP_2 +#define ET_INTERNAL_UNWRAP_9 ET_INTERNAL_UNWRAP_2 +#define ET_INTERNAL_UNWRAP_10 ET_INTERNAL_UNWRAP_2 diff --git a/third-party/include/executorch/runtime/core/span.h b/third-party/include/executorch/runtime/core/span.h new file mode 100644 index 00000000..903a6d27 --- /dev/null +++ b/third-party/include/executorch/runtime/core/span.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace executorch { +namespace runtime { + +/** + * Represent a reference to an array (0 or more elements + * consecutively in memory), i.e. a start pointer and a length. It allows + * various APIs to take consecutive elements easily and conveniently. + * + * This class does not own the underlying data, it is expected to be used in + * situations where the data resides in some other buffer, whose lifetime + * extends past that of the Span. + * + * Span and ArrayRef are extrememly similar with the difference being ArrayRef + * views a list of constant elements and Span views a list of mutable elements. + * Clients should decide between the two based on if the list elements for their + * use case should be mutable. + * + * This is intended to be trivially copyable, so it should be passed by + * value. + */ +template class Span final { +public: + using iterator = T *; + using size_type = size_t; + +public: + /// Construct an empty Span. + /* implicit */ constexpr Span() noexcept : data_(nullptr), length_(0) {} + + /// Construct a Span from a pointer and length. + Span(T *data, size_t length) : data_(data), length_(length) { + ET_DCHECK(data_ != nullptr || length_ == 0); + } + + /// Construct a Span from a range. + Span(T *begin, T *end) : data_(begin), length_(end - begin) {} + + /// Construct a Span from a C array. + template + /* implicit */ constexpr Span(T (&Arr)[N]) : data_(Arr), length_(N) {} + + /// @returns a pointer to the start of the underlying element buffer. + iterator begin() const noexcept { return data_; } + + /// @returns a pointer to the end of the underlying element buffer. + iterator end() const noexcept { return data_ + length_; } + + /// @retval a boolean indicating if the Span is empty. + constexpr bool empty() const noexcept { return length_ == 0; } + + /// @returns a pointer to the start of the underlying element buffer. + constexpr T *data() const noexcept { return data_; } + + /// @returns the number of elements in the Span. + constexpr size_t size() const noexcept { return length_; } + + /// Unchecked index into the array according to the argument index. + /// @returns a reference to the element at the specified index. + T &operator[](size_t index) const { return data_[index]; } + +private: + /// The start of the array, in an external buffer. + T *data_; + + /// The number of elements. + size_type length_; +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::Span; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/tag.h b/third-party/include/executorch/runtime/core/tag.h new file mode 100644 index 00000000..8c329105 --- /dev/null +++ b/third-party/include/executorch/runtime/core/tag.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace runtime { + +#define EXECUTORCH_FORALL_TAGS(_) \ + _(None) \ + _(Tensor) \ + _(String) \ + _(Double) \ + _(Int) \ + _(Bool) \ + _(ListBool) \ + _(ListDouble) \ + _(ListInt) \ + _(ListTensor) \ + _(ListScalar) \ + _(ListOptionalTensor) + +/** + * The dynamic type of an EValue. + */ +enum class Tag : uint32_t { +#define DEFINE_TAG(x) x, + EXECUTORCH_FORALL_TAGS(DEFINE_TAG) +#undef DEFINE_TAG +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::Tag; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/core/tensor_shape_dynamism.h b/third-party/include/executorch/runtime/core/tensor_shape_dynamism.h new file mode 100644 index 00000000..ee956288 --- /dev/null +++ b/third-party/include/executorch/runtime/core/tensor_shape_dynamism.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace runtime { + +/** + * The resizing capabilities of a Tensor. + * + * The rank of an ExecuTorch Tensors can never change, but shape sometimes can. + */ +enum class TensorShapeDynamism : uint8_t { + /// Cannot change shape. + STATIC = 0, + /// Shape cannot exceed initial capacity. + DYNAMIC_BOUND = 1, + /// No restriction on shape and capacity. + DYNAMIC_UNBOUND = 2, +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::TensorShapeDynamism; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/executor/memory_manager.h b/third-party/include/executorch/runtime/executor/memory_manager.h new file mode 100644 index 00000000..91cdeb3d --- /dev/null +++ b/third-party/include/executorch/runtime/executor/memory_manager.h @@ -0,0 +1,113 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace runtime { + +/** + * A container class for allocators used during Method load and execution. + * + * This class consolidates all dynamic memory needs for Method load and + * execution. This can allow for heap-based as well as heap-less execution + * (relevant to some embedded scenarios), and overall provides more control over + * memory use. + * + * This class, however, cannot ensure all allocation is accounted for since + * kernel and backend implementations are free to use a separate way to allocate + * memory (e.g., for things like scratch space). But we do suggest that backends + * and kernels use these provided allocators whenever possible. + */ +class MemoryManager final { +public: + /** + * Constructs a new MemoryManager. + * + * @param[in] method_allocator The allocator to use when loading a Method and + * allocating its internal structures. Must outlive the Method that uses + * it. + * @param[in] planned_memory The memory-planned buffers to use for mutable + * tensor data when executing a Method. Must outlive the Method that uses + * it. May be `nullptr` if the Method does not use any memory-planned + * tensor data. The sizes of the buffers in this HierarchicalAllocator + * must agree with the corresponding + * `MethodMeta::num_memory_planned_buffers()` and + * `MethodMeta::memory_planned_buffer_size(N)` values, which are embedded + * in the Program. + * @param[in] temp_allocator The allocator to use when allocating temporary + * data during kernel or delegate execution. Must outlive the Method that + * uses it. May be `nullptr` if the Method does not use kernels or + * delegates that allocate temporary data. This allocator will be reset + * after every kernel or delegate call during execution. + */ + explicit MemoryManager(MemoryAllocator *method_allocator, + HierarchicalAllocator *planned_memory = nullptr, + MemoryAllocator *temp_allocator = nullptr) + : method_allocator_(method_allocator), planned_memory_(planned_memory), + temp_allocator_(temp_allocator) { + ET_CHECK_MSG(method_allocator != temp_allocator, + "method allocator cannot be the same as temp allocator"); + } + + /** + * DEPRECATED: Use the constructor without `constant_allocator` instead. + * + * TODO(T162089316): Remove this once all users migrate to the new ctor. + */ + ET_DEPRECATED MemoryManager(MemoryAllocator *constant_allocator, + HierarchicalAllocator *non_constant_allocator, + MemoryAllocator *runtime_allocator, + MemoryAllocator *temporary_allocator) + : MemoryManager( + /*method_allocator=*/runtime_allocator, + /*planned_memory=*/non_constant_allocator, + /*temp_allocator=*/temporary_allocator) { + (void)constant_allocator; // Suppress unused variable warning + } + + /** + * Returns the allocator that the runtime will use to allocate internal + * structures while loading a Method. Must not be used after its associated + * Method has been loaded. + */ + MemoryAllocator *method_allocator() const { return method_allocator_; } + + /** + * Returns the memory-planned buffers to use for mutable tensor data. + */ + HierarchicalAllocator *planned_memory() const { return planned_memory_; } + + /** + * Returns the allocator to use for allocating temporary data during kernel or + * delegate execution. + * + * This allocator will be reset after every kernel or delegate call during + * execution. + */ + MemoryAllocator *temp_allocator() const { return temp_allocator_; } + +private: + MemoryAllocator *method_allocator_; + HierarchicalAllocator *planned_memory_; + MemoryAllocator *temp_allocator_; +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::MemoryManager; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/executor/method.h b/third-party/include/executorch/runtime/executor/method.h new file mode 100644 index 00000000..06aaf5a8 --- /dev/null +++ b/third-party/include/executorch/runtime/executor/method.h @@ -0,0 +1,339 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +// Forward declare flatbuffer types. This is a public header and must not +// include the generated flatbuffer header. +namespace executorch_flatbuffer { +struct Chain; +struct ExecutionPlan; +struct EValue; +} // namespace executorch_flatbuffer + +namespace executorch { +namespace runtime { + +// Forward declare Program to avoid a circular reference. +class Program; + +// Forward declare internal types. +class BackendDelegate; +struct Chain; +class KernelRuntimeContext; +using OpFunction = void (*)(KernelRuntimeContext &, EValue **); +/// A list of pointers into the master values table that together compose the +/// argument list for a single instruction +using InstructionArgs = Span; + +/** + * An executable method of an executorch program. Maps to a python method like + * `forward()` on the original nn.Module. + */ +class Method final { +public: + /** + * Move ctor. Takes ownership of resources previously owned by `rhs`, + * and leaves `rhs` in an uninitialized state. + */ + Method(Method &&rhs) noexcept + : step_state_(rhs.step_state_), program_(rhs.program_), + memory_manager_(rhs.memory_manager_), + temp_allocator_(rhs.temp_allocator_), + serialization_plan_(rhs.serialization_plan_), + event_tracer_(rhs.event_tracer_), n_value_(rhs.n_value_), + values_(rhs.values_), n_delegate_(rhs.n_delegate_), + delegates_(rhs.delegates_), n_chains_(rhs.n_chains_), + chains_(rhs.chains_), init_state_(rhs.init_state_) { + // Required: clear out fields that the dtor looks at, so that we don't free + // anything twice. + rhs.n_value_ = 0; + rhs.values_ = nullptr; + rhs.n_delegate_ = 0; + rhs.delegates_ = nullptr; + + // Helpful: Try to ensure that any other interactions with the old object + // result in failures. + rhs.init_state_ = InitializationState::Uninitialized; + rhs.step_state_ = {}; + rhs.program_ = nullptr; + rhs.memory_manager_ = nullptr; + rhs.serialization_plan_ = nullptr; + rhs.event_tracer_ = nullptr; + rhs.n_chains_ = 0; + rhs.chains_ = nullptr; + } + + /** + * Sets the internal input value to be equivalent to the to the provided + * value. + * + * @param[in] input_evalue The evalue to copy into the method input. If the + * evalue is a tensor, the data is copied in most cases, so the tensor + * passed in here does not always need to outlive this call. But there is + * a case where the Method will keep a pointer to the tensor's data. + * Based on the memory plan of the method, the inputs may not have + * buffer space pre-allocated for them. In this case the executor will + * alias the memory of the tensors provided as inputs here rather then + * deepcopy the input into the memory planned arena. + * + * @param[in] input_idx Zero-based index of the input to set. Must be less + * than the value returned by inputs_size(). + * + * @returns Error::Ok on success, non-Ok on failure. + */ + ET_NODISCARD Error set_input(const EValue &input_evalue, size_t input_idx); + + /** + * Sets the values of all method inputs. + * + * See set_input() for a more detailed description of the behavior. + * + * @param[in] input_evalues The new values for all of the method inputs. The + * type of each element must match the type of corresponding input. If the + * value of an element is a tensor, attempts to allow dynamic shape, but + * the dtype must always agree. + * + * @returns Error::Ok on success, non-Ok on failure. + */ + ET_NODISCARD Error + set_inputs(const executorch::aten::ArrayRef &input_evalues); + + /** + * Sets the data buffer of the specified method output to the provided value. + * + * NOTE: Based on the memory plan of the method, the output tensors may not + * have buffer space pre-allocated for them, in this case the executor will + * point those tensors to the buffer provided here, so the user should take + * care that the life span of this memory outlasts the executor forward. + * + * @param[in] buffer The block of memory to point the specified tensor at. + * + * @param[in] size the length of buffer in bytes, must be >= the nbytes of the + * specified tensor. + * + * @param[in] output_idx The index of the output to set the data_ptr for. Must + * correspond to a tensor, and that tensor must not have had a buffer + * allocated by the memory plan. + * + * @returns Error::Ok on success, non-Ok on failure. + */ + ET_NODISCARD Error set_output_data_ptr(void *buffer, size_t size, + size_t output_idx); + + /** + * Copies the method's outputs into the provided array. + * + * WARNING: The output contains shallow copies of internal tensor outputs. + * Please do not mutate returned Tensor elements. + * + * TODO(T139259264): Add checks to detect output mutation, or deep-copy + * outputs. + * + * @param[in] output_evalues The array to copy the outputs into. The first + * `outputs_size()` elements will be set to the corresponding output + * values. The rest of the array will be set to the EValue value None. + * @param[in] length The size of the `output_evalues` array in elements. Must + * be greater than or equal to `outputs_size()`. + * + * @returns Error::Ok on success, non-Ok on failure. + */ + ET_NODISCARD Error get_outputs(EValue *output_evalues, size_t length); + + /** + * Copies the method's inputs into the provided array. + * + * WARNING: The input contains shallow copies of internal tensor inputs. + * Please do not mutate returned Tensor elements. + * + * @param[in] input_evalues The array to copy the inputs into. The first + * `inputs_size()` elements will be set to the corresponding input + * values. The rest of the array will be set to the EValue value None. + * @param[in] length The size of the `input_evalues` array in elements. Must + * be greater than or equal to `inputs_size()`. + * + * @returns Error::Ok on success, non-Ok on failure. + */ + ET_NODISCARD Error get_inputs(EValue *input_evalues, size_t length); + + /** + * Execute the method. + * + * NOTE: Will fail if the method has been partially executed using the + * `step()` api. + * + * @returns Error::Ok on success, non-Ok on failure. + */ + ET_NODISCARD Error execute(); + + /** + * EXPERIMENTAL: Advances/executes a single instruction in the method. + * + * @retval Error::Ok step succeeded + * @retval non-Ok step failed + * @retval Error::EndOfMethod method finished executing successfully + */ + ET_EXPERIMENTAL ET_NODISCARD Error step(); + + /// DEPRECATED: Use `step()` instead. + ET_DEPRECATED ET_NODISCARD Error experimental_step(); + + /** + * EXPERIMENTAL: Resets execution state to the start of the Method. For use + * with the `step()` API. + * + * @retval Error:Ok on success + * @retval Error::InvalidState if called before step-based execution reached + * the end of the Method. This means it is not possible to recover a + * Method that failed mid-execution. + */ + ET_EXPERIMENTAL ET_NODISCARD Error reset_execution(); + + /// DEPRECATED: Use `reset_execution()` instead. + ET_DEPRECATED ET_NODISCARD Error experimental_reset_execution(); + + /** + * Returns the MethodMeta that corresponds to the calling Method. + */ + MethodMeta method_meta() const; + + /** + * Returns the number of inputs the Method expects. + */ + size_t inputs_size() const; + + /** + * Returns the number of outputs the Method returns. + */ + size_t outputs_size() const; + + /** + * Retrieves the output at the specified index. + */ + const EValue &get_output(size_t i) const; + + EventTracer *get_event_tracer(); + + /// DEPRECATED: Use MethodMeta instead to access metadata, and set_input to + /// update Method inputs. + ET_DEPRECATED const EValue &get_input(size_t i) const; + /// DEPRECATED: Use MethodMeta instead to access metadata, and set_input to + /// update Method inputs. + ET_DEPRECATED EValue &mutable_input(size_t i); + /// DEPRECATED: Use MethodMeta instead to access metadata, and get_output to + /// retrieve Method outputs. + ET_DEPRECATED EValue &mutable_output(size_t i); + + ~Method(); + +private: + // Delete other rule-of-five methods. + Method(const Method &) = delete; + Method &operator=(const Method &) noexcept = delete; + Method &operator=(Method &&) = delete; + + // Let Program call load(). + friend class Program; + // Let Executor call the ctor and init(). + friend class Executor; + + enum class InitializationState : uint8_t { + Uninitialized, + Initialized, + InitializationFailed, + }; + + /// Tracks what step in program execution we are on + struct StepState { + size_t chain_idx; + size_t instr_idx; + }; + + Method(const Program *program, MemoryManager *memory_manager, + EventTracer *event_tracer, MemoryAllocator *temp_allocator) + : step_state_(), program_(program), memory_manager_(memory_manager), + temp_allocator_(temp_allocator), serialization_plan_(nullptr), + event_tracer_(event_tracer), n_value_(0), values_(nullptr), + n_delegate_(0), delegates_(nullptr), n_chains_(0), chains_(nullptr), + init_state_(InitializationState::Uninitialized) {} + + /// Static factory used by Program. + ET_NODISCARD static Result + load(executorch_flatbuffer::ExecutionPlan *s_plan, const Program *program, + MemoryManager *memory_manager, EventTracer *event_tracer); + + /** + * Initialize the method from its serialized representation. + * + * @returns Error::Ok on success, non-Ok on failure. + */ + ET_NODISCARD Error init(executorch_flatbuffer::ExecutionPlan *s_plan); + + /// Returns true if the Method was successfully initialized. + inline bool initialized() const { + return init_state_ == InitializationState::Initialized; + } + + const EValue &get_value(size_t i) const; + EValue &mutable_value(size_t i); + size_t get_input_index(size_t i) const; + size_t get_output_index(size_t i) const; + + // Executes a single instruction using the state in step_state_ + ET_NODISCARD Error execute_instruction(); + + StepState step_state_; + const Program *program_; + MemoryManager *memory_manager_; + MemoryAllocator *temp_allocator_; + executorch_flatbuffer::ExecutionPlan *serialization_plan_; + EventTracer *event_tracer_; + + size_t n_value_; + EValue *values_; + + size_t n_delegate_; + BackendDelegate *delegates_; + + size_t n_chains_; + Chain *chains_; + + InitializationState init_state_; + + /** + * Parses the elements of the values_ array. On error, n_value_ will be set to + * the number of successfully-initialized entries so that ~Method doesn't try + * to clean up uninitialized entries. + */ + ET_NODISCARD Error parse_values(); + + ET_NODISCARD Error resolve_operator(int32_t op_index, OpFunction *kernels, + size_t kernel_index, InstructionArgs args, + size_t n_args); + + void log_outputs(); +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::Method; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/executor/method_meta.h b/third-party/include/executorch/runtime/executor/method_meta.h new file mode 100644 index 00000000..7675ea2e --- /dev/null +++ b/third-party/include/executorch/runtime/executor/method_meta.h @@ -0,0 +1,227 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +// Forward declare flatbuffer types. This is a public header and must not +// include the generated flatbuffer header. +namespace executorch_flatbuffer { +struct ExecutionPlan; +} // namespace executorch_flatbuffer + +namespace executorch { +namespace runtime { + +/** + * Metadata about a specific tensor of an ExecuTorch Program. + * + * The program used to create the MethodMeta object that created this + * TensorInfo must outlive this TensorInfo. + */ +class TensorInfo final { +public: + TensorInfo() = delete; + TensorInfo(const TensorInfo &) = default; + TensorInfo(TensorInfo &&) = default; + TensorInfo &operator=(const TensorInfo &) = default; + TensorInfo &operator=(TensorInfo &&other) = default; + ~TensorInfo() = default; + + /** + * Returns the sizes of the tensor. + */ + Span sizes() const; + + /** + * Returns the dim order of the tensor. + */ + Span dim_order() const; + + /** + * Returns the scalar type of the input/output. + */ + executorch::aten::ScalarType scalar_type() const; + + /** + * Returns whether the tensor's memory was planned during export. + */ + bool is_memory_planned() const; + + /** + * Returns the size of the tensor in bytes. + */ + size_t nbytes() const; + +private: + // Let MethodMeta create TensorInfo. + friend class MethodMeta; + + TensorInfo(Span sizes, Span dim_order, + executorch::aten::ScalarType scalar_type, + const bool is_memory_planned); + + /** + * The sizes of the tensor. + * + * NOTE: References data from the Program, so the Program must outlive the + * TensorInfo. + */ + Span sizes_; + + /** + * The dim order of the tensor. + * + * NOTE: References data from the Program, so the Program must outlive the + * TensorInfo. + */ + Span dim_order_; + + /// The scalar type of the tensor. + executorch::aten::ScalarType scalar_type_; + + /// Whether the tensor's memory was planned during export. + bool is_memory_planned_; + + /// The size in bytes of the tensor. + size_t nbytes_; +}; + +/** + * Describes a a method in an ExecuTorch program. + * + * The program used to create a MethodMeta object must outlive the MethodMeta. + * It is separate from Method so that this information can be accessed without + * paying the initialization cost of loading the full Method. + */ +class MethodMeta final { +public: + MethodMeta() = delete; + MethodMeta(const MethodMeta &) = default; + MethodMeta(MethodMeta &&) = default; + MethodMeta &operator=(const MethodMeta &) = default; + MethodMeta &operator=(MethodMeta &&other) = default; + ~MethodMeta() = default; + + /** + * Get the name of this method. + * + * @returns The method name. + */ + const char *name() const; + + /** + * Get the number of inputs to this method. + * + * @returns The number of inputs. + */ + size_t num_inputs() const; + + /** + * Get the tag of the specified input. + * + * @param[in] index The index of the input to look up. + * @returns The tag of input, can only be [Tensor, Int, Bool, Double, String]. + */ + Result input_tag(size_t index) const; + + /** + * Get metadata about the specified input. + * + * @param[in] index The index of the input to look up. + * @returns The metadata on success, or an error on failure. Only valid for + * tag::Tensor + */ + Result input_tensor_meta(size_t index) const; + + /** + * Get the number of outputs to this method. + * + * @returns The number of outputs. + */ + size_t num_outputs() const; + + /** + * Get the tag of the specified output. + * + * @param[in] index The index of the output to look up. + * @returns The tag of output, can only be [Tensor, Int, Bool, Double, + * String]. + */ + Result output_tag(size_t index) const; + + /** + * Get metadata about the specified output. + * + * @param[in] index The index of the output to look up. + * @returns The metadata on success, or an error on failure. Only valid for + * tag::Tensor + */ + Result output_tensor_meta(size_t index) const; + + /** + * Get the number of memory-planned buffers this method requires. + * + * @returns The number of memory-planned buffers. + */ + size_t num_memory_planned_buffers() const; + + /** + * Get the size in bytes of the specified memory-planned buffer. + * + * @param[in] index The index of the buffer to look up. + * @returns The size in bytes on success, or an error on failure. + */ + Result memory_planned_buffer_size(size_t index) const; + + /** + * Get the number of instructions in this method. + * + * @returns The number of instructions. + */ + ET_EXPERIMENTAL size_t num_instructions() const; + + /** + * DEPRECATED: Use num_memory_planned_buffers() instead. + */ + ET_DEPRECATED size_t num_non_const_buffers() const { + return num_memory_planned_buffers(); + } + + /** + * DEPRECATED: Use memory_planned_buffer_size() instead. + */ + Result non_const_buffer_size(size_t index) const { + return memory_planned_buffer_size(index); + } + +private: + // Let Program create MethodMeta. + friend class Program; + + explicit MethodMeta(const executorch_flatbuffer::ExecutionPlan *s_plan); + + /// Source of truth for method information + const executorch_flatbuffer::ExecutionPlan *s_plan_; +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::MethodMeta; +using ::executorch::runtime::TensorInfo; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/executor/program.h b/third-party/include/executorch/runtime/executor/program.h new file mode 100644 index 00000000..3c4e30ef --- /dev/null +++ b/third-party/include/executorch/runtime/executor/program.h @@ -0,0 +1,294 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Forward declare flatbuffer types. This is a public header and must not +// include the generated flatbuffer header. +namespace executorch_flatbuffer { +struct Program; +} // namespace executorch_flatbuffer + +namespace executorch { +namespace runtime { + +namespace testing { +// Provides test access to private Program methods. +class ProgramTestFriend; +} // namespace testing + +namespace deserialization { +// Provides Tensor deserializaiton access to private Program methods. +class TensorParser; +} // namespace deserialization + +/** + * A deserialized ExecuTorch program binary. + */ +class Program final { +public: + /** + * Types of validation that the Program can do before parsing the data. + */ + enum class Verification : uint8_t { + /** + * Do minimal verification of the data, ensuring that the header appears + * correct. + * + * Has minimal runtime overhead. + */ + Minimal, + /** + * Do full verification of the data, ensuring that internal pointers are + * self-consistent and that the data has not been truncated or obviously + * corrupted. May not catch all types of corruption, but should guard + * against illegal memory operations during parsing. + * + * Will have higher runtime overhead, scaling with the complexity of the + * proram data. + */ + InternalConsistency, + }; + + /** + * Loads a Program from the provided loader. The Program will hold a pointer + * to the loader, which must outlive the returned Program instance. + * + * @param[in] loader The source to load program data from. The Program will + * hold a pointer to this loader, which must outlive the returned Program + * instance. + * @param[in] verification The type of verification to do before returning + * success. + */ + ET_NODISCARD static Result + load(DataLoader *loader, Verification verification = Verification::Minimal); + + /// DEPRECATED: Use the lowercase `load()` instead. + ET_DEPRECATED ET_NODISCARD static Result + Load(DataLoader *loader, Verification verification = Verification::Minimal) { + return load(loader, verification); + } + + // Movable, to be compatible with Result. + Program(Program &&) noexcept = default; + ~Program() = default; + + /** + * Get the constant buffer inside Program with index buffer_idx. + * @param[in] buffer_idx the index of the buffer in the constant_buffer. + * @param[in] nbytes the number of bytes to read from the buffer. + * @return The buffer with corresponding index. + */ + Result get_constant_buffer_data(size_t buffer_idx, + size_t nbytes) const; + + /** + * Returns the number of methods in the program. + */ + size_t num_methods() const; + + /** + * Returns the name of the method at particular index. + * + * @param[in] method_index The index of the method name to retrieve. Must be + * less than the value returned by `num_methods()`. + * + * @returns The name of the requested method. The pointer is owned by the + * Program, and has the same lifetime as the Program. + */ + Result get_method_name(size_t method_index) const; + + /** + * Loads the named method and prepares it for execution. + * + * @param[in] method_name The name of the method to load. + * @param[in] memory_manager The allocators to use during initialization and + * execution of the loaded method. If `memory_manager.temp_allocator()` is + * null, the runtime will allocate temp memory using `et_pal_allocate()`. + * @param[in] event_tracer The event tracer to use for this method run. + * + * @returns The loaded method on success, or an error on failure. + */ + Result load_method(const char *method_name, + MemoryManager *memory_manager, + EventTracer *event_tracer = nullptr) const; + + /** + * Gathers metadata for the named method. + * + * @param[in] method_name The name of the method to get metadata for. + */ + Result method_meta(const char *method_name) const; + + /** + * DEPRECATED: Get the pytree encoding string for the output. Deprecated as + * this functionality will eventually move out of the core program into a + * higher level structure, but that does not exist at this time. + * @param[in] method_name The name of the method to get the encoding for. + * + * @return The pytree encoding string for the output + */ + ET_DEPRECATED Result + get_output_flattening_encoding(const char *method_name = "forward") const; + + /** + * Describes the presence of an ExecuTorch program header. + */ + enum HeaderStatus { + /** + * An ExecuTorch program header is present, and its version is compatible + * with this version of the runtime. + */ + CompatibleVersion, + + /** + * An ExecuTorch program header is present, but its version is not + * compatible with this version of the runtime. + */ + IncompatibleVersion, + + /** + * An ExecuTorch program header is not present. + */ + NotPresent, + + /** + * The data provided was too short to find the program header. + */ + ShortData, + }; + + /** + * The minimum number of bytes necessary for calls to `check_header`. + */ + static constexpr size_t kMinHeadBytes = 64; + + /** + * Looks for an ExecuTorch program header in the provided data. + * + * @param[in] data The data from the beginning of a file that might contain + * an ExecuTorch program. + * @param[in] size The size of `data` in bytes. Must be >= `kMinHeadBytes`. + * + * @returns A value describing the presence of a header in the data. + */ + static HeaderStatus check_header(const void *data, size_t size); + +private: + // Let some classes call these private methods. + friend class BackendDelegate; + friend class Executor; + friend class Method; + friend class deserialization::TensorParser; + friend class testing::ProgramTestFriend; + + const executorch_flatbuffer::Program *get_internal_program() const { + return internal_program_; + } + + // Used by Method to look up entries in the delegate data table. + Error get_backend_delegate_data(size_t index, const void **out_data, + size_t *out_size) const; + + /** + * Loads a segment by index. + * + * @param[in] segment_info Struct containing an index to load from the + * Program.segments list. The other fields of the struct, such as + * `segment_type` and `descriptor`, need to also be correct. + * + * @returns The data as a FreeableBuffer, if the index is valid. + * @retval Error::NotFound The program does not contain any segments or the + * index is out of range. + * @returns Other errors depending on the implementation of + * DataLoader: The Program.segment table is inconsistent, or the + * data cannot be accessed. + */ + ET_NODISCARD Result + LoadSegment(const DataLoader::SegmentInfo &segment_info) const; + + /** + * Loads a portion of a mutable segment into the provided buffer. + * + * @param[in] mutable_data_segments_index The index into the + * mutable_data_segments_array. + * @param[in] offset_index The index into the segment's offsets array. + * @param[in] size The number of bytes to load. + * @param[in] buffer The buffer to load data into. Must point to at least + * `size` bytes of memory. + * + * @returns An error code on if the load was successful. + * @retval Error::Ok The load was successful. + * @retval Error::NotFound The program does not contain any segments or the + * indices are out of range. + * @returns Other errors depending on the implementation of + * DataLoader: The Program.segment table is inconsistent, or the + * data cannot be accessed. + */ + ET_NODISCARD Error load_mutable_subsegment_into( + size_t mutable_data_segments_index, size_t offset_index, size_t size, + void *buffer) const; + +private: + Program(DataLoader *loader, size_t segment_base_offset, + FreeableBuffer &&program_data, + const executorch_flatbuffer::Program *internal_program, + FreeableBuffer &&constant_segment_data) + : program_data_(std::move(program_data)), + // Don't need the loader if there are no segments. + loader_(segment_base_offset > 0 ? loader : nullptr), + internal_program_(internal_program), + segment_base_offset_(segment_base_offset), + constant_segment_data_(std::move(constant_segment_data)) {} + + // Not copyable or assignable. + Program(const Program &rhs) = delete; + Program &operator=(Program &&rhs) noexcept = delete; + Program &operator=(const Program &rhs) = delete; + + /// The serialized program data. Tensors will point directly into this buffer. + FreeableBuffer program_data_; + + /// Used to load segment data. Null if there are no segments. + DataLoader *loader_; + + /// The flatbuffer representation of the program. Must not be exposed to + /// users. + const executorch_flatbuffer::Program *internal_program_; + + /// The offset to the first segment, in bytes. If zero, no segments should + /// be present in internal_program_. + size_t segment_base_offset_; + + /// Constant segment data. + FreeableBuffer constant_segment_data_; +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::Program; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/executor/tensor_parser.h b/third-party/include/executorch/runtime/executor/tensor_parser.h new file mode 100644 index 00000000..a256109c --- /dev/null +++ b/third-party/include/executorch/runtime/executor/tensor_parser.h @@ -0,0 +1,116 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace executorch { +namespace runtime { +namespace deserialization { + +ET_NODISCARD Result +parseTensor(const Program *program, MemoryManager *memory_manager, + const executorch_flatbuffer::Tensor *s_tensor); + +ET_NODISCARD Result> +parseTensorList(const flatbuffers::Vector *tensor_indices, + EValue *values_, MemoryManager *memory_manager); + +// Deserializes a List of optional type. The code here is the same between all +// list of optionals: list of optional Tensor, list of optional float etc, so we +// just use a template to avoid boilerplate. +template +ET_NODISCARD Result>> +parseListOptionalType(const flatbuffers::Vector *value_indices, + EValue *values_, MemoryManager *memory_manager) { + auto *evalp_list = memory_manager->method_allocator()->allocateList( + value_indices->size()); + if (evalp_list == nullptr) { + return Error::MemoryAllocationFailed; + } + + auto *optional_tensor_list = + memory_manager->method_allocator() + ->allocateList>(value_indices->size()); + if (optional_tensor_list == nullptr) { + return Error::MemoryAllocationFailed; + } + + size_t output_idx = 0; + // For each index look up the corresponding EValue (which has been + // already allocated) and stick it in the list. + for (int32_t index : *value_indices) { + // Lists of objects are stored in fbb as list[int] where the ints are + // indices into values_. Currently serialization is deciding if they want to + // put -1 for serialized None type indices, or give us a valid index to a + // serialized None. We support either for now. + // Placement new as the list elements are not initialized, so calling + // copy assignment is not defined if its non trivial. + if (index == -1) { + new (&optional_tensor_list[output_idx]) + executorch::aten::optional(executorch::aten::nullopt); + // no value to point to. BoxedEvalueList for optional tensor will convert + // this to nullopt. + // TODO(T161156879): do something less hacky here. + evalp_list[output_idx] = nullptr; + } else { + new (&optional_tensor_list[output_idx]) + executorch::aten::optional(values_[index].toOptional()); + evalp_list[output_idx] = &values_[static_cast(index)]; + } + output_idx++; + } + return BoxedEvalueList>( + evalp_list, optional_tensor_list, value_indices->size()); +} + +/** + * Returns the appropriate data pointer for `s_tensor`. + * + * Overall, a Tensor is either constant or non-constant, except we differentiate + * 2 special variants of non-constant Tensor ("input" and control-flow + * "placeholder") as a special optimization to avoid holding unnecessary + * AllocationDetails. Thus, s_tensor can be configured as 1 of 3 options: + * - constant_buffer > 0, allocation_info = Null: Constant Tensor. + * - constant_buffer = 0, allocation_info = Non Null: Non-constant Tensor. + * - constant_buffer = 0, allocation_info = Null: Input/placeholder Tensor. + * + * @param[in] s_tensor The tensor to find the data pointer for. + * @param[in] program The Program to use for constant buffer data. + * @param[in] nbytes The amount of memory to get from the allocator. + * @param[in] allocator The source of memory for non-constant tensors. + * + * @returns On success, the data pointer to use for the tensor. On failure, a + * non-Ok Error. + */ +ET_NODISCARD Result +getTensorDataPtr(const executorch_flatbuffer::Tensor *s_tensor, + const Program *program, size_t nbytes, + HierarchicalAllocator *allocator); + +} // namespace deserialization +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +namespace deserialization { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::deserialization::getTensorDataPtr; +using ::executorch::runtime::deserialization::parseListOptionalType; +using ::executorch::runtime::deserialization::parseTensor; +using ::executorch::runtime::deserialization::parseTensorList; +} // namespace deserialization +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/kernel/kernel_runtime_context.h b/third-party/include/executorch/runtime/kernel/kernel_runtime_context.h new file mode 100644 index 00000000..e6367da9 --- /dev/null +++ b/third-party/include/executorch/runtime/kernel/kernel_runtime_context.h @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace executorch { +namespace runtime { + +/** + * Runtime state and functionality for kernel implementations. + * + * NOTE: Will not be passed to operators if running in ATen mode as those + * operators do not expect to receive a KernelRuntimeContext argument. + */ +class KernelRuntimeContext { +public: + /** + * Construct a new kernel runtime context. + * + * KernelRuntimeContext does not take ownership + * of these pointers, so they must outlive the context instance. + * + * @param[in] event_tracer The optional EventTracer to use for + * profiling/debugging + * @param[in] temp_allocator The optional MemoryAllocator used to allocate + * temporary memory for the kernel. If not provided, an error will be + * returned when calling allocate_temp. + */ + KernelRuntimeContext(EventTracer *event_tracer = nullptr, + MemoryAllocator *temp_allocator = nullptr) + : event_tracer_(event_tracer), temp_allocator_(temp_allocator) {} + /** + * Tells the runtime that the kernel call has failed. Prefer this over + * ET_CHECK_*(), which fatally panics the process/system. + * + * If this is not called, the runtime will treat the kernel call as + * successful. + * + * This unusual error-propagation path is required because kernel signatures + * do not have a natural way to return errors directly. They are generally + * compatible with core PyTorch ATen kernel signatures, which use exceptions + * to report errors. But, ExecuTorch does not use exceptions. + */ + void fail(Error error) { failure_state_ = error; } + + /// Returns the current failure state. + ET_NODISCARD Error failure_state() const { return failure_state_; } + + /** + * INTERNAL ONLY + * + * Returns a pointer to an instance of EventTracer to do profiling/debugging + * logging inside the codegen layer. This is only for internal usage inside + * the codegen layer and users should not be accessing this. + */ + EventTracer *internal_event_tracer() { return event_tracer_; } + + /** + * Allocates temporary memory that will be freed when the kernel returns. This + * returns a pointer to the allocated memory or an error if the allocation + * fails. + * + * @param[in] size Number of bytes to allocate. + * @param[in] alignment Minimum alignment for the returned pointer. Must be a + * power of 2. + * + * @returns A result object containing either a pointer to the allocated + * memory or an error to indicate failure + */ + Result + allocate_temp(size_t size, + size_t alignment = MemoryAllocator::kDefaultAlignment) { + ET_CHECK_OR_RETURN_ERROR(temp_allocator_ != nullptr, NotFound, + "No temp allocator provided"); + void *temp_memory = temp_allocator_->allocate(size, alignment); + ET_CHECK_OR_RETURN_ERROR( + temp_memory != nullptr, MemoryAllocationFailed, + "Failed to allocate temp memory. Bytes requested: %zu", size); + return temp_memory; + } + + // TODO(T147221312): Add a way to resize a tensor. + +private: + EventTracer *event_tracer_ = nullptr; + MemoryAllocator *temp_allocator_ = nullptr; + Error failure_state_ = Error::Ok; +}; + +} // namespace runtime +} // namespace executorch + +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +namespace torch { +namespace executor { +/// DEPRECATED: Use ::executorch::runtime::KernelRuntimeContext instead. +using ::executorch::runtime::KernelRuntimeContext; +/// DEPRECATED: Use ::executorch::runtime::KernelRuntimeContext instead. +using RuntimeContext = ::executorch::runtime::KernelRuntimeContext; +} // namespace executor +} // namespace torch +namespace executorch { +namespace aten { +/// DEPRECATED: Use ::executorch::runtime::KernelRuntimeContext instead. +using RuntimeContext = ::executorch::runtime::KernelRuntimeContext; +} // namespace aten +} // namespace executorch +// DEPRECATED: The exec_aten:: namespace is deprecated. Use executorch::aten:: +// instead. +namespace exec_aten = ::executorch::aten; diff --git a/third-party/include/executorch/runtime/kernel/operator_registry.h b/third-party/include/executorch/runtime/kernel/operator_registry.h new file mode 100644 index 00000000..095a6742 --- /dev/null +++ b/third-party/include/executorch/runtime/kernel/operator_registry.h @@ -0,0 +1,260 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +// Debug switch for operator registry +#if defined(ET_OP_REGISTRY_DEBUG) +#include +#endif + +#define ET_LOG_KERNEL_KEY(k) \ + ET_LOG(Error, "key: %s, is_fallback: %s", k.data(), \ + k.is_fallback() ? "true" : "false"); +#define ET_LOG_TENSOR_META(meta_list) \ + for (const auto &meta : meta_list) { \ + ET_LOG(Error, "dtype: %d | dim order: [", int(meta.dtype_)); \ + for (int i = 0; i < meta.dim_order_.size(); i++) { \ + ET_LOG(Error, "%d,", static_cast(meta.dim_order_[i])); \ + } \ + ET_LOG(Error, "]"); \ + } + +namespace executorch { +namespace runtime { + +class KernelRuntimeContext; // Forward declaration +using OpFunction = void (*)(KernelRuntimeContext &, EValue **); + +/** + * Dtype and dim order metadata for a Tensor argument to an operator. + * Used by the Executor to hold the tensor metadata info and retrieve kernel. + */ +struct TensorMeta { + executorch::aten::ScalarType dtype_; + Span dim_order_; + + TensorMeta() = default; + TensorMeta(executorch::aten::ScalarType dtype, + Span order) + : dtype_(dtype), dim_order_(order) {} + + bool operator==(const TensorMeta &other) const { return this->equals(other); } + + bool operator!=(const TensorMeta &other) const { + return !this->equals(other); + } + + bool equals(const TensorMeta &other) const { + if (dtype_ != other.dtype_) { + return false; + } + if (dim_order_.size() != other.dim_order_.size()) { + return false; + } + for (int i = 0; i < dim_order_.size(); i++) { + if (dim_order_[i] != other.dim_order_[i]) { + return false; + } + } + return true; + } + +#if defined(ET_OP_REGISTRY_DEBUG) + friend std::ostream &operator<<(std::ostream &os, const TensorMeta &meta) { + os << "dtype: " << int(meta.dtype_) << " | dim order: ["; + for (int i = 0; i < meta.dim_order_.size(); i++) { + os << static_cast(meta.dim_order_[i]) << ", "; + } + os << "]"; + return os; + } +#endif +}; + +/** + * Describes which dtype & dim order specialized kernel to be bound to an + * operator. If `is_fallback_` is true, it means this kernel can be used as a + * fallback, if false, it means this kernel can only be used if all the + * `TensorMeta` are matched. Fallback means this kernel will be used for + * all input tensor dtypes and dim orders, if the specialized kernel is not + * registered. + * + * The format of a kernel key data is a string: + * "v/|..." + * Size: Up to 691 1 1 1 (42 +1) * 16 + * Assuming max number of tensors is 16 ^ + * Kernel key version is v1 for now. If the kernel key format changes, + * update the version to avoid breaking pre-existing kernel keys. + * Example: v1/7;0,1,2,3 + * The kernel key has only one tensor: a double tensor with dimension 0, 1, 2, 3 + * + * Each tensor_meta has the following format: ";" + * Size: Up to 42 1-2 1 24 (1 byte for 0-9; 2 + * for 10-15) + 15 commas Assuming that the max number of dims is 16 ^ Example: + * 7;0,1,2,3 for [double; 0, 1, 2, 3] + * + * IMPORTANT: + * Users should not construct a kernel key manually. Instead, it should be + * generated from kernel yaml. + */ +struct KernelKey { +public: + KernelKey() : is_fallback_(true) {} + + /* implicit */ KernelKey(const char *kernel_key_data) + : kernel_key_data_(kernel_key_data), is_fallback_(false) {} + + constexpr static int MAX_SIZE = 691; + + bool operator==(const KernelKey &other) const { return this->equals(other); } + + bool operator!=(const KernelKey &other) const { return !this->equals(other); } + + bool equals(const KernelKey &other) const { + if (is_fallback_ != other.is_fallback_) { + return false; + } + if (is_fallback_) { + return true; + } + return strncmp(kernel_key_data_, other.kernel_key_data_, MAX_SIZE) == 0; + } + + bool is_fallback() const { return is_fallback_; } + + const char *data() const { return kernel_key_data_; } + +#if defined(ET_OP_REGISTRY_DEBUG) + friend std::ostream &operator<<(std::ostream &os, const KernelKey &key) { + os << key.kernel_key_data_ << std::endl; + return os; + } +#endif + +private: + const char *kernel_key_data_ = nullptr; + bool is_fallback_; +}; + +/** + * Struct that bundles a kernel key, a function and an op name together. An + * `Operator` may have more than one `Kernel` (maximum kMaxNumOfKernelPerOp) and + * they should have the same op name and different kernel key. A "fallback" + * kernel may or may not live in an `Operator`. + */ +struct Kernel { + const char *name_; + // String representation of kernel key, with the same format as + // KernelKey.to_string_representation() + // Data is not owned by the Kernel struct. + KernelKey kernel_key_; + OpFunction op_; + /** + * We are doing a copy of the string pointer instead of duplicating the string + * itself, we require the lifetime of the operator name to be at least as long + * as the operator registry. + */ + explicit Kernel(const char *name, OpFunction func) : name_(name), op_(func) {} + + explicit Kernel(const char *name, KernelKey key, OpFunction func) + : name_(name), kernel_key_(key), op_(func) {} + + Kernel() {} +}; + +namespace internal { +void make_kernel_key_string(Span key, char *buf); +} // namespace internal + +/** + * Checks whether an operator exists with a given name and TensorMeta list. When + * TensorMeta is empty, it means this op does not have specialized kernels, so + * it checks whether it has any fallback kernels. + */ +bool registry_has_op_function(const char *name, + Span meta_list = {}); + +/** + * Returns the operator with a given name and TensorMeta list, if present. + */ +::executorch::runtime::Result +get_op_function_from_registry(const char *name, + Span meta_list = {}); + +/** + * Returns all registered kernels. + */ +Span get_registered_kernels(); + +/** + * Registers the provided kernels. + * + * @param[in] kernels Kernel objects to register. + * @retval Error::Ok always. Panics on error. This function needs to return a + * non-void type to run at static initialization time. + */ +ET_NODISCARD Error register_kernels(const Span); + +/** + * Registers a single kernel. + * + * @param[in] kernel Kernel object to register. + * @retval Error::Ok always. Panics on error. This function needs to return a + * non-void type to run at static initialization time. + */ +ET_NODISCARD inline Error register_kernel(const Kernel &kernel) { + return register_kernels({&kernel, 1}); +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::Kernel; +using ::executorch::runtime::KernelKey; +using ::executorch::runtime::KernelRuntimeContext; +using ::executorch::runtime::OpFunction; +using ::executorch::runtime::TensorMeta; +using KernelRuntimeContext = ::executorch::runtime::KernelRuntimeContext; + +inline ::executorch::runtime::Error register_kernels(ArrayRef kernels) { + return ::executorch::runtime::register_kernels( + {kernels.data(), kernels.size()}); +} +inline OpFunction getOpsFn(const char *name, + ArrayRef meta_list = {}) { + auto result = ::executorch::runtime::get_op_function_from_registry( + name, {meta_list.data(), meta_list.size()}); + ET_CHECK(result.ok()); // get_op_function_from_registry() logs details. + return *result; +} +inline bool hasOpsFn(const char *name, ArrayRef meta_list = {}) { + return ::executorch::runtime::registry_has_op_function( + name, {meta_list.data(), meta_list.size()}); +} +inline ArrayRef get_kernels() { + Span kernels = ::executorch::runtime::get_registered_kernels(); + return ArrayRef(kernels.data(), kernels.size()); +} +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/platform/abort.h b/third-party/include/executorch/runtime/platform/abort.h new file mode 100644 index 00000000..ae1a761a --- /dev/null +++ b/third-party/include/executorch/runtime/platform/abort.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * ExecuTorch global abort wrapper function. + */ + +#pragma once + +#include + +namespace executorch { +namespace runtime { + +/** + * Trigger the ExecuTorch global runtime to immediately exit without cleaning + * up, and set an abnormal exit status (platform-defined). + */ +ET_NORETURN void runtime_abort(); + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::runtime_abort; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/platform/assert.h b/third-party/include/executorch/runtime/platform/assert.h new file mode 100644 index 00000000..14ec2706 --- /dev/null +++ b/third-party/include/executorch/runtime/platform/assert.h @@ -0,0 +1,119 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +/** + * Assertion failure message emit method. + * + * @param[in] _format Printf-style message format string. + * @param[in] ... Format string arguments. + */ +#define ET_ASSERT_MESSAGE_EMIT(_format, ...) \ + ET_LOG(Fatal, "In function %s(), assert failed" _format, ET_FUNCTION, \ + ##__VA_ARGS__) + +/** + * Abort the runtime if the condition is not true. + * This check will be performed even in release builds. + * + * @param[in] _cond Condition asserted as true. + * @param[in] _format Printf-style message format string. + * @param[in] ... Format string arguments. + */ +#define ET_CHECK_MSG(_cond, _format, ...) \ + do { \ + if ET_UNLIKELY (!(_cond)) { \ + ET_ASSERT_MESSAGE_EMIT(" (%s): " _format, #_cond, ##__VA_ARGS__); \ + ::executorch::runtime::runtime_abort(); \ + } \ + } while (0) + +/** + * Abort the runtime if the condition is not true. + * This check will be performed even in release builds. + * + * @param[in] _cond Condition asserted as true. + */ +#define ET_CHECK(_cond) \ + do { \ + if ET_UNLIKELY (!(_cond)) { \ + ET_ASSERT_MESSAGE_EMIT(": %s", #_cond); \ + ::executorch::runtime::runtime_abort(); \ + } \ + } while (0) + +#ifdef NDEBUG + +/** + * Abort the runtime if the condition is not true. + * This check will be performed in debug builds, but not release builds. + * + * @param[in] _cond Condition asserted as true. + * @param[in] _format Printf-style message format string. + * @param[in] ... Format string arguments. + */ +#define ET_DCHECK_MSG(_cond, _format, ...) ((void)0) + +/** + * Abort the runtime if the condition is not true. + * This check will be performed in debug builds, but not release builds. + * + * @param[in] _cond Condition asserted as true. + */ +#define ET_DCHECK(_cond) ((void)0) +#define ET_DEBUG_ONLY [[maybe_unused]] + +#else // NDEBUG + +/** + * Abort the runtime if the condition is not true. + * This check will be performed in debug builds, but not release builds. + * + * @param[in] _cond Condition asserted as true. + * @param[in] _format Printf-style message format string. + * @param[in] ... Format string arguments. + */ +#define ET_DCHECK_MSG(_cond, _format, ...) \ + ET_CHECK_MSG(_cond, _format, ##__VA_ARGS__) + +/** + * Abort the runtime if the condition is not true. + * This check will be performed in debug builds, but not release builds. + * + * @param[in] _cond Condition asserted as true. + */ +#define ET_DCHECK(_cond) ET_CHECK(_cond) +#define ET_DEBUG_ONLY + +#endif // NDEBUG + +/** + * Assert that this code location is unreachable during execution. + */ +#define ET_ASSERT_UNREACHABLE() \ + do { \ + ET_CHECK_MSG(false, "Execution should not reach this point"); \ + ET_UNREACHABLE(); \ + } while (0) + +/** + * Assert that this code location is unreachable during execution. + * + * @param[in] _message Message on how to avoid this assertion error. + */ +#define ET_ASSERT_UNREACHABLE_MSG(_format, ...) \ + do { \ + ET_CHECK_MSG(false, "Execution should not reach this point. " _format, \ + ##__VA_ARGS__); \ + ET_UNREACHABLE(); \ + } while (0) diff --git a/third-party/include/executorch/runtime/platform/clock.h b/third-party/include/executorch/runtime/platform/clock.h new file mode 100644 index 00000000..36c25d90 --- /dev/null +++ b/third-party/include/executorch/runtime/platform/clock.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * Clock and timing related methods. + */ + +#pragma once + +#include + +namespace executorch { +namespace runtime { + +/** + * Convert an interval from units of system ticks to nanoseconds. + * The conversion ratio is platform-dependent, and thus depends on + * the platform implementation of et_pal_ticks_to_ns_multiplier(). + * + * @param[in] ticks The interval length in system ticks. + * @retval The interval length in nanoseconds. + */ +inline uint64_t ticks_to_ns(et_timestamp_t ticks) { + et_tick_ratio_t ratio = et_pal_ticks_to_ns_multiplier(); + return static_cast(ticks) * ratio.numerator / ratio.denominator; +} + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::ticks_to_ns; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/platform/compiler.h b/third-party/include/executorch/runtime/platform/compiler.h new file mode 100644 index 00000000..6e98906b --- /dev/null +++ b/third-party/include/executorch/runtime/platform/compiler.h @@ -0,0 +1,184 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * Compiler utility macros. + */ + +#pragma once + +/* + * Compiler support checks. Follows the logic used by pytorch/c10/util/C++17.h + * but may support older versions. + */ + +// https://gcc.gnu.org/projects/cxx-status.html#cxx17 +#if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && \ + __GNUC__ < 7 +#error \ + "You're trying to build ExecuTorch with a too old version of GCC. We need GCC 7 or later." +#endif + +// https://clang.llvm.org/cxx_status.html#cxx17 +#if defined(__clang__) && __clang_major__ < 5 +#error \ + "You're trying to build ExecuTorch with a too old version of Clang. We need Clang 5 or later." +#endif + +#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201703L)) || \ + (!defined(_MSC_VER) && __cplusplus < 201703L) +#error "You need C++17 to compile ExecuTorch" +#endif + +#if defined(_MSC_VER) && (defined(min) || defined(max)) +#error \ + "Macro clash with min and max -- define NOMINMAX when compiling your program on Windows" +#endif + +/* + * Define annotations aliasing C++ declaration attributes. + * See all C++ declaration attributes here: + * https://en.cppreference.com/w/cpp/language/attributes + * + * Note that ExecuTorch supports a lower C++ standard version than all standard + * attributes. Therefore, some annotations are defined using their Clang/GNU + * counterparts. + * + * GNU attribute definitions: + * https://gcc.gnu.org/onlinedocs/gcc/Common-Function-Attributes.html + */ + +#define ET_NORETURN [[noreturn]] +#define ET_NOINLINE __attribute__((noinline)) +#define ET_INLINE __attribute__((always_inline)) inline +#define ET_INLINE_ATTRIBUTE __attribute__((always_inline)) + +#if defined(__GNUC__) + +#define ET_UNREACHABLE() __builtin_unreachable() + +#elif defined(_MSC_VER) + +#define ET_UNREACHABLE() __assume(0) + +#else // defined(__GNUC__) + +#define ET_UNREACHABLE() \ + while (1) \ + ; + +#endif // defined(__GNUC__) + +#define ET_DEPRECATED [[deprecated]] +#define ET_EXPERIMENTAL \ + [[deprecated("This API is experimental and may change without notice.")]] +#define ET_FALLTHROUGH [[fallthrough]] +#define ET_NODISCARD [[nodiscard]] +#define ET_UNUSED [[maybe_unused]] + +// UNLIKELY Macro +// example +// if ET_UNLIKELY(a > 10 && b < 5) { +// do something +// } +#if (__cplusplus) >= 202002L + +#define ET_LIKELY(expr) (expr) [[likely]] +#define ET_UNLIKELY(expr) (expr) [[unlikely]] + +#else + +#define ET_LIKELY(expr) (expr) +#define ET_UNLIKELY(expr) (expr) + +#endif // (__cplusplus) >= 202002L + +/// Define a C symbol with weak linkage. +#ifdef _MSC_VER +// There currently doesn't seem to be a great way to do this in Windows and +// given that weak linkage is not really critical on Windows, we'll just leave +// it as a stub. +#define ET_WEAK +#else +#define ET_WEAK __attribute__((weak)) +#endif + +/** + * Annotation marking a function as printf-like, providing compiler support + * for format string argument checking. + */ +#ifdef _MSC_VER +#include +#define ET_PRINTFLIKE(_string_index, _va_index) _Printf_format_string_ +#else +#define ET_PRINTFLIKE(_string_index, _va_index) \ + __attribute__((format(printf, _string_index, _va_index))) +#endif + +#ifndef __has_builtin +#define __has_builtin(x) (0) +#endif + +#if __has_builtin(__builtin_strrchr) +/// Name of the source file without a directory string. +#define ET_SHORT_FILENAME (__builtin_strrchr("/" __FILE__, '/') + 1) +#else +#define ET_SHORT_FILENAME __FILE__ +#endif + +#if __has_builtin(__builtin_LINE) +/// Current line as an integer. +#define ET_LINE __builtin_LINE() +#else +#define ET_LINE __LINE__ +#endif // __has_builtin(__builtin_LINE) + +#if __has_builtin(__builtin_FUNCTION) +/// Name of the current function as a const char[]. +#define ET_FUNCTION __builtin_FUNCTION() +#else +#define ET_FUNCTION __FUNCTION__ +#endif // __has_builtin(__builtin_FUNCTION) + +// Whether the compiler supports GNU statement expressions. +// https://gcc.gnu.org/onlinedocs/gcc/Statement-Exprs.html +#ifndef ET_HAVE_GNU_STATEMENT_EXPRESSIONS +#if (defined(__GNUC__) && __GNUC__ >= 3) || defined(__clang__) +#define ET_HAVE_GNU_STATEMENT_EXPRESSIONS 1 +#else +#define ET_HAVE_GNU_STATEMENT_EXPRESSIONS 0 +#endif +#endif // ifndef + +// Define size_t and ssize_t. +#ifndef _MSC_VER +#include +#else +#include +using ssize_t = ptrdiff_t; +#endif + +// DEPRECATED: Use the non-underscore-prefixed versions instead. +// TODO(T199005537): Remove these once all users have stopped using them. +#define __ET_DEPRECATED ET_DEPRECATED +#define __ET_FALLTHROUGH ET_FALLTHROUGH +#define __ET_FUNCTION ET_FUNCTION +#define __ET_HAVE_GNU_STATEMENT_EXPRESSIONS ET_HAVE_GNU_STATEMENT_EXPRESSIONS +#define __ET_INLINE ET_INLINE +#define __ET_LIKELY ET_LIKELY +#define __ET_LINE ET_LINE +#define __ET_NODISCARD ET_NODISCARD +#define __ET_NOINLINE ET_NOINLINE +#define __ET_NORETURN ET_NORETURN +#define __ET_PRINTFLIKE ET_PRINTFLIKE +#define __ET_SHORT_FILENAME ET_SHORT_FILENAME +#define __ET_UNLIKELY ET_UNLIKELY +#define __ET_UNREACHABLE ET_UNREACHABLE +#define __ET_UNUSED ET_UNUSED +#define __ET_WEAK ET_WEAK diff --git a/third-party/include/executorch/runtime/platform/log.h b/third-party/include/executorch/runtime/platform/log.h new file mode 100644 index 00000000..3bcbce7c --- /dev/null +++ b/third-party/include/executorch/runtime/platform/log.h @@ -0,0 +1,168 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * ExecuTorch logging API. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +// Set minimum log severity if compiler option is not provided. +#ifndef ET_MIN_LOG_LEVEL +#define ET_MIN_LOG_LEVEL Info +#endif // !defined(ET_MIN_LOG_LEVEL) + +/* + * Enable logging by default if compiler option is not provided. + * This should facilitate less confusion for those developing ExecuTorch. + */ +#ifndef ET_LOG_ENABLED +#define ET_LOG_ENABLED 1 +#endif // !defined(ET_LOG_ENABLED) + +namespace executorch { +namespace runtime { + +/** + * Severity level of a log message. Must be ordered from lowest to highest + * severity. + */ +enum class LogLevel : uint8_t { + /** + * Log messages provided for highly granular debuggability. + * + * Log messages using this severity are unlikely to be compiled by default + * into most debug builds. + */ + Debug, + + /** + * Log messages providing information about the state of the system + * for debuggability. + */ + Info, + + /** + * Log messages about errors within ExecuTorch during runtime. + */ + Error, + + /** + * Log messages that precede a fatal error. However, logging at this level + * does not perform the actual abort, something else needs to. + */ + Fatal, + + /** + * Number of supported log levels, with values in [0, NumLevels). + */ + NumLevels, +}; + +namespace internal { + +/** + * Get the current timestamp to construct a log event. + * + * @retval Monotonically non-decreasing timestamp in system ticks. + */ +et_timestamp_t get_log_timestamp(); + +/** + * Log a string message. + * + * Note: This is an internal function. Use the `ET_LOG` macro instead. + * + * @param[in] level Log severity level. + * @param[in] timestamp Timestamp (in system ticks) of the log event. + * @param[in] filename Name of the source file creating the log event. + * @param[in] function Name of the function creating the log event. + * @param[in] line Source file line of the caller. + * @param[in] format Format string. + * @param[in] args Variable argument list. + */ +ET_PRINTFLIKE(6, 0) +void vlogf(LogLevel level, et_timestamp_t timestamp, const char *filename, + const char *function, size_t line, const char *format, va_list args); + +/** + * Log a string message. + * + * Note: This is an internal function. Use the `ET_LOG` macro instead. + * + * @param[in] level Log severity level. + * @param[in] timestamp Timestamp (in system ticks) of the log event. + * @param[in] filename Name of the source file creating the log event. + * @param[in] function Name of the function creating the log event. + * @param[in] line Source file line of the caller. + * @param[in] format Format string. + */ +ET_PRINTFLIKE(6, 7) +inline void logf(LogLevel level, et_timestamp_t timestamp, const char *filename, + const char *function, size_t line, const char *format, ...) { +#if ET_LOG_ENABLED + va_list args; + va_start(args, format); + internal::vlogf(level, timestamp, filename, function, line, format, args); + va_end(args); +#endif // ET_LOG_ENABLED +} + +} // namespace internal + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::LogLevel; +} // namespace executor +} // namespace torch + +#if ET_LOG_ENABLED + +/** + * Log a message at the given log severity level. + * + * @param[in] _level Log severity level. + * @param[in] _format Log message format string. + */ +#define ET_LOG(_level, _format, ...) \ + do { \ + const auto _log_level = ::executorch::runtime::LogLevel::_level; \ + if (static_cast(_log_level) >= \ + static_cast( \ + ::executorch::runtime::LogLevel::ET_MIN_LOG_LEVEL)) { \ + const auto _timestamp = \ + ::executorch::runtime::internal::get_log_timestamp(); \ + ::executorch::runtime::internal::logf(_log_level, _timestamp, \ + ET_SHORT_FILENAME, ET_FUNCTION, \ + ET_LINE, _format, ##__VA_ARGS__); \ + } \ + } while (0) +#else // ET_LOG_ENABLED + +/** + * Log a message at the given log severity level. + * + * @param[in] _level Log severity level. + * @param[in] _format Log message format string. + */ +#define ET_LOG(_level, _format, ...) ((void)0) + +#endif // ET_LOG_ENABLED diff --git a/third-party/include/executorch/runtime/platform/platform.h b/third-party/include/executorch/runtime/platform/platform.h new file mode 100644 index 00000000..c5cf1dd1 --- /dev/null +++ b/third-party/include/executorch/runtime/platform/platform.h @@ -0,0 +1,133 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * Platform abstraction layer to allow individual platform libraries to override + * symbols in ExecuTorch. PAL functions are defined as C functions so a platform + * library implementer can use C in lieu of C++. + */ + +#pragma once + +// Use C-style includes so that C code can include this header. +#include +#include + +#include +#include + +/** + * Clients should neither define nor use this macro. Used to optionally declare + * the et_pal_*() functions as weak symbols. + * + * This provides a way to both: + * - Include the header and define weak symbols (used by the internal default + * implementations) + * - Include the header and define strong symbols (used by client overrides) + */ +#ifndef ET_INTERNAL_PLATFORM_WEAKNESS +#define ET_INTERNAL_PLATFORM_WEAKNESS +#endif + +extern "C" { + +/** + * Represents the conversion ratio from system ticks to nanoseconds. + * To convert, use nanoseconds = ticks * numerator / denominator. + */ +typedef struct { + uint64_t numerator; + uint64_t denominator; +} et_tick_ratio_t; + +/** + * Initialize the platform abstraction layer. + * + * This function should be called before any other function provided by the PAL + * to initialize any global state. Typically overridden by PAL implementer. + */ +void et_pal_init(void) ET_INTERNAL_PLATFORM_WEAKNESS; + +/** + * Immediately abort execution, setting the device into an error state, if + * available. + */ +ET_NORETURN void et_pal_abort(void) ET_INTERNAL_PLATFORM_WEAKNESS; + +/** + * Return a monotonically non-decreasing timestamp in system ticks. + * + * @retval Timestamp value in system ticks. + */ +et_timestamp_t et_pal_current_ticks(void) ET_INTERNAL_PLATFORM_WEAKNESS; + +/** + * Return the conversion rate from system ticks to nanoseconds as a fraction. + * To convert a system ticks to nanoseconds, multiply the tick count by the + * numerator and then divide by the denominator: + * nanoseconds = ticks * numerator / denominator + * + * The utility method executorch::runtime::ticks_to_ns(et_timestamp_t) can also + * be used to perform the conversion for a given tick count. It is defined in + * torch/executor/runtime/platform/clock.h. + * + * @retval The ratio of nanoseconds to system ticks. + */ +et_tick_ratio_t +et_pal_ticks_to_ns_multiplier(void) ET_INTERNAL_PLATFORM_WEAKNESS; + +/** + * Severity level of a log message. Values must map to printable 7-bit ASCII + * uppercase letters. + */ +typedef enum { + kDebug = 'D', + kInfo = 'I', + kError = 'E', + kFatal = 'F', + kUnknown = '?', // Exception to the "uppercase letter" rule. +} et_pal_log_level_t; + +/** + * Emit a log message via platform output (serial port, console, etc). + * + * @param[in] timestamp Timestamp of the log event in system ticks since boot. + * @param[in] level Severity level of the message. Must be a printable 7-bit + * ASCII uppercase letter. + * @param[in] filename Name of the file that created the log event. + * @param[in] function Name of the function that created the log event. + * @param[in] line Line in the source file where the log event was created. + * @param[in] message Message string to log. + * @param[in] length Message string length. + */ +void et_pal_emit_log_message(et_timestamp_t timestamp, et_pal_log_level_t level, + const char *filename, const char *function, + size_t line, const char *message, + size_t length) ET_INTERNAL_PLATFORM_WEAKNESS; + +/** + * NOTE: Core runtime code must not call this directly. It may only be called by + * a MemoryAllocator wrapper. + * + * Allocates size bytes of memory. + * + * @param[in] size Number of bytes to allocate. + * @returns the allocated memory, or nullptr on failure. Must be freed using + * et_pal_free(). + */ +void *et_pal_allocate(size_t size) ET_INTERNAL_PLATFORM_WEAKNESS; + +/** + * Frees memory allocated by et_pal_allocate(). + * + * @param[in] ptr Pointer to memory to free. May be nullptr. + */ +void et_pal_free(void *ptr) ET_INTERNAL_PLATFORM_WEAKNESS; + +} // extern "C" diff --git a/third-party/include/executorch/runtime/platform/profiler.h b/third-party/include/executorch/runtime/platform/profiler.h new file mode 100644 index 00000000..6905d643 --- /dev/null +++ b/third-party/include/executorch/runtime/platform/profiler.h @@ -0,0 +1,292 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace runtime { + +// Version string used to check for compatibility with post-processing +// tool +#define ET_PROF_VER 0x00000001 + +// By default we support profiling upto 1024 perf events. Build +// targets can override this to increase the profiling buffer size +// during compilation. +#ifndef MAX_PROFILE_EVENTS +#define MAX_PROFILE_EVENTS 1024 +#endif +// By default we support profiling upto 1024 memory allocation events. +// Build targets can choose to override this, which will consequently have +// the effect of increasing/decreasing the profiling buffer size. +#ifndef MAX_MEM_PROFILE_EVENTS +#define MAX_MEM_PROFILE_EVENTS 1024 +#endif +// By default we support profiling only upto 16 allocators. If users +// have more allocators than these then they can override this during +// compilation time. There will be an increase/decrease in the profiling +// buffer size based on the way this value is changed. +#ifndef MEM_PROFILE_MAX_ALLOCATORS +#define MEM_PROFILE_MAX_ALLOCATORS 32 +#endif +// By default we support only one profiling block. If users want to profile +// something that will be iterated on multiple times then they will have to +// increment this to support their use case. In post-processing the stats for +// all these iterations will be consolidated. +#ifndef MAX_PROFILE_BLOCKS +#define MAX_PROFILE_BLOCKS 2 +#endif + +#define PROF_NAME_MAX_LEN 32 + +typedef struct alignas(8) { + union { + const char *name_str; + char name[PROF_NAME_MAX_LEN]; + }; + // chain_idx == -1 is a null value, when profile event happens out of chain + // execution + int32_t chain_idx; + uint32_t instruction_idx; + uint64_t start_time; + uint64_t end_time; +} prof_event_t; + +typedef struct alignas(8) { + uint32_t allocator_id; + uint32_t allocation_size; +} mem_prof_event_t; + +typedef struct alignas(8) { + char name[PROF_NAME_MAX_LEN]; + uint64_t allocator_id; +} prof_allocator_t; + +typedef struct alignas(8) { + uint8_t *prof_data; + uint32_t num_bytes; + uint32_t num_blocks; +} prof_result_t; + +typedef struct alignas(8) { + char name[32]; + uint32_t prof_ver; + uint32_t max_prof_entries; + uint32_t prof_entries; + uint32_t max_allocator_entries; + uint32_t allocator_entries; + uint32_t max_mem_prof_entries; + uint32_t mem_prof_entries; +} prof_header_t; + +/* +This is what the layout of the profiling buffer looks like. +--------------------------------------- +| Profiling header | +--------------------------------------- +| Profile events (Perf events) | +--------------------------------------- +| Memory allocators info | +--------------------------------------- +| Profile events (Memory allocations) | +--------------------------------------- +*/ + +// offsets of the various sections in the profiling buffer +// Total size required for profiling buffer +constexpr uint32_t prof_buf_size = + sizeof(prof_header_t) + sizeof(prof_event_t) * MAX_PROFILE_EVENTS + + sizeof(mem_prof_event_t) * MAX_MEM_PROFILE_EVENTS + + sizeof(prof_allocator_t) * MEM_PROFILE_MAX_ALLOCATORS; + +constexpr size_t prof_header_offset = 0; +constexpr size_t prof_events_offset = sizeof(prof_header_t); +constexpr size_t prof_mem_alloc_info_offset = + prof_events_offset + sizeof(prof_event_t) * MAX_PROFILE_EVENTS; +constexpr size_t prof_mem_alloc_events_offset = + prof_mem_alloc_info_offset + + sizeof(prof_allocator_t) * MEM_PROFILE_MAX_ALLOCATORS; + +// Set the initial state for the profiler assuming we're using the +// statically allocated buffer declared in the profiler module. +void profiler_init(void); + +// This starts the profiling of this event and returns a token +// by which this event can be referred to in the future. +uint32_t begin_profiling(const char *name); + +// End profiling event represented by token_id +void end_profiling(uint32_t token_id); + +// Dump profiler results, return pointer to prof event array and number of +// events in it. +void dump_profile_stats(prof_result_t *prof_result); + +void reset_profile_stats(); + +void track_allocation(int32_t id, uint32_t size); + +uint32_t track_allocator(const char *name); + +void profiling_create_block(const char *name); + +// This class enables scope based profiling where needed. Profiling +// will be started when the object is created and will end when the +// object goes out of scope. +class ExecutorchProfiler { +public: + explicit ExecutorchProfiler(const char *name); + + ~ExecutorchProfiler(); + +private: + uint32_t prof_tok; +}; + +typedef struct { + int32_t chain_idx; + uint32_t instruction_idx; +} prof_state_t; + +const prof_state_t &get_profile_tls_state(); + +void set_profile_tls_state(const prof_state_t &state); + +class ExecutorchProfilerInstructionScope { +public: + explicit ExecutorchProfilerInstructionScope(const prof_state_t &state); + ~ExecutorchProfilerInstructionScope(); + + // ScopeGuard: non-copyable, non-movable + ExecutorchProfilerInstructionScope( + const ExecutorchProfilerInstructionScope &) = delete; + ExecutorchProfilerInstructionScope & + operator=(const ExecutorchProfilerInstructionScope &) = delete; + + ExecutorchProfilerInstructionScope(ExecutorchProfilerInstructionScope &&) = + delete; + ExecutorchProfilerInstructionScope & + operator=(ExecutorchProfilerInstructionScope &&) = delete; + +private: + prof_state_t old_state_; +}; + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::begin_profiling; +using ::executorch::runtime::dump_profile_stats; +using ::executorch::runtime::end_profiling; +using ::executorch::runtime::ExecutorchProfiler; +using ::executorch::runtime::ExecutorchProfilerInstructionScope; +using ::executorch::runtime::get_profile_tls_state; +using ::executorch::runtime::mem_prof_event_t; +using ::executorch::runtime::prof_allocator_t; +using ::executorch::runtime::prof_buf_size; +using ::executorch::runtime::prof_event_t; +using ::executorch::runtime::prof_events_offset; +using ::executorch::runtime::prof_header_offset; +using ::executorch::runtime::prof_header_t; +using ::executorch::runtime::prof_mem_alloc_events_offset; +using ::executorch::runtime::prof_mem_alloc_info_offset; +using ::executorch::runtime::prof_result_t; +using ::executorch::runtime::prof_state_t; +using ::executorch::runtime::profiler_init; +using ::executorch::runtime::profiling_create_block; +using ::executorch::runtime::reset_profile_stats; +using ::executorch::runtime::set_profile_tls_state; +using ::executorch::runtime::track_allocation; +using ::executorch::runtime::track_allocator; +} // namespace executor +} // namespace torch + +#ifdef PROFILING_ENABLED + +#define EXECUTORCH_PROFILE_CREATE_BLOCK(name) \ + ::executorch::runtime::profiling_create_block(name); + +// Convenience macros to begin and end profiling. These can be inserted +// anywhere as it'll be ensured that for the prod builds these will +// essentially be noops. +#define EXECUTORCH_BEGIN_PROF(name) \ + ::executorch::runtime::begin_profiling(name); + +#define EXECUTORCH_END_PROF(token_id) \ + ::executorch::runtime::end_profiling(token_id); + +#define EXECUTORCH_SCOPE_PROF(name) \ + ::executorch::runtime::ExecutorchProfiler profiler(name); + +#define EXECUTORCH_PROFILE_INSTRUCTION_SCOPE(chain_idx, instruction_idx) \ + ::executorch::runtime::ExecutorchProfilerInstructionScope \ + __profiler_instruction_scope({chain_idx, instruction_idx}); + +#define EXECUTORCH_DUMP_PROFILE_RESULTS(prof_result) \ + ::executorch::runtime::dump_profile_stats(prof_result); + +#define EXECUTORCH_RESET_PROFILE_RESULTS() \ + ::executorch::runtime::reset_profile_stats(); + +#define EXECUTORCH_TRACK_ALLOCATOR(name) \ + ::executorch::runtime::track_allocator(name); + +#define EXECUTORCH_TRACK_ALLOCATION(id, size) \ + ::executorch::runtime::track_allocation(id, size); + +#else + +#define EXECUTORCH_PROFILE_CREATE_BLOCK(name) \ + do { \ + (void)(name); \ + } while (0) + +#define EXECUTORCH_BEGIN_PROF(name) \ + { \ + } + +#define EXECUTORCH_END_PROF(token_id) \ + do { \ + (void)(token_id); \ + } while (0) + +#define EXECUTORCH_SCOPE_PROF(name) \ + do { \ + (void)(name); \ + } while (0) + +#define EXECUTORCH_PROFILE_INSTRUCTION_SCOPE(chain_idx, instruction_idx) \ + do { \ + (void)(chain_idx); \ + (void)(instruction_idx); \ + } while (0) + +#define EXECUTORCH_DUMP_PROFILE_RESULTS(prof_result_test) \ + memset(prof_result_test, 0, sizeof(::executorch::runtime::prof_result_t)); + +#define EXECUTORCH_RESET_PROFILE_RESULTS() \ + { \ + } + +#define EXECUTORCH_TRACK_ALLOCATOR(name) ((void)(name), -1) + +#define EXECUTORCH_TRACK_ALLOCATION(id, size) \ + do { \ + (void)(id); \ + (void)(size); \ + } while (0) + +#endif diff --git a/third-party/include/executorch/runtime/platform/runtime.h b/third-party/include/executorch/runtime/platform/runtime.h new file mode 100644 index 00000000..375ae795 --- /dev/null +++ b/third-party/include/executorch/runtime/platform/runtime.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * ExecuTorch global runtime wrapper functions. + */ + +#pragma once + +#include + +namespace executorch { +namespace runtime { + +/** + * Initialize the ExecuTorch global runtime. + */ +void runtime_init(); + +} // namespace runtime +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::runtime::runtime_init; +} // namespace executor +} // namespace torch diff --git a/third-party/include/executorch/runtime/platform/system.h b/third-party/include/executorch/runtime/platform/system.h new file mode 100644 index 00000000..ae658507 --- /dev/null +++ b/third-party/include/executorch/runtime/platform/system.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * Platform abstraction layer to allow individual host OS to override + * symbols in ExecuTorch. PAL functions are defined as C functions so an + * implementer can use C in lieu of C++. + */ +#pragma once + +/** + * To enable dynamic linking debugging capability on UNIX-like OS. If enabled + * and see an error like: `undefined symbol: dladdr`, install `libdl` to fix. + */ +#if defined(ET_USE_LIBDL) +#include +#endif + +static constexpr const char *DYNAMIC_LIBRARY_NOT_SUPPORTED = "NOT_SUPPORTED"; +static constexpr const char *DYNAMIC_LIBRARY_NOT_FOUND = "NOT_FOUND"; + +extern "C" { + +/** + * Return shared library . + * + * @param[in] addr Address to the symbol we are looking for in shared libraries. + * @retval The path to the shared library containing the symbol. + */ +inline const char *et_pal_get_shared_library_name(const void *addr) { +#if defined(ET_USE_LIBDL) + Dl_info info; + if (dladdr(addr, &info) && info.dli_fname) { + return info.dli_fname; + } else { + return DYNAMIC_LIBRARY_NOT_FOUND; + } +#endif + (void)addr; + return DYNAMIC_LIBRARY_NOT_SUPPORTED; +} + +} // extern "C" diff --git a/third-party/include/executorch/runtime/platform/types.h b/third-party/include/executorch/runtime/platform/types.h new file mode 100644 index 00000000..a91f357f --- /dev/null +++ b/third-party/include/executorch/runtime/platform/types.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * Public types used by the ExecuTorch Platform Abstraction Layer. + */ + +#pragma once + +// Use C-style includes so that C code can include this header. +#include + +extern "C" { + +/// Platform timestamp in system ticks. +typedef uint64_t et_timestamp_t; + +} // extern "C" diff --git a/third-party/include/executorch/schema/extended_header.h b/third-party/include/executorch/schema/extended_header.h new file mode 100644 index 00000000..59b28e30 --- /dev/null +++ b/third-party/include/executorch/schema/extended_header.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace runtime { + +/** + * An extended, ExecuTorch-specific header that may be embedded in the + * serialized Program data header. + * + * For details see //executorch/docs/source/pte-file-format.md + */ +struct ExtendedHeader { + /** + * To find the header, callers should provide at least this many bytes of the + * head of the serialized Program data. + */ + static constexpr size_t kNumHeadBytes = 64; + + /** + * The offset into the Program serialized program data where the extended + * header should begin. + */ + static constexpr size_t kHeaderOffset = 8; + + /** + * The magic bytes that identify the header. + * + * This is the canonical definition of the expected value. If the header + * layout ever changes in a compatibility-breaking way, increment the digits + * in the magic. But, doing so will prevent older binaries from recognizing + * the presence of the header. The compatibility-preserving way to make + * changes is to increase the header's length field and add new fields at the + * end. + */ + static constexpr size_t kMagicSize = 4; + static constexpr char kMagic[kMagicSize] = {'e', 'h', '0', '0'}; + + /** + * Look for and parse an ExtendedHeader in the provided data. + * + * @param[in] data The contents of the beginning of the serialized binary + * Program data, starting at offset 0 (i.e., the head of the file). + * @param[in] size Length of `data` in bytes. Must be >= kNumHeadBytes or this + * call will fail. + * + * @returns an ExtendedHeader if the header was found and is valid. Returns an + * error if size was too short, if the header was not found, or if the + * header appeared to be corrupt. + */ + static Result Parse(const void *data, size_t size); + + /** + * The size in bytes of the Program flatbuffer data, starting from offset + * zero. + */ + uint64_t program_size; + + /** + * The offset in bytes of the first segment, if present. Zero if no segment + * is present. + */ + uint64_t segment_base_offset; +}; + +} // namespace runtime +} // namespace executorch diff --git a/turbo.json b/turbo.json index 405897ee..9a1cd260 100644 --- a/turbo.json +++ b/turbo.json @@ -7,6 +7,7 @@ "package.json", "android", "!android/build", + "common", "src/*.ts", "src/*.tsx", "example/package.json", @@ -23,6 +24,7 @@ "package.json", "*.podspec", "ios", + "common", "src/*.ts", "src/*.tsx", "example/package.json",