-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathinput_parser.cpp
More file actions
128 lines (109 loc) · 5.86 KB
/
input_parser.cpp
File metadata and controls
128 lines (109 loc) · 5.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include "flock/functions/input_parser.hpp"
#include "duckdb/common/operator/cast_operators.hpp"
namespace flock {
// Helper function to validate and clean context column, handling NULL values
static void ValidateAndCleanContextColumn(nlohmann::json& column, const std::initializer_list<const char*>& allowed_keys) {
std::string column_type = "";
bool has_type = false;
bool has_transcription_model = false;
for (const auto& key: allowed_keys) {
if (key != std::string("data")) {
bool key_exists = column.contains(key);
bool is_null = key_exists && column[key].get<std::string>() == "NULL";
if (key == std::string("type") && key_exists && !is_null) {
column_type = column[key].get<std::string>();
has_type = true;
} else if (key == std::string("transcription_model") && key_exists && !is_null) {
has_transcription_model = true;
} else if (!key_exists || is_null) {
column.erase(key);
}
}
}
// Validate transcription_model is only used with audio type
if (has_transcription_model && column_type != "audio") {
std::string type_display = has_type ? column_type : "tabular";
throw std::runtime_error(duckdb_fmt::format("Argument 'transcription_model' is not supported for data type '{}'. It can only be used with type 'audio'.", type_display));
}
// Validate that audio type requires transcription_model
if (has_type && column_type == "audio" && !has_transcription_model) {
throw std::runtime_error("Argument 'transcription_model' is required when type is 'audio'.");
}
}
nlohmann::json CastVectorOfStructsToJson(const duckdb::Vector& struct_vector, const int size) {
nlohmann::json struct_json;
for (auto i = 0; i < size; i++) {
for (auto j = 0; j < static_cast<int>(duckdb::StructType::GetChildCount(struct_vector.GetType())); j++) {
const auto key = duckdb::StructType::GetChildName(struct_vector.GetType(), j);
auto value = duckdb::StructValue::GetChildren(struct_vector.GetValue(i))[j];
if (key == "context_columns") {
if (value.GetTypeMutable().id() != duckdb::LogicalTypeId::LIST) {
throw std::runtime_error("Expected 'context_columns' to be a list.");
}
auto context_columns = duckdb::ListValue::GetChildren(value);
for (auto context_column_idx = 0; context_column_idx < static_cast<int>(context_columns.size()); context_column_idx++) {
auto context_column = context_columns[context_column_idx];
auto context_column_json = CastVectorOfStructsToJson(duckdb::Vector(context_column), 1);
auto allowed_keys = {"name", "data", "type", "detail", "transcription_model"};
for (const auto& key: context_column_json.items()) {
if (std::find(std::begin(allowed_keys), std::end(allowed_keys), key.key()) == std::end(allowed_keys)) {
throw std::runtime_error(duckdb_fmt::format("Unexpected key in 'context_columns': {}", key.key()));
}
}
auto required_keys = {"data"};
for (const auto& key: required_keys) {
if (!context_column_json.contains(key)) {
throw std::runtime_error(duckdb_fmt::format("Expected 'context_columns' to contain key: {}", key));
}
}
if (struct_json.contains("context_columns") && struct_json["context_columns"].size() == context_columns.size()) {
struct_json["context_columns"][context_column_idx]["data"].push_back(context_column_json["data"]);
} else {
struct_json["context_columns"].push_back(context_column_json);
ValidateAndCleanContextColumn(struct_json["context_columns"][context_column_idx], allowed_keys);
struct_json["context_columns"][context_column_idx]["data"] = nlohmann::json::array();
struct_json["context_columns"][context_column_idx]["data"].push_back(context_column_json["data"]);
}
}
} else if (key == "batch_size") {
if (value.GetTypeMutable() != duckdb::LogicalType::INTEGER) {
throw std::runtime_error("Expected 'batch_size' to be an integer.");
}
struct_json[key] = value.GetValue<int>();
} else {
struct_json[key] = value.ToString();
}
}
}
return struct_json;
}
nlohmann::json CastValueToJson(const duckdb::Value& value) {
nlohmann::json result;
if (value.IsNull()) {
return result;
}
auto& value_type = value.type();
if (value_type.id() == duckdb::LogicalTypeId::STRUCT) {
auto& children = duckdb::StructValue::GetChildren(value);
auto child_count = duckdb::StructType::GetChildCount(value_type);
for (idx_t i = 0; i < child_count; i++) {
auto key = duckdb::StructType::GetChildName(value_type, i);
auto& child_value = children[i];
if (!child_value.IsNull()) {
// Recursively convert child values
if (child_value.type().id() == duckdb::LogicalTypeId::STRUCT) {
result[key] = CastValueToJson(child_value);
} else if (child_value.type().id() == duckdb::LogicalTypeId::INTEGER) {
result[key] = child_value.GetValue<int>();
} else {
result[key] = child_value.ToString();
}
}
}
}
else {
result = nlohmann::json(value.ToString());
}
return result;
}
}// namespace flock