import contextlib
import csv
import io
import os
import shutil
import tempfile
from typing import List, Optional
from zipfile import ZipFile
import dallinger.data
import dallinger.models
import postgres_copy
import psutil
import six
import sqlalchemy
from dallinger import db
from dallinger.command_line.docker_ssh import CONFIGURED_HOSTS
from dallinger.data import fix_autoincrement
from dallinger.db import Base as SQLBase # noqa
from dallinger.experiment_server import dashboard
from dallinger.models import Info # noqa
from dallinger.models import Network # noqa
from dallinger.models import Node # noqa
from dallinger.models import Notification # noqa
from dallinger.models import Question # noqa
from dallinger.models import Recruitment # noqa
from dallinger.models import Transformation # noqa
from dallinger.models import Transmission # noqa
from dallinger.models import Vector # noqa
from dallinger.models import SharedMixin, timenow # noqa
from jsonpickle.util import importable_name
from sqlalchemy import Column, String
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import deferred
from sqlalchemy.orm.session import close_all_sessions
from sqlalchemy.schema import (
DropConstraint,
DropTable,
ForeignKeyConstraint,
MetaData,
Table,
)
from tqdm import tqdm
from yaspin import yaspin
from . import field
from .field import PythonDict, is_basic_type
from .utils import classproperty, json_to_data_frame, organize_by_key
[docs]
def get_db_tables():
"""
Lists the tables in the database.
Returns
-------
A dictionary where the keys identify the tables and the values are the table objects themselves.
"""
return db.Base.metadata.tables
def _get_superclasses_by_table():
"""
Returns
-------
A dictionary where the keys enumerate the different tables in the database
and the values correspond to the superclasses for each of those tables.
"""
mappers = list(db.Base.registry.mappers)
mapped_classes = [m.class_ for m in mappers]
mapped_classes_by_table = organize_by_key(mapped_classes, lambda x: x.__tablename__)
superclasses_by_table = {
cls: _get_superclass(class_list)
for cls, class_list in mapped_classes_by_table.items()
}
return superclasses_by_table
def _get_superclass(class_list):
"""
Given a list of classes, returns the class in that list that is a superclass of
all other classes in that list. Assumes that exactly one such class exists
in that list; if this is not true, an AssertionError is raised.
Parameters
----------
classes :
List of classes to check.
Returns
-------
A single superclass.
"""
superclasses = [cls for cls in class_list if _is_global_superclass(cls, class_list)]
assert len(superclasses) == 1
cls = superclasses[0]
cls = _get_preferred_superclass_version(cls)
return cls
def _is_global_superclass(x, class_list):
"""
Parameters
----------
x :
Class to test
class_list :
List of classes to test against
Returns
-------
``True`` if ``x`` is a superclass of all elements of ``class_list``, ``False`` otherwise.
"""
return all([issubclass(cls, x) for cls in class_list])
def _get_preferred_superclass_version(cls):
"""
Given an SQLAlchemy superclass for SQLAlchemy-mapped objects (e.g. ``Info``),
looks to see if there is a preferred version of this superclass (e.g. ``Trial``)
that still covers all instances in the database.
Parameters
----------
cls :
Class to simplify
Returns
-------
A simplified class if one was found, otherwise the original class.
"""
import dallinger.models
import psynet.timeline
preferred_superclasses = {
dallinger.models.Info: psynet.trial.main.Trial,
psynet.bot.Bot: psynet.participant.Participant,
psynet.timeline._Response: psynet.timeline.Response,
}
proposed_cls = preferred_superclasses.get(cls)
if proposed_cls:
proposed_cls = preferred_superclasses[cls]
n_original_cls_instances = cls.query.count()
n_proposed_cls_instances = proposed_cls.query.count()
proposed_cls_has_equal_coverage = (
n_original_cls_instances == n_proposed_cls_instances
)
if proposed_cls_has_equal_coverage:
return proposed_cls
return cls
def _db_instance_to_dict(obj, scrub_pii: bool):
"""
Converts an ORM-mapped instance to a JSON-style representation.
Complex types (e.g. lists, dicts) are serialized to strings using
psynet.serialize.serialize.
Parameters
----------
obj
Object to convert.
scrub_pii
Whether to remove personally identifying information.
Returns
-------
JSON-style dictionary
"""
try:
data = obj.to_dict()
except AttributeError:
data = obj.__json__()
if "class" not in data:
data["class"] = obj.__class__.__name__ # for the Dallinger classes
if scrub_pii and hasattr(obj, "scrub_pii"):
data = obj.scrub_pii(data)
for key, value in data.items():
if not is_basic_type(value):
from .serialize import serialize
data[key] = serialize(value)
return data
def _prepare_db_export(scrub_pii: bool):
"""
Encodes the database to a JSON-style representation suitable for export.
Parameters
----------
scrub_pii
Whether to remove personally identifying information.
Returns
-------
A dictionary keyed by class names with lists of JSON-style
encoded class instances as values.
The keys correspond to the most-specific available class names,
e.g. ``CustomNetwork`` as opposed to ``Network``.
"""
from psynet.experiment import get_experiment
exp = get_experiment()
tables = get_db_tables().values()
obj_sql_by_table = [exp.pull_table(table) for table in tables]
obj_sql = [obj for sublist in obj_sql_by_table for obj in sublist]
obj_sql_by_cls = organize_by_key(obj_sql, key=lambda x: x.__class__.__name__)
obj_dict_by_cls = {
_cls_name: [
_db_instance_to_dict(obj, scrub_pii)
for obj in tqdm(_obj_sql_for_cls, desc=_cls_name)
]
for _cls_name, _obj_sql_for_cls in obj_sql_by_cls.items()
if _cls_name not in exp.export_classes_to_skip
}
return obj_dict_by_cls
def copy_db_table_to_csv(tablename, path):
# TODO - improve naming of copy_db_table_to_csv and dump_db_to_disk to clarify
# that the former is a Dallinger export and the latter is a PsyNet export
with tempfile.TemporaryDirectory() as tempdir:
dallinger.data.copy_db_to_csv(db.db_url, tempdir)
temp_filename = f"{tablename}.csv"
shutil.copyfile(os.path.join(tempdir, temp_filename), path)
[docs]
def dump_db_to_disk(dir, scrub_pii: bool):
"""
Exports all database objects to JSON-style dictionaries
and writes them to CSV files, one for each class type.
Parameters
----------
dir
Directory to which the CSV files should be exported.
scrub_pii
Whether to remove personally identifying information.
"""
from .utils import make_parents
objects_by_class = _prepare_db_export(scrub_pii)
for cls, objects in objects_by_class.items():
filename = cls + ".csv"
filepath = os.path.join(dir, filename)
with open(make_parents(filepath), "w") as file:
json_to_data_frame(objects).to_csv(file, index=False)
[docs]
class InvalidDefinitionError(ValueError):
"""
InvalidDefinitionError class
"""
pass
checked_classes = set()
[docs]
class SQLMixinDallinger(SharedMixin):
"""
We apply this Mixin class when subclassing Dallinger classes,
for example ``Network`` and ``Info``.
It adds a few useful exporting features,
but most importantly it adds automatic mapping logic,
so that polymorphic identities are constructed automatically from
class names instead of having to be specified manually.
For example:
::
from dallinger.models import Info
class CustomInfo(Info)
pass
"""
polymorphic_identity = (
None # set this to a string if you want to customize your polymorphic identity
)
__extra_vars__ = {}
def __new__(cls, *args, **kwargs):
self = super().__new__(cls)
cls.check_validity()
return self
def __repr__(self):
base_class = get_sql_base_class(self).__name__
cls = self.__class__.__name__
return "{}-{}-{}".format(base_class, self.id, cls)
@declared_attr
def vars(cls):
return deferred(Column(PythonDict, default=lambda: {}, server_default="{}"))
@property
def var(self):
from .field import VarStore
return VarStore(self)
[docs]
def to_dict(self):
"""
Determines the information that is shown for this object in the dashboard
and in the csv files generated by ``psynet export``.
"""
from psynet.trial import ChainNode
from psynet.trial.main import GenericTrialNode
x = {c: getattr(self, c) for c in self.sql_columns}
x["class"] = self.__class__.__name__
# This is a little hack we do for compatibility with the Dallinger
# network visualization, which relies on sources being explicitly labeled.
if isinstance(self, GenericTrialNode) or (
isinstance(self, ChainNode) and self.degree == 0
):
x["type"] = "TrialSource"
else:
x["type"] = x["class"]
# Dallinger also needs us to set a parameter called ``object_type``
# which is used to determine the visualization method.
base_class = get_sql_base_class(self)
x["object_type"] = base_class.__name__ if base_class else x["type"]
field.json_add_extra_vars(x, self)
field.json_clean(x, details=True)
field.json_format_vars(x)
return x
def __json__(self) -> dict:
"Used to transmit the item to the Dallinger dashboard"
data = self.to_dict()
for key, value in data.items():
if not is_basic_type(value):
data[key] = repr(value)
return data
@classproperty
def sql_columns(cls):
return cls.__mapper__.column_attrs.keys()
@classproperty
def inherits_table(cls):
for ancestor_cls in cls.__mro__[1:]:
if (
hasattr(ancestor_cls, "__tablename__")
and ancestor_cls.__tablename__ is not None
):
return True
return False
@classmethod
def ancestor_has_same_polymorphic_identity(cls, polymorphic_identity):
for ancestor_cls in cls.__mro__[1:]:
if (
hasattr(ancestor_cls, "polymorphic_identity")
and ancestor_cls.polymorphic_identity == polymorphic_identity
):
return True
return False
@declared_attr
def __mapper_args__(cls):
"""
This programmatic definition of polymorphic_identity and polymorphic_on
means that users can define new SQLAlchemy classes without any reference
to these SQLAlchemy constructs. Instead the polymorphic mappers are
constructed automatically based on class names.
"""
# If the class has a distinct polymorphic_identity attribute, use that
cls.check_validity()
if cls.polymorphic_identity and not cls.ancestor_has_same_polymorphic_identity(
cls.polymorphic_identity
):
polymorphic_identity = cls.polymorphic_identity
else:
# Otherwise, take the polymorphic_identity from the fully qualified class name
polymorphic_identity = importable_name(cls)
x = {"polymorphic_identity": polymorphic_identity}
if not cls.inherits_table:
x["polymorphic_on"] = cls.type
return x
__validity_checks_complete__ = False
@classmethod
def check_validity(cls):
if cls not in checked_classes:
cls._check_validity()
checked_classes.add(cls)
@classmethod
def _check_validity(cls):
if cls.defined_in_invalid_location():
raise InvalidDefinitionError(
f"Problem detected with the definition of class {cls.__name__}:"
"You are not allowed to define SQLAlchemy classes in unconventional places, "
"e.g. as class attributes of other classes, within functions, etc. - "
"it can cause some very hard to debug problems downstream, "
"for example silently breaking SQLAlchemy relationship updating. "
"You should instead define your class at the top level of a Python file."
)
@classmethod
def defined_in_invalid_location(cls):
from jsonpickle.util import importable_name
path = importable_name(cls)
family = path.split(".")
ancestors = family[:-1]
parent_path = ".".join(ancestors)
return parent_path != cls.__module__
# if "<locals>" in parent_path:
# return True
#
# parent = loadclass(parent_path)
# if parent is None or isclass(parent):
# return True
#
# return False
[docs]
def scrub_pii(self, json):
"""
Removes personally identifying information from the object's JSON representation.
This is a destructive operation (it changes the input object).
"""
try:
del json["worker_id"]
except KeyError:
pass
return json
#
# @event.listens_for(SQLMixinDallinger, "after_insert", propagate=True)
# def after_insert(mapper, connection, target):
# # obj = unserialize(serialize(target))
# old_session = db.session
# db.session = db.scoped_session(db.session_factory) # db.create_scoped_session()
# obj = unserialize(serialize(target))
# obj.on_creation()
# # target.on_creation()
# db.session.commit()
# db.session = old_session
[docs]
class SQLMixin(SQLMixinDallinger):
"""
We apply this mixin when creating our own SQL-backed classes from scratch.
For example:
::
from psynet.data import SQLBase, SQLMixin, register_table
@register_table
class Bird(SQLBase, SQLMixin):
__tablename__ = "bird"
class Sparrow(Bird):
pass
"""
@declared_attr
def type(cls):
return Column(String(50))
old_init_db = dallinger.db.init_db
def init_db(drop_all=False, bind=db.engine):
# Without these preliminary steps, the process can freeze --
# https://stackoverflow.com/questions/24289808/drop-all-freezes-in-flask-with-sqlalchemy
db.session.commit()
close_all_sessions()
with yaspin(
text="Initializing the database...",
color="green",
) as spinner:
old_init_db(drop_all, bind)
spinner.ok("✔")
import time
time.sleep(1) # Todo - remove this if it doesn't break the tests?
return db.session
dallinger.db.init_db = init_db
[docs]
def drop_all_db_tables(bind=db.engine):
"""
Drops all tables from the Postgres database.
Includes a workaround for the fact that SQLAlchemy doesn't provide a CASCADE option to ``drop_all``,
which was causing errors with Dallinger's version of database resetting in ``init_db``.
(https://github.com/pallets-eco/flask-sqlalchemy/issues/722)
"""
from sqlalchemy.exc import ProgrammingError
engine = bind
db.session.commit()
con = engine.connect()
trans = con.begin()
all_fkeys, tables = list_fkeys()
for fkey in all_fkeys:
try:
con.execute(DropConstraint(fkey))
except ProgrammingError as err:
if "UndefinedTable" in str(err):
pass
else:
raise
for table in tables:
try:
con.execute(DropTable(table))
except ProgrammingError as err:
if "UndefinedTable" in str(err):
pass
else:
raise
trans.commit()
# Calling _old_drop_all helps clear up edge cases, such as the dropping of enum types
_old_drop_all(bind=bind)
def list_fkeys():
inspector = sqlalchemy.inspect(db.engine)
# We need to re-create a minimal metadata with only the required things to
# successfully emit drop constraints and tables commands for postgres (based
# on the actual schema of the running instance)
meta = MetaData()
tables = []
all_fkeys = []
for table_name in inspector.get_table_names():
fkeys = []
for fkey in inspector.get_foreign_keys(table_name):
if not fkey["name"]:
continue
fkeys.append(ForeignKeyConstraint((), (), name=fkey["name"]))
tables.append(Table(table_name, meta, *fkeys))
all_fkeys.extend(fkeys)
return all_fkeys, tables
_old_drop_all = dallinger.db.Base.metadata.drop_all
dallinger.db.Base.metadata.drop_all = drop_all_db_tables
# @contextlib.contextmanager
# def disable_foreign_key_constraints():
# db.session.execute("SET session_replication_role = replica;")
# # connection.execute("SET session_replication_role = replica;")
# yield
# db.session.execute("SET session_replication_role = DEFAULT;")
# This would have been useful for importing data, however in practice
# it caused the import process to hang.
#
@contextlib.contextmanager
def disable_foreign_key_constraints():
db.session.commit()
# con = db.engine.connect()
# trans = con.begin()
all_fkeys, tables = list_fkeys()
for fkey in all_fkeys:
# con.execute(DropConstraint(fkey))
db.session.execute(DropConstraint(fkey))
db.session.commit()
yield
# This code was meant to re-add the constraints afterwards, but it causes an error that we have not been
# able to debug, so we have disabled it. It should not be too much of a problem, though; SQLAlchemy
# should protect us from foreign key misuse anyway.
#
# for fkey in all_fkeys:
# # con.execute(AddConstraint(fkey))
# print(fkey)
# db.session.execute(AddConstraint(fkey))
#
# db.session.commit()
# trans.commit()
def _sql_dallinger_base_classes():
"""
These base classes define the basic object relational mappers for the
Dallinger database tables.
Returns
-------
A dictionary of base classes for Dallinger tables
keyed by Dallinger table names.
"""
from .participant import Participant
return {
"info": Info,
"network": Network,
"node": Node,
"notification": Notification,
"participant": Participant,
"question": Question,
"recruitment": Recruitment,
"transformation": Transformation,
"transmission": Transmission,
"vector": Vector,
}
# A dictionary of base classes for additional tables that are defined in PsyNet
# or by individual experiment implementations, keyed by table names.
# See also dallinger_table_base_classes().
_sql_psynet_base_classes = {}
[docs]
def sql_base_classes():
"""
Lists the base classes underpinning the different SQL tables used by PsyNet,
including both base classes defined in Dallinger (e.g. ``Node``, ``Info``)
and additional classes defined in custom PsyNet tables.
Returns
-------
A dictionary of base classes (e.g. ``Node``), keyed by the corresponding
table names for those base classes (e.g. `node`).
"""
return {
**_sql_dallinger_base_classes(),
**_sql_psynet_base_classes,
}
[docs]
def get_sql_base_class(x):
"""
Return the SQLAlchemy base class of an object x, returning None if no such base class is found.
"""
for cls in sql_base_classes().values():
if isinstance(x, cls):
return cls
return None
[docs]
def register_table(cls):
"""
This decorator should be applied whenever defining a new
SQLAlchemy table. For example:
::
@register_table
class Bird(SQLBase, SQLMixin):
__tablename__ = "bird"
"""
_sql_psynet_base_classes[cls.__tablename__] = cls
setattr(dallinger.models, cls.__name__, cls)
update_dashboard_models()
dallinger.data.table_names.append(cls.__tablename__)
return cls
[docs]
def update_dashboard_models():
"Determines the list of objects in the dashboard database browser."
dashboard.BROWSEABLE_MODELS = sorted(
list(
{
"Participant",
"Network",
"Node",
"Trial",
"Response",
"Transformation",
"Transmission",
"Notification",
"Recruitment",
}
.union({cls.__name__ for cls in _sql_psynet_base_classes.values()})
.difference({"_Response"})
)
)
[docs]
def ingest_to_model(
file,
model,
engine=None,
clear_columns: Optional[List] = None,
replace_columns: Optional[dict] = None,
):
"""
Imports a CSV file to the database.
The implementation is similar to ``dallinger.data.ingest_to_model``,
but incorporates a few extra parameters (``clear_columns``, ``replace_columns``)
and does not fail for tables without an ``id`` column.
Parameters
----------
file :
CSV file to import (specified as a file handler, created for example by open())
model :
SQLAlchemy class corresponding to the objects that should be created.
clear_columns :
Optional list of columns to clear when importing the CSV file.
This is useful in the case of foreign-key constraints (e.g. participant IDs).
replace_columns :
Optional dictionary of values to set for particular columns.
"""
if engine is None:
engine = db.engine
if clear_columns or replace_columns:
with tempfile.TemporaryDirectory() as temp_dir:
patched_csv = os.path.join(temp_dir, "patched.csv")
patch_csv(file, patched_csv, clear_columns, replace_columns)
with open(patched_csv, "r") as patched_csv_file:
ingest_to_model(
patched_csv_file, model, clear_columns=None, replace_columns=None
)
else:
inspector = sqlalchemy.inspect(db.engine)
reader = csv.reader(file)
columns = tuple('"{}"'.format(n) for n in next(reader))
with disable_foreign_key_constraints():
postgres_copy.copy_from(
file, model, engine, columns=columns, format="csv", HEADER=False
)
column_names = [x["name"] for x in inspector.get_columns(model.__table__)]
if "id" in column_names:
fix_autoincrement(engine, model.__table__.name)
def patch_csv(infile, outfile, clear_columns, replace_columns):
import pandas as pd
df = pd.read_csv(infile)
_replace_columns = {**{col: pd.NA for col in clear_columns}, **replace_columns}
for col, value in _replace_columns.items():
df[col] = value
df.to_csv(outfile, index=False)
[docs]
def ingest_zip(path, engine=None):
"""
Given a path to a zip file created with `export()`, recreate the
database with the data stored in the included .csv files.
This is a patched version of dallinger.data.ingest_zip that incorporates
support for custom tables.
"""
if engine is None:
engine = db.engine
inspector = sqlalchemy.inspect(engine)
all_table_names = inspector.get_table_names()
import_order = [
"network",
"participant",
"response",
"node",
"info",
"notification",
"question",
"transformation",
"vector",
"transmission",
"asset",
]
for n in all_table_names:
if n not in import_order:
import_order.append(n)
with ZipFile(path, "r") as archive:
filenames = archive.namelist()
for tablename in import_order:
filename_template = f"data/{tablename}.csv"
matches = [f for f in filenames if filename_template in f]
if len(matches) == 0:
continue
elif len(matches) > 1:
raise IOError(
f"Multiple matches for {filename_template} found in archive: {matches}"
)
else:
filename = matches[0]
model = sql_base_classes()[tablename]
file = archive.open(filename)
if six.PY3:
file = io.TextIOWrapper(file, encoding="utf8", newline="")
ingest_to_model(file, model, engine)
dallinger.data.ingest_zip = ingest_zip
dallinger.data.ingest_to_model = ingest_to_model
def export_assets(
path,
include_private: bool,
experiment_assets_only: bool,
include_on_demand_assets: bool,
n_parallel=None,
server=None,
):
from joblib import Parallel, delayed
# Assumes we already have loaded the experiment into the local database,
# as would be the case if the function is called from psynet export.
if n_parallel:
n_jobs = n_parallel
else:
n_jobs = psutil.cpu_count()
if experiment_assets_only:
from .asset import ExperimentAsset as base_class
else:
from .asset import Asset as base_class
asset_query = db.session.query(base_class.id, base_class.personal)
if not include_private:
asset_query = asset_query.filter_by(personal=False)
asset_ids = [a.id for a in asset_query]
n_jobs = 1 # todo - fix - parallel (SSH?) export seems to cause a deadlock, so we disable it for now
Parallel(
n_jobs=n_jobs,
verbose=10,
backend="threading",
# backend="multiprocessing", # Slow compared to threading
)(
delayed(export_asset)(asset_id, path, include_on_demand_assets, server)
for asset_id in asset_ids
)
# Parallel(n_jobs=n_jobs)(delayed(db.session.close)() for _ in range(n_jobs))
# def close_parallel_db_sessions():
def export_asset(asset_id, root, include_on_demand_assets, server):
from .asset import Asset, OnDemandAsset
from .experiment import import_local_experiment
from .utils import make_parents
if server is None:
ssh_host = None
ssh_user = None
else:
server_info = CONFIGURED_HOSTS[server]
ssh_host = server_info["host"]
ssh_user = server_info.get("user")
import_local_experiment()
a = Asset.query.filter_by(id=asset_id).one()
if not include_on_demand_assets and isinstance(a, OnDemandAsset):
return
path = os.path.join(root, a.export_path)
make_parents(path)
try:
a.export(path, ssh_host=ssh_host, ssh_user=ssh_user)
except Exception:
print(f"An error occurred when trying to export the asset with id: {asset_id}")
raise