fire2a.agglomerative_clustering

πŸ‘‹πŸŒŽ 🌲πŸ”₯

Raster clustering

Usage

Overview

  1. Choose your raster files
  2. Configure nodata, scaling strategies and weights in the config.toml file
  3. Choose "distance threshold" (or "number of clusters") for the Agglomerative clustering algorithm. Recommended:
    • Start with a distance threshold of 10.0 and decrease for more or increase for less clusters
    • After calibrating the distance threshold;
    • Sieve small clusters (merge them to the biggest neighbor) with the --sieve integer_pixels_size option

Execution

# get command line help
python -m fire2a.agglomerative_clustering -h
python -m fire2a.agglomerative_clustering --help

# activate your qgis dev environment
source ~/pyqgisdev/bin/activate 
# execute 
(qgis) $ python -m fire2a.agglomerative_clustering -d 10.0

# windowsπŸ’© users should use QGIS's python
C:\PROGRA~1\QGIS33~1.3\bin\python-qgis.bat -m fire2a.agglomerative_clustering -d 10.0

More info on: How to windows πŸ’© using qgis's python

Preparation

1. Choose your raster files

  • Any GDAL compatible raster will be read
  • Place them all in the same directory where the script will be executed
  • "Quote them" if they have any non alphanumerical chars [a-zA-Z0-9]

2. Preprocessing configuration

See the config.toml file for example of the configuration of the preprocessing steps. The file is structured as follows:

["filename.tif"]
no_data_strategy = "most_frequent"
scaling_strategy = "onehot"
fill_value = 0
weight = 1
  1. __scaling_strategy__

    • can be "standard", "robust", "onehot"
    • default is "robust"
    • Standard: (x-mean)/stddev
    • Robust: same but droping the tails of the distribution
    • OneHot: __for CATEGORICAL DATA__
  2. __no_data_strategy__

    • can be "mean", "median", "most_frequent", "constant"
    • default is "mean"
    • categorical data should use "most_frequent" or "constant"
    • "constant" will use the value in __fill_value__ (see below)
    • SimpleImputer
  3. __fill_value__

    • used when __no_data_strategy__ is "constant"
    • default is 0
    • SimpleImputer
  4. __weight__

    • default is 1
    • used to give more importance to some features than others
    • This is done after the nodata imputation and scaling steps, before clustering

3. Clustering configuration

  1. __Agglomerative__ clustering algorithm is used. The following parameters are muttually exclusive:
  • -n or --n_clusters: The number of clusters to form as well as the number of centroids to generate.
  • -d or --distance_threshold: The linkage distance threshold above which, clusters will not be merged. When scaling start with 10.0 and downward (0.0 is compute the whole algorithm).
  • More parameters for clustering can be passed directly into the pipelie method as keyword arguments
  1. __Sieve filter__ is applied to remove small clusters. The sieve filter is applied using the GDAL sieve library
  • --sieve: Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor

4. Post-processing

Outputs can be:

  • A raster file with the cluster labels and a polygon file with the cluster polygons
  • A polygon file with the cluster polygons, with attribute being the number of pixels in each cluster
  • A plot of the input data distributions, the rescaled data distributions, and the cluster size history and histogram (crashes QGIS in windows)

Or use the --script option to return the label_map and the pipeline object for further processing in another python script:

from fire2a.agglomerative_clustering import main
label_map, pipe1, pipe2 = main(["-d", "10.0", "-s"])
  1#!/usr/bin/env python3
  2# fmt: off
  3"""πŸ‘‹πŸŒŽ 🌲πŸ”₯
  4# Raster clustering
  5## Usage
  6### Overview
  71. Choose your raster files
  82. Configure nodata, scaling strategies and weights in the `config.toml` file
  93. Choose "distance threshold" (or "number of clusters") for the [Agglomerative](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) clustering algorithm. Recommended:
 10   - Start with a distance threshold of 10.0 and decrease for more or increase for less clusters
 11   - After calibrating the distance threshold; 
 12   - [Sieve](https://gdal.org/en/latest/programs/gdal_sieve.html) small clusters (merge them to the biggest neighbor) with the `--sieve integer_pixels_size` option 
 13
 14### Execution
 15```bash
 16# get command line help
 17python -m fire2a.agglomerative_clustering -h
 18python -m fire2a.agglomerative_clustering --help
 19
 20# activate your qgis dev environment
 21source ~/pyqgisdev/bin/activate 
 22# execute 
 23(qgis) $ python -m fire2a.agglomerative_clustering -d 10.0
 24
 25# windowsπŸ’© users should use QGIS's python
 26C:\\PROGRA~1\\QGIS33~1.3\\bin\\python-qgis.bat -m fire2a.agglomerative_clustering -d 10.0
 27```
 28[More info on: How to windows πŸ’© using qgis's python](https://github.com/fire2a/fire2a-lib/tree/main/qgis-launchers)
 29
 30### Preparation
 31#### 1. Choose your raster files
 32- Any [GDAL compatible](https://gdal.org/en/latest/drivers/raster/index.html) raster will be read
 33- Place them all in the same directory where the script will be executed
 34- "Quote them" if they have any non alphanumerical chars [a-zA-Z0-9]
 35
 36#### 2. Preprocessing configuration
 37See the `config.toml` file for example of the configuration of the preprocessing steps. The file is structured as follows:
 38
 39```toml
 40["filename.tif"]
 41no_data_strategy = "most_frequent"
 42scaling_strategy = "onehot"
 43fill_value = 0
 44weight = 1
 45```
 46
 471. __scaling_strategy__
 48   - can be "standard", "robust", "onehot"
 49   - default is "robust"
 50   - [Standard](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html): (x-mean)/stddev
 51   - [Robust](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html): same but droping the tails of the distribution
 52   - [OneHot](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html): __for CATEGORICAL DATA__
 53
 542. __no_data_strategy__
 55   - can be "mean", "median", "most_frequent", "constant"
 56   - default is "mean"
 57   - categorical data should use "most_frequent" or "constant"
 58   - "constant" will use the value in __fill_value__ (see below)
 59   - [SimpleImputer](https://scikit-learn.org/stable/modules/generated/sklearn.impute.SimpleImputer.html)
 60
 613. __fill_value__
 62   - used when __no_data_strategy__ is "constant"
 63   - default is 0
 64   - [SimpleImputer](https://scikit-learn.org/stable/modules/generated/sklearn.impute.SimpleImputer.html)
 65
 664. __weight__
 67   - default is 1
 68   - used to give more importance to some features than others
 69   - This is done after the nodata imputation and scaling steps, before clustering
 70
 71
 72#### 3. Clustering configuration
 73
 741. __Agglomerative__ clustering algorithm is used. The following parameters are muttually exclusive:
 75- `-n` or `--n_clusters`: The number of clusters to form as well as the number of centroids to generate.
 76- `-d` or `--distance_threshold`: The linkage distance threshold above which, clusters will not be merged. When scaling start with 10.0 and downward (0.0 is compute the whole algorithm).
 77- More [parameters](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) for clustering can be passed directly into the pipelie method as keyword arguments
 78
 792. __Sieve filter__ is applied to remove small clusters. The sieve filter is applied using the [GDAL sieve library](https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve)
 80- `--sieve`: Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor
 81
 82#### 4. Post-processing
 83Outputs can be:
 84- A raster file with the cluster labels and a polygon file with the cluster polygons
 85- A polygon file with the cluster polygons, with attribute being the number of pixels in each cluster
 86- A plot of the input data distributions, the rescaled data distributions, and the cluster size history and histogram (crashes QGIS in windows)
 87
 88Or use the `--script` option to return the label_map and the pipeline object for further processing in another python script:
 89    ```python
 90    from fire2a.agglomerative_clustering import main
 91    label_map, pipe1, pipe2 = main(["-d", "10.0", "-s"])
 92    ```
 93"""
 94# fmt: on
 95# from IPython.terminal.embed import InteractiveShellEmbed
 96# InteractiveShellEmbed()()
 97import logging
 98import sys
 99from pathlib import Path
