"""
Mock CASTEP code.
A more advanced way of mocking. The input of a calculation can be
hash and match from a repository of calculation that has been run.
This way we can perform tests for workchain without the need for
injecting test code into the workchain logic itself.
"""
import hashlib
import logging
import os
import shutil
from pathlib import Path
from typing import Union
import numpy as np
from aiida.repository.common import FileType
from castepinput import CellInput, ParamInput
# pylint: disable=logging-format-interpolation, import-outside-toplevel
INPUT_OBJECTS = ("aiida.param", "aiida.cell")
SEED_NAME = "aiida"
EXCLUDED = tuple()
[docs]def get_hash(dict_obj):
"""
Return the hash for a dictionary of arbitrary items.
This is not meant to be robust for all cases, but should be OK for
matching hashes of calculation inputs.
The point here is to make the value invariant to the permutation of key orders.
"""
# If a list is passed - convert it to a dictionary with keys being the indices
if isinstance(dict_obj, list):
dict_obj = dict(enumerate(dict_obj))
rec = []
for key, value in dict_obj.items():
key = repr(key)
# For numpy/list with floating point zero (0.0) we have to converge -0.0 to 0.0
# as they should be equivalent
if isinstance(value, np.ndarray):
value[value == 0] = 0
elif isinstance(value, list) and 0 in value:
value = [type(tmp)(0) if tmp == 0 else tmp for tmp in value]
# Handle if value itself is float zero
if isinstance(value, float) and value == 0:
value = 0.0
if isinstance(value, (dict, list)):
rec.append(key + ":" + get_hash(value)[0])
else:
# Use the string representation
rec.append(key + ":" + repr(value) + ":" + repr(type(value)))
# Update, use sorted so the original order does not matter, in force case so
# sting keys with upper/lower cases are treated as the same
base = [record.encode().lower() for record in sorted(rec)]
# Compute the hash
md5 = hashlib.md5()
for item in base:
md5.update(item)
return md5.hexdigest(), base
[docs]class MockRegistry:
"""
A class to create and manage a registry of completed calculations.
Calculations are identified using the hash of the parsed inputs.
"""
[docs] def __init__(self, base_path):
"""
Instantiate and Registry
"""
self.base_path = Path(base_path)
self.reg_hash = {}
self.reg_name = {}
self.logger = logging.getLogger("aiida_castep.utils.mock_code.MockRegistry")
self._setup_logger()
self.scan()
[docs] def _setup_logger(self, level=logging.INFO):
"""Setup the logger"""
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler = logging.StreamHandler()
handler.setLevel(level)
handler.setFormatter(formatter)
self.logger.setLevel(level)
[docs] def scan(self):
"""
Scan the base folder and locate input/output folders
"""
for output_folder in Path(self.base_path).glob("**/out"):
calc_base_folder = output_folder.parent.absolute()
self._register_folder(calc_base_folder)
[docs] def get_path_by_hash(self, hash_val):
"""
Return the output folder for a given hash
"""
return Path(self.reg_hash[hash_val])
[docs] def get_path_by_name(self, name):
"""
Return the output folder for a given hash
"""
return Path(self.reg_hash[self.reg_name[name]])
[docs] @staticmethod
def compute_hash(input_folder: Path, seed=SEED_NAME):
"""
Compute the hash of a input folder
"""
param_file = ParamInput.from_file(input_folder / f"{seed}.param", plain=True)
cell_file = CellInput.from_file(input_folder / f"{seed}.cell", plain=True)
return get_hash({"CELL": dict(cell_file), "PARAM": dict(param_file)})[0]
[docs] def upload_calc(
self, folder: Path, rel_path: Union[Path, str], excluded_object=None
):
"""
Register a calculation folder to the repository
"""
inp = list(INPUT_OBJECTS)
excluded = list(EXCLUDED)
if excluded_object:
excluded.extend(excluded_object)
# Check if the repository folder already exists
repo_calc_base = self.base_path / rel_path
if repo_calc_base.exists():
raise FileExistsError(
f"There is already a directory at {repo_calc_base.resolve()}."
)
# Deposit the objects
repo_calc_base.mkdir(parents=True)
repo_in = repo_calc_base / "inp"
repo_out = repo_calc_base / "out"
repo_in.mkdir(parents=True)
repo_out.mkdir(parents=True)
for obj in folder.glob("*"):
if obj.name in inp:
shutil.copy2(obj, repo_in)
elif obj.name not in excluded:
if obj.is_file():
shutil.copy2(obj, repo_out)
elif obj.is_dir():
shutil.copytree(obj, repo_out / obj.name)
# Update the hash table
self._register_folder(repo_calc_base)
[docs] def _register_folder(self, calc_base: Path):
"""
Register a folder inside the repository
"""
# Get the relative path to the base
rel = calc_base.relative_to(self.base_path)
# Compute the hash
hash_val = self.compute_hash(calc_base / "inp")
# Link absolute path to hash, and hash to relative path (used as name)
self.reg_hash[hash_val] = calc_base.absolute()
self.reg_name[str(rel)] = hash_val
[docs] def upload_aiida_calc(
self, calc_node, rel_path: Union[str, Path], excluded_names=None
):
"""
Register an aiida calc_class
"""
# Check if the repository folder already exists
repo_calc_base = self.base_path / rel_path
if repo_calc_base.exists():
raise FileExistsError(
f"There is already a directory at {repo_calc_base.resolve()}."
)
# Deposit the objects
repo_calc_base.mkdir(parents=True)
repo_in = repo_calc_base / "inp"
repo_out = repo_calc_base / "out"
repo_in.mkdir(parents=True)
repo_out.mkdir(parents=True)
exclude = list(EXCLUDED)
if excluded_names:
exclude.extend(excluded_names)
# Copy the input objects
for obj in calc_node.list_objects():
if obj.name in exclude:
continue
copy_from_aiida(obj.name, calc_node, repo_in)
# Copy the retrieved objects
for obj in calc_node.outputs.retrieved.list_objects():
if obj.name in exclude:
continue
copy_from_aiida(obj.name, calc_node.outputs.retrieved, repo_out)
self.logger.info("Calculation %s has been registered", calc_node)
self._register_folder(repo_calc_base)
[docs] def upload_aiida_work(self, worknode, rel_path: Union[str, Path]):
"""
Upload all calculations in a workchain node
"""
from aiida.orm import CalcJobNode
from aiida.plugins import CalculationFactory
calc_class = CalculationFactory("castep.castep")
to_upload = []
for node in worknode.called_descendants:
if isinstance(node, CalcJobNode) and node.process_class is calc_class:
to_upload.append(node)
to_upload.sort(key=lambda x: x.ctime)
self.logger.info(
"Collected %s nodes to upload under name %s.", to_upload, rel_path
)
for idx, node in enumerate(to_upload):
rel = Path(rel_path) / f"calc-{idx:03d}"
self.upload_aiida_calc(node, rel)
self.logger.info("WorkChain %s has been uploaded.", worknode)
[docs]class MockCastep:
"""
Mock CastepExecutable
"""
[docs] def __init__(self, workdir: Union[str, Path], registry: MockRegistry, seed="aiida"):
"""
Mock CASTEP executable that copies over outputs from existing calculations.
Inputs are hash and looked for.
"""
self.workdir = workdir
self.registry = registry
self.seed = seed
[docs] def run(self, debug=True):
"""
Run the mock CASTEP
"""
hash_val = self.registry.compute_hash(self.workdir, seed=self.seed)
if debug:
print(f"Target hash value: {hash_val}")
if hash_val in self.registry.reg_hash:
self.registry.extract_calc_by_hash(hash_val, self.workdir)
else:
if debug:
print(f"Registered hashes: {self.registry.reg_hash}")
raise ValueError("The calculation is not registered!!")
@property
def is_runnable(self) -> bool:
"""Return wether the mock code can be run"""
hash_val = self.registry.compute_hash(self.workdir)
return hash_val in self.registry.reg_hash
[docs]def copy_from_aiida(name: str, node, dst: Path):
"""
Copy objects from aiida repository.
:param name: The full name (including the parent path) of the object.
:param node: Node object for which the objects in the repo to be copied.
:param dst: Path of the destination folder.
This is a recursive function so directory copying also works.
"""
obj = node.get_object(name)
# If it is a directory, copy the contents one by one
if obj.file_type == FileType.DIRECTORY:
for sub_obj in node.list_objects(name):
copy_from_aiida(os.path.join(name, sub_obj.name), node, dst)
else:
# Anything else
with node.open(name) as fsource:
# Make parent directory if needed
frepo_path = dst / name
Path(frepo_path.parent).mkdir(exist_ok=True, parents=True)
# Write the object
with open(frepo_path, "w") as fdst:
shutil.copyfileobj(fsource, fdst)