Skip to content

Generator enhancements #674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/reference-docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,10 @@ product_blocks:
-
```

In this section we define the product block(s) that are part of this product. The first
product block will automatically be the root product block.
In this section we define the product block(s) that are part of this product.

**A product configuration should contain exactly 1 root product block.** This means there should be 1 product block that is not used by any other product blocks within this product.
If the configuration does contain multiple root blocks, or none at all due to a cyclic dependency, then the generator will raise a helpful error.

The use of `name`, `type`, `tag` and `description` in the product block definition is equivalent
to the product definition above. The `fields` describe the product block resource types.
Expand Down
6 changes: 3 additions & 3 deletions orchestrator/cli/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def merge(


@app.command()
def upgrade(revision: str = typer.Argument(default=None, help="Rev id to upgrade to")) -> None:
def upgrade(revision: str = typer.Argument(help="Rev id to upgrade to")) -> None:
"""The `upgrade` command will upgrade the database to the specified revision.

Args:
Expand Down Expand Up @@ -228,7 +228,7 @@ def revision(
@app.command()
def history(
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
indicate_current: bool = typer.Option(False, "--current", "-c", help="Indicate current revision"),
indicate_current: bool = typer.Option(True, "--current", "-c", help="Indicate current revision"),
) -> None:
"""The `history` command lists Alembic revision history/changeset scripts in chronological order.

Expand Down Expand Up @@ -323,7 +323,7 @@ def migrate_workflows(
message: str = typer.Argument(..., help="Migration name"),
test: bool = typer.Option(False, help="Optional boolean if you don't want to generate a migration file"),
) -> tuple[list[dict], list[dict]] | None:
"""The `migrate-workflows` commanad creates a migration file based on the difference between workflows in the database and registered WorkflowInstances in your codebase.
"""The `migrate-workflows` command creates a migration file based on the difference between workflows in the database and registered WorkflowInstances in your codebase.

!!! warning "BACKUP YOUR DATABASE BEFORE USING THE MIGRATION!"

Expand Down
118 changes: 114 additions & 4 deletions orchestrator/cli/generator/generator/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator
import inspect
from collections.abc import Generator, Iterable
from importlib import import_module
from os import listdir, path
from pathlib import Path
from typing import Any

import structlog
from more_itertools import first
from more_itertools import first, one

from orchestrator.cli.generator.generator.enums import to_dict
from orchestrator.cli.generator.generator.settings import product_generator_settings as settings
from orchestrator.domain.base import ProductBlockModel
from orchestrator.utils.helpers import camel_to_snake, snake_to_camel

logger = structlog.getLogger(__name__)
Expand All @@ -44,10 +49,115 @@ def get_product_block_file_name(product_block: dict) -> str:
return get_product_block_variable(product_block)


def get_existing_product_blocks() -> dict[str, Any]:
"""Inspect the python code for existing product blocks."""

def yield_blocks() -> Generator:
def is_product_block(attribute: Any) -> bool:
return issubclass(attribute, ProductBlockModel)

if not path.exists(get_product_blocks_folder()):
logger.warning("Product block path does not exist", product_blocks_path=get_product_blocks_folder())
return

for pb_file in listdir(get_product_blocks_folder()):
name = pb_file.removesuffix(".py")
module_name = f"{get_product_blocks_module()}.{name}"

module = import_module(module_name)

classes = [obj for _, obj in inspect.getmembers(module, inspect.isclass) if obj.__module__ == module_name]

yield from ((klass.__name__, module_name) for klass in classes if is_product_block(klass))

return dict(yield_blocks())


def get_product_block_depends_on(
product_blocks: list[dict], include_existing_blocks: bool = False
) -> dict[str, set[str]]:
_product_block_types = {block["type"] for block in product_blocks}

def base_type(block_name: str) -> str:
block_type, _lifecycle = block_name.rsplit("Block", maxsplit=1)
return block_type

if include_existing_blocks:
existing_blocks = {base_type(block) for block in get_existing_product_blocks()}
_product_block_types.update(existing_blocks)

def dependencies(product_block: dict) -> Iterable[str]:
"""Find all product blocks which this product block depends on."""
for field in product_block.get("fields", []):
field_type = field.get("list_type", field["type"])
if field_type in _product_block_types:
yield field_type

return {product_block["type"]: set(dependencies(product_block)) for product_block in product_blocks}


def find_root_product_block(product_blocks: list[dict]) -> str | None:
block_dependencies = get_product_block_depends_on(product_blocks)

blocks_in_use = set().union(*block_dependencies.values())
root_blocks = block_dependencies.keys() - blocks_in_use
return one(
root_blocks,
too_short=ValueError(
"There should be exactly 1 root product block, found none. Please ensure there are no cyclic relations"
),
too_long=ValueError(f"There should be exactly 1 root product block, found multiple: {root_blocks}"),
)


def root_product_block(config: dict) -> dict:
product_blocks = config.get("product_blocks", [])
# TODO: multiple product_blocks will need more logic, ok for now
return product_blocks[0]
root_block_name = find_root_product_block(config.get("product_blocks", []))
return one(block for block in product_blocks if block["type"] == root_block_name)


def sort_product_blocks_by_dependencies(product_blocks: list[dict]) -> list[dict]:
"""Perform a 'Topological Sort' on the list of product blocks.

This ensures that a product's blocks are created bottom-up and that there is no cycle.
"""
block_dependencies = get_product_block_depends_on(product_blocks)

block_order: dict[str, int] = {}
order = 0
while block_dependencies:
cycle = True

for block, depends_on_blocks in list(block_dependencies.items()):
if depends_on_blocks - block_order.keys():
# Not all dependent blocks are resolved yet
continue
cycle = False
block_order[block] = order
order += 1
del block_dependencies[block]

if cycle:
raise ValueError(f"Cycle detected in product blocks: {block_dependencies}")

return sorted(product_blocks, key=lambda block: block_order[block["type"]])


def set_resource_types(product_blocks: list[dict], block_dependencies: dict[str, set[str]]) -> list[dict]:
"""Returns product blocks enriched with a list 'resource_types'.

Args:
product_blocks: product blocks to enrich
block_dependencies: mapping of product blocks to dependent blocks
"""

def resource_type_fields(product_block: dict) -> Iterable[dict]:
for field in product_block["fields"]:
field_type = field.get("list_type", field["type"])
if field_type not in block_dependencies[product_block["type"]]:
yield field

return [(block | {"resource_types": list(resource_type_fields(block))}) for block in product_blocks]


def insert_into_imports(content: list[str], new_import: str) -> list[str]:
Expand Down
64 changes: 42 additions & 22 deletions orchestrator/cli/generator/generator/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
from alembic.util import rev_id
from jinja2 import Environment

from orchestrator.cli.generator.generator.helpers import get_product_types_module
from orchestrator.cli.generator.generator.helpers import (
find_root_product_block,
get_product_block_depends_on,
get_product_types_module,
set_resource_types,
sort_product_blocks_by_dependencies,
)
from orchestrator.cli.generator.generator.settings import product_generator_settings as settings

logger = structlog.getLogger(__name__)
Expand Down Expand Up @@ -134,25 +140,39 @@
if "data" not in heads:
create_data_head(context=context, depends_on=heads["schema"])

if migration_file := create_migration_file(message=f"add {config['name']}", head="data"):
if revision_info := get_revision_info(migration_file):
if "fixed_inputs" in config and config["fixed_inputs"]:
fixed_input_values = [
[(fixed_input["name"], str(value)) for value in fixed_input["values"]]
for fixed_input in config["fixed_inputs"]
if "values" in fixed_input
]
fixed_input_combinations = list(itertools.product(*fixed_input_values))
product_variants = [
(" ".join([config["name"]] + [value for name, value in combination]), combination)
for combination in fixed_input_combinations
]
else:
product_variants = [((config["name"]), ())]
template = environment.get_template("new_product_migration.j2")
content = template.render(product=config, product_variants=product_variants, **revision_info)
writer(migration_file, content)
update_subscription_model_registry(environment, config, product_variants, writer)

else:
if not (migration_file := create_migration_file(message=f"add {config['name']}", head="data")):
logger.error("Could not create migration file")
return

Check warning on line 145 in orchestrator/cli/generator/generator/migration.py

View check run for this annotation

Codecov / codecov/patch

orchestrator/cli/generator/generator/migration.py#L145

Added line #L145 was not covered by tests

if not (revision_info := get_revision_info(migration_file)):
logger.error("Could not get revision info from migration file", migration_file=migration_file)
return

Check warning on line 149 in orchestrator/cli/generator/generator/migration.py

View check run for this annotation

Codecov / codecov/patch

orchestrator/cli/generator/generator/migration.py#L148-L149

Added lines #L148 - L149 were not covered by tests

if fixed_inputs := config.get("fixed_inputs"):
fixed_input_values = [
[(fixed_input["name"], str(value)) for value in fixed_input["values"]]
for fixed_input in fixed_inputs
if "values" in fixed_input
]
fixed_input_combinations = list(itertools.product(*fixed_input_values))
product_variants = [
(" ".join([config["name"]] + [value for name, value in combination]), combination)
for combination in fixed_input_combinations
]
else:
product_variants = [((config["name"]), ())]

# Add depends_on_block_relations, sort the product blocks, set resource types and the root block
product_blocks = sort_product_blocks_by_dependencies(config.get("product_blocks", []))
block_depends_on = get_product_block_depends_on(product_blocks, include_existing_blocks=True)
config["root_product_block"] = find_root_product_block(product_blocks)
product_blocks = set_resource_types(product_blocks, block_depends_on)

for block in product_blocks:
block["depends_on_blocks"] = block_depends_on[block["type"]]

template = environment.get_template("new_product_migration.j2")
config["product_blocks"] = product_blocks
content = template.render(product=config, product_variants=product_variants, **revision_info)
writer(migration_file, content)
update_subscription_model_registry(environment, config, product_variants, writer)
5 changes: 3 additions & 2 deletions orchestrator/cli/generator/generator/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_product_types_folder,
get_product_types_module,
path_to_module,
root_product_block,
)
from orchestrator.cli.generator.generator.settings import product_generator_settings as settings

Expand All @@ -29,7 +30,7 @@ def generate_product(context: dict) -> None:

product = config["type"]
fixed_inputs = config.get("fixed_inputs", [])
product_blocks = config.get("product_blocks", [])
root_block = root_product_block(config)

non_standard_fixed_inputs = get_non_standard_fields(fixed_inputs)
int_enums = get_int_enums(fixed_inputs)
Expand All @@ -41,7 +42,7 @@ def generate_product(context: dict) -> None:
product_blocks_module=path_to_module(settings.FOLDER_PREFIX / settings.PRODUCT_BLOCKS_PATH),
product_types_module=get_product_types_module(),
non_standard_fixed_inputs=non_standard_fixed_inputs,
product_blocks=product_blocks,
root_block=root_block,
int_enums=int_enums,
str_enums=str_enums,
fixed_inputs=(to_dict(fixed_inputs) | to_dict(int_enums) | to_dict(str_enums)).values(),
Expand Down
Loading
Loading