diff --git a/compiler/record-hessian/include/record-hessian/RecordHessian.h b/compiler/record-hessian/include/record-hessian/RecordHessian.h new file mode 100644 index 00000000000..c29dd0ce199 --- /dev/null +++ b/compiler/record-hessian/include/record-hessian/RecordHessian.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __RECORD_HESSIAN_RECORD_HESSIAN_H__ +#define __RECORD_HESSIAN_RECORD_HESSIAN_H__ + +#include "record-hessian/HessianObserver.h" + +#include +#include + +namespace record_hessian +{ + +class RecordHessian +{ +public: + RecordHessian() {} + + void initialize(luci::Module *module); + std::unique_ptr profileData(const std::string &input_data_path); + +private: + luci_interpreter::Interpreter *getInterpreter() const { return _interpreter.get(); } + + // Never return nullptr + HessianObserver *getObserver() const { return _observer.get(); } + + luci::Module *_module = nullptr; + + std::unique_ptr _interpreter; + std::unique_ptr _observer; +}; + +} // namespace record_hessian + +#endif // __RECORD_HESSIAN_RECORD_HESSIAN_H__ diff --git a/compiler/record-hessian/src/RecordHessian.cpp b/compiler/record-hessian/src/RecordHessian.cpp new file mode 100644 index 00000000000..1ff686b63d7 --- /dev/null +++ b/compiler/record-hessian/src/RecordHessian.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "record-hessian/RecordHessian.h" +#include "record-hessian/HessianObserver.h" + +#include + +#include + +using Shape = std::vector; + +namespace record_hessian +{ + +void RecordHessian::initialize(luci::Module *module) +{ + // Create and initialize interpreters and observers + + _module = module; + + auto interpreter = std::make_unique(module); + auto observer = std::make_unique(); + + interpreter->attachObserver(observer.get()); +} + +std::unique_ptr RecordHessian::profileData(const std::string &input_data_path) +{ + try + { + dio::hdf5::HDF5Importer importer(input_data_path); + // To be implemented + } + catch (const H5::Exception &e) + { + H5::Exception::printErrorStack(); + throw std::runtime_error("RecordHessian: HDF5 error occurred."); + } + + return getObserver()->hessianData(); +} + +} // namespace record_hessian