Source code for psynet.utils

import base64
import contextlib
import gettext
import glob
import hashlib
import importlib
import importlib.util
import inspect
import json
import logging
import os
import re
import sys
import time
from collections import OrderedDict
from datetime import datetime
from functools import cache, lru_cache, reduce, wraps
from os.path import abspath, dirname, exists
from os.path import join as join_path
from pathlib import Path
from typing import Type, Union
from urllib.parse import ParseResult, urlparse

import click
import html2text
import jsonpickle
import pexpect
from _hashlib import HASH as Hash
from babel.support import Translations
from dallinger.config import experiment_available
from flask import url_for
from flask.globals import current_app, request
from flask.templating import Environment, _render


def get_logger():
    return logging.getLogger()


logger = get_logger()
LOCALES_DIR = join_path(abspath(dirname(__file__)), "locales")


[docs] class NoArgumentProvided: """ We use this class as a replacement for ``None`` as a default argument, to distinguish cases where the user doesn't provide an argument from cases where they intentionally provide ``None`` as an argument. """ pass
def deep_copy(x): try: return jsonpickle.decode(jsonpickle.encode(x)) except Exception: logger.error(f"Failed to copy the following object: {x}") raise def get_arg_from_dict(x, desired: str, use_default=False, default=None): if desired not in x: if use_default: return default else: raise KeyError return x[desired] def sql_sample_one(x): from sqlalchemy.sql import func return x.order_by(func.random()).first() def dict_to_js_vars(x): y = [f"var {key} = JSON.parse('{json.dumps(value)}'); " for key, value in x.items()] return reduce(lambda a, b: a + b, y)
[docs] def call_function(function, *args, **kwargs): """ Calls a function with ``*args`` and ``**kwargs``, but omits any ``**kwargs`` that are not requested explicitly. """ kwargs = {key: value for key, value in kwargs.items() if key in get_args(function)} return function(*args, **kwargs)
def call_function_with_context(function, *args, **kwargs): from psynet.participant import Participant from psynet.trial.main import Trial participant = kwargs.get("participant", NoArgumentProvided) experiment = kwargs.get("experiment", NoArgumentProvided) assets = kwargs.get("assets", NoArgumentProvided) nodes = kwargs.get("nodes", NoArgumentProvided) trial_maker = kwargs.get("trial_maker", NoArgumentProvided) requested = get_args(function) if experiment == NoArgumentProvided: from .experiment import get_experiment experiment = get_experiment() if "assets" in requested and assets == NoArgumentProvided: assets = {} for asset in experiment.global_assets: if asset.module_id is None: assets[asset.local_key] = asset elif participant != NoArgumentProvided: assert isinstance(participant, Participant) if ( participant.module_state and asset.module_id == participant.module_state.module_id ): assets[asset.local_key] = asset if participant != NoArgumentProvided: assert isinstance(participant, Participant) if participant.module_state: assets = { **assets, **participant.module_state.assets, } if participant != NoArgumentProvided and participant.module_state: if "nodes" in requested and nodes == NoArgumentProvided: nodes = [] for node in experiment.global_nodes: if node.module_id is None: nodes.append(node) elif node.module_id == participant.module_state.module_id: nodes.append(node) nodes += participant.module_state.nodes if "trial_maker" in requested and trial_maker == NoArgumentProvided: if ( participant != NoArgumentProvided and participant.in_module and isinstance(participant.current_trial, Trial) ): trial_maker = participant.current_trial.trial_maker new_kwargs = { "experiment": experiment, "participant": participant, "assets": assets, "nodes": nodes, "trial_maker": trial_maker, **kwargs, } return call_function(function, *args, **new_kwargs) config_defaults = { "keep_old_chrome_windows_in_debug_mode": False, } def get_config(): from dallinger.config import get_config as dallinger_get_config config = dallinger_get_config() if not config.ready: config.load() return config def get_from_config(key): global config_defaults config = get_config() if not config.ready: config.load() if key in config_defaults: return config.get(key, default=config_defaults[key]) else: return config.get(key) def get_args(f): return [str(x) for x in inspect.signature(f).parameters] def check_function_args(f, args, need_all=True): if not callable(f): raise TypeError("<f> is not a function (but it should be).") actual = [str(x) for x in inspect.signature(f).parameters] if need_all: if actual != list(args): raise ValueError(f"Invalid argument list: {actual}") else: for a in actual: if a not in args: raise ValueError(f"Invalid argument: {a}") return True
[docs] def get_object_from_module(module_name: str, object_name: str): """ Finds and returns an object from a module. Parameters ---------- module_name The name of the module. object_name The name of the object. """ mod = importlib.import_module(module_name) obj = getattr(mod, object_name) return obj
def log_time_taken(fun): @wraps(fun) def wrapper(*args, **kwargs): with time_logger(fun.__name__): res = fun(*args, **kwargs) return res return wrapper
[docs] def negate(f): """ Negates a function. Parameters ---------- f Function to negate. """ @wraps(f) def g(*args, **kwargs): return not f(*args, **kwargs) return g
[docs] def linspace(lower, upper, length: int): """ Returns a list of equally spaced numbers between two closed bounds. Parameters ---------- lower : number The lower bound. upper : number The upper bound. length : int The length of the resulting list. """ return [lower + x * (upper - lower) / (length - 1) for x in range(length)]
[docs] def merge_dicts(*args, overwrite: bool): """ Merges a collection of dictionaries, with later dictionaries taking precedence when the same key appears twice. Parameters ---------- *args Dictionaries to merge. overwrite If ``True``, when the same key appears twice in multiple dictionaries, the key from the latter dictionary takes precedence. If ``False``, an error is thrown if such duplicates occur. """ if len(args) == 0: return {} return reduce(lambda x, y: merge_two_dicts(x, y, overwrite), args)
[docs] def merge_two_dicts(x: dict, y: dict, overwrite: bool): """ Merges two dictionaries. Parameters ---------- x : First dictionary. y : Second dictionary. overwrite : If ``True``, when the same key appears twice in the two dictionaries, the key from the latter dictionary takes precedence. If ``False``, an error is thrown if such duplicates occur. """ if not overwrite: for key in y.keys(): if key in x: raise DuplicateKeyError( f"Duplicate key {key} found in the dictionaries to be merged." ) return {**x, **y}
[docs] class DuplicateKeyError(ValueError): pass
def corr(x: list, y: list, method="pearson"): import pandas as pd df = pd.DataFrame({"x": x, "y": y}, columns=["x", "y"]) return float(df.corr(method=method).at["x", "y"]) class DisableLogger: def __enter__(self): logging.disable(logging.CRITICAL) def __exit__(self, a, b, c): logging.disable(logging.NOTSET)
[docs] def query_yes_no(question, default="yes"): """ Ask a yes/no question via raw_input() and return their answer. "question" is a string that is presented to the user. "default" is the presumed answer if the user just hits <Enter>. It must be "yes" (the default), "no" or None (meaning an answer is required of the user). The "answer" return value is True for "yes" or False for "no". """ valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False} if default is None: prompt = " [y/n] " elif default == "yes": prompt = " [Y/n] " elif default == "no": prompt = " [y/N] " else: raise ValueError("invalid default answer: '%s'" % default) while True: sys.stdout.write(question + prompt) choice = input().lower() if default is not None and choice == "": return valid[default] elif choice in valid: return valid[choice] else: sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n")
def md5_object(x): string = jsonpickle.encode(x).encode("utf-8") hashed = hashlib.md5(string) return str(hashed.hexdigest()) hash_object = md5_object # MD5 hashing code: # https://stackoverflow.com/a/54477583/8454486 def md5_update_from_file(filename: Union[str, Path], hash: Hash) -> Hash: if not Path(filename).is_file(): raise FileNotFoundError(f"File not found: {filename}") with open(str(filename), "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash.update(chunk) return hash def md5_file(filename: Union[str, Path]) -> str: return str(md5_update_from_file(filename, hashlib.md5()).hexdigest()) def md5_update_from_dir(directory: Union[str, Path], hash: Hash) -> Hash: assert Path(directory).is_dir() for path in sorted(Path(directory).iterdir(), key=lambda p: str(p).lower()): hash.update(path.name.encode()) if path.is_file(): hash = md5_update_from_file(path, hash) elif path.is_dir(): hash = md5_update_from_dir(path, hash) return hash def md5_directory(directory: Union[str, Path]) -> str: return str(md5_update_from_dir(directory, hashlib.md5()).hexdigest()) def format_hash(hashed, digits=32): return base64.urlsafe_b64encode(hashed.digest())[:digits].decode("utf-8") def import_module(name, source): spec = importlib.util.spec_from_file_location(name, source) foo = importlib.util.module_from_spec(spec) spec.loader.exec_module(foo) def serialise_datetime(x): if x is None: return None return x.isoformat() def unserialise_datetime(x): if x is None: return None return datetime.fromisoformat(x) def clamp(x): return max(0, min(x, 255)) def rgb_to_hex(r, g, b): return "#{0:02x}{1:02x}{2:02x}".format( clamp(round(r)), clamp(round(g)), clamp(round(b)) )
[docs] def serialise(obj): """Serialise objects not serialisable by default""" if isinstance(obj, (datetime)): return serialise_datetime(obj) raise TypeError("Type %s is not serialisable" % type(obj))
def format_datetime(datetime): return datetime.strftime("%Y-%m-%d %H:%M:%S") def model_name_to_snake_case(model_name): return re.sub(r"(?<!^)(?=[A-Z])", "_", model_name).lower() def json_to_data_frame(json_data): import pandas as pd columns = [] for row in json_data: [columns.append(key) for key in row.keys() if key not in columns] data_frame = pd.DataFrame.from_records(json_data, columns=columns) return data_frame def wait_until( condition, max_wait, poll_interval=0.5, error_message=None, *args, **kwargs ): if condition(*args, **kwargs): return True else: waited = 0.0 while waited <= max_wait: time.sleep(poll_interval) waited += poll_interval if condition(*args, **kwargs): return True if error_message is None: error_message = ( "Condition was not satisfied within the required time interval." ) raise RuntimeError(error_message) def wait_while(condition, **kwargs): wait_until(lambda: not condition(), **kwargs) def strip_url_parameters(url): parse_result = urlparse(url) return ParseResult( scheme=parse_result.scheme, netloc=parse_result.netloc, path=parse_result.path, params=None, query=None, fragment=None, ).geturl() def is_valid_html5_id(str): if not str or " " in str: return False return True def pretty_format_seconds(seconds): minutes_and_seconds = divmod(seconds, 60) seconds_remainder = round(minutes_and_seconds[1]) formatted_time = f"{round(minutes_and_seconds[0])} min" if seconds_remainder > 0: formatted_time += f" {seconds_remainder} sec" return formatted_time def pretty_log_dict(dict, spaces_for_indentation=0): return "\n".join( " " * spaces_for_indentation + "{}: {}".format(key, (f'"{value}"' if isinstance(value, str) else value)) for key, value in dict.items() )
[docs] def require_exp_directory(f): """Decorator to verify that a command is run inside a valid PsyNet experiment directory.""" error_one = "The current directory is not a valid PsyNet experiment." error_two = "There are problems with the current experiment. Please check with `dallinger verify`." @wraps(f) def wrapper(*args, **kwargs): try: if not experiment_available(): raise click.UsageError(error_one) except ValueError: raise click.UsageError(error_two) ensure_config_txt_exists() return f(*args, **kwargs) return wrapper
def ensure_config_txt_exists(): config_txt_path = Path("config.txt") if not config_txt_path.exists(): config_txt_path.touch()
[docs] def require_requirements_txt(f): """Decorator to verify that a command is run inside a directory which contains a requirements.txt file.""" @wraps(f) def wrapper(*args, **kwargs): if not Path("requirements.txt").exists(): raise click.UsageError( "The current directory does not contain a requirements.txt file." ) return f(*args, **kwargs) return wrapper
[docs] def get_language(): """ Returns the language selected in config.txt. Throws a KeyError if no such language is specified. Returns ------- A string, for example "en". """ config = get_config() if not config.ready: config.load() return config.get("language", "en")
def _render_with_translations( locale, template_name=None, template_string=None, all_template_args=None ): """Render a template with translations applied.""" from psynet.utils import get_config if all_template_args is None: all_template_args = {} all_template_args["config"] = dict(get_config().as_dict().items()) assert [template_name, template_string].count( None ) == 1, "Only one of template_name or template_string should be provided." app = current_app._get_current_object() # type: ignore[attr-defined] gettext, pgettext = get_translator(locale) gettext_functions = [gettext, pgettext, url_for] gettext_abbr = {_f.__name__: _f for _f in gettext_functions} translation = Translations.load("translations", [locale]) environment = Environment( loader=app.jinja_env.loader, extensions=["jinja2.ext.i18n"], app=app ) environment.install_gettext_translations(translation) environment.globals.update(**gettext_abbr) if template_name is not None: template = environment.get_template(template_name) else: template = environment.from_string(template_string) return _render(app, template, all_template_args) def render_template_with_translations(template_name, locale=None, **kwargs): return _render_with_translations( template_name=template_name, locale=locale, all_template_args=kwargs ) def render_string_with_translations(template_string, locale=None, **kwargs): return _render_with_translations( template_string=template_string, locale=locale, all_template_args=kwargs ) @cache def get_translator( locale=None, module="psynet", locales_dir=LOCALES_DIR, ): from psynet.internationalization import compile_mo if locale is None: try: GET = request.args.to_dict() possible_keys = ["assignmentId", "workerId", "participantId"] from psynet.participant import Participant if any([key in GET for key in possible_keys]): if "assignmentId" in GET: participant = Participant.query.filter_by( assignment_id=GET["assignment_id"] ).one() elif "workerId" in GET: participant = Participant.query.filter_by( worker_id=int(GET["worker_id"]) ).one() elif "participantId" in GET: participant = Participant.query.filter_by( id=GET["participant_id"] ).one() locale = participant.var.locale except Exception: pass if locale is None: locale = get_language() mo_path = join_path(locales_dir, locale, "LC_MESSAGES", f"{module}.mo") po_path = join_path(locales_dir, locale, "LC_MESSAGES", f"{module}.po") if exists(mo_path): if os.path.getmtime(po_path) > os.path.getmtime(mo_path): logger.info(f"Compiling translation again, because {po_path} was updated.") compile_mo(po_path) translator = gettext.translation(module, locales_dir, [locale]) elif exists(po_path): logger.info(f"Compiling translation file on demand {po_path}.") compile_mo(po_path) translator = gettext.translation(module, locales_dir, [locale]) else: if locale != "en": logger.warning(f"No translation file found for locale {locale}.") translator = gettext.NullTranslations() return translator.gettext, translator.pgettext ISO_639_1_CODES = [ "ab", "aa", "af", "ak", "sq", "am", "ar", "an", "hy", "as", "av", "ae", "ay", "az", "bm", "ba", "eu", "be", "bn", "bh", "bi", "bs", "br", "bg", "my", "ca", "ch", "ce", "ny", "zh", "cv", "kw", "co", "cr", "hr", "cs", "da", "dv", "nl", "dz", "en", "eo", "et", "ee", "fo", "fj", "fi", "fr", "ff", "gl", "ka", "de", "el", "gn", "gu", "ht", "ha", "he", "hz", "hi", "ho", "hu", "ia", "id", "ie", "ga", "ig", "ik", "io", "is", "it", "iu", "ja", "jv", "kl", "kn", "kr", "ks", "kk", "km", "ki", "rw", "ky", "kv", "kg", "ko", "ku", "kj", "la", "lb", "lg", "li", "ln", "lo", "lt", "lu", "lv", "gv", "mk", "mg", "ms", "ml", "mt", "mi", "mr", "mh", "mn", "na", "nv", "nd", "ne", "ng", "nb", "nn", "no", "ii", "nr", "oc", "oj", "cu", "om", "or", "os", "pa", "pi", "fa", "pl", "ps", "pt", "qu", "rm", "rn", "ro", "ru", "sa", "sc", "sd", "se", "sh", "sm", "sg", "sr", "gd", "sn", "si", "sk", "sl", "so", "st", "es", "su", "sw", "ss", "sv", "ta", "te", "tg", "th", "ti", "bo", "tk", "tl", "tn", "to", "tr", "ts", "tt", "tw", "ty", "ug", "uk", "ur", "uz", "ve", "vi", "vo", "wa", "cy", "wo", "fy", "xh", "yi", "yo", "za", ] def get_available_locales(locales_dir=LOCALES_DIR): return [ f for f in os.listdir(locales_dir) if os.path.isdir(join_path(locales_dir, f)) ]
[docs] def countries(locale=None): """ List compiled using the pycountry package v20.7.3 with :: sorted([(lang.alpha_2, lang.name) for lang in pycountry.countries if hasattr(lang, 'alpha_2')], key=lambda country: country[1]) """ _, _p = get_translator(locale) return [ ("AF", _p("country_name", "Afghanistan")), ("AL", _p("country_name", "Albania")), ("DZ", _p("country_name", "Algeria")), ("AS", _p("country_name", "American Samoa")), ("AD", _p("country_name", "Andorra")), ("AO", _p("country_name", "Angola")), ("AI", _p("country_name", "Anguilla")), ("AQ", _p("country_name", "Antarctica")), ("AG", _p("country_name", "Antigua and Barbuda")), ("AR", _p("country_name", "Argentina")), ("AM", _p("country_name", "Armenia")), ("AW", _p("country_name", "Aruba")), ("AU", _p("country_name", "Australia")), ("AT", _p("country_name", "Austria")), ("AZ", _p("country_name", "Azerbaijan")), ("BS", _p("country_name", "Bahamas")), ("BH", _p("country_name", "Bahrain")), ("BD", _p("country_name", "Bangladesh")), ("BB", _p("country_name", "Barbados")), ("BY", _p("country_name", "Belarus")), ("BE", _p("country_name", "Belgium")), ("BZ", _p("country_name", "Belize")), ("BJ", _p("country_name", "Benin")), ("BM", _p("country_name", "Bermuda")), ("BT", _p("country_name", "Bhutan")), ("BO", _p("country_name", "Bolivia")), ("BQ", _p("country_name", "Bonaire, Sint Eustatius and Saba")), ("BA", _p("country_name", "Bosnia and Herzegovina")), ("BW", _p("country_name", "Botswana")), ("BV", _p("country_name", "Bouvet Island")), ("BR", _p("country_name", "Brazil")), ("IO", _p("country_name", "British Indian Ocean Territory")), ("BN", _p("country_name", "Brunei Darussalam")), ("BG", _p("country_name", "Bulgaria")), ("BF", _p("country_name", "Burkina Faso")), ("BI", _p("country_name", "Burundi")), ("CV", _p("country_name", "Cabo Verde")), ("KH", _p("country_name", "Cambodia")), ("CM", _p("country_name", "Cameroon")), ("CA", _p("country_name", "Canada")), ("KY", _p("country_name", "Cayman Islands")), ("CF", _p("country_name", "Central African Republic")), ("TD", _p("country_name", "Chad")), ("CL", _p("country_name", "Chile")), ("CN", _p("country_name", "China")), ("CX", _p("country_name", "Christmas Island")), ("CC", _p("country_name", "Cocos Islands")), ("CO", _p("country_name", "Colombia")), ("KM", _p("country_name", "Comoros")), ("CG", _p("country_name", "Congo")), ("CD", _p("country_name", "Congo (Democratic Republic)")), ("CK", _p("country_name", "Cook Islands")), ("CR", _p("country_name", "Costa Rica")), ("HR", _p("country_name", "Croatia")), ("CU", _p("country_name", "Cuba")), ("CW", _p("country_name", "Curaçao")), ("CY", _p("country_name", "Cyprus")), ("CZ", _p("country_name", "Czechia")), ("CI", _p("country_name", "Côte d'Ivoire")), ("DK", _p("country_name", "Denmark")), ("DJ", _p("country_name", "Djibouti")), ("DM", _p("country_name", "Dominica")), ("DO", _p("country_name", "Dominican Republic")), ("EC", _p("country_name", "Ecuador")), ("EG", _p("country_name", "Egypt")), ("SV", _p("country_name", "El Salvador")), ("GQ", _p("country_name", "Equatorial Guinea")), ("ER", _p("country_name", "Eritrea")), ("EE", _p("country_name", "Estonia")), ("SZ", _p("country_name", "Eswatini")), ("ET", _p("country_name", "Ethiopia")), ("FK", _p("country_name", "Falkland Islands (Malvinas)")), ("FO", _p("country_name", "Faroe Islands")), ("FJ", _p("country_name", "Fiji")), ("FI", _p("country_name", "Finland")), ("FR", _p("country_name", "France")), ("GF", _p("country_name", "French Guiana")), ("PF", _p("country_name", "French Polynesia")), ("TF", _p("country_name", "French Southern Territories")), ("GA", _p("country_name", "Gabon")), ("GM", _p("country_name", "Gambia")), ("GE", _p("country_name", "Georgia")), ("DE", _p("country_name", "Germany")), ("GH", _p("country_name", "Ghana")), ("GI", _p("country_name", "Gibraltar")), ("GR", _p("country_name", "Greece")), ("GL", _p("country_name", "Greenland")), ("GD", _p("country_name", "Grenada")), ("GP", _p("country_name", "Guadeloupe")), ("GU", _p("country_name", "Guam")), ("GT", _p("country_name", "Guatemala")), ("GG", _p("country_name", "Guernsey")), ("GN", _p("country_name", "Guinea")), ("GW", _p("country_name", "Guinea-Bissau")), ("GY", _p("country_name", "Guyana")), ("HT", _p("country_name", "Haiti")), ("HM", _p("country_name", "Heard Island and McDonald Islands")), ("VA", _p("country_name", "Vatican City State")), ("HN", _p("country_name", "Honduras")), ("HK", _p("country_name", "Hong Kong")), ("HU", _p("country_name", "Hungary")), ("IS", _p("country_name", "Iceland")), ("IN", _p("country_name", "India")), ("ID", _p("country_name", "Indonesia")), ("IR", _p("country_name", "Iran")), ("IQ", _p("country_name", "Iraq")), ("IE", _p("country_name", "Ireland")), ("IM", _p("country_name", "Isle of Man")), ("IL", _p("country_name", "Israel")), ("IT", _p("country_name", "Italy")), ("JM", _p("country_name", "Jamaica")), ("JP", _p("country_name", "Japan")), ("JE", _p("country_name", "Jersey")), ("JO", _p("country_name", "Jordan")), ("KZ", _p("country_name", "Kazakhstan")), ("KE", _p("country_name", "Kenya")), ("KI", _p("country_name", "Kiribati")), ("KP", _p("country_name", "North Korea")), ("KR", _p("country_name", "South Korea")), ("KW", _p("country_name", "Kuwait")), ("KG", _p("country_name", "Kyrgyzstan")), ("LA", _p("country_name", "Lao")), ("LV", _p("country_name", "Latvia")), ("LB", _p("country_name", "Lebanon")), ("LS", _p("country_name", "Lesotho")), ("LR", _p("country_name", "Liberia")), ("LY", _p("country_name", "Libya")), ("LI", _p("country_name", "Liechtenstein")), ("LT", _p("country_name", "Lithuania")), ("LU", _p("country_name", "Luxembourg")), ("MO", _p("country_name", "Macao")), ("MG", _p("country_name", "Madagascar")), ("MW", _p("country_name", "Malawi")), ("MY", _p("country_name", "Malaysia")), ("MV", _p("country_name", "Maldives")), ("ML", _p("country_name", "Mali")), ("MT", _p("country_name", "Malta")), ("MH", _p("country_name", "Marshall Islands")), ("MQ", _p("country_name", "Martinique")), ("MR", _p("country_name", "Mauritania")), ("MU", _p("country_name", "Mauritius")), ("YT", _p("country_name", "Mayotte")), ("MX", _p("country_name", "Mexico")), ("FM", _p("country_name", "Micronesia")), ("MD", _p("country_name", "Moldova")), ("MC", _p("country_name", "Monaco")), ("MN", _p("country_name", "Mongolia")), ("ME", _p("country_name", "Montenegro")), ("MS", _p("country_name", "Montserrat")), ("MA", _p("country_name", "Morocco")), ("MZ", _p("country_name", "Mozambique")), ("MM", _p("country_name", "Myanmar")), ("NA", _p("country_name", "Namibia")), ("NR", _p("country_name", "Nauru")), ("NP", _p("country_name", "Nepal")), ("NL", _p("country_name", "Netherlands")), ("NC", _p("country_name", "New Caledonia")), ("NZ", _p("country_name", "New Zealand")), ("NI", _p("country_name", "Nicaragua")), ("NE", _p("country_name", "Niger")), ("NG", _p("country_name", "Nigeria")), ("NU", _p("country_name", "Niue")), ("NF", _p("country_name", "Norfolk Island")), ("MK", _p("country_name", "North Macedonia")), ("MP", _p("country_name", "Northern Mariana Islands")), ("NO", _p("country_name", "Norway")), ("OM", _p("country_name", "Oman")), ("PK", _p("country_name", "Pakistan")), ("PW", _p("country_name", "Palau")), ("PS", _p("country_name", "Palestine")), ("PA", _p("country_name", "Panama")), ("PG", _p("country_name", "Papua New Guinea")), ("PY", _p("country_name", "Paraguay")), ("PE", _p("country_name", "Peru")), ("PH", _p("country_name", "Philippines")), ("PN", _p("country_name", "Pitcairn")), ("PL", _p("country_name", "Poland")), ("PT", _p("country_name", "Portugal")), ("PR", _p("country_name", "Puerto Rico")), ("QA", _p("country_name", "Qatar")), ("RO", _p("country_name", "Romania")), ("RU", _p("country_name", "Russian Federation")), ("RW", _p("country_name", "Rwanda")), ("RE", _p("country_name", "Réunion")), ("BL", _p("country_name", "Saint Barthélemy")), ("SH", _p("country_name", "Saint Helena, Ascension and Tristan da Cunha")), ("KN", _p("country_name", "Saint Kitts and Nevis")), ("LC", _p("country_name", "Saint Lucia")), ("PM", _p("country_name", "Saint Pierre and Miquelon")), ("VC", _p("country_name", "Saint Vincent and the Grenadines")), ("WS", _p("country_name", "Samoa")), ("SM", _p("country_name", "San Marino")), ("ST", _p("country_name", "Sao Tome and Principe")), ("SA", _p("country_name", "Saudi Arabia")), ("SN", _p("country_name", "Senegal")), ("RS", _p("country_name", "Serbia")), ("SC", _p("country_name", "Seychelles")), ("SL", _p("country_name", "Sierra Leone")), ("SG", _p("country_name", "Singapore")), ("SX", _p("country_name", "Sint Maarten")), ("SK", _p("country_name", "Slovakia")), ("SI", _p("country_name", "Slovenia")), ("SB", _p("country_name", "Solomon Islands")), ("SO", _p("country_name", "Somalia")), ("ZA", _p("country_name", "South Africa")), ("GS", _p("country_name", "South Georgia and the South Sandwich Islands")), ("SS", _p("country_name", "South Sudan")), ("ES", _p("country_name", "Spain")), ("LK", _p("country_name", "Sri Lanka")), ("SD", _p("country_name", "Sudan")), ("SR", _p("country_name", "Suriname")), ("SJ", _p("country_name", "Svalbard and Jan Mayen")), ("SE", _p("country_name", "Sweden")), ("CH", _p("country_name", "Switzerland")), ("SY", _p("country_name", "Syria")), ("TW", _p("country_name", "Taiwan")), ("TJ", _p("country_name", "Tajikistan")), ("TZ", _p("country_name", "Tanzania")), ("TH", _p("country_name", "Thailand")), ("TL", _p("country_name", "Timor-Leste")), ("TG", _p("country_name", "Togo")), ("TK", _p("country_name", "Tokelau")), ("TO", _p("country_name", "Tonga")), ("TT", _p("country_name", "Trinidad and Tobago")), ("TN", _p("country_name", "Tunisia")), ("TR", _p("country_name", "Turkey")), ("TM", _p("country_name", "Turkmenistan")), ("TC", _p("country_name", "Turks and Caicos Islands")), ("TV", _p("country_name", "Tuvalu")), ("UG", _p("country_name", "Uganda")), ("UA", _p("country_name", "Ukraine")), ("AE", _p("country_name", "United Arab Emirates")), ("GB", _p("country_name", "United Kingdom")), ("US", _p("country_name", "United States")), ("UM", _p("country_name", "United States Minor Outlying Islands")), ("UY", _p("country_name", "Uruguay")), ("UZ", _p("country_name", "Uzbekistan")), ("VU", _p("country_name", "Vanuatu")), ("VE", _p("country_name", "Venezuela")), ("VN", _p("country_name", "Vietnam")), ("VG", _p("country_name", "Virgin Islands (British)")), ("VI", _p("country_name", "Virgin Islands (U.S.)")), ("WF", _p("country_name", "Wallis and Futuna")), ("EH", _p("country_name", "Western Sahara")), ("YE", _p("country_name", "Yemen")), ("ZM", _p("country_name", "Zambia")), ("ZW", _p("country_name", "Zimbabwe")), ("AX", _p("country_name", "Åland Islands")), ]
[docs] def languages(locale=None): """ List compiled using the pycountry package v20.7.3 with :: sorted([(lang.alpha_2, lang.name) for lang in pycountry.languages if hasattr(lang, 'alpha_2')], key=lambda country: country[1]) """ _, _p = get_translator(locale) return [ ("ab", _p("language_name", "Abkhazian")), ("aa", _p("language_name", "Afar")), ("af", _p("language_name", "Afrikaans")), ("ak", _p("language_name", "Akan")), ("sq", _p("language_name", "Albanian")), ("am", _p("language_name", "Amharic")), ("ar", _p("language_name", "Arabic")), ("an", _p("language_name", "Aragonese")), ("hy", _p("language_name", "Armenian")), ("as", _p("language_name", "Assamese")), ("av", _p("language_name", "Avaric")), ("ae", _p("language_name", "Avestan")), ("ay", _p("language_name", "Aymara")), ("az", _p("language_name", "Azerbaijani")), ("bm", _p("language_name", "Bambara")), ("ba", _p("language_name", "Bashkir")), ("eu", _p("language_name", "Basque")), ("be", _p("language_name", "Belarusian")), ("bn", _p("language_name", "Bengali")), ("bi", _p("language_name", "Bislama")), ("bs", _p("language_name", "Bosnian")), ("br", _p("language_name", "Breton")), ("bg", _p("language_name", "Bulgarian")), ("my", _p("language_name", "Burmese")), ("ca", _p("language_name", "Catalan")), ("km", _p("language_name", "Central Khmer")), ("ch", _p("language_name", "Chamorro")), ("ce", _p("language_name", "Chechen")), ("zh", _p("language_name", "Chinese")), ("zh-cn", _p("language_name", "Chinese")), ("cu", _p("language_name", "Church Slavic")), ("cv", _p("language_name", "Chuvash")), ("kw", _p("language_name", "Cornish")), ("co", _p("language_name", "Corsican")), ("cr", _p("language_name", "Cree")), ("hr", _p("language_name", "Croatian")), ("ceb", _p("language_name", "Cebuano")), ("cs", _p("language_name", "Czech")), ("da", _p("language_name", "Danish")), ("dv", _p("language_name", "Dhivehi")), ("nl", _p("language_name", "Dutch")), ("dz", _p("language_name", "Dzongkha")), ("en", _p("language_name", "English")), ("eo", _p("language_name", "Esperanto")), ("et", _p("language_name", "Estonian")), ("ee", _p("language_name", "Ewe")), ("fo", _p("language_name", "Faroese")), ("fj", _p("language_name", "Fijian")), ("fi", _p("language_name", "Finnish")), ("fr", _p("language_name", "French")), ("ff", _p("language_name", "Fulah")), ("gl", _p("language_name", "Galician")), ("lg", _p("language_name", "Ganda")), ("ka", _p("language_name", "Georgian")), ("de", _p("language_name", "German")), ("got", _p("language_name", "Gothic")), ("gn", _p("language_name", "Guarani")), ("gu", _p("language_name", "Gujarati")), ("ht", _p("language_name", "Haitian")), ("ha", _p("language_name", "Hausa")), ("haw", _p("language_name", "Hawaiian")), ("he", _p("language_name", "Hebrew")), ("hz", _p("language_name", "Herero")), ("hi", _p("language_name", "Hindi")), ("ho", _p("language_name", "Hiri Motu")), ("hmn", _p("language_name", "Hmong")), ("hu", _p("language_name", "Hungarian")), ("is", _p("language_name", "Icelandic")), ("io", _p("language_name", "Ido")), ("ig", _p("language_name", "Igbo")), ("id", _p("language_name", "Indonesian")), ("ia", _p("language_name", "Interlingua")), ("ie", _p("language_name", "Interlingue")), ("iu", _p("language_name", "Inuktitut")), ("ik", _p("language_name", "Inupiaq")), ("ga", _p("language_name", "Irish")), ("it", _p("language_name", "Italian")), ("ja", _p("language_name", "Japanese")), ("jv", _p("language_name", "Javanese")), ("jw", _p("language_name", "Javanese")), ("kl", _p("language_name", "Kalaallisut")), ("kn", _p("language_name", "Kannada")), ("kr", _p("language_name", "Kanuri")), ("ks", _p("language_name", "Kashmiri")), ("kk", _p("language_name", "Kazakh")), ("ki", _p("language_name", "Kikuyu")), ("rw", _p("language_name", "Kinyarwanda")), ("ky", _p("language_name", "Kirghiz")), ("kv", _p("language_name", "Komi")), ("kg", _p("language_name", "Kongo")), ("ko", _p("language_name", "Korean")), ("kj", _p("language_name", "Kuanyama")), ("ku", _p("language_name", "Kurdish")), ("lo", _p("language_name", "Lao")), ("la", _p("language_name", "Latin")), ("lv", _p("language_name", "Latvian")), ("li", _p("language_name", "Limburgan")), ("ln", _p("language_name", "Lingala")), ("lt", _p("language_name", "Lithuanian")), ("lu", _p("language_name", "Luba-Katanga")), ("lb", _p("language_name", "Luxembourgish")), ("mk", _p("language_name", "Macedonian")), ("mg", _p("language_name", "Malagasy")), ("ms", _p("language_name", "Malay")), ("ml", _p("language_name", "Malayalam")), ("mt", _p("language_name", "Maltese")), ("gv", _p("language_name", "Manx")), ("mi", _p("language_name", "Maori")), ("mr", _p("language_name", "Marathi")), ("mh", _p("language_name", "Marshallese")), ("el", _p("language_name", "Greek")), ("mn", _p("language_name", "Mongolian")), ("na", _p("language_name", "Nauru")), ("nv", _p("language_name", "Navajo")), ("ng", _p("language_name", "Ndonga")), ("ne", _p("language_name", "Nepali")), ("nd", _p("language_name", "North Ndebele")), ("se", _p("language_name", "Northern Sami")), ("no", _p("language_name", "Norwegian")), ("nb", _p("language_name", "Norwegian Bokmål")), ("nn", _p("language_name", "Norwegian Nynorsk")), ("ny", _p("language_name", "Nyanja")), ("oc", _p("language_name", "Occitan")), ("oj", _p("language_name", "Ojibwa")), ("or", _p("language_name", "Oriya")), ("om", _p("language_name", "Oromo")), ("os", _p("language_name", "Ossetian")), ("pi", _p("language_name", "Pali")), ("pa", _p("language_name", "Panjabi")), ("fa", _p("language_name", "Persian")), ("pl", _p("language_name", "Polish")), ("pt", _p("language_name", "Portuguese")), ("ps", _p("language_name", "Pushto")), ("qu", _p("language_name", "Quechua")), ("ro", _p("language_name", "Romanian")), ("rm", _p("language_name", "Romansh")), ("rn", _p("language_name", "Rundi")), ("ru", _p("language_name", "Russian")), ("sm", _p("language_name", "Samoan")), ("sg", _p("language_name", "Sango")), ("sa", _p("language_name", "Sanskrit")), ("sc", _p("language_name", "Sardinian")), ("gd", _p("language_name", "Scottish Gaelic")), ("sr", _p("language_name", "Serbian")), ("sh", _p("language_name", "Serbo-Croatian")), ("sn", _p("language_name", "Shona")), ("ii", _p("language_name", "Sichuan Yi")), ("sd", _p("language_name", "Sindhi")), ("si", _p("language_name", "Sinhala")), ("sk", _p("language_name", "Slovak")), ("sl", _p("language_name", "Slovenian")), ("so", _p("language_name", "Somali")), ("nr", _p("language_name", "South Ndebele")), ("st", _p("language_name", "Southern Sotho")), ("es", _p("language_name", "Spanish")), ("su", _p("language_name", "Sundanese")), ("sw", _p("language_name", "Swahili")), ("ss", _p("language_name", "Swati")), ("sv", _p("language_name", "Swedish")), ("zh-tw", _p("language_name", "Taiwanese")), ("tl", _p("language_name", "Tagalog")), ("ty", _p("language_name", "Tahitian")), ("tg", _p("language_name", "Tajik")), ("ta", _p("language_name", "Tamil")), ("tt", _p("language_name", "Tatar")), ("te", _p("language_name", "Telugu")), ("th", _p("language_name", "Thai")), ("bo", _p("language_name", "Tibetan")), ("ti", _p("language_name", "Tigrinya")), ("to", _p("language_name", "Tonga")), ("ts", _p("language_name", "Tsonga")), ("tn", _p("language_name", "Tswana")), ("tr", _p("language_name", "Turkish")), ("tk", _p("language_name", "Turkmen")), ("tw", _p("language_name", "Twi")), ("ug", _p("language_name", "Uighur")), ("uk", _p("language_name", "Ukrainian")), ("ur", _p("language_name", "Urdu")), ("uz", _p("language_name", "Uzbek")), ("ve", _p("language_name", "Venda")), ("vi", _p("language_name", "Vietnamese")), ("vo", _p("language_name", "Volapük")), ("wa", _p("language_name", "Walloon")), ("cy", _p("language_name", "Welsh")), ("hyw", _p("language_name", "Western Armenian")), ("fy", _p("language_name", "Western Frisian")), ("wo", _p("language_name", "Wolof")), ("xh", _p("language_name", "Xhosa")), ("yi", _p("language_name", "Yiddish")), ("yo", _p("language_name", "Yoruba")), ("za", _p("language_name", "Zhuang")), ("zu", _p("language_name", "Zulu")), ]
def _get_entity_dict_from_tuple_list(tuple_list, sort_by_value): dictionary = dict( zip([key for key, value in tuple_list], [value for key, value in tuple_list]) ) if sort_by_value: return dict(OrderedDict(sorted(dictionary.items(), key=lambda t: t[1]))) else: return dictionary def get_language_dict(locale, sort_by_name=True): return _get_entity_dict_from_tuple_list(languages(locale), sort_by_name) def get_country_dict(locale, sort_by_name=True): return _get_entity_dict_from_tuple_list(countries(locale), sort_by_name) def sample_from_surface_of_unit_sphere(n_dimensions): import numpy as np res = np.random.randn(n_dimensions, 1) res /= np.linalg.norm(res, axis=0) return res[:, 0].tolist() class ClassPropertyDescriptor(object): def __init__(self, fget, fset=None): self.fget = fget self.fset = fset def __get__(self, obj, cls=None): if cls is None: cls = type(obj) return self.fget.__get__(obj, cls)() def __set__(self, obj, value): if not self.fset: raise AttributeError("can't set attribute") type_ = type(obj) return self.fset.__get__(obj, type_)(value) def setter(self, func): if not isinstance(func, (classmethod, staticmethod)): func = classmethod(func) self.fset = func return self
[docs] def classproperty(func): """ Defines an analogous version of @property but for classes, after https://stackoverflow.com/questions/5189699/how-to-make-a-class-property. """ if not isinstance(func, (classmethod, staticmethod)): func = classmethod(func) return ClassPropertyDescriptor(func)
def run_subprocess_with_live_output(command, timeout=None, cwd=None): _command = command.replace('"', '\\"').replace("'", "\\'") p = pexpect.spawn(f'bash -c "{_command}"', timeout=timeout, cwd=cwd) while not p.eof(): line = p.readline().decode("utf-8") print(line, end="") p.close() if p.exitstatus > 0: sys.exit(p.exitstatus) def get_extension(path): if path: _, extension = os.path.splitext(path) return extension else: return "" # Backported from Python 3.9
[docs] def cache(user_function, /): 'Simple lightweight unbounded cache. Sometimes called "memoize".' return lru_cache(maxsize=None)(user_function)
[docs] def organize_by_key(lst, key, sort_key=None): """ Sorts a list of items into groups. Parameters ---------- lst : List to sort. key : Function applied to elements of ``lst`` which defines the grouping key. Returns ------- A dictionary keyed by the outputs of ``key``. """ out = {} for obj in lst: _key = key(obj) if _key not in out: out[_key] = [] out[_key].append(obj) if sort_key: for value in out.values(): value.sort(key=sort_key) return out
@contextlib.contextmanager def working_directory(path): start_dir = os.getcwd() try: os.chdir(path) yield finally: os.chdir(start_dir)
[docs] def get_custom_sql_classes(): """ Returns ------- A dictionary of all custom SQLAlchemy classes defined in the local experiment (excluding any which are defined within packages). """ def f(): return { cls.__name__: cls for _, module in inspect.getmembers(sys.modules["dallinger_experiment"]) for _, cls in inspect.getmembers(module) if inspect.isclass(cls) and cls.__module__.startswith("dallinger_experiment") and hasattr(cls, "_sa_registry") } try: return f() except KeyError: from psynet.experiment import import_local_experiment import_local_experiment() return f()
[docs] def make_parents(path): """ Creates the parent directories for a specified file if they don't exist already. Returns ------- The original path. """ Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True) return path
def bytes_to_megabytes(bytes): return bytes / (1024 * 1024) def get_file_size_mb(path): bytes = os.path.getsize(path) return bytes_to_megabytes(bytes) def get_folder_size_mb(path): bytes = sum(entry.stat().st_size for entry in os.scandir(path)) return bytes_to_megabytes(bytes) # def run_async_command_locally(fun, *args, **kwargs): # """ # This is for when want to run a command asynchronously (so that it doesn't block current execution) # but locally (so that we know we have access to local files). # """ # # def wrapper(): # f = io.StringIO() # with contextlib.redirect_stdout(f): # try: # fun(*args, **kwargs) # except Exception: # print(traceback.format_exc()) # log_to_redis(f.getvalue()) # # import threading # # thr = threading.Thread(target=wrapper) # thr.start() # def log_to_redis(msg): # """ # This passes the message to the Redis queue to be printed by the worker that picks it up. # This is useful for logging from processes that don't have access to the main logger. # """ # q = Queue("default", connection=redis_conn) # q.enqueue_call( # func=logger.info, args=(), kwargs=dict(msg=msg), timeout=1e10, at_front=True # ) @contextlib.contextmanager def disable_logger(): logging.disable(sys.maxsize) yield logging.disable(logging.NOTSET) def clear_all_caches(): import functools import gc for obj in gc.get_objects(): try: if isinstance(obj, functools._lru_cache_wrapper): obj.cache_clear() except ReferenceError: pass @contextlib.contextmanager def log_pexpect_errors(process): try: yield except (pexpect.EOF, pexpect.TIMEOUT) as err: print(f"A {err} error occurred. Printing process logs:") print(process.before) raise # This seemed like a good idea for preventing cases where people use random functions # in code blocks, page makers, etc. In practice however it didn't work, because # some library functions tamper with the random state in a hidden way, # making the check have too many false positives. # # @contextlib.contextmanager # def disallow_random_functions(func_name, func=None): # random_state = random.getstate # numpy_random_state = numpy.random.get_state() # # yield # # if ( # random.getstate() != random_state # or numpy.random.get_state() != numpy_random_state # ): # message = ( # "It looks like you used Python's random number generator within " # f"your {func_name} code. This is disallowed because it allows your " # "experiment to get into inconsistent states. Instead you should generate " # "call any random number generators within code blocks, for_loop() constructs, " # "Trial.make_definition methods, or similar." # ) # if func: # message += "\n" # message += "Offending code:\n" # message += inspect.getsource(func) # # raise RuntimeError(message)
[docs] def is_method_overridden(obj, ancestor: Type, method: str): """ Test whether a method has been overridden. Parameters ---------- obj : Object to test. ancestor : Ancestor class to test against. method : Method name. Returns ------- Returns ``True`` if the object shares a method with its ancestor, or ``False`` if that method has been overridden. """ return getattr(obj.__class__, method) != getattr(ancestor, method)
@contextlib.contextmanager def time_logger(label, threshold=0.01): log = { "time_started": time.monotonic(), "time_finished": None, "time_taken": None, } yield log log["time_finished"] = time.monotonic() log["time_taken"] = log["time_finished"] - log["time_started"] if log["time_taken"] > threshold: logger.info( "Task '%s' took %.3f s", label, log["time_taken"], ) @contextlib.contextmanager def log_level(logger: logging.Logger, level): original_level = logger.level logger.setLevel(level) yield logger.setLevel(original_level) def get_psynet_root(): import psynet return Path(psynet.__file__).parent.parent def list_experiment_dirs(for_ci_tests=False, ci_node_total=None, ci_node_index=None): demo_root = get_psynet_root() / "demos" test_experiments_root = get_psynet_root() / "tests/experiments" dirs = sorted( [ dir_ for root in [demo_root, test_experiments_root] for dir_, sub_dirs, files in os.walk(root) if ( "experiment.py" in files and not dir_.endswith("/develop") and ( not for_ci_tests or not ( # Skip the recruiter demos because they're not meaningful to run here "recruiters" in dir_ # Skip the gibbs_video demo because it relies on ffmpeg which is not installed # in the CI environment or dir_.endswith("/gibbs_video") ) ) ) ] ) if ci_node_total is not None and ci_node_index is not None: dirs = with_parallel_ci(dirs, ci_node_total, ci_node_index) return dirs def with_parallel_ci(paths, ci_node_total, ci_node_index): index = ci_node_index - 1 # 1-indexed to 0-indexed assert 0 <= index < ci_node_total return [paths[i] for i in range(len(paths)) if i % ci_node_total == index] def list_isolated_tests(ci_node_total=None, ci_node_index=None): isolated_tests_root = get_psynet_root() / "tests" / "isolated" isolated_tests_demos = isolated_tests_root / "demos" isolated_tests_experiments = isolated_tests_root / "experiments" isolated_tests_features = isolated_tests_root / "features" tests = [] for directory in [ isolated_tests_root, isolated_tests_demos, isolated_tests_experiments, isolated_tests_features, ]: tests.extend(glob.glob(str(directory / "*.py"))) if ci_node_total is not None and ci_node_index is not None: tests = with_parallel_ci(tests, ci_node_total, ci_node_index) return tests # Check TODOs class PatternDir: def __init__(self, pattern, glob_dir): self.pattern = pattern self.glob_dir = glob_dir def __dict__(self): return {"pattern": self.pattern, "glob_dir": self.glob_dir} def _check_todos(pattern, glob_dir): from glob import iglob todo_count = {} for path in list(iglob(glob_dir, recursive=True)): key = (path, pattern) with open(path, "r") as f: line_has_todo = [line.strip().startswith(pattern) for line in f.readlines()] if any(line_has_todo): todo_count[key] = sum(line_has_todo) return todo_count def _aggregate_todos(pattern_dirs: [PatternDir]): todo_count = {} for pattern_dir in pattern_dirs: todo_count.update(_check_todos(**pattern_dir.__dict__())) return todo_count def check_todos_before_deployment(): if os.environ.get("SKIP_TODO_CHECK"): print( "SKIP_TODO_CHECK is set so we will not check if there are any TODOs in the experiment folder." ) return todo_count = _aggregate_todos( [ # For now only limit to comments specific to the experiment logic (i.e. Python and JS) PatternDir("# TODO", "**/*.py"), # Python comments PatternDir("// TODO", "**/*.py"), # Javascript comment in py files PatternDir("// TODO", "**/*.html"), # Javascript comment in html files PatternDir("// TODO", "**/*.js"), # Javascript comment in js files ] ) file_names = [key[0] for key in todo_count.keys()] total_todo_count = sum(todo_count.values()) n_files = len(set(file_names)) assert len(todo_count) == 0, ( f"You have {total_todo_count} TODOs in {n_files} file(s) in your experiment folder. " "Please fix them or remove them before deploying. " "To view all TODOs in your project in PyCharm, go to 'View' > 'Tool Windows' > 'TODO'. " "You can skip this check by writing `export SKIP_TODO_CHECK=1` (without quotes) in your terminal." ) def as_plain_text(html): text = html2text.HTML2Text().handle(str(html)) pattern = re.compile(r"\s+") text = re.sub(pattern, " ", text).strip() return text