Skip to content

Commit

Permalink
Feature/improve operation query speed (#60)
Browse files Browse the repository at this point in the history
### Description
Changes the joins from operation to related models to be eagerly joined.
This was to address the performance impact of dynamically evaluating the
related fields from the main `operations` query. To avoid making another
query to resolve related attributes we fetch the related records through
the join eagerly.

### Notes
Performance of this query still seems like it could be faster but it is
a rather complex operation that reaches into most of the models in the
database.
  • Loading branch information
GregHattJr authored Aug 15, 2024
2 parents 31a99a7 + 738f576 commit 85370bf
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 32 deletions.
13 changes: 11 additions & 2 deletions backend/extensions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from flask_marshmallow import Marshmallow
from flask_sqlalchemy import SQLAlchemy
from flask_static_digest import FlaskStaticDigest
from flask_marshmallow import Marshmallow

db = SQLAlchemy()

class SQLiteAlchemy(SQLAlchemy):
def apply_driver_hacks(self, app, info, options):
options.update({
'isolation_level': 'AUTOCOMMIT',
})
super(SQLiteAlchemy, self).apply_driver_hacks(app, info, options)


db = SQLiteAlchemy()
flask_static_digest = FlaskStaticDigest()
ma = Marshmallow()
26 changes: 9 additions & 17 deletions backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Text,
Float,
)
from sqlalchemy.orm import relationship

from backend.extensions import db

Expand Down Expand Up @@ -109,22 +110,18 @@ class Device(db.Model):

class Tensor(db.Model):
__table__ = tensors
input_tensors = relationship("InputTensor", back_populates="tensor", lazy="joined")
output_tensors = relationship("OutputTensor", back_populates="tensor", lazy="joined")

def producers(self):
return [c.operation_id for c in OutputTensor.query.filter_by(tensor_id=self.tensor_id)]

def consumers(self):
return [c.operation_id for c in InputTensor.query.filter_by(tensor_id=self.tensor_id)]


class Buffer(db.Model):
__table__ = buffers
device = db.relationship("Device")


class InputTensor(db.Model):
__table__ = input_tensors
tensor = db.relationship("Tensor", backref="input")
tensor = db.relationship("Tensor", lazy="joined", back_populates="input_tensors")


class StackTrace(db.Model):
Expand All @@ -133,21 +130,16 @@ class StackTrace(db.Model):

class OutputTensor(db.Model):
__table__ = output_tensors
tensor = db.relationship("Tensor", backref="output")
tensor = db.relationship("Tensor", lazy="joined", back_populates="output_tensors")


class Operation(db.Model):
__table__ = operations
arguments = db.relationship("OperationArgument", backref="operation")
inputs = db.relationship("InputTensor", backref="operation")
outputs = db.relationship("OutputTensor", backref="operation")
buffers = db.relationship("Buffer", backref="operation")
stack_trace = db.relationship("StackTrace", backref="operation")
arguments = db.relationship("OperationArgument", lazy="joined")
inputs = db.relationship("InputTensor", lazy="joined")
outputs = db.relationship("OutputTensor", lazy="joined")
stack_trace = db.relationship("StackTrace")


class OperationArgument(db.Model):
__table__ = operation_arguments




25 changes: 14 additions & 11 deletions backend/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ class Meta:
shape = ma.auto_field()
address = ma.auto_field()
id = ma.Function(lambda obj: obj.tensor_id, dump_only=True)
producers = ma.Function(lambda obj: obj.producers(), dump_only=True)
consumers = ma.Function(lambda obj: obj.consumers(), dump_only=True)


class StackTraceSchema(ma.SQLAlchemyAutoSchema):
Expand All @@ -33,18 +31,23 @@ class Meta:


class InputOutputSchema(object):
# TODO - We can probably create a model to avoid backrefs
id = ma.Function(lambda obj: obj.tensor.tensor_id)
operation_id = ma.auto_field()
shape = ma.Function(lambda obj: obj.tensor.shape)
address = ma.Function(lambda obj: obj.tensor.address)
layout = ma.Function(lambda obj: obj.tensor.layout)
memory_config = ma.Function(lambda obj: obj.tensor.memory_config)
device_id = ma.Function(lambda obj: obj.tensor.device_id)
buffer_type = ma.Function(lambda obj: obj.tensor.buffer_type)
dtype = ma.Function(lambda obj: obj.tensor.dtype)
producers = ma.Function(lambda obj: obj.tensor.producers())
consumers = ma.Function(lambda obj: obj.tensor.consumers())
id = ma.Function(lambda obj: obj.tensor.tensor_id, dump_only=True)
operation_id = ma.auto_field()
consumers = fields.Method("get_consumers")
producers = fields.Method("get_producers")

def get_producers(self, obj):
return [ot.operation_id for ot in obj.tensor.output_tensors]

def get_consumers(self, obj):
return [it.operation_id for it in obj.tensor.input_tensors]


class OutputTensorSchema(ma.SQLAlchemyAutoSchema, InputOutputSchema):
Expand Down Expand Up @@ -89,10 +92,10 @@ class Meta:
id = ma.Function(lambda obj: obj.operation_id)
name = ma.auto_field()
duration = ma.auto_field()
buffers = ma.List(ma.Nested(BufferSchema))
outputs = ma.List(ma.Nested(OutputTensorSchema))
inputs = ma.List(ma.Nested(InputTensorSchema))
arguments = ma.List(ma.Nested(OperationArgumentsSchema))
buffers = ma.List(ma.Nested(BufferSchema()))
outputs = ma.List(ma.Nested(OutputTensorSchema()))
inputs = ma.List(ma.Nested(InputTensorSchema()))
arguments = ma.List(ma.Nested(OperationArgumentsSchema()))


# Filesystem Schemas
Expand Down
2 changes: 1 addition & 1 deletion backend/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Config(object):


class DevelopmentConfig(Config):
SQLALCHEMY_ECHO = True
pass


class TestingConfig(Config):
Expand Down
20 changes: 19 additions & 1 deletion backend/utils.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,20 @@
import logging
from functools import wraps
from timeit import default_timer

logger = logging.getLogger(__name__)

def str_to_bool(string_value):
return string_value.lower() in ("yes", "true", "t", "1")
return string_value.lower() in ("yes", "true", "t", "1")


def timer(f):
@wraps(f)
def wrapper(*args, **kwargs):
start_time = default_timer()
response = f(*args, **kwargs)
total_elapsed_time = default_timer() - start_time
logger.info(f"Elapsed time: {total_elapsed_time}")
return response

return wrapper
2 changes: 2 additions & 0 deletions backend/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
OperationSchema,
TensorSchema,
)
from backend.utils import timer

logger = logging.getLogger(__name__)

Expand All @@ -36,6 +37,7 @@ def health_check():


@api.route("/operations", methods=["GET"])
@timer
def operation_list():
operations = Operation.query.all()
return OperationSchema(
Expand Down

0 comments on commit 85370bf

Please sign in to comment.