Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 320 additions & 2 deletions src/workflow/FileManager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from pathlib import Path
import gzip
import shutil
import string
import random
import shutil
import sqlite3

import pandas as pd
import pickle as pkl

from io import BytesIO
from pathlib import Path
from typing import Union, List

class FileManager:
Expand All @@ -19,12 +26,44 @@ class FileManager:
def __init__(
self,
workflow_dir: Path,
cache_path: Path,
):
"""
Initializes the FileManager object with a the current workflow results directory.
"""
self.workflow_dir = workflow_dir

# Setup Caching
self.cache_path = cache_path
Path(self.cache_path, 'files').mkdir(parents=True, exist_ok=True)
self._connect_to_sql()

def _connect_to_sql(self):
self.cache_connection = sqlite3.connect(
Path(self.cache_path, 'cache.db'), isolation_level=None
)
self.cache_cursor = self.cache_connection.cursor()
self.cache_cursor.execute("""
CREATE TABLE IF NOT EXISTS stored_data (
id TEXT PRIMARY KEY
);
""")
self.cache_cursor.execute("""
CREATE TABLE IF NOT EXISTS stored_files (
id TEXT PRIMARY KEY
);
""")

def __getstate__(self):
state = self.__dict__.copy()
del state['cache_connection']
del state['cache_cursor']
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._connect_to_sql()

