from datetime import datetime
from mltrace.db.utils import (
_create_engine_wrapper,
_initialize_db_tables,
_drop_everything,
_map_extension_to_enum,
_hash_value,
_get_data_and_model_args,
_load,
_save,
)
from mltrace.db import (
Component,
ComponentRun,
IOPointer,
PointerTypeEnum,
Tag,
component_run_output_association,
)
from mltrace.db.models import component_run_output_association
from sqlalchemy import func, and_
from sqlalchemy.orm import sessionmaker, joinedload
from sqlalchemy.sql.expression import Tuple
import ctypes
import hashlib
import logging
import typing
[docs]class Store(object):
"""Helper methods to interact with the db."""
def __init__(self, uri: str, delete_first: bool = False):
"""
Creates the postgres database for the store. Raises exception if uri
isn't prefixed with postgresql://.
Args:
uri (str): URI string to connect to the SQLAlchemy database.
delete_first (bool): Whether all the tables in the db should be
deleted.
"""
if uri.lower().strip() == "test":
uri = "sqlite:///:memory:"
elif not uri.startswith("postgresql://"):
raise RuntimeError(
"Database URI must be prefixed with `postgresql://`"
)
self.engine = _create_engine_wrapper(uri)
# TODO(shreyashankar) remove this line
if delete_first:
_drop_everything(self.engine)
# TODO(shreyashankar) check existing tables against expected tables
_initialize_db_tables(self.engine)
# Initialize session
self.Session = sessionmaker(self.engine)
self.session = self.Session()
def __del__(self):
"""On destruction, close session."""
self.session.close()
[docs] def create_component(
self,
name: str,
description: str,
owner: str,
tags: typing.List[str] = [],
):
"""Creates a component entity in the database if it does
not already exist."""
res = self.get_component(name)
if res:
logging.info(f'Component with name "{name}" already exists.')
return
# Add to the DB if it is not already there
logging.info(
f'Creating new Component with name "{name}", description '
+ f'"{description}", owner "{owner}", and tags "{tags}".'
)
tags = list(set([self.get_tag(t) for t in tags]))
component = Component(
name=name, description=description, owner=owner, tags=tags
)
self.session.add(component)
self.session.commit()
[docs] def get_component(self, name: str) -> Component:
"""Retrieves component if exists."""
component = (
self.session.query(Component)
.outerjoin(Tag, Component.tags)
.filter(Component.name == name)
.first()
)
return component
[docs] def get_component_run(self, id: str) -> ComponentRun:
"""Retrieves component run if exists."""
component_run = (
self.session.query(ComponentRun)
.filter(ComponentRun.id == id)
.first()
)
return component_run
[docs] def unflag_all(self):
"""Unflags all IO Pointers and commits."""
flagged_iop = (
self.session.query(IOPointer)
.filter(IOPointer.flag.is_(True))
.all()
)
for iop in flagged_iop:
iop.clear_flag()
self.session.commit()
[docs] def initialize_empty_component_run(
self, component_name: str
) -> ComponentRun:
"""Initializes an empty run for the specified component. Does not
commit to the database."""
component_run = ComponentRun(component_name=component_name)
return component_run
[docs] def get_tag(self, name=str) -> Tag:
"""Creates a tag around the name if it doesn't already exist."""
res = self.session.query(Tag).filter(Tag.name == name).all()
# Must create new Tag
if len(res) == 0:
logging.info(f'Creating new Tag with name "{name}".')
tag = Tag(name)
self.session.add(tag)
self.session.commit()
return tag
# Return existing Tag
return res[0]
[docs] def get_io_pointers(
self,
names: typing.List[str],
values: typing.List[typing.Any] = None,
pointer_type: PointerTypeEnum = None,
) -> typing.List[IOPointer]:
"""Creates io pointers around the specified path names. Retrieves
existing io pointer if exists in DB, otherwise creates a new one with
inferred pointer type."""
values = (
[_hash_value(v) for v in values] if values else [b""] * len(names)
)
res = (
self.session.query(IOPointer)
.filter(
Tuple(IOPointer.name, IOPointer.value).in_(
list(zip(names, values))
)
)
.all()
)
res_names_values = set([(r.name, r.value) for r in res])
need_to_add = set(zip(names, values)) - res_names_values
if len(need_to_add) != 0:
# Create new IOPointers
if pointer_type is None:
pointer_type = _map_extension_to_enum(
next(iter(need_to_add))[0]
)
iops = [
IOPointer(
name=name,
value=value,
pointer_type=pointer_type,
)
for name, value in need_to_add
]
self.session.add_all(iops)
self.session.commit()
return res + iops
return res
[docs] def get_io_pointer(
self,
name: str,
value: typing.Any = "",
pointer_type: PointerTypeEnum = None,
create=True,
) -> IOPointer:
"""Creates an io pointer around the specified path.
Retrieves existing io pointer if exists in DB,
otherwise creates a new one if create flag is set."""
hval = _hash_value(value)
same_name_res = (
self.session.query(
component_run_output_association.c.output_path_value
)
.filter(
component_run_output_association.c.output_path_name == name
)
.order_by(
component_run_output_association.c.component_run_id.desc()
)
.all()
)
same_name_res = [r[0] for r in same_name_res]
if len(same_name_res) > 0 and bytes(same_name_res[0]) != hval:
logging.warning(
f'IOPointer with name "{name}" has a different value '
+ "from the last write."
)
res = (
self.session.query(IOPointer)
.filter(and_(IOPointer.name == name, IOPointer.value == hval))
.all()
)
# Must create new IOPointer
if len(res) == 0:
if create is False:
raise RuntimeError(
f"IOPointer with name {name} noes not exist. Set create"
+ f" flag to True if you would like to create it."
)
logging.info(f'Creating new IOPointer with name "{name}".')
if pointer_type is None:
pointer_type = _map_extension_to_enum(name)
iop = IOPointer(name=name, value=hval, pointer_type=pointer_type)
self.session.add(iop)
self.session.commit()
return iop
# Return existing object
return res[0]
[docs] def delete_component(self, component: Component):
self.session.delete(component)
logging.info(
f'Successfully deleted Component with name "{component.name}".'
)
[docs] def delete_component_run(self, component_run: ComponentRun):
self.session.delete(component_run)
logging.info(
f'Successfully deleted ComponentRun with id "{component_run.id}"'
+ f' and name "{component_run.component_name}".'
)
[docs] def delete_io_pointer(self, io_pointer: IOPointer):
self.session.delete(io_pointer)
logging.info(
f'Successfully deleted IOPointer with name "{io_pointer.name}".'
)
[docs] def commit_component_run(
self,
component_run: ComponentRun,
staleness_threshold: int = (60 * 60 * 24 * 30),
):
"""Commits a fully initialized component run to the DB."""
status_dict = component_run.check_completeness()
if not status_dict["success"]:
raise RuntimeError(status_dict["msg"])
if status_dict["msg"]:
logging.info(status_dict["msg"])
# Check for staleness
for dep in component_run.dependencies:
# First case: there is over a month between component runs
time_diff = (
component_run.start_timestamp - dep.start_timestamp
).total_seconds()
if time_diff > staleness_threshold:
days_diff = int(time_diff // (60 * 60 * 24))
component_run.add_staleness_message(
f"{dep.component_name} (ID {dep.id}) was run {days_diff}"
+ " days ago."
)
# Second case: there is a newer run of the dependency
fresher_runs = self.get_history(
dep.component_name,
limit=None,
date_lower=dep.start_timestamp,
date_upper=component_run.start_timestamp,
)
fresher_runs = [
cr for cr in fresher_runs if component_run.id != cr.id
]
if len(fresher_runs) > 1:
run_or_runs = "run" if len(fresher_runs) - 1 == 1 else "runs"
component_run.add_staleness_message(
f"{dep.component_name} (ID {dep.id}) has "
+ f"{len(fresher_runs) - 1} fresher {run_or_runs} that "
+ "began before this component run started."
)
# Warn user if there is a staleness message
if len(component_run.stale) > 0:
logging.warning(component_run.stale)
# Commit to DB
self.session.add(component_run)
logging.info(
f"Committing ComponentRun of type "
+ f'"{component_run.component_name}" to the database.'
)
self.session.commit()
def _traverse(
self,
node: ComponentRun,
depth: int,
node_list: typing.List[ComponentRun],
):
# Add node to node_list as the step
node_list.append((depth, node))
# Base case
if len(node.dependencies) == 0:
return
# Recurse on neighbors
for neighbor in node.dependencies:
self._traverse(neighbor, depth + 1, node_list)
def _web_trace_helper(self, component_run_object: ComponentRun):
"""Helper function that populates the dictionary of ComponentRuns for
the web trace. Returns dictionary and counter."""
res = {}
res["id"] = f"componentrun_{component_run_object.id}"
res["label"] = component_run_object.component_name
res["hasCaret"] = True
res["isExpanded"] = True
res["stale"] = component_run_object.stale
res["childNodes"] = []
for out in sorted(component_run_object.outputs, key=lambda x: x.name):
out_dict = {
"id": f"iopointer_{out.name}",
"label": out.name,
"hasCaret": False,
"parent": res["id"],
}
# Settle on icon
if out.pointer_type == PointerTypeEnum.DATA:
out_dict["icon"] = "database"
elif out.pointer_type == PointerTypeEnum.MODEL:
out_dict["icon"] = "function"
elif out.pointer_type == PointerTypeEnum.ENDPOINT:
out_dict["icon"] = "flow-end"
res["childNodes"].append(out_dict)
for dep in sorted(
component_run_object.dependencies, key=lambda x: x.id
):
child_res = self._web_trace_helper(dep)
res["childNodes"].append(child_res)
return res
[docs] def web_trace(self, output_id: str, last_only: bool = False):
"""Prints list of ComponentRuns to display in the UI."""
component_run_objects = (
self.session.query(ComponentRun)
.outerjoin(IOPointer, ComponentRun.outputs)
.order_by(ComponentRun.start_timestamp.desc())
.filter(IOPointer.name == output_id)
.all()
)
if len(component_run_objects) == 0:
raise RuntimeError(f"ID {output_id} does not exist.")
if last_only:
component_run_objects = [component_run_objects[0]]
return [self._web_trace_helper(cr) for cr in component_run_objects]
[docs] def trace(self, output_id: str):
"""Prints trace for an output id.
Returns list of tuples (level, ComponentRun) where level is how
many hops away the node is from the node that produced the
output_id."""
if not isinstance(output_id, str):
raise RuntimeError("Please specify an output id of string type.")
component_run_object = (
self.session.query(ComponentRun)
.outerjoin(IOPointer, ComponentRun.outputs)
.order_by(ComponentRun.start_timestamp.desc())
.filter(IOPointer.name == output_id)
.first()
)
if component_run_object is None:
raise RuntimeError(f"ID {output_id} does not exist.")
node_list = []
self._traverse(component_run_object, 0, node_list)
return node_list
[docs] def trace_batch(self, output_ids: typing.List[str]):
pass
[docs] def get_history(
self,
component_name: str,
limit: int = 10,
date_lower: typing.Union[datetime, str] = datetime.min,
date_upper: typing.Union[datetime, str] = datetime.max,
) -> typing.List[ComponentRun]:
"""Gets lineage for the component, or a history of all its runs."""
history = (
self.session.query(ComponentRun)
.filter(ComponentRun.component_name == component_name)
.filter(
and_(
ComponentRun.start_timestamp >= date_lower,
ComponentRun.start_timestamp <= date_upper,
)
)
.order_by(ComponentRun.start_timestamp.desc())
.limit(limit)
.all()
)
return history
[docs] def get_components(self, tag: str = "", owner: str = ""):
"""Returns a list of all the components associated with the specified
owner and/or tags."""
if tag and owner:
components = (
self.session.query(Component)
.join(Tag, Component.tags)
.filter(
and_(
Tag.name == tag,
Component.owner == owner,
)
)
.all()
)
elif tag:
components = (
self.session.query(Component)
.join(Tag, Component.tags)
.filter(Tag.name == tag)
.all()
)
elif owner:
components = (
self.session.query(Component)
.filter(Component.owner == owner)
.options(joinedload("tags"))
.all()
)
else:
components = self.session.query(Component).all()
if len(components) == 0:
raise RuntimeError(f"Search yielded no components.")
return components
[docs] def get_recent_run_ids(
self, limit: int = 50, last_run_id=None
) -> typing.List[str]:
"""Returns a list of recent component run IDs."""
if last_run_id:
# Get start timestamp of last run id
ts = (
self.session.query(ComponentRun)
.filter(ComponentRun.id == last_run_id)
.first()
).start_timestamp
if not ts:
raise RuntimeError(
f"Last run ID {last_run_id} does not exist."
)
# Return list of runs after ts
runs = list(
map(
lambda x: int(x[0]),
self.session.query(ComponentRun.id)
.order_by(ComponentRun.start_timestamp.desc())
.filter(
and_(
ComponentRun.start_timestamp <= ts,
ComponentRun.id != last_run_id,
)
)
.limit(limit)
.all(),
)
)
return runs
runs = list(
map(
lambda x: int(x[0]),
self.session.query(ComponentRun.id)
.order_by(ComponentRun.start_timestamp.desc())
.limit(limit)
.all(),
)
)
return runs
[docs] def add_notes_to_component_run(
self, component_run_id: str, notes: str
) -> str:
"""Retreives existing component and adds tags."""
component_run = self.get_component_run(component_run_id)
if not component_run:
raise RuntimeError(
f'ComponentRun with ID "{component_run_id}" not found.'
)
component_run.add_notes(notes)
self.session.commit()
return component_run.notes
[docs] def set_io_pointer_flag(self, output_id: str, value: bool):
"""Sets the flag property of an IOPointer."""
try:
iop = self.get_io_pointer(output_id, create=False)
if value:
iop.set_flag()
else:
iop.clear_flag()
self.session.commit()
return value
except RuntimeError:
raise RuntimeError(
f"IOPointer with name {output_id} does not exist."
)
[docs] def review_flagged_outputs(
self,
) -> typing.Tuple[
typing.List[str], typing.List[typing.Tuple[ComponentRun, int]]
]:
"""Finds common ComponentRuns for a group of flagged outputs."""
# Collate flagged outputs
flagged_iops = (
self.session.query(IOPointer)
.filter(IOPointer.flag.is_(True))
.all()
)
flagged_output_ids = [iop.name for iop in flagged_iops]
# Perform traces for each output id
traces = [self.trace(output_id) for output_id in flagged_output_ids]
traces = [list(set([node for _, node in trace])) for trace in traces]
# Sort traces by ComponentRun count & id, descending
trace_nodes_counts = {}
for trace in traces:
for node in trace:
if node not in trace_nodes_counts:
trace_nodes_counts[node] = 0
trace_nodes_counts[node] += 1
trace_nodes_counts = sorted(
trace_nodes_counts.items(),
key=lambda item: (-item[1], -item[0].id),
)
# Return a list of the ComponentRuns in the order
return flagged_output_ids, trace_nodes_counts
[docs] def get_io_pointers_from_args(self, **kwargs):
"""Filters kwargs to data and model types,
then gets corresponding IOPointers."""
args_filtered = _get_data_and_model_args(**kwargs)
io_pointers = []
# Hash each arg and see if the corresponding IOPointer exists
for key, value in args_filtered.items():
hval = _hash_value(value)
same_name_res = (
self.session.query(
component_run_output_association.c.output_path_name
)
.filter(
component_run_output_association.c.output_path_value
== hval
)
.order_by(
component_run_output_association.c.component_run_id.desc()
)
.first()
)
if same_name_res:
res = (
self.session.query(IOPointer)
.filter(
and_(
IOPointer.name == same_name_res[0],
IOPointer.value == hval,
)
)
.all()
)
io_pointers.append(res[0])
continue
# See if IOPointer exists but not in output table
res = (
self.session.query(IOPointer)
.filter(
IOPointer.value == hval,
)
.first()
)
if res:
io_pointers.append(res)
continue
# Save artifact and create new IOPointer
pathname = _save(value, var_name=key, from_client=False)
iop = self.get_io_pointer(pathname, value)
io_pointers.append(iop)
return io_pointers