"""Excitation lists base classes

"""
from math import sqrt
import numpy as np

from ase.units import Ha

import gpaw.mpi as mpi
from gpaw.io.logger import GPAWLogger


def get_filehandle(cls, filename, mode='r'):
    cls.filename = filename
    if filename.endswith('.gz'):
        try:
            import gzip
            return gzip.open(filename, mode + 't')
        except ModuleNotFoundError:
            pass
    return open(filename, mode)


class ExcitationLogger(GPAWLogger):
    def header(self):
        pass


class ExcitationList(list):
    """General Excitation List class.
    """
    def __init__(self, log=None, txt='-'):
        # initialise empty list
        list.__init__(self)
        self.energy_to_eV_scale = Ha

        # set output
        if log is not None:
            self.log = log
        else:
            self.log = ExcitationLogger(world=mpi.world)
            self.log.fd = txt

    @property
    def calc(self):
        1 / 0

    def get_energies(self):
        """Get excitation energies in Hartrees"""
        el = []
        for ex in self:
            el.append(ex.get_energy())
        return np.array(el)

    def get_trk(self):
        """Evaluate the Thomas Reiche Kuhn sum rule"""
        trkm = np.zeros(3)
        for ex in self:
            me = ex.get_dipole_me()
            trkm += ex.get_energy() * (me.real ** 2 + me.imag ** 2)
        return 2. * trkm  # scale to get the number of electrons XXX spinpol ?

    def get_polarizabilities(self, lmax=7):
        """Calculate the Polarisabilities
        see Jamorski et al. J. Chem. Phys. 104 (1996) 5134"""
        S = np.zeros(lmax + 1)
        for ex in self:
            e = ex.get_energy()
            f = ex.get_oscillator_strength()[0]
            for l in range(lmax + 1):
                S[l] += e ** (-2 * l) * f
        return S

    def __truediv__(self, x):
        return self.__mul__(1. / x)

    __div__ = __truediv__

    def __rmul__(self, x):
        return self.__mul__(x)

    def __mul__(self, x):
        """Multiply with a number"""
        if isinstance(x, (float, int)):
            result = self.__class__()
            result.dtype = self.dtype
            for kss in self:
                result.append(x * kss)
            return result
        else:
            return RuntimeError('not a number')

    def __sub__(self, other):
        result = self.__class__()
        result.dtype = self.dtype
        assert len(self) == len(other)
        for kss, ksso in zip(self, other):
            result.append(kss - ksso)
        return result

    def __str__(self):
        string = '# ' + str(type(self))
        if len(self) != 0:
            string += ', %d excitations:' % len(self)
        string += '\n'
        for ex in self:
            string += '#  ' + ex.__str__() + '\n'
        return string


class Excitation:

    def get_energy(self):
        """Get the excitations energy relative to the ground state energy
        in Hartrees.
        """
        return self.energy

    def get_dipole_me(self, form='r'):
        """return the excitations dipole matrix element
        including the occupation factor sqrt(fij)"""
        if form == 'r':
            # length form
            return self.me / sqrt(self.energy)
        elif form == 'v':
            # velocity form
            return - np.sqrt(self.fij) * self.muv
        else:
            raise RuntimeError('Unknown form >' + form + '<')

    def get_dipole_tensor(self, form='r'):
        """Return the "oscillator strength tensor"

        self.me is assumed to be::

          form='r': sqrt(f * E) * <I|r|J>,
          form='v': sqrt(f / E) * <I|d/(dr)|J>

        for f = multiplicity, E = transition energy and initial and
        final states::

          |I>, |J>
        """

        if form == 'r':
            # length form
            me = self.me
        elif form == 'v':
            # velocity form
            me = self.muv * np.sqrt(self.fij * self.energy)
        else:
            raise RuntimeError('Unknown form >' + form + '<')

        return 2 * np.outer(me, me.conj())

    def get_oscillator_strength(self, form='r'):
        """Return the excitations dipole oscillator strength."""
        me2_c = self.get_dipole_tensor(form).diagonal().real
        return np.array([np.sum(me2_c) / 3.] + me2_c.tolist())

    def get_rotatory_strength(self, form='r', units='cgs'):
        """Return rotatory strength"""
        if self.magn is None:
            raise RuntimeError('Magnetic moment not available.')

        if units == 'cgs':
            # 10^-40 esu cm erg / G
            # = 3.33564095 * 10^-15 A^2 m^3 s
            # conversion factor after
            # T. B. Pedersen and A. E. Hansen,
            # Chem. Phys. Lett. 246 (1995) 1
            # pre = 471.43
            # From TurboMole
            pre = 64604.8164
        elif units == 'a.u.':
            pre = 1.
        else:
            raise RuntimeError('Unknown units >' + units + '<')

        if form == 'r':
            # length form
            mu = self.mur
        elif form == 'v':
            # velocity form
            mu = self.muv
        else:
            raise RuntimeError('Unknown form >' + form + '<')

        return -pre * np.dot(mu, self.magn)

    def set_energy(self, E):
        """Set the excitations energy relative to the ground state energy"""
        self.energy = E
