Source code for aiida_castep.calculations.inpgen

"""
Module for generating text based CASTEP inputs
"""
import numpy as np
from aiida.common import InputValidationError, MultipleObjectsError
from castepinput import Block
from castepinput.inputs import CellInput, ParamInput

from aiida_castep.common import INPUT_LINKNAMES as in_ln

from ..data.otfg import OTFGData
from .utils import _lowercase_dict, _uppercase_dict, get_castep_ion_line

# pylint: disable=no-member, too-many-locals, too-many-statements, too-many-branches


[docs]class CastepInputGenerator: """ Class for generating CASTEP inputs """
[docs] def __init__(self): """ Initialise the object """ # Initialize the underlying cell file and param file # objects self.param_file = ParamInput() self.cell_file = CellInput() self.local_copy_list_to_append = set() self.param_dict = {} self.settings_dict = {}
[docs] def prepare_inputs(self, reset=True): """ Prepare the inputs :param reset: Rest existing self.param_file and self.cell file """ if reset: self.param_file = ParamInput() self.cell_file = CellInput() self.local_copy_list_to_append = set() param_dict = self.inputs[in_ln["parameters"]].get_dict() settings_node = self.inputs.get("settings", None) settings_dict = settings_node.get_dict() if settings_node else {} # Standardise top level keys should be CAPITALIZED param_dict = _uppercase_dict(param_dict, dict_name="parameters") # Second level keys should be lowercased param_dict = {k: _lowercase_dict(v, dict_name=k) for k, v in param_dict.items()} # Set iprint to 1 param_dict["PARAM"]["iprint"] = param_dict["PARAM"].get("iprint", 1) # Set run_time using define value for this calculation run_time = self.inputs.metadata.options.get("max_wallclock_seconds") if run_time: n_seconds = run_time * 0.95 n_seconds = (n_seconds // 60) * 60 # Round down to the nearest minutes # Do not do any thing if calculated time is less than 1 hour if n_seconds < 180: pass elif "run_time" not in param_dict["PARAM"]: param_dict["PARAM"]["run_time"] = int(n_seconds) # Set the default comment using the label of this calculation comment_str = self.inputs.metadata.get("label", None) if "comment" not in param_dict["PARAM"] and comment_str: param_dict["PARAM"]["comment"] = comment_str # Expose at the instance level self.param_dict = param_dict self.settings_dict = settings_dict # prepare the cell and param files self._prepare_cell_file() self._prepare_param_file()
[docs] def _prepare_cell_file(self): """ Prepare the cell file """ cell_vector_list = [] for vector in self.inputs[in_ln["structure"]].cell: cell_vector_list.append("{:18.10f} {:18.10f} " "{:18.10f}".format(*vector)) self.cell_file["LATTICE_CART"] = Block(cell_vector_list) # --------- ATOMIC POSITIONS--------- # for kind in self.inputs[in_ln['structure']].kinds: atomic_position_list = [] mixture_count = 0 # deal with initial spins spin_list = self.settings_dict.pop("SPINS", None) label_list = self.settings_dict.pop("LABELS", None) for i, site in enumerate(self.inputs[in_ln["structure"]].sites): # get the kind of the site kind = self.inputs[in_ln["structure"]].get_kind(site.kind_name) # Position is always needed pos = site.position mixture = False try: name = kind.symbol # If we are dealing with mixed atoms except ValueError: name = kind.symbols mixture_count += 1 mixture = True # If the symbol is not the same as the kindname # e.g there are inequivalent atoms of the same element # We change the name to '<symbol>:<kind.name>' if not mixture: # Only do this if the name(symbol) is not equal to the kindname if name != kind.name: name = name + ":" + kind.name else: # If we are dealing with the mixtures, # we also add the kindname as an identifier name = [ntemp + ":" + kind.name for ntemp in name] if spin_list: spin = spin_list[i] else: spin = None # deal with labels if label_list: label = label_list[i] else: label = None # Get the line of positions_abs block line = get_castep_ion_line( name, pos, label=label, spin=spin, occupation=kind.weights, mix_num=mixture_count, ) # Append the line to the list atomic_position_list.append(line) # End of the atomic position block self.cell_file["POSITIONS_ABS"] = Block(atomic_position_list) # Check the consistency of spin in parameters if spin_list: # In case of non-collinear spin if isinstance(spin_list[0], (list, tuple)): total_spin = np.linalg.norm(spin_list, axis=1).sum() else: total_spin = sum(s for s in spin_list if s) param_spin = self.param_dict["PARAM"].get("spin", None) if param_spin is not None: # If spin is specified - check consistency if param_spin != total_spin: raise InputValidationError( "Inconsistent spin in cell and param files." "Total spin: {} in cell file but {} in param file".format( total_spin, param_spin ) ) else: # If no spin specified, do it automatically # Note that we don't check if spin polarized calculation is # requested in the first place # self.param_dict["PARAM"]["spin"] = total_spin # Explicitly setting `spin` is not need since 18.1 - having those in CELL # is enough pass # Validate if spin_treatment: vector is activated. # --------- KPOINTS --------- kpoints = self.inputs.get("kpoints") use_kpoints = self.inputs.metadata.options.use_kpoints if not kpoints and use_kpoints: raise InputValidationError("Kpoints required but not found in the input") if self.inputs.metadata.options.use_kpoints: try: mesh, offset = kpoints.get_kpoints_mesh() has_mesh = True except AttributeError: try: kpoints_list = kpoints.get_kpoints() num_kpoints = len(kpoints_list) has_mesh = False if num_kpoints == 0: raise InputValidationError( "At least one k points must be provided" ) except AttributeError: raise InputValidationError("No valid kpoints have been found") try: _, weights = kpoints.get_kpoints(also_weights=True) except AttributeError: weights = np.ones(num_kpoints, dtype=float) / num_kpoints kpoints_line_list = [] if has_mesh is True: self.cell_file["kpoints_mp_grid"] = "{} {} {}".format(*mesh) if offset != [0.0, 0.0, 0.0]: self.cell_file["kpoints_mp_offset"] = "{} {} {}".format(*offset) else: for kpoint, weight in zip(kpoints_list, weights): kpoints_line_list.append( "{:18.10f} {:18.10f} " "{:18.10f} {:18.10f}".format( kpoint[0], kpoint[1], kpoint[2], weight ) ) self.cell_file["KPOINTS_LIST"] = Block(kpoints_line_list) # --------- keywords in cell file--------- for key, value in self.param_dict["CELL"].items(): if "species_pot" in key: raise MultipleObjectsError( "Pseudopotentials should not be specified directly" ) # Constructing block keywords # List of strings are passed as blocks if isinstance(value, (list, tuple)): if isinstance(value[0], str): value = Block(value) self.cell_file[key] = value self._prepare_pseudo_potentials()
[docs] def _include_extra_kpoints(self, kpn_node, kpn_name, kpn_settings, report_fn=None): """Write extra kpoints to the cell""" try: mesh, offset = kpn_node.get_kpoints_mesh() has_mesh = True except AttributeError: # Not defined as mesh try: bs_kpts_list = kpn_node.get_kpoints() num_kpoints = len(bs_kpts_list) has_mesh = False if num_kpoints == 0: raise InputValidationError("At least one k points must be provided") except AttributeError: raise InputValidationError( "No valid {}_kpoints have been found from node {}".format( kpn_name.lower(), kpn_node.pk ) ) # Do we have weights defined? try: _, weights = kpn_node.get_kpoints(also_weights=True) except AttributeError: # If not, fill with fractions if kpn_settings["need_weights"] is True: weights = np.ones(num_kpoints, dtype=float) / num_kpoints if report_fn is not None: report_fn( "Warning:filling evenly distributed weights for {}_kpoints".format( kpn_name ) ) # now add to the cell file if has_mesh is True: mesh_name = f"{kpn_name}_kpoint_mp_grid" self.cell_file[mesh_name] = "{} {} {}".format(*mesh) if offset != [0.0, 0.0, 0.0]: self.cell_file[mesh_name.replace("grid", "offset")] = "{} {} {}".format( *offset ) else: extra_kpts_lines = [] if kpn_settings["need_weights"] is True: for kpoint, weight in zip(bs_kpts_list, weights): extra_kpts_lines.append( f"{kpoint[0]:18.10f} {kpoint[1]:18.10f} {kpoint[2]:18.10f} {weight:18.14f}" ) else: for kpoint in bs_kpts_list: extra_kpts_lines.append( f"{kpoint[0]:18.10f} {kpoint[1]:18.10f} {kpoint[2]:18.10f}" ) bname = f"{kpn_name}_kpoint_list".upper() self.cell_file[bname] = Block(extra_kpts_lines)
[docs] def _prepare_pseudo_potentials(self): """ Prepare the pseudopotential part of the cell file """ # --------- PSEUDOPOTENTIALS -------- # Check if we are using UPF pseudos # Now only support simple elemental pseudopotentials species_pot_map = {} pseudos = self.inputs.pseudos # Make kindname unique for kind in self.inputs[in_ln["structure"]].kinds: symbols = kind.symbols # If the site has multiple symbols, add all of them to the list mixture = False if len(symbols) > 1: mixture = True for symbol in symbols: if symbol == kind.name: pseudo_name = symbol else: pseudo_name = symbol + ":" + kind.name if not mixture: # Get the pseudopotential is defined by the kind.name ps_node = pseudos[kind.name] else: # If with mixture the pseudopotential is deined as '<kind_name>_<symbol>' ps_node = pseudos[kind.name + "_" + symbol] # If we are using OTFG, just add the string property of it if isinstance(ps_node, OTFGData): species_pot_map[pseudo_name] = "{:5} {}".format( pseudo_name, ps_node.string ) else: # If we are dealing with file based pseudopotentials objects # Add the specification to the file try: species_pot_map[pseudo_name] = "{:5} {}".format( pseudo_name, ps_node.filename ) # Add to the copy list self.local_copy_list_to_append.add( (ps_node.uuid, ps_node.filename, ps_node.filename) ) except Exception as error: raise InputValidationError( "Unknown node as pseudo: {}. Exception raised: {}".format( ps_node, error ) ) # Ensure it is a list self.cell_file["SPECIES_POT"] = Block(list(species_pot_map.values()))
[docs] def _prepare_param_file(self): """ Prepare the content of PARAM file """ for key, value in self.param_dict["PARAM"].items(): if isinstance(value, (list, tuple)): if isinstance(value[0], str): value = Block(value) self.param_file[key] = value