Skip to content

Commit 499f706

Browse files
committed
add options argument in NewModel method
1 parent 72fb6b9 commit 499f706

File tree

3 files changed

+70
-5
lines changed

3 files changed

+70
-5
lines changed

genai/main_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func newTestModel(t *testing.T, rt *Runtime) *Model {
3737
t.Skip("Test model not available. Set ONNXRUNTIME_GENAI_MODEL_PATH environment variable.")
3838
}
3939

40-
model, err := rt.NewModel(testModelPath)
40+
model, err := rt.NewModel(testModelPath, nil)
4141
if err != nil {
4242
t.Fatalf("Failed to create model: %v", err)
4343
}

genai/model_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ func TestNewModel(t *testing.T) {
1111

1212
rt := newTestRuntime(t)
1313

14-
model, err := rt.NewModel(testModelPath)
14+
model, err := rt.NewModel(testModelPath, nil)
1515
if err != nil {
1616
t.Fatalf("Failed to create model: %v", err)
1717
}

genai/runtime.go

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,81 @@ func (r *Runtime) Close() error {
8585
return nil
8686
}
8787

88+
// ProviderOptions contains configuration options for an execution provider.
89+
type ProviderOptions map[string]string
90+
91+
// ModelOptions configures options for creating a model.
92+
type ModelOptions struct {
93+
// Providers specifies the execution providers to use, in order of preference.
94+
Providers []string
95+
96+
// ProviderOptions specifies options for each provider.
97+
// The key is the provider name, and the value is a map of option key-value pairs.
98+
ProviderOptions map[string]ProviderOptions
99+
}
100+
88101
// NewModel loads a model from the specified directory path.
89102
// The path should point to a directory containing the model files
90103
// (e.g., genai_config.json, model.onnx, tokenizer files).
91-
func (r *Runtime) NewModel(modelPath string) (*Model, error) {
104+
// If options is nil, default options will be used.
105+
func (r *Runtime) NewModel(modelPath string, options *ModelOptions) (*Model, error) {
92106
pathBytes := stringToBytes(modelPath)
93107

108+
// If no options or no providers specified, use the simple CreateModel
109+
if options == nil || len(options.Providers) == 0 {
110+
var modelPtr api.OgaModel
111+
result := r.funcs.CreateModel(&pathBytes[0], &modelPtr)
112+
if err := resultError(r.funcs, result); err != nil {
113+
return nil, fmt.Errorf("failed to create model: %w", err)
114+
}
115+
116+
return &Model{
117+
ptr: modelPtr,
118+
runtime: r,
119+
}, nil
120+
}
121+
122+
// Create config for provider configuration
123+
var configPtr api.OgaConfig
124+
result := r.funcs.CreateConfig(&pathBytes[0], &configPtr)
125+
if err := resultError(r.funcs, result); err != nil {
126+
return nil, fmt.Errorf("failed to create config: %w", err)
127+
}
128+
defer r.funcs.DestroyConfig(configPtr)
129+
130+
// Clear existing providers
131+
result = r.funcs.ConfigClearProviders(configPtr)
132+
if err := resultError(r.funcs, result); err != nil {
133+
return nil, fmt.Errorf("failed to clear providers: %w", err)
134+
}
135+
136+
// Append providers
137+
for _, provider := range options.Providers {
138+
providerBytes := stringToBytes(provider)
139+
result = r.funcs.ConfigAppendProvider(configPtr, &providerBytes[0])
140+
if err := resultError(r.funcs, result); err != nil {
141+
return nil, fmt.Errorf("failed to append provider %q: %w", provider, err)
142+
}
143+
144+
// Set provider options if specified
145+
if options.ProviderOptions != nil {
146+
if providerOpts, ok := options.ProviderOptions[provider]; ok {
147+
for key, value := range providerOpts {
148+
keyBytes := stringToBytes(key)
149+
valueBytes := stringToBytes(value)
150+
result = r.funcs.ConfigSetProviderOption(configPtr, &providerBytes[0], &keyBytes[0], &valueBytes[0])
151+
if err := resultError(r.funcs, result); err != nil {
152+
return nil, fmt.Errorf("failed to set provider option %q=%q for %q: %w", key, value, provider, err)
153+
}
154+
}
155+
}
156+
}
157+
}
158+
94159
var modelPtr api.OgaModel
95-
result := r.funcs.CreateModel(&pathBytes[0], &modelPtr)
160+
result = r.funcs.CreateModelFromConfig(configPtr, &modelPtr)
96161
if err := resultError(r.funcs, result); err != nil {
97-
return nil, fmt.Errorf("failed to create model: %w", err)
162+
return nil, fmt.Errorf("failed to create model from config: %w", err)
98163
}
99164

100165
return &Model{

0 commit comments

Comments
 (0)