Skip to content

chatbot #4523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: staging
Choose a base branch
from
Open

chatbot #4523

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/spatial/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,23 @@ Ensure latitude and longitude have high precision (up to six decimal places) for

---

### **Chatbot Tool**

The Chatbot Tool provides a conversational interface for users to interact with the AirQo platform and receive information about air quality and related topics.

```http
POST http://127.0.0.1:5000/api/v2/spatial/chat
```

##### Example Request Body
```json
{
"grid_id": "659d036497e611001236cd1b",
"start_time": "2024-12-01T00:00",
"end_time": "2025-03-06T00:00",
"prompt":"with 300 words, write the air quality report "
}
```

This README provides an overview of the setup, API endpoints, and example requests. For further details, consult the official AirQo API documentation.

8 changes: 2 additions & 6 deletions src/spatial/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@
# Add custom CORS header
@app.after_request
def add_cors_headers(response):
response.headers[
"Access-Control-Allow-Origin"
] = "*" # You can specify specific origins instead of '*'
response.headers[
"Access-Control-Allow-Headers"
] = "Content-Type, Authorization, X-Requested-With, X-Auth-Token"
response.headers["Access-Control-Allow-Origin"] = "*" # You can specify specific origins instead of '*'
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Auth-Token"
response.headers["Access-Control-Allow-Methods"] = "GET,PUT,POST,DELETE,OPTION"
response.headers["Access-Control-Allow-Credentials"] = "true"
return response
Expand Down
5 changes: 5 additions & 0 deletions src/spatial/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class Config:
)
ANALTICS_URL = os.getenv("ANALTICS_URL")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
# Redis configuration
REDIS_HOST = os.getenv("REDIS_HOST", "localhost") # Default to 'localhost' if not set
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379)) # Convert to int, default to 6379
REDIS_DB = int(os.getenv("REDIS_DB", 0)) # Convert to int, default to 0
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None) # Default to None if not set
class ProductionConfig(Config):
DEBUG = False
TESTING = False
Expand Down
10 changes: 8 additions & 2 deletions src/spatial/controllers/controllers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# controller/controller.py
from flask import Blueprint, request, jsonify
from views.chatbot_views import ChatbotView
from views.getis_services import SpatialDataHandler
from views.getis_confidence_services import SpatialDataHandler_confidence
from views.localmoran_services import SpatialDataHandler_moran
Expand All @@ -12,7 +13,7 @@
from views.satellite_predictions import SatellitePredictionView
from views.site_category_view import SiteCategorizationView
from views.site_selection_views import SiteSelectionView
from views.report_view import ReportView
from views.report_view import ReportView


controller_bp = Blueprint("controller", __name__)
Expand Down Expand Up @@ -78,4 +79,9 @@ def fetch_air_quality_without_llm():

@controller_bp.route("/air_quality_report_with_customised_prompt", methods=["POST"])
def fetch_air_quality_with_customised_prompt():
return ReportView.generate_air_quality_report_with_customised_prompt_gemini()
return ReportView.generate_air_quality_report_with_customised_prompt_gemini()

@controller_bp.route("/chatbot", methods=["POST"])
def Chatbot_Views():
return ChatbotView.chat_endpoint()

194 changes: 194 additions & 0 deletions src/spatial/models/chatbot_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import requests
import redis
import json
import google.generativeai as genai
import logging
import re
import threading
from urllib.parse import urlencode
from configure import Config

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Configure API keys
GOOGLE_API_KEY = Config.GOOGLE_API_KEY
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove or relocate the Flask app initialization.

The app = Flask(__name__) call in a model file is a code smell. Model classes should be framework-agnostic wherever possible. If a centralized Flask application is needed, consider initializing it in a dedicated application or controller module.

genai.configure(api_key=GOOGLE_API_KEY)

# Initialize Redis client
try:
redis_client = redis.StrictRedis(
host=Config.REDIS_HOST or 'localhost',
port=Config.REDIS_PORT or 6379,
db=Config.REDIS_DB or 0,
password=Config.REDIS_PASSWORD or None,
decode_responses=True
)
# Test the connection
redis_client.ping()
logging.info("Connected to Redis")
except Exception as e:
logging.error(f"Error connecting to Redis: {e}")
redis_client = None
# lock for thread safety to prevent race conditions when multiple users request the same data.
data_fetch_lock = threading.Lock()

