Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chatbot #4523

Open
wants to merge 14 commits into
base: staging
Choose a base branch
from
Open

chatbot #4523

Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/spatial/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class Config:
)
ANALTICS_URL = os.getenv("ANALTICS_URL")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
# Redis configuration
REDIS_HOST = os.getenv("REDIS_HOST", "localhost") # Default to 'localhost' if not set
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379)) # Convert to int, default to 6379
REDIS_DB = int(os.getenv("REDIS_DB", 0)) # Convert to int, default to 0
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None) # Default to None if not set
class ProductionConfig(Config):
DEBUG = False
TESTING = False
Expand Down
10 changes: 8 additions & 2 deletions src/spatial/controllers/controllers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# controller/controller.py
from flask import Blueprint, request, jsonify
from views.chatbot_views import ChatbotView
from views.getis_services import SpatialDataHandler
from views.getis_confidence_services import SpatialDataHandler_confidence
from views.localmoran_services import SpatialDataHandler_moran
Expand All @@ -12,7 +13,7 @@
from views.satellite_predictions import SatellitePredictionView
from views.site_category_view import SiteCategorizationView
from views.site_selection_views import SiteSelectionView
from views.report_view import ReportView
from views.report_view import ReportView


controller_bp = Blueprint("controller", __name__)
Expand Down Expand Up @@ -78,4 +79,9 @@ def fetch_air_quality_without_llm():

@controller_bp.route("/air_quality_report_with_customised_prompt", methods=["POST"])
def fetch_air_quality_with_customised_prompt():
return ReportView.generate_air_quality_report_with_customised_prompt_gemini()
return ReportView.generate_air_quality_report_with_customised_prompt_gemini()

@controller_bp.route("/chatbot", methods=["POST"])
def Chatbot_Views():
return ChatbotView.chat_endpoint()

191 changes: 191 additions & 0 deletions src/spatial/models/chatbot_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import requests
import redis
import json
import google.generativeai as genai
import logging
import re
import threading
from flask import Flask
from urllib.parse import urlencode
from configure import Config

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Configure API keys
GOOGLE_API_KEY = Config.GOOGLE_API_KEY
genai.configure(api_key=GOOGLE_API_KEY)

# Initialize Redis client
try:
redis_client = redis.StrictRedis(
host=Config.REDIS_HOST or 'localhost',
port=Config.REDIS_PORT or 6379,
db=Config.REDIS_DB or 0,
password=Config.REDIS_PASSWORD or None,
decode_responses=True
)
# Test the connection
redis_client.ping()
logging.info("Connected to Redis")
except Exception as e:
logging.error(f"Error connecting to Redis: {e}")
redis_client = None
# lock for thread safety to prevent race conditions when multiple users request the same data.
data_fetch_lock = threading.Lock()

class DataFetcher:
@staticmethod
def fetch_air_quality_data(grid_id, start_time, end_time):
"""Fetch air quality data and cache it in Redis to avoid redundant API calls."""
cache_key = f"air_quality:{grid_id}:{start_time}:{end_time}"

# Check if data is cached in Redis
cached_data = None
if redis_client:
try:
cached_data = redis_client.get(cache_key)
except Exception as e:
logging.error(f"Error retrieving data from Redis: {e}")
else:
logging.error("Redis client not available, skipping cache check")

if cached_data:
logging.info(f"Retrieved cached data for {cache_key}")
return json.loads(cached_data)

token = Config.AIRQO_API_TOKEN
analytics_url = Config.ANALTICS_URL
if not token:
logging.error("AIRQO_API_TOKEN is not set.")
return None

query_params = {'token': token}
url = f"{analytics_url}?{urlencode(query_params)}"
payload = {"grid_id": grid_id, "start_time": start_time, "end_time": end_time}
logging.info(f"Fetching air quality data with payload: {payload}")

try:
response = requests.post(url, json=payload, timeout=5)
response.raise_for_status()
data = response.json()
# Cache response in Redis for 1 hour
# setex is
with data_fetch_lock:
if redis_client:
redis_client.setex(cache_key, 3600, json.dumps(data))
logging.info(f"Data fetched and cached for grid_id: {grid_id}")
return data
except requests.exceptions.HTTPError as http_err:
logging.error(f"HTTP error: {http_err}")
except requests.exceptions.RequestException as req_err:
logging.error(f"Request error: {req_err}")
except ValueError as json_err:
logging.error(f"JSON error: {json_err}")
return None

class AirQualityChatbot:
def __init__(self, air_quality_data):
self.data = air_quality_data or {}
self.grid_name = self.data.get('airquality', {}).get('sites', {}).get('grid name', ['Unknown'])[0]
self.annual_data = self.data.get('airquality', {}).get('annual_pm', [{}])[0] or {}
self.daily_mean_data = self.data.get('airquality', {}).get('daily_mean_pm', []) or []
self.diurnal = self.data.get('airquality', {}).get('diurnal', []) or []
self.monthly_data = self.data.get('airquality', {}).get('site_monthly_mean_pm', []) or []
self.site_names = [item.get('site_name', 'Unknown') for item in self.data.get('airquality', {}).get('site_annual_mean_pm', [])] or ['Unknown']
self.num_sites = self.data.get('airquality', {}).get('sites', {}).get('number_of_sites', 'Unknown')
self.starttime = self.data.get('airquality', {}).get('period', {}).get('startTime', '')[:10] or 'N/A'
self.endtime = self.data.get('airquality', {}).get('period', {}).get('endTime', '')[:10] or 'N/A'
self.annual_pm2_5 = self.annual_data.get("pm2_5_calibrated_value", 'N/A')

