Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DSERV-715-llm #361

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions catalog_llm_query/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from flask import Flask, request, jsonify
import json
from arango import ArangoClient
from langchain_community.graphs import ArangoGraph
from langchain.chains import ArangoGraphQAChain
from langchain_openai import ChatOpenAI
import os
from catalog_llm_query.aql_examples import AQL_EXAMPLES
from select_collections import select_collections
from langchain_community.callbacks import get_openai_callback


# Initialize Flask app
app = Flask(__name__)


def get_config():
with open('config/development.json') as f:
config = json.load(f)
return config


def initialize_arango_graph(config):
# Configuration for ArangoDB
connection_uri = config['connectionUri']
dbName = config['dbName']
username = config['auth']['username']
password = config['auth']['password']

# Connect to ArangoDB and initialize graph
client = ArangoClient(hosts=connection_uri)
db = client.db(dbName, username=username, password=password)

# Initialize ArangoGraph
graph = ArangoGraph(db)
return graph


def initialize_collection_names(collection_schema):
collection_names = [collection['collection_name']
for collection in collection_schema]
return collection_names


def initialize_llm(config):
api_key = config['openai_api_key']
os.environ['OPENAI_API_KEY'] = api_key
model_name = config['openai_model']
model = ChatOpenAI(temperature=0, model_name=model_name)
return model


def ask_llm(question):
selected_collection_names = select_collections(question, collection_names)

updated_graph = get_updated_graph(
graph, collection_schema, selected_collection_names)
chain = ArangoGraphQAChain.from_llm(
model, graph=updated_graph, verbose=True, allow_dangerous_requests=True)
# Set the maximum number of AQL Query Results to return to 5
# This avoids burning the LLM token limit on JSON results
chain.top_k = 5
# Specify the maximum amount of AQL Generation attempts that should be made
# before returning an error
chain.max_aql_generation_attempts = 5

# Specify whether or not to return the AQL Query in the output dictionary
# Use `chain("...")` instead of `chain.invoke("...")` to see this change
chain.return_aql_query = True

# Specify whether or not to return the AQL JSON Result in the output dictionary
# Use `chain("...")` instead of `chain.invoke("...")` to see this change
chain.return_aql_result = True
# The AQL Examples modifier instructs the LLM to adapt its AQL-completion style
# to the user’s examples. These examples arepassed to the AQL Generation Prompt
# Template to promote few-shot-learning.

chain.aql_examples = AQL_EXAMPLES
with get_openai_callback() as cb:
response = chain.invoke(question)
print(cb)
return response


config = get_config()
graph = initialize_arango_graph(config['database'])
collection_schema = graph.schema['Collection Schema']
collection_names = initialize_collection_names(collection_schema)
model = initialize_llm(config)


def get_updated_graph(graph, collection_schema, selected_collection_names):
collection_schema_updated = []
for collection_name in selected_collection_names:
for collection in collection_schema:
if collection['collection_name'] == collection_name:
collection_schema_updated.append(collection)
break
updated_graph = graph
updated_graph.schema['Collection Schema'] = collection_schema_updated
return updated_graph


def build_response(block):
return {
**block, **{
'title': 'IGVF Catalog LLM Query',
}
}
# Create Flask endpoint


@app.route('/query', methods=['GET'])
def query():
# Get the query from the request arguments
user_query = request.args.get('query')
if not user_query:
return jsonify({'error': 'Query parameter is required'}), 400

try:
response = ask_llm(user_query)
return jsonify(build_response(response))

except Exception as e:
error = {
'query': user_query,
'error': str(e)
}
return jsonify(error), 500


# Run the Flask app
if __name__ == '__main__':
app.run(debug=True)
33 changes: 33 additions & 0 deletions catalog_llm_query/aql_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
AQL_EXAMPLES = """
# show me all the vairants that is in chromosome 1, position at 10000000?
WITH variants
FOR v IN variants
FILTER v.chr == "chr1" AND v.pos == 10000000
RETURN v
# Can you tell me the variant with SPDI of NC_000012.12:102855312:C:T is associated with what diseases?
WITH variants, variants_diseases, ontology_terms
FOR variant IN variants
FILTER variant.spdi == 'NC_000012.12:102855312:C:T'
FOR disease IN OUTBOUND variant variants_diseases
RETURN disease
# Show me all variants associated with cardiomyopathy
FOR v in variants
FILTER v._id IN (
FOR d IN variants_diseases
FILTER d._to IN (
FOR o in ontology_terms
FILTER o.name == 'cardiomyopathy'
RETURN o._id
)
RETURN d._from)
RETURN v
# What are the transcripts from the protein PARI_HUMAN?
FOR p IN proteins
FILTER p.name == 'PARI_HUMAN'
FOR t IN transcripts_proteins
FILTER t._to == p._id
RETURN t
"""
6 changes: 6 additions & 0 deletions catalog_llm_query/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Flask==2.2.3
python-arango==7.5.7
langchain==0.3.16
langchain-community==0.3.16
langchain-openai==0.3.2
openai==1.60.2
54 changes: 54 additions & 0 deletions catalog_llm_query/select_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import openai
import ast