class DataFetcher:
@staticmethod
def fetch_air_quality_data(grid_id, start_time, end_time):
"""Fetch air quality data and cache it in Redis to avoid redundant API calls."""
cache_key = f"air_quality:{grid_id}:{start_time}:{end_time}"

# Check if data is cached in Redis
cached_data = None
if redis_client:
try:
cached_data = redis_client.get(cache_key)
except Exception as e:
logging.error(f"Error retrieving data from Redis: {e}")
else:
logging.error("Redis client not available, skipping cache check")

if cached_data:
logging.info(f"Retrieved cached data for {cache_key}")
return json.loads(cached_data)

token = Config.AIRQO_API_TOKEN
analytics_url = Config.ANALTICS_URL
if not token:
Comment on lines +57 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Implement thread-safety in the LLM response method.

A threading lock is declared but never used. To ensure thread safety when making external calls or modifying shared data in _llm_response, wrap the logic in with self.lock:. This helps avoid race conditions if multiple threads access the method concurrently.

Also applies to: 94-114, 164-181

logging.error("AIRQO_API_TOKEN is not set.")
return None

query_params = {'token': token}
url = f"{analytics_url}?{urlencode(query_params)}"
payload = {"grid_id": grid_id, "start_time": start_time, "end_time": end_time}
logging.info(f"Fetching air quality data with payload: {payload}")

try:
response = requests.post(url, json=payload, timeout=10)
response.raise_for_status()
data = response.json()
# Cache response in Redis for 1 hour
Comment on lines +64 to +71
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor or remove duplicated code blocks.

The file contains two versions of _prepare_data_context, _rule_based_response, _llm_response, and chat methods. Duplicated logic can cause confusion and maintenance effort. Consider merging them into a single, well-structured set of methods.

Also applies to: 121-132, 72-93, 133-163, 94-114, 164-181, 115-120, 182-194

# setex is redis command to set a key with an expiration time
with data_fetch_lock:
if redis_client:
redis_client.setex(cache_key, 3600, json.dumps(data))
logging.info(f"Data fetched and cached for grid_id: {grid_id}")
return data
except requests.exceptions.HTTPError as http_err:
logging.error(f"HTTP error: {http_err}")
except requests.exceptions.RequestException as req_err:
logging.error(f"Request error: {req_err}")
except ValueError as json_err:
logging.error(f"JSON error: {json_err}")
return None

class AirQualityChatbot:
def __init__(self, air_quality_data):
self.data = air_quality_data or {}
self.grid_name = self.data.get('airquality', {}).get('sites', {}).get('grid name', ['Unknown'])[0]
self.annual_data = self.data.get('airquality', {}).get('annual_pm', [{}])[0] or {}
self.daily_mean_data = self.data.get('airquality', {}).get('daily_mean_pm', []) or []
self.diurnal = self.data.get('airquality', {}).get('diurnal', []) or []
self.monthly_data = self.data.get('airquality', {}).get('site_monthly_mean_pm', []) or []
self.site_names = [item.get('site_name', 'Unknown') for item in self.data.get('airquality', {}).get('site_annual_mean_pm', [])] or ['Unknown']
self.num_sites = self.data.get('airquality', {}).get('sites', {}).get('number_of_sites', 'Unknown')
self.starttime = self.data.get('airquality', {}).get('period', {}).get('startTime', '')[:10] or 'N/A'
self.endtime = self.data.get('airquality', {}).get('period', {}).get('endTime', '')[:10] or 'N/A'
self.annual_pm2_5 = self.annual_data.get("pm2_5_calibrated_value", 'N/A')
self.mean_pm2_5_by_site = self.data.get('airquality', {}).get('site_annual_mean_pm', [])

