Source code for spatialist.explorer

import os
import re
import numpy as np
from .raster import Raster
from .vector import Vector
import matplotlib.pyplot as plt

import sys

from osgeo import ogr

if sys.version_info >= (3, 0):
    from tkinter import filedialog, Tk
else:
    from Tkinter import Tk
    import tkFileDialog as filedialog

from IPython.display import display
from ipywidgets import interactive_output, IntSlider, Layout, Checkbox, Button, HBox, Label, VBox
from mpl_toolkits.axes_grid1 import make_axes_locatable

"""
This module is intended for gathering functionalities for plotting spatial data with jupyter notebooks
"""


[docs]class RasterViewer(object): """ | Plotting utility for displaying a geocoded image stack file. | On moving the slider, the band at the slider position is read from the file and displayed. | By clicking on the band image display, you can display time series profiles. | The collected profiles can be saved to a csv file. Parameters ---------- filename: str the name of the file to display cmap: str the color map name for displaying the image. See :class:`matplotlib.colors.Colormap`. band_indices: list or None a list of indices for renaming the individual band indices in `filename`; e.g. -70:70, instead of the raw band indices, e.g. 1:140. The number of unique elements must of same length as the number of bands in `filename`. band_names: list or None alternative names to assign to the individual bands pmin: int the minimum percentile for linear histogram stretching pmax: int the maximum percentile for linear histogram stretching zmin: int or float the minimum value of the displayed data range; overrides pmin zmax: int of float the maximum value of the displayed data range; overrides pmax ts_convert: function or None a function to read time stamps from the band names title: str or None the plot title to be displayed; per default, if set to `None`: `Figure 1`, `Figure 2`, ... datalabel: str a label for the units of the displayed data. This also supports LaTeX mathematical notation. See `Text rendering With LaTeX <https://matplotlib.org/users/usetex.html>`_. spectrumlabel: str a label for the x-axis of the vertical spectra See Also -------- :func:`matplotlib.pyplot.imshow` """ def __init__(self, filename, cmap='jet', band_indices=None, band_names=None, pmin=2, pmax=98, zmin=None, zmax=None, ts_convert=None, title=None, datalabel='data', spectrumlabel='time', fontsize=8): self.ts_convert = ts_convert self.filename = filename with Raster(filename) as ras: self.rows = ras.rows self.cols = ras.cols self.bands = ras.bands self.epsg = ras.epsg self.crs = ras.srs geo = ras.raster.GetGeoTransform() self.nodata = ras.nodata self.format = ras.format if band_names is None: self.bandnames = ras.bandnames else: self.bandnames = band_names self.slider_readout = False self.timestamps = range(0, self.bands) if ts_convert is None else [ts_convert(x) for x in self.bandnames] self.datalabel = datalabel self.spectrumlabel = spectrumlabel xlab = self.crs.GetAxisName(None, 0) ylab = self.crs.GetAxisName(None, 1) self.xlab = xlab.lower() if xlab is not None else 'longitude' self.ylab = ylab.lower() if ylab is not None else 'latitude' self.xmin = geo[0] self.ymax = geo[3] self.xres = geo[1] self.yres = abs(geo[5]) self.xmax = self.xmin + self.xres * self.cols self.ymin = self.ymax - self.yres * self.rows self.extent = (self.xmin, self.xmax, self.ymin, self.ymax) self.pmin, self.pmax = pmin, pmax self.zmin, self.zmax = zmin, zmax # define some options for display of the widget box self.layout = Layout( display='flex', flex_flow='column', border='solid 2px', align_items='stretch', width='100%' ) self.fontsize = fontsize self.colormap = cmap if band_indices is not None: if len(list(set(band_indices))) != self.bands: raise RuntimeError('length mismatch of unique provided band indices ({0}) ' 'and image bands ({1})'.format(len(band_indices), self.bands)) else: self.indices = sorted(band_indices) else: self.indices = range(1, self.bands + 1) # define a slider for changing a plotted image self.slider = IntSlider(min=min(self.indices), max=max(self.indices), step=1, continuous_update=False, value=self.indices[len(self.indices) // 2], description='band', style={'description_width': 'initial'}, readout=self.slider_readout) # a simple checkbox to enable/disable stacking of vertical profiles into one plot self.checkbox = Checkbox(value=True, description='stack vertical profiles', indent=False) # a button to clear the vertical profile plot self.clearbutton = Button(description='clear vertical plot') self.clearbutton.on_click(lambda x: self.__init_vertical_plot()) self.write_csv = Button(description='export csv') self.write_csv.on_click(lambda x: self.csv()) self.write_shp = Button(description='export shp') self.write_shp.on_click(lambda x: self.shp()) if self.format == 'ENVI': self.sliderlabel = Label(value=self.bandnames[self.slider.value], layout={'width': '500px'}) children = [HBox([self.slider, self.sliderlabel]), HBox([self.checkbox, self.clearbutton, self.write_csv, self.write_shp])] else: children = [self.slider, HBox([self.checkbox, self.clearbutton, self.write_csv, self.write_shp])] form = VBox(children=children, layout=self.layout) display(form) self.fig = plt.figure(num=title) # left display (image) self.ax1 = self.fig.add_subplot(121) # right display (time series) self.ax2 = self.fig.add_subplot(122) # self.ax1 = plt.gca() self.ax1.get_xaxis().get_major_formatter().set_useOffset(False) self.ax1.get_yaxis().get_major_formatter().set_useOffset(False) self.ax1.set_xlabel(self.xlab, fontsize=self.fontsize) self.ax1.set_ylabel(self.ylab, fontsize=self.fontsize) self.ax1.tick_params(axis='both', which='major', labelsize=self.fontsize) self.ax2.tick_params(axis='both', which='major', labelsize=self.fontsize) # format the values displayed for the mouse pointer text_pointer = self.ylab + '={0:.2f}, ' + self.xlab + '={1:.2f}, value=' self.ax1.format_coord = lambda x, y: text_pointer.format(y, x) # add a cross-hair to the horizontal slice plot self.x_coord, self.y_coord = self.__img2map(0, 0) self.lhor = self.ax1.axhline(self.y_coord, linewidth=1, color='r') self.lver = self.ax1.axvline(self.x_coord, linewidth=1, color='r') # set up the vertical profile plot self.__init_vertical_plot() # make the figure responds to mouse clicks by executing method __onclick self.cid1 = self.fig.canvas.mpl_connect('button_press_event', self.__onclick) # enable interaction with the slider out = interactive_output(self.__onslide, {'h': self.slider}) plt.tight_layout() def __onslide(self, h): mat = self.__read_band(self.indices.index(h) + 1) masked = np.ma.array(mat, mask=np.isnan(mat)) pmin, pmax = np.percentile(masked.compressed(), (self.pmin, self.pmax)) vmin = self.zmin if self.zmin is not None else pmin vmax = self.zmax if self.zmax is not None else pmax cmap = plt.get_cmap(self.colormap) cmap.set_bad('white') title = self.bandnames[self.slider.value - 1] self.ax1.set_title(title, fontsize=self.fontsize) self.ax1.imshow(masked, vmin=vmin, vmax=vmax, extent=self.extent, cmap=cmap) if hasattr(self, 'sliderlabel'): self.sliderlabel.value = title self._set_colorbar(self.ax1) self.vline.set_xdata(self.timestamps[self.slider.value]) def __read_band(self, band): with Raster(self.filename) as ras: mat = ras.matrix(band) return mat def __read_timeseries(self, x, y): with Raster(self.filename) as ras: vals = ras.raster.ReadAsArray(xoff=x, yoff=y, xsize=1, ysize=1) if isinstance(self.nodata, list): for i, x in enumerate(vals): if x == self.nodata[i]: vals[i] = np.nan else: vals[vals == self.nodata] = np.nan return vals.reshape(vals.shape[0]) def __img2map(self, x, y): x_map = self.xmin + self.xres * x y_map = self.ymax - self.yres * y return x_map, y_map def __map2img(self, x, y): x_img = int((x - self.xmin) / self.xres) y_img = int((self.ymax - y) / self.yres) return x_img, y_img def __reset_crosshair(self): """ redraw the cross-hair on the horizontal slice plot Parameters ---------- x: int the x image coordinate y: int the y image coordinate Returns ------- """ self.lhor.set_ydata(self.y_coord) self.lver.set_xdata(self.x_coord) def __init_vertical_plot(self): """ set up the vertical profile plot Returns ------- """ # clear the plot if lines have already been drawn on it if len(self.ax2.lines) > 0: self.ax2.cla() # set up the vertical profile plot self.ax2.set_ylabel(self.datalabel, fontsize=self.fontsize) self.ax2.set_xlabel(self.spectrumlabel, fontsize=self.fontsize) self.ax2.set_title('vertical point profiles', fontsize=self.fontsize) self.ax2.set_xlim([min(self.timestamps), max(self.timestamps)]) # plot vertical line at the slider position self.vline = self.ax2.axvline(self.timestamps[self.slider.value], color='black') def __onclick(self, event): """ respond to mouse clicks in the plot. This function responds to clicks on the first (horizontal slice) plot and updates the vertical profile and slice plots Parameters ---------- event: matplotlib.backend_bases.MouseEvent the click event object containing image coordinates """ # only do something if the first plot has been clicked on if event.inaxes == self.ax1: # retrieve the click coordinates self.x_coord = event.xdata self.y_coord = event.ydata # redraw the cross-hair self.__reset_crosshair() # convert the map coordinates collected at the click to image pixel coordinates x, y = self.__map2img(self.x_coord, self.y_coord) # read the time series at the image coordinates subset_vertical = self.__read_timeseries(x, y) # redraw/clear the vertical profile plot in case stacking is disabled if not self.checkbox.value: self.__init_vertical_plot() # plot the vertical profile label = 'x: {0:03}; y: {1:03}'.format(x, y) self.ax2.plot(self.timestamps, subset_vertical, label=label) self.ax2_legend = self.ax2.legend(loc=0, prop={'size': 7}, markerscale=1)
[docs] def csv(self, outname=None): # the first line is the vertical band line and is thus excluded profiles = self.ax2.get_lines()[1:] if len(profiles) == 0: return if outname is None: root = Tk() # Hide the main window root.withdraw() outname = filedialog.asksaveasfilename(initialdir=os.path.expanduser('~'), defaultextension='.csv', filetypes=(('csv', '*.csv'), ('all files', '*.*'))) if outname is None: return with open(outname, 'w') as csv: csv.write('id;bandname;row;column;xdata;ydata\n') for i in range(0, len(profiles)): line = profiles[i] xdata = line.get_xdata() ydata = line.get_ydata() # get the row and column indices of the profile legend_text = self.ax2.get_legend().texts[i].get_text() legend_items = re.sub('[xy: ]', '', legend_text).split(';') col, row = [int(x) for x in legend_items] for j in range(0, self.bands): entry = '{};{};{};{};{};{}\n'.format(i + 1, self.bandnames[j], row, col, xdata[j], ydata[j]) csv.write(entry) csv.close()
[docs] def shp(self, outname=None): # the first line is the vertical band line and is thus excluded profiles = self.ax2.get_lines()[1:] if len(profiles) == 0: return if outname is None: root = Tk() # Hide the main window root.withdraw() outname = filedialog.asksaveasfilename(initialdir=os.path.expanduser('~'), defaultextension='.shp', filetypes=(('shp', '*.shp'), ('all files', '*.*'))) if outname is None: return layername = os.path.splitext(os.path.basename(outname))[0] with Vector(driver='Memory') as points: points.addlayer(layername, self.crs, 1) fieldnames = ['b{}'.format(i) for i in range(0, self.bands)] for field in fieldnames: points.addfield(field, ogr.OFTReal) for i, line in enumerate(profiles): # get the data values from the profile ydata = line.get_ydata().tolist() # get the row and column indices of the profile legend_text = self.ax2.get_legend().texts[i].get_text() legend_items = re.sub('[xy: ]', '', legend_text).split(';') col, row = [int(x) for x in legend_items] # convert the pixel indices to map coordinates x, y = self.__img2map(col, row) # create a new point geometry point = ogr.Geometry(ogr.wkbPoint) point.AddPoint(x, y) fields = {} # create a field lookup dictionary for j, value in enumerate(ydata): if np.isnan(value): value = -9999 fields[fieldnames[j]] = value # add the new feature to the layer points.addfeature(point, fields=fields) point = None points.write(outname, 'ESRI Shapefile') lookup = os.path.splitext(outname)[0] + '_lookup.csv' with open(lookup, 'w') as csv: content = [';'.join(x) for x in zip(fieldnames, self.bandnames)] csv.write('id;bandname\n') csv.write('\n'.join(content))
def _set_colorbar(self, axis, label=None): if len(axis.images) > 1: axis.images[0].colorbar.remove() del axis.images[0] divider = make_axes_locatable(axis) cax = divider.append_axes('right', size='5%', pad=0.05) self.cbar = self.fig.colorbar(axis.images[0], cax=cax) self.cbar.ax.tick_params(axis='both', which='major', labelsize=self.fontsize) if label is not None: self.cbar.ax.set_ylabel(label, fontsize=self.fontsize)