def create_prompt(input_text, categories):
category_list = '\n'.join([f'- {cat}' for cat in categories])
return f"""
Please categorize the following input into one or more of the predefined categories. Only return the category names and return the answer in json format. for example: {{"category_names": ["category1", "category2"]}}.
Input: {input_text}
Categories:
{category_list}
Here is the examples you can learn from:
###
input: what diseases are associated with gene PAH?
answer: ["genes", "diseases_genes"]
###
input: Tell me about gene PAH?
answer: ["genes"]
###
input: The variant with SPDI of NC_000001.11:981168:A:G affects the expression of several genes. What are the genes that are affected?
answer: ["genes", 'variants_genes', 'variants]
###
input: Can you tell me the variant with SPDI of NC_000012.12:102855312:C:T is associated with what diseases?
answer: ["variants", "variants_diseases", "ontology_terms"]
###
input: What does NEK5 interact with?
answer: ["proteins", "proteins_proteins"]
###
"""


def select_collections(query, collection_names):
RESPONSE_FORMAT = {'type': 'json_object'}

content = create_prompt(query, collection_names)
response = openai.chat.completions.create(
response_format=RESPONSE_FORMAT,
model='gpt-4o',
temperature=0,
messages=[
{'role': 'user', 'content': content},
]
)
output = response.choices[0].message.content
try:
json_obj = ast.literal_eval(output)
return json_obj['category_names']
except:
return output
5 changes: 4 additions & 1 deletion config/development.json
Original file line number Diff line number Diff line change
@@ -22,5 +22,8 @@
"password": ""
}
},
"cluster": ["0.0.0.0"]
"cluster": ["0.0.0.0"],
"openai_api_key": "XXXXXXXX",
"openai_model": "gpt-4o",
"catalog_llm_query": "http://127.0.0.1:5000/query?"
}
5 changes: 4 additions & 1 deletion src/__tests__/env.test.ts
Original file line number Diff line number Diff line change
@@ -16,7 +16,10 @@ describe('System configuration', () => {
username: 'user',
password: 'psswd'
}
}
},
openai_api_key: 'XXXXXXXX',
openai_model: 'gpt-model',
catalog_llm_query: 'http://127.0.0.1:5000/query?'
}

beforeEach(() => {
2 changes: 2 additions & 0 deletions src/database.ts
Original file line number Diff line number Diff line change
@@ -11,3 +11,5 @@ export const db = new Database({
password: dbConfig.auth.password
}
})

export const llmQueryUrl = envData.catalog_llm_query
5 changes: 4 additions & 1 deletion src/env.ts
Original file line number Diff line number Diff line change
@@ -17,7 +17,10 @@ const envSchema = z.object({
username: z.string(),
password: z.string()
})
})
}),
openai_api_key: z.string(),
openai_model: z.string(),
catalog_llm_query: z.string()
})

let config = envConfig
5 changes: 4 additions & 1 deletion src/env/development.json
Original file line number Diff line number Diff line change
@@ -12,5 +12,8 @@
"username": "username",
"password": "password"
}
}
},
"openai_api_key": "XXXXXXXX",
"openai_model": "gpt-4o",
"catalog_llm_query": "http://127.0.0.1:5000/query?"
}
5 changes: 4 additions & 1 deletion src/routers/datatypeRouters/nodes/_all.ts
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@ import { drugsRouters } from './drugs'
import { codingVariantsRouters } from './coding_variants'
import { genesStructureRouters } from './genes_structure'
import { pathwaysRouters } from './pathways'
import { llmQueryRouters } from './llm_query'

export const nodeRouters = {
...ontologyRouters,
@@ -26,5 +27,7 @@ export const nodeRouters = {
...studiesRouters,
...codingVariantsRouters,
...genesStructureRouters,
...pathwaysRouters
...pathwaysRouters,
...llmQueryRouters

}
48 changes: 48 additions & 0 deletions src/routers/datatypeRouters/nodes/llm_query.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { z } from 'zod'
import { llmQueryUrl } from '../../../database'
import { publicProcedure } from '../../../trpc'
import { paramsFormatType } from '../_helpers'
import { descriptions } from '../descriptions'
import { TRPCError } from '@trpc/server'

const queryFormat = z.object({
query: z.string()
})

const outputFormat = z.object({
query: z.string(),
result: z.string()

})

async function query (input: paramsFormatType): Promise<any> {
const url = `${llmQueryUrl}query=${encodeURIComponent(input.query as string)}`
const response = await fetch(url, {
method: 'GET',
headers: {
'Content-Type': 'application/json'
}
})

if (!response.ok) {
throw new TRPCError({
code: 'BAD_REQUEST',
message: 'The query could not be executed.'
})
}
const jsonObj = await response.json()
return {
query: input.query,
result: jsonObj.result
}
}

const llmQuery = publicProcedure
.meta({ openapi: { method: 'GET', path: '/llm-query', description: descriptions.genes } })
.input(queryFormat)
.output(outputFormat)
.query(async ({ input }) => await query(input))

export const llmQueryRouters = {
llmQuery
}