def get_files(
self,
files: Union[List[Union[str, Path]], Path, str, List[List[str]]],
Expand Down Expand Up @@ -177,3 +216,282 @@ def _create_results_sub_dir(self, name: str = "") -> str:
path = Path(self.workflow_dir, "results", name)
path.mkdir(exist_ok=True)
return str(path)

def _get_column_list(self, table_name: str) -> List[str]:
"""
Get a list of columns in the table.

Args:
table_name (str): The name of the table.

Returns:
columns (List): The columns in the table.
"""
self.cache_cursor.execute(f"PRAGMA table_info({table_name});")
return [col[1] for col in self.cache_cursor.fetchall()]


def _add_column(self, table_name: str, column_name: str) -> None:
"""
Checks if a column is in the cache table and if it is not adds
it to the table.

Args:
table_name (str): The name of the table
column_name (str): The name of the column
"""

# Fetch list of columns
columns = self._get_column_list(table_name)

# Add column to table if it does not exist
if column_name not in columns:
self.cache_cursor.execute(
f"ALTER TABLE {table_name} ADD COLUMN {column_name} TEXT;"
)

def _add_entry(self, table_name: str, dataset_id: str,
column_name: str, path: str) -> None:
"""
Adds an entry to the cache index.

Args:
table_name (str): The name of the table
dataset_id (str): The name of the dataset the data is
attached to.
column_name (str): The name of the column
path (str): The path to be inserted
"""

# Ensure column exists
self._add_column(table_name, column_name)

# Store reference
self.cache_cursor.execute(f"""
INSERT INTO {table_name} (id, {column_name})
VALUES ("{dataset_id}", "{path}")
ON CONFLICT(id)
DO UPDATE SET {column_name} = excluded.{column_name};
""")

def _store_data(self, dataset_id: str, name_tag: str, data) -> None:
"""
Stores data as a cached file. Pandas DataFrames are stored as
parquet files, while all other data structures are stored as
compressed pickle.
Args:
dataset_id (str): The name of the dataset the data is
attached to.
name_tag (str): The name of the associated data structure.
data: Any pickleable data structure.

Returns:
file_path (Path): The file path of the stored file.
"""

path = Path(self.cache_path, 'files', dataset_id)
path.mkdir(parents=True, exist_ok=True)

# DataFrames are stored as apache parquet
if isinstance(data, pd.DataFrame):
path = Path(path, f"{name_tag}.pq")
with open(path, 'wb') as f:
data.to_parquet(f)
# Other data structures are stored as compressed pickle
else:
path = Path(path, f"{name_tag}.pkl.gz")
with gzip.open(path, 'wb') as f:
pkl.dump(data, f)

return path

def store_data(self, dataset_id: str, name_tag: str, data) -> None:
"""
Stores a given data structure.

Args:
dataset_id (str): The name of the dataset the data is
attached to.
name_tag (str): The name of the associated data structure.
data: Any pickleable data structure.
"""

# Store datastructure as file
data_path = self._store_data(dataset_id, name_tag, data)

# Store reference in index
self._add_entry('stored_data', dataset_id, name_tag, data_path)

def store_file(self, dataset_id: str, name_tag: str, file: Path | BytesIO,
remove: bool = True, file_name = None) -> None:
"""
Stores a given file.

Args:
dataset_id (str): The name of the dataset the data is
attached to.
name_tag (str): The name of the associated data structure.
file (Path of File-Like): The file that should be stored.
remove (bool): Wether or not the file should be removed
after copying it.
filetype (str): The file extension of the file. Only
neccessary if a file-like object is used as input.
"""

# Define storage path
if file_name is None:
file_name = f"{name_tag}{file.suffix}"

target_path = Path(
self.cache_path, 'files', dataset_id, file_name
)
target_path.parent.mkdir(parents=True, exist_ok=True)

# Store file in path
if isinstance(file, BytesIO):
with open(target_path, 'wb') as f:
f.write(file.getbuffer())
else:
file = Path(file)
shutil.copy(file, target_path)
if remove:
file.unlink()

# Store reference in index
self._add_entry('stored_files', dataset_id, name_tag, target_path)

def get_results_list(self, name_tags: List[str], partial=False) -> List[str]:
"""
Get all results that contain data for specified fields.

Args:
name_tags (List): the fields to be considered.
"""
# Some columns might not have been created yet (or ever)..
available_columns = (
set(self._get_column_list('stored_data'))
| set(self._get_column_list('stored_files'))
)
name_tags = [n for n in name_tags if n in available_columns]
if len(name_tags) == 0:
return []

# Fetch data
selection_operator = 'OR' if partial else 'AND'
selection_statement = (
f" IS NOT NULL {selection_operator} ".join(name_tags)
+ " IS NOT NULL;"
)
self.cache_cursor.execute(f"""
SELECT id
FROM (
SELECT sd.id AS id, sd.*, sf.*
FROM stored_data sd
LEFT JOIN stored_files sf ON sd.id = sf.id

UNION

SELECT sf.id AS id, sd.*, sf.*
FROM stored_files sf
LEFT JOIN stored_data sd ON sf.id = sd.id
) combined
WHERE {selection_statement}
""")

return [row[0] for row in self.cache_cursor.fetchall()]

def get_results(self, dataset_id, name_tags, partial=False):
results = {}
# Retrieve files as Path objects
file_columns = self._get_column_list('stored_files')
file_columns = [c for c in file_columns if c in name_tags]
if len(file_columns) > 0:
self.cache_cursor.execute(f"""
SELECT {', '.join(file_columns)}
FROM stored_files
WHERE id = '{dataset_id}';
""")
result = self.cache_cursor.fetchone()
for c, r in zip(file_columns, result):
if r is None:
if partial:
continue
else:
raise KeyError(f"{c} does not exist for {dataset_id}")
results[c] = Path(r)
# Retrieve data as Python objects
data_columns = self._get_column_list('stored_data')
data_columns = [c for c in data_columns if c in name_tags]
if len(data_columns) > 0:
self.cache_cursor.execute(f"""
SELECT {', '.join(data_columns)}
FROM stored_data
WHERE id = '{dataset_id}';
""")
result = self.cache_cursor.fetchone()
for c, r in zip(data_columns, result):
if r is None:
if partial:
continue
else:
raise KeyError(f"{c} does not exist for {dataset_id}")
file_path = Path(r)
if file_path.suffix == '.pq':
data = pd.read_parquet(file_path)
else:
with gzip.open(file_path, 'rb') as f:
data = pkl.load(f)
results[c] = data
return results

def result_exists(self, dataset_id, name_tag):

# Check which table is correct
if name_tag in self._get_column_list('stored_data'):
table = 'stored_data'
elif name_tag in self._get_column_list('stored_files'):
table = 'stored_files'
else:
return False

# Check if field value is set
self.cache_cursor.execute(f"""
SELECT {name_tag}
FROM {table}
WHERE id = '{dataset_id}' AND {name_tag} IS NOT NULL
""")
if self.cache_cursor.fetchone():
return True
return False

def remove_results(self, dataset_id):

# Remove references
self.cache_cursor.execute(f"""
DELETE FROM stored_data
WHERE id = '{dataset_id}';
""")
self.cache_cursor.execute(f"""
DELETE FROM stored_files
WHERE id = '{dataset_id}';
""")

# Remove stored files
shutil.rmtree(Path(self.cache_path, 'files', dataset_id))

def clear_cache(self):
shutil.rmtree(Path(self.cache_path, 'files'))
Path(self.cache_path, 'files').mkdir()
self.cache_cursor.execute(f"DROP TABLE IF EXISTS stored_data;")
self.cache_cursor.execute(f"DROP TABLE IF EXISTS stored_files;")
self.cache_cursor.execute("""
CREATE TABLE IF NOT EXISTS stored_data (
id TEXT PRIMARY KEY
);
""")
self.cache_cursor.execute("""
CREATE TABLE IF NOT EXISTS stored_files (
id TEXT PRIMARY KEY
);
""")

10 changes: 8 additions & 2 deletions src/workflow/WorkflowManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@

class WorkflowManager:
# Core workflow logic using the above classes
def __init__(self, name: str, workspace: str):
def __init__(self, name: str, workspace: str, share_cache: bool = False):
self.name = name
self.workflow_dir = Path(workspace, name.replace(" ", "-").lower())
self.file_manager = FileManager(self.workflow_dir)

if share_cache:
cache_path = Path(workspace, 'cache')
else:
cache_path = Path(self.workflow_dir, 'cache')

self.file_manager = FileManager(self.workflow_dir, cache_path)
self.logger = Logger(self.workflow_dir)
self.parameter_manager = ParameterManager(self.workflow_dir)
self.executor = CommandExecutor(self.workflow_dir, self.logger, self.parameter_manager)
Expand Down
Loading