Skip to content

Commit c00e886

Browse files
authored
fix base score and msg format (#66)
* fix err msg formta * fix base score miss * update version & changelog
1 parent e6b7d74 commit c00e886

File tree

9 files changed

+26
-10
lines changed

9 files changed

+26
-10
lines changed

Diff for: CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111

1212
> please add your unreleased change here.
1313
14+
## 20240524 - 0.3.1b0
15+
16+
- [Bugfix] fix tree predict base score miss
17+
- [Bugfix] fix http adapater error msg format failed
18+
1419
## 20240423 - 0.3.0b0
1520

1621
- [Feature] Add Trace function

Diff for: secretflow_serving/feature_adapter/http_adapter.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ void HttpFeatureAdapter::OnFetchFeature(const Request& request,
147147
SetSpanAttrs(span, span_option);
148148

149149
SERVING_ENFORCE(span_option.code == errors::ErrorCode::OK, span_option.code,
150-
span_option.msg);
150+
"{}", span_option.msg);
151151
response->header->mutable_data()->swap(
152152
*spi_response.mutable_header()->mutable_data());
153153
response->features =

Diff for: secretflow_serving/framework/execute_context.cc

+3-4
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ void ExecuteContext::CheckAndUpdateResponse() {
2626
void ExecuteContext::CheckAndUpdateResponse(
2727
const apis::ExecuteResponse& exec_res) {
2828
if (!CheckStatusOk(exec_res.status())) {
29-
SERVING_THROW(
30-
exec_res.status().code(),
31-
fmt::format("{} exec failed: code({}), {}", target_id_,
32-
exec_res.status().code(), exec_res.status().msg()));
29+
SERVING_THROW(exec_res.status().code(), "{} exec failed: code({}), {}",
30+
target_id_, exec_res.status().code(),
31+
exec_res.status().msg());
3332
}
3433
MergeResonseHeader(exec_res);
3534
}

Diff for: secretflow_serving/framework/execute_context.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ class RemoteExecute : public ExecuteBase,
209209
if (span_option.code == errors::ErrorCode::OK) {
210210
exec_ctx_.MergeResonseHeader();
211211
} else {
212-
SERVING_THROW(span_option.code, span_option.msg);
212+
SERVING_THROW(span_option.code, "{}", span_option.msg);
213213
}
214214
}
215215

Diff for: secretflow_serving/ops/tree_ensemble_predict.cc

+8-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ TreeEnsemblePredict::TreeEnsemblePredict(OpKernelOptions opts)
4040
GetNodeAttr<std::string>(opts_.node_def, *opts_.op_def, "algo_func");
4141
func_type_ = ParseLinkFuncType(func_type_str);
4242

43+
base_score_ =
44+
GetNodeAttr<double>(opts_.node_def, *opts_.op_def, "base_score");
45+
4346
BuildInputSchema();
4447
BuildOutputSchema();
4548
}
@@ -65,7 +68,8 @@ void TreeEnsemblePredict::DoCompute(ComputeContext* ctx) {
6568
arrow::DoubleBuilder builder;
6669
SERVING_CHECK_ARROW_STATUS(builder.Resize(merged_array->length()));
6770
for (int64_t i = 0; i < merged_array->length(); ++i) {
68-
auto score = ApplyLinkFunc(merged_array->Value(i), func_type_);
71+
auto score = merged_array->Value(i) + base_score_;
72+
score = ApplyLinkFunc(score, func_type_);
6973
SERVING_CHECK_ARROW_STATUS(builder.Append(score));
7074
}
7175
std::shared_ptr<arrow::Array> res_array;
@@ -90,7 +94,7 @@ void TreeEnsemblePredict::BuildOutputSchema() {
9094
}
9195

9296
REGISTER_OP_KERNEL(TREE_ENSEMBLE_PREDICT, TreeEnsemblePredict)
93-
REGISTER_OP(TREE_ENSEMBLE_PREDICT, "0.0.1",
97+
REGISTER_OP(TREE_ENSEMBLE_PREDICT, "0.0.2",
9498
"Accept the weighted results from multiple trees (`TREE_SELECT` + "
9599
"`TREE_MERGE`), merge them, and obtain the final prediction result "
96100
"of the tree ensemble.")
@@ -101,6 +105,8 @@ REGISTER_OP(TREE_ENSEMBLE_PREDICT, "0.0.1",
101105
.StringAttr("output_col_name",
102106
"The column name of tree ensemble predict score", false, false)
103107
.Int32Attr("num_trees", "The number of ensemble's tree", false, false)
108+
.DoubleAttr("base_score", "The initial prediction score, global bias.",
109+
false, true, 0.0)
104110
.StringAttr(
105111
"algo_func",
106112
"Optional value: "

Diff for: secretflow_serving/ops/tree_ensemble_predict.h

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class TreeEnsemblePredict : public OpKernel {
3737

3838
int32_t num_trees_;
3939
LinkFunctionType func_type_;
40+
41+
double base_score_ = 0.0;
4042
};
4143

4244
} // namespace secretflow::serving::op

Diff for: secretflow_serving/ops/tree_ensemble_predict_test.cc

+4
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ TEST_P(TreeEnsemblePredictParamTest, Works) {
4747
},
4848
"output_col_name": {
4949
"s": "scores"
50+
},
51+
"base_score": {
52+
"d": 0.1
5053
}
5154
}
5255
}
@@ -96,6 +99,7 @@ TEST_P(TreeEnsemblePredictParamTest, Works) {
9699
for (size_t col = 1; col < param.tree_weights.size(); ++col) {
97100
score += param.tree_weights[col][row];
98101
}
102+
score += 0.1;
99103
SERVING_CHECK_ARROW_STATUS(expect_res_builder.Append(
100104
ApplyLinkFunc(score, ParseLinkFuncType(param.algo_func))));
101105
}

Diff for: secretflow_serving_lib/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
# limitations under the License.
1414

1515

16-
__version__ = "0.3.0b0"
16+
__version__ = "0.3.1b0"

Diff for: version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
# limitations under the License.
1414

1515

16-
__version__ = "0.3.0b0"
16+
__version__ = "0.3.1b0"

0 commit comments

Comments
 (0)