#!/users/keimre/soft/miniconda3/bin/python import numpy as np import xml.etree.ElementTree as et import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt import argparse ang_2_bohr = 1.0/0.52917721067 hart_2_ev = 27.21138602 parser = argparse.ArgumentParser( description='Plots QE bands.') parser.add_argument( 'qe_xml', metavar='FILENAME', help='QE xml file.') args = parser.parse_args() def read_band_data_new_xml(xml_file): """ Reads data from QE bands calculations (new XML) Returns: - kpts[i_kpt] = [kx, ky, kz] in [2*pi/a] - eigvals[i_kpt, i_band] in [eV] - fermi_en in [eV] """ data_file_xml = et.parse(xml_file) data_file_root = data_file_xml.getroot() output_node = data_file_root.find('output') # Find fermi band_node = output_node.find('band_structure') fermi_en = float(band_node.find('fermi_energy').text)*27.21138602 lsda = band_node.find('spinorbit').text kpts = [] eigvals = [] for kpt in band_node.findall("ks_energies"): k_coords = np.array(kpt.find('k_point').text.split(), dtype=float) kpts.append(k_coords) eig_vals = np.array(kpt.find('eigenvalues').text.split(), dtype=float) eigvals.append(eig_vals*27.21138602) kpts = np.array(kpts) eigvals = np.array(eigvals) return kpts, eigvals, fermi_en kpts, eigvals, fermi_en = read_band_data_new_xml(args.qe_xml) k_arr = [0.0] cum_sum = 0.0 for i_k in range(1, len(kpts)): k0 = kpts[i_k-1] k1 = kpts[i_k] dk = k1 - k0 cum_sum += np.linalg.norm(dk) k_arr.append(cum_sum) x_ticks = {} for i_k, kpt in enumerate(kpts): if np.allclose(kpt, np.array([0.0, 0.0, 0.0])): x_ticks['G'] = k_arr[i_k] if np.allclose(kpt, np.array([-0.5, 0.5, 0.5])): x_ticks['L'] = k_arr[i_k] if np.allclose(kpt, np.array([0.5, 0.5, 0.5])): x_ticks['L'] = k_arr[i_k] if np.allclose(kpt, np.array([-1.0, 0.0, 0.0])): x_ticks['X'] = k_arr[i_k] if np.allclose(kpt, np.array([0.5, 0.0, 0.5])): x_ticks['X'] = k_arr[i_k] plt.figure(figsize=(6, 10)) plt.plot(k_arr, eigvals, 'b', lw=2.0) plt.xlim([0.0, np.max(k_arr)]) plt.ylim([np.min(eigvals)-2, np.max(eigvals)+2]) plt.xticks(list(x_ticks.values()), list(x_ticks.keys())) for xtick in list(x_ticks.values()): plt.axvline(xtick, color='gray') plt.ylabel("energy [eV]") plt.savefig("./bands.png", dpi=200)