Skip to content

Commit

Permalink
v0.6.0 (#161)
Browse files Browse the repository at this point in the history
- Buffers view: zoom
- Buffer view: linkable operations
- Buffers view: buffer table - initial rendering
- Buffers pages query api for Tensor sharding visualization
- Support for multiple per-tab parallel downloads
- Multithreading to support background jobs
- Real-time monitoring of sync/download/upload jobs
- Improved session management
- UX improvements to display/monitor background jobs
  • Loading branch information
aidemsined authored Oct 16, 2024
2 parents 5d09b00 + 032c899 commit 4b0320b
Show file tree
Hide file tree
Showing 44 changed files with 1,587 additions and 503 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ build
.env


sessions.db
*.db
66 changes: 48 additions & 18 deletions backend/ttnn_visualizer/app.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import logging
import shutil
import subprocess
from os import environ
from pathlib import Path

from flask import Flask
import flask
from dotenv import load_dotenv
from flask import Flask
from flask_cors import CORS
from werkzeug.debug import DebuggedApplication
from werkzeug.middleware.proxy_fix import ProxyFix
from flask_cors import CORS
from ttnn_visualizer import settings
from dotenv import load_dotenv
from ttnn_visualizer.sessions import init_sessions, CustomRequest, init_session_db

from ttnn_visualizer.settings import Config


def create_app(settings_override=None):
Expand All @@ -31,9 +31,8 @@ def create_app(settings_override=None):
flask_env = environ.get("FLASK_ENV", "development")

app = Flask(__name__, static_folder=static_assets_dir, static_url_path="/")
app.request_class = CustomRequest

app.config.from_object(getattr(settings, flask_env))
app.config.from_object(Config())

logging.basicConfig(level=app.config.get("LOG_LEVEL", "INFO"))

Expand All @@ -42,8 +41,7 @@ def create_app(settings_override=None):
if settings_override:
app.config.update(settings_override)

init_session_db()

app.config["USE_WEBSOCKETS"] = True # Set this based on environment
middleware(app)

app.register_blueprint(api)
Expand All @@ -61,7 +59,8 @@ def catch_all(path):


def extensions(app: flask.Flask):
from ttnn_visualizer.extensions import flask_static_digest
from ttnn_visualizer.extensions import flask_static_digest, db, socketio
from ttnn_visualizer.sockets import register_handlers

"""
Register 0 or more extensions (mutates the app passed in).
Expand All @@ -71,6 +70,20 @@ def extensions(app: flask.Flask):
"""

flask_static_digest.init_app(app)
socketio.init_app(app)
db.init_app(app)

app.config["SESSION_TYPE"] = "sqlalchemy"
app.config["SESSION_SQLALCHEMY"] = db

with app.app_context():
db.drop_all()

register_handlers(socketio)

# Create the tables within the application context
with app.app_context():
db.create_all()

# For automatically reflecting table data
# with app.app_context():
Expand All @@ -86,21 +99,38 @@ def middleware(app: flask.Flask):
:param app: Flask application instance
:return: None
"""
# Enable the Flask interactive debugger in the brower for development.
if app.debug:
app.wsgi_app = DebuggedApplication(app.wsgi_app, evalex=True)
# Only use the middleware if running in pure WSGI (HTTP requests)
if not app.config.get("USE_WEBSOCKETS"):
# Enable the Flask interactive debugger in the browser for development.
if app.debug:
app.wsgi_app = DebuggedApplication(app.wsgi_app, evalex=True)

# Set the real IP address into request.remote_addr when behind a proxy.
app.wsgi_app = ProxyFix(app.wsgi_app)
# Set the real IP address into request.remote_addr when behind a proxy.
app.wsgi_app = ProxyFix(app.wsgi_app)

# CORS configuration
origins = ["http://localhost:5173", "http://localhost:8000"]

init_sessions(app)

CORS(
app,
origins=origins,
)

return None


if __name__ == "__main__":
config = Config()

gunicorn_args = [
"gunicorn",
"-k",
config.GUNICORN_WORKER_CLASS,
"-w",
config.GUNICORN_WORKERS,
config.GUNICORN_APP_MODULE,
"-b",
config.GUNICORN_BIND,
]

subprocess.run(gunicorn_args)
17 changes: 14 additions & 3 deletions backend/ttnn_visualizer/appserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,29 @@

from gunicorn.app.wsgiapp import run

from ttnn_visualizer.settings import Config

app_dir = pathlib.Path(__file__).parent.resolve()
config_dir = app_dir.joinpath("config")
static_assets_dir = app_dir.joinpath("static")


def serve():
"""Run command for use in wheel package entrypoint"""
config = Config()

os.environ.setdefault("FLASK_ENV", "production")
os.environ.setdefault("STATIC_ASSETS", str(static_assets_dir))

sys.argv = [
"gunicorn",
"-c",
str(config_dir.joinpath("gunicorn.py").absolute()),
"ttnn_visualizer.app:create_app()",
"-k",
config.GUNICORN_WORKER_CLASS,
"-w",
config.GUNICORN_WORKERS,
"-b",
config.GUNICORN_BIND,
config.GUNICORN_APP_MODULE,
]

run()
12 changes: 5 additions & 7 deletions backend/ttnn_visualizer/config/gunicorn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*-

import multiprocessing
import os
from pathlib import Path

from dotenv import load_dotenv

from ttnn_visualizer.utils import str_to_bool

# Load dotenv from root directory
Expand All @@ -15,11 +16,8 @@
accesslog = "-"
access_log_format = "%(h)s %(l)s %(u)s %(t)s '%(r)s' %(s)s %(b)s '%(f)s' '%(a)s' in %(D)sµs" # noqa: E501

workers = 1
threads = 1
reload = bool(str_to_bool(os.getenv("WEB_RELOAD", "false")))

# Currently no need for multithreading/workers
# workers = int(os.getenv("WEB_CONCURRENCY", multiprocessing.cpu_count() * 2))
# threads = int(os.getenv("PYTHON_MAX_THREADS", 1))
worker_class = "gevent"

reload = bool(str_to_bool(os.getenv("WEB_RELOAD", "false")))
workers = 1
91 changes: 0 additions & 91 deletions backend/ttnn_visualizer/database.py
Original file line number Diff line number Diff line change
@@ -1,91 +0,0 @@
import sqlite3
from logging import getLogger

logger = getLogger(__name__)


def create_update_database(sqlite_db_path):
"""
Creates or updates database with all tables
:param sqlite_db_path Path to target SQLite database
:return:
"""
sqlite_connection = sqlite3.connect(sqlite_db_path)
logger.info("Creating/updating SQLite database")
cursor = sqlite_connection.cursor()
cursor.execute(
"""CREATE TABLE IF NOT EXISTS devices
(
device_id int,
num_y_cores int,
num_x_cores int,
num_y_compute_cores int,
num_x_compute_cores int,
worker_l1_size int,
l1_num_banks int,
l1_bank_size int,
address_at_first_l1_bank int,
address_at_first_l1_cb_buffer int,
num_banks_per_storage_core int,
num_compute_cores int,
num_storage_cores int,
total_l1_memory int,
total_l1_for_tensors int,
total_l1_for_interleaved_buffers int,
total_l1_for_sharded_buffers int,
cb_limit int
)"""
)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS tensors
(tensor_id int UNIQUE, shape text, dtype text, layout text, memory_config text, device_id int, address int, buffer_type int)"""
)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS local_tensor_comparison_records
(tensor_id int UNIQUE, golden_tensor_id int, matches bool, desired_pcc bool, actual_pcc float)"""
)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS global_tensor_comparison_records
(tensor_id int UNIQUE, golden_tensor_id int, matches bool, desired_pcc bool, actual_pcc float)"""
)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS operations
(operation_id int UNIQUE, name text, duration float)"""
)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS operation_arguments
(operation_id int, name text, value text)"""
)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS stack_traces
(operation_id int, stack_trace text)"""
)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS input_tensors
(operation_id int, input_index int, tensor_id int)"""
)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS output_tensors
(operation_id int, output_index int, tensor_id int)"""
)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS buffers
(operation_id int, device_id int, address int, max_size_per_bank int, buffer_type int)"""
)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS buffer_pages
(operation_id int, device_id int, address int, core_y int, core_x int, bank_id int, page_index int, page_address int, page_size int, buffer_type int)"""
)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS nodes
(operation_id int, unique_id int, node_operation_id int, name text)"""
)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS edges
(operation_id int, source_unique_id int, sink_unique_id int, source_output_index int, sink_input_index int, key int)"""
)
cursor.execute(
"""CREATE TABLE IF NOT EXISTS captured_graph
(operation_id int, captured_graph text)"""
)
sqlite_connection.commit()
38 changes: 31 additions & 7 deletions backend/ttnn_visualizer/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,57 @@
from flask import request, abort
from pathlib import Path

from ttnn_visualizer.sessions import get_or_create_tab_session
from ttnn_visualizer.utils import get_report_path


def with_report_path(func):
@wraps(func)
def wrapper(*args, **kwargs):
target_report_path = getattr(request, "report_path", None)
if not target_report_path or not Path(target_report_path).exists():
from flask import current_app

tab_id = request.args.get("tabId")

if not tab_id:
current_app.logger.error("No tabId present on request, returning 404")
abort(404)

session = get_or_create_tab_session(tab_id=tab_id)
active_report = session.active_report

if not active_report:
current_app.logger.error(
f"No active report exists for tabId {tab_id}, returning 404"
)
# Raise 404 if report_path is missing or does not exist
abort(404)

report_path = get_report_path(active_report, current_app)
if not Path(report_path).exists():
current_app.logger.error(
f"Specified report path {report_path} does not exist, returning 404"
)
abort(404)

# Add the report path to the view's arguments
kwargs["report_path"] = target_report_path
kwargs["report_path"] = report_path
return func(*args, **kwargs)

return wrapper




def remote_exception_handler(func):
def remote_handler(*args, **kwargs):
from flask import current_app

from paramiko.ssh_exception import AuthenticationException
from paramiko.ssh_exception import NoValidConnectionsError
from paramiko.ssh_exception import SSHException
from ttnn_visualizer.exceptions import RemoteFolderException, NoProjectsException

from ttnn_visualizer.exceptions import (
RemoteFolderException,
NoProjectsException,
)

connection = args[0]

try:
Expand Down
6 changes: 6 additions & 0 deletions backend/ttnn_visualizer/extensions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from flask_socketio import SocketIO
from flask_static_digest import FlaskStaticDigest
from flask_sqlalchemy import SQLAlchemy


flask_static_digest = FlaskStaticDigest()
# Initialize Flask SQLAlchemy
db = SQLAlchemy()

socketio = SocketIO(cors_allowed_origins="*", async_mode="gevent")
17 changes: 17 additions & 0 deletions backend/ttnn_visualizer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from json import JSONDecodeError

from pydantic import BaseModel, Field
from sqlalchemy import Integer, Column, String, JSON

from ttnn_visualizer.extensions import db


class BufferType(enum.Enum):
Expand Down Expand Up @@ -165,6 +168,20 @@ class RemoteConnection(BaseModel):
path: str


class TabSession(db.Model):
__tablename__ = "tab_sessions"

id = Column(Integer, primary_key=True)
tab_id = Column(String, unique=True, nullable=False)
active_report = Column(JSON)
remote_connection = Column(JSON, nullable=True)

def __init__(self, tab_id, active_report, remote_connection=None):
self.tab_id = tab_id
self.active_report = active_report
self.remote_connection = remote_connection


class StatusMessage(BaseModel):
status: int
message: str
Expand Down
Loading

0 comments on commit 4b0320b

Please sign in to comment.