diff --git a/src/python/python.cpp b/src/python/python.cpp index 4ee8f87ed..9f6926eb4 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -215,6 +215,13 @@ struct PyDeviceMemorySpan { pybind11::array_t py_cpu_array_; }; +struct PyNamedTensors { + PyNamedTensors(std::unique_ptr named_tensors) : named_tensors_{std::move(named_tensors)} { + } + + std::unique_ptr named_tensors_; +}; + struct PyGeneratorParams { PyGeneratorParams(const Model& model) : params_{std::make_shared(model)} { } @@ -238,6 +245,11 @@ struct PyGeneratorParams { refs_.emplace_back(value); } + void SetInputs(std::shared_ptr named_tensors) { + params_->SetInputs(*named_tensors->named_tensors_); + named_tensors_ = named_tensors; + } + void SetSearchOptions(const pybind11::kwargs& dict) { for (auto& entry : dict) { auto name = entry.first.cast(); @@ -268,14 +280,8 @@ struct PyGeneratorParams { pybind11::array py_whisper_input_features_; pybind11::array py_alignment_heads_; - std::vector refs_; // References to data we want to ensure doesn't get garbage collected -}; - -struct PyNamedTensors { - PyNamedTensors(std::unique_ptr named_tensors) : named_tensors_{std::move(named_tensors)} { - } - - std::unique_ptr named_tensors_; + std::vector refs_; // References to data we want to ensure doesn't get garbage collected + std::shared_ptr named_tensors_; // Ensure the model inputs don't get garbage collected }; struct PyGenerator { @@ -387,11 +393,11 @@ PYBIND11_MODULE(onnxruntime_genai, m) { // TODO(baijumeswani): Rename/redesign the whisper_input_features to be more generic .def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_) .def_readwrite("alignment_heads", &PyGeneratorParams::py_alignment_heads_) - .def("set_inputs", [](PyGeneratorParams& generator_params, PyNamedTensors* named_tensors) { + .def("set_inputs", [](PyGeneratorParams& generator_params, std::shared_ptr named_tensors) { if (!named_tensors || !named_tensors->named_tensors_) throw std::runtime_error("No inputs provided."); - generator_params.params_->SetInputs(*named_tensors->named_tensors_); + generator_params.SetInputs(named_tensors); }) .def("set_model_input", &PyGeneratorParams::SetModelInput) .def("set_search_options", &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options @@ -456,8 +462,8 @@ PYBIND11_MODULE(onnxruntime_genai, m) { generator.SetActiveAdapter(adapters, adapter_name); }); - pybind11::class_(m, "Images") - .def_static("open", [](pybind11::args image_paths) { + pybind11::class_>(m, "Images") + .def_static("open", [](pybind11::args image_paths) -> std::shared_ptr { if (image_paths.empty()) throw std::runtime_error("No images provided"); @@ -470,7 +476,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { image_paths_vector.push_back(image_paths_string.back().c_str()); } - return LoadImages(image_paths_vector); + return std::shared_ptr(LoadImages(image_paths_vector)); }) .def_static("open_bytes", [](pybind11::args image_datas) { if (image_datas.empty()) @@ -486,10 +492,10 @@ PYBIND11_MODULE(onnxruntime_genai, m) { image_raw_data[i] = ort_extensions::ImageRawData(data, data + info.size); } - return std::make_unique(std::move(image_raw_data), image_datas.size()); + return std::make_shared(std::move(image_raw_data), image_datas.size()); }); - pybind11::class_(m, "Audios") + pybind11::class_>(m, "Audios") .def_static("open", [](pybind11::args audio_paths) { if (audio_paths.empty()) throw std::runtime_error("No audios provided"); @@ -504,14 +510,14 @@ PYBIND11_MODULE(onnxruntime_genai, m) { audio_paths_vector.push_back(audio_paths_string.back().c_str()); } - return LoadAudios(audio_paths_vector); + return std::shared_ptr(LoadAudios(audio_paths_vector)); }); - pybind11::class_(m, "NamedTensors"); + pybind11::class_>(m, "NamedTensors"); pybind11::class_>(m, "MultiModalProcessor") .def( - "__call__", [](MultiModalProcessor& processor, const std::optional& prompt, const pybind11::kwargs& kwargs) -> std::unique_ptr { + "__call__", [](MultiModalProcessor& processor, const std::optional& prompt, const pybind11::kwargs& kwargs) -> std::shared_ptr { if (kwargs.contains("images")) { if (processor.image_processor_ == nullptr) { throw std::runtime_error("Image processor is not available for this model."); @@ -520,11 +526,11 @@ PYBIND11_MODULE(onnxruntime_genai, m) { if (!prompt.has_value()) { throw std::runtime_error("Prompt is required for processing the image."); } - return std::make_unique( + return std::make_shared( processor.image_processor_->Process(*processor.tokenizer_, *prompt, images)); } else if (kwargs.contains("audios")) { const Audios* audios = kwargs["audios"].cast(); - return std::make_unique( + return std::make_shared( processor.audio_processor_->Process(audios)); } else { throw std::runtime_error("Nothing to process.");