# Sort daily_mean_data to get the most recent measurement
if self.daily_mean_data:
sorted_daily = sorted(self.daily_mean_data, key=lambda x: x.get('date', ''), reverse=True)
self.today_pm2_5 = sorted_daily[0].get('pm2_5_calibrated_value', 'N/A') if sorted_daily else 'N/A'
self.today_date = sorted_daily[0].get('date', 'N/A') if sorted_daily else 'N/A'
else:
self.today_pm2_5 = 'N/A'
self.today_date = 'N/A'

self.peak_diurnal = max(self.diurnal, key=lambda x: x.get('pm2_5_calibrated_value', 0)) if self.diurnal else {}

try:
# Gemini model
self.gemini_model = genai.GenerativeModel('gemini-2.0-flash')
except Exception as e:
logging.error(f"Failed to initialize Gemini model: {e}")
self.gemini_model = None
self.lock = threading.Lock()

def _prepare_data_context(self):
"""Prepare a concise data context for the LLM."""
return (
f"AirQo data for {self.grid_name} ({self.starttime}-{self.endtime}): "
f"Annual PM2.5={self.annual_pm2_5} µg/m³, Sites={self.num_sites}, "
f"Most recent daily PM2.5={self.today_pm2_5} µg/m³ on {self.today_date}, "
f"Diurnal peak={self.peak_diurnal.get('pm2_5_calibrated_value', 'N/A')} µg/m³ at {self.peak_diurnal.get('hour', 'N/A')}:00, "
f"Site names={self.site_names}."
f"Monthly data={self.monthly_data}."
f"Daily mean data={self.daily_mean_data}."
f"Site annual mean data={self.mean_pm2_5_by_site}."
)

def _rule_based_response(self, user_prompt):
"""Handle common queries with precomputed responses."""
prompt = user_prompt.lower()

if re.search(r"(today|now).*air.*quality", prompt):
if self.today_pm2_5 != 'N/A':
return f"The most recent PM2.5 in {self.grid_name} is {self.today_pm2_5} µg/m³ on {self.today_date}."
return "No recent air quality data available."

if re.search(r"(worst|highest|peak).*time", prompt):
if self.peak_diurnal:
return f"Pollution peaks at {self.peak_diurnal.get('hour', 'N/A')}:00 with {self.peak_diurnal.get('pm2_5_calibrated_value', 'N/A')} µg/m³."
return "No diurnal data available."

if re.search(r"how.*many.*(site|sites|monitors)", prompt):
if self.num_sites != 'Unknown':
return f"There are {self.num_sites} monitoring sites in {self.grid_name}."
return "Number of sites is not available."

if re.search(r"(year|annual).*average", prompt):
if self.annual_pm2_5 != 'N/A':
return f"The annual PM2.5 average in {self.grid_name} is {self.annual_pm2_5} µg/m³."
return "Annual air quality data is not available."

if re.search(r"(?:where|which|list).*(?:site|sites|locations)", prompt):
if self.site_names != ['Unknown']:
return f"Monitoring sites in {self.grid_name}: {', '.join(self.site_names)}."
return "Site information is not available."

return None

def _llm_response(self, user_prompt):
"""Generate a response using the Gemini model for complex queries."""
if not self.gemini_model:
return "Language model is not available."

full_prompt = (
f"Data: {self._prepare_data_context()}\n"
f"User: {user_prompt}\n"
"Respond concisely and accurately based on the data."
)

try:
response = self.gemini_model.generate_content(full_prompt)
return response.text
except Exception as e:
logging.error(f"LLM error: {e}")
return "Sorry, I couldn't generate a response."

def chat(self, user_prompt):
"""Process user queries and return appropriate responses."""
if not self.data:
return "Air quality data is not available for the specified grid and time period."
if not user_prompt or not isinstance(user_prompt, str):
return "Please provide a valid question about air quality."
if len(user_prompt) > 500:
return "Your question is too long. Please keep it under 500 characters."

rule_response = self._rule_based_response(user_prompt)
if rule_response:
return rule_response
return self._llm_response(user_prompt)
2 changes: 1 addition & 1 deletion src/spatial/models/report_datafetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, data):


