Source code for mltrace.db.models

from __future__ import annotations
from datetime import datetime
from sqlalchemy.sql.schema import UniqueConstraint

from sqlalchemy.sql.sqltypes import Boolean
from mltrace.db.base import Base
from sqlalchemy import (
    ARRAY,
    Column,
    JSON,
    Index,
    String,
    LargeBinary,
    Integer,
    DateTime,
    Table,
    ForeignKey,
    Enum,
    PickleType,
    UniqueConstraint,
    text,
    Numeric,
)
from sqlalchemy.orm import relationship, backref
from sqlalchemy.schema import ForeignKeyConstraint

import enum
import typing


[docs]class PointerTypeEnum(str, enum.Enum): DATA = "DATA" MODEL = "MODEL" ENDPOINT = "ENDPOINT" UNKNOWN = "UNKNOWN"
# Tables for monitoring extensions output_table = Table( "outputs", Base.metadata, Column("timestamp", DateTime), Column("identifier", String), Column("task_name", String), Column("value", Numeric), Index("outputs_ts_name_asc", "timestamp", "task_name"), Index( "outputs_ts_name_desc", text("timestamp DESC"), "task_name", ), ) feedback_table = Table( "feedback", Base.metadata, Column("timestamp", DateTime), Column("identifier", String), Column("task_name", String), Column("value", Numeric), Index("feedback_ts_name_asc", "timestamp", "task_name"), Index( "feedback_ts_name_desc", text("timestamp DESC"), "task_name", ), ) component_tag_association = Table( "component_tags", Base.metadata, Column("component_name", String, ForeignKey("components.name")), Column("tag_name", String, ForeignKey("tags.name")), ) # Functionality for label tracking. label_io_pointer_association = Table( "labels_io_pointers", Base.metadata, Column("label", String, ForeignKey("labels.id"), index=True), Column("io_pointer_name", String), Column("io_pointer_value", LargeBinary), ForeignKeyConstraint( ["io_pointer_name", "io_pointer_value"], ["io_pointers.name", "io_pointers.value"], ), ) deleted_labels = Table( "deleted_labels", Base.metadata, Column("label", String, ForeignKey("labels.id"), primary_key=True), Column("deletion_request_time", DateTime), )
[docs]class Label(Base): __tablename__ = "labels" id = Column(String, primary_key=True) io_pointers = relationship( "IOPointer", secondary=label_io_pointer_association, cascade="all", backref="io_pointers", ) def __init__(self, id: str): self.id = id self.io_pointers = []
[docs]class Component(Base): __tablename__ = "components" name = Column(String, primary_key=True) description = Column(String) owner = Column(String) component_runs = relationship("ComponentRun", cascade="all, delete-orphan") tags = relationship( "Tag", secondary=component_tag_association, cascade="all" ) def __init__( self, name: str, description: str, owner: str, tags: typing.List[Tag] = [], ): self.name = name self.description = description self.owner = owner self.tags = tags
[docs] def add_tags(self, tags: typing.List[Tag]): self.tags = list(set(self.tags + tags))
[docs]class Tag(Base): __tablename__ = "tags" name = Column(String, primary_key=True) def __init__(self, name: str): self.name = name
[docs]class IOPointer(Base): __tablename__ = "io_pointers" name = Column(String, primary_key=True, nullable=False) value = Column(LargeBinary, primary_key=True, nullable=False) pointer_type = Column(Enum(PointerTypeEnum)) flag = Column(Boolean, default=False) labels = relationship( "Label", secondary=label_io_pointer_association, cascade="all", backref="labels", ) __table_args__ = (UniqueConstraint("name", "value", name="_iop_uc"),) def __init__( self, name: str, value: bytes = b"", pointer_type: PointerTypeEnum = PointerTypeEnum.UNKNOWN, labels=[], ): self.name = name self.value = value self.pointer_type = pointer_type self.flag = False self.labels = labels
[docs] def set_pointer_type(self, pointer_type: PointerTypeEnum): self.pointer_type = pointer_type
[docs] def set_flag(self): self.flag = True
[docs] def clear_flag(self): self.flag = False
[docs] def add_label(self, label: Label): self.labels = self.labels + [label]
[docs] def add_labels(self, labels: typing.list[Label]): self.labels = self.labels + labels
[docs] def dedup_labels(self): self.labels = list(set(self.labels))
component_run_input_association = Table( "component_runs_inputs", Base.metadata, Column("input_path_name", String), Column("input_path_value", LargeBinary), Column("component_run_id", Integer, ForeignKey("component_runs.id")), # UniqueConstraint( # "input_path_name", "input_path_value", name="inp_nameval" # ), ForeignKeyConstraint( ["input_path_name", "input_path_value"], ["io_pointers.name", "io_pointers.value"], ), ) component_run_output_association = Table( "component_runs_outputs", Base.metadata, Column("output_path_name", String), Column("output_path_value", LargeBinary), Column("component_run_id", Integer, ForeignKey("component_runs.id")), # UniqueConstraint( # "output_path_name", "output_path_value", name="out_nameval" # ), ForeignKeyConstraint( ["output_path_name", "output_path_value"], ["io_pointers.name", "io_pointers.value"], ), ) component_run_dependencies = Table( "component_run_dependencies", Base.metadata, Column( "component_run_id", Integer, ForeignKey("component_runs.id"), primary_key=True, ), Column( "depends_on_component_run_id", Integer, ForeignKey("component_runs.id"), primary_key=True, ), )
[docs]class ComponentRun(Base): __tablename__ = "component_runs" id = Column(Integer, primary_key=True) component_name = Column(String, ForeignKey("components.name")) notes = Column(String) git_hash = Column(String) git_tags = Column(PickleType) code_snapshot = Column(LargeBinary) start_timestamp = Column(DateTime) end_timestamp = Column(DateTime) mlflow_run_id = Column(String) mlflow_run_params = Column(PickleType) mlflow_run_metrics = Column(PickleType) inputs = relationship( "IOPointer", secondary=component_run_input_association, cascade="all", backref=backref("component_runs_inputs", lazy="joined"), ) outputs = relationship( "IOPointer", secondary=component_run_output_association, cascade="all", backref=backref("component_runs_outputs", lazy="joined"), ) dependencies = relationship( "ComponentRun", secondary=component_run_dependencies, primaryjoin=id == component_run_dependencies.c.component_run_id, secondaryjoin=id == component_run_dependencies.c.depends_on_component_run_id, backref="left_component_run_ids", cascade="all", ) stale = Column(PickleType) test_results = Column(JSON) def __init__(self, component_name): """Initialize ComponentRun, or an instance of a Component's 'run.'""" self.component_name = component_name self.notes = "" self.inputs = [] self.outputs = [] self.dependencies = [] self.stale = [] self.test_results = JSON.NULL
[docs] def set_mlflow_run_id(self, mlflow_run_id: str): """Call this function to set the mlflow component run id""" self.mlflow_run_id = mlflow_run_id
[docs] def set_mlflow_run_metrics(self, mlflow_run_metrics: dict): """Call this function to set the mlflow component run id""" self.mlflow_run_metrics = mlflow_run_metrics
[docs] def set_mlflow_run_params(self, mlflow_run_params: dict): """Call this function to set the mlflow component run id""" self.mlflow_run_params = mlflow_run_params
[docs] def set_start_timestamp(self, ts: datetime = None): """Call this function to set the start timestamp to a specific timestamp or now.""" if ts is None: ts = datetime.utcnow() if not isinstance(ts, datetime): raise TypeError("Timestamp must be of type datetime.") self.start_timestamp = ts
[docs] def set_end_timestamp(self, ts: datetime = None): """Call this function to set the end timestamp to a specific timestamp or now.""" if ts is None: ts = datetime.utcnow() if not isinstance(ts, datetime): raise TypeError("Timestamp must be of type datetime.") self.end_timestamp = ts
[docs] def set_code_snapshot(self, code_snapshot: bytes): """Code snapshot setter.""" self.code_snapshot = code_snapshot
[docs] def add_notes(self, notes: str): """Add notes describing details of component run""" if not isinstance(notes, str): raise TypeError("notes field must be of type str") self.notes = notes
[docs] def set_git_hash(self, git_hash: str): """Git hash setter.""" self.git_hash = git_hash
[docs] def set_git_tags(self, git_tags: typing.List[str]): """Git tag setter.""" self.git_tags = git_tags
[docs] def add_staleness_message(self, message: str): """Staleness indicator.""" self.stale = self.stale + [message]
[docs] def add_input(self, input: IOPointer): """Add a single input (instance of IOPointer).""" self._add_io(input, True)
[docs] def add_inputs(self, inputs: typing.List[IOPointer]): """Add a list of inputs (each element should be an instance of IOPointer).""" self._add_io(inputs, True)
[docs] def add_output(self, output: IOPointer): """ "Add a single output (instance of IOPointer).""" self._add_io(output, False)
[docs] def add_outputs(self, outputs: typing.List[IOPointer]): """Add a list of outputs (each element should be an instance of IOPointer).""" self._add_io(outputs, False)
def _add_io( self, elems: typing.Union[typing.List[IOPointer], IOPointer], input: bool, ): """Helper function to add inputs or outputs.""" # Elems can be a list or a single IOPointer. Set to a list. elems = [elems] if not isinstance(elems, list) else elems if input: self.inputs = list(set(self.inputs + elems)) else: self.outputs = list(set(self.outputs + elems))
[docs] def set_upstream( self, dependencies: typing.Union[typing.List[ComponentRun], ComponentRun], ): """Set dependencies for this ComponentRun. API similar to Airflow set_upstream.""" # Dependencies can be a list or a single ComponentRun. Set to a list. dependencies = ( [dependencies] if not isinstance(dependencies, list) else dependencies ) self.dependencies += dependencies # Drop duplicates self.dependencies = list(set(self.dependencies))
[docs] def check_completeness(self) -> dict: """Returns a dictionary of success indicator and error messages.""" status_dict = {"success": True, "msg": ""} if self.start_timestamp is None: status_dict["success"] = False status_dict[ "msg" ] += f"{self.component_name} ComponentRun has no start timestamp. " if self.end_timestamp is None: status_dict["success"] = False status_dict[ "msg" ] += f"{self.component_name} ComponentRun has no end timestamp. " # Show warnings if there are no dependencies or I/O. if len(self.inputs) == 0: status_dict[ "msg" ] += f"{self.component_name} ComponentRun has no inputs. " if len(self.outputs) == 0: status_dict[ "msg" ] += f"{self.component_name} ComponentRun has no outputs. " if len(self.dependencies) == 0: status_dict[ "msg" ] += f"{self.component_name} ComponentRun has no dependencies. " # Make sure there are no circular dependencies. if self.id and self.id in [x.id for x in self.dependencies]: status_dict["success"] = False status_dict["msg"] += ( f"{self.component_name} ComponentRun has a " + "circular dependency. " ) return status_dict
[docs] def set_test_result(self, test_results: JSON): self.test_results = test_results