@@ -138,73 +138,85 @@ export async function ensureTablesExist() {
138
138
}
139
139
140
140
// 获取模型价格,如果不存在则创建默认值
141
- export async function getOrCreateModelPrice (
142
- id : string ,
143
- name : string ,
144
- base_model_id ?: string
145
- ) : Promise < ModelPrice > {
141
+ export async function getOrCreateModelPrices (
142
+ models : Array < { id : string ; name : string ; base_model_id ?: string } >
143
+ ) : Promise < ModelPrice [ ] > {
146
144
let client : PoolClient | null = null ;
147
145
try {
148
146
client = await pool . connect ( ) ;
149
147
150
- // 首先尝试获取当前模型的价格
151
- const currentModelResult = await client . query < ModelPriceRow > (
152
- `SELECT * FROM model_prices WHERE id = $1` ,
153
- [ id ]
148
+ // 1. 首先获取所有已存在的模型价格
149
+ const modelIds = models . map ( ( m ) => m . id ) ;
150
+ const baseModelIds = models . map ( ( m ) => m . base_model_id ) . filter ( ( id ) => id ) ;
151
+
152
+ const existingModelsResult = await client . query < ModelPriceRow > (
153
+ `SELECT * FROM model_prices WHERE id = ANY($1::text[])` ,
154
+ [ modelIds ]
154
155
) ;
155
156
156
- // 如果模型不存在,并且有 base_model_id,尝试获取基础模型的价格
157
- let baseModelPrices = null ;
158
- if ( currentModelResult . rows . length === 0 && base_model_id ) {
159
- const baseModelResult = await client . query < ModelPriceRow > (
160
- `SELECT input_price, output_price, per_msg_price
161
- FROM model_prices
162
- WHERE id = $1` ,
163
- [ base_model_id ]
164
- ) ;
165
- if ( baseModelResult . rows . length > 0 ) {
166
- baseModelPrices = baseModelResult . rows [ 0 ] ;
167
- }
168
- }
157
+ // 2. 获取所有基础模型的价格
158
+ const baseModelsResult = await client . query < ModelPriceRow > (
159
+ `SELECT * FROM model_prices WHERE id = ANY($1::text[])` ,
160
+ [ baseModelIds ]
161
+ ) ;
169
162
170
- // 插入或更新模型价格
171
- const result = await client . query < ModelPriceRow > (
172
- `INSERT INTO model_prices (
173
- id,
174
- name,
175
- input_price,
176
- output_price,
177
- per_msg_price
178
- )
179
- VALUES (
180
- $1,
181
- $2,
182
- $3,
183
- $4,
184
- $5
185
- )
186
- ON CONFLICT (id) DO UPDATE
187
- SET name = $2
188
- RETURNING *` ,
189
- [
190
- id ,
191
- name ,
192
- baseModelPrices ?. input_price ?? null ,
193
- baseModelPrices ?. output_price ?? null ,
194
- baseModelPrices ?. per_msg_price ?? null ,
195
- ]
163
+ const existingModels = new Map (
164
+ existingModelsResult . rows . map ( ( row ) => [ row . id , row ] )
165
+ ) ;
166
+ const baseModels = new Map (
167
+ baseModelsResult . rows . map ( ( row ) => [ row . id , row ] )
196
168
) ;
197
169
198
- return {
199
- id : result . rows [ 0 ] . id ,
200
- name : result . rows [ 0 ] . name ,
201
- input_price : Number ( result . rows [ 0 ] . input_price ) ,
202
- output_price : Number ( result . rows [ 0 ] . output_price ) ,
203
- per_msg_price : Number ( result . rows [ 0 ] . per_msg_price ) ,
204
- updated_at : result . rows [ 0 ] . updated_at ,
205
- } ;
170
+ // 3. 批量插入或更新缺失的模型
171
+ const missingModels = models . filter ( ( m ) => ! existingModels . has ( m . id ) ) ;
172
+ if ( missingModels . length > 0 ) {
173
+ const values = missingModels . map ( ( m ) => {
174
+ const baseModel = m . base_model_id
175
+ ? baseModels . get ( m . base_model_id )
176
+ : null ;
177
+ return [
178
+ m . id ,
179
+ m . name ,
180
+ baseModel ?. input_price ?? null ,
181
+ baseModel ?. output_price ?? null ,
182
+ baseModel ?. per_msg_price ?? null ,
183
+ ] ;
184
+ } ) ;
185
+
186
+ const placeholders = values
187
+ . map (
188
+ ( _ , i ) =>
189
+ `($${ i * 5 + 1 } , $${ i * 5 + 2 } , $${ i * 5 + 3 } , $${ i * 5 + 4 } , $${
190
+ i * 5 + 5
191
+ } )`
192
+ )
193
+ . join ( "," ) ;
194
+
195
+ const result = await client . query < ModelPriceRow > (
196
+ `INSERT INTO model_prices (id, name, input_price, output_price, per_msg_price)
197
+ VALUES ${ placeholders }
198
+ ON CONFLICT (id) DO UPDATE
199
+ SET name = EXCLUDED.name
200
+ RETURNING *` ,
201
+ values . flat ( )
202
+ ) ;
203
+
204
+ result . rows . forEach ( ( row ) => existingModels . set ( row . id , row ) ) ;
205
+ }
206
+
207
+ return models . map ( ( m ) => {
208
+ const row = existingModels . get ( m . id ) ! ;
209
+ return {
210
+ id : row . id ,
211
+ name : row . name ,
212
+ input_price : Number ( row . input_price ) ,
213
+ output_price : Number ( row . output_price ) ,
214
+ per_msg_price : Number ( row . per_msg_price ) ,
215
+ updated_at : row . updated_at ,
216
+ } ;
217
+ } ) ;
206
218
} catch ( error ) {
207
- console . error ( "Error in getOrCreateModelPrice :" , error ) ;
219
+ console . error ( "Error in getOrCreateModelPrices :" , error ) ;
208
220
throw error ;
209
221
} finally {
210
222
if ( client ) {
0 commit comments