diff --git a/src/spectramap/spectramap.py b/src/spectramap/spectramap.py index a2136dc..7a7b6d1 100644 --- a/src/spectramap/spectramap.py +++ b/src/spectramap/spectramap.py @@ -1,4 +1,5 @@ -#Others +#Others +import os import sys import numpy as np import spc_spectra as spc @@ -11,6 +12,7 @@ import numpy as np import colorcet as cc from sklearn.cluster import KMeans +import os ############################################# # Internal functions @@ -472,7 +474,24 @@ def airPLS(x, strength, porder = 1, itermax = 50): ######################################## class intensity_object: - def __init__(self, name, data, position, resolutionx, resolutiony): + def __init__(self, name, data, position, resolutionx, resolutiony, output_dir=None): + """Initialize an intensity object for spectral mapping. + + Parameters + ---------- + name : str + Name of the intensity object + data : array-like + Intensity data + position : array-like + Position data with shape (n, 2) for x,y coordinates + resolutionx : float + Resolution in x direction + resolutiony : float + Resolution in y direction + output_dir : str, optional + Directory to save output plots. If None, plots won't be saved by default + """ self.name = name self.position = position self.data = data.copy() @@ -480,11 +499,33 @@ def __init__(self, name, data, position, resolutionx, resolutiony): self.num_stepsy = int(np.max(position[:,1])) + 1 self.resolutionx = resolutionx self.resolutiony = resolutiony + self.output_dir = output_dir + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) - - ### show the intensity map using position and data - def show_map(self, xlabel, ylabel, ybarlabel, title, rotate=False, order = 'F'): - + def show_map(self, xlabel, ylabel, ybarlabel, title, save_path=None, display=True, rotate=False, order='F'): + """Show and/or save the intensity map using position and data. + + Parameters + ---------- + xlabel : str + Label for x-axis + ylabel : str + Label for y-axis + ybarlabel : str + Label for colorbar + title : str + Title of the plot + save_path : str, optional + Path where to save the plot. If None but output_dir is set, + will save to output_dir/name_title.png + display : bool, optional + Whether to display the plot (default True) + rotate : bool, optional + Whether to rotate the plot (default False) + order : str, optional + Order for reshaping the data (default 'F') + """ fig, ax = plt.subplots() cmap = cc.cm.linear_protanopic_deuteranopic_kbw_5_98_c40 @@ -492,23 +533,36 @@ def show_map(self, xlabel, ylabel, ybarlabel, title, rotate=False, order = 'F'): plane = np.rot90(self.data.reshape((self.num_stepsx, self.num_stepsy), order=order)) diameterx = self.num_stepsy * self.resolutiony diametery = self.num_stepsx * self.resolutionx - xticks = np.linspace(0, diameterx, 6) # Set custom xtick values - yticks = np.linspace(0, diametery, 6) # Set custom ytick values + xticks = np.linspace(0, diameterx, 6) + yticks = np.linspace(0, diametery, 6) ax.set(xlabel=xlabel, ylabel=ylabel, xticks=yticks, yticks=xticks) else: plane = self.data.reshape((self.num_stepsx, self.num_stepsy), order=order) diameterx = self.num_stepsx * self.resolutionx diametery = self.num_stepsy * self.resolutiony - xticks = np.linspace(0, diameterx, 6) # Set custom xtick values - yticks = np.linspace(0, diametery, 6) # Set custom ytick values + xticks = np.linspace(0, diameterx, 6) + yticks = np.linspace(0, diametery, 6) ax.set(xlabel=xlabel, ylabel=ylabel, xticks=xticks, yticks=yticks) img = ax.imshow(plane, cmap=cmap) cbar = fig.colorbar(img, ax=ax) cbar.set_label(ybarlabel) ax.set_title(title) - plt.show() - + + # Handle saving the plot + if save_path: + plt.savefig(save_path, bbox_inches='tight', dpi=300) + elif self.output_dir: + # Create a default filename based on object name and plot title + filename = f"{self.name}_{title.replace(' ', '_')}.png" + save_path = os.path.join(self.output_dir, filename) + plt.savefig(save_path, bbox_inches='tight', dpi=300) + + if display: + plt.show() + else: + plt.close(fig) + def get_data(self): return (self.data)