"""
Visualization tools using Jupyter notebooks
"""
import os
import re
import math
import inspect
import numpy as np
from .raster import Raster
from .vector import Vector
import matplotlib.pyplot as plt
import sys
from osgeo import ogr
if sys.version_info >= (3, 0):
from tkinter import filedialog, Tk
else:
from Tkinter import Tk
import tkFileDialog as filedialog
from IPython.display import display
from ipywidgets import interactive_output, IntSlider, Layout, Checkbox, Button, HBox, Label, VBox
from mpl_toolkits.axes_grid1 import make_axes_locatable
[docs]class RasterViewer(object):
"""
| Plotting utility for displaying a geocoded image stack file.
| On moving the slider, the band at the slider position is read from the file and displayed.
| By clicking on the band image display, you can display time series profiles.
| The collected profiles can be saved to a csv file.
Parameters
----------
filename: str
the name of the file to display
cmap: str
the color map name for displaying the image.
See :class:`matplotlib.colors.Colormap`.
band_indices: list or None
a list of indices for renaming the individual band indices in `filename`;
e.g. -70:70, instead of the raw band indices, e.g. 1:140.
The number of unique elements must be of same length as the number of bands in `filename`.
band_names: list or None
alternative names to assign to the individual bands
pmin: int
the minimum percentile for linear histogram stretching
pmax: int
the maximum percentile for linear histogram stretching
zmin: int or float or None
the minimum value of the displayed data range; overrides `pmin`
zmax: int or float or None
the maximum value of the displayed data range; overrides `pmax`
ts_convert: function or None
a function to read time stamps from the band names
title: str or None
the plot title to be displayed; per default, if set to `None`: `Figure 1`, `Figure 2`, ...
datalabel: str
a label for the units of the displayed data. This also supports LaTeX mathematical notation.
See `Text rendering With LaTeX <https://matplotlib.org/users/usetex.html>`_.
spectrumlabel: str
a label for the x-axis of the vertical spectra
fontsize: int
the label text font size
custom: list or None
Custom functions for plotting figures in additional subplots.
Each figure will be updated upon click on the major map display.
Each function is required to take at least an argument `axis`.
Furthermore, the following optional arguments are supported:
* `values` (:py:obj:`list`): the time series values collected from the last click
* `timestamps` (:py:obj:`list`): the time stamps as returned by `ts_convert`
* `band` (:py:obj:`int`): the index of the currently displayed band
* `x` (:py:obj:`float`): the x map coordinate in units of the image CRS
* `y` (:py:obj:`float`): the y map coordinate in units of the image CRS
Additional subplots are automatically added in a row-major order.
The list may contain `None` elements to leave certain subplots empty for later usage.
This might be useful for plots which are not to be updated each time the map display is clicked on.
See Also
--------
:func:`matplotlib.pyplot.imshow`
"""
def __init__(self, filename, cmap='jet', band_indices=None, band_names=None, pmin=2, pmax=98, zmin=None, zmax=None,
ts_convert=None, title=None, datalabel='data', spectrumlabel='time', fontsize=8, custom=None):
self.ts_convert = ts_convert
self.custom = custom
self.filename = filename
with Raster(filename) as ras:
self.rows = ras.rows
self.cols = ras.cols
self.bands = ras.bands
self.epsg = ras.epsg
self.crs = ras.srs
geo = ras.raster.GetGeoTransform()
self.nodata = ras.nodata
self.format = ras.format
if band_names is None:
self.bandnames = ras.bandnames
else:
self.bandnames = band_names
self.slider_readout = False
self.timestamps = range(0, self.bands) if ts_convert is None else [ts_convert(x) for x in self.bandnames]
self.datalabel = datalabel
self.spectrumlabel = spectrumlabel
xlab = self.crs.GetAxisName(None, 0)
ylab = self.crs.GetAxisName(None, 1)
self.xlab = xlab.lower() if xlab is not None else 'longitude'
self.ylab = ylab.lower() if ylab is not None else 'latitude'
self.xmin = geo[0]
self.ymax = geo[3]
self.xres = geo[1]
self.yres = abs(geo[5])
self.xmax = self.xmin + self.xres * self.cols
self.ymin = self.ymax - self.yres * self.rows
self.extent = (self.xmin, self.xmax, self.ymin, self.ymax)
self.pmin, self.pmax = pmin, pmax
self.zmin, self.zmax = zmin, zmax
# define some options for display of the widget box
self.layout = Layout(
display='flex',
flex_flow='column',
border='solid 2px',
align_items='stretch',
width='100%'
)
self.fontsize = fontsize
self.colormap = cmap
if band_indices is not None:
if len(list(set(band_indices))) != self.bands:
raise RuntimeError('length mismatch of unique provided band indices ({0}) '
'and image bands ({1})'.format(len(band_indices), self.bands))
else:
self.indices = sorted(band_indices)
else:
self.indices = range(1, self.bands + 1)
self.band = self.indices[len(self.indices) // 2]
# define a slider for changing a plotted image
self.slider = IntSlider(min=min(self.indices), max=max(self.indices), step=1, continuous_update=False,
value=self.band,
description='band',
style={'description_width': 'initial'},
readout=self.slider_readout)
# a simple checkbox to enable/disable stacking of vertical profiles into one plot
self.checkbox = Checkbox(value=True, description='stack vertical profiles', indent=False)
# a button to clear the vertical profile plot
self.clearbutton = Button(description='clear vertical plot')
self.clearbutton.on_click(lambda x: self.__init_vertical_plot())
self.write_csv = Button(description='export csv')
self.write_csv.on_click(lambda x: self.csv())
self.write_shp = Button(description='export shp')
self.write_shp.on_click(lambda x: self.shp())
if self.format == 'ENVI':
self.sliderlabel = Label(value=self.bandnames[self.slider.value], layout={'width': '500px'})
children = [HBox([self.slider, self.sliderlabel]),
HBox([self.checkbox, self.clearbutton, self.write_csv, self.write_shp])]
else:
children = [self.slider, HBox([self.checkbox, self.clearbutton, self.write_csv, self.write_shp])]
form = VBox(children=children, layout=self.layout)
display(form)
# self.fig = plt.figure(num=title)
if custom is None:
self.fig, axes = plt.subplots(1, 2, num=title)
# left display (image)
self.ax1 = axes[0]
# right display (time series)
self.ax2 = axes[1]
else:
rows = math.ceil(len(custom) / 2) + 1
self.fig, axes = plt.subplots(rows, 2, num=title)
self.ax1, self.ax2 = axes[0]
self.cax = np.ravel(axes[1:])
self.ax1.get_xaxis().get_major_formatter().set_useOffset(False)
self.ax1.get_yaxis().get_major_formatter().set_useOffset(False)
self.ax1.set_xlabel(self.xlab, fontsize=self.fontsize)
self.ax1.set_ylabel(self.ylab, fontsize=self.fontsize)
self.ax1.tick_params(axis='both', which='major', labelsize=self.fontsize)
self.ax2.tick_params(axis='both', which='major', labelsize=self.fontsize)
# format the values displayed for the mouse pointer
self.ax1.format_coord = self.__format_coord
# add a cross-hair to the horizontal slice plot
self.x_coord, self.y_coord = self.__img2map(0, 0)
self.lhor = self.ax1.axhline(self.y_coord, linewidth=1, color='r')
self.lver = self.ax1.axvline(self.x_coord, linewidth=1, color='r')
# set up the vertical profile plot
self.__init_vertical_plot()
# make the figure responds to mouse clicks by executing method __onclick
self.cid1 = self.fig.canvas.mpl_connect('button_press_event', self.__onclick)
# enable interaction with the slider
out = interactive_output(self.__onslide, {'h': self.slider})
plt.tight_layout()
def __onslide(self, h):
self.band = self.indices.index(h)
mat = self.__read_band(self.band + 1)
masked = np.ma.array(mat, mask=np.isnan(mat))
pmin, pmax = np.percentile(masked.compressed(), (self.pmin, self.pmax))
vmin = self.zmin if self.zmin is not None else pmin
vmax = self.zmax if self.zmax is not None else pmax
cmap = plt.get_cmap(self.colormap)
cmap.set_bad('white')
title = self.bandnames[self.slider.value - 1]
self.ax1.set_title(title, fontsize=self.fontsize)
self.ax1.imshow(masked, vmin=vmin, vmax=vmax, extent=self.extent, cmap=cmap)
if hasattr(self, 'sliderlabel'):
self.sliderlabel.value = title
self.__set_colorbar(self.ax1)
self.vline.set_xdata(self.timestamps[self.slider.value])
def __read_band(self, band):
with Raster(self.filename) as ras:
mat = ras.matrix(band)
return mat
def __img2map(self, x, y):
x_map = self.xmin + self.xres * x
y_map = self.ymax - self.yres * y
return x_map, y_map
def __map2img(self, x, y):
x_img = int((x - self.xmin) / self.xres)
y_img = int((self.ymax - y) / self.yres)
return x_img, y_img
def __reset_crosshair(self):
"""
redraw the cross-hair on the horizontal slice plot
Parameters
----------
x: int
the x image coordinate
y: int
the y image coordinate
Returns
-------
"""
self.lhor.set_ydata(self.y_coord)
self.lver.set_xdata(self.x_coord)
def __init_vertical_plot(self):
"""
set up the vertical profile plot
Returns
-------
"""
# clear the plot if lines have already been drawn on it
if len(self.ax2.lines) > 0:
self.ax2.cla()
# set up the vertical profile plot
self.ax2.set_ylabel(self.datalabel, fontsize=self.fontsize)
self.ax2.set_xlabel(self.spectrumlabel, fontsize=self.fontsize)
self.ax2.set_title('vertical point profiles', fontsize=self.fontsize)
self.ax2.set_xlim([min(self.timestamps), max(self.timestamps)])
# plot vertical line at the slider position
self.vline = self.ax2.axvline(self.timestamps[self.slider.value], color='black')
def __onclick(self, event):
"""
respond to mouse clicks in the plot.
This function responds to clicks on the first (horizontal slice) plot and updates the vertical profile and
slice plots
Parameters
----------
event: matplotlib.backend_bases.MouseEvent
the click event object containing image coordinates
"""
# only do something if the first plot has been clicked on
if event.inaxes == self.ax1:
# retrieve the click coordinates
self.x_coord = event.xdata
self.y_coord = event.ydata
# redraw the cross-hair
self.__reset_crosshair()
# read the time series at the clicked coordinate
with Raster(self.filename)[self.y_coord, self.x_coord, :] as ras:
timeseries = ras.array()
# convert the map coordinates collected at the click to image pixel coordinates
x, y = self.__map2img(self.x_coord, self.y_coord)
# redraw/clear the vertical profile plot in case stacking is disabled
if not self.checkbox.value:
self.__init_vertical_plot()
# plot the vertical profile
label = 'x: {0:03}; y: {1:03}'.format(x, y)
self.ax2.plot(self.timestamps, timeseries, label=label)
self.ax2_legend = self.ax2.legend(loc=0, prop={'size': 7}, markerscale=1)
if self.custom is not None:
for i, func in enumerate(self.custom):
if func is not None:
self.cax[i].cla()
args = self.__argcheck(function=func,
axis=self.cax[i],
values=timeseries)
func(**args)
plt.tight_layout()
[docs] def csv(self, outname=None):
"""
write the collected samples to a CSV file
Parameters
----------
outname: str
the name of the file to write; if left at the default `None`, a graphical file selection dialog is opened
Returns
-------
"""
# the first line is the vertical band line and is thus excluded
profiles = self.ax2.get_lines()[1:]
if len(profiles) == 0:
return
if outname is None:
root = Tk()
# Hide the main window
root.withdraw()
outname = filedialog.asksaveasfilename(initialdir=os.path.expanduser('~'),
defaultextension='.csv',
filetypes=(('csv', '*.csv'),
('all files', '*.*')))
if outname is None:
return
with open(outname, 'w') as csv:
csv.write('id;bandname;row;column;xdata;ydata\n')
for i in range(0, len(profiles)):
line = profiles[i]
xdata = line.get_xdata()
ydata = line.get_ydata()
# get the row and column indices of the profile
legend_text = self.ax2.get_legend().texts[i].get_text()
legend_items = re.sub('[xy: ]', '', legend_text).split(';')
col, row = [int(x) for x in legend_items]
for j in range(0, self.bands):
entry = '{};{};{};{};{};{}\n'.format(i + 1, self.bandnames[j], row, col, xdata[j], ydata[j])
csv.write(entry)
csv.close()
[docs] def get_current_profile(self):
"""
Returns
-------
list
the values of the most recently plotted time series
"""
profiles = self.ax2.get_lines()[1:]
if len(profiles) == 0:
return []
else:
line = profiles[-1]
return line.get_ydata()
[docs] def shp(self, outname=None):
"""
write the collected samples to a CSV file
Parameters
----------
outname: str
the name of the file to write; if left at the default `None`, a graphical file selection dialog is opened
Returns
-------
"""
# the first line is the vertical band line and is thus excluded
profiles = self.ax2.get_lines()[1:]
if len(profiles) == 0:
return
if outname is None:
root = Tk()
# Hide the main window
root.withdraw()
outname = filedialog.asksaveasfilename(initialdir=os.path.expanduser('~'),
defaultextension='.shp',
filetypes=(('shp', '*.shp'),
('all files', '*.*')))
if outname is None:
return
layername = os.path.splitext(os.path.basename(outname))[0]
with Vector(driver='Memory') as points:
points.addlayer(layername, self.crs, 1)
fieldnames = ['b{}'.format(i) for i in range(0, self.bands)]
for field in fieldnames:
points.addfield(field, ogr.OFTReal)
for i, line in enumerate(profiles):
# get the data values from the profile
ydata = line.get_ydata().tolist()
# get the row and column indices of the profile
legend_text = self.ax2.get_legend().texts[i].get_text()
legend_items = re.sub('[xy: ]', '', legend_text).split(';')
col, row = [int(x) for x in legend_items]
# convert the pixel indices to map coordinates
x, y = self.__img2map(col, row)
# create a new point geometry
point = ogr.Geometry(ogr.wkbPoint)
point.AddPoint(x, y)
fields = {}
# create a field lookup dictionary
for j, value in enumerate(ydata):
if np.isnan(value):
value = -9999
fields[fieldnames[j]] = value
# add the new feature to the layer
points.addfeature(point, fields=fields)
point = None
points.write(outname, 'ESRI Shapefile')
lookup = os.path.splitext(outname)[0] + '_lookup.csv'
with open(lookup, 'w') as csv:
content = [';'.join(x) for x in zip(fieldnames, self.bandnames)]
csv.write('id;bandname\n')
csv.write('\n'.join(content))
def __set_colorbar(self, axis, label=None):
if len(axis.images) > 1:
axis.images[0].colorbar.remove()
del axis.images[0]
divider = make_axes_locatable(axis)
cax = divider.append_axes('right', size='5%', pad=0.05)
self.cbar = self.fig.colorbar(axis.images[0], cax=cax)
self.cbar.ax.tick_params(axis='both', which='major', labelsize=self.fontsize)
if label is not None:
self.cbar.ax.set_ylabel(label, fontsize=self.fontsize)
def __argcheck(self, function, axis, values):
args = locals()
del args['function']
args['timestamps'] = self.timestamps
args['band'] = self.band
args['x'] = self.x_coord
args['y'] = self.y_coord
fargs = inspect.getfullargspec(function).args
for required in ['axis']:
if required not in fargs:
raise TypeError("missing argument '{}'".format(required))
return {key: value for key, value in args.items() if key in fargs}
def __format_coord(self, x, y):
text_pointer = 'x, y: {0}, {1}; ' \
+ self.xlab + ', ' + self.ylab \
+ ': {2:.2f}, {3:.2f}; value:'
x_img, y_img = self.__map2img(x, y)
return text_pointer.format(x_img, y_img, x, y)