# Initialize models once in the constructor
self.gemini_model = genai.GenerativeModel('gemini-pro')
self.gemini_model = genai.GenerativeModel('gemini-2.0-flash')

def _prepare_base_info(self):
return (
Expand Down
1 change: 1 addition & 0 deletions src/spatial/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ joblib~=1.4.2
lightgbm~=4.1.0
numpy
google-generativeai
redis
93 changes: 93 additions & 0 deletions src/spatial/views/chatbot_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# views/chatbot_views.py
from flask import request, jsonify
from models.chatbot_model import AirQualityChatbot, DataFetcher
import logging
import uuid # For generating session IDs if not provided

logger = logging.getLogger(__name__)

class ChatbotView:
@staticmethod
def chat_endpoint():
"""
Handles chatbot API requests for air quality information with session management.
Expects JSON payload with grid_id, start_time, end_time, prompt, and optional session_id, session_title.
Returns JSON response with chatbot's answer and session metadata.
"""
# Validate request payload
payload = request.json
required_fields = ["grid_id", "start_time", "end_time", "prompt"]
if not payload or not all(key in payload for key in required_fields):
logger.error("Invalid payload: missing required fields")
return jsonify({
"error": "Missing required fields: grid_id, start_time, end_time, prompt",
"status": "failure"
}), 400

# Extract parameters
grid_id = payload["grid_id"]
start_time = payload["start_time"]
end_time = payload["end_time"]
user_prompt = payload["prompt"]

# Session metadata (optional)
session_id = payload.get("session_id", str(uuid.uuid4())) # Generate if not provided

# Automatically generate session_title from prompt (truncate to reasonable length)
session_title = payload.get("session_title") # Check if provided
if not session_title: # If not provided, generate from prompt
session_title = (user_prompt[:50] + "...") if len(user_prompt) > 50 else user_prompt

# Validate prompt
if not user_prompt or not isinstance(user_prompt, str):
logger.error(f"Invalid prompt received: {user_prompt}")
return jsonify({
"error": "No valid prompt provided",
"status": "failure"
}), 400
Comment on lines +41 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add additional validation for time format and grid_id.

The current implementation only validates the presence of grid_id, start_time, and end_time, and does some basic validation on the prompt. Consider adding more thorough validation for date/time formats and grid_id format/existence.

# After extracting parameters, add:
+        # Validate time formats
+        try:
+            # Use appropriate datetime parsing based on your expected format
+            from datetime import datetime
+            datetime.strptime(start_time, "%Y-%m-%dT%H:%M:%SZ")
+            datetime.strptime(end_time, "%Y-%m-%dT%H:%M:%SZ")
+            
+            # Ensure start_time is before end_time
+            if datetime.strptime(start_time, "%Y-%m-%dT%H:%M:%SZ") >= datetime.strptime(end_time, "%Y-%m-%dT%H:%M:%SZ"):
+                logger.error(f"Invalid time range: start_time must be before end_time")
+                return jsonify({
+                    "error": "Invalid time range: start_time must be before end_time",
+                    "status": "failure"
+                }), 400
+        except ValueError:
+            logger.error(f"Invalid time format for start_time or end_time")
+            return jsonify({
+                "error": "Invalid time format. Expected format: YYYY-MM-DDThh:mm:ssZ",
+                "status": "failure"
+            }), 400
+            
+        # Validate grid_id format if needed
+        # For example, if grid_id should be a UUID or follow a specific pattern
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Validate prompt
if not user_prompt or not isinstance(user_prompt, str):
logger.error(f"Invalid prompt received: {user_prompt}")
return jsonify({
"error": "No valid prompt provided",
"status": "failure"
}), 400
# Assume parameters such as start_time, end_time, and grid_id have already been extracted
# Validate time formats
try:
# Use appropriate datetime parsing based on your expected format
from datetime import datetime
datetime.strptime(start_time, "%Y-%m-%dT%H:%M:%SZ")
datetime.strptime(end_time, "%Y-%m-%dT%H:%M:%SZ")
# Ensure start_time is before end_time
if datetime.strptime(start_time, "%Y-%m-%dT%H:%M:%SZ") >= datetime.strptime(end_time, "%Y-%m-%dT%H:%M:%SZ"):
logger.error(f"Invalid time range: start_time must be before end_time")
return jsonify({
"error": "Invalid time range: start_time must be before end_time",
"status": "failure"
}), 400
except ValueError:
logger.error(f"Invalid time format for start_time or end_time")
return jsonify({
"error": "Invalid time format. Expected format: YYYY-MM-DDThh:mm:ssZ",
"status": "failure"
}), 400
# Validate grid_id format if needed
# For example, if grid_id should be a UUID or follow a specific pattern
# Validate prompt
if not user_prompt or not isinstance(user_prompt, str):
logger.error(f"Invalid prompt received: {user_prompt}")
return jsonify({
"error": "No valid prompt provided",
"status": "failure"
}), 400


try:
# Fetch air quality data
logger.info(f"Fetching data for grid_id: {grid_id}, {start_time} to {end_time}")
air_quality_data = DataFetcher.fetch_air_quality_data(grid_id, start_time, end_time)

if not air_quality_data or 'airquality' not in air_quality_data:
logger.error(f"No valid air quality data returned for grid_id: {grid_id}")
return jsonify({
"error": "Failed to fetch air quality data",
"status": "failure"
}), 500

# Initialize chatbot and get response
chatbot = AirQualityChatbot(air_quality_data)
response = chatbot.chat(user_prompt)

if not response:
logger.warning(f"Empty response generated for prompt: {user_prompt}")
return jsonify({
"error": "No response generated",
"status": "failure"
}), 500

logger.info(f"Successfully processed request for {grid_id}, session_id: {session_id}")
return jsonify({
"response": response,
"status": "success",
"grid_id": grid_id,
"period": {
"start_time": start_time,
"end_time": end_time
},
"session": {
"session_id": session_id,
"session_title": session_title,
"timestamp": start_time # Optional: to track when the session started
}
}), 200

except Exception as e:
logger.error(f"Unhandled exception: {str(e)}")
return jsonify({
"error": "Internal server error",
"status": "failure"
}), 500
Comment on lines +88 to +93
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

❓ Verification inconclusive

Enhance exception handling to prevent information exposure.

The current exception handling logs the full exception details but could potentially expose sensitive information in the response. The GitHub security bot previously flagged this issue pattern.

        except Exception as e:
            logger.error(f"Unhandled exception: {str(e)}")
+            # Generate a unique error reference ID that can be used to trace this error in logs
+            error_reference = str(uuid.uuid4())
+            logger.error(f"Error reference: {error_reference}")
            return jsonify({
                "error": "Internal server error", 
+               "error_reference": error_reference,
                "status": "failure"
            }), 500

The improved version generates a unique error reference ID that logs with the exception details but only returns the reference ID to the client. This allows support staff to locate the specific error in logs without exposing implementation details to users.


Refactor exception handling to enhance security and traceability.

  • In src/spatial/views/chatbot_views.py (lines 85–90), update the exception block so that a unique error reference is generated (using uuid.uuid4()) and logged alongside the full exception details.
  • Return only the generated error reference in the JSON response to avoid leaking sensitive information.
  • Ensure that the file imports uuid (e.g., add import uuid at the top) so that the error reference generation works correctly.
        except Exception as e:
            logger.error(f"Unhandled exception: {str(e)}")
+            # Generate a unique error reference ID that can be used to trace this error in logs
+            error_reference = str(uuid.uuid4())
+            logger.error(f"Error reference: {error_reference}")
            return jsonify({
                "error": "Internal server error", 
+               "error_reference": error_reference,
                "status": "failure"
            }), 500
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
except Exception as e:
logger.error(f"Unhandled exception: {str(e)}")
return jsonify({
"error": "Internal server error",
"status": "failure"
}), 500
except Exception as e:
logger.error(f"Unhandled exception: {str(e)}")
# Generate a unique error reference ID that can be used to trace this error in logs
error_reference = str(uuid.uuid4())
logger.error(f"Error reference: {error_reference}")
return jsonify({
"error": "Internal server error",
"error_reference": error_reference,
"status": "failure"
}), 500

Loading