Skip to content

Commit

Permalink
Refactor producer/consumer logic, add to operations (#52)
Browse files Browse the repository at this point in the history
# Description

- Refactors the way we're attaching producers and consumers
- Adds input/output tensors to operation list
  • Loading branch information
GregHattJr authored Aug 8, 2024
2 parents 881a8be + 3168584 commit fa3228a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 28 deletions.
12 changes: 9 additions & 3 deletions backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
PrimaryKeyConstraint("operation_id", "output_index", "tensor_id"),
)


stack_traces = Table(
"stack_traces",
db.metadata,
Expand All @@ -69,7 +68,6 @@
PrimaryKeyConstraint("operation_id", "stack_trace")
)


buffers = Table(
"buffers",
db.metadata,
Expand Down Expand Up @@ -104,14 +102,20 @@
Column("cb_limit", Integer),
)


class Device(db.Model):
__table__ = devices



class Tensor(db.Model):
__table__ = tensors

def producers(self):
return [c.operation_id for c in OutputTensor.query.filter_by(tensor_id=self.tensor_id)]

def consumers(self):
return [c.operation_id for c in InputTensor.query.filter_by(tensor_id=self.tensor_id)]


class Buffer(db.Model):
__table__ = buffers
Expand All @@ -126,11 +130,13 @@ class InputTensor(db.Model):
class StackTrace(db.Model):
__table__ = stack_traces


class OutputTensor(db.Model):
__table__ = output_tensors
tensor = db.relationship("Tensor", backref="output")



class Operation(db.Model):
__table__ = operations
arguments = db.relationship("OperationArgument", backref="operation")
Expand Down
11 changes: 7 additions & 4 deletions backend/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from marshmallow import fields, validates

from backend.extensions import ma
Expand All @@ -19,15 +17,20 @@
class TensorSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = Tensor

shape = ma.auto_field()
address = ma.auto_field()
consumers = fields.List(fields.Integer, default=[])
producers = fields.List(fields.Integer, default=[])


class StackTraceSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = StackTrace

stack_trace = ma.Function(lambda obj: obj.stack_trace or "")


class InputOutputSchema(object):
# TODO - We can probably create a model to avoid backrefs
shape = ma.Function(lambda obj: obj.tensor.shape)
Expand All @@ -38,8 +41,8 @@ class InputOutputSchema(object):
buffer_type = ma.Function(lambda obj: obj.tensor.buffer_type)
dtype = ma.Function(lambda obj: obj.tensor.dtype)
tensor_id = ma.Function(lambda obj: obj.tensor.tensor_id)
consumers = fields.List(fields.Integer, default=[])
producers = fields.List(fields.Integer, default=[])
producers = ma.Function(lambda obj: obj.tensor.producers())
consumers = ma.Function(lambda obj: obj.tensor.consumers())
operation_id = ma.auto_field()


Expand Down
27 changes: 6 additions & 21 deletions backend/views.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from http import HTTPStatus
import json
from pathlib import Path
import shutil
from http import HTTPStatus
from pathlib import Path

from flask import Blueprint, Response, current_app, request

from backend.models import (
Device,
Operation,
Buffer,
InputTensor,
OutputTensor,
Tensor,
StackTrace,
)
Expand Down Expand Up @@ -47,23 +45,11 @@ def operation_list():
many=True,
exclude=[
"buffers",
"input_tensors",
"output_tensors",
"operation_id",
],
).dump(operations)


def attach_producers_consumers(t: Tensor):
t.consumers = [
c.operation_id for c in InputTensor.query.filter_by(tensor_id=t.tensor_id)
]
t.producers = [
c.operation_id for c in OutputTensor.query.filter_by(tensor_id=t.tensor_id)
]
return t


@api.route("/operations/<operation_id>", methods=["GET"])
def operation_detail(operation_id):
operation = Operation.query.get(operation_id)
Expand All @@ -80,10 +66,10 @@ def operation_detail(operation_id):
stack_trace_dump = StackTraceSchema().dump(stack_trace, many=False)
stack_trace_value = stack_trace_dump.get("stack_trace")
input_tensors = InputTensorSchema().dump(
map(attach_producers_consumers, operation.input_tensors), many=True
operation.input_tensors, many=True
)
output_tensors = OutputTensorSchema().dump(
map(attach_producers_consumers, operation.output_tensors), many=True
operation.output_tensors, many=True
)

return dict(
Expand Down Expand Up @@ -125,7 +111,7 @@ def get_config():

@api.route("/tensors", methods=["GET"])
def get_tensors():
tensors = map(attach_producers_consumers, Tensor.query.all())
tensors = Tensor.query.all()
return TensorSchema().dump(tensors, many=True)


Expand All @@ -134,7 +120,7 @@ def get_tensor(tensor_id):
tensor = Tensor.query.get(tensor_id)
if not tensor:
return Response(status=HTTPStatus.NOT_FOUND)
return TensorSchema().dump(attach_producers_consumers(tensor))
return TensorSchema().dump(tensor)


@api.route(
Expand All @@ -146,7 +132,6 @@ def get_tensor(tensor_id):
def create_upload_files():
"""
Copies the folder upload into the active data directory
:param files:
:return:
"""
files = request.files.getlist("files")
Expand Down

0 comments on commit fa3228a

Please sign in to comment.