Skip to content

Commit e7cc669

Browse files
ajindal1pranavsharma
authored andcommitted
Set runtime options for Terminating session (#1011)
Add API option for setting any runtime options. Particularly, add Terminate session option for the user, the session remains in the terminated state until the user disables terminate_session. To terminate session: SetRuntimeOptionsConfig("terminate_session", "1") To unset: SetRuntimeOptionsConfig("terminate_session", "0") --------- Co-authored-by: Pranav Sharma <[email protected]>
1 parent bc44d7b commit e7cc669

File tree

9 files changed

+136
-0
lines changed

9 files changed

+136
-0
lines changed

documents/Runtime_option.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Runtime Options
2+
3+
This file will provide details on the usage of SetRuntimeOption API. It will list all the current key value pairs which can be used as an input for this API.
4+
5+
## Set Terminate
6+
7+
Set Terminate is a runtime option to terminate the current session or continue/restart an already terminated session. There are two valid ways to call Set Terminate.
8+
9+
To enable terminate, the valid pair is: ("set_terminate", "1")
10+
11+
To disable terminate, the valid pair is: ("set_terminate", "0")
12+
13+
Key: "set_terminate"
14+
15+
Accepted values: ("0", "1")

src/generators.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ std::string CurrentModulePath() {
3030
}
3131
#endif
3232

33+
void ThrowErrorIfSessionTerminated(bool is_session_terminated) {
34+
if (is_session_terminated)
35+
throw std::runtime_error("Session in Terminated state, exiting!");
36+
}
37+
3338
namespace Generators {
3439

3540
#if USE_CUDA
@@ -284,6 +289,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_
284289
}
285290

286291
void Generator::ComputeLogits() {
292+
ThrowErrorIfSessionTerminated(state_->session_terminated_);
287293
if (computed_logits_)
288294
throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first");
289295

@@ -301,7 +307,25 @@ void Generator::ComputeLogits() {
301307
search_->ApplyRepetitionPenalty(search.repetition_penalty);
302308
}
303309

310+
void Generator::SetRuntimeOption(const char* key, const char* value) {
311+
// TODO: Need a better way to handle different keys
312+
// We can create a config manager to host all configurations and do comparison at that point
313+
if (strcmp(key, "terminate_session") == 0) {
314+
if (strcmp(value, "0") == 0) {
315+
state_->UnsetTerminate();
316+
} else if (strcmp(value, "1") == 0) {
317+
state_->SetTerminate();
318+
} else {
319+
// Value not expected
320+
throw std::runtime_error(std::string("terminate_session key value unexpected: ") + value);
321+
}
322+
} else {
323+
throw std::runtime_error(std::string("SetRuntimeOption key is not expected: ") + key);
324+
}
325+
}
326+
304327
bool Generator::IsDone() const {
328+
ThrowErrorIfSessionTerminated(state_->session_terminated_);
305329
if (computed_logits_)
306330
throw std::runtime_error("IsDone() can't be called in the middle of processing logits");
307331

@@ -313,7 +337,12 @@ bool Generator::IsDone() const {
313337
return is_done;
314338
}
315339

340+
bool Generator::IsSessionTerminated() const {
341+
return state_->session_terminated_;
342+
}
343+
316344
void Generator::GenerateNextToken() {
345+
ThrowErrorIfSessionTerminated(state_->session_terminated_);
317346
if (!computed_logits_)
318347
throw std::runtime_error("Must call ComputeLogits before GenerateNextToken");
319348
computed_logits_ = false;

src/generators.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ using cudaStream_t = void*;
4040
#include "runtime_settings.h"
4141
#include "tensor.h"
4242

43+
void ThrowErrorIfSessionTerminated(bool is_session_terminated);
44+
4345
namespace Generators {
4446
struct Model;
4547
struct State;
@@ -108,7 +110,9 @@ struct Generator : LeakChecked<Generator> {
108110
Generator(const Model& model, const GeneratorParams& params);
109111

110112
bool IsDone() const;
113+
void SetRuntimeOption(const char* key, const char* value);
111114
void ComputeLogits();
115+
bool IsSessionTerminated() const;
112116
void GenerateNextToken();
113117

114118
DeviceMemorySpan<int32_t> GetSequence(size_t index) const;

src/models/model.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,18 @@ void State::Run(OrtSession& session, int new_batch_size) {
6565
}
6666
}
6767

68+
void State::SetTerminate() {
69+
session_terminated_ = true;
70+
run_options_->SetTerminate();
71+
}
72+
73+
void State::UnsetTerminate() {
74+
session_terminated_ = false;
75+
run_options_->UnsetTerminate();
76+
}
77+
6878
OrtValue* State::GetInput(const char* name) {
79+
ThrowErrorIfSessionTerminated(session_terminated_);
6980
for (size_t i = 0; i < input_names_.size(); i++) {
7081
if (std::strcmp(input_names_[i], name) == 0) {
7182
return inputs_[i];
@@ -75,6 +86,7 @@ OrtValue* State::GetInput(const char* name) {
7586
}
7687

7788
OrtValue* State::GetOutput(const char* name) {
89+
ThrowErrorIfSessionTerminated(session_terminated_);
7890
for (size_t i = 0; i < output_names_.size(); i++) {
7991
if (std::strcmp(output_names_[i], name) == 0) {
8092
return outputs_[i];

src/models/model.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ struct State {
3434
virtual const CapturedGraphInfo* GetCapturedGraphInfo() const { return nullptr; }
3535
virtual void Finalize() {}
3636

37+
void SetTerminate();
38+
void UnsetTerminate();
39+
mutable bool session_terminated_{};
3740
OrtValue* GetInput(const char* name);
3841

3942
virtual OrtValue* GetOutput(const char* name);

src/ort_genai.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,10 @@ struct OgaGenerator : OgaAbstract {
258258
return OgaGenerator_IsDone(this);
259259
}
260260

261+
bool IsSessionTerminated() const {
262+
return OgaGenerator_IsSessionTerminated(this);
263+
}
264+
261265
void ComputeLogits() {
262266
OgaCheckResult(OgaGenerator_ComputeLogits(this));
263267
}
@@ -266,6 +270,10 @@ struct OgaGenerator : OgaAbstract {
266270
OgaCheckResult(OgaGenerator_GenerateNextToken(this));
267271
}
268272

273+
void SetRuntimeOption(const char* key, const char* value) {
274+
OgaCheckResult(OgaGenerator_SetRuntimeOption(this, key, value));
275+
}
276+
269277
size_t GetSequenceCount(size_t index) const {
270278
return OgaGenerator_GetSequenceCount(this, index);
271279
}

src/ort_genai_c.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,10 @@ bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator) {
271271
return reinterpret_cast<const Generators::Generator*>(generator)->IsDone();
272272
}
273273

274+
bool OGA_API_CALL OgaGenerator_IsSessionTerminated(const OgaGenerator* generator) {
275+
return reinterpret_cast<const Generators::Generator*>(generator)->IsSessionTerminated();
276+
}
277+
274278
OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator) {
275279
OGA_TRY
276280
reinterpret_cast<Generators::Generator*>(generator)->ComputeLogits();
@@ -285,6 +289,13 @@ OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator)
285289
OGA_CATCH
286290
}
287291

292+
OgaResult* OGA_API_CALL OgaGenerator_SetRuntimeOption(OgaGenerator* generator, const char* key, const char* value) {
293+
OGA_TRY
294+
reinterpret_cast<Generators::Generator*>(generator)->SetRuntimeOption(key, value);
295+
return nullptr;
296+
OGA_CATCH
297+
}
298+
288299
OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out) {
289300
OGA_TRY
290301
auto& generator = *reinterpret_cast<const Generators::Generator*>(oga_generator);

src/ort_genai_c.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator);
278278
* \return True if the generator has finished generating all the sequences, false otherwise.
279279
*/
280280
OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator);
281+
OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsSessionTerminated(const OgaGenerator* generator);
281282

282283
/*
283284
* \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator.
@@ -287,6 +288,8 @@ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator);
287288
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator);
288289
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator);
289290

291+
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_SetRuntimeOption(OgaGenerator* generator, const char* key, const char* value);
292+
290293
/*
291294
* \brief Returns a copy of the model output identified by the given name as an OgaTensor on CPU. The buffer is owned by returned OgaTensor
292295
* and will be released when the OgaTensor is destroyed

test/c_api_tests.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <iostream>
66
#include <ort_genai.h>
77
#include "../src/span.h"
8+
#include <thread>
9+
#include <vector>
810

911
#ifndef MODEL_PATH
1012
#define MODEL_PATH "../../test/test_models/"
@@ -283,6 +285,55 @@ TEST(CAPITests, GetOutputCAPI) {
283285
generator->GenerateNextToken();
284286
}
285287

288+
TEST(CAPITests, SetTerminate) {
289+
#if TEST_PHI2
290+
291+
auto GeneratorSetTerminateCall = [](OgaGenerator* generator) {
292+
// Set Terminate
293+
generator->SetRuntimeOption("terminate_session", "1");
294+
};
295+
296+
auto GenerateOutput = [](OgaGenerator* generator, std::unique_ptr<OgaTokenizerStream> tokenizer_stream) {
297+
try {
298+
while (!generator->IsDone()) {
299+
generator->ComputeLogits();
300+
generator->GenerateNextToken();
301+
}
302+
}
303+
catch (const std::exception& e) {
304+
EXPECT_EQ(generator->IsSessionTerminated(), true);
305+
std::cout << "Session Terminated: " << e.what() << std::endl;
306+
}
307+
};
308+
309+
auto model = OgaModel::Create(PHI2_PATH);
310+
auto tokenizer = OgaTokenizer::Create(*model);
311+
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);
312+
313+
const char* input_string = "She sells sea shells by the sea shore.";
314+
auto input_sequences = OgaSequences::Create();
315+
tokenizer->Encode(input_string, *input_sequences);
316+
auto params = OgaGeneratorParams::Create(*model);
317+
params->SetInputSequences(*input_sequences);
318+
params->SetSearchOption("max_length", 40);
319+
320+
auto generator = OgaGenerator::Create(*model, *params);
321+
EXPECT_EQ(generator->IsSessionTerminated(), false);
322+
std::vector<std::thread> threads;
323+
threads.push_back(std::thread(GenerateOutput, generator.get(), std::move(tokenizer_stream)));
324+
threads.push_back(std::thread(GeneratorSetTerminateCall, generator.get()));
325+
326+
for (auto& th : threads) {
327+
std::cout << "Waiting for threads completion" << std::endl;
328+
th.join(); // Wait for each thread to finish
329+
}
330+
EXPECT_EQ(generator->IsSessionTerminated(), true);
331+
// Unset terminate
332+
generator->SetRuntimeOption("terminate_session", "0");
333+
EXPECT_EQ(generator->IsSessionTerminated(), false);
334+
#endif
335+
}
336+
286337
#if TEST_PHI2
287338

288339
struct Phi2Test {

0 commit comments

Comments
 (0)