Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 111 additions & 142 deletions core/services/workflow_service/controllers/compute_block_controller.py

Large diffs are not rendered by default.

42 changes: 22 additions & 20 deletions core/services/workflow_service/controllers/project_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,17 @@

from fastapi import HTTPException
from services.workflow_service.models.project import Project
import services.workflow_service.controllers.compute_block_controller as \
compute_block_controller
import services.workflow_service.controllers.template_controller as \
template_controller
from services.workflow_service.schemas.workflow import (
WorkflowTemplate
from services.workflow_service.controllers import (
compute_block_controller,
template_controller,
)
from services.workflow_service.schemas.workflow import WorkflowTemplate


def create_project(db: Session, name: str, current_user_uuid: UUID) -> UUID:
logging.debug(f"Creating project with name: {
name} for user: {current_user_uuid}")
logging.debug(
f"Creating project with name: {name} for user: {current_user_uuid}"
)

project: Project = Project()

Expand All @@ -33,20 +32,22 @@ def create_project(db: Session, name: str, current_user_uuid: UUID) -> UUID:


def create_project_from_template(
name: str,
template_identifier: str,
current_user_uuid: UUID
name: str,
template_identifier: str,
current_user_uuid: UUID,
project_name: str,
) -> UUID:
"""
This method will handle the creation of project, blocks and edges as
defined in the template.yaml
"""
db: Session = next(get_database())

template: WorkflowTemplate =\
template: WorkflowTemplate = (
template_controller.get_workflow_template_by_identifier(
template_identifier
)
)
required_blocks = template_controller.extract_block_urls_from_template(
template
)
Expand All @@ -59,16 +60,19 @@ def create_project_from_template(
try:
with db.begin():
project_id = create_project(db, name, current_user_uuid)
block_name_to_model, block_outputs_by_name, block_inputs_by_name =\
template_controller.configure_and_create_blocks(
G, db, unconfigured_blocks, project_id
)
(
block_name_to_model,
block_outputs_by_name,
block_inputs_by_name,
) = template_controller.configure_and_create_blocks(
G, db, unconfigured_blocks, project_id, project_name
)
template_controller.create_edges_from_template(
G,
db,
block_name_to_model,
block_outputs_by_name,
block_inputs_by_name
block_inputs_by_name,
)
return project_id
except Exception as e:
Expand Down Expand Up @@ -183,9 +187,7 @@ def read_projects_by_user_uuid(user_uuid: UUID) -> list[Project]:
db: Session = next(get_database())

projects = (
db.query(Project)
.filter(Project.users.contains([user_uuid]))
.all()
db.query(Project).filter(Project.users.contains([user_uuid])).all()
)

if not projects:
Expand Down
105 changes: 56 additions & 49 deletions core/services/workflow_service/controllers/template_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,19 @@
from scystream.sdk.config.models import (
ComputeBlock,
Entrypoint as SDKEntrypoint,
InputOutputModel
)
from services.workflow_service.models.block import (
Block
InputOutputModel,
)
from services.workflow_service.models.block import Block
from services.workflow_service.models.input_output import (
InputOutput, InputOutputType, DataType
InputOutput,
InputOutputType,
DataType,
)
from services.workflow_service.controllers.compute_block_controller import (
updated_configs_with_values, do_config_keys_match, create_compute_block,
create_stream_and_update_target_cfg
updated_configs_with_values,
do_config_keys_match,
create_compute_block,
create_stream_and_update_target_cfg,
)
from utils.config.defaults import (
get_file_cfg_defaults_dict,
Expand All @@ -47,16 +49,15 @@ def get_workflow_template_by_identifier(identifier: str) -> WorkflowTemplate:
"""
try:
registry = RepoRegistry()
repo_path = registry.get_repo(
ENV.WORKFLOW_TEMPLATE_REPO
)
repo_path = registry.get_repo(ENV.WORKFLOW_TEMPLATE_REPO)

file_path = os.path.join(repo_path, identifier)
if not os.path.isfile(file_path):
raise HTTPException(
status_code=404,
detail=f"Template file '{
identifier}' not found in repository."
detail=(
f"Template file '{identifier}' not found in repository."
),
)

with open(file_path, "r") as f:
Expand All @@ -66,18 +67,15 @@ def get_workflow_template_by_identifier(identifier: str) -> WorkflowTemplate:
except ValidationError as ve:
logging.warning(f"Validation failed for {identifier}: {ve}")
raise HTTPException(
status_code=422,
detail=f"Template validation failed: {ve}"
status_code=422, detail=f"Template validation failed: {ve}"
)


def get_workflow_templates() -> list[WorkflowTemplate]:
templates: WorkflowTemplate = []

registry = RepoRegistry()
repo_path = registry.get_repo(
ENV.WORKFLOW_TEMPLATE_REPO
)
repo_path = registry.get_repo(ENV.WORKFLOW_TEMPLATE_REPO)

for file in os.listdir(repo_path):
if not file.endswith((".yaml", ".yml")):
Expand Down Expand Up @@ -134,11 +132,12 @@ def build_workflow_graph(template: WorkflowTemplate):
from_block,
to_block,
input_identifier=inp.identifier,
output_identifier=inp.depends_on.output
output_identifier=inp.depends_on.output,
)
if not nx.is_directed_acyclic_graph(G):
raise HTTPException(
status_code=422, detail="Template defines a cyclic dependency.")
status_code=422, detail="Template defines a cyclic dependency."
)

# Assigning Positions
level_map = {}
Expand Down Expand Up @@ -167,7 +166,9 @@ def _build_io(
data_type: DataType,
description: str,
config: dict,
template_settings: dict | None = None
project_name: str,
block_name: str,
template_settings: dict | None = None,
) -> InputOutput:
"""
Constructs an InputOutput object, applying default values for outputs
Expand All @@ -179,14 +180,14 @@ def _build_io(
name=identifier,
data_type=data_type,
description=description,
config=config
config=config,
)

if io_type is InputOutputType.OUTPUT:
default_values = (
get_file_cfg_defaults_dict(identifier)
if data_type is DataType.FILE
else get_pg_cfg_defaults_dict(identifier)
else get_pg_cfg_defaults_dict(project_name, identifier, block_name)
)
io.config = updated_configs_with_values(io, default_values, data_type)

Expand All @@ -199,7 +200,9 @@ def _build_io(
def _configure_io_items(
template_ios: list[InputTemplate] | list[OutputTemplate],
unconfigured_ios: dict[str, InputOutputModel],
io_type: InputOutputType
io_type: InputOutputType,
project_name: str,
block_name: str,
) -> list[InputOutput]:
"""
Iterates over the compute blocks ios.
Expand Down Expand Up @@ -229,9 +232,10 @@ def _configure_io_items(
status_code=421,
detail=(
f"The keys used in the template to configure IO '{
template.identifier}' "
template.identifier
}' "
f"do not match those in the compute block definition."
)
),
)

configured.append(
Expand All @@ -241,7 +245,9 @@ def _configure_io_items(
data_type=data_type,
description=unconfigured_io.description,
config=unconfigured_io.config,
template_settings=template.settings
template_settings=template.settings,
project_name=project_name,
block_name=block_name,
)
)
else:
Expand All @@ -251,7 +257,9 @@ def _configure_io_items(
io_type=io_type,
data_type=data_type,
description=unconfigured_io.description,
config=unconfigured_io.config
config=unconfigured_io.config,
project_name=project_name,
block_name=block_name,
)
)

Expand All @@ -260,12 +268,10 @@ def _configure_io_items(

def _configure_block(
block_template: BlockTemplate,
unconfigured_entry: SDKEntrypoint
) -> (
ConfigType,
list[InputOutput],
list[InputOutput]
):
unconfigured_entry: SDKEntrypoint,
project_name: str,
block_name: str,
) -> (ConfigType, list[InputOutput], list[InputOutput]):
"""
This method returns:
:dict: the configuration from the template applied to the configuration
Expand All @@ -291,19 +297,23 @@ def _configure_block(
The Config-Keys provided by the template
do not match with the configs that the block
{block_template.name} offers.
"""
""",
)

configured_inputs: list[InputOutput] = _configure_io_items(
block_template.inputs or [],
unconfigured_entry.inputs or {},
InputOutputType.INPUT
InputOutputType.INPUT,
project_name,
block_name,
)

configured_outputs: list[InputOutput] = _configure_io_items(
block_template.outputs or [],
unconfigured_entry.outputs or {},
InputOutputType.OUTPUT
InputOutputType.OUTPUT,
project_name,
block_name,
)

return (configured_envs, configured_inputs, configured_outputs)
Expand All @@ -313,11 +323,10 @@ def configure_and_create_blocks(
G: nx.DiGraph,
db: Session,
unconfigured_blocks: dict[str, ComputeBlock],
project_id: UUID
project_id: UUID,
project_name: str,
) -> tuple[
dict[str, Block],
dict[str, dict[str, UUID]],
dict[str, dict[str, UUID]]
dict[str, Block], dict[str, dict[str, UUID]], dict[str, dict[str, UUID]]
]:
"""
Configures and creates blocks defined in the template graph.
Expand All @@ -344,26 +353,24 @@ def configure_and_create_blocks(
for block_name in nx.topological_sort(G):
block_template = G.nodes[block_name]["block"]
# 1. Validate wether Template Definition of Compute Block is correct
compute_block = unconfigured_blocks.get(
block_template.repo_url
)
compute_block = unconfigured_blocks.get(block_template.repo_url)
if compute_block is None:
raise HTTPException(
status_code=422,
detail=f"Block repo '{block_template.repo_url}' not found."
detail=f"Block repo '{block_template.repo_url}' not found.",
)

entrypoint = compute_block.entrypoints.get(block_template.entrypoint)
if entrypoint is None:
raise HTTPException(
status_code=422,
detail=f"Entrypoint '{block_template.entrypoint}' not found in\
block '{block_template.name}'."
block '{block_template.name}'.",
)

# 2. Configure the Block
configured_envs, inputs, outputs = _configure_block(
block_template, entrypoint
block_template, entrypoint, project_name, compute_block.name
)

# 3. Create the Block
Expand All @@ -383,7 +390,7 @@ def configure_and_create_blocks(
envs=configured_envs,
inputs=inputs,
outputs=outputs,
project_id=project_id
project_id=project_id,
)

# 4. Create the maps that "connect" template to database representation
Expand Down Expand Up @@ -426,13 +433,13 @@ def create_edges_from_template(
Dependency resolution failed for edge:
"{from_block} -> {to_block}
"({output_identifier}-> {input_identifier})
"""
""",
)

create_stream_and_update_target_cfg(
db,
upstream_block.uuid,
output_uuid,
downstream_block.uuid,
input_uuid
input_uuid,
)
Loading
Loading