From 261fe25b434db5d5d7ad8cb4020d54cb7d0f2643 Mon Sep 17 00:00:00 2001 From: Adriankhl Date: Wed, 29 May 2024 19:58:05 +0800 Subject: [PATCH] fix: use Ref instead of raw pointer --- src/gdllava.cpp | 6 ++--- src/gdllava.hpp | 6 ++--- src/llm_db.cpp | 58 +++++++++++++++++++++++++++---------------------- src/llm_db.hpp | 10 ++++----- 4 files changed, 43 insertions(+), 37 deletions(-) diff --git a/src/gdllava.cpp b/src/gdllava.cpp index d8e26d3..5285efc 100644 --- a/src/gdllava.cpp +++ b/src/gdllava.cpp @@ -304,7 +304,7 @@ Error GDLlava::run_generate_text_base64(String prompt, String image_base64) { } -String GDLlava::generate_text_image_internal(String prompt, Image* image) { +String GDLlava::generate_text_image_internal(String prompt, Ref image) { glog_verbose("generate_text_image_internal"); String image_base64 = Marshalls::get_singleton()->raw_to_base64(image->save_jpg_to_buffer()); @@ -319,7 +319,7 @@ String GDLlava::generate_text_image_internal(String prompt, Image* image) { return full_generated_text; } -String GDLlava::generate_text_image(String prompt, Image* image) { +String GDLlava::generate_text_image(String prompt, Ref image) { glog_verbose("generate_text_image"); func_mutex->lock(); @@ -341,7 +341,7 @@ String GDLlava::generate_text_image(String prompt, Image* image) { return full_generated_text; } -Error GDLlava::run_generate_text_image(String prompt, Image* image) { +Error GDLlava::run_generate_text_image(String prompt, Ref image) { glog_verbose("run_generate_text_image"); func_mutex->lock(); diff --git a/src/gdllava.hpp b/src/gdllava.hpp index 8f84aef..0f9a3ff 100644 --- a/src/gdllava.hpp +++ b/src/gdllava.hpp @@ -22,7 +22,7 @@ class GDLlava : public Node { Ref generate_text_thread; String generate_text_common(String prompt, String image_base64); String generate_text_base64_internal(String prompt, String image_base64); - String generate_text_image_internal(String prompt, Image* image); + String generate_text_image_internal(String prompt, Ref image); std::function glog; std::function glog_verbose; std::string generate_text_buffer; @@ -60,8 +60,8 @@ class GDLlava : public Node { bool is_running(); String generate_text_base64(String prompt, String image_base64); Error run_generate_text_base64(String prompt, String image_base64); - String generate_text_image(String prompt, Image* image); - Error run_generate_text_image(String prompt, Image* image); + String generate_text_image(String prompt, Ref image); + Error run_generate_text_image(String prompt, Ref image); void stop_generate_text(); }; diff --git a/src/llm_db.cpp b/src/llm_db.cpp index f27e5a6..a15c2bd 100644 --- a/src/llm_db.cpp +++ b/src/llm_db.cpp @@ -2,6 +2,7 @@ #include "gdembedding.hpp" #include "sqlite3.h" #include "sqlite-vec.h" +#include #include #include #include @@ -20,6 +21,7 @@ #include #include #include +#include namespace godot { @@ -50,36 +52,36 @@ LlmDBMetaData::LlmDBMetaData() : data_name {"default_name"}, LlmDBMetaData::~LlmDBMetaData() {} -LlmDBMetaData* LlmDBMetaData::create(String data_name, int data_type) { - LlmDBMetaData* data = memnew(LlmDBMetaData()); +Ref LlmDBMetaData::create(String data_name, int data_type) { + Ref data = memnew(LlmDBMetaData()); data->set_data_name(data_name); data->set_data_type(data_type); return data; } -LlmDBMetaData* LlmDBMetaData::create_int(String data_name) { - LlmDBMetaData* data = memnew(LlmDBMetaData()); +Ref LlmDBMetaData::create_int(String data_name) { + Ref data = memnew(LlmDBMetaData()); data->set_data_name(data_name); data->set_data_type(0); return data; } -LlmDBMetaData* LlmDBMetaData::create_real(String data_name) { - LlmDBMetaData* data = memnew(LlmDBMetaData()); +Ref LlmDBMetaData::create_real(String data_name) { + Ref data = memnew(LlmDBMetaData()); data->set_data_name(data_name); data->set_data_type(1); return data; } -LlmDBMetaData* LlmDBMetaData::create_text(String data_name) { - LlmDBMetaData* data = memnew(LlmDBMetaData()); +Ref LlmDBMetaData::create_text(String data_name) { + Ref data = memnew(LlmDBMetaData()); data->set_data_name(data_name); data->set_data_type(2); return data; } -LlmDBMetaData* LlmDBMetaData::create_blob(String data_name) { - LlmDBMetaData* data = memnew(LlmDBMetaData()); +Ref LlmDBMetaData::create_blob(String data_name) { + Ref data = memnew(LlmDBMetaData()); data->set_data_name(data_name); data->set_data_type(2); return data; @@ -277,7 +279,7 @@ TypedArray LlmDB::get_meta() const { void LlmDB::set_meta(TypedArray p_meta) { bool is_id_valid = true; - int col_to_remove = -1; + std::vector cols_to_remove {}; if (p_meta.size() != 0) { @@ -286,27 +288,31 @@ void LlmDB::set_meta(TypedArray p_meta) { UtilityFunctions::print_verbose("Checking meta data " + String::num_int64(i)); if (p_meta[i].get_type() != Variant::NIL) { UtilityFunctions::print_verbose("Correct resource type"); - LlmDBMetaData* sd = Object::cast_to(p_meta[i]); + Ref sd = Object::cast_to(p_meta[i]); if (sd->get_data_name() == "id") { UtilityFunctions::printerr("Column " + String::num_int64(i) + " error: Id column must be the first column (0)"); - col_to_remove = i; + cols_to_remove.push_back(i); } } } - if (col_to_remove != -1) { - UtilityFunctions::printerr("Removing column " + String::num(col_to_remove)); - p_meta.remove_at(col_to_remove); + + // Remove from the end to make sure the indexes are correct + std::reverse(cols_to_remove.begin(), cols_to_remove.end()); + + for (int i : cols_to_remove) { + UtilityFunctions::printerr("Removing column " + String::num(i)); + p_meta.remove_at(i); } - LlmDBMetaData* sd0 = Object::cast_to(p_meta[0]); + Ref sd0 = Object::cast_to(p_meta[0]); if (sd0->get_data_name() == "id" && sd0->get_data_type() != LlmDBMetaDataType::TEXT) { UtilityFunctions::printerr("Id column should be TEXT type, removing"); p_meta.remove_at(0); } // Get again since it might get removed - sd0 = Object::cast_to(p_meta[0]); - if (sd0->get_data_name() != "id") { + Ref sd0_1 = Object::cast_to(p_meta[0]); + if (sd0_1->get_data_name() != "id") { UtilityFunctions::printerr("First column is not id"); is_id_valid = false; } @@ -485,7 +491,7 @@ void LlmDB::create_llm_tables() { UtilityFunctions::print_verbose("create_llm_tables: " + table_name); String statement = "CREATE TABLE IF NOT EXISTS " + table_name + " ("; for (int i = 0; i < meta.size(); i++) { - LlmDBMetaData* sd = Object::cast_to(meta[i]); + Ref sd = Object::cast_to(meta[i]); statement += "'" + sd->get_data_name() + "' "; statement += type_int_to_string(sd->get_data_type()); statement += ", "; @@ -507,7 +513,7 @@ void LlmDB::create_llm_tables() { String statement_meta = "CREATE TABLE IF NOT EXISTS " + meta_table_name + " ("; for (int i = 0; i < meta.size(); i++) { - LlmDBMetaData* sd = Object::cast_to(meta[i]); + Ref sd = Object::cast_to(meta[i]); statement_meta += " '" + sd->get_data_name() + "' "; statement_meta += type_int_to_string(sd->get_data_type()); if (i == 0) { @@ -632,7 +638,7 @@ bool LlmDB::is_table_valid(String p_table_name) { String name = String::utf8((char *) sqlite3_column_text(stmt, 1)); String type = String::utf8((char *) sqlite3_column_text(stmt, 2)); - LlmDBMetaData* sd = Object::cast_to(meta[i]); + Ref sd = Object::cast_to(meta[i]); if (name != sd->get_data_name()) { UtilityFunctions::printerr("Column name wrong, table : " + name + ", meta: " + sd->get_data_name()); @@ -668,7 +674,7 @@ void LlmDB::store_meta(Dictionary meta_dict) { Dictionary p_meta_dict = meta_dict.duplicate(false); PackedStringArray array_bind {PackedStringArray()}; for (int i = 0; i < meta.size(); i++) { - LlmDBMetaData* sd = Object::cast_to(meta[i]); + Ref sd = Object::cast_to(meta[i]); if(p_meta_dict.has(sd->get_data_name())) { Variant v = p_meta_dict.get(sd->get_data_name(), nullptr); if (v.get_type() != type_int_to_variant(sd->get_data_type())) { @@ -875,14 +881,14 @@ void LlmDB::insert_text_by_id(String id, String text) { String statement = "INSERT INTO " + table_name + " ("; for (int i = 0; i < meta.size(); i++) { - LlmDBMetaData* sd = Object::cast_to(meta[i]); + Ref sd = Object::cast_to(meta[i]); statement += sd->get_data_name() + ", "; } statement += "llm_text, embedding) VALUES (?, "; for (int i = 1; i < meta.size(); i++) { - LlmDBMetaData* sd = Object::cast_to(meta[i]); + Ref sd = Object::cast_to(meta[i]); statement += "(SELECT " + sd->get_data_name() + " FROM " + table_name + "_meta" + " WHERE id=?), "; } @@ -985,7 +991,7 @@ void LlmDB::insert_text_by_meta(Dictionary meta_dict, String text) { Dictionary p_meta_dict = meta_dict.duplicate(false); PackedStringArray array_bind {PackedStringArray()}; for (int i = 0; i < meta.size(); i++) { - LlmDBMetaData* sd = Object::cast_to(meta[i]); + Ref sd = Object::cast_to(meta[i]); if(p_meta_dict.has(sd->get_data_name())) { Variant v = p_meta_dict.get(sd->get_data_name(), nullptr); if (v.get_type() != type_int_to_variant(sd->get_data_type())) { diff --git a/src/llm_db.hpp b/src/llm_db.hpp index aff6d0b..660dff8 100644 --- a/src/llm_db.hpp +++ b/src/llm_db.hpp @@ -34,11 +34,11 @@ class LlmDBMetaData : public Resource { public: LlmDBMetaData(); ~LlmDBMetaData(); - static LlmDBMetaData* create(String data_name, int data_type); - static LlmDBMetaData* create_int(String data_name); - static LlmDBMetaData* create_real(String data_name); - static LlmDBMetaData* create_text(String data_name); - static LlmDBMetaData* create_blob(String data_name); + static Ref create(String data_name, int data_type); + static Ref create_int(String data_name); + static Ref create_real(String data_name); + static Ref create_text(String data_name); + static Ref create_blob(String data_name); String get_data_name() const; void set_data_name(const String p_data_name); int get_data_type() const;