"""
Base Component class. Other components should inherit from this class.
"""
import json
import logging
import typing
from mltrace import client
from mltrace import utils as clientUtils
from mltrace.db import Store, PointerTypeEnum
from mltrace.entities import utils
from mltrace.entities.base import Base
import functools
import inspect
import io
import git
import sys
[docs]class Component(Base):
def __init__(
self,
name: str = "",
owner: str = "",
description: str = "",
beforeTests: list = [],
afterTests: list = [],
tags: typing.List[str] = [],
):
"""Components abstraction.
Components should have a name, owner, and lists
of before and after tests to run. Optionally they will have tags."""
self._name = name
self._owner = owner
self._description = description
self._tags = tags
self._beforeTests = beforeTests
self._afterTests = afterTests
[docs] def beforeRun(self, **kwargs):
"""Computation to execute before running a component.
Will run each test object listed in beforeTests."""
for test in self._beforeTests:
test().runTests(**kwargs)
[docs] def afterRun(self, **local_vars):
"""Computation to execute after running a component.
Will run all test objects listed in afterTests."""
for test in self._afterTests:
test().runTests(**local_vars)
[docs] def run(
self,
inputs: typing.List[str] = [],
outputs: typing.List[str] = [],
input_vars: typing.List[str] = [],
output_vars: typing.List[str] = [],
input_kwargs: typing.Dict[str, str] = {},
output_kwargs: typing.Dict[str, str] = {},
endpoint: bool = False,
staleness_threshold: int = (60 * 60 * 24 * 30),
auto_log: bool = False,
*user_args,
**user_kwargs,
):
"""
Decorator around the function executed:
c = Component()
@c.run
def my_function(arg1, arg2):
do_something()
arg1 and arg2 are the arguments passed to the
beforeRun and afterRun methods.
We first execute the beforeRun method, then the function itself,
then the afterRun method with the values of the args at the
end of the function.
ADD DESCRIPTION HERE ABOUT INPUT VARIABLEs and what they are
"""
inv_user_kwargs = {v: k for k, v in user_kwargs.items()}
key_names = ["skip_before", "skip_after"]
def actual_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Construct component run object
store = Store(clientUtils.get_db_uri())
component_run = store.initialize_empty_component_run(self.name)
component_run.set_start_timestamp()
# Assert key names are not in args or kwargs
if (
set(key_names) & set(inspect.getfullargspec(func).args)
) or (set(key_names) & set(kwargs.keys())):
raise ValueError(
"skip_before or skip_after cannot be in "
+ f"the arguments of the function {func.__name__}"
)
# Run before test
if not user_kwargs.get("skip_before"):
all_args = dict(
zip(inspect.getfullargspec(func).args, args)
)
all_args = {
k
if k not in inv_user_kwargs
else inv_user_kwargs[k]: v
for k, v in all_args.items()
}
all_args = {**all_args, **kwargs}
self.beforeRun(**all_args)
# Create input and output pointers
input_pointers = []
output_pointers = []
# Auto log inputs
if auto_log:
# Get IOPointers corresponding to args and f_locals
all_input_args = {
k: v.default
for k, v in inspect.signature(func).parameters.items()
if v.default is not inspect.Parameter.empty
}
all_input_args = {
**all_input_args,
**dict(zip(inspect.getfullargspec(func).args, args)),
}
all_input_args = {**all_input_args, **kwargs}
# print(all_input_args.keys())
input_pointers += store.get_io_pointers_from_args(
**all_input_args
)
# Run function
local_vars, value = utils.run_func_capture_locals(
func, *args, **kwargs
)
component_run.set_end_timestamp()
# Add input_vars and output_vars as pointers
for var in input_vars:
if var not in local_vars:
raise ValueError(
f"Variable {var} not in current stack frame."
)
val = local_vars[var]
if val is None:
logging.debug(f"Variable {var} has value {val}.")
continue
if isinstance(val, list):
input_pointers += store.get_io_pointers(val)
else:
input_pointers.append(store.get_io_pointer(str(val)))
for var in output_vars:
if var not in local_vars:
raise ValueError(
f"Variable {var} not in current stack frame."
)
val = local_vars[var]
if val is None:
logging.debug(f"Variable {var} has value {val}.")
continue
if isinstance(val, list):
output_pointers += (
store.get_io_pointers(
val, pointer_type=PointerTypeEnum.ENDPOINT
)
if endpoint
else store.get_io_pointers(val)
)
else:
output_pointers += (
[
store.get_io_pointer(
str(val),
pointer_type=PointerTypeEnum.ENDPOINT,
)
]
if endpoint
else [store.get_io_pointer(str(val))]
)
# Add input_kwargs and output_kwargs as pointers
for key, val in input_kwargs.items():
if key not in local_vars or val not in local_vars:
raise ValueError(
f"({key}, {val}) not in current stack frame."
)
if local_vars[key] is None:
logging.debug(
f"Variable {key} has value {local_vars[key]}."
)
continue
if isinstance(local_vars[key], list):
if not isinstance(local_vars[val], list) or len(
local_vars[key]
) != len(local_vars[val]):
raise ValueError(
f'Value "{val}" does not have the same '
+ f'length as the key "{key}."'
)
input_pointers += store.get_io_pointers(
local_vars[key], values=local_vars[val]
)
else:
input_pointers.append(
store.get_io_pointer(
str(local_vars[key]), local_vars[val]
)
)
for key, val in output_kwargs.items():
if key not in local_vars or val not in local_vars:
raise ValueError(
f"({key}, {val}) not in current stack frame."
)
if local_vars[key] is None:
logging.debug(
f"Variable {key} has value {local_vars[key]}."
)
continue
if isinstance(local_vars[key], list):
if not isinstance(local_vars[val], list) or len(
local_vars[key]
) != len(local_vars[val]):
raise ValueError(
f'Value "{val}" does not have the same '
+ f'length as the key "{key}."'
)
output_pointers += (
store.get_io_pointers(
local_vars[key],
local_vars[val],
pointer_type=PointerTypeEnum.ENDPOINT,
)
if endpoint
else store.get_io_pointers(
local_vars[key], local_vars[val]
)
)
else:
output_pointers += (
[
store.get_io_pointer(
str(local_vars[key]),
local_vars[val],
pointer_type=PointerTypeEnum.ENDPOINT,
)
]
if endpoint
else [
store.get_io_pointer(
str(local_vars[key]), local_vars[val]
)
]
)
# Directly specified I/O
if not callable(inputs):
input_pointers += [
store.get_io_pointer(inp) for inp in inputs
]
output_pointers += (
[
store.get_io_pointer(
out, pointer_type=PointerTypeEnum.ENDPOINT
)
for out in outputs
]
if endpoint
else [store.get_io_pointer(out) for out in outputs]
)
# If there were calls to mltrace.load and mltrace.save, log
if "_mltrace_loaded_artifacts" in local_vars:
input_pointers += [
store.get_io_pointer(name, val)
for name, val in local_vars[
"_mltrace_loaded_artifacts"
].items()
]
if "_mltrace_saved_artifacts" in local_vars:
output_pointers += [
store.get_io_pointer(name, val)
for name, val in local_vars[
"_mltrace_saved_artifacts"
].items()
]
func_source_code = inspect.getsource(func)
if auto_log:
# Get IOPointers corresponding to args and f_locals
all_output_args = {
k: v
for k, v in local_vars.items()
if k not in all_input_args
}
output_pointers += store.get_io_pointers_from_args(
**all_output_args
)
component_run.add_inputs(input_pointers)
component_run.add_outputs(output_pointers)
# Add code versions
try:
repo = git.Repo(search_parent_directories=True)
component_run.set_git_hash(str(repo.head.object.hexsha))
except Exception as e:
logging.info("No git repo found.")
# Add git tags
if client.get_git_tags() is not None:
component_run.set_git_tags(client.get_git_tags())
# Add source code if less than 2^16
if len(func_source_code) < 2 ** 16:
component_run.set_code_snapshot(
bytes(func_source_code, "ascii")
)
# Create component if it does not exist
client.create_component(
self.name, self.description, self.owner, self.tags
)
# Set dependencies
store.set_dependencies_from_inputs(component_run)
# Commit component run object to the DB
store.commit_component_run(
component_run, staleness_threshold=staleness_threshold
)
# Perform after run tests
if not user_kwargs.get("skip_after"):
# Run after test
after_run_args = {
k
if k not in inv_user_kwargs
else inv_user_kwargs[k]: v
for k, v in local_vars.items()
}
self.afterRun(**after_run_args)
return value
return wrapper
if callable(inputs):
# Used decorator without arguments
return actual_decorator(inputs)
else:
# User passed in some kwargs
return actual_decorator
@property
def name(self) -> str:
return self._name
@property
def owner(self) -> str:
return self._owner
@property
def description(self) -> str:
return self._description
@property
def tags(self) -> typing.List[str]:
return self._tags
@property
def beforeTests(self) -> list:
return self._beforeTests
@property
def afterTests(self) -> list:
return self._afterTests
def __repr__(self):
params = self.to_dictionary()
del params["beforeTests"]
del params["afterTests"]
return json.dumps(params)