|
| 1 | +import { useState } from 'react'; |
| 2 | +import DeckGL from '@deck.gl/react'; |
| 3 | +import { ScatterplotLayer } from '@deck.gl/layers'; |
| 4 | +import { Map } from 'react-map-gl/maplibre'; |
| 5 | +import { SAMPLE_DATASETS } from './dataset'; |
| 6 | + |
| 7 | +import { AiAssistant } from '@openassistant/ui'; |
| 8 | +import { histogramFunctionDefinition } from '@openassistant/echarts'; |
| 9 | +import { |
| 10 | + CallbackFunctionProps, |
| 11 | + CustomFunctionContext, |
| 12 | +} from '@openassistant/core'; |
| 13 | +import { queryDuckDBFunctionDefinition } from '@openassistant/duckdb'; |
| 14 | + |
| 15 | +type PointData = { |
| 16 | + index: number; |
| 17 | + longitude: number; |
| 18 | + latitude: number; |
| 19 | + revenue: number; |
| 20 | + population: number; |
| 21 | +}; |
| 22 | + |
| 23 | +export function App() { |
| 24 | + // Add state for filtered indices |
| 25 | + const [filteredIndices, setFilteredIndices] = useState<number[]>([]); |
| 26 | + |
| 27 | + // Add state for radius multiplier |
| 28 | + const [radiusMultiplier, setRadiusMultiplier] = useState<number>(1); |
| 29 | + |
| 30 | + // Initial viewport state |
| 31 | + const initialViewState = { |
| 32 | + longitude: -98.5795, // Center of continental US |
| 33 | + latitude: 39.8283, // Center of continental US |
| 34 | + zoom: 3, // Zoomed out to show entire country |
| 35 | + pitch: 0, |
| 36 | + bearing: 0, |
| 37 | + }; |
| 38 | + |
| 39 | + // Sample data point |
| 40 | + const data = SAMPLE_DATASETS.myVenues; |
| 41 | + |
| 42 | + // Add LLM instructions |
| 43 | + const instructions = `You are a data analyst. You can help users to analyze data including: |
| 44 | + - changing the radius of the points |
| 45 | + - filtering the points by state |
| 46 | + - querying the data using selected variables |
| 47 | + - create a histogram of the selected variable |
| 48 | +
|
| 49 | +When responding to user queries: |
| 50 | +1. Analyze if the task requires one or multiple function calls |
| 51 | +2. For each required function: |
| 52 | + - Identify the appropriate function to call |
| 53 | + - Determine all required parameters |
| 54 | + - If parameters are missing, ask the user to provide them |
| 55 | + - Please ask the user to confirm the parameters |
| 56 | + - If the user doesn't agree, try to provide variable functions to the user |
| 57 | + - Execute functions in a sequential order |
| 58 | +3. For SQL query, please help to generate select query clause using the content of the dataset: |
| 59 | + - please use double quotes for table name |
| 60 | + - please only use the columns that are in the dataset context |
| 61 | + - please try to use the aggregate functions if possible |
| 62 | +
|
| 63 | +Please use the following data context to answer the user's question: |
| 64 | +Dataset Name: myVenues |
| 65 | +Fields: |
| 66 | +- index |
| 67 | +- location |
| 68 | +- longitude |
| 69 | +- latitude |
| 70 | +- revenue |
| 71 | +- population |
| 72 | + `; |
| 73 | + |
| 74 | + // a llm tool to change the radius of the points |
| 75 | + function radiusFunctionDefinition(context: CustomFunctionContext<any>) { |
| 76 | + return { |
| 77 | + name: 'radius', |
| 78 | + description: 'Make the radius of the points larger or smaller', |
| 79 | + properties: { |
| 80 | + radiusMultiplier: { |
| 81 | + type: 'number', |
| 82 | + description: 'The multiplier for the radius of the points', |
| 83 | + }, |
| 84 | + }, |
| 85 | + required: ['radiusMultiplier'], |
| 86 | + callbackFunction: async (props: CallbackFunctionProps) => { |
| 87 | + const { functionName, functionArgs, functionContext } = props; |
| 88 | + const { radiusMultiplier } = functionArgs; |
| 89 | + |
| 90 | + const { changeRadius } = functionContext as { |
| 91 | + changeRadius: (radiusMultiplier: number) => void; |
| 92 | + }; |
| 93 | + changeRadius(Number(radiusMultiplier)); |
| 94 | + |
| 95 | + return { |
| 96 | + type: 'success', |
| 97 | + name: functionName, |
| 98 | + result: `Radius multiplier set to ${radiusMultiplier}`, |
| 99 | + }; |
| 100 | + }, |
| 101 | + callbackFunctionContext: context, |
| 102 | + }; |
| 103 | + } |
| 104 | + |
| 105 | + function highlightPoints(indices: number[]) { |
| 106 | + // highlight the points |
| 107 | + setFilteredIndices(indices); |
| 108 | + } |
| 109 | + |
| 110 | + const filterByStateCallbackFunctionContext = { |
| 111 | + points: SAMPLE_DATASETS.myVenues, |
| 112 | + }; |
| 113 | + |
| 114 | + function filterByStateCallback(props) { |
| 115 | + const { functionArgs, functionContext } = props; |
| 116 | + const { state, boundingBox } = functionArgs; |
| 117 | + const { points } = functionContext; |
| 118 | + // get the index of the points that fits inside the bounding box |
| 119 | + const filteredIndices = points |
| 120 | + .filter((point) => { |
| 121 | + const isInside = |
| 122 | + point.longitude >= boundingBox[0] && |
| 123 | + point.longitude <= boundingBox[2] && |
| 124 | + point.latitude >= boundingBox[1] && |
| 125 | + point.latitude <= boundingBox[3]; |
| 126 | + return isInside; |
| 127 | + }) |
| 128 | + .map((point) => point.index); |
| 129 | + |
| 130 | + // highlight the filtered points |
| 131 | + highlightPoints(filteredIndices); |
| 132 | + |
| 133 | + return { |
| 134 | + type: 'success', |
| 135 | + result: `${filteredIndices.length} points are filtered by state ${state} and bounding box ${boundingBox}`, |
| 136 | + }; |
| 137 | + } |
| 138 | + |
| 139 | + function filterByStateFunctionDefinition( |
| 140 | + callbackFunction, |
| 141 | + callbackFunctionContext |
| 142 | + ) { |
| 143 | + return { |
| 144 | + name: 'filterByState', |
| 145 | + description: 'Filter points by state', |
| 146 | + properties: { |
| 147 | + state: { |
| 148 | + type: 'string', |
| 149 | + description: 'The state to filter by', |
| 150 | + }, |
| 151 | + boundingBox: { |
| 152 | + type: 'array', |
| 153 | + description: |
| 154 | + 'The bounding box coordinates of the state. The format is [minLongitude, minLatitude, maxLongitude, maxLatitude]. If not provided, please try to use approximate bounding box of the state.00', |
| 155 | + items: { |
| 156 | + type: 'number', |
| 157 | + }, |
| 158 | + }, |
| 159 | + }, |
| 160 | + required: ['state'], |
| 161 | + callbackFunction, |
| 162 | + callbackFunctionContext, |
| 163 | + }; |
| 164 | + } |
| 165 | + |
| 166 | + // Define LLM tools |
| 167 | + const functionTools = [ |
| 168 | + histogramFunctionDefinition({ |
| 169 | + getValues: (datasetName: string, variableName: string) => { |
| 170 | + const dataset = SAMPLE_DATASETS[datasetName]; |
| 171 | + return dataset.map((item) => item[variableName]); |
| 172 | + }, |
| 173 | + onSelected: (datasetName: string, selectedIndices: number[]) => { |
| 174 | + console.log(datasetName, selectedIndices); |
| 175 | + setFilteredIndices([...selectedIndices]); |
| 176 | + }, |
| 177 | + config: { isDraggable: true, theme: 'light' }, |
| 178 | + }), |
| 179 | + radiusFunctionDefinition({ |
| 180 | + changeRadius: (radiusMultiplier: number) => { |
| 181 | + console.log('changeRadius', radiusMultiplier); |
| 182 | + setRadiusMultiplier(radiusMultiplier); |
| 183 | + }, |
| 184 | + }), |
| 185 | + filterByStateFunctionDefinition( |
| 186 | + filterByStateCallback, |
| 187 | + filterByStateCallbackFunctionContext |
| 188 | + ), |
| 189 | + queryDuckDBFunctionDefinition({ |
| 190 | + getValues: (datasetName, variableName) => { |
| 191 | + const dataset = SAMPLE_DATASETS[datasetName]; |
| 192 | + return dataset.map((row) => row[variableName]); |
| 193 | + }, |
| 194 | + config: { isDraggable: true }, |
| 195 | + }), |
| 196 | + ]; |
| 197 | + |
| 198 | + // Create a scatterplot layer with key prop for forcing updates |
| 199 | + const layers = [ |
| 200 | + new ScatterplotLayer<PointData>({ |
| 201 | + id: 'scatter-plot', |
| 202 | + data, |
| 203 | + pickable: true, |
| 204 | + opacity: 0.8, |
| 205 | + stroked: true, |
| 206 | + filled: true, |
| 207 | + radiusScale: 1, |
| 208 | + radiusMinPixels: 1, |
| 209 | + radiusMaxPixels: 100, |
| 210 | + lineWidthMinPixels: 1, |
| 211 | + getPosition: (d: PointData) => [d.longitude, d.latitude], |
| 212 | + getRadius: (d: PointData) => (d.revenue / 200) * radiusMultiplier, |
| 213 | + getFillColor: (d: PointData) => { |
| 214 | + return filteredIndices.includes(d.index) ? [255, 0, 0] : [0, 0, 255]; |
| 215 | + }, |
| 216 | + getLineColor: [0, 0, 0], |
| 217 | + updateTriggers: { |
| 218 | + getFillColor: [filteredIndices], |
| 219 | + getRadius: [radiusMultiplier], |
| 220 | + }, |
| 221 | + }), |
| 222 | + ]; |
| 223 | + |
| 224 | + const mapStyle = |
| 225 | + 'https://basemaps.cartocdn.com/gl/positron-nolabels-gl-style/style.json'; |
| 226 | + |
| 227 | + return ( |
| 228 | + <div className="flex flex-row w-screen h-screen"> |
| 229 | + <div className="w-[550px] h-[800px] m-4"> |
| 230 | + <AiAssistant |
| 231 | + name="My Assistant" |
| 232 | + apiKey="your-api-key" |
| 233 | + version="v1" |
| 234 | + modelProvider="openai" |
| 235 | + model="gpt-4o" |
| 236 | + welcomeMessage="Hello, how can I help you today?" |
| 237 | + instructions={instructions} |
| 238 | + functions={functionTools} |
| 239 | + /> |
| 240 | + </div> |
| 241 | + <div className="deckgl h-full w-full"> |
| 242 | + <DeckGL |
| 243 | + initialViewState={initialViewState} |
| 244 | + controller={true} |
| 245 | + layers={layers} |
| 246 | + style={{ position: 'relative' }} |
| 247 | + > |
| 248 | + <Map reuseMaps mapStyle={mapStyle} /> |
| 249 | + </DeckGL> |
| 250 | + </div> |
| 251 | + </div> |
| 252 | + ); |
| 253 | +} |
0 commit comments