### script to handle isotopologues
### copy from https://github.com/bmorris3/shone/tree/refs/heads/opacity-downloads

import re
import numpy as np
from periodictable import elements


__all__ = [
    'isotopologue_to_species',
    'species_name_to_common_isotopologue_name'
]


def isotopologue_to_species(isotopologue):
    """
    Convert isotopologue name to common species name.

    Example: Take 1H2-16O and turn it to H2O, or take 48Ti-16O and turn it to TiO.

    Parameters
    ----------
    isotopologue : str
        Isotopologue name, like "1H2-16O".

    Returns
    -------
    common_name : str
        Common name, like "H2O".
    """
    species = ""
    for element in isotopologue.split('-'):
        for s in re.findall(r'\D+\d*', element):
            species += ''.join(s)
    return species if len(species) > 0 else isotopologue


def species_name_to_common_isotopologue_name(species):
    """
    Convert generic species name, like "H2O", to isotopologue name like "1H2-16O".

    Parameters
    ----------
    species : str
        Generic name, like "H2O".

    Returns
    -------
    isotopologue_name : str
        Isotopologue name, like "1H2-16O".
    """
    atoms = np.array(list(filter(
        lambda x: len(x) > 0, re.split(r"(?<=[a-z])|(?=[A-Z])|\d", species)
    )))

    multipliers = np.array([
        int(x) if len(x) > 0 else 1 for x in re.split(r'\D', species)
    ])
    lens = [len(''.join(atom)) for atom in atoms]
    multipliers_skipped = np.array([multipliers[cs] for cs in np.cumsum(lens)])

    masses = np.array([
        round(getattr(elements, atom).mass) for atom, mult in zip(atoms, multipliers_skipped)
    ])

    if len(atoms) > 1:
        correct_notation = '-'.join([
            str(mass) + a + (str(mult) if mult > 1 else '')
            for a, mult, mass in zip(atoms, multipliers_skipped, masses)
        ])

    # If single atom, give only the name of the atom:
    else:
        correct_notation = atoms[0]

    return correct_notation
