Source code for mltrace.db.models

from __future__ import annotations
from datetime import datetime
from sqlalchemy.sql.schema import UniqueConstraint

from sqlalchemy.sql.sqltypes import Boolean
from mltrace.db.base import Base
from sqlalchemy import (
    Column,
    String,
    LargeBinary,
    Integer,
    DateTime,
    Table,
    ForeignKey,
    Enum,
    PickleType,
    UniqueConstraint,
)
from sqlalchemy.orm import relationship, backref
from sqlalchemy.schema import ForeignKeyConstraint

import enum
import typing


[docs]class PointerTypeEnum(str, enum.Enum): DATA = "DATA" MODEL = "MODEL" ENDPOINT = "ENDPOINT" UNKNOWN = "UNKNOWN"
component_tag_association = Table( "component_tags", Base.metadata, Column("component_name", String, ForeignKey("components.name")), Column("tag_name", String, ForeignKey("tags.name")), )
[docs]class Component(Base): __tablename__ = "components" name = Column(String, primary_key=True) description = Column(String) owner = Column(String) component_runs = relationship("ComponentRun", cascade="all, delete-orphan") tags = relationship( "Tag", secondary=component_tag_association, cascade="all" ) def __init__( self, name: str, description: str, owner: str, tags: typing.List[Tag] = [], ): self.name = name self.description = description self.owner = owner self.tags = tags
[docs] def add_tags(self, tags: typing.List[Tag]): self.tags = list(set(self.tags + tags))
[docs]class Tag(Base): __tablename__ = "tags" name = Column(String, primary_key=True) def __init__(self, name: str): self.name = name
[docs]class IOPointer(Base): __tablename__ = "io_pointers" name = Column(String, primary_key=True, nullable=False) value = Column(LargeBinary, primary_key=True, nullable=False) pointer_type = Column(Enum(PointerTypeEnum)) flag = Column(Boolean, default=False) __table_args__ = (UniqueConstraint("name", "value", name="_iop_uc"),) def __init__( self, name: str, value: bytes = b"", pointer_type: PointerTypeEnum = PointerTypeEnum.UNKNOWN, ): self.name = name self.value = value self.pointer_type = pointer_type self.flag = False
[docs] def set_pointer_type(self, pointer_type: PointerTypeEnum): self.pointer_type = pointer_type
[docs] def set_flag(self): self.flag = True
[docs] def clear_flag(self): self.flag = False
component_run_input_association = Table( "component_runs_inputs", Base.metadata, Column("input_path_name", String), Column("input_path_value", LargeBinary), Column("component_run_id", Integer, ForeignKey("component_runs.id")), # UniqueConstraint( # "input_path_name", "input_path_value", name="inp_nameval" # ), ForeignKeyConstraint( ["input_path_name", "input_path_value"], ["io_pointers.name", "io_pointers.value"], ), ) component_run_output_association = Table( "component_runs_outputs", Base.metadata, Column("output_path_name", String), Column("output_path_value", LargeBinary), Column("component_run_id", Integer, ForeignKey("component_runs.id")), # UniqueConstraint( # "output_path_name", "output_path_value", name="out_nameval" # ), ForeignKeyConstraint( ["output_path_name", "output_path_value"], ["io_pointers.name", "io_pointers.value"], ), ) component_run_dependencies = Table( "component_run_dependencies", Base.metadata, Column( "component_run_id", Integer, ForeignKey("component_runs.id"), primary_key=True, ), Column( "depends_on_component_run_id", Integer, ForeignKey("component_runs.id"), primary_key=True, ), )
[docs]class ComponentRun(Base): __tablename__ = "component_runs" id = Column(Integer, primary_key=True) component_name = Column(String, ForeignKey("components.name")) notes = Column(String) git_hash = Column(String) git_tags = Column(PickleType) code_snapshot = Column(LargeBinary) start_timestamp = Column(DateTime) end_timestamp = Column(DateTime) inputs = relationship( "IOPointer", secondary=component_run_input_association, cascade="all", backref=backref("component_runs_inputs", lazy="joined"), ) outputs = relationship( "IOPointer", secondary=component_run_output_association, cascade="all", backref=backref("component_runs_outputs", lazy="joined"), ) dependencies = relationship( "ComponentRun", secondary=component_run_dependencies, primaryjoin=id == component_run_dependencies.c.component_run_id, secondaryjoin=id == component_run_dependencies.c.depends_on_component_run_id, backref="left_component_run_ids", cascade="all", ) stale = Column(PickleType) def __init__(self, component_name): """Initialize ComponentRun, or an instance of a Component's 'run.'""" self.component_name = component_name self.notes = "" self.inputs = [] self.outputs = [] self.dependencies = [] self.stale = []
[docs] def set_start_timestamp(self, ts: datetime = None): """Call this function to set the start timestamp to a specific timestamp or now.""" if ts is None: ts = datetime.utcnow() if not isinstance(ts, datetime): raise TypeError("Timestamp must be of type datetime.") self.start_timestamp = ts
[docs] def set_end_timestamp(self, ts: datetime = None): """Call this function to set the end timestamp to a specific timestamp or now.""" if ts is None: ts = datetime.utcnow() if not isinstance(ts, datetime): raise TypeError("Timestamp must be of type datetime.") self.end_timestamp = ts
[docs] def set_code_snapshot(self, code_snapshot: bytes): """Code snapshot setter.""" self.code_snapshot = code_snapshot
[docs] def add_notes(self, notes: str): """Add notes describing details of component run""" if not isinstance(notes, str): raise TypeError("notes field must be of type str") self.notes = notes
[docs] def set_git_hash(self, git_hash: str): """Git hash setter.""" self.git_hash = git_hash
[docs] def set_git_tags(self, git_tags: typing.List[str]): """Git tag setter.""" self.git_tags = git_tags
[docs] def add_staleness_message(self, message: str): """Staleness indicator.""" self.stale = self.stale + [message]
[docs] def add_input(self, input: IOPointer): """Add a single input (instance of IOPointer).""" self._add_io(input, True)
[docs] def add_inputs(self, inputs: typing.List[IOPointer]): """Add a list of inputs (each element should be an instance of IOPointer).""" self._add_io(inputs, True)
[docs] def add_output(self, output: IOPointer): """ "Add a single output (instance of IOPointer).""" self._add_io(output, False)
[docs] def add_outputs(self, outputs: typing.List[IOPointer]): """Add a list of outputs (each element should be an instance of IOPointer).""" self._add_io(outputs, False)
def _add_io( self, elems: typing.Union[typing.List[IOPointer], IOPointer], input: bool, ): """Helper function to add inputs or outputs.""" # Elems can be a list or a single IOPointer. Set to a list. elems = [elems] if not isinstance(elems, list) else elems if input: self.inputs = list(set(self.inputs + elems)) else: self.outputs = list(set(self.outputs + elems))
[docs] def set_upstream( self, dependencies: typing.Union[typing.List[ComponentRun], ComponentRun], ): """Set dependencies for this ComponentRun. API similar to Airflow set_upstream.""" # Dependencies can be a list or a single ComponentRun. Set to a list. dependencies = ( [dependencies] if not isinstance(dependencies, list) else dependencies ) self.dependencies += dependencies # Drop duplicates self.dependencies = list(set(self.dependencies))
[docs] def check_completeness(self) -> dict: """Returns a dictionary of success indicator and error messages.""" status_dict = {"success": True, "msg": ""} if self.start_timestamp is None: status_dict["success"] = False status_dict[ "msg" ] += f"{self.component_name} ComponentRun has no start timestamp. " if self.end_timestamp is None: status_dict["success"] = False status_dict[ "msg" ] += f"{self.component_name} ComponentRun has no end timestamp. " # Show warnings if there are no dependencies or I/O. if len(self.inputs) == 0: status_dict[ "msg" ] += f"{self.component_name} ComponentRun has no inputs. " if len(self.outputs) == 0: status_dict[ "msg" ] += f"{self.component_name} ComponentRun has no outputs. " if len(self.dependencies) == 0: status_dict[ "msg" ] += f"{self.component_name} ComponentRun has no dependencies. " # Make sure there are no circular dependencies. if self.id and self.id in [x.id for x in self.dependencies]: status_dict["success"] = False status_dict["msg"] += ( f"{self.component_name} ComponentRun has a " + "circular dependency. " ) return status_dict