100
101import numpy as np
102from osgeo import gdal, ogr, osr
103from sklearn.base import BaseEstimator, TransformerMixin
104from sklearn.cluster import AgglomerativeClustering
105from sklearn.compose import ColumnTransformer
106from sklearn.impute import SimpleImputer
107from sklearn.neighbors import radius_neighbors_graph
108from sklearn.pipeline import Pipeline
109from sklearn.preprocessing import OneHotEncoder, RobustScaler, StandardScaler
110
111from fire2a.utils import fprint, read_toml
112
113try:
114    GDT = gdal.GDT_Int64
115except:
116    GDT = gdal.GDT_Int32
117try:
118    OFT = ogr.OFTInteger64
119except:
120    OFT = ogr.OFTInteger
121
122logger = logging.getLogger(__name__)
123
124
125def check_shapes(data_list):
126    """Check if all data arrays have the same shape and are 2D.
127    Returns the shape of the data arrays if they are all equal.
128    """
129    from functools import reduce
130
131    def equal_or_error(x, y):
132        """Check if x and y are equal, returns x if equal else raises a ValueError."""
133        if x == y:
134            return x
135        else:
136            raise ValueError("All data arrays must have the same shape")
137
138    shape = reduce(equal_or_error, (data.shape for data in data_list))
139    if len(shape) != 2:
140        raise ValueError("All data arrays must be 2D")
141    height, width = shape
142    return height, width
143
144
145def get_map_neighbors(height, width, num_neighbors=8):
146    """Get the neighbors of each cell in a 2D grid.
147    n_jobs=-1 uses all available cores.
148    """
149
150    grid_points = np.indices((height, width)).reshape(2, -1).T
151
152    nb4 = radius_neighbors_graph(grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1)
153    nb8 = radius_neighbors_graph(grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1)
154
155    # assert nb4.shape[0] == width * height
156    # assert nb8.shape[1] == width * height
157    # for n in range(width * height):
158    #     _, neighbors = np.nonzero(nb4[n])
159    #     assert 2<= len(neighbors) <= 4, f"{n=} {neighbors=}"
160    #     assert 3<= len(neighbors) <= 8, f"{n=} {neighbors=}"
161    return nb4, nb8
162
163
164class NoDataImputer(BaseEstimator, TransformerMixin):
165    """A custom Imputer that treats a specified nodata_value as np.nan and supports different strategies per column"""
166
167    def __init__(self, no_data_values, strategies, constants):
168        self.no_data_values = no_data_values
169        self.strategies = strategies
170        self.constants = constants
171        self.imputers = []
172        for no_data_value, strategy, constant in zip(no_data_values, strategies, constants):
173            if no_data_value:
174                self.imputers += [SimpleImputer(strategy=strategy, missing_values=no_data_value, fill_value=constant)]
175            else:
176                self.imputers += [SimpleImputer(strategy=strategy, fill_value=constant)]
177
178    def fit(self, X, y=None):
179        for i, imputer in enumerate(self.imputers):
180            imputer.fit(X[:, [i]], y)
181        return self
182
183    def transform(self, X):
184        for i, imputer in enumerate(self.imputers):
185            X[:, [i]] = imputer.transform(X[:, [i]])
186        self.output_data = X
187        return X
188
189
190class RescaleAllToCommonRange(BaseEstimator, TransformerMixin):
191    """A custom transformer that rescales all features to a common range [0, 1]"""
192
193    def __init__(self, weight_map):
194        self.weight_map = weight_map
195
196    def fit(self, X, y=None):
197        # Determine the combined range of all scaled features
198        self.min_val = [x.min() for x in X.T]
199        self.max_val = [x.max() for x in X.T]
200        return self
201
202    def transform(self, X):
203        # Rescale all features to match the common range
204        for i, (x, mi, ma) in enumerate(zip(X.T, self.min_val, self.max_val)):
205            if ma - mi == 0:
206                X.T[i] = x * self.weight_map[i]
207            else:
208                X.T[i] = (x - mi) / (ma - mi) * self.weight_map[i]
209        return X
210
211
212class CustomAgglomerativeClustering(BaseEstimator, TransformerMixin):
213    def __init__(self, height, width, neighbors=4, **kwargs):
214        self.height = height
215        self.width = width
216        self.neighbors = neighbors
217
218        self.grid_points = np.indices((height, width)).reshape(2, -1).T
219        if neighbors == 4:
220            connectivity = radius_neighbors_graph(
221                self.grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1
222            )
223        elif neighbors == 8:
224            connectivity = radius_neighbors_graph(
225                self.grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1
226            )
227
228        self.connectivity = connectivity
229        self.kwargs = kwargs
230        self.model = AgglomerativeClustering(connectivity=self.connectivity, **self.kwargs)
231
232    def fit(self, X, y=None):
233        logger.debug("not sure why, but this method is never called alas needed")
234        self.model.fit(X)
235        return self
236
237    def fit_predict(self, X, y=None):
238        self.input_data = X
239        return self.model.fit_predict(X)
240
241
242def pipelie(observations, info_list, height, width, **kwargs):
243    """A scipy pipeline to achieve Agglomerative Clustering with connectivity on 2d matrix
244    Steps are:
245    1. Impute missing values
246    2. Scale the features
247    3. Rescale all features to a common range
248    4. Cluster the data using Agglomerative Clustering with connectivity
249    5. Reshape the labels back to the original spatial map shape
250    6. Return the labels and the pipeline object
251
252    Args:
253        observations (np.ndarray): The input data to cluster (n_samples, n_features) shaped
254        info_list (list): A list of dictionaries containing information about each feature
255        height (int): The height of the spatial map
256        width (int): The width of the spatial map
257        kwargs: Additional keyword arguments for AgglomerativeClustering, at least one of n_clusters or distance_threshold
258
259    Returns:
260        np.ndarray: The labels of the clusters, reshaped to the original 2d spatial map shape
261        Pipeline: The pipeline object containing all the steps of the pipeline
262    """
263    # kwargs = {"n_clusters": args.n_clusters, "distance_threshold": args.distance_threshold}
264
265    # imputer strategies
266    no_data_values = [info["NoDataValue"] for info in info_list]
267    no_data_strategies = [info["no_data_strategy"] for info in info_list]
268    fill_values = [info["fill_value"] for info in info_list]
269    weights = [info["weight"] for info in info_list]
270    # scaling_strategies = [info["scaling_strategy"] for info in info_list]
271
272    # scaling strategies
273    index_map = {}
274    for strategy in ["robust", "standard", "onehot"]:
275        index_map[strategy] = [i for i, info in enumerate(info_list) if info["scaling_strategy"] == strategy]
276    # index_map
277    # !cat config.toml
278
279    # Create transformers for each type
280    robust_transformer = Pipeline(steps=[("robust_step", RobustScaler())])
281    standard_transformer = Pipeline(steps=[("standard_step", StandardScaler())])
282    onehot_transformer = Pipeline(steps=[("onehot_step", OneHotEncoder(sparse_output=False))])
283    # OneHotEncoder._n_features_outs):
284
285    # Combine transformers using ColumnTransformer
286    feature_scaler = ColumnTransformer(
287        transformers=[
288            ("robust", robust_transformer, index_map["robust"]),
289            ("standard", standard_transformer, index_map["standard"]),
290            ("onehot", onehot_transformer, index_map["onehot"]),
291        ]
292    )
293
294    # # Create a temporary directory for caching calculations
295    # # FOR ACCESING STEPS LATER ON VERY LARGE DATASETS
296    # import tempfile
297    # import joblib
298    # temp_dir = tempfile.mkdtemp()
299    # memory = joblib.Memory(location=temp_dir, verbose=0)
300
301    # Create and apply the pipeline
302    # part 1 until feature scaling
303    pipe1 = Pipeline(
304        # n_features_in_ : int
305        # feature_names_in_ : ndarray of shape (`n_features_in_`,)
306        steps=[
307            ("no_data_imputer", NoDataImputer(no_data_values, no_data_strategies, fill_values)),
308            ("feature_scaling", feature_scaler),
309        ],
310        # memory=memory,
311        verbose=True,
312    )
313    # map weights to new columns (onehot feature scaler creates one column per category)
314    obs1 = pipe1.fit_transform(observations)
315    cat_names = pipe1.named_steps["feature_scaling"]["onehot"].get_feature_names_out()
316    split_names = [name.split("_")[0] for name in cat_names]
317    cat_count = np.unique(split_names, return_counts=True)[1]
318    onehot_map = {}
319    for i, key in enumerate(index_map["onehot"]):
320        onehot_map[key] = cat_count[i]
321    # onehot_map = {key: cat_count[i] for i, key in enumerate(index_map["onehot"])}
322    weight_map = []
323    for name, idxs in index_map.items():
324        for idx in idxs:
325            if name == "onehot":
326                weight_map += [weights[idx]] * onehot_map[idx]
327                continue
328            weight_map += [weights[idx]]
329    # part 2 use weight_map and cluster
330    pipe2 = Pipeline(
331        steps=[
332            ("common_rescaling", RescaleAllToCommonRange(weight_map)),
333            ("agglomerative_clustering", CustomAgglomerativeClustering(height, width, neighbors=4, **kwargs)),
334        ],
335        # memory=memory,
336        verbose=True,
337    )
338
339    # apply pipeLIE
340    labels = pipe2.fit_predict(obs1)
341
342    # Reshape the labels back to the original spatial map shape
343    labels_reshaped = labels.reshape(height, width)
344    return labels_reshaped, pipe1, pipe2
345
346
347def write(
348    label_map,
349    width,
350    height,
351    output_raster="",
352    output_poly="output.shp",
353    authid="EPSG:3857",
354    geotransform=(0, 1, 0, 0, 0, 1),
355    nodata=None,
356    feedback=None,
357):
358
359    from fire2a.processing_utils import get_output_raster_format, get_vector_driver_from_filename
360
361    # setup drivers for raster and polygon output formats
362    if output_raster == "":
363        raster_driver = "MEM"
364    else:
365        try:
366            raster_driver = get_output_raster_format(output_raster, feedback=feedback)
367        except Exception:
368            raster_driver = "GTiff"
369    try:
370        poly_driver = get_vector_driver_from_filename(output_poly)
371    except Exception:
372        poly_driver = "ESRI Shapefile"
373
374    # create raster output
375    src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, GDT)
376    src_ds.SetGeoTransform(geotransform)  # != 0 ?
377    src_ds.SetProjection(authid)  # != 0 ?
378    #  src_band = src_ds.GetRasterBand(1)
379    #  if nodata:
380    #      src_band.SetNoDataValue(nodata)
381    #  src_band.WriteArray(label_map)
382
383    # create polygon output
384    drv = ogr.GetDriverByName(poly_driver)
385    dst_ds = drv.CreateDataSource(output_poly)
386    sp_ref = osr.SpatialReference()
387    sp_ref.SetFromUserInput(authid)  # != 0 ?
388    dst_lyr = dst_ds.CreateLayer("clusters", srs=sp_ref, geom_type=ogr.wkbPolygon)
389    dst_lyr.CreateField(ogr.FieldDefn("DN", OFT))  # != 0 ?
390    dst_lyr.CreateField(ogr.FieldDefn("pixel_count", OFT))
391    # dst_lyr.CreateField(ogr.FieldDefn("area", OFT))
392    # dst_lyr.CreateField(ogr.FieldDefn("perimeter", OFT))
393
394    # 0 != gdal.Polygonize( srcband, maskband, dst_layer, dst_field, options, callback = gdal.TermProgress)
395
396    # FAIL: All together it merges labels into a single polygon
397    #  src_band = src_ds.GetRasterBand(1)
398    #  if nodata:
399    #      src_band.SetNoDataValue(nodata)
400    #  src_band.WriteArray(label_map)
401    # gdal.Polygonize(src_band, None, dst_lyr, 0, callback=gdal.TermProgress)  # , ["8CONNECTED=8"])
402
403    # B separado
404    # for loop for creating each label_map value into a different polygonized feature
405    mem_drv = ogr.GetDriverByName("Memory")
406    tmp_ds = mem_drv.CreateDataSource("tmp_ds")
407    # itera = iter(np.unique(label_map))
408    # cluster_id = next(itera)
409    areas = []
410    pixels = []
411    data = np.zeros_like(label_map)
412    for cluster_id, px_count in zip(*np.unique(label_map, return_counts=True)):
413        # temporarily write band
414        src_band = src_ds.GetRasterBand(1)
415        src_band.SetNoDataValue(0)
416        data[label_map == cluster_id] = 1
417        src_band.WriteArray(data)
418        # create feature
419        tmp_lyr = tmp_ds.CreateLayer("", srs=sp_ref)
420        gdal.Polygonize(src_band, src_band.GetMaskBand(), tmp_lyr, -1)
421        # unset tmp data
422        data[label_map == cluster_id] = 0
423        # set polygon feat
424        feat = tmp_lyr.GetNextFeature()
425        geom = feat.GetGeometryRef()
426        featureDefn = dst_lyr.GetLayerDefn()
427        feature = ogr.Feature(featureDefn)
428        feature.SetGeometry(geom)
429        feature.SetField("DN", float(cluster_id))
430        areas += [geom.GetArea()]
431        pixels += [px_count]
432        feature.SetField("pixel_count", float(px_count))
433        # feature.SetField("area", int(geom.GetArea()))
434        # feature.SetField("perimeter", int(geom.Boundary().Length()))
435        dst_lyr.CreateFeature(feature)
436
437    fprint(f"Polygon Areas: {min(areas)=} {max(areas)=}", level="info", feedback=feedback, logger=logger)
438    fprint(f"Cluster PixelCounts: {min(pixels)=} {max(pixels)=}", level="info", feedback=feedback, logger=logger)
439    # RESTART RASTER
440    # src_ds = None
441    # src_band = None
442    # src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, GDT)
443    # src_ds.SetGeoTransform(geotransform)  # != 0 ?
444    # src_ds.SetProjection(authid)  # != 0 ?
445    src_band = src_ds.GetRasterBand(1)
446    if nodata:
447        src_band.SetNoDataValue(nodata)
448    else:
449        # useless paranoia ?
450        src_band.SetNoDataValue(-1)
451    src_band.WriteArray(label_map)
452    # close datasets
453    src_ds.FlushCache()
454    src_ds = None
455    dst_ds.FlushCache()
456    dst_ds = None
457    return True
458
459
460def plot(labels_reshaped, pipe1, pipe2, info_list, **kwargs):
461    """Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram.
462    Args:
463        labels_reshaped (np.ndarray): The reshaped labels of the clusters
464        pipe1 (Pipeline): The first pipeline object containing imputer and feature scaling steps
465        pipe2 (Pipeline): The second pipeline object containing the rescaling and clustering steps
466        info_list (list): A list of dictionaries containing information about each feature
467        **kargs: Additional keyword arguments
468            n_clusters (int): The number of clusters
469            distance_threshold (float): The linkage distance threshold
470            sieve (int): The number of pixels to use as a sieve filter
471            block (bool): Block the execution until the plot window is closed
472            filename (str): The filename to save the plot
473    """
474    from matplotlib import pyplot as plt
475
476    no_data_imputed = pipe1.named_steps["no_data_imputer"].output_data
477    pre_aggclu_data = pipe2.named_steps["agglomerative_clustering"].input_data
478
479    # filtrar onehot
480    num_onehots = sum([1 for i in info_list if i["scaling_strategy"] == "onehot"])
481    num_no_onehots = len(info_list) - num_onehots
482    pre_aggclu_data = pre_aggclu_data[:, :num_no_onehots]
483
484    # indices sin onehots
485    nohots_idxs = [i for i, info in enumerate(info_list) if info["scaling_strategy"] != "onehot"]
486
487    # filtrar onehot de no_data_treated
488    no_data_imputed = no_data_imputed[:, nohots_idxs]
489
490    # reordenados en robust y despues standard
491    rob_std_idxs = [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "robust"]
492    rob_std_idxs += [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "standard"]
493
494    # reordenar rob then std
495    pre_aggclu_data = pre_aggclu_data[:, rob_std_idxs]
496
497    names = [info_list[i]["fname"] for i in nohots_idxs]
498
499    fgs = np.array(plt.rcParams["figure.figsize"]) * 5
500    fig, axs = plt.subplots(3, 2, figsize=fgs)
501    suptitle = ""
502    if n_clusters := kwargs.get("n_clusters"):
503        suptitle = f"n_clusters: {n_clusters}"
504    if distance_threshold := kwargs.get("distance_threshold"):
505        suptitle = f"distance_threshold: {distance_threshold}"
506    if sieve := kwargs.get("sieve"):
507        suptitle += f", sieve: {sieve}"
508    if n_clusters or distance_threshold or sieve:
509        suptitle += f", resulting clusters: {len(np.unique(labels_reshaped))}"
510    suptitle += "\n(Not showing categorical data)"
511    fig.suptitle(suptitle)
512
513    # plot violin plot
514    axs[0, 0].violinplot(no_data_imputed, showmeans=False, showmedians=True, showextrema=True)
515    axs[0, 0].set_title("Violin Plot of NoData Imputed")
516    axs[0, 0].yaxis.grid(True)
517    axs[0, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
518    axs[0, 0].set_ylabel("Observed values")
519
520    # plot boxplot
521    axs[0, 1].boxplot(no_data_imputed)
522    axs[0, 1].set_title("Box Plot of NoData Imputed")
523    axs[0, 1].yaxis.grid(True)
524    axs[0, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
525    axs[0, 1].set_ylabel("Observed values")
526
527    # plot violin plot
528    axs[1, 0].violinplot(pre_aggclu_data, showmeans=False, showmedians=True, showextrema=True)
529    axs[1, 0].set_title("Violin Plot of Common Rescaled")
530    axs[1, 0].yaxis.grid(True)
531    axs[1, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
532    axs[0, 1].set_ylabel("Adjusted range")
533
534    # plot boxplot
535    axs[1, 1].boxplot(pre_aggclu_data)
536    axs[1, 1].set_title("Box Plot of Common Rescaled")
537    axs[1, 1].yaxis.grid(True)
538    axs[1, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
539    axs[0, 1].set_ylabel("Adjusted range")
540
541    # cluster history
542    unique_labels, counts = np.unique(labels_reshaped, return_counts=True)
543    axs[2, 0].plot(unique_labels, counts, marker="o", color="blue")
544    axs[2, 0].set_title("Cluster history size (in pixels)")
545    axs[2, 0].set_xlabel("Algorithm Step")
546    axs[2, 0].set_ylabel("Size (in pixels)")
547
548    # cluster histogram
549    axs[2, 1].hist(counts, log=True)
550    axs[2, 1].set_xlabel("Cluster Size (in pixels)")
551    axs[2, 1].set_ylabel("Number of Clusters")
552    axs[2, 1].set_title("Histogram of Cluster Sizes")
553
554    plt.tight_layout()
555    if filename := kwargs.get("filename"):
556        logger.info(f"Saving plot to {filename}")
557        plt.savefig(filename)
558    else:
559        if block := kwargs.get("block"):
560            plt.show(block=block)
561        else:
562            plt.show()
563
564
565def sieve_filter(data, threshold=2, connectedness=4, feedback=None):
566    """Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve
567    Args:
568        data (np.ndarray): The input data to filter
569        threshold (int): The maximum number of pixels in a cluster to keep
570        connectedness (int): The number of connected pixels to consider when filtering 4 or 8
571        feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins
572    Returns:
573        np.ndarray: The filtered data
574    """
575    logger.info("Applying sieve filter")
576
577    height, width = data.shape
578    # fprint("antes", np.sort(np.unique(data, return_counts=True)), len(np.unique(data)), level="info", feedback=feedback, logger=logger)
579    num_clusters = len(np.unique(data))
580    src_ds = gdal.GetDriverByName("MEM").Create("sieve", width, height, 1, GDT)
581    src_band = src_ds.GetRasterBand(1)
582    src_band.WriteArray(data)
583    if 0 != gdal.SieveFilter(src_band, None, src_band, threshold, connectedness):
584        fprint("Error applying sieve filter", level="error", feedback=feedback, logger=logger)
585    else:
586        sieved = src_band.ReadAsArray()
587        src_band = None
588        src_ds = None
589        num_sieved = len(np.unique(sieved))
590        # fprint("despues", np.sort(np.unique(sieved, return_counts=True)), len(np.unique(sieved)), level="info", feedback=feedback, logger=logger)
591        fprint(
592            f"Reduced from {num_clusters} to {num_sieved} clusters, {num_clusters-num_sieved} less",
593            level="info",
594            feedback=feedback,
595            logger=logger,
596        )
597        fprint(
598            "Please try again increasing distance_threshold or reducing n_clusters instead...",
599            level="info",
600            feedback=feedback,
601            logger=logger,
602        )
603        # from matplotlib import pyplot as plt
604        # fig, (ax1, ax2) = plt.subplots(1, 2)
605        # ax1.imshow(data)
606        # ax1.set_title("before sieve" + str(len(np.unique(data))))
607        # ax2.imshow(sieved)
608        # ax2.set_title("after sieve" + str(len(np.unique(sieved))))
609        # plt.show()
610        # data = sieved
611        return sieved
612
613
614def arg_parser(argv=None):
615    """Parse command line arguments."""
616    from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
617
618    parser = ArgumentParser(
619        description="Agglomerative Clustering with Connectivity for raster data",
620        formatter_class=ArgumentDefaultsHelpFormatter,
621        epilog="More at https://fire2a.github.io/fire2a-lib",
622    )
623    parser.add_argument(
624        "config_file",
625        nargs="?",
626        type=Path,
627        help="For each raster file, configure its preprocess: nodata & scaling methods",
628        default="config.toml",
629    )
630
631    aggclu = parser.add_mutually_exclusive_group(required=True)
632    aggclu.add_argument(
633        "-d",
634        "--distance_threshold",
635        type=float,
636        help="Distance threshold (a good starting point when scaling is 10, higher means less clusters, 0 could take a long time)",
637    )
638    aggclu.add_argument("-n", "--n_clusters", type=int, help="Number of clusters")
639
640    parser.add_argument("-or", "--output_raster", help="Output raster file, warning overwrites!", default="")
641    parser.add_argument("-op", "--output_poly", help="Output polygons file, warning overwrites!", default="output.gpkg")
642    parser.add_argument("-a", "--authid", type=str, help="Output raster authid", default="EPSG:3857")
643    parser.add_argument(
644        "-g", "--geotransform", type=str, help="Output raster geotransform", default="(0, 1, 0, 0, 0, 1)"
645    )
646    parser.add_argument(
647        "-nw",
648        "--no_write",
649        action="store_true",
650        help="Do not write outputs raster nor polygons",
651        default=False,
652    )
653    parser.add_argument(
654        "-s",
655        "--script",
656        action="store_true",
657        help="Run in script mode, returning the label_map and the pipeline object",
658        default=False,
659    )
660    parser.add_argument(
661        "--sieve",
662        type=int,
663        help="Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor",
664    )
665    parser.add_argument("--verbose", "-v", action="count", default=0, help="WARNING:1, INFO:2, DEBUG:3")
666
667    plot = parser.add_argument_group(
668        "Plotting, Visually inspect input distributions: NoData treated observations, rescaled data, with violing plots and boxplots. Also check output clustering size history and histograms"
669    )
670    plot.add_argument(
671        "-p",
672        "--plots",
673        action="store_true",
674        help="Activate the plotting routines",
675    )
676    plot.add_argument(
677        "-b",
678        "--block",
679        action="store_false",
680        default=True,
681        help="Block the execution until the plot window is closed. Use False for interactive ipykernels or QGIS",
682    )
683    plot.add_argument(
684        "-f",
685        "--filename",
686        type=str,
687        help="Filename to save the plot. If not provided, matplotlib will raise a window",
688    )
689    args = parser.parse_args(argv)
690    args.geotransform = tuple(map(float, args.geotransform[1:-1].split(",")))
691    if Path(args.config_file).is_file() is False:
692        parser.error(f"File {args.config_file} not found")
693    return args
694
695
696def main(argv=None):
697    """
698
699    args = arg_parser(["-d","10.0", "-g","(0, 10, 0, 0, 0, 10)", "config2.toml"])
700    args = arg_parser(["-d","10.0"]])
701    args = arg_parser(["-d","10.0", "config2.toml"])
702    args = arg_parser(["-n","10"])
703    """
704    if argv is sys.argv:
705        argv = sys.argv[1:]
706    args = arg_parser(argv)
707
708    if args.verbose != 0:
709        global logger
710        from fire2a import setup_logger
711
712        logger = setup_logger(verbosity=args.verbose)
713
714    logger.info("args %s", args)
715
716    # 2 LEE CONFIG
717    config = read_toml(args.config_file)
718    # logger.debug(config)
719
720    # 2.1 ADD DEFAULTS
721    for filename, file_config in config.items():
722        if "no_data_strategy" not in file_config:
723            config[filename]["no_data_strategy"] = "mean"
724        if "scaling_strategy" not in file_config:
725            config[filename]["scaling_strategy"] = "robust"
726        if "fill_value" not in file_config:
727            config[filename]["fill_value"] = 0
728        if "weight" not in file_config:
729            config[filename]["weight"] = 1
730    logger.debug(config)
731
732    # 3. LEE DATA
733    from fire2a.raster import read_raster
734
735    data_list, info_list = [], []
736    for filename, file_config in config.items():
737        data, info = read_raster(filename)
738        info["fname"] = Path(filename).name
739        info["no_data_strategy"] = file_config["no_data_strategy"]
740        info["scaling_strategy"] = file_config["scaling_strategy"]
741        info["fill_value"] = file_config["fill_value"]
742        info["weight"] = file_config["weight"]
743        data_list += [data]
744        info_list += [info]
745        logger.debug("%s", data[:2, :2])
746        logger.debug("%s", info)
747
748    # 4. VALIDAR 2d todos mismo shape
749    height, width = check_shapes(data_list)
750
751    # 5. lista[mapas] -> OBSERVACIONES
752    observations = np.column_stack([data.ravel() for data in data_list])
753
754    # 6. nodata -> feature scaling -> all scaling -> clustering
755    labels_reshaped, pipe1, pipe2 = pipelie(
756        observations,
757        info_list,
758        height,
759        width,
760        n_clusters=args.n_clusters,
761        distance_threshold=args.distance_threshold,
762    )  # insert more keyworded arguments to the clustering algorithm here!
763
764    # SIEVE
765    if args.sieve:
766        logger.info(f"Number of clusters before sieving: {len(np.unique(labels_reshaped))}")
767        labels_reshaped = sieve_filter(labels_reshaped, args.sieve)
768
769    logger.info(f"Final number of clusters: {len(np.unique(labels_reshaped))}")
770
771    # 7 debbuging plots
772    if args.plots:
773        plot(labels_reshaped, pipe1, pipe2, info_list, **vars(args))
774
775    # 8. ESCRIBIR RASTER
776    if not args.no_write:
777        if not write(
778            labels_reshaped,
779            width,
780            height,
781            output_raster=args.output_raster,
782            output_poly=args.output_poly,
783            authid=args.authid,
784            geotransform=args.geotransform,
785        ):
786            logger.error("Error writing output raster")
787
788    # 9. SCRIPT MODE
789    if args.script:
790        return labels_reshaped, pipe1, pipe2
791
792    return 0
793
794
795if __name__ == "__main__":
796    sys.exit(main(sys.argv))
logger = <Logger fire2a.agglomerative_clustering (WARNING)>
def check_shapes(data_list):
126def check_shapes(data_list):
127    """Check if all data arrays have the same shape and are 2D.
128    Returns the shape of the data arrays if they are all equal.
129    """
130    from functools import reduce
131
132    def equal_or_error(x, y):
133        """Check if x and y are equal, returns x if equal else raises a ValueError."""
134        if x == y:
135            return x
136        else:
137            raise ValueError("All data arrays must have the same shape")
138
139    shape = reduce(equal_or_error, (data.shape for data in data_list))
140    if len(shape) != 2:
141        raise ValueError("All data arrays must be 2D")
142    height, width = shape
143    return height, width

Check if all data arrays have the same shape and are 2D. Returns the shape of the data arrays if they are all equal.

def get_map_neighbors(height, width, num_neighbors=8):
146def get_map_neighbors(height, width, num_neighbors=8):
147    """Get the neighbors of each cell in a 2D grid.
148    n_jobs=-1 uses all available cores.
149    """
150
151    grid_points = np.indices((height, width)).reshape(2, -1).T
152
153    nb4 = radius_neighbors_graph(grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1)
154    nb8 = radius_neighbors_graph(grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1)
155
156    # assert nb4.shape[0] == width * height
157    # assert nb8.shape[1] == width * height
158    # for n in range(width * height):
159    #     _, neighbors = np.nonzero(nb4[n])
160    #     assert 2<= len(neighbors) <= 4, f"{n=} {neighbors=}"
161    #     assert 3<= len(neighbors) <= 8, f"{n=} {neighbors=}"
162    return nb4, nb8

Get the neighbors of each cell in a 2D grid. n_jobs=-1 uses all available cores.

class NoDataImputer(sklearn.base.BaseEstimator, sklearn.base.TransformerMixin):
165class NoDataImputer(BaseEstimator, TransformerMixin):
166    """A custom Imputer that treats a specified nodata_value as np.nan and supports different strategies per column"""
167
168    def __init__(self, no_data_values, strategies, constants):
169        self.no_data_values = no_data_values
170        self.strategies = strategies
171        self.constants = constants
172        self.imputers = []
173        for no_data_value, strategy, constant in zip(no_data_values, strategies, constants):
174            if no_data_value:
175                self.imputers += [SimpleImputer(strategy=strategy, missing_values=no_data_value, fill_value=constant)]
176            else:
177                self.imputers += [SimpleImputer(strategy=strategy, fill_value=constant)]
178
179    def fit(self, X, y=None):
180        for i, imputer in enumerate(self.imputers):
181            imputer.fit(X[:, [i]], y)
182        return self
183
184    def transform(self, X):
185        for i, imputer in enumerate(self.imputers):
186            X[:, [i]] = imputer.transform(X[:, [i]])
187        self.output_data = X
188        return X

A custom Imputer that treats a specified nodata_value as np.nan and supports different strategies per column

NoDataImputer(no_data_values, strategies, constants)
168    def __init__(self, no_data_values, strategies, constants):
169        self.no_data_values = no_data_values
170        self.strategies = strategies
171        self.constants = constants
172        self.imputers = []
173        for no_data_value, strategy, constant in zip(no_data_values, strategies, constants):
174            if no_data_value:
175                self.imputers += [SimpleImputer(strategy=strategy, missing_values=no_data_value, fill_value=constant)]
176            else:
177                self.imputers += [SimpleImputer(strategy=strategy, fill_value=constant)]
no_data_values
strategies
constants
imputers
def fit(self, X, y=None):
179    def fit(self, X, y=None):
180        for i, imputer in enumerate(self.imputers):
181            imputer.fit(X[:, [i]], y)
182        return self
def transform(self, X):
184    def transform(self, X):
185        for i, imputer in enumerate(self.imputers):
186            X[:, [i]] = imputer.transform(X[:, [i]])
187        self.output_data = X
188        return X
class RescaleAllToCommonRange(sklearn.base.BaseEstimator, sklearn.base.TransformerMixin):
191class RescaleAllToCommonRange(BaseEstimator, TransformerMixin):
192    """A custom transformer that rescales all features to a common range [0, 1]"""
193
194    def __init__(self, weight_map):
195        self.weight_map = weight_map
196
197    def fit(self, X, y=None):
198        # Determine the combined range of all scaled features
199        self.min_val = [x.min() for x in X.T]
200        self.max_val = [x.max() for x in X.T]
201        return self
202
203    def transform(self, X):
204        # Rescale all features to match the common range
205        for i, (x, mi, ma) in enumerate(zip(X.T, self.min_val, self.max_val)):
206            if ma - mi == 0:
207                X.T[i] = x * self.weight_map[i]
208            else:
209                X.T[i] = (x - mi) / (ma - mi) * self.weight_map[i]
210        return X

A custom transformer that rescales all features to a common range [0, 1]

RescaleAllToCommonRange(weight_map)
194    def __init__(self, weight_map):
195        self.weight_map = weight_map
weight_map
def fit(self, X, y=None):
197    def fit(self, X, y=None):
198        # Determine the combined range of all scaled features
199        self.min_val = [x.min() for x in X.T]
200        self.max_val = [x.max() for x in X.T]
201        return self
def transform(self, X):
203    def transform(self, X):
204        # Rescale all features to match the common range
205        for i, (x, mi, ma) in enumerate(zip(X.T, self.min_val, self.max_val)):
206            if ma - mi == 0:
207                X.T[i] = x * self.weight_map[i]
208            else:
209                X.T[i] = (x - mi) / (ma - mi) * self.weight_map[i]
210        return X
class CustomAgglomerativeClustering(sklearn.base.BaseEstimator, sklearn.base.TransformerMixin):
213class CustomAgglomerativeClustering(BaseEstimator, TransformerMixin):
214    def __init__(self, height, width, neighbors=4, **kwargs):
215        self.height = height
216        self.width = width
217        self.neighbors = neighbors
218
219        self.grid_points = np.indices((height, width)).reshape(2, -1).T
220        if neighbors == 4:
221            connectivity = radius_neighbors_graph(
222                self.grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1
223            )
224        elif neighbors == 8:
225            connectivity = radius_neighbors_graph(
226                self.grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1
227            )
228
229        self.connectivity = connectivity
230        self.kwargs = kwargs
231        self.model = AgglomerativeClustering(connectivity=self.connectivity, **self.kwargs)
232
233    def fit(self, X, y=None):
234        logger.debug("not sure why, but this method is never called alas needed")
235        self.model.fit(X)
236        return self
237
238    def fit_predict(self, X, y=None):
239        self.input_data = X
240        return self.model.fit_predict(X)

Base class for all estimators in scikit-learn.

Inheriting from this class provides default implementations of:

  • setting and getting parameters used by GridSearchCV and friends;
  • textual and HTML representation displayed in terminals and IDEs;
  • estimator serialization;
  • parameters validation;
  • data validation;
  • feature names validation.

Read more in the :ref:User Guide <rolling_your_own_estimator>.

Notes

All estimators should specify all the parameters that can be set at the class level in their __init__ as explicit keyword arguments (no *args or **kwargs).

Examples

>>> import numpy as np
>>> from sklearn.base import BaseEstimator
>>> class MyEstimator(BaseEstimator):
...     def __init__(self, *, param=1):
...         self.param = param
...     def fit(self, X, y=None):
...         self.is_fitted_ = True
...         return self
...     def predict(self, X):
...         return np.full(shape=X.shape[0], fill_value=self.param)
>>> estimator = MyEstimator(param=2)
>>> estimator.get_params()
{'param': 2}
>>> X = np.array([[1, 2], [2, 3], [3, 4]])
>>> y = np.array([1, 0, 1])
>>> estimator.fit(X, y).predict(X)
array([2, 2, 2])
>>> estimator.set_params(param=3).fit(X, y).predict(X)
array([3, 3, 3])
CustomAgglomerativeClustering(height, width, neighbors=4, **kwargs)
214    def __init__(self, height, width, neighbors=4, **kwargs):
215        self.height = height
216        self.width = width
217        self.neighbors = neighbors
218
219        self.grid_points = np.indices((height, width)).reshape(2, -1).T
220        if neighbors == 4:
221            connectivity = radius_neighbors_graph(
222                self.grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1
223            )
224        elif neighbors == 8:
225            connectivity = radius_neighbors_graph(
226                self.grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1
227            )
228
229        self.connectivity = connectivity
230        self.kwargs = kwargs
231        self.model = AgglomerativeClustering(connectivity=self.connectivity, **self.kwargs)
height
width
neighbors
grid_points
connectivity
kwargs
model
def fit(self, X, y=None):
233    def fit(self, X, y=None):
234        logger.debug("not sure why, but this method is never called alas needed")
235        self.model.fit(X)
236        return self
def fit_predict(self, X, y=None):
238    def fit_predict(self, X, y=None):
239        self.input_data = X
240        return self.model.fit_predict(X)
def pipelie(observations, info_list, height, width, **kwargs):
243def pipelie(observations, info_list, height, width, **kwargs):
244    """A scipy pipeline to achieve Agglomerative Clustering with connectivity on 2d matrix
245    Steps are:
246    1. Impute missing values
247    2. Scale the features
248    3. Rescale all features to a common range
249    4. Cluster the data using Agglomerative Clustering with connectivity
250    5. Reshape the labels back to the original spatial map shape
251    6. Return the labels and the pipeline object
252
253    Args:
254        observations (np.ndarray): The input data to cluster (n_samples, n_features) shaped
255        info_list (list): A list of dictionaries containing information about each feature
256        height (int): The height of the spatial map
257        width (int): The width of the spatial map
258        kwargs: Additional keyword arguments for AgglomerativeClustering, at least one of n_clusters or distance_threshold
259
260    Returns:
261        np.ndarray: The labels of the clusters, reshaped to the original 2d spatial map shape
262        Pipeline: The pipeline object containing all the steps of the pipeline
263    """
264    # kwargs = {"n_clusters": args.n_clusters, "distance_threshold": args.distance_threshold}
265
266    # imputer strategies
267    no_data_values = [info["NoDataValue"] for info in info_list]
268    no_data_strategies = [info["no_data_strategy"] for info in info_list]
269    fill_values = [info["fill_value"] for info in info_list]
270    weights = [info["weight"] for info in info_list]
271    # scaling_strategies = [info["scaling_strategy"] for info in info_list]
272
273    # scaling strategies
274    index_map = {}
275    for strategy in ["robust", "standard", "onehot"]:
276        index_map[strategy] = [i for i, info in enumerate(info_list) if info["scaling_strategy"] == strategy]
277    # index_map
278    # !cat config.toml
279
280    # Create transformers for each type
281    robust_transformer = Pipeline(steps=[("robust_step", RobustScaler())])
282    standard_transformer = Pipeline(steps=[("standard_step", StandardScaler())])
283    onehot_transformer = Pipeline(steps=[("onehot_step", OneHotEncoder(sparse_output=False))])
284    # OneHotEncoder._n_features_outs):
285
286    # Combine transformers using ColumnTransformer
287    feature_scaler = ColumnTransformer(
288        transformers=[
289            ("robust", robust_transformer, index_map["robust"]),
290            ("standard", standard_transformer, index_map["standard"]),
291            ("onehot", onehot_transformer, index_map["onehot"]),
292        ]
293    )
294
295    # # Create a temporary directory for caching calculations
296    # # FOR ACCESING STEPS LATER ON VERY LARGE DATASETS
297    # import tempfile
298    # import joblib
299    # temp_dir = tempfile.mkdtemp()
300    # memory = joblib.Memory(location=temp_dir, verbose=0)
301
302    # Create and apply the pipeline
303    # part 1 until feature scaling
304    pipe1 = Pipeline(
305        # n_features_in_ : int
306        # feature_names_in_ : ndarray of shape (`n_features_in_`,)
307        steps=[
308            ("no_data_imputer", NoDataImputer(no_data_values, no_data_strategies, fill_values)),
309            ("feature_scaling", feature_scaler),
310        ],
311        # memory=memory,
312        verbose=True,
313    )
314    # map weights to new columns (onehot feature scaler creates one column per category)
315    obs1 = pipe1.fit_transform(observations)
316    cat_names = pipe1.named_steps["feature_scaling"]["onehot"].get_feature_names_out()
317    split_names = [name.split("_")[0] for name in cat_names]
318    cat_count = np.unique(split_names, return_counts=True)[1]
319    onehot_map = {}
320    for i, key in enumerate(index_map["onehot"]):
321        onehot_map[key] = cat_count[i]
322    # onehot_map = {key: cat_count[i] for i, key in enumerate(index_map["onehot"])}
323    weight_map = []
324    for name, idxs in index_map.items():
325        for idx in idxs:
326            if name == "onehot":
327                weight_map += [weights[idx]] * onehot_map[idx]
328                continue
329            weight_map += [weights[idx]]
330    # part 2 use weight_map and cluster
331    pipe2 = Pipeline(
332        steps=[
333            ("common_rescaling", RescaleAllToCommonRange(weight_map)),
334            ("agglomerative_clustering", CustomAgglomerativeClustering(height, width, neighbors=4, **kwargs)),
335        ],
336        # memory=memory,
337        verbose=True,
338    )
339
340    # apply pipeLIE
341    labels = pipe2.fit_predict(obs1)
342
343    # Reshape the labels back to the original spatial map shape
344    labels_reshaped = labels.reshape(height, width)
345    return labels_reshaped, pipe1, pipe2

A scipy pipeline to achieve Agglomerative Clustering with connectivity on 2d matrix Steps are:

  1. Impute missing values
  2. Scale the features
  3. Rescale all features to a common range
  4. Cluster the data using Agglomerative Clustering with connectivity
  5. Reshape the labels back to the original spatial map shape
  6. Return the labels and the pipeline object

Args: observations (np.ndarray): The input data to cluster (n_samples, n_features) shaped info_list (list): A list of dictionaries containing information about each feature height (int): The height of the spatial map width (int): The width of the spatial map kwargs: Additional keyword arguments for AgglomerativeClustering, at least one of n_clusters or distance_threshold

Returns: np.ndarray: The labels of the clusters, reshaped to the original 2d spatial map shape Pipeline: The pipeline object containing all the steps of the pipeline

def write( label_map, width, height, output_raster='', output_poly='output.shp', authid='EPSG:3857', geotransform=(0, 1, 0, 0, 0, 1), nodata=None, feedback=None):
348def write(
349    label_map,
350    width,
351    height,
352    output_raster="",
353    output_poly="output.shp",
354    authid="EPSG:3857",
355    geotransform=(0, 1, 0, 0, 0, 1),
356    nodata=None,
357    feedback=None,
358):
359
360    from fire2a.processing_utils import get_output_raster_format, get_vector_driver_from_filename
361
362    # setup drivers for raster and polygon output formats
363    if output_raster == "":
364        raster_driver = "MEM"
365    else:
366        try:
367            raster_driver = get_output_raster_format(output_raster, feedback=feedback)
368        except Exception:
369            raster_driver = "GTiff"
370    try:
371        poly_driver = get_vector_driver_from_filename(output_poly)
372    except Exception:
373        poly_driver = "ESRI Shapefile"
374
375    # create raster output
376    src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, GDT)
377    src_ds.SetGeoTransform(geotransform)  # != 0 ?
378    src_ds.SetProjection(authid)  # != 0 ?
379    #  src_band = src_ds.GetRasterBand(1)
380    #  if nodata:
381    #      src_band.SetNoDataValue(nodata)
382    #  src_band.WriteArray(label_map)
383
384    # create polygon output
385    drv = ogr.GetDriverByName(poly_driver)
386    dst_ds = drv.CreateDataSource(output_poly)
387    sp_ref = osr.SpatialReference()
388    sp_ref.SetFromUserInput(authid)  # != 0 ?
389    dst_lyr = dst_ds.CreateLayer("clusters", srs=sp_ref, geom_type=ogr.wkbPolygon)
390    dst_lyr.CreateField(ogr.FieldDefn("DN", OFT))  # != 0 ?
391    dst_lyr.CreateField(ogr.FieldDefn("pixel_count", OFT))
392    # dst_lyr.CreateField(ogr.FieldDefn("area", OFT))
393    # dst_lyr.CreateField(ogr.FieldDefn("perimeter", OFT))
394
395    # 0 != gdal.Polygonize( srcband, maskband, dst_layer, dst_field, options, callback = gdal.TermProgress)
396
397    # FAIL: All together it merges labels into a single polygon
398    #  src_band = src_ds.GetRasterBand(1)
399    #  if nodata:
400    #      src_band.SetNoDataValue(nodata)
401    #  src_band.WriteArray(label_map)
402    # gdal.Polygonize(src_band, None, dst_lyr, 0, callback=gdal.TermProgress)  # , ["8CONNECTED=8"])
403
404    # B separado
405    # for loop for creating each label_map value into a different polygonized feature
406    mem_drv = ogr.GetDriverByName("Memory")
407    tmp_ds = mem_drv.CreateDataSource("tmp_ds")
408    # itera = iter(np.unique(label_map))
409    # cluster_id = next(itera)
410    areas = []
411    pixels = []
412    data = np.zeros_like(label_map)
413    for cluster_id, px_count in zip(*np.unique(label_map, return_counts=True)):
414        # temporarily write band
415        src_band = src_ds.GetRasterBand(1)
416        src_band.SetNoDataValue(0)
417        data[label_map == cluster_id] = 1
418        src_band.WriteArray(data)
419        # create feature
420        tmp_lyr = tmp_ds.CreateLayer("", srs=sp_ref)
421        gdal.Polygonize(src_band, src_band.GetMaskBand(), tmp_lyr, -1)
422        # unset tmp data
423        data[label_map == cluster_id] = 0
424        # set polygon feat
425        feat = tmp_lyr.GetNextFeature()
426        geom = feat.GetGeometryRef()
427        featureDefn = dst_lyr.GetLayerDefn()
428        feature = ogr.Feature(featureDefn)
429        feature.SetGeometry(geom)
430        feature.SetField("DN", float(cluster_id))
431        areas += [geom.GetArea()]
432        pixels += [px_count]
433        feature.SetField("pixel_count", float(px_count))
434        # feature.SetField("area", int(geom.GetArea()))
435        # feature.SetField("perimeter", int(geom.Boundary().Length()))
436        dst_lyr.CreateFeature(feature)
437
438    fprint(f"Polygon Areas: {min(areas)=} {max(areas)=}", level="info", feedback=feedback, logger=logger)
439    fprint(f"Cluster PixelCounts: {min(pixels)=} {max(pixels)=}", level="info", feedback=feedback, logger=logger)
440    # RESTART RASTER
441    # src_ds = None
442    # src_band = None
443    # src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, GDT)
444    # src_ds.SetGeoTransform(geotransform)  # != 0 ?
445    # src_ds.SetProjection(authid)  # != 0 ?
446    src_band = src_ds.GetRasterBand(1)
447    if nodata:
448        src_band.SetNoDataValue(nodata)
449    else:
450        # useless paranoia ?
451        src_band.SetNoDataValue(-1)
452    src_band.WriteArray(label_map)
453    # close datasets
454    src_ds.FlushCache()
455    src_ds = None
456    dst_ds.FlushCache()
457    dst_ds = None
458    return True
def plot(labels_reshaped, pipe1, pipe2, info_list, **kwargs):
461def plot(labels_reshaped, pipe1, pipe2, info_list, **kwargs):
462    """Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram.
463    Args:
464        labels_reshaped (np.ndarray): The reshaped labels of the clusters
465        pipe1 (Pipeline): The first pipeline object containing imputer and feature scaling steps
466        pipe2 (Pipeline): The second pipeline object containing the rescaling and clustering steps
467        info_list (list): A list of dictionaries containing information about each feature
468        **kargs: Additional keyword arguments
469            n_clusters (int): The number of clusters
470            distance_threshold (float): The linkage distance threshold
471            sieve (int): The number of pixels to use as a sieve filter
472            block (bool): Block the execution until the plot window is closed
473            filename (str): The filename to save the plot
474    """
475    from matplotlib import pyplot as plt
476
477    no_data_imputed = pipe1.named_steps["no_data_imputer"].output_data
478    pre_aggclu_data = pipe2.named_steps["agglomerative_clustering"].input_data
479
480    # filtrar onehot
481    num_onehots = sum([1 for i in info_list if i["scaling_strategy"] == "onehot"])
482    num_no_onehots = len(info_list) - num_onehots
483    pre_aggclu_data = pre_aggclu_data[:, :num_no_onehots]
484
485    # indices sin onehots
486    nohots_idxs = [i for i, info in enumerate(info_list) if info["scaling_strategy"] != "onehot"]
487
488    # filtrar onehot de no_data_treated
489    no_data_imputed = no_data_imputed[:, nohots_idxs]
490
491    # reordenados en robust y despues standard
492    rob_std_idxs = [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "robust"]
493    rob_std_idxs += [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "standard"]
494
495    # reordenar rob then std
496    pre_aggclu_data = pre_aggclu_data[:, rob_std_idxs]
497
498    names = [info_list[i]["fname"] for i in nohots_idxs]
499
500    fgs = np.array(plt.rcParams["figure.figsize"]) * 5
501    fig, axs = plt.subplots(3, 2, figsize=fgs)
502    suptitle = ""
503    if n_clusters := kwargs.get("n_clusters"):
504        suptitle = f"n_clusters: {n_clusters}"
505    if distance_threshold := kwargs.get("distance_threshold"):
506        suptitle = f"distance_threshold: {distance_threshold}"
507    if sieve := kwargs.get("sieve"):
508        suptitle += f", sieve: {sieve}"
509    if n_clusters or distance_threshold or sieve:
510        suptitle += f", resulting clusters: {len(np.unique(labels_reshaped))}"
511    suptitle += "\n(Not showing categorical data)"
512    fig.suptitle(suptitle)
513
514    # plot violin plot
515    axs[0, 0].violinplot(no_data_imputed, showmeans=False, showmedians=True, showextrema=True)
516    axs[0, 0].set_title("Violin Plot of NoData Imputed")
517    axs[0, 0].yaxis.grid(True)
518    axs[0, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
519    axs[0, 0].set_ylabel("Observed values")
520
521    # plot boxplot
522    axs[0, 1].boxplot(no_data_imputed)
523    axs[0, 1].set_title("Box Plot of NoData Imputed")
524    axs[0, 1].yaxis.grid(True)
525    axs[0, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
526    axs[0, 1].set_ylabel("Observed values")
527
528    # plot violin plot
529    axs[1, 0].violinplot(pre_aggclu_data, showmeans=False, showmedians=True, showextrema=True)
530    axs[1, 0].set_title("Violin Plot of Common Rescaled")
531    axs[1, 0].yaxis.grid(True)
532    axs[1, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
533    axs[0, 1].set_ylabel("Adjusted range")
534
535    # plot boxplot
536    axs[1, 1].boxplot(pre_aggclu_data)
537    axs[1, 1].set_title("Box Plot of Common Rescaled")
538    axs[1, 1].yaxis.grid(True)
539    axs[1, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
540    axs[0, 1].set_ylabel("Adjusted range")
541
542    # cluster history
543    unique_labels, counts = np.unique(labels_reshaped, return_counts=True)
544    axs[2, 0].plot(unique_labels, counts, marker="o", color="blue")
545    axs[2, 0].set_title("Cluster history size (in pixels)")
546    axs[2, 0].set_xlabel("Algorithm Step")
547    axs[2, 0].set_ylabel("Size (in pixels)")
548
549    # cluster histogram
550    axs[2, 1].hist(counts, log=True)
551    axs[2, 1].set_xlabel("Cluster Size (in pixels)")
552    axs[2, 1].set_ylabel("Number of Clusters")
553    axs[2, 1].set_title("Histogram of Cluster Sizes")
554
555    plt.tight_layout()
556    if filename := kwargs.get("filename"):
557        logger.info(f"Saving plot to {filename}")
558        plt.savefig(filename)
559    else:
560        if block := kwargs.get("block"):
561            plt.show(block=block)
562        else:
563            plt.show()

Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram. Args: labels_reshaped (np.ndarray): The reshaped labels of the clusters pipe1 (Pipeline): The first pipeline object containing imputer and feature scaling steps pipe2 (Pipeline): The second pipeline object containing the rescaling and clustering steps info_list (list): A list of dictionaries containing information about each feature **kargs: Additional keyword arguments n_clusters (int): The number of clusters distance_threshold (float): The linkage distance threshold sieve (int): The number of pixels to use as a sieve filter block (bool): Block the execution until the plot window is closed filename (str): The filename to save the plot

def sieve_filter(data, threshold=2, connectedness=4, feedback=None):
566def sieve_filter(data, threshold=2, connectedness=4, feedback=None):
567    """Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve
568    Args:
569        data (np.ndarray): The input data to filter
570        threshold (int): The maximum number of pixels in a cluster to keep
571        connectedness (int): The number of connected pixels to consider when filtering 4 or 8
572        feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins
573    Returns:
574        np.ndarray: The filtered data
575    """
576    logger.info("Applying sieve filter")
577
578    height, width = data.shape
579    # fprint("antes", np.sort(np.unique(data, return_counts=True)), len(np.unique(data)), level="info", feedback=feedback, logger=logger)
580    num_clusters = len(np.unique(data))
581    src_ds = gdal.GetDriverByName("MEM").Create("sieve", width, height, 1, GDT)
582    src_band = src_ds.GetRasterBand(1)
583    src_band.WriteArray(data)
584    if 0 != gdal.SieveFilter(src_band, None, src_band, threshold, connectedness):
585        fprint("Error applying sieve filter", level="error", feedback=feedback, logger=logger)
586    else:
587        sieved = src_band.ReadAsArray()
588        src_band = None
589        src_ds = None
590        num_sieved = len(np.unique(sieved))
591        # fprint("despues", np.sort(np.unique(sieved, return_counts=True)), len(np.unique(sieved)), level="info", feedback=feedback, logger=logger)
592        fprint(
593            f"Reduced from {num_clusters} to {num_sieved} clusters, {num_clusters-num_sieved} less",
594            level="info",
595            feedback=feedback,
596            logger=logger,
597        )
598        fprint(
599            "Please try again increasing distance_threshold or reducing n_clusters instead...",
600            level="info",
601            feedback=feedback,
602            logger=logger,
603        )
604        # from matplotlib import pyplot as plt
605        # fig, (ax1, ax2) = plt.subplots(1, 2)
606        # ax1.imshow(data)
607        # ax1.set_title("before sieve" + str(len(np.unique(data))))
608        # ax2.imshow(sieved)
609        # ax2.set_title("after sieve" + str(len(np.unique(sieved))))
610        # plt.show()
611        # data = sieved
612        return sieved

Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve Args: data (np.ndarray): The input data to filter threshold (int): The maximum number of pixels in a cluster to keep connectedness (int): The number of connected pixels to consider when filtering 4 or 8 feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins Returns: np.ndarray: The filtered data

def arg_parser(argv=None):
615def arg_parser(argv=None):
616    """Parse command line arguments."""
617    from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
618
619    parser = ArgumentParser(
620        description="Agglomerative Clustering with Connectivity for raster data",
621        formatter_class=ArgumentDefaultsHelpFormatter,
622        epilog="More at https://fire2a.github.io/fire2a-lib",
623    )
624    parser.add_argument(
625        "config_file",
626        nargs="?",
627        type=Path,
628        help="For each raster file, configure its preprocess: nodata & scaling methods",
629        default="config.toml",
630    )
631
632    aggclu = parser.add_mutually_exclusive_group(required=True)
633    aggclu.add_argument(
634        "-d",
635        "--distance_threshold",
636        type=float,
637        help="Distance threshold (a good starting point when scaling is 10, higher means less clusters, 0 could take a long time)",
638    )
639    aggclu.add_argument("-n", "--n_clusters", type=int, help="Number of clusters")
640
641    parser.add_argument("-or", "--output_raster", help="Output raster file, warning overwrites!", default="")
642    parser.add_argument("-op", "--output_poly", help="Output polygons file, warning overwrites!", default="output.gpkg")
643    parser.add_argument("-a", "--authid", type=str, help="Output raster authid", default="EPSG:3857")
644    parser.add_argument(
645        "-g", "--geotransform", type=str, help="Output raster geotransform", default="(0, 1, 0, 0, 0, 1)"
646    )
647    parser.add_argument(
648        "-nw",
649        "--no_write",
650        action="store_true",
651        help="Do not write outputs raster nor polygons",
652        default=False,
653    )
654    parser.add_argument(
655        "-s",
656        "--script",
657        action="store_true",
658        help="Run in script mode, returning the label_map and the pipeline object",
659        default=False,
660    )
661    parser.add_argument(
662        "--sieve",
663        type=int,
664        help="Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor",
665    )
666    parser.add_argument("--verbose", "-v", action="count", default=0, help="WARNING:1, INFO:2, DEBUG:3")
667
668    plot = parser.add_argument_group(
669        "Plotting, Visually inspect input distributions: NoData treated observations, rescaled data, with violing plots and boxplots. Also check output clustering size history and histograms"
670    )
671    plot.add_argument(
672        "-p",
673        "--plots",
674        action="store_true",
675        help="Activate the plotting routines",
676    )
677    plot.add_argument(
678        "-b",
679        "--block",
680        action="store_false",
681        default=True,
682        help="Block the execution until the plot window is closed. Use False for interactive ipykernels or QGIS",
683    )
684    plot.add_argument(
685        "-f",
686        "--filename",
687        type=str,
688        help="Filename to save the plot. If not provided, matplotlib will raise a window",
689    )
690    args = parser.parse_args(argv)
691    args.geotransform = tuple(map(float, args.geotransform[1:-1].split(",")))
692    if Path(args.config_file).is_file() is False:
693        parser.error(f"File {args.config_file} not found")
694    return args

Parse command line arguments.

def main(argv=None):
697def main(argv=None):
698    """
699
700    args = arg_parser(["-d","10.0", "-g","(0, 10, 0, 0, 0, 10)", "config2.toml"])
701    args = arg_parser(["-d","10.0"]])
702    args = arg_parser(["-d","10.0", "config2.toml"])
703    args = arg_parser(["-n","10"])
704    """
705    if argv is sys.argv:
706        argv = sys.argv[1:]
707    args = arg_parser(argv)
708
709    if args.verbose != 0:
710        global logger
711        from fire2a import setup_logger
712
713        logger = setup_logger(verbosity=args.verbose)
714
715    logger.info("args %s", args)
716
717    # 2 LEE CONFIG
718    config = read_toml(args.config_file)
719    # logger.debug(config)
720
721    # 2.1 ADD DEFAULTS
722    for filename, file_config in config.items():
723        if "no_data_strategy" not in file_config:
724            config[filename]["no_data_strategy"] = "mean"
725        if "scaling_strategy" not in file_config:
726            config[filename]["scaling_strategy"] = "robust"
727        if "fill_value" not in file_config:
728            config[filename]["fill_value"] = 0
729        if "weight" not in file_config:
730            config[filename]["weight"] = 1
731    logger.debug(config)
732
733    # 3. LEE DATA
734    from fire2a.raster import read_raster
735
736    data_list, info_list = [], []
737    for filename, file_config in config.items():
738        data, info = read_raster(filename)
739        info["fname"] = Path(filename).name
740        info["no_data_strategy"] = file_config["no_data_strategy"]
741        info["scaling_strategy"] = file_config["scaling_strategy"]
742        info["fill_value"] = file_config["fill_value"]
743        info["weight"] = file_config["weight"]
744        data_list += [data]
745        info_list += [info]
746        logger.debug("%s", data[:2, :2])
747        logger.debug("%s", info)
748
749    # 4. VALIDAR 2d todos mismo shape
750    height, width = check_shapes(data_list)
751
752    # 5. lista[mapas] -> OBSERVACIONES
753    observations = np.column_stack([data.ravel() for data in data_list])
754
755    # 6. nodata -> feature scaling -> all scaling -> clustering
756    labels_reshaped, pipe1, pipe2 = pipelie(
757        observations,
758        info_list,
759        height,
760        width,
761        n_clusters=args.n_clusters,
762        distance_threshold=args.distance_threshold,
763    )  # insert more keyworded arguments to the clustering algorithm here!
764
765    # SIEVE
766    if args.sieve:
767        logger.info(f"Number of clusters before sieving: {len(np.unique(labels_reshaped))}")
768        labels_reshaped = sieve_filter(labels_reshaped, args.sieve)
769
770    logger.info(f"Final number of clusters: {len(np.unique(labels_reshaped))}")
771
772    # 7 debbuging plots
773    if args.plots:
774        plot(labels_reshaped, pipe1, pipe2, info_list, **vars(args))
775
776    # 8. ESCRIBIR RASTER
777    if not args.no_write:
778        if not write(
779            labels_reshaped,
780            width,
781            height,
782            output_raster=args.output_raster,
783            output_poly=args.output_poly,
784            authid=args.authid,
785            geotransform=args.geotransform,
786        ):
787            logger.error("Error writing output raster")
788
789    # 9. SCRIPT MODE
790    if args.script:
791        return labels_reshaped, pipe1, pipe2
792
793    return 0

args = arg_parser(["-d","10.0", "-g","(0, 10, 0, 0, 0, 10)", "config2.toml"]) args = arg_parser(["-d","10.0"]]) args = arg_parser(["-d","10.0", "config2.toml"]) args = arg_parser(["-n","10"])