Skip to content

Commit

Permalink
Hold onto named tensors to ensure they don't get garbage collected in…
Browse files Browse the repository at this point in the history
… Python (#1174)
  • Loading branch information
baijumeswani authored Jan 9, 2025
1 parent ee1fadd commit 636a95e
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ struct PyDeviceMemorySpan {
pybind11::array_t<T> py_cpu_array_;
};

struct PyNamedTensors {
PyNamedTensors(std::unique_ptr<NamedTensors> named_tensors) : named_tensors_{std::move(named_tensors)} {
}

std::unique_ptr<NamedTensors> named_tensors_;
};

struct PyGeneratorParams {
PyGeneratorParams(const Model& model) : params_{std::make_shared<GeneratorParams>(model)} {
}
Expand All @@ -238,6 +245,11 @@ struct PyGeneratorParams {
refs_.emplace_back(value);
}

void SetInputs(std::shared_ptr<PyNamedTensors> named_tensors) {
params_->SetInputs(*named_tensors->named_tensors_);
named_tensors_ = named_tensors;
}

void SetSearchOptions(const pybind11::kwargs& dict) {
for (auto& entry : dict) {
auto name = entry.first.cast<std::string>();
Expand Down Expand Up @@ -268,14 +280,8 @@ struct PyGeneratorParams {
pybind11::array py_whisper_input_features_;
pybind11::array py_alignment_heads_;

std::vector<pybind11::object> refs_; // References to data we want to ensure doesn't get garbage collected
};

struct PyNamedTensors {
PyNamedTensors(std::unique_ptr<NamedTensors> named_tensors) : named_tensors_{std::move(named_tensors)} {
}

std::unique_ptr<NamedTensors> named_tensors_;
std::vector<pybind11::object> refs_; // References to data we want to ensure doesn't get garbage collected
std::shared_ptr<PyNamedTensors> named_tensors_; // Ensure the model inputs don't get garbage collected
};

struct PyGenerator {
Expand Down Expand Up @@ -387,11 +393,11 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
// TODO(baijumeswani): Rename/redesign the whisper_input_features to be more generic
.def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_)
.def_readwrite("alignment_heads", &PyGeneratorParams::py_alignment_heads_)
.def("set_inputs", [](PyGeneratorParams& generator_params, PyNamedTensors* named_tensors) {
.def("set_inputs", [](PyGeneratorParams& generator_params, std::shared_ptr<PyNamedTensors> named_tensors) {
if (!named_tensors || !named_tensors->named_tensors_)
throw std::runtime_error("No inputs provided.");

generator_params.params_->SetInputs(*named_tensors->named_tensors_);
generator_params.SetInputs(named_tensors);
})
.def("set_model_input", &PyGeneratorParams::SetModelInput)
.def("set_search_options", &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options
Expand Down Expand Up @@ -456,8 +462,8 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
generator.SetActiveAdapter(adapters, adapter_name);
});

pybind11::class_<Images>(m, "Images")
.def_static("open", [](pybind11::args image_paths) {
pybind11::class_<Images, std::shared_ptr<Images>>(m, "Images")
.def_static("open", [](pybind11::args image_paths) -> std::shared_ptr<Images> {
if (image_paths.empty())
throw std::runtime_error("No images provided");

Expand All @@ -470,7 +476,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
image_paths_vector.push_back(image_paths_string.back().c_str());
}

return LoadImages(image_paths_vector);
return std::shared_ptr<Images>(LoadImages(image_paths_vector));
})
.def_static("open_bytes", [](pybind11::args image_datas) {
if (image_datas.empty())
Expand All @@ -486,10 +492,10 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
image_raw_data[i] = ort_extensions::ImageRawData(data, data + info.size);
}

return std::make_unique<Images>(std::move(image_raw_data), image_datas.size());
return std::make_shared<Images>(std::move(image_raw_data), image_datas.size());
});

pybind11::class_<Audios>(m, "Audios")
pybind11::class_<Audios, std::shared_ptr<Audios>>(m, "Audios")
.def_static("open", [](pybind11::args audio_paths) {
if (audio_paths.empty())
throw std::runtime_error("No audios provided");
Expand All @@ -504,14 +510,14 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
audio_paths_vector.push_back(audio_paths_string.back().c_str());
}

return LoadAudios(audio_paths_vector);
return std::shared_ptr<Audios>(LoadAudios(audio_paths_vector));
});

pybind11::class_<PyNamedTensors>(m, "NamedTensors");
pybind11::class_<PyNamedTensors, std::shared_ptr<PyNamedTensors>>(m, "NamedTensors");

pybind11::class_<MultiModalProcessor, std::shared_ptr<MultiModalProcessor>>(m, "MultiModalProcessor")
.def(
"__call__", [](MultiModalProcessor& processor, const std::optional<std::string>& prompt, const pybind11::kwargs& kwargs) -> std::unique_ptr<PyNamedTensors> {
"__call__", [](MultiModalProcessor& processor, const std::optional<std::string>& prompt, const pybind11::kwargs& kwargs) -> std::shared_ptr<PyNamedTensors> {
if (kwargs.contains("images")) {
if (processor.image_processor_ == nullptr) {
throw std::runtime_error("Image processor is not available for this model.");
Expand All @@ -520,11 +526,11 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
if (!prompt.has_value()) {
throw std::runtime_error("Prompt is required for processing the image.");
}
return std::make_unique<PyNamedTensors>(
return std::make_shared<PyNamedTensors>(
processor.image_processor_->Process(*processor.tokenizer_, *prompt, images));
} else if (kwargs.contains("audios")) {
const Audios* audios = kwargs["audios"].cast<const Audios*>();
return std::make_unique<PyNamedTensors>(
return std::make_shared<PyNamedTensors>(
processor.audio_processor_->Process(audios));
} else {
throw std::runtime_error("Nothing to process.");
Expand Down

0 comments on commit 636a95e

Please sign in to comment.