# Sort daily_mean_data to get the most recent measurement
if self.daily_mean_data:
sorted_daily = sorted(self.daily_mean_data, key=lambda x: x.get('date', ''), reverse=True)
self.today_pm2_5 = sorted_daily[0].get('pm2_5_calibrated_value', 'N/A') if sorted_daily else 'N/A'
self.today_date = sorted_daily[0].get('date', 'N/A') if sorted_daily else 'N/A'
else:
self.today_pm2_5 = 'N/A'
self.today_date = 'N/A'

self.peak_diurnal = max(self.diurnal, key=lambda x: x.get('pm2_5_calibrated_value', 0)) if self.diurnal else {}

try:
# Gemini model
self.gemini_model = genai.GenerativeModel('gemini-2.0-flash')
except Exception as e:
logging.error(f"Failed to initialize Gemini model: {e}")
self.gemini_model = None
self.lock = threading.Lock()

def _prepare_data_context(self):
"""Prepare a concise data context for the LLM."""
return (
f"AirQo data for {self.grid_name} ({self.starttime}-{self.endtime}): "
f"Annual PM2.5={self.annual_pm2_5} µg/m³, Sites={self.num_sites}, "
f"Most recent daily PM2.5={self.today_pm2_5} µg/m³ on {self.today_date}, "
f"Diurnal peak={self.peak_diurnal.get('pm2_5_calibrated_value', 'N/A')} µg/m³ at {self.peak_diurnal.get('hour', 'N/A')}:00, "
f"Site names={self.site_names}."
)

def _rule_based_response(self, user_prompt):
"""Handle common queries with precomputed responses."""
prompt = user_prompt.lower()

if re.search(r"(today|now).*air.*quality", prompt):
if self.today_pm2_5 != 'N/A':
return f"The most recent PM2.5 in {self.grid_name} is {self.today_pm2_5} µg/m³ on {self.today_date}."
return "No recent air quality data available."

if re.search(r"(worst|highest|peak).*time", prompt):
if self.peak_diurnal:
return f"Pollution peaks at {self.peak_diurnal.get('hour', 'N/A')}:00 with {self.peak_diurnal.get('pm2_5_calibrated_value', 'N/A')} µg/m³."
return "No diurnal data available."

if re.search(r"how.*many.*(site|sites|monitors)", prompt):
if self.num_sites != 'Unknown':
return f"There are {self.num_sites} monitoring sites in {self.grid_name}."
return "Number of sites is not available."

if re.search(r"(year|annual).*average", prompt):
if self.annual_pm2_5 != 'N/A':
return f"The annual PM2.5 average in {self.grid_name} is {self.annual_pm2_5} µg/m³."
return "Annual air quality data is not available."

if re.search(r"(where|which|list).*site|sites|locations", prompt):
if self.site_names != ['Unknown']:
return f"Monitoring sites in {self.grid_name}: {', '.join(self.site_names)}."
return "Site information is not available."

return None

def _llm_response(self, user_prompt):
"""Generate a response using the Gemini model for complex queries."""
if not self.gemini_model:
return "Language model is not available."

full_prompt = (
f"Data: {self._prepare_data_context()}\n"
f"User: {user_prompt}\n"
"Respond concisely and accurately based on the data."
)

try:
response = self.gemini_model.generate_content(full_prompt)
return response.text
except Exception as e:
logging.error(f"LLM error: {e}")
return "Sorry, I couldn't generate a response."

def chat(self, user_prompt):
"""Process user queries and return appropriate responses."""
if not self.data:
return "Air quality data is not available for the specified grid and time period."
if not user_prompt or not isinstance(user_prompt, str):
return "Please provide a valid question about air quality."
if len(user_prompt) > 500:
return "Your question is too long. Please keep it under 500 characters."

rule_response = self._rule_based_response(user_prompt)
if rule_response:
return rule_response
return self._llm_response(user_prompt)
2 changes: 1 addition & 1 deletion src/spatial/models/report_datafetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, data):


# Initialize models once in the constructor
self.gemini_model = genai.GenerativeModel('gemini-pro')
self.gemini_model = genai.GenerativeModel('gemini-2.0-flash')

def _prepare_base_info(self):
return (
Expand Down
1 change: 1 addition & 0 deletions src/spatial/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ joblib~=1.4.2
lightgbm~=4.1.0
numpy
google-generativeai
redis
80 changes: 80 additions & 0 deletions src/spatial/views/chatbot_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from flask import request, jsonify
# Assuming these are in models/chatbot_model.py based on the import
from models.chatbot_model import AirQualityChatbot, DataFetcher
import numpy as np
import logging

#logging.basicConfig(filename="report_log.log", level=logging.INFO, filemode="w")
logger = logging.getLogger(__name__)

class ChatbotView:
@staticmethod
def chat_endpoint():
"""
Handles chatbot API requests for air quality information
Expects JSON payload with grid_id, start_time, end_time, and prompt
Returns JSON response with chatbot's answer or error message
"""
# Validate request payload
payload = request.json
if not payload or not all(key in payload for key in ["grid_id", "start_time", "end_time", "prompt"]):
logger.error("Invalid payload: missing required fields")
return jsonify({
"error": "Missing required fields: grid_id, start_time, end_time, prompt",
"status": "failure"
}), 400

# Extract parameters
grid_id = payload["grid_id"]
start_time = payload["start_time"]
end_time = payload["end_time"]
user_prompt = payload["prompt"]

# Validate prompt
if not user_prompt or not isinstance(user_prompt, str):
logger.error(f"Invalid prompt received: {user_prompt}")
return jsonify({
"error": "No valid prompt provided",
"status": "failure"
}), 400
Comment on lines +33 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add additional validation for time format and grid_id.

The current implementation only validates the presence of grid_id, start_time, and end_time, and does some basic validation on the prompt. Consider adding more thorough validation for date/time formats and grid_id format/existence.

# After extracting parameters, add:
+        # Validate time formats
+        try:
+            # Use appropriate datetime parsing based on your expected format
+            from datetime import datetime
+            datetime.strptime(start_time, "%Y-%m-%dT%H:%M:%SZ")
+            datetime.strptime(end_time, "%Y-%m-%dT%H:%M:%SZ")
+            
+            # Ensure start_time is before end_time
+            if datetime.strptime(start_time, "%Y-%m-%dT%H:%M:%SZ") >= datetime.strptime(end_time, "%Y-%m-%dT%H:%M:%SZ"):
+                logger.error(f"Invalid time range: start_time must be before end_time")
+                return jsonify({
+                    "error": "Invalid time range: start_time must be before end_time",
+                    "status": "failure"
+                }), 400
+        except ValueError:
+            logger.error(f"Invalid time format for start_time or end_time")
+            return jsonify({
+                "error": "Invalid time format. Expected format: YYYY-MM-DDThh:mm:ssZ",
+                "status": "failure"
+            }), 400
+            
+        # Validate grid_id format if needed
+        # For example, if grid_id should be a UUID or follow a specific pattern
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Validate prompt
if not user_prompt or not isinstance(user_prompt, str):
logger.error(f"Invalid prompt received: {user_prompt}")
return jsonify({
"error": "No valid prompt provided",
"status": "failure"
}), 400
# Assume parameters such as start_time, end_time, and grid_id have already been extracted
# Validate time formats
try:
# Use appropriate datetime parsing based on your expected format
from datetime import datetime
datetime.strptime(start_time, "%Y-%m-%dT%H:%M:%SZ")
datetime.strptime(end_time, "%Y-%m-%dT%H:%M:%SZ")
# Ensure start_time is before end_time
if datetime.strptime(start_time, "%Y-%m-%dT%H:%M:%SZ") >= datetime.strptime(end_time, "%Y-%m-%dT%H:%M:%SZ"):
logger.error(f"Invalid time range: start_time must be before end_time")
return jsonify({
"error": "Invalid time range: start_time must be before end_time",
"status": "failure"
}), 400
except ValueError:
logger.error(f"Invalid time format for start_time or end_time")
return jsonify({
"error": "Invalid time format. Expected format: YYYY-MM-DDThh:mm:ssZ",
"status": "failure"
}), 400
# Validate grid_id format if needed
# For example, if grid_id should be a UUID or follow a specific pattern
# Validate prompt
if not user_prompt or not isinstance(user_prompt, str):
logger.error(f"Invalid prompt received: {user_prompt}")
return jsonify({
"error": "No valid prompt provided",
"status": "failure"
}), 400


try:
# Fetch air quality data with logging
logger.info(f"Fetching data for grid_id: {grid_id}, {start_time} to {end_time}")
air_quality_data = DataFetcher.fetch_air_quality_data(grid_id, start_time, end_time)

if not air_quality_data or 'airquality' not in air_quality_data:
logger.error(f"No valid air quality data returned for grid_id: {grid_id}")
return jsonify({
"error": "Failed to fetch air quality data",
"status": "failure"
}), 500

# Initialize chatbot and get response
chatbot = AirQualityChatbot(air_quality_data)
response = chatbot.chat(user_prompt)

if not response:
logger.warning(f"Empty response generated for prompt: {user_prompt}")
return jsonify({
"error": "No response generated",
"status": "failure"
}), 500

logger.info(f"Successfully processed request for {grid_id}")
return jsonify({
"response": response,
"status": "success",
"grid_id": grid_id,
"period": {
"start_time": start_time,
"end_time": end_time
}
}), 200

except Exception as e:
logger.error(f"Unhandled exception: {str(e)}")
return jsonify({
"error": "Internal server error",
"status": "failure"
}), 500
Loading