Skip to content

Commit d0adfc0

Browse files
authored
Add local client to be used with dria-cli (#4)
* add local client to be used with dria-cli * depracate `query` * bump version
1 parent 32dbfca commit d0adfc0

File tree

9 files changed

+230
-69
lines changed

9 files changed

+230
-69
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ const dria = new Dria({ apiKey });
9797

9898
contractId = await dria.create(
9999
"My New Contract,
100-
"jinaai/jina-embeddings-v2-base-en",
100+
"jina-embeddings-v2-base-en",
101101
"Science",
102102
);
103103
dria.contractId = contractId;

package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "dria",
3-
"version": "0.0.3",
3+
"version": "0.0.4",
44
"license": "Apache-2.0",
55
"author": "FirstBatch Team <[email protected]>",
66
"contributors": [
@@ -23,6 +23,7 @@
2323
"lint": "eslint '**/*.ts' && echo 'All good.'",
2424
"test": "bun test --timeout 15000",
2525
"t": "bun run test",
26+
"test:local": "LOCAL_TEST=true bun test local --timeout 15000",
2627
"proto:code": "npx pbjs ./proto/insert.proto -w commonjs -t static-module -o ./proto/insert.js",
2728
"proto:type": "npx pbts ./proto/insert.js -o ./proto/insert.d.ts",
2829
"proto": "bun proto:code && bun proto:type"

src/clients/common.ts

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import { AxiosInstance } from "axios";
2+
3+
/**
4+
* A utility class that exposes `post` and `get` requests
5+
* for other clients to use. The constructor takes in an Axios instance.
6+
*/
7+
export class DriaCommon {
8+
constructor(protected readonly client: AxiosInstance) {}
9+
10+
/**
11+
* A POST request wrapper.
12+
* @param url request URL
13+
* @param body request body
14+
* @template T type of response body
15+
* @returns parsed response body
16+
*/
17+
protected async post<T = unknown>(url: string, body: unknown) {
18+
const res = await this.client.post<{ success: boolean; data: T; code: number }>(url, body);
19+
if (res.status !== 200) {
20+
throw `Dria API (POST) failed with ${res.statusText} (${res.status}).\n${res.data}`;
21+
}
22+
return res.data.data;
23+
}
24+
25+
/**
26+
* A GET request wrapper.
27+
* @param url request URL
28+
* @param params query parameters
29+
* @template T type of response body
30+
* @returns parsed response body
31+
*/
32+
protected async get<T = unknown>(url: string, params: Record<string, unknown> = {}) {
33+
const res = await this.client.get<{ success: boolean; data: T; code: number }>(url, { params });
34+
if (res.status !== 200) {
35+
throw `Dria API (GET) failed with ${res.statusText} (${res.status}).\n${res.data}`;
36+
}
37+
return res.data.data;
38+
}
39+
}

src/dria.ts renamed to src/clients/dria.ts

Lines changed: 29 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import Axios from "axios";
2-
import type { AxiosInstance } from "axios";
3-
import { encodeBatchTexts, encodeBatchVectors } from "./proto";
4-
import { SearchOptions, QueryOptions, BatchVectors, BatchTexts, MetadataType } from "./schemas";
5-
import { CategoryTypes, DriaParams, ModelTypes } from "./types";
6-
import constants from "./constants";
2+
import { encodeBatchTexts, encodeBatchVectors } from "../proto";
3+
import { SearchOptions, QueryOptions, BatchVectors, BatchTexts, MetadataType } from "../schemas";
4+
import { CategoryTypes, DriaParams, ModelTypes } from "../types";
5+
import constants from "../constants";
6+
import { DriaCommon } from "./common";
77

88
/**
9-
* Dria JS Client
9+
* ## Dria Client
1010
*
1111
* @param params optional API key and contract ID.
1212
*
@@ -40,8 +40,7 @@ import constants from "./constants";
4040
* dria.contractId = contractId;
4141
*/
4242
// eslint-disable-next-line @typescript-eslint/no-explicit-any
43-
export class Dria<T extends MetadataType = any> {
44-
protected client: AxiosInstance;
43+
export class Dria<T extends MetadataType = any> extends DriaCommon {
4544
contractId: string | undefined;
4645
/** Cached contract models. */
4746
private models: Record<string, ModelTypes> = {};
@@ -50,18 +49,21 @@ export class Dria<T extends MetadataType = any> {
5049
const apiKey = params.apiKey ?? process.env.DRIA_API_KEY;
5150
if (!apiKey) throw new Error("Missing Dria API key.");
5251

52+
super(
53+
Axios.create({
54+
headers: {
55+
"x-api-key": apiKey,
56+
"Content-Type": "application/json",
57+
"Accept-Encoding": "gzip, deflate, br",
58+
Connection: "keep-alive",
59+
Accept: "*/*",
60+
},
61+
// lets us handle the errors
62+
validateStatus: () => true,
63+
}),
64+
);
65+
5366
this.contractId = params.contractId;
54-
this.client = Axios.create({
55-
headers: {
56-
"x-api-key": apiKey,
57-
"Content-Type": "application/json",
58-
"Accept-Encoding": "gzip, deflate, br",
59-
Connection: "keep-alive",
60-
Accept: "*/*",
61-
},
62-
// lets us handle the errors
63-
validateStatus: () => true,
64-
});
6567
}
6668

6769
/** A text-based search.
@@ -79,7 +81,7 @@ export class Dria<T extends MetadataType = any> {
7981
async search(text: string, options: SearchOptions = {}) {
8082
options = SearchOptions.parse(options);
8183
const contractId = this.getContractId();
82-
return await this.post<{ id: number; metadata: string; score: number }[]>(constants.DRIA_SEARCH_URL + "/search", {
84+
return await this.post<{ id: number; metadata: string; score: number }[]>(constants.DRIA.SEARCH_URL + "/search", {
8385
query: text,
8486
top_n: options.topK,
8587
level: options.level,
@@ -103,7 +105,7 @@ export class Dria<T extends MetadataType = any> {
103105
async query<M extends MetadataType = T>(vector: number[], options: QueryOptions = {}) {
104106
options = QueryOptions.parse(options);
105107
const data = await this.post<{ id: number; metadata: string; score: number }[]>(
106-
constants.DRIA_SEARCH_URL + "/query",
108+
constants.DRIA.SEARCH_URL + "/query",
107109
{ vector, contract_id: this.getContractId(), top_n: options.topK },
108110
);
109111
return data.map((d) => ({ ...d, metadata: JSON.parse(d.metadata) as M }));
@@ -119,7 +121,7 @@ export class Dria<T extends MetadataType = any> {
119121
*/
120122
async fetch<M extends MetadataType = T>(ids: number[]) {
121123
if (ids.length === 0) throw "No IDs provided.";
122-
const data = await this.post<{ metadata: string[]; vectors: number[][] }>(constants.DRIA_SEARCH_URL + "/fetch", {
124+
const data = await this.post<{ metadata: string[]; vectors: number[][] }>(constants.DRIA.SEARCH_URL + "/fetch", {
123125
id: ids,
124126
contract_id: this.getContractId(),
125127
});
@@ -145,7 +147,7 @@ export class Dria<T extends MetadataType = any> {
145147
items = BatchVectors.parse(items) as BatchVectors<M>;
146148
const encodedData = encodeBatchVectors(items);
147149
const contractId = this.getContractId();
148-
const data = await this.post<string>(constants.DRIA_INSERT_URL + "/insert_vector", {
150+
const data = await this.post<string>(constants.DRIA.INSERT_URL + "/insert_vector", {
149151
data: encodedData,
150152
batch_size: items.length,
151153
model: await this.getModel(contractId),
@@ -170,7 +172,7 @@ export class Dria<T extends MetadataType = any> {
170172
items = BatchTexts.parse(items) as BatchTexts<M>;
171173
const encodedData = encodeBatchTexts(items);
172174
const contractId = this.getContractId();
173-
const data = await this.post<string>(constants.DRIA_INSERT_URL + "/insert_text", {
175+
const data = await this.post<string>(constants.DRIA.INSERT_URL + "/insert_text", {
174176
data: encodedData,
175177
batch_size: items.length,
176178
model: await this.getModel(contractId),
@@ -196,7 +198,7 @@ export class Dria<T extends MetadataType = any> {
196198
* // you can now make queries, or insert data there
197199
*/
198200
async create(name: string, embedding: ModelTypes, category: CategoryTypes, description: string = "") {
199-
const data = await this.post<{ contract_id: string }>(constants.DRIA_API_URL + "/v1/knowledge/index/create", {
201+
const data = await this.post<{ contract_id: string }>(constants.DRIA.API_URL + "/v1/knowledge/index/create", {
200202
name,
201203
embedding,
202204
category,
@@ -214,7 +216,7 @@ export class Dria<T extends MetadataType = any> {
214216
*/
215217
async delete(contractId: string) {
216218
// expect message to be `true`
217-
const data = await this.post<{ message: boolean }>(constants.DRIA_API_URL + "/v1/knowledge/remove", {
219+
const data = await this.post<{ message: boolean }>(constants.DRIA.API_URL + "/v1/knowledge/remove", {
218220
contract_id: contractId,
219221
});
220222
return data.message;
@@ -231,7 +233,7 @@ export class Dria<T extends MetadataType = any> {
231233
if (contractId in this.models) {
232234
return this.models[contractId];
233235
} else {
234-
const data = await this.get<{ model: string }>(constants.DRIA_API_URL + "/v1/knowledge/index/get_model", {
236+
const data = await this.get<{ model: string }>(constants.DRIA.API_URL + "/v1/knowledge/index/get_model", {
235237
contract_id: contractId,
236238
});
237239
// memoize the model for later
@@ -247,37 +249,4 @@ export class Dria<T extends MetadataType = any> {
247249
if (this.contractId) return this.contractId;
248250
throw Error("ContractID was not set.");
249251
}
250-
251-
/**
252-
* A POST request wrapper.
253-
* @param url request URL
254-
* @param body request body
255-
* @template T type of response body
256-
* @returns parsed response body
257-
*/
258-
private async post<T = unknown>(url: string, body: unknown) {
259-
const res = await this.client.post<{ success: boolean; data: T; code: number }>(url, body);
260-
if (res.status !== 200) {
261-
console.log({ url, body });
262-
// console.log(res);
263-
throw `Dria API (POST) failed with ${res.statusText} (${res.status}).\n${res.data}`;
264-
}
265-
return res.data.data;
266-
}
267-
268-
/**
269-
* A GET request wrapper.
270-
* @param url request URL
271-
* @param params query parameters
272-
* @template T type of response body
273-
* @returns parsed response body
274-
*/
275-
private async get<T = unknown>(url: string, params: Record<string, unknown> = {}) {
276-
const res = await this.client.get<{ success: boolean; data: T; code: number }>(url, { params });
277-
if (res.status !== 200) {
278-
console.log(res.request);
279-
throw `Dria API (GET) failed with ${res.statusText} (${res.status}).\n${res.data}`;
280-
}
281-
return res.data.data;
282-
}
283252
}

src/clients/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
export * from "./dria";
2+
export * from "./local";

src/clients/local.ts

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import Axios from "axios";
2+
import { QueryOptions, BatchVectors, MetadataType } from "../schemas";
3+
import { DriaCommon } from "./common";
4+
5+
/**
6+
* ## Dria Local Client
7+
*
8+
* Dria local client is a convenience tool that allows one to use the served knowledge via [Dria Docker](https://github.com/firstbatchxyz/dria-docker).
9+
* The URL defaults to `http://localhost:8080`, but you can override it.
10+
*
11+
* Unlike the other Dria client, Dria local does not require an API key or a contract ID, since the locally served knowledge serves a single contract.
12+
* Furthermore, text-based input is not allowed as that requires an embedding model to be running on the side.
13+
*
14+
* @template T default type of metadata; a metadata in Dria is a single-level mapping, with string keys and values of type `string`, `number`
15+
*
16+
* @example
17+
* // connects to http://localhost:8080
18+
* const dria = new DriaLocal();
19+
*
20+
* @example
21+
* const dria = new DriaLocal("your-url");
22+
*/
23+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
24+
export class DriaLocal<T extends MetadataType = any> extends DriaCommon {
25+
public url: string;
26+
constructor(url: string = "http://localhost:8080") {
27+
super(
28+
Axios.create({
29+
baseURL: url,
30+
headers: {
31+
"Content-Type": "application/json",
32+
Connection: "keep-alive",
33+
Accept: "*/*",
34+
},
35+
// lets us handle the errors
36+
validateStatus: () => true,
37+
}),
38+
);
39+
this.url = url;
40+
}
41+
42+
/** A simple health-check. */
43+
async health() {
44+
try {
45+
await this.get("/health");
46+
return true;
47+
} catch {
48+
return false;
49+
}
50+
}
51+
52+
/** A vector-based query.
53+
* @param vector query vector.
54+
* @param options query options:
55+
* - `topK`: number of results to return.
56+
* @template M type of the metadata, defaults to type provided to the client.
57+
* @returns an array of `topK` results with id, metadata and the relevancy score.
58+
* @example
59+
* const res = await dria.query<{about: string}>([0.1, 0.92, ..., 0.16]);
60+
* console.log(res[0].metadata.about);
61+
*
62+
* @deprecated local query is disabled right now
63+
*/
64+
private async query<M extends MetadataType = T>(vector: number[], options: QueryOptions = {}) {
65+
options = QueryOptions.parse(options);
66+
const data = await this.post<{ id: number; metadata: M; score: number }[]>("/query", {
67+
vector,
68+
top_n: options.topK,
69+
});
70+
return data;
71+
}
72+
73+
/** Fetch vectors with the given IDs.
74+
* @param ids an array of ids.
75+
* @template M type of the metadata, defaults to type provided to the client.
76+
* @returns an array of metadatas belonging to the given vector IDs.
77+
* @example
78+
* const res = await dria.fetch([0])
79+
* console.log(res[0])
80+
*/
81+
async fetch<M extends MetadataType = T>(ids: number[]) {
82+
if (ids.length === 0) throw "No IDs provided.";
83+
const data = await this.post<M[]>("/fetch", {
84+
id: ids,
85+
});
86+
return data;
87+
}
88+
89+
/**
90+
* Insert a batch of vectors to your existing knowledge.
91+
* @param items batch of vectors with optional metadatas
92+
* @returns a string indicating success
93+
* @example
94+
* const batch = [
95+
* {vector: [...], metadata: {}},
96+
* {vector: [...], metadata: {foo: 'bar'}},
97+
* // ...
98+
* ]
99+
* await dria.insertVectors(batch);
100+
*/
101+
async insertVectors<M extends MetadataType = T>(items: BatchVectors<M>) {
102+
items = BatchVectors.parse(items) as BatchVectors<M>;
103+
const data = await this.post<string>("/insert_vector", {
104+
data: items,
105+
});
106+
return data;
107+
}
108+
}

src/constants/index.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
export default {
2-
/** URL to make fetch / query / search requests */
3-
DRIA_SEARCH_URL: "https://search.dria.co/hnsw",
4-
/** URL to insert texts and vectors */
5-
DRIA_INSERT_URL: "https://search.dria.co/hnswt",
6-
/** URL to get model */
7-
DRIA_API_URL: "https://api.dria.co",
2+
// TODO: naming doesnt really make sense here...
3+
DRIA: {
4+
/** URL to make fetch / query / search requests */
5+
SEARCH_URL: "https://search.dria.co/hnsw",
6+
/** URL to insert texts and vectors */
7+
INSERT_URL: "https://search.dria.co/hnswt",
8+
/** URL to get model */
9+
API_URL: "https://api.dria.co",
10+
},
811
} as const;

src/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
export { Dria } from "./dria";
1+
export { Dria, DriaLocal } from "./clients";
22
export type { DriaParams } from "./types";

0 commit comments

Comments
 (0)