1
1
import Axios from "axios" ;
2
- import type { AxiosInstance } from "axios " ;
3
- import { encodeBatchTexts , encodeBatchVectors } from "./proto " ;
4
- import { SearchOptions , QueryOptions , BatchVectors , BatchTexts , MetadataType } from "./schemas " ;
5
- import { CategoryTypes , DriaParams , ModelTypes } from "./types " ;
6
- import constants from "./constants " ;
2
+ import { encodeBatchTexts , encodeBatchVectors } from "../proto " ;
3
+ import { SearchOptions , QueryOptions , BatchVectors , BatchTexts , MetadataType } from "../schemas " ;
4
+ import { CategoryTypes , DriaParams , ModelTypes } from "../types " ;
5
+ import constants from "../constants " ;
6
+ import { DriaCommon } from "./common " ;
7
7
8
8
/**
9
- * Dria JS Client
9
+ * ## Dria Client
10
10
*
11
11
* @param params optional API key and contract ID.
12
12
*
@@ -40,8 +40,7 @@ import constants from "./constants";
40
40
* dria.contractId = contractId;
41
41
*/
42
42
// eslint-disable-next-line @typescript-eslint/no-explicit-any
43
- export class Dria < T extends MetadataType = any > {
44
- protected client : AxiosInstance ;
43
+ export class Dria < T extends MetadataType = any > extends DriaCommon {
45
44
contractId : string | undefined ;
46
45
/** Cached contract models. */
47
46
private models : Record < string , ModelTypes > = { } ;
@@ -50,18 +49,21 @@ export class Dria<T extends MetadataType = any> {
50
49
const apiKey = params . apiKey ?? process . env . DRIA_API_KEY ;
51
50
if ( ! apiKey ) throw new Error ( "Missing Dria API key." ) ;
52
51
52
+ super (
53
+ Axios . create ( {
54
+ headers : {
55
+ "x-api-key" : apiKey ,
56
+ "Content-Type" : "application/json" ,
57
+ "Accept-Encoding" : "gzip, deflate, br" ,
58
+ Connection : "keep-alive" ,
59
+ Accept : "*/*" ,
60
+ } ,
61
+ // lets us handle the errors
62
+ validateStatus : ( ) => true ,
63
+ } ) ,
64
+ ) ;
65
+
53
66
this . contractId = params . contractId ;
54
- this . client = Axios . create ( {
55
- headers : {
56
- "x-api-key" : apiKey ,
57
- "Content-Type" : "application/json" ,
58
- "Accept-Encoding" : "gzip, deflate, br" ,
59
- Connection : "keep-alive" ,
60
- Accept : "*/*" ,
61
- } ,
62
- // lets us handle the errors
63
- validateStatus : ( ) => true ,
64
- } ) ;
65
67
}
66
68
67
69
/** A text-based search.
@@ -79,7 +81,7 @@ export class Dria<T extends MetadataType = any> {
79
81
async search ( text : string , options : SearchOptions = { } ) {
80
82
options = SearchOptions . parse ( options ) ;
81
83
const contractId = this . getContractId ( ) ;
82
- return await this . post < { id : number ; metadata : string ; score : number } [ ] > ( constants . DRIA_SEARCH_URL + "/search" , {
84
+ return await this . post < { id : number ; metadata : string ; score : number } [ ] > ( constants . DRIA . SEARCH_URL + "/search" , {
83
85
query : text ,
84
86
top_n : options . topK ,
85
87
level : options . level ,
@@ -103,7 +105,7 @@ export class Dria<T extends MetadataType = any> {
103
105
async query < M extends MetadataType = T > ( vector : number [ ] , options : QueryOptions = { } ) {
104
106
options = QueryOptions . parse ( options ) ;
105
107
const data = await this . post < { id : number ; metadata : string ; score : number } [ ] > (
106
- constants . DRIA_SEARCH_URL + "/query" ,
108
+ constants . DRIA . SEARCH_URL + "/query" ,
107
109
{ vector, contract_id : this . getContractId ( ) , top_n : options . topK } ,
108
110
) ;
109
111
return data . map ( ( d ) => ( { ...d , metadata : JSON . parse ( d . metadata ) as M } ) ) ;
@@ -119,7 +121,7 @@ export class Dria<T extends MetadataType = any> {
119
121
*/
120
122
async fetch < M extends MetadataType = T > ( ids : number [ ] ) {
121
123
if ( ids . length === 0 ) throw "No IDs provided." ;
122
- const data = await this . post < { metadata : string [ ] ; vectors : number [ ] [ ] } > ( constants . DRIA_SEARCH_URL + "/fetch" , {
124
+ const data = await this . post < { metadata : string [ ] ; vectors : number [ ] [ ] } > ( constants . DRIA . SEARCH_URL + "/fetch" , {
123
125
id : ids ,
124
126
contract_id : this . getContractId ( ) ,
125
127
} ) ;
@@ -145,7 +147,7 @@ export class Dria<T extends MetadataType = any> {
145
147
items = BatchVectors . parse ( items ) as BatchVectors < M > ;
146
148
const encodedData = encodeBatchVectors ( items ) ;
147
149
const contractId = this . getContractId ( ) ;
148
- const data = await this . post < string > ( constants . DRIA_INSERT_URL + "/insert_vector" , {
150
+ const data = await this . post < string > ( constants . DRIA . INSERT_URL + "/insert_vector" , {
149
151
data : encodedData ,
150
152
batch_size : items . length ,
151
153
model : await this . getModel ( contractId ) ,
@@ -170,7 +172,7 @@ export class Dria<T extends MetadataType = any> {
170
172
items = BatchTexts . parse ( items ) as BatchTexts < M > ;
171
173
const encodedData = encodeBatchTexts ( items ) ;
172
174
const contractId = this . getContractId ( ) ;
173
- const data = await this . post < string > ( constants . DRIA_INSERT_URL + "/insert_text" , {
175
+ const data = await this . post < string > ( constants . DRIA . INSERT_URL + "/insert_text" , {
174
176
data : encodedData ,
175
177
batch_size : items . length ,
176
178
model : await this . getModel ( contractId ) ,
@@ -196,7 +198,7 @@ export class Dria<T extends MetadataType = any> {
196
198
* // you can now make queries, or insert data there
197
199
*/
198
200
async create ( name : string , embedding : ModelTypes , category : CategoryTypes , description : string = "" ) {
199
- const data = await this . post < { contract_id : string } > ( constants . DRIA_API_URL + "/v1/knowledge/index/create" , {
201
+ const data = await this . post < { contract_id : string } > ( constants . DRIA . API_URL + "/v1/knowledge/index/create" , {
200
202
name,
201
203
embedding,
202
204
category,
@@ -214,7 +216,7 @@ export class Dria<T extends MetadataType = any> {
214
216
*/
215
217
async delete ( contractId : string ) {
216
218
// expect message to be `true`
217
- const data = await this . post < { message : boolean } > ( constants . DRIA_API_URL + "/v1/knowledge/remove" , {
219
+ const data = await this . post < { message : boolean } > ( constants . DRIA . API_URL + "/v1/knowledge/remove" , {
218
220
contract_id : contractId ,
219
221
} ) ;
220
222
return data . message ;
@@ -231,7 +233,7 @@ export class Dria<T extends MetadataType = any> {
231
233
if ( contractId in this . models ) {
232
234
return this . models [ contractId ] ;
233
235
} else {
234
- const data = await this . get < { model : string } > ( constants . DRIA_API_URL + "/v1/knowledge/index/get_model" , {
236
+ const data = await this . get < { model : string } > ( constants . DRIA . API_URL + "/v1/knowledge/index/get_model" , {
235
237
contract_id : contractId ,
236
238
} ) ;
237
239
// memoize the model for later
@@ -247,37 +249,4 @@ export class Dria<T extends MetadataType = any> {
247
249
if ( this . contractId ) return this . contractId ;
248
250
throw Error ( "ContractID was not set." ) ;
249
251
}
250
-
251
- /**
252
- * A POST request wrapper.
253
- * @param url request URL
254
- * @param body request body
255
- * @template T type of response body
256
- * @returns parsed response body
257
- */
258
- private async post < T = unknown > ( url : string , body : unknown ) {
259
- const res = await this . client . post < { success : boolean ; data : T ; code : number } > ( url , body ) ;
260
- if ( res . status !== 200 ) {
261
- console . log ( { url, body } ) ;
262
- // console.log(res);
263
- throw `Dria API (POST) failed with ${ res . statusText } (${ res . status } ).\n${ res . data } ` ;
264
- }
265
- return res . data . data ;
266
- }
267
-
268
- /**
269
- * A GET request wrapper.
270
- * @param url request URL
271
- * @param params query parameters
272
- * @template T type of response body
273
- * @returns parsed response body
274
- */
275
- private async get < T = unknown > ( url : string , params : Record < string , unknown > = { } ) {
276
- const res = await this . client . get < { success : boolean ; data : T ; code : number } > ( url , { params } ) ;
277
- if ( res . status !== 200 ) {
278
- console . log ( res . request ) ;
279
- throw `Dria API (GET) failed with ${ res . statusText } (${ res . status } ).\n${ res . data } ` ;
280
- }
281
- return res . data . data ;
282
- }
283
252
}
0 commit comments