"""Provides class that allows access to an SQL Database that serves as the back-end to EasyVVUQ.
"""
import os
import json
import logging
import pandas as pd
import numpy as np
from sqlalchemy.sql import case
from sqlalchemy import create_engine, Column, Integer, String, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy import MetaData
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import text
from sqlalchemy.engine import Engine
from sqlalchemy import event
from .base import BaseCampaignDB
from easyvvuq import constants
from easyvvuq import ParamsSpecification
from easyvvuq.utils.helpers import easyvvuq_serialize, easyvvuq_deserialize
__copyright__ = """
Copyright 2018 Robin A. Richardson, David W. Wright
This file is part of EasyVVUQ
EasyVVUQ is free software: you can redistribute it and/or modify
it under the terms of the Lesser GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
EasyVVUQ is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Lesser GNU General Public License for more details.
You should have received a copy of the Lesser GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
__license__ = "LGPL"
COMMIT_RATE = 50000
logger = logging.getLogger(__name__)
Base = declarative_base()
[docs]class DBInfoTable(Base):
"""An SQLAlchemy schema for the database information table.
"""
__tablename__ = 'db_info'
id = Column(Integer, primary_key=True)
next_run = Column(Integer)
[docs]class CampaignTable(Base):
"""An SQLAlchemy schema for the campaign information table.
"""
__tablename__ = 'campaign_info'
id = Column(Integer, primary_key=True)
name = Column(String, unique=True)
easyvvuq_version = Column(String)
campaign_dir_prefix = Column(String)
campaign_dir = Column(String)
runs_dir = Column(String)
sampler = Column(Integer, ForeignKey('sampler.id'))
active_app = Column(Integer, ForeignKey('app.id'))
[docs]class AppTable(Base):
"""An SQLAlchemy schema for the app table.
"""
__tablename__ = 'app'
id = Column(Integer, primary_key=True)
name = Column(String, unique=True)
params = Column(String)
actions = Column(String)
[docs]class RunTable(Base):
"""An SQLAlchemy schema for the run table.
"""
__tablename__ = 'run'
id = Column(Integer, primary_key=True)
run_name = Column(String, index=True)
app = Column(Integer, ForeignKey('app.id'))
params = Column(String)
status = Column(Integer)
run_dir = Column(String)
result = Column(String, default="{}")
execution_info = Column(String, default="{}")
campaign = Column(Integer, ForeignKey('campaign_info.id'))
sampler = Column(Integer, ForeignKey('sampler.id'))
iteration = Column(Integer, default=0)
[docs]class SamplerTable(Base):
"""An SQLAlchemy schema for the run table.
"""
__tablename__ = 'sampler'
id = Column(Integer, primary_key=True)
sampler = Column(String)
[docs]@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA synchronous = OFF")
cursor.execute("PRAGMA journal_mode = OFF")
cursor.close()
[docs]class CampaignDB(BaseCampaignDB):
"""An interface between the campaign database and the campaign.
Parameters
----------
location: str
database URI as needed by SQLAlchemy
"""
def __init__(self, location=None):
if location is not None:
self.engine = create_engine(location)
else:
self.engine = create_engine('sqlite://')
self.commit_counter = 0
session_maker = sessionmaker(bind=self.engine)
self.session = session_maker()
Base.metadata.create_all(self.engine, checkfirst=True)
[docs] def resume_campaign(self, name):
"""Resumes campaign.
Parameters
----------
name: str
Name of the Campaign to resume. Must already exist in the database.
"""
info = self.session.query(
CampaignTable).filter_by(name=name).first()
if info is None:
raise ValueError('Campaign with the given name not found.')
db_info = self.session.query(DBInfoTable).first()
self._next_run = db_info.next_run
[docs] def create_campaign(self, info):
"""Creates a new campaign in the database.
Parameters
----------
info: CampaignInfo
This `easyvvuq.data_structs.CampaignInfo` will contain information
needed to construct the Campaign table.
"""
is_db_empty = (self.session.query(CampaignTable).first() is None)
version_check = self.session.query(
CampaignTable).filter(CampaignTable.easyvvuq_version != info.easyvvuq_version).all()
if (not is_db_empty) and (len(version_check) != 0):
raise RuntimeError('Database contains campaign created with an incompatible' +
' version of EasyVVUQ!')
self._next_run = 1
self.session.add(CampaignTable(**info.to_dict(flatten=True)))
self.session.add(DBInfoTable(next_run=self._next_run))
self.session.commit()
[docs] def get_active_app(self):
"""Returns active app table.
Returns
-------
AppTable
"""
return self.session.query(AppTable, CampaignTable).filter(
AppTable.id == CampaignTable.active_app).first()
[docs] def campaign_exists(self, name):
"""Check if campaign specified by that name already exists.
Parameters
----------
name: str
Returns
-------
bool
True if such a campaign already exists, False otherwise
"""
result = self.session.query(CampaignTable).filter(
CampaignTable.name == name).all()
return len(result) > 0
[docs] def app(self, name=None):
"""Get app information. Specific applications selected by `name`,
otherwise first entry in database 'app' selected.
Parameters
----------
name : str or None
Name of selected app, if `None` given then first app will be
selected.
Returns
-------
dict
Information about the application.
"""
if name is None:
selected = self.session.query(AppTable).all()
else:
selected = self.session.query(AppTable).filter_by(name=name).all()
if len(selected) == 0:
message = f'No entry for app: ({name}).'
logger.critical(message)
raise RuntimeError(message)
selected_app = selected[0]
app_dict = {
'id': selected_app.id,
'name': selected_app.name,
'params': ParamsSpecification.deserialize(selected_app.params),
'actions': selected_app.actions,
}
return app_dict
[docs] def set_active_app(self, name):
"""Set an app specified by name as active.
Parameters
----------
name: str
name of the app to set as active
"""
selected = self.session.query(AppTable).filter_by(name=name).all()
if len(selected) == 0:
raise RuntimeError('no such app - {}'.format(name))
assert (not (len(selected) > 1))
app = selected[0]
self.session.query(CampaignTable).update({'active_app': app.id})
self.session.commit()
[docs] def add_app(self, app_info):
"""Add application to the 'app' table.
Parameters
----------
app_info: AppInfo
Application definition.
"""
# Check that no app with same name exists
name = app_info.name
selected = self.session.query(AppTable).filter_by(name=name).all()
if len(selected) > 0:
message = (
f'There is already an app in this database with name {name}'
f'(found {len(selected)}).'
)
logger.critical(message)
raise RuntimeError(message)
app_dict = app_info.to_dict(flatten=True)
db_entry = AppTable(**app_dict)
self.session.add(db_entry)
self.session.commit()
[docs] def replace_actions(self, app_name, actions):
"""Replace actions for an app with a given name.
Parameters
----------
app_name: str
Name of the app.
actions: Actions
`Actions` instance, will replace the current `Actions` of an app.
"""
self.session.query(AppTable).filter_by(name=app_name).update(
{'actions': easyvvuq_serialize(actions)})
self.session.commit()
[docs] def add_sampler(self, sampler_element):
"""Add new Sampler to the 'sampler' table.
Parameters
----------
sampler_element: Sampler
An EasyVVUQ sampler.
Returns
-------
int
The sampler `id` in the database.
"""
db_entry = SamplerTable(sampler=easyvvuq_serialize(sampler_element))
self.session.add(db_entry)
self.session.commit()
return db_entry.id
[docs] def update_sampler(self, sampler_id, sampler_element):
"""Update the state of the Sampler with id 'sampler_id' to
that in the passed 'sampler_element'
Parameters
----------
sampler_id: int
The id of the sampler in the db to update
sampler_element: Sampler
The sampler that should be used as the new state
"""
selected = self.session.query(SamplerTable).get(sampler_id)
selected.sampler = easyvvuq_serialize(sampler_element)
self.session.commit()
[docs] def resurrect_sampler(self, sampler_id):
"""Return the sampler object corresponding to id sampler_id in the database.
It is deserialized from the state stored in the database.
Parameters
----------
sampler_id: int
The id of the sampler to resurrect
Returns
-------
Sampler
The 'live' sampler object, deserialized from the state in the db
"""
try:
serialized_sampler = self.session.query(SamplerTable).get(sampler_id).sampler
sampler = easyvvuq_deserialize(serialized_sampler.encode('utf-8'))
except AttributeError:
sampler = None
return sampler
[docs] def resurrect_app(self, app_name):
"""Return the 'live' encoder, decoder and collation objects corresponding to the app with
name 'app_name' in the database. They are deserialized from the states previously
stored in the database.
Parameters
----------
app_name: string
Name of the app to resurrect
Returns
-------
Actions
The 'live' `Actions` object associated with this app. Used to execute the simulation
associated with the app as well as do any pre- and post-processing.
"""
app_info = self.app(app_name)
actions = easyvvuq_deserialize(app_info['actions'])
return actions
[docs] def add_runs(self, run_info_list=None, run_prefix='run_', iteration=0):
"""Add list of runs to the `runs` table in the database.
Parameters
----------
run_info_list: List of RunInfo objects
Each RunInfo object contains relevant run fields: params, status (where in the
EasyVVUQ workflow is this RunTable), campaign (id number), sample, app
run_prefix: str
Prefix for run name
iteration: int
Iteration number used by iterative workflows. For example, MCMC. Can be left
as default zero in other cases.
"""
# Add all runs to RunTable
commit_counter = 0
for run_info in run_info_list:
run_info.run_name = f"{run_prefix}{self._next_run}"
run_info.iteration = iteration
run = RunTable(**run_info.to_dict(flatten=True))
self.session.add(run)
self._next_run += 1
commit_counter += 1
if commit_counter % COMMIT_RATE == 0:
self.session.commit()
# Update run and ensemble counters in db
db_info = self.session.query(DBInfoTable).first()
db_info.next_run = self._next_run
self.session.commit()
@staticmethod
def _run_to_dict(run_row):
"""Convert the provided row from 'runs' table into a dictionary
Parameters
----------
run_row: RunTable
Information on a particular run in the database.
Returns
-------
dict
Contains run information (keys = run_name, params, status, sample,
campaign and app)
"""
run_info = {
'run_name': run_row.run_name,
'params': json.loads(run_row.params),
'status': constants.Status(run_row.status),
'sampler': run_row.sampler,
'campaign': run_row.campaign,
'app': run_row.app,
'result': run_row.result,
'run_dir': run_row.run_dir
}
return run_info
[docs] def set_dir_for_run(self, run_name, run_dir, campaign=None, sampler=None):
"""Set the 'run_dir' path for the specified run in the database.
Parameters
----------
run_name: str
Name of run to filter for.
run_dir: str
Directory path associated to set for this run.
campaign: int or None
Campaign id to filter for.
sampler: int or None
Sample id to filter for.
"""
filter_options = {'run_name': run_name}
if campaign:
filter_options['campaign'] = campaign
if sampler:
filter_options['sampler'] = sampler
selected = self.session.query(RunTable).filter_by(**filter_options)
if selected.count() != 1:
logging.critical('Multiple runs selected - using the first')
selected = selected.first()
selected.run_dir = run_dir
self.session.commit()
[docs] def get_run_status(self, run_id, campaign=None, sampler=None):
"""Return the status (enum) for the run with name 'run_name' (and, optionally,
filtering for campaign and sampler by id)
Parameters
----------
run_id: int
id of the run
campaign: int
ID of the desired Campaign
sampler: int
ID of the desired Sampler
Returns
-------
enum(Status)
Status of the run.
"""
filter_options = {'id': run_id}
if campaign:
filter_options['campaign'] = campaign
if sampler:
filter_options['sampler'] = sampler
selected = self.session.query(RunTable).filter_by(**filter_options)
if selected.count() != 1:
logging.critical('Multiple runs selected - using the first')
selected = selected.first()
return constants.Status(selected.status)
[docs] def set_run_statuses(self, run_id_list, status):
"""Set the specified 'status' (enum) for all runs in the list run_id_list
Parameters
----------
run_id_list: list of int
a list of run ids
status: enum(Status)
The new status all listed runs should now have
"""
self.session.query(RunTable).filter(
RunTable.id.in_(run_id_list)).update(
{RunTable.status: status}, synchronize_session='fetch')
self.session.commit()
[docs] def campaigns(self):
"""Get list of campaigns for which information is stored in the
database.
Returns
-------
list
Campaign names.
"""
return [c.name for c in self.session.query(CampaignTable).all()]
def _get_campaign_info(self, campaign_name=None):
"""Retrieves Campaign info based on name.
Parameters
----------
campaign_name: str
Name of campaign to select.
Returns
-------
SQLAlchemy query for campaign with this name.
"""
assert (isinstance(campaign_name, str) or campaign_name is None)
query = self.session.query(CampaignTable)
if campaign_name is None:
campaign_info = query
else:
campaign_info = query.filter_by(name=campaign_name).all()
if campaign_name is not None:
if len(campaign_info) > 1:
logger.warning(
'More than one campaign selected - using first one.')
elif len(campaign_info) == 0:
message = 'No campaign available.'
logger.critical(message)
raise RuntimeError(message)
return campaign_info[0]
return campaign_info.first()
[docs] def get_campaign_id(self, name):
"""Return the (database) id corresponding to the campaign with name 'name'.
Parameters
----------
name: str
Name of the campaign.
Returns
-------
int
The id of the campaign with the specified name
"""
selected = self.session.query(
CampaignTable.name.label(name),
CampaignTable.id).filter(CampaignTable.name == name).all()
if len(selected) == 0:
msg = f"No campaign with name {name} found in campaign database"
logger.error(msg)
raise RuntimeError(msg)
if len(selected) > 1:
msg = (
f"More than one campaign with name {name} found in"
f"campaign database. Database state is compromised."
)
logger.error(msg)
raise RuntimeError(msg)
# Return the database ID for the specified campaign
return selected[0][1]
[docs] def get_sampler_id(self, campaign_id):
"""Return the (database) id corresponding to the sampler currently set
for the campaign with id 'campaign_id'
Parameters
----------
campaign_id: int
ID of the campaign.
Returns
-------
int
The id of the sampler set for the specified campaign
"""
sampler_id = self.session.query(CampaignTable).get(campaign_id).sampler
return sampler_id
[docs] def set_sampler(self, campaign_id, sampler_id):
"""Set specified campaign to be using specified sampler
Parameters
----------
campaign_id: int
ID of the campaign.
sampler_id: int
ID of the sampler.
"""
self.session.query(CampaignTable).get(campaign_id).sampler = sampler_id
self.session.commit()
[docs] def campaign_dir(self, campaign_name=None):
"""Get campaign directory for `campaign_name`.
Parameters
----------
campaign_name: str
Name of campaign to select
Returns
-------
str
Path to campaign directory.
"""
return self._get_campaign_info(campaign_name=campaign_name).campaign_dir
def _select_runs(
self,
name=None,
campaign=None,
sampler=None,
status=None,
not_status=None,
app_id=None):
"""Select all runs in the database which match the input criteria.
Parameters
----------
name: str
Name of run to filter for.
campaign: int or None
Campaign id to filter for.
sampler: int or None
Sampler id to filter for.
status: enum(Status) or None
Status string to filter for.
not_status: enum(Status) or None
Exclude runs with this status string
app_id: int or None
App id to filter for.
Returns
-------
sqlalchemy.orm.query.Query
Selected runs from the database run table.
"""
filter_options = {}
if name:
filter_options['run_name'] = name
if campaign:
filter_options['campaign'] = campaign
if sampler:
filter_options['sampler'] = sampler
if status:
filter_options['status'] = status
if app_id:
filter_options['app'] = app_id
# Note that for some databases this can be sped up with a yield_per(), but not all
selected = self.session.query(RunTable).filter_by(
**filter_options).filter(RunTable.status != not_status)
return selected
[docs] def run(self, name, campaign=None, sampler=None, status=None, not_status=None, app_id=None):
"""Get the information for a specified run.
Parameters
----------
name: str
Name of run to filter for.
campaign: int or None
Campaign id to filter for.
sampler: int or None
Sampler id to filter for.
status: enum(Status) or None
Status string to filter for.
not_status: enum(Status) or None
Exclude runs with this status string
app_id: int or None
App id to filter for.
Returns
-------
dict
Containing run information (run_name, params, status, sample,
campaign, app)
"""
selected = self._select_runs(
name=name,
campaign=campaign,
sampler=sampler,
status=status,
not_status=not_status,
app_id=app_id)
if selected.count() != 1:
logging.warning('Multiple runs selected - using the first')
selected = selected.first()
return self._run_to_dict(selected)
[docs] def runs(self, campaign=None, sampler=None, status=None, not_status=None, app_id=None):
"""A generator to return all run information for selected `campaign` and `sampler`.
Parameters
----------
campaign: int or None
Campaign id to filter for.
sampler: int or None
Sampler id to filter for.
status: enum(Status) or None
Status string to filter for.
not_status: enum(Status) or None
Exclude runs with this status string
app_id: int or None
App id to filter for.
Yields
------
dict
Information on each selected run (key = run_name, value = dict of
run information fields.), one at a time.
"""
selected = self._select_runs(
campaign=campaign,
sampler=sampler,
status=status,
not_status=not_status,
app_id=app_id)
for r in selected:
yield r.id, self._run_to_dict(r)
[docs] def run_ids(self, campaign=None, sampler=None, status=None, not_status=None, app_id=None):
"""A generator to return all run IDs for selected `campaign` and `sampler`.
Parameters
----------
campaign: int or None
Campaign id to filter for.
sampler: int or None
Sampler id to filter for.
status: enum(Status) or None
Status string to filter for.
not_status: enum(Status) or None
Exclude runs with this status string
app_id: int or None
App id to filter for.
Yields
------
str
run ID for each selected run, one at a time.
"""
selected = self._select_runs(
campaign=campaign,
sampler=sampler,
status=status,
not_status=not_status,
app_id=app_id)
for r in selected:
yield r.run_name
[docs] def get_num_runs(self, campaign=None, sampler=None, status=None, not_status=None):
"""Returns the number of runs matching the filtering criteria.
Parameters
----------
campaign: int or None
Campaign id to filter for.
sampler: int or None
Sampler id to filter for.
status: enum(Status) or None
Status string to filter for.
not_status: enum(Status) or None
Exclude runs with this status string
Returns
-------
int
The number of runs in the database matching the filtering criteria
"""
selected = self._select_runs(
campaign=campaign,
sampler=sampler,
status=status,
not_status=not_status)
return selected.count()
[docs] def runs_dir(self, campaign_name=None):
"""Get the directory used to store run information for `campaign_name`.
Parameters
----------
campaign_name: str
Name of the selected campaign.
Returns
-------
str
Path containing run outputs.
"""
return self._get_campaign_info(campaign_name=campaign_name).runs_dir
[docs] def store_result(self, run_id, result, change_status=True):
"""Stores results of a simulation inside the RunTable given a run id.
Parameters
----------
run_id: int
The id of a run to store the results in. This will be the run with which these
results are associated with. Namely the run that has the inputs used to generate
these results.
result: dict
Results in dictionary form. This is the same format as used by the `Decoder`.
change_status: bool
If set to False will not update the runs' status to COLLATED. This is sometimes
useful in scenarios where you want several apps to work on the same runs.
"""
self.commit_counter += 1
def convert_nonserializable(obj):
if isinstance(obj, np.int64):
return int(obj)
raise TypeError('Unknown type:', type(obj))
result_ = result['result']
result.pop('result')
result.pop('run_info')
if change_status:
self.session.query(RunTable).\
filter(RunTable.id == run_id).\
update({'result': json.dumps(result_, default=convert_nonserializable),
'status': constants.Status.COLLATED,
'run_dir': result['rundir']})
else:
self.session.query(RunTable).\
filter(RunTable.id == run_id).\
update({'result': json.dumps(result_, default=convert_nonserializable),
'run_dir': result['rundir']})
if self.commit_counter % COMMIT_RATE == 0:
self.session.commit()
[docs] def store_results(self, app_name, results):
"""Stores the results from a given run in the database.
Parameters
----------
run_name: str
name of the run
results: dict
dictionary with the results (from the decoder)
"""
try:
app_id = self.session.query(AppTable).filter(AppTable.name == app_name).all()[0].id
except IndexError:
raise RuntimeError("app with the name {} not found".format(app_name))
commit_counter = 0
for run_id, result in results:
try:
self.session.query(RunTable).\
filter(RunTable.id == run_id, RunTable.app == app_id).\
update({'result': json.dumps(result), 'status': constants.Status.COLLATED})
commit_counter += 1
if commit_counter % COMMIT_RATE == 0:
self.session.commit()
except IndexError:
raise RuntimeError("no runs with name {} found".format(run_id))
self.session.commit()
[docs] def get_results(self, app_name, sampler_id, status=constants.Status.COLLATED, iteration=-1):
"""Returns the results as a pandas DataFrame.
Parameters
----------
app_name: str
Name of the app to return data for.
sampler_id: int
ID of the sampler.
status: STATUS
Run status to filter for.
iteration: int
If a positive integer will return the results for a given iteration only.
Returns
-------
DataFrame
Will construct a `DataFrame` from the decoder output dictionaries.
"""
try:
app_id = self.session.query(AppTable).filter(AppTable.name == app_name).all()[0].id
except IndexError:
raise RuntimeError("app with the name {} not found".format(app_name))
pd_result = {}
query = self.session.query(RunTable).\
filter(RunTable.app == app_id).\
filter(RunTable.sampler == sampler_id).\
filter(RunTable.status == status)
# if only a specific iteration is requested filter it out
if iteration >= 0:
query = query.filter(RunTable.iteration == iteration)
for row in query:
params = {'run_id': row.id}
params['iteration'] = row.iteration
params = {**params, **json.loads(row.params)}
result = json.loads(row.result)
pd_dict = {**params, **result}
for key in pd_dict.keys():
if not isinstance(pd_dict[key], list):
try:
pd_result[(key, 0)].append(pd_dict[key])
except KeyError:
pd_result[(key, 0)] = [pd_dict[key]]
else:
for i, elt in enumerate(pd_dict[key]):
try:
pd_result[(key, i)].append(pd_dict[key][i])
except KeyError:
pd_result[(key, i)] = [pd_dict[key][i]]
try:
return pd.DataFrame(pd_result)
except ValueError:
raise RuntimeError(
'the results received from the database seem to be malformed - commonly because a vector quantity of interest changes dimensionality')
[docs] def relocate(self, new_path, campaign_name):
"""Update all runs in the db with the new campaign path.
Parameters
----------
new_path: str
new runs directory
campaign_name: str
name of the campaign
"""
campaign_id = self.get_campaign_id(campaign_name)
campaign_info = self.session.query(CampaignTable).\
filter(CampaignTable.id == campaign_id).first()
path, runs_dir = os.path.split(campaign_info.runs_dir)
self.session.query(CampaignTable).\
filter(CampaignTable.id == campaign_id).\
update({'campaign_dir': str(new_path),
'runs_dir': str(os.path.join(new_path, runs_dir))})
self.session.commit()
[docs] def dump(self):
"""Dump the database as JSON for debugging purposes.
Returns
-------
dict
A database dump in JSON format.
"""
meta = MetaData()
meta.reflect(bind=self.engine)
result = {}
for table in meta.sorted_tables:
result[table.name] = [dict(row) for row in self.engine.execute(table.select())]
return json.dumps(result)