diff --git a/cmeutils/gsd_utils.py b/cmeutils/gsd_utils.py index 88893d3..60a6723 100644 --- a/cmeutils/gsd_utils.py +++ b/cmeutils/gsd_utils.py @@ -380,6 +380,131 @@ def xml_to_gsd(xmlfile, gsdfile): print(f"XML data written to {gsdfile}") +def trim_snapshot_molecules(parent_snapshot, mol_indices): + """Given a snapshot of a system, trim the snapshot to only include + a subset of the molecules. + + Parameters + ---------- + parent_snapshot : gsd.hoomd.Frame + The snapshot to read in. + mol_indices : list of np.ndarray + List of arrays where each array contains the indices + of the particles in a molecule to include. + + Returns + ------- + gsd.hoomd.Frame + The new snapshot with only the specified molecules. + + Notes + ----- + See cmetuils.gsd_utils.get_molecule_cluster for a method to obtain + mol_indices. + + """ + new_snap = gsd.hoomd.Frame() + new_snap.configuration.box = parent_snapshot.configuration.box + new_snap.particles.N = sum(len(i) for i in mol_indices) + + # Write out particle info + for attr in [ + "position", + "mass", + "velocity", + "orientation", + "image", + "diameter", + "angmom", + "typeid", + ]: + setattr( + new_snap.particles, + attr, + np.concatenate( + list( + getattr(parent_snapshot.particles, attr)[i] + for i in mol_indices + ) + ), + ) + new_snap.particles.types = parent_snapshot.particles.types + + particle_index_map = dict() + count = 0 + for indices in mol_indices: + for i in indices: + particle_index_map[i] = count + count += 1 + + # Write out bond info + mol_bond_groups = [] + mol_bond_ids = [] + for indices in mol_indices: + mask = np.any( + np.isin(parent_snapshot.bonds.group, indices.flatten()), axis=1 + ) + parent_mol_bonds = parent_snapshot.bonds.group[np.where(mask)[0]] + parent_mol_bond_typeids = parent_snapshot.bonds.typeid[ + np.where(mask)[0] + ] + new_mol_bonds = np.vectorize(particle_index_map.get)(parent_mol_bonds) + mol_bond_groups.append(new_mol_bonds) + mol_bond_ids.append(parent_mol_bond_typeids) + + new_snap.bonds.types = parent_snapshot.bonds.types + new_snap.bonds.group = np.concatenate(mol_bond_groups) + new_snap.bonds.typeid = np.concatenate(mol_bond_ids) + new_snap.bonds.N = sum(len(i) for i in mol_bond_ids) + + # Write out angle info + mol_angle_groups = [] + mol_angle_ids = [] + for indices in mol_indices: + mask = np.any( + np.isin(parent_snapshot.angles.group, indices.flatten()), axis=1 + ) + parent_mol_angles = parent_snapshot.angles.group[np.where(mask)[0]] + parent_mol_angle_typeids = parent_snapshot.angles.typeid[ + np.where(mask)[0] + ] + new_mol_angles = np.vectorize(particle_index_map.get)(parent_mol_angles) + mol_angle_groups.append(new_mol_angles) + mol_angle_ids.append(parent_mol_angle_typeids) + + new_snap.angles.types = parent_snapshot.angles.types + new_snap.angles.group = np.concatenate(mol_angle_groups) + new_snap.angles.typeid = np.concatenate(mol_angle_ids) + new_snap.angles.N = sum(len(i) for i in mol_angle_ids) + + # Write out dihedral info + mol_dihedral_groups = [] + mol_dihedral_ids = [] + for indices in mol_indices: + mask = np.any( + np.isin(parent_snapshot.dihedrals.group, indices.flatten()), axis=1 + ) + parent_mol_dihedrals = parent_snapshot.dihedrals.group[ + np.where(mask)[0] + ] + parent_mol_dihedral_typeids = parent_snapshot.dihedrals.typeid[ + np.where(mask)[0] + ] + new_mol_dihedrals = np.vectorize(particle_index_map.get)( + parent_mol_dihedrals + ) + mol_dihedral_groups.append(new_mol_dihedrals) + mol_dihedral_ids.append(parent_mol_dihedral_typeids) + + new_snap.dihedrals.types = parent_snapshot.dihedrals.types + new_snap.dihedrals.group = np.concatenate(mol_dihedral_groups) + new_snap.dihedrals.typeid = np.concatenate(mol_dihedral_ids) + new_snap.dihedrals.N = sum(len(i) for i in mol_dihedral_ids) + + new_snap.validate() + return new_snap + + def identify_snapshot_connections(snapshot): """Identify angle and dihedral connections in a snapshot from bonds.