import logging
import os
import astropy.units as u
import h5py
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from astropy.io.misc.hdf5 import read_table_hdf5
from exorad.log import Logger
from exorad.log import setLogLevel
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
cmap = matplotlib.cm.get_cmap("Set1")
[docs]
class Plotter(Logger):
"""
Plotter class. It offers a fast and easy way to produce diagnostic plots on the produced data.
Parameters
-----------
input_table: Qtable
table where to grab data and wl grid to plot.
channels: dict
dictionary describing the channels in the payload. Default is None
"""
def __init__(self, input_table, channels=None, payload=None):
self.set_log_name()
self.inputTable = input_table
self.channels = channels
self.payload = payload
self.fig = None
self.fig_efficiency = None
[docs]
def plot_bands(self, ax, scale="log", channel_edges=True):
"""
It plots the channels bands behind the indicated axes.
Parameters
-----------
ax: matplotlib.axes
axes where to plot the bands
Returns
--------
matplotlib.axes.axes
Note
----
The Class input_table input parameter is required for this method to work.
"""
channels = set(self.inputTable["chName"])
norm = matplotlib.colors.Normalize(vmin=0.0, vmax=len(channels))
tick_list = []
for k, channelName in enumerate(channels):
wl_min = min(
self.inputTable["LeftBinEdge"][
np.where(self.inputTable["chName"] == channelName)
]
)
if hasattr(wl_min, "unit"):
wl_min = wl_min.value
wl_max = max(
self.inputTable["RightBinEdge"][
np.where(self.inputTable["chName"] == channelName)
]
)
if hasattr(wl_max, "unit"):
wl_max = wl_max.value
ax.axvspan(
wl_min,
wl_max,
alpha=0.1,
zorder=0,
color=cmap(
norm(k),
),
)
tick_list.append(wl_min)
tick_list.append(wl_max)
ax.set_xscale(scale)
if channel_edges:
ax.set_xticks(tick_list)
ax.get_xaxis().set_major_formatter(
matplotlib.ticker.ScalarFormatter()
)
return ax
[docs]
def plot_efficiency(self, scale="log", channel_edges=True):
"""
It produces a figure with payload efficiency over wavelength.
The quantities reported are quantum efficiency, transmission and the photon conversion efficiency (pce)
computed as the product of the quantum efficiency and transmission.
Returns
--------
matplotlib.pyplot.figure
Note
----
The Class channels input parameter is required for this method to work.
"""
if self.channels:
from matplotlib.lines import Line2D
from exorad.models.signal import Signal
norm = matplotlib.colors.Normalize(
vmin=0.0, vmax=len(self.channels)
)
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
fig.suptitle("Payload photon conversion efficiency")
if self.channels:
keys = ["transmission", "qe"]
for ch in self.channels:
pce = None
for i, key in enumerate(keys):
data = self.channels[ch].built_instr[
"{}_data".format(key)
]
sig = Signal(
wl_grid=data["wl_grid"]["value"]
* u.Unit(data["wl_grid"]["unit"]),
data=data["data"]["value"],
)
ax.plot(
sig.wl_grid,
sig.data,
color=cmap(
norm(i),
),
zorder=10,
)
if not pce:
pce = sig
else:
sig.spectral_rebin(pce.wl_grid)
pce.data *= sig.data
ax.plot(
pce.wl_grid,
pce.data,
color=cmap(
norm(i + 1),
),
zorder=10,
)
lines, labels = [], []
for i, key in enumerate(keys):
lines.append(
Line2D(
[0],
[0],
color=cmap(
norm(i),
),
lw=4,
)
)
labels.append(key)
lines.append(
Line2D(
[0],
[0],
color=cmap(
norm(i + 1),
),
lw=4,
)
)
labels.append("pce")
else:
pce = self.inputTable["TR"] * self.inputTable["QE"]
self.inputTable["pce"] = pce
keys = ["TR", "QE", "pce"]
self.debug("efficiency keys : {}".format(keys))
for e in keys:
ax.plot(
self.inputTable["Wavelength"],
self.inputTable[e],
label=e,
zorder=10,
)
# ax.plot(self.inputTable['Wavelength'], self.inputTable[e], c='None')
ax.grid(zorder=0)
ax = self.plot_bands(ax, scale, channel_edges)
ax.legend(handles=lines, labels=labels)
ax.set_title("Photon conversion efficiency")
ax.set_xlabel(r"Wavelength [$\mu m$]")
ax.set_ylabel("efficiency")
# ax.set_xscale('log')
self.fig_efficiency = fig
return fig
else:
self.error(
"channels parameter is required for this method to work"
)
[docs]
def plot_noise(self, ax, ylim=None, scale="log", channel_edges=True):
"""
It plots the noise components found in the input table in the indicated axes.
Parameters
-----------
ax: matplotlib.axes
axes where to plot the noises
ylim: float
if present, it sets the axes y bottom lim. Default is None.
Returns
--------
matplotlib.axes.axes
Note
----
The Class input_table input parameter is required for this method to work.
"""
noise_keys = [
x for x in self.inputTable.keys() if "noise" in x or "custom" in x
]
self.debug("noise keys : {}".format(noise_keys))
for k, n in enumerate(noise_keys):
if n == "total_noise":
ax.plot(
self.inputTable["Wavelength"],
self.inputTable[n],
zorder=9,
lw=1,
c="k",
marker=".",
markersize=5,
label="total_noise",
alpha=0.8,
) # , c='None')
else:
if self.inputTable[n].unit == u.hr**0.5:
noise = self.inputTable[n]
else:
self.debug(
"{} rescaled by starSignal_inAperture".format(n)
)
noise = (
self.inputTable[n]
/ self.inputTable["star_signal_inAperture"]
)
# ax.scatter(self.inputTable['Wavelength'], noise, label=n, zorder=10, s=5, color=palette[k])
ax.plot(
self.inputTable["Wavelength"],
noise,
zorder=9,
lw=1,
alpha=0.5,
marker=".",
label=n,
) # color=palette[k]) # c='None')
if not ylim:
ax.set_ylim(1e-7)
else:
ax.set_ylim(ylim)
# ax.grid(zorder=0, which='both')
locmaj = matplotlib.ticker.LogLocator(base=10, numticks=12)
ax.yaxis.set_major_locator(locmaj)
locmin = matplotlib.ticker.LogLocator(
base=10.0, subs=(0.2, 0.4, 0.6, 0.8), numticks=12
)
ax.yaxis.set_minor_locator(locmin)
ax.yaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())
ax.grid(axis="y", which="minor", alpha=0.3)
ax.grid(axis="y", which="major", alpha=0.5)
# ax.legend(bbox_to_anchor=(1, 1))
ax.legend(
prop={"size": 12},
loc="upper left",
ncol=3,
bbox_to_anchor=(0.05, -0.2),
labelspacing=1.2,
handlelength=1,
)
ax.set_title("Noise Budget")
ax.set_xlabel(r"Wavelength [$\mu m$]")
ax.set_ylabel(r"relative noise [$\sqrt{{hr}}$]")
ax.set_yscale("log")
# ax.set_xscale('log')
ax = self.plot_bands(ax, scale, channel_edges)
return ax
[docs]
def plot_signal(self, ax, ylim=None, scale="log", channel_edges=True):
"""
It plots the signal components found in the input table in the indicated axes.
Parameters
-----------
ax: matplotlib.axes
axes where to plot the signals
ylim: float
if present, it sets the axes y bottom lim. Default is None.
Returns
--------
matplotlib.axes.axes
Note
----
The Class input_table input parameter is required for this method to work.
"""
keys = [
x
for x in self.inputTable.keys()
if "signal" in x and "noise" not in x
]
self.debug("signal keys : {}".format(keys))
for k, s in enumerate(keys):
# ax.scatter(self.inputTable['Wavelength'], self.inputTable[s], label=s, zorder=10, s=5, color=palette[k])
ax.plot(
self.inputTable["Wavelength"],
self.inputTable[s],
zorder=9,
lw=1,
alpha=0.5,
marker=".",
label=s,
)
# color=palette[k]) # , c='None')
if not ylim:
ax.set_ylim(1e-3)
else:
ax.set_ylim(ylim)
# ax.grid(zorder=0, which='both')
locmaj = matplotlib.ticker.LogLocator(base=10, numticks=12)
ax.yaxis.set_major_locator(locmaj)
locmin = matplotlib.ticker.LogLocator(
base=10.0, subs=(0.2, 0.4, 0.6, 0.8), numticks=12
)
ax.yaxis.set_minor_locator(locmin)
ax.yaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())
ax.grid(axis="y", which="minor", alpha=0.3)
ax.grid(axis="y", which="major", alpha=0.5)
# ax.legend(bbox_to_anchor=(1, 1))
ax.legend(
prop={"size": 12},
loc="upper left",
ncol=3,
bbox_to_anchor=(0.05, -0.2),
labelspacing=1.2,
handlelength=1,
)
ax.set_title("Signals")
ax.set_xlabel(r"Wavelength [$\mu m$]")
ax.set_ylabel("$ct/s$")
ax.set_yscale("log")
# ax.set_xscale('log')
ax = self.plot_bands(ax, scale, channel_edges)
return ax
[docs]
def plot_table(
self, sig_ylim=None, noise_ylim=None, scale="log", channel_edges=True
):
"""
It produces a figure with signal and noise for the input table.
Returns
--------
matplotlib.pyplot.figure
(matplotlib.axes.axes, matplotlib.axes.axes)
Note
----
The Class input_table input parameter is required for this method to work.
"""
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))
fig.suptitle(self.inputTable.meta["name"])
ax1 = self.plot_signal(
ax1, ylim=sig_ylim, scale=scale, channel_edges=channel_edges
)
ax2 = self.plot_noise(
ax2, ylim=noise_ylim, scale=scale, channel_edges=channel_edges
)
plt.tight_layout()
plt.subplots_adjust(top=0.9, bottom=0.22, hspace=0.7)
self.fig = fig
return fig, (ax1, ax2)
[docs]
def save_fig(self, name, efficiency=False):
"""
It saves the produced figure.
Parameters
--------
name: str
figure name
efficiency: bool
if True it wll save the efficiency plot instead of the table plot. Default is False.
"""
try:
if efficiency:
self.fig_efficiency.savefig("{}".format(name))
else:
self.fig.savefig("{}".format(name))
self.info("plot saved in {}".format(name))
except AttributeError:
self.error(
"the indicated figure is not available. Check if you have produced it."
)
[docs]
def main():
import argparse
from exorad.__version__ import __version__
from exorad.utils.util import parse_range
parser = argparse.ArgumentParser(
description="ExoRad {}".format(__version__)
)
parser.add_argument(
"-i",
"--input",
dest="input",
type=str,
required=True,
help="Input h5 file to pass",
)
parser.add_argument(
"-o",
"--out",
dest="out",
type=str,
default="None",
required=True,
help="Output directory",
)
parser.add_argument(
"-n",
"--target-number",
dest="target_number",
type=str,
default="all",
required=False,
help="A list or range of targets to run",
)
parser.add_argument(
"-t",
"--target-name",
dest="target_name",
type=str,
default="None",
required=False,
help="name of the target to plot",
)
parser.add_argument(
"-d",
"--debug",
dest="debug",
default=False,
required=False,
help="log output on screen",
action="store_true",
)
args = parser.parse_args()
logger = logging.getLogger("exorad")
from exorad.utils.ascii_art import ascii_plot
logger.info(ascii_plot)
logger.info("code version {}".format(__version__))
if args.debug:
setLogLevel(logging.DEBUG)
if not os.path.exists(args.out):
os.makedirs(args.out)
logger.info("output directory created")
logger.info("reading {}".format(args.input))
file = h5py.File(args.input)
if args.target_number != "all" and args.target_name != "None":
logger.error("you cannot use both target number and target name")
raise ValueError
targets_dir = file["targets"]
targets_to_run_id = parse_range(
args.target_number, len(targets_dir.keys())
)
targets_to_run = [list(targets_dir.keys())[n] for n in targets_to_run_id]
if args.target_name != "None":
targets_to_run = [
target for target in targets_to_run if target == args.target_name
]
for target in targets_to_run:
target_dir = targets_dir[target]
table_dir = target_dir["table"]
table = read_table_hdf5(table_dir, path="table")
plotter = Plotter(input_table=table)
plotter.plot_table()
plotter.save_fig(os.path.join(args.out, "{}.png".format(target)))
plt.close()