Skip to content

Commit 04bc513

Browse files
authored
Replace ORM Querying and Marshmallow Serialization (#112)
# Description ## Removes ORM The use of SQLAlchemy's ORM framework (and its integration w/ Flask) isn't a fit for the way that we're operating on databases within the application. Rather than having a single database that hosts data for a single application we allow the user to dynamically swap the target database that they are using - because of this problems arise with the ORM's session/caching logic. This frees us up to implement SSH based querying, multi user environments and multi-tab environments but allowing the target database to be selected on a per-request basis. SQLAlchemy dependencies and configuration settings have been removed. ## New Queries/Dataclasses Queries have been moved to queries.py - these queries are optimized to perform as few reads on the file system as possible since performing complex joins against the SQLite database can result in thousands of file read operations which are considerably slower than filtering/joining/parsing the data programatically. ## New Serializers A new serializers.py file has been added that contains the logic of serialization the query results in the application. This replaces Marshmallow which integrated with the ORM to automagically serialize nested relationships of objects defined by the ORM. ## Minor Changes - Adds method to check for existence of tables to support backwards compatibility - Adds check against new table, captured_graph - Adds function name to timing functions. - Adds timing functions to other DB operations - Minor renaming - Adds 404 handlers for missing records - Adds handlers for missing database path - Changes database path to be on a per request basis (path is stubbed, dynamic path is next)
2 parents b886931 + b395399 commit 04bc513

File tree

11 files changed

+721
-407
lines changed

11 files changed

+721
-407
lines changed

backend/ttnn_visualizer/app.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def catch_all(path):
7070

7171

7272
def extensions(app: flask.Flask):
73-
from ttnn_visualizer.extensions import flask_static_digest, db, ma
73+
from ttnn_visualizer.extensions import flask_static_digest
7474

7575
"""
7676
Register 0 or more extensions (mutates the app passed in).
@@ -79,8 +79,6 @@ def extensions(app: flask.Flask):
7979
:return: None
8080
"""
8181

82-
db.init_app(app)
83-
ma.init_app(app)
8482
flask_static_digest.init_app(app)
8583

8684
# For automatically reflecting table data

backend/ttnn_visualizer/extensions.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,4 @@
1-
from flask_marshmallow import Marshmallow
2-
from flask_sqlalchemy import SQLAlchemy
31
from flask_static_digest import FlaskStaticDigest
42

53

6-
class SQLiteAlchemy(SQLAlchemy):
7-
def apply_driver_hacks(self, app, info, options):
8-
options.update(
9-
{
10-
"isolation_level": "AUTOCOMMIT",
11-
}
12-
)
13-
super(SQLiteAlchemy, self).apply_driver_hacks(app, info, options)
14-
15-
16-
db = SQLiteAlchemy()
174
flask_static_digest = FlaskStaticDigest()
18-
ma = Marshmallow()

backend/ttnn_visualizer/models.py

Lines changed: 151 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -1,181 +1,152 @@
1+
import dataclasses
2+
import enum
13
import json
2-
3-
from sqlalchemy import (
4-
Column,
5-
PrimaryKeyConstraint,
6-
Table,
7-
Integer,
8-
String,
9-
JSON,
10-
types,
11-
Text,
12-
Float,
13-
TypeDecorator,
14-
)
15-
from sqlalchemy.orm import relationship, Mapped, mapped_column
16-
17-
from ttnn_visualizer.extensions import db
18-
19-
operations = Table(
20-
"operations",
21-
db.metadata,
22-
Column("operation_id", Integer, primary_key=True),
23-
Column("name", String),
24-
Column("duration", Float),
25-
)
26-
27-
# TODO Ask about PK for this table (UUID for instance)
28-
operation_arguments = Table(
29-
"operation_arguments",
30-
db.metadata,
31-
Column("operation_id", db.ForeignKey("operations.operation_id")),
32-
Column("name", Text),
33-
Column("value", Text),
34-
PrimaryKeyConstraint("operation_id", "value", "name"),
35-
)
36-
37-
# TODO Ask about PK for this table (UUID for instance)
38-
input_tensors = Table(
39-
"input_tensors",
40-
db.metadata,
41-
Column("operation_id", db.ForeignKey("operations.operation_id")),
42-
Column("input_index", Integer),
43-
Column("tensor_id", db.ForeignKey("tensors.tensor_id")),
44-
PrimaryKeyConstraint("operation_id", "input_index", "tensor_id"),
45-
)
46-
47-
tensors = Table(
48-
"tensors",
49-
db.metadata,
50-
Column("tensor_id", Integer, primary_key=True),
51-
Column("shape", Text),
52-
Column("dtype", Text),
53-
Column("layout", Text),
54-
Column("memory_config", Text),
55-
Column("device_id", Integer),
56-
Column("address", Integer),
57-
Column("buffer_type", Integer),
58-
)
59-
60-
output_tensors = Table(
61-
"output_tensors",
62-
db.metadata,
63-
Column("operation_id", db.ForeignKey("operations.operation_id")),
64-
Column("output_index", Integer),
65-
Column("tensor_id", db.ForeignKey("tensors.tensor_id")),
66-
PrimaryKeyConstraint("operation_id", "output_index", "tensor_id"),
67-
)
68-
69-
stack_traces = Table(
70-
"stack_traces",
71-
db.metadata,
72-
Column("operation_id", db.ForeignKey("operations.operation_id")),
73-
Column("stack_trace", Text),
74-
PrimaryKeyConstraint("operation_id", "stack_trace"),
75-
)
76-
77-
buffers = Table(
78-
"buffers",
79-
db.metadata,
80-
Column("operation_id", db.ForeignKey("operations.operation_id")),
81-
Column("device_id", db.ForeignKey("devices.device_id")),
82-
Column("address", Integer),
83-
Column("max_size_per_bank", Integer),
84-
Column("buffer_type", Integer),
85-
PrimaryKeyConstraint("operation_id", "device_id", "address", "max_size_per_bank"),
86-
)
87-
88-
devices = Table(
89-
"devices",
90-
db.metadata,
91-
Column("device_id", Integer, primary_key=True),
92-
Column("num_y_cores", Integer),
93-
Column("num_x_cores", Integer),
94-
Column("num_y_compute_cores", Integer),
95-
Column("num_x_compute_cores", Integer),
96-
Column("worker_l1_size", Integer),
97-
Column("l1_num_banks", Integer),
98-
Column("l1_bank_size", Integer),
99-
Column("address_at_first_l1_bank", Integer),
100-
Column("address_at_first_l1_cb_buffer", Integer),
101-
Column("num_banks_per_storage_core", Integer),
102-
Column("num_compute_cores", Integer),
103-
Column("num_storage_cores", Integer),
104-
Column("total_l1_memory", Integer),
105-
Column("total_l1_for_tensors", Integer),
106-
Column("total_l1_for_interleaved_buffers", Integer),
107-
Column("total_l1_for_sharded_buffers", Integer),
108-
Column("cb_limit", Integer),
109-
)
110-
111-
device_operations = Table(
112-
"captured_graph",
113-
db.metadata,
114-
Column("operation_id", db.ForeignKey("operations.operation_id")),
115-
Column(
116-
"captured_graph",
117-
Text,
118-
),
119-
PrimaryKeyConstraint("operation_id", "captured_graph"),
120-
)
121-
122-
123-
class Device(db.Model):
124-
__table__ = devices
125-
126-
127-
class Tensor(db.Model):
128-
__table__ = tensors
129-
input_tensors = relationship("InputTensor", back_populates="tensor", lazy="joined")
130-
output_tensors = relationship(
131-
"OutputTensor", back_populates="tensor", lazy="joined"
132-
)
133-
134-
@property
135-
def producers(self):
136-
return [i.operation_id for i in self.output_tensors]
137-
138-
@property
139-
def consumers(self):
140-
return [i.operation_id for i in self.input_tensors]
141-
142-
143-
class Buffer(db.Model):
144-
__table__ = buffers
145-
device = relationship("Device")
146-
147-
148-
class InputTensor(db.Model):
149-
__table__ = input_tensors
150-
tensor = db.relationship(
151-
"Tensor", back_populates="input_tensors", innerjoin=True, lazy="joined"
152-
)
153-
154-
155-
class StackTrace(db.Model):
156-
__table__ = stack_traces
157-
158-
159-
class OutputTensor(db.Model):
160-
__table__ = output_tensors
161-
tensor = db.relationship(
162-
"Tensor", back_populates="output_tensors", innerjoin=True, lazy="joined"
163-
)
164-
165-
166-
class Operation(db.Model):
167-
__table__ = operations
168-
arguments = db.relationship("OperationArgument", lazy="joined")
169-
inputs = db.relationship("InputTensor", lazy="joined")
170-
outputs = db.relationship("OutputTensor", lazy="joined")
171-
stack_trace = db.relationship("StackTrace", lazy="joined")
172-
buffers = db.relationship("Buffer")
173-
device_operations = db.relationship("DeviceOperation", uselist=False, lazy="joined")
174-
175-
176-
class OperationArgument(db.Model):
177-
__table__ = operation_arguments
178-
179-
180-
class DeviceOperation(db.Model):
181-
__table__ = device_operations
4+
from json import JSONDecodeError
5+
6+
7+
class BufferType(enum.Enum):
8+
DRAM = 0
9+
L1 = 1
10+
SYSTEM_MEMORY = 2
11+
L1_SMALL = 3
12+
TRACE = 4
13+
14+
15+
@dataclasses.dataclass
16+
class Operation:
17+
operation_id: int
18+
name: str
19+
duration: float
20+
21+
22+
@dataclasses.dataclass
23+
class Device:
24+
device_id: int
25+
num_y_cores: int
26+
num_x_cores: int
27+
num_y_compute_cores: int
28+
num_x_compute_cores: int
29+
worker_l1_size: int
30+
l1_num_banks: int
31+
l1_bank_size: int
32+
address_at_first_l1_bank: int
33+
address_at_first_l1_cb_buffer: int
34+
num_banks_per_storage_core: int
35+
num_compute_cores: int
36+
num_storage_cores: int
37+
total_l1_memory: int
38+
total_l1_for_tensors: int
39+
total_l1_for_interleaved_buffers: int
40+
total_l1_for_sharded_buffers: int
41+
cb_limit: int
42+
43+
44+
@dataclasses.dataclass
45+
class DeviceOperation:
46+
operation_id: int
47+
captured_graph: str
48+
49+
def __post_init__(self):
50+
try:
51+
captured_graph = json.loads(self.captured_graph)
52+
for graph in captured_graph:
53+
id = graph.pop("counter")
54+
graph.update({"id": id})
55+
56+
self.captured_graph = captured_graph
57+
58+
except JSONDecodeError:
59+
self.captured_graph = []
60+
61+
62+
@dataclasses.dataclass
63+
class Buffer:
64+
operation_id: int
65+
device_id: int
66+
address: int
67+
max_size_per_bank: int
68+
buffer_type: BufferType
69+
70+
def __post_init__(self):
71+
self.buffer_type = (
72+
BufferType(self.buffer_type).value if self.buffer_type is not None else None
73+
)
74+
75+
76+
@dataclasses.dataclass
77+
class BufferPage:
78+
operation_id: int
79+
device_id: int
80+
address: int
81+
core_y: int
82+
core_x: int
83+
bank_id: int
84+
page_index: int
85+
page_address: int
86+
page_size: int
87+
buffer_type: BufferType
88+
89+
def __post_init__(self):
90+
self.buffer_type = (
91+
BufferType(self.buffer_type).value if self.buffer_type is not None else None
92+
)
93+
94+
95+
@dataclasses.dataclass
96+
class ProducersConsumers:
97+
tensor_id: int
98+
producers: list[int]
99+
consumers: list[int]
100+
101+
102+
@dataclasses.dataclass
103+
class Tensor:
104+
tensor_id: int
105+
shape: str
106+
dtype: str
107+
layout: str
108+
memory_config: str
109+
device_id: int
110+
address: int
111+
buffer_type: BufferType
112+
113+
def __post_init__(self):
114+
self.buffer_type = (
115+
BufferType(self.buffer_type).value if self.buffer_type is not None else None
116+
)
117+
118+
119+
@dataclasses.dataclass
120+
class InputTensor:
121+
operation_id: int
122+
input_index: int
123+
tensor_id: int
124+
125+
126+
@dataclasses.dataclass
127+
class OutputTensor:
128+
operation_id: int
129+
output_index: int
130+
tensor_id: int
131+
132+
133+
@dataclasses.dataclass
134+
class TensorComparisonRecord:
135+
tensor_id: int
136+
golden_tensor_id: int
137+
matches: bool
138+
desired_pcc: bool
139+
actual_pcc: float
140+
141+
142+
@dataclasses.dataclass
143+
class OperationArgument:
144+
operation_id: int
145+
name: str
146+
value: str
147+
148+
149+
@dataclasses.dataclass
150+
class StackTrace:
151+
operation_id: int
152+
stack_trace: str

0 commit comments

Comments
 (0)