-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chatbot #4523
Open
wabinyai
wants to merge
14
commits into
staging
Choose a base branch
from
chatbot
base: staging
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
chatbot #4523
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
5ee7124
chatbot
wabinyai fc8c96c
status
wabinyai d316b48
models
wabinyai 6f000e7
gemini
wabinyai 5b8b3e4
redis
wabinyai bafdfa8
llm
wabinyai c4fa495
model
wabinyai 7592b68
bot
wabinyai 269e289
flask
wabinyai 65927d4
redis update
wabinyai 24270e1
fre
wabinyai e05ee1c
readme
wabinyai a66eaf7
freed
wabinyai 6ab8b04
setex
wabinyai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import requests | ||
import redis | ||
import json | ||
import google.generativeai as genai | ||
import logging | ||
import re | ||
import threading | ||
from flask import Flask | ||
from urllib.parse import urlencode | ||
from configure import Config | ||
|
||
# Configure logging | ||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | ||
|
||
# Configure API keys | ||
GOOGLE_API_KEY = Config.GOOGLE_API_KEY | ||
genai.configure(api_key=GOOGLE_API_KEY) | ||
|
||
# Initialize Redis client | ||
try: | ||
redis_client = redis.StrictRedis( | ||
host=Config.REDIS_HOST or 'localhost', | ||
port=Config.REDIS_PORT or 6379, | ||
db=Config.REDIS_DB or 0, | ||
password=Config.REDIS_PASSWORD or None, | ||
decode_responses=True | ||
) | ||
# Test the connection | ||
redis_client.ping() | ||
logging.info("Connected to Redis") | ||
except Exception as e: | ||
logging.error(f"Error connecting to Redis: {e}") | ||
redis_client = None | ||
# lock for thread safety to prevent race conditions when multiple users request the same data. | ||
data_fetch_lock = threading.Lock() | ||
|
||
class DataFetcher: | ||
@staticmethod | ||
def fetch_air_quality_data(grid_id, start_time, end_time): | ||
"""Fetch air quality data and cache it in Redis to avoid redundant API calls.""" | ||
cache_key = f"air_quality:{grid_id}:{start_time}:{end_time}" | ||
|
||
# Check if data is cached in Redis | ||
cached_data = None | ||
if redis_client: | ||
try: | ||
cached_data = redis_client.get(cache_key) | ||
except Exception as e: | ||
logging.error(f"Error retrieving data from Redis: {e}") | ||
else: | ||
logging.error("Redis client not available, skipping cache check") | ||
|
||
if cached_data: | ||
logging.info(f"Retrieved cached data for {cache_key}") | ||
return json.loads(cached_data) | ||
|
||
token = Config.AIRQO_API_TOKEN | ||
analytics_url = Config.ANALTICS_URL | ||
if not token: | ||
logging.error("AIRQO_API_TOKEN is not set.") | ||
return None | ||
|
||
query_params = {'token': token} | ||
url = f"{analytics_url}?{urlencode(query_params)}" | ||
payload = {"grid_id": grid_id, "start_time": start_time, "end_time": end_time} | ||
logging.info(f"Fetching air quality data with payload: {payload}") | ||
|
||
try: | ||
response = requests.post(url, json=payload, timeout=5) | ||
response.raise_for_status() | ||
data = response.json() | ||
# Cache response in Redis for 1 hour | ||
# setex is redis command to set a key with an expiration time | ||
with data_fetch_lock: | ||
if redis_client: | ||
redis_client.setex(cache_key, 3600, json.dumps(data)) | ||
logging.info(f"Data fetched and cached for grid_id: {grid_id}") | ||
return data | ||
except requests.exceptions.HTTPError as http_err: | ||
logging.error(f"HTTP error: {http_err}") | ||
except requests.exceptions.RequestException as req_err: | ||
logging.error(f"Request error: {req_err}") | ||
except ValueError as json_err: | ||
logging.error(f"JSON error: {json_err}") | ||
return None | ||
|
||
class AirQualityChatbot: | ||
def __init__(self, air_quality_data): | ||
self.data = air_quality_data or {} | ||
self.grid_name = self.data.get('airquality', {}).get('sites', {}).get('grid name', ['Unknown'])[0] | ||
self.annual_data = self.data.get('airquality', {}).get('annual_pm', [{}])[0] or {} | ||
self.daily_mean_data = self.data.get('airquality', {}).get('daily_mean_pm', []) or [] | ||
self.diurnal = self.data.get('airquality', {}).get('diurnal', []) or [] | ||
self.monthly_data = self.data.get('airquality', {}).get('site_monthly_mean_pm', []) or [] | ||
self.site_names = [item.get('site_name', 'Unknown') for item in self.data.get('airquality', {}).get('site_annual_mean_pm', [])] or ['Unknown'] | ||
self.num_sites = self.data.get('airquality', {}).get('sites', {}).get('number_of_sites', 'Unknown') | ||
self.starttime = self.data.get('airquality', {}).get('period', {}).get('startTime', '')[:10] or 'N/A' | ||
self.endtime = self.data.get('airquality', {}).get('period', {}).get('endTime', '')[:10] or 'N/A' | ||
self.annual_pm2_5 = self.annual_data.get("pm2_5_calibrated_value", 'N/A') | ||
|
||
# Sort daily_mean_data to get the most recent measurement | ||
if self.daily_mean_data: | ||
sorted_daily = sorted(self.daily_mean_data, key=lambda x: x.get('date', ''), reverse=True) | ||
self.today_pm2_5 = sorted_daily[0].get('pm2_5_calibrated_value', 'N/A') if sorted_daily else 'N/A' | ||
self.today_date = sorted_daily[0].get('date', 'N/A') if sorted_daily else 'N/A' | ||
else: | ||
self.today_pm2_5 = 'N/A' | ||
self.today_date = 'N/A' | ||
|
||
self.peak_diurnal = max(self.diurnal, key=lambda x: x.get('pm2_5_calibrated_value', 0)) if self.diurnal else {} | ||
|
||
try: | ||
# Gemini model | ||
self.gemini_model = genai.GenerativeModel('gemini-2.0-flash') | ||
except Exception as e: | ||
logging.error(f"Failed to initialize Gemini model: {e}") | ||
self.gemini_model = None | ||
self.lock = threading.Lock() | ||
|
||
def _prepare_data_context(self): | ||
"""Prepare a concise data context for the LLM.""" | ||
return ( | ||
f"AirQo data for {self.grid_name} ({self.starttime}-{self.endtime}): " | ||
f"Annual PM2.5={self.annual_pm2_5} µg/m³, Sites={self.num_sites}, " | ||
f"Most recent daily PM2.5={self.today_pm2_5} µg/m³ on {self.today_date}, " | ||
f"Diurnal peak={self.peak_diurnal.get('pm2_5_calibrated_value', 'N/A')} µg/m³ at {self.peak_diurnal.get('hour', 'N/A')}:00, " | ||
f"Site names={self.site_names}." | ||
) | ||
|
||
def _rule_based_response(self, user_prompt): | ||
"""Handle common queries with precomputed responses.""" | ||
prompt = user_prompt.lower() | ||
|
||
if re.search(r"(today|now).*air.*quality", prompt): | ||
if self.today_pm2_5 != 'N/A': | ||
return f"The most recent PM2.5 in {self.grid_name} is {self.today_pm2_5} µg/m³ on {self.today_date}." | ||
return "No recent air quality data available." | ||
|
||
if re.search(r"(worst|highest|peak).*time", prompt): | ||
if self.peak_diurnal: | ||
return f"Pollution peaks at {self.peak_diurnal.get('hour', 'N/A')}:00 with {self.peak_diurnal.get('pm2_5_calibrated_value', 'N/A')} µg/m³." | ||
return "No diurnal data available." | ||
|
||
if re.search(r"how.*many.*(site|sites|monitors)", prompt): | ||
if self.num_sites != 'Unknown': | ||
return f"There are {self.num_sites} monitoring sites in {self.grid_name}." | ||
return "Number of sites is not available." | ||
|
||
if re.search(r"(year|annual).*average", prompt): | ||
if self.annual_pm2_5 != 'N/A': | ||
return f"The annual PM2.5 average in {self.grid_name} is {self.annual_pm2_5} µg/m³." | ||
return "Annual air quality data is not available." | ||
|
||
if re.search(r"(where|which|list).*site|sites|locations", prompt): | ||
if self.site_names != ['Unknown']: | ||
return f"Monitoring sites in {self.grid_name}: {', '.join(self.site_names)}." | ||
return "Site information is not available." | ||
|
||
return None | ||
|
||
def _llm_response(self, user_prompt): | ||
"""Generate a response using the Gemini model for complex queries.""" | ||
if not self.gemini_model: | ||
return "Language model is not available." | ||
|
||
full_prompt = ( | ||
f"Data: {self._prepare_data_context()}\n" | ||
f"User: {user_prompt}\n" | ||
"Respond concisely and accurately based on the data." | ||
) | ||
|
||
try: | ||
response = self.gemini_model.generate_content(full_prompt) | ||
return response.text | ||
except Exception as e: | ||
logging.error(f"LLM error: {e}") | ||
return "Sorry, I couldn't generate a response." | ||
|
||
def chat(self, user_prompt): | ||
"""Process user queries and return appropriate responses.""" | ||
if not self.data: | ||
return "Air quality data is not available for the specified grid and time period." | ||
if not user_prompt or not isinstance(user_prompt, str): | ||
return "Please provide a valid question about air quality." | ||
if len(user_prompt) > 500: | ||
return "Your question is too long. Please keep it under 500 characters." | ||
|
||
rule_response = self._rule_based_response(user_prompt) | ||
if rule_response: | ||
return rule_response | ||
return self._llm_response(user_prompt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,3 +26,4 @@ joblib~=1.4.2 | |
lightgbm~=4.1.0 | ||
numpy | ||
google-generativeai | ||
redis |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from flask import request, jsonify | ||
# Assuming these are in models/chatbot_model.py based on the import | ||
from models.chatbot_model import AirQualityChatbot, DataFetcher | ||
import numpy as np | ||
import logging | ||
|
||
#logging.basicConfig(filename="report_log.log", level=logging.INFO, filemode="w") | ||
logger = logging.getLogger(__name__) | ||
|
||
class ChatbotView: | ||
@staticmethod | ||
def chat_endpoint(): | ||
""" | ||
Handles chatbot API requests for air quality information | ||
Expects JSON payload with grid_id, start_time, end_time, and prompt | ||
Returns JSON response with chatbot's answer or error message | ||
""" | ||
# Validate request payload | ||
payload = request.json | ||
if not payload or not all(key in payload for key in ["grid_id", "start_time", "end_time", "prompt"]): | ||
logger.error("Invalid payload: missing required fields") | ||
return jsonify({ | ||
"error": "Missing required fields: grid_id, start_time, end_time, prompt", | ||
"status": "failure" | ||
}), 400 | ||
|
||
# Extract parameters | ||
grid_id = payload["grid_id"] | ||
start_time = payload["start_time"] | ||
end_time = payload["end_time"] | ||
user_prompt = payload["prompt"] | ||
|
||
# Validate prompt | ||
if not user_prompt or not isinstance(user_prompt, str): | ||
logger.error(f"Invalid prompt received: {user_prompt}") | ||
return jsonify({ | ||
"error": "No valid prompt provided", | ||
"status": "failure" | ||
}), 400 | ||
|
||
try: | ||
# Fetch air quality data with logging | ||
logger.info(f"Fetching data for grid_id: {grid_id}, {start_time} to {end_time}") | ||
air_quality_data = DataFetcher.fetch_air_quality_data(grid_id, start_time, end_time) | ||
|
||
if not air_quality_data or 'airquality' not in air_quality_data: | ||
logger.error(f"No valid air quality data returned for grid_id: {grid_id}") | ||
return jsonify({ | ||
"error": "Failed to fetch air quality data", | ||
"status": "failure" | ||
}), 500 | ||
|
||
# Initialize chatbot and get response | ||
chatbot = AirQualityChatbot(air_quality_data) | ||
response = chatbot.chat(user_prompt) | ||
|
||
if not response: | ||
logger.warning(f"Empty response generated for prompt: {user_prompt}") | ||
return jsonify({ | ||
"error": "No response generated", | ||
"status": "failure" | ||
}), 500 | ||
|
||
logger.info(f"Successfully processed request for {grid_id}") | ||
return jsonify({ | ||
"response": response, | ||
"status": "success", | ||
"grid_id": grid_id, | ||
"period": { | ||
"start_time": start_time, | ||
"end_time": end_time | ||
} | ||
}), 200 | ||
|
||
except Exception as e: | ||
logger.error(f"Unhandled exception: {str(e)}") | ||
return jsonify({ | ||
"error": "Internal server error", | ||
"status": "failure" | ||
}), 500 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add additional validation for time format and grid_id.
The current implementation only validates the presence of
grid_id
,start_time
, andend_time
, and does some basic validation on the prompt. Consider adding more thorough validation for date/time formats and grid_id format/existence.📝 Committable suggestion