fire2a.agglomerative_clustering

πŸ‘‹πŸŒŽ 🌲πŸ”₯

Raster clustering

Usage

Overview

  1. Choose your raster files
  2. Configure nodata, scaling strategies and weights in the config.toml file
  3. Choose "distance threshold" (or "number of clusters") for the Agglomerative clustering algorithm. Recommended:
    • Start with a distance threshold of 10.0 and decrease for more or increase for less clusters
    • After calibrating the distance threshold;
    • Sieve small clusters (merge them to the biggest neighbor) with the --sieve integer_pixels_size option

Execution

# get command line help
python -m fire2a.agglomerative_clustering -h
python -m fire2a.agglomerative_clustering --help

# activate your qgis dev environment
source ~/pyqgisdev/bin/activate 
# execute 
(qgis) $ python -m fire2a.agglomerative_clustering -d 10.0

# windowsπŸ’© users should use QGIS's python
C:\PROGRA~1\QGIS33~1.3\bin\python-qgis.bat -m fire2a.agglomerative_clustering -d 10.0

More info on: How to windows πŸ’© using qgis's python

Preparation

1. Choose your raster files

  • Any GDAL compatible raster will be read
  • Place them all in the same directory where the script will be executed
  • "Quote them" if they have any non alphanumerical chars [a-zA-Z0-9]

2. Preprocessing configuration

See the config.toml file for example of the configuration of the preprocessing steps. The file is structured as follows:

["filename.tif"]
no_data_strategy = "most_frequent"
scaling_strategy = "onehot"
fill_value = 0
weight = 1
  1. __scaling_strategy__

    • can be "standard", "robust", "onehot"
    • default is "robust"
    • Standard: (x-mean)/stddev
    • Robust: same but droping the tails of the distribution
    • OneHot: __for CATEGORICAL DATA__
  2. __no_data_strategy__

    • can be "mean", "median", "most_frequent", "constant"
    • default is "mean"
    • categorical data should use "most_frequent" or "constant"
    • "constant" will use the value in __fill_value__ (see below)
    • SimpleImputer
  3. __fill_value__

    • used when __no_data_strategy__ is "constant"
    • default is 0
    • SimpleImputer
  4. __weight__

    • default is 1
    • used to give more importance to some features than others
    • This is done after the nodata imputation and scaling steps, before clustering

3. Clustering configuration

  1. __Agglomerative__ clustering algorithm is used. The following parameters are muttually exclusive:
  • -n or --n_clusters: The number of clusters to form as well as the number of centroids to generate.
  • -d or --distance_threshold: The linkage distance threshold above which, clusters will not be merged. When scaling start with 10.0 and downward (0.0 is compute the whole algorithm).
  • More parameters for clustering can be passed directly into the pipelie method as keyword arguments
  1. __Sieve filter__ is applied to remove small clusters. The sieve filter is applied using the GDAL sieve library
  • --sieve: Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor

4. Post-processing

Outputs can be:

  • A raster file with the cluster labels and a polygon file with the cluster polygons
  • A polygon file with the cluster polygons, with attribute being the number of pixels in each cluster
  • A plot of the input data distributions, the rescaled data distributions, and the cluster size history and histogram (crashes QGIS in windows)

Or use the --script option to return the label_map and the pipeline object for further processing in another python script:

from fire2a.agglomerative_clustering import main
label_map, pipe1, pipe2 = main(["-d", "10.0", "-s"])
  1#!/usr/bin/env python3
  2# fmt: off
  3"""πŸ‘‹πŸŒŽ 🌲πŸ”₯
  4# Raster clustering
  5## Usage
  6### Overview
  71. Choose your raster files
  82. Configure nodata, scaling strategies and weights in the `config.toml` file
  93. Choose "distance threshold" (or "number of clusters") for the [Agglomerative](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) clustering algorithm. Recommended:
 10   - Start with a distance threshold of 10.0 and decrease for more or increase for less clusters
 11   - After calibrating the distance threshold; 
 12   - [Sieve](https://gdal.org/en/latest/programs/gdal_sieve.html) small clusters (merge them to the biggest neighbor) with the `--sieve integer_pixels_size` option 
 13
 14### Execution
 15```bash
 16# get command line help
 17python -m fire2a.agglomerative_clustering -h
 18python -m fire2a.agglomerative_clustering --help
 19
 20# activate your qgis dev environment
 21source ~/pyqgisdev/bin/activate 
 22# execute 
 23(qgis) $ python -m fire2a.agglomerative_clustering -d 10.0
 24
 25# windowsπŸ’© users should use QGIS's python
 26C:\\PROGRA~1\\QGIS33~1.3\\bin\\python-qgis.bat -m fire2a.agglomerative_clustering -d 10.0
 27```
 28[More info on: How to windows πŸ’© using qgis's python](https://github.com/fire2a/fire2a-lib/tree/main/qgis-launchers)
 29
 30### Preparation
 31#### 1. Choose your raster files
 32- Any [GDAL compatible](https://gdal.org/en/latest/drivers/raster/index.html) raster will be read
 33- Place them all in the same directory where the script will be executed
 34- "Quote them" if they have any non alphanumerical chars [a-zA-Z0-9]
 35
 36#### 2. Preprocessing configuration
 37See the `config.toml` file for example of the configuration of the preprocessing steps. The file is structured as follows:
 38
 39```toml
 40["filename.tif"]
 41no_data_strategy = "most_frequent"
 42scaling_strategy = "onehot"
 43fill_value = 0
 44weight = 1
 45```
 46
 471. __scaling_strategy__
 48   - can be "standard", "robust", "onehot"
 49   - default is "robust"
 50   - [Standard](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html): (x-mean)/stddev
 51   - [Robust](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html): same but droping the tails of the distribution
 52   - [OneHot](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html): __for CATEGORICAL DATA__
 53
 542. __no_data_strategy__
 55   - can be "mean", "median", "most_frequent", "constant"
 56   - default is "mean"
 57   - categorical data should use "most_frequent" or "constant"
 58   - "constant" will use the value in __fill_value__ (see below)
 59   - [SimpleImputer](https://scikit-learn.org/stable/modules/generated/sklearn.impute.SimpleImputer.html)
 60
 613. __fill_value__
 62   - used when __no_data_strategy__ is "constant"
 63   - default is 0
 64   - [SimpleImputer](https://scikit-learn.org/stable/modules/generated/sklearn.impute.SimpleImputer.html)
 65
 664. __weight__
 67   - default is 1
 68   - used to give more importance to some features than others
 69   - This is done after the nodata imputation and scaling steps, before clustering
 70
 71
 72#### 3. Clustering configuration
 73
 741. __Agglomerative__ clustering algorithm is used. The following parameters are muttually exclusive:
 75- `-n` or `--n_clusters`: The number of clusters to form as well as the number of centroids to generate.
 76- `-d` or `--distance_threshold`: The linkage distance threshold above which, clusters will not be merged. When scaling start with 10.0 and downward (0.0 is compute the whole algorithm).
 77- More [parameters](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) for clustering can be passed directly into the pipelie method as keyword arguments
 78
 792. __Sieve filter__ is applied to remove small clusters. The sieve filter is applied using the [GDAL sieve library](https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve)
 80- `--sieve`: Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor
 81
 82#### 4. Post-processing
 83Outputs can be:
 84- A raster file with the cluster labels and a polygon file with the cluster polygons
 85- A polygon file with the cluster polygons, with attribute being the number of pixels in each cluster
 86- A plot of the input data distributions, the rescaled data distributions, and the cluster size history and histogram (crashes QGIS in windows)
 87
 88Or use the `--script` option to return the label_map and the pipeline object for further processing in another python script:
 89    ```python
 90    from fire2a.agglomerative_clustering import main
 91    label_map, pipe1, pipe2 = main(["-d", "10.0", "-s"])
 92    ```
 93"""
 94# fmt: on
 95# from IPython.terminal.embed import InteractiveShellEmbed
 96# InteractiveShellEmbed()()
 97import logging
 98import sys
 99from pathlib import Path
100
101import numpy as np
102from sklearn.base import BaseEstimator, TransformerMixin
103from sklearn.cluster import AgglomerativeClustering
104from sklearn.compose import ColumnTransformer
105from sklearn.impute import SimpleImputer
106from sklearn.neighbors import radius_neighbors_graph
107from sklearn.pipeline import Pipeline
108from sklearn.preprocessing import OneHotEncoder, RobustScaler, StandardScaler
109
110from fire2a.utils import fprint, read_toml
111
112logger = logging.getLogger(__name__)
113
114
115def check_shapes(data_list):
116    """Check if all data arrays have the same shape and are 2D.
117    Returns the shape of the data arrays if they are all equal.
118    """
119    from functools import reduce
120
121    def equal_or_error(x, y):
122        """Check if x and y are equal, returns x if equal else raises a ValueError."""
123        if x == y:
124            return x
125        else:
126            raise ValueError("All data arrays must have the same shape")
127
128    shape = reduce(equal_or_error, (data.shape for data in data_list))
129    if len(shape) != 2:
130        raise ValueError("All data arrays must be 2D")
131    height, width = shape
132    return height, width
133
134
135def get_map_neighbors(height, width, num_neighbors=8):
136    """Get the neighbors of each cell in a 2D grid.
137    n_jobs=-1 uses all available cores.
138    """
139
140    grid_points = np.indices((height, width)).reshape(2, -1).T
141
142    nb4 = radius_neighbors_graph(grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1)
143    nb8 = radius_neighbors_graph(grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1)
144
145    # assert nb4.shape[0] == width * height
146    # assert nb8.shape[1] == width * height
147    # for n in range(width * height):
148    #     _, neighbors = np.nonzero(nb4[n])
149    #     assert 2<= len(neighbors) <= 4, f"{n=} {neighbors=}"
150    #     assert 3<= len(neighbors) <= 8, f"{n=} {neighbors=}"
151    return nb4, nb8
152
153
154class NoDataImputer(BaseEstimator, TransformerMixin):
155    """A custom Imputer that treats a specified nodata_value as np.nan and supports different strategies per column"""
156
157    def __init__(self, no_data_values, strategies, constants):
158        self.no_data_values = no_data_values
159        self.strategies = strategies
160        self.constants = constants
161        self.imputers = []
162        for no_data_value, strategy, constant in zip(no_data_values, strategies, constants):
163            if no_data_value:
164                self.imputers += [SimpleImputer(strategy=strategy, missing_values=no_data_value, fill_value=constant)]
165            else:
166                self.imputers += [SimpleImputer(strategy=strategy, fill_value=constant)]
167
168    def fit(self, X, y=None):
169        for i, imputer in enumerate(self.imputers):
170            imputer.fit(X[:, [i]], y)
171        return self
172
173    def transform(self, X):
174        for i, imputer in enumerate(self.imputers):
175            X[:, [i]] = imputer.transform(X[:, [i]])
176        self.output_data = X
177        return X
178
179
180class RescaleAllToCommonRange(BaseEstimator, TransformerMixin):
181    """A custom transformer that rescales all features to a common range [0, 1]"""
182
183    def __init__(self, weight_map):
184        self.weight_map = weight_map
185
186    def fit(self, X, y=None):
187        # Determine the combined range of all scaled features
188        self.min_val = [x.min() for x in X.T]
189        self.max_val = [x.max() for x in X.T]
190        return self
191
192    def transform(self, X):
193        # Rescale all features to match the common range
194        for i, (x, mi, ma) in enumerate(zip(X.T, self.min_val, self.max_val)):
195            if ma - mi == 0:
196                X.T[i] = x * self.weight_map[i]
197            else:
198                X.T[i] = (x - mi) / (ma - mi) * self.weight_map[i]
199        return X
200
201
202class CustomAgglomerativeClustering(BaseEstimator, TransformerMixin):
203    def __init__(self, height, width, neighbors=4, **kwargs):
204        self.height = height
205        self.width = width
206        self.neighbors = neighbors
207
208        self.grid_points = np.indices((height, width)).reshape(2, -1).T
209        if neighbors == 4:
210            connectivity = radius_neighbors_graph(
211                self.grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1
212            )
213        elif neighbors == 8:
214            connectivity = radius_neighbors_graph(
215                self.grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1
216            )
217
218        self.connectivity = connectivity
219        self.kwargs = kwargs
220        self.model = AgglomerativeClustering(connectivity=self.connectivity, **self.kwargs)
221
222    def fit(self, X, y=None):
223        logger.debug("not sure why, but this method is never called alas needed")
224        self.model.fit(X)
225        return self
226
227    def fit_predict(self, X, y=None):
228        self.input_data = X
229        return self.model.fit_predict(X)
230
231
232def pipelie(observations, info_list, height, width, **kwargs):
233    """A scipy pipeline to achieve Agglomerative Clustering with connectivity on 2d matrix
234    Steps are:
235    1. Impute missing values
236    2. Scale the features
237    3. Rescale all features to a common range
238    4. Cluster the data using Agglomerative Clustering with connectivity
239    5. Reshape the labels back to the original spatial map shape
240    6. Return the labels and the pipeline object
241
242    Args:
243        observations (np.ndarray): The input data to cluster (n_samples, n_features) shaped
244        info_list (list): A list of dictionaries containing information about each feature
245        height (int): The height of the spatial map
246        width (int): The width of the spatial map
247        kwargs: Additional keyword arguments for AgglomerativeClustering, at least one of n_clusters or distance_threshold
248
249    Returns:
250        np.ndarray: The labels of the clusters, reshaped to the original 2d spatial map shape
251        Pipeline: The pipeline object containing all the steps of the pipeline
252    """
253    # kwargs = {"n_clusters": args.n_clusters, "distance_threshold": args.distance_threshold}
254
255    # imputer strategies
256    no_data_values = [info["NoDataValue"] for info in info_list]
257    no_data_strategies = [info["no_data_strategy"] for info in info_list]
258    fill_values = [info["fill_value"] for info in info_list]
259    weights = [info["weight"] for info in info_list]
260    # scaling_strategies = [info["scaling_strategy"] for info in info_list]
261
262    # scaling strategies
263    index_map = {}
264    for strategy in ["robust", "standard", "onehot"]:
265        index_map[strategy] = [i for i, info in enumerate(info_list) if info["scaling_strategy"] == strategy]
266    # index_map
267    # !cat config.toml
268
269    # Create transformers for each type
270    robust_transformer = Pipeline(steps=[("robust_step", RobustScaler())])
271    standard_transformer = Pipeline(steps=[("standard_step", StandardScaler())])
272    onehot_transformer = Pipeline(steps=[("onehot_step", OneHotEncoder(sparse_output=False))])
273    # OneHotEncoder._n_features_outs):
274
275    # Combine transformers using ColumnTransformer
276    feature_scaler = ColumnTransformer(
277        transformers=[
278            ("robust", robust_transformer, index_map["robust"]),
279            ("standard", standard_transformer, index_map["standard"]),
280            ("onehot", onehot_transformer, index_map["onehot"]),
281        ]
282    )
283
284    # # Create a temporary directory for caching calculations
285    # # FOR ACCESING STEPS LATER ON VERY LARGE DATASETS
286    # import tempfile
287    # import joblib
288    # temp_dir = tempfile.mkdtemp()
289    # memory = joblib.Memory(location=temp_dir, verbose=0)
290
291    # Create and apply the pipeline
292    # part 1 until feature scaling
293    pipe1 = Pipeline(
294        # n_features_in_ : int
295        # feature_names_in_ : ndarray of shape (`n_features_in_`,)
296        steps=[
297            ("no_data_imputer", NoDataImputer(no_data_values, no_data_strategies, fill_values)),
298            ("feature_scaling", feature_scaler),
299        ],
300        # memory=memory,
301        verbose=True,
302    )
303    # map weights to new columns (onehot feature scaler creates one column per category)
304    obs1 = pipe1.fit_transform(observations)
305    cat_names = pipe1.named_steps["feature_scaling"]["onehot"].get_feature_names_out()
306    split_names = [name.split("_")[0] for name in cat_names]
307    cat_count = np.unique(split_names, return_counts=True)[1]
308    onehot_map = {}
309    for i, key in enumerate(index_map["onehot"]):
310        onehot_map[key] = cat_count[i]
311    # onehot_map = {key: cat_count[i] for i, key in enumerate(index_map["onehot"])}
312    weight_map = []
313    for name, idxs in index_map.items():
314        for idx in idxs:
315            if name == "onehot":
316                weight_map += [weights[idx]] * onehot_map[idx]
317                continue
318            weight_map += [weights[idx]]
319    # part 2 use weight_map and cluster
320    pipe2 = Pipeline(
321        steps=[
322            ("common_rescaling", RescaleAllToCommonRange(weight_map)),
323            ("agglomerative_clustering", CustomAgglomerativeClustering(height, width, neighbors=4, **kwargs)),
324        ],
325        # memory=memory,
326        verbose=True,
327    )
328
329    # apply pipeLIE
330    labels = pipe2.fit_predict(obs1)
331
332    # Reshape the labels back to the original spatial map shape
333    labels_reshaped = labels.reshape(height, width)
334    return labels_reshaped, pipe1, pipe2
335
336
337def write(
338    label_map,
339    width,
340    height,
341    output_raster="",
342    output_poly="output.shp",
343    authid="EPSG:3857",
344    geotransform=(0, 1, 0, 0, 0, 1),
345    nodata=None,
346    feedback=None,
347):
348    from osgeo import gdal, ogr, osr
349
350    from fire2a.processing_utils import get_output_raster_format, get_vector_driver_from_filename
351
352    # setup drivers for raster and polygon output formats
353    if output_raster == "":
354        raster_driver = "MEM"
355    else:
356        try:
357            raster_driver = get_output_raster_format(output_raster, feedback=feedback)
358        except Exception:
359            raster_driver = "GTiff"
360    try:
361        poly_driver = get_vector_driver_from_filename(output_poly)
362    except Exception:
363        poly_driver = "ESRI Shapefile"
364
365    # create raster output
366    src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, gdal.GDT_Int64)
367    src_ds.SetGeoTransform(geotransform)  # != 0 ?
368    src_ds.SetProjection(authid)  # != 0 ?
369    #  src_band = src_ds.GetRasterBand(1)
370    #  if nodata:
371    #      src_band.SetNoDataValue(nodata)
372    #  src_band.WriteArray(label_map)
373
374    # create polygon output
375    drv = ogr.GetDriverByName(poly_driver)
376    dst_ds = drv.CreateDataSource(output_poly)
377    sp_ref = osr.SpatialReference()
378    sp_ref.SetFromUserInput(authid)  # != 0 ?
379    dst_lyr = dst_ds.CreateLayer("clusters", srs=sp_ref, geom_type=ogr.wkbPolygon)
380    dst_lyr.CreateField(ogr.FieldDefn("DN", ogr.OFTInteger64))  # != 0 ?
381    dst_lyr.CreateField(ogr.FieldDefn("pixel_count", ogr.OFTInteger64))
382    # dst_lyr.CreateField(ogr.FieldDefn("area", ogr.OFTInteger))
383    # dst_lyr.CreateField(ogr.FieldDefn("perimeter", ogr.OFTInteger))
384
385    # 0 != gdal.Polygonize( srcband, maskband, dst_layer, dst_field, options, callback = gdal.TermProgress)
386
387    # FAIL: All together it merges labels into a single polygon
388    #  src_band = src_ds.GetRasterBand(1)
389    #  if nodata:
390    #      src_band.SetNoDataValue(nodata)
391    #  src_band.WriteArray(label_map)
392    # gdal.Polygonize(src_band, None, dst_lyr, 0, callback=gdal.TermProgress)  # , ["8CONNECTED=8"])
393
394    # B separado
395    # for loop for creating each label_map value into a different polygonized feature
396    mem_drv = ogr.GetDriverByName("Memory")
397    tmp_ds = mem_drv.CreateDataSource("tmp_ds")
398    # itera = iter(np.unique(label_map))
399    # cluster_id = next(itera)
400    areas = []
401    pixels = []
402    data = np.zeros_like(label_map)
403    for cluster_id, px_count in zip(*np.unique(label_map, return_counts=True)):
404        # temporarily write band
405        src_band = src_ds.GetRasterBand(1)
406        src_band.SetNoDataValue(0)
407        data[label_map == cluster_id] = 1
408        src_band.WriteArray(data)
409        # create feature
410        tmp_lyr = tmp_ds.CreateLayer("", srs=sp_ref)
411        gdal.Polygonize(src_band, src_band.GetMaskBand(), tmp_lyr, -1)
412        # unset tmp data
413        data[label_map == cluster_id] = 0
414        # set polygon feat
415        feat = tmp_lyr.GetNextFeature()
416        geom = feat.GetGeometryRef()
417        featureDefn = dst_lyr.GetLayerDefn()
418        feature = ogr.Feature(featureDefn)
419        feature.SetGeometry(geom)
420        feature.SetField("DN", float(cluster_id))
421        areas += [geom.GetArea()]
422        pixels += [px_count]
423        feature.SetField("pixel_count", float(px_count))
424        # feature.SetField("area", int(geom.GetArea()))
425        # feature.SetField("perimeter", int(geom.Boundary().Length()))
426        dst_lyr.CreateFeature(feature)
427
428    fprint(f"Polygon Areas: {min(areas)=} {max(areas)=}", level="info", feedback=feedback, logger=logger)
429    fprint(f"Cluster PixelCounts: {min(pixels)=} {max(pixels)=}", level="info", feedback=feedback, logger=logger)
430    # RESTART RASTER
431    # src_ds = None
432    # src_band = None
433    # src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, gdal.GDT_Int64)
434    # src_ds.SetGeoTransform(geotransform)  # != 0 ?
435    # src_ds.SetProjection(authid)  # != 0 ?
436    src_band = src_ds.GetRasterBand(1)
437    if nodata:
438        src_band.SetNoDataValue(nodata)
439    else:
440        # useless paranoia ?
441        src_band.SetNoDataValue(-1)
442    src_band.WriteArray(label_map)
443    # close datasets
444    src_ds.FlushCache()
445    src_ds = None
446    dst_ds.FlushCache()
447    dst_ds = None
448    return True
449
450
451def plot(labels_reshaped, pipe1, pipe2, info_list, **kwargs):
452    """Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram.
453    Args:
454        labels_reshaped (np.ndarray): The reshaped labels of the clusters
455        pipe1 (Pipeline): The first pipeline object containing imputer and feature scaling steps
456        pipe2 (Pipeline): The second pipeline object containing the rescaling and clustering steps
457        info_list (list): A list of dictionaries containing information about each feature
458        **kargs: Additional keyword arguments
459            n_clusters (int): The number of clusters
460            distance_threshold (float): The linkage distance threshold
461            sieve (int): The number of pixels to use as a sieve filter
462            block (bool): Block the execution until the plot window is closed
463            filename (str): The filename to save the plot
464    """
465    from matplotlib import pyplot as plt
466
467    no_data_imputed = pipe1.named_steps["no_data_imputer"].output_data
468    pre_aggclu_data = pipe2.named_steps["agglomerative_clustering"].input_data
469
470    # filtrar onehot
471    num_onehots = sum([1 for i in info_list if i["scaling_strategy"] == "onehot"])
472    num_no_onehots = len(info_list) - num_onehots
473    pre_aggclu_data = pre_aggclu_data[:, :num_no_onehots]
474
475    # indices sin onehots
476    nohots_idxs = [i for i, info in enumerate(info_list) if info["scaling_strategy"] != "onehot"]
477
478    # filtrar onehot de no_data_treated
479    no_data_imputed = no_data_imputed[:, nohots_idxs]
480
481    # reordenados en robust y despues standard
482    rob_std_idxs = [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "robust"]
483    rob_std_idxs += [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "standard"]
484
485    # reordenar rob then std
486    pre_aggclu_data = pre_aggclu_data[:, rob_std_idxs]
487
488    names = [info_list[i]["fname"] for i in nohots_idxs]
489
490    fgs = np.array(plt.rcParams["figure.figsize"]) * 5
491    fig, axs = plt.subplots(3, 2, figsize=fgs)
492    suptitle = ""
493    if n_clusters := kwargs.get("n_clusters"):
494        suptitle = f"n_clusters: {n_clusters}"
495    if distance_threshold := kwargs.get("distance_threshold"):
496        suptitle = f"distance_threshold: {distance_threshold}"
497    if sieve := kwargs.get("sieve"):
498        suptitle += f", sieve: {sieve}"
499    if n_clusters or distance_threshold or sieve:
500        suptitle += f", resulting clusters: {len(np.unique(labels_reshaped))}"
501    suptitle += "\n(Not showing categorical data)"
502    fig.suptitle(suptitle)
503
504    # plot violin plot
505    axs[0, 0].violinplot(no_data_imputed, showmeans=False, showmedians=True, showextrema=True)
506    axs[0, 0].set_title("Violin Plot of NoData Imputed")
507    axs[0, 0].yaxis.grid(True)
508    axs[0, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
509    axs[0, 0].set_ylabel("Observed values")
510
511    # plot boxplot
512    axs[0, 1].boxplot(no_data_imputed)
513    axs[0, 1].set_title("Box Plot of NoData Imputed")
514    axs[0, 1].yaxis.grid(True)
515    axs[0, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
516    axs[0, 1].set_ylabel("Observed values")
517
518    # plot violin plot
519    axs[1, 0].violinplot(pre_aggclu_data, showmeans=False, showmedians=True, showextrema=True)
520    axs[1, 0].set_title("Violin Plot of Common Rescaled")
521    axs[1, 0].yaxis.grid(True)
522    axs[1, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
523    axs[0, 1].set_ylabel("Adjusted range")
524
525    # plot boxplot
526    axs[1, 1].boxplot(pre_aggclu_data)
527    axs[1, 1].set_title("Box Plot of Common Rescaled")
528    axs[1, 1].yaxis.grid(True)
529    axs[1, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
530    axs[0, 1].set_ylabel("Adjusted range")
531
532    # cluster history
533    unique_labels, counts = np.unique(labels_reshaped, return_counts=True)
534    axs[2, 0].plot(unique_labels, counts, marker="o", color="blue")
535    axs[2, 0].set_title("Cluster history size (in pixels)")
536    axs[2, 0].set_xlabel("Algorithm Step")
537    axs[2, 0].set_ylabel("Size (in pixels)")
538
539    # cluster histogram
540    axs[2, 1].hist(counts, log=True)
541    axs[2, 1].set_xlabel("Cluster Size (in pixels)")
542    axs[2, 1].set_ylabel("Number of Clusters")
543    axs[2, 1].set_title("Histogram of Cluster Sizes")
544
545    plt.tight_layout()
546    if filename := kwargs.get("filename"):
547        logger.info(f"Saving plot to {filename}")
548        plt.savefig(filename)
549    else:
550        if block := kwargs.get("block"):
551            plt.show(block=block)
552        else:
553            plt.show()
554
555
556def sieve_filter(data, threshold=2, connectedness=4, feedback=None):
557    """Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve
558    Args:
559        data (np.ndarray): The input data to filter
560        threshold (int): The maximum number of pixels in a cluster to keep
561        connectedness (int): The number of connected pixels to consider when filtering 4 or 8
562        feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins
563    Returns:
564        np.ndarray: The filtered data
565    """
566    logger.info("Applying sieve filter")
567    from osgeo import gdal
568
569    height, width = data.shape
570    # fprint("antes", np.sort(np.unique(data, return_counts=True)), len(np.unique(data)), level="info", feedback=feedback, logger=logger)
571    num_clusters = len(np.unique(data))
572    src_ds = gdal.GetDriverByName("MEM").Create("sieve", width, height, 1, gdal.GDT_Int64)
573    src_band = src_ds.GetRasterBand(1)
574    src_band.WriteArray(data)
575    if 0 != gdal.SieveFilter(src_band, None, src_band, threshold, connectedness):
576        fprint("Error applying sieve filter", level="error", feedback=feedback, logger=logger)
577    else:
578        sieved = src_band.ReadAsArray()
579        src_band = None
580        src_ds = None
581        num_sieved = len(np.unique(sieved))
582        # fprint("despues", np.sort(np.unique(sieved, return_counts=True)), len(np.unique(sieved)), level="info", feedback=feedback, logger=logger)
583        fprint(
584            f"Reduced from {num_clusters} to {num_sieved} clusters, {num_clusters-num_sieved} less",
585            level="info",
586            feedback=feedback,
587            logger=logger,
588        )
589        fprint(
590            "Please try again increasing distance_threshold or reducing n_clusters instead...",
591            level="info",
592            feedback=feedback,
593            logger=logger,
594        )
595        # from matplotlib import pyplot as plt
596        # fig, (ax1, ax2) = plt.subplots(1, 2)
597        # ax1.imshow(data)
598        # ax1.set_title("before sieve" + str(len(np.unique(data))))
599        # ax2.imshow(sieved)
600        # ax2.set_title("after sieve" + str(len(np.unique(sieved))))
601        # plt.show()
602        # data = sieved
603        return sieved
604
605
606def arg_parser(argv=None):
607    """Parse command line arguments."""
608    from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
609
610    parser = ArgumentParser(
611        description="Agglomerative Clustering with Connectivity for raster data",
612        formatter_class=ArgumentDefaultsHelpFormatter,
613        epilog="More at https://fire2a.github.io/fire2a-lib",
614    )
615    parser.add_argument(
616        "config_file",
617        nargs="?",
618        type=Path,
619        help="For each raster file, configure its preprocess: nodata & scaling methods",
620        default="config.toml",
621    )
622
623    aggclu = parser.add_mutually_exclusive_group(required=True)
624    aggclu.add_argument(
625        "-d",
626        "--distance_threshold",
627        type=float,
628        help="Distance threshold (a good starting point when scaling is 10, higher means less clusters, 0 could take a long time)",
629    )
630    aggclu.add_argument("-n", "--n_clusters", type=int, help="Number of clusters")
631
632    parser.add_argument("-or", "--output_raster", help="Output raster file, warning overwrites!", default="")
633    parser.add_argument("-op", "--output_poly", help="Output polygons file, warning overwrites!", default="output.gpkg")
634    parser.add_argument("-a", "--authid", type=str, help="Output raster authid", default="EPSG:3857")
635    parser.add_argument(
636        "-g", "--geotransform", type=str, help="Output raster geotransform", default="(0, 1, 0, 0, 0, 1)"
637    )
638    parser.add_argument(
639        "-nw",
640        "--no_write",
641        action="store_true",
642        help="Do not write outputs raster nor polygons",
643        default=False,
644    )
645    parser.add_argument(
646        "-s",
647        "--script",
648        action="store_true",
649        help="Run in script mode, returning the label_map and the pipeline object",
650        default=False,
651    )
652    parser.add_argument(
653        "--sieve",
654        type=int,
655        help="Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor",
656    )
657    parser.add_argument("--verbose", "-v", action="count", default=0, help="WARNING:1, INFO:2, DEBUG:3")
658
659    plot = parser.add_argument_group(
660        "Plotting, Visually inspect input distributions: NoData treated observations, rescaled data, with violing plots and boxplots. Also check output clustering size history and histograms"
661    )
662    plot.add_argument(
663        "-p",
664        "--plots",
665        action="store_true",
666        help="Activate the plotting routines",
667    )
668    plot.add_argument(
669        "-b",
670        "--block",
671        action="store_false",
672        default=True,
673        help="Block the execution until the plot window is closed. Use False for interactive ipykernels or QGIS",
674    )
675    plot.add_argument(
676        "-f",
677        "--filename",
678        type=str,
679        help="Filename to save the plot. If not provided, matplotlib will raise a window",
680    )
681    args = parser.parse_args(argv)
682    args.geotransform = tuple(map(float, args.geotransform[1:-1].split(",")))
683    if Path(args.config_file).is_file() is False:
684        parser.error(f"File {args.config_file} not found")
685    return args
686
687
688def main(argv=None):
689    """
690
691    args = arg_parser(["-d","10.0", "-g","(0, 10, 0, 0, 0, 10)", "config2.toml"])
692    args = arg_parser(["-d","10.0"]])
693    args = arg_parser(["-d","10.0", "config2.toml"])
694    args = arg_parser(["-n","10"])
695    """
696    if argv is sys.argv:
697        argv = sys.argv[1:]
698    args = arg_parser(argv)
699
700    if args.verbose != 0:
701        global logger
702        from fire2a import setup_logger
703
704        logger = setup_logger(verbosity=args.verbose)
705
706    logger.info("args %s", args)
707
708    # 2 LEE CONFIG
709    config = read_toml(args.config_file)
710    # logger.debug(config)
711
712    # 2.1 ADD DEFAULTS
713    for filename, file_config in config.items():
714        if "no_data_strategy" not in file_config:
715            config[filename]["no_data_strategy"] = "mean"
716        if "scaling_strategy" not in file_config:
717            config[filename]["scaling_strategy"] = "robust"
718        if "fill_value" not in file_config:
719            config[filename]["fill_value"] = 0
720        if "weight" not in file_config:
721            config[filename]["weight"] = 1
722    logger.debug(config)
723
724    # 3. LEE DATA
725    from fire2a.raster import read_raster
726
727    data_list, info_list = [], []
728    for filename, file_config in config.items():
729        data, info = read_raster(filename)
730        info["fname"] = Path(filename).name
731        info["no_data_strategy"] = file_config["no_data_strategy"]
732        info["scaling_strategy"] = file_config["scaling_strategy"]
733        info["fill_value"] = file_config["fill_value"]
734        info["weight"] = file_config["weight"]
735        data_list += [data]
736        info_list += [info]
737        logger.debug("%s", data[:2, :2])
738        logger.debug("%s", info)
739
740    # 4. VALIDAR 2d todos mismo shape
741    height, width = check_shapes(data_list)
742
743    # 5. lista[mapas] -> OBSERVACIONES
744    observations = np.column_stack([data.ravel() for data in data_list])
745
746    # 6. nodata -> feature scaling -> all scaling -> clustering
747    labels_reshaped, pipe1, pipe2 = pipelie(
748        observations,
749        info_list,
750        height,
751        width,
752        n_clusters=args.n_clusters,
753        distance_threshold=args.distance_threshold,
754    )  # insert more keyworded arguments to the clustering algorithm here!
755
756    # SIEVE
757    if args.sieve:
758        logger.info(f"Number of clusters before sieving: {len(np.unique(labels_reshaped))}")
759        labels_reshaped = sieve_filter(labels_reshaped, args.sieve)
760
761    logger.info(f"Final number of clusters: {len(np.unique(labels_reshaped))}")
762
763    # 7 debbuging plots
764    if args.plots:
765        plot(labels_reshaped, pipe1, pipe2, info_list, **vars(args))
766
767    # 8. ESCRIBIR RASTER
768    if not args.no_write:
769        if not write(
770            labels_reshaped,
771            width,
772            height,
773            output_raster=args.output_raster,
774            output_poly=args.output_poly,
775            authid=args.authid,
776            geotransform=args.geotransform,
777        ):
778            logger.error("Error writing output raster")
779
780    # 9. SCRIPT MODE
781    if args.script:
782        return labels_reshaped, pipe1, pipe2
783
784    return 0
785
786
787if __name__ == "__main__":
788    sys.exit(main(sys.argv))
logger = <Logger fire2a.agglomerative_clustering (WARNING)>
def check_shapes(data_list):
116def check_shapes(data_list):
117    """Check if all data arrays have the same shape and are 2D.
118    Returns the shape of the data arrays if they are all equal.
119    """
120    from functools import reduce
121
122    def equal_or_error(x, y):
123        """Check if x and y are equal, returns x if equal else raises a ValueError."""
124        if x == y:
125            return x
126        else:
127            raise ValueError("All data arrays must have the same shape")
128
129    shape = reduce(equal_or_error, (data.shape for data in data_list))
130    if len(shape) != 2:
131        raise ValueError("All data arrays must be 2D")
132    height, width = shape
133    return height, width

Check if all data arrays have the same shape and are 2D. Returns the shape of the data arrays if they are all equal.

def get_map_neighbors(height, width, num_neighbors=8):
136def get_map_neighbors(height, width, num_neighbors=8):
137    """Get the neighbors of each cell in a 2D grid.
138    n_jobs=-1 uses all available cores.
139    """
140
141    grid_points = np.indices((height, width)).reshape(2, -1).T
142
143    nb4 = radius_neighbors_graph(grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1)
144    nb8 = radius_neighbors_graph(grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1)
145
146    # assert nb4.shape[0] == width * height
147    # assert nb8.shape[1] == width * height
148    # for n in range(width * height):
149    #     _, neighbors = np.nonzero(nb4[n])
150    #     assert 2<= len(neighbors) <= 4, f"{n=} {neighbors=}"
151    #     assert 3<= len(neighbors) <= 8, f"{n=} {neighbors=}"
152    return nb4, nb8

Get the neighbors of each cell in a 2D grid. n_jobs=-1 uses all available cores.

class NoDataImputer(sklearn.base.BaseEstimator, sklearn.base.TransformerMixin):
155class NoDataImputer(BaseEstimator, TransformerMixin):
156    """A custom Imputer that treats a specified nodata_value as np.nan and supports different strategies per column"""
157
158    def __init__(self, no_data_values, strategies, constants):
159        self.no_data_values = no_data_values
160        self.strategies = strategies
161        self.constants = constants
162        self.imputers = []
163        for no_data_value, strategy, constant in zip(no_data_values, strategies, constants):
164            if no_data_value:
165                self.imputers += [SimpleImputer(strategy=strategy, missing_values=no_data_value, fill_value=constant)]
166            else:
167                self.imputers += [SimpleImputer(strategy=strategy, fill_value=constant)]
168
169    def fit(self, X, y=None):
170        for i, imputer in enumerate(self.imputers):
171            imputer.fit(X[:, [i]], y)
172        return self
173
174    def transform(self, X):
175        for i, imputer in enumerate(self.imputers):
176            X[:, [i]] = imputer.transform(X[:, [i]])
177        self.output_data = X
178        return X

A custom Imputer that treats a specified nodata_value as np.nan and supports different strategies per column

NoDataImputer(no_data_values, strategies, constants)
158    def __init__(self, no_data_values, strategies, constants):
159        self.no_data_values = no_data_values
160        self.strategies = strategies
161        self.constants = constants
162        self.imputers = []
163        for no_data_value, strategy, constant in zip(no_data_values, strategies, constants):
164            if no_data_value:
165                self.imputers += [SimpleImputer(strategy=strategy, missing_values=no_data_value, fill_value=constant)]
166            else:
167                self.imputers += [SimpleImputer(strategy=strategy, fill_value=constant)]
no_data_values
strategies
constants
imputers
def fit(self, X, y=None):
169    def fit(self, X, y=None):
170        for i, imputer in enumerate(self.imputers):
171            imputer.fit(X[:, [i]], y)
172        return self
def transform(self, X):
174    def transform(self, X):
175        for i, imputer in enumerate(self.imputers):
176            X[:, [i]] = imputer.transform(X[:, [i]])
177        self.output_data = X
178        return X
class RescaleAllToCommonRange(sklearn.base.BaseEstimator, sklearn.base.TransformerMixin):
181class RescaleAllToCommonRange(BaseEstimator, TransformerMixin):
182    """A custom transformer that rescales all features to a common range [0, 1]"""
183
184    def __init__(self, weight_map):
185        self.weight_map = weight_map
186
187    def fit(self, X, y=None):
188        # Determine the combined range of all scaled features
189        self.min_val = [x.min() for x in X.T]
190        self.max_val = [x.max() for x in X.T]
191        return self
192
193    def transform(self, X):
194        # Rescale all features to match the common range
195        for i, (x, mi, ma) in enumerate(zip(X.T, self.min_val, self.max_val)):
196            if ma - mi == 0:
197                X.T[i] = x * self.weight_map[i]
198            else:
199                X.T[i] = (x - mi) / (ma - mi) * self.weight_map[i]
200        return X

A custom transformer that rescales all features to a common range [0, 1]

RescaleAllToCommonRange(weight_map)
184    def __init__(self, weight_map):
185        self.weight_map = weight_map
weight_map
def fit(self, X, y=None):
187    def fit(self, X, y=None):
188        # Determine the combined range of all scaled features
189        self.min_val = [x.min() for x in X.T]
190        self.max_val = [x.max() for x in X.T]
191        return self
def transform(self, X):
193    def transform(self, X):
194        # Rescale all features to match the common range
195        for i, (x, mi, ma) in enumerate(zip(X.T, self.min_val, self.max_val)):
196            if ma - mi == 0:
197                X.T[i] = x * self.weight_map[i]
198            else:
199                X.T[i] = (x - mi) / (ma - mi) * self.weight_map[i]
200        return X
class CustomAgglomerativeClustering(sklearn.base.BaseEstimator, sklearn.base.TransformerMixin):
203class CustomAgglomerativeClustering(BaseEstimator, TransformerMixin):
204    def __init__(self, height, width, neighbors=4, **kwargs):
205        self.height = height
206        self.width = width
207        self.neighbors = neighbors
208
209        self.grid_points = np.indices((height, width)).reshape(2, -1).T
210        if neighbors == 4:
211            connectivity = radius_neighbors_graph(
212                self.grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1
213            )
214        elif neighbors == 8:
215            connectivity = radius_neighbors_graph(
216                self.grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1
217            )
218
219        self.connectivity = connectivity
220        self.kwargs = kwargs
221        self.model = AgglomerativeClustering(connectivity=self.connectivity, **self.kwargs)
222
223    def fit(self, X, y=None):
224        logger.debug("not sure why, but this method is never called alas needed")
225        self.model.fit(X)
226        return self
227
228    def fit_predict(self, X, y=None):
229        self.input_data = X
230        return self.model.fit_predict(X)

Base class for all estimators in scikit-learn.

Inheriting from this class provides default implementations of:

  • setting and getting parameters used by GridSearchCV and friends;
  • textual and HTML representation displayed in terminals and IDEs;
  • estimator serialization;
  • parameters validation;
  • data validation;
  • feature names validation.

Read more in the :ref:User Guide <rolling_your_own_estimator>.

Notes

All estimators should specify all the parameters that can be set at the class level in their __init__ as explicit keyword arguments (no *args or **kwargs).

Examples

>>> import numpy as np
>>> from sklearn.base import BaseEstimator
>>> class MyEstimator(BaseEstimator):
...     def __init__(self, *, param=1):
...         self.param = param
...     def fit(self, X, y=None):
...         self.is_fitted_ = True
...         return self
...     def predict(self, X):
...         return np.full(shape=X.shape[0], fill_value=self.param)
>>> estimator = MyEstimator(param=2)
>>> estimator.get_params()
{'param': 2}
>>> X = np.array([[1, 2], [2, 3], [3, 4]])
>>> y = np.array([1, 0, 1])
>>> estimator.fit(X, y).predict(X)
array([2, 2, 2])
>>> estimator.set_params(param=3).fit(X, y).predict(X)
array([3, 3, 3])
CustomAgglomerativeClustering(height, width, neighbors=4, **kwargs)
204    def __init__(self, height, width, neighbors=4, **kwargs):
205        self.height = height
206        self.width = width
207        self.neighbors = neighbors
208
209        self.grid_points = np.indices((height, width)).reshape(2, -1).T
210        if neighbors == 4:
211            connectivity = radius_neighbors_graph(
212                self.grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1
213            )
214        elif neighbors == 8:
215            connectivity = radius_neighbors_graph(
216                self.grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1
217            )
218
219        self.connectivity = connectivity
220        self.kwargs = kwargs
221        self.model = AgglomerativeClustering(connectivity=self.connectivity, **self.kwargs)
height
width
neighbors
grid_points
connectivity
kwargs
model
def fit(self, X, y=None):
223    def fit(self, X, y=None):
224        logger.debug("not sure why, but this method is never called alas needed")
225        self.model.fit(X)
226        return self
def fit_predict(self, X, y=None):
228    def fit_predict(self, X, y=None):
229        self.input_data = X
230        return self.model.fit_predict(X)
def pipelie(observations, info_list, height, width, **kwargs):
233def pipelie(observations, info_list, height, width, **kwargs):
234    """A scipy pipeline to achieve Agglomerative Clustering with connectivity on 2d matrix
235    Steps are:
236    1. Impute missing values
237    2. Scale the features
238    3. Rescale all features to a common range
239    4. Cluster the data using Agglomerative Clustering with connectivity
240    5. Reshape the labels back to the original spatial map shape
241    6. Return the labels and the pipeline object
242
243    Args:
244        observations (np.ndarray): The input data to cluster (n_samples, n_features) shaped
245        info_list (list): A list of dictionaries containing information about each feature
246        height (int): The height of the spatial map
247        width (int): The width of the spatial map
248        kwargs: Additional keyword arguments for AgglomerativeClustering, at least one of n_clusters or distance_threshold
249
250    Returns:
251        np.ndarray: The labels of the clusters, reshaped to the original 2d spatial map shape
252        Pipeline: The pipeline object containing all the steps of the pipeline
253    """
254    # kwargs = {"n_clusters": args.n_clusters, "distance_threshold": args.distance_threshold}
255
256    # imputer strategies
257    no_data_values = [info["NoDataValue"] for info in info_list]
258    no_data_strategies = [info["no_data_strategy"] for info in info_list]
259    fill_values = [info["fill_value"] for info in info_list]
260    weights = [info["weight"] for info in info_list]
261    # scaling_strategies = [info["scaling_strategy"] for info in info_list]
262
263    # scaling strategies
264    index_map = {}
265    for strategy in ["robust", "standard", "onehot"]:
266        index_map[strategy] = [i for i, info in enumerate(info_list) if info["scaling_strategy"] == strategy]
267    # index_map
268    # !cat config.toml
269
270    # Create transformers for each type
271    robust_transformer = Pipeline(steps=[("robust_step", RobustScaler())])
272    standard_transformer = Pipeline(steps=[("standard_step", StandardScaler())])
273    onehot_transformer = Pipeline(steps=[("onehot_step", OneHotEncoder(sparse_output=False))])
274    # OneHotEncoder._n_features_outs):
275
276    # Combine transformers using ColumnTransformer
277    feature_scaler = ColumnTransformer(
278        transformers=[
279            ("robust", robust_transformer, index_map["robust"]),
280            ("standard", standard_transformer, index_map["standard"]),
281            ("onehot", onehot_transformer, index_map["onehot"]),
282        ]
283    )
284
285    # # Create a temporary directory for caching calculations
286    # # FOR ACCESING STEPS LATER ON VERY LARGE DATASETS
287    # import tempfile
288    # import joblib
289    # temp_dir = tempfile.mkdtemp()
290    # memory = joblib.Memory(location=temp_dir, verbose=0)
291
292    # Create and apply the pipeline
293    # part 1 until feature scaling
294    pipe1 = Pipeline(
295        # n_features_in_ : int
296        # feature_names_in_ : ndarray of shape (`n_features_in_`,)
297        steps=[
298            ("no_data_imputer", NoDataImputer(no_data_values, no_data_strategies, fill_values)),
299            ("feature_scaling", feature_scaler),
300        ],
301        # memory=memory,
302        verbose=True,
303    )
304    # map weights to new columns (onehot feature scaler creates one column per category)
305    obs1 = pipe1.fit_transform(observations)
306    cat_names = pipe1.named_steps["feature_scaling"]["onehot"].get_feature_names_out()
307    split_names = [name.split("_")[0] for name in cat_names]
308    cat_count = np.unique(split_names, return_counts=True)[1]
309    onehot_map = {}
310    for i, key in enumerate(index_map["onehot"]):
311        onehot_map[key] = cat_count[i]
312    # onehot_map = {key: cat_count[i] for i, key in enumerate(index_map["onehot"])}
313    weight_map = []
314    for name, idxs in index_map.items():
315        for idx in idxs:
316            if name == "onehot":
317                weight_map += [weights[idx]] * onehot_map[idx]
318                continue
319            weight_map += [weights[idx]]
320    # part 2 use weight_map and cluster
321    pipe2 = Pipeline(
322        steps=[
323            ("common_rescaling", RescaleAllToCommonRange(weight_map)),
324            ("agglomerative_clustering", CustomAgglomerativeClustering(height, width, neighbors=4, **kwargs)),
325        ],
326        # memory=memory,
327        verbose=True,
328    )
329
330    # apply pipeLIE
331    labels = pipe2.fit_predict(obs1)
332
333    # Reshape the labels back to the original spatial map shape
334    labels_reshaped = labels.reshape(height, width)
335    return labels_reshaped, pipe1, pipe2

A scipy pipeline to achieve Agglomerative Clustering with connectivity on 2d matrix Steps are:

  1. Impute missing values
  2. Scale the features
  3. Rescale all features to a common range
  4. Cluster the data using Agglomerative Clustering with connectivity
  5. Reshape the labels back to the original spatial map shape
  6. Return the labels and the pipeline object

Args: observations (np.ndarray): The input data to cluster (n_samples, n_features) shaped info_list (list): A list of dictionaries containing information about each feature height (int): The height of the spatial map width (int): The width of the spatial map kwargs: Additional keyword arguments for AgglomerativeClustering, at least one of n_clusters or distance_threshold

Returns: np.ndarray: The labels of the clusters, reshaped to the original 2d spatial map shape Pipeline: The pipeline object containing all the steps of the pipeline

def write( label_map, width, height, output_raster='', output_poly='output.shp', authid='EPSG:3857', geotransform=(0, 1, 0, 0, 0, 1), nodata=None, feedback=None):
338def write(
339    label_map,
340    width,
341    height,
342    output_raster="",
343    output_poly="output.shp",
344    authid="EPSG:3857",
345    geotransform=(0, 1, 0, 0, 0, 1),
346    nodata=None,
347    feedback=None,
348):
349    from osgeo import gdal, ogr, osr
350
351    from fire2a.processing_utils import get_output_raster_format, get_vector_driver_from_filename
352
353    # setup drivers for raster and polygon output formats
354    if output_raster == "":
355        raster_driver = "MEM"
356    else:
357        try:
358            raster_driver = get_output_raster_format(output_raster, feedback=feedback)
359        except Exception:
360            raster_driver = "GTiff"
361    try:
362        poly_driver = get_vector_driver_from_filename(output_poly)
363    except Exception:
364        poly_driver = "ESRI Shapefile"
365
366    # create raster output
367    src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, gdal.GDT_Int64)
368    src_ds.SetGeoTransform(geotransform)  # != 0 ?
369    src_ds.SetProjection(authid)  # != 0 ?
370    #  src_band = src_ds.GetRasterBand(1)
371    #  if nodata:
372    #      src_band.SetNoDataValue(nodata)
373    #  src_band.WriteArray(label_map)
374
375    # create polygon output
376    drv = ogr.GetDriverByName(poly_driver)
377    dst_ds = drv.CreateDataSource(output_poly)
378    sp_ref = osr.SpatialReference()
379    sp_ref.SetFromUserInput(authid)  # != 0 ?
380    dst_lyr = dst_ds.CreateLayer("clusters", srs=sp_ref, geom_type=ogr.wkbPolygon)
381    dst_lyr.CreateField(ogr.FieldDefn("DN", ogr.OFTInteger64))  # != 0 ?
382    dst_lyr.CreateField(ogr.FieldDefn("pixel_count", ogr.OFTInteger64))
383    # dst_lyr.CreateField(ogr.FieldDefn("area", ogr.OFTInteger))
384    # dst_lyr.CreateField(ogr.FieldDefn("perimeter", ogr.OFTInteger))
385
386    # 0 != gdal.Polygonize( srcband, maskband, dst_layer, dst_field, options, callback = gdal.TermProgress)
387
388    # FAIL: All together it merges labels into a single polygon
389    #  src_band = src_ds.GetRasterBand(1)
390    #  if nodata:
391    #      src_band.SetNoDataValue(nodata)
392    #  src_band.WriteArray(label_map)
393    # gdal.Polygonize(src_band, None, dst_lyr, 0, callback=gdal.TermProgress)  # , ["8CONNECTED=8"])
394
395    # B separado
396    # for loop for creating each label_map value into a different polygonized feature
397    mem_drv = ogr.GetDriverByName("Memory")
398    tmp_ds = mem_drv.CreateDataSource("tmp_ds")
399    # itera = iter(np.unique(label_map))
400    # cluster_id = next(itera)
401    areas = []
402    pixels = []
403    data = np.zeros_like(label_map)
404    for cluster_id, px_count in zip(*np.unique(label_map, return_counts=True)):
405        # temporarily write band
406        src_band = src_ds.GetRasterBand(1)
407        src_band.SetNoDataValue(0)
408        data[label_map == cluster_id] = 1
409        src_band.WriteArray(data)
410        # create feature
411        tmp_lyr = tmp_ds.CreateLayer("", srs=sp_ref)
412        gdal.Polygonize(src_band, src_band.GetMaskBand(), tmp_lyr, -1)
413        # unset tmp data
414        data[label_map == cluster_id] = 0
415        # set polygon feat
416        feat = tmp_lyr.GetNextFeature()
417        geom = feat.GetGeometryRef()
418        featureDefn = dst_lyr.GetLayerDefn()
419        feature = ogr.Feature(featureDefn)
420        feature.SetGeometry(geom)
421        feature.SetField("DN", float(cluster_id))
422        areas += [geom.GetArea()]
423        pixels += [px_count]
424        feature.SetField("pixel_count", float(px_count))
425        # feature.SetField("area", int(geom.GetArea()))
426        # feature.SetField("perimeter", int(geom.Boundary().Length()))
427        dst_lyr.CreateFeature(feature)
428
429    fprint(f"Polygon Areas: {min(areas)=} {max(areas)=}", level="info", feedback=feedback, logger=logger)
430    fprint(f"Cluster PixelCounts: {min(pixels)=} {max(pixels)=}", level="info", feedback=feedback, logger=logger)
431    # RESTART RASTER
432    # src_ds = None
433    # src_band = None
434    # src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, gdal.GDT_Int64)
435    # src_ds.SetGeoTransform(geotransform)  # != 0 ?
436    # src_ds.SetProjection(authid)  # != 0 ?
437    src_band = src_ds.GetRasterBand(1)
438    if nodata:
439        src_band.SetNoDataValue(nodata)
440    else:
441        # useless paranoia ?
442        src_band.SetNoDataValue(-1)
443    src_band.WriteArray(label_map)
444    # close datasets
445    src_ds.FlushCache()
446    src_ds = None
447    dst_ds.FlushCache()
448    dst_ds = None
449    return True
def plot(labels_reshaped, pipe1, pipe2, info_list, **kwargs):
452def plot(labels_reshaped, pipe1, pipe2, info_list, **kwargs):
453    """Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram.
454    Args:
455        labels_reshaped (np.ndarray): The reshaped labels of the clusters
456        pipe1 (Pipeline): The first pipeline object containing imputer and feature scaling steps
457        pipe2 (Pipeline): The second pipeline object containing the rescaling and clustering steps
458        info_list (list): A list of dictionaries containing information about each feature
459        **kargs: Additional keyword arguments
460            n_clusters (int): The number of clusters
461            distance_threshold (float): The linkage distance threshold
462            sieve (int): The number of pixels to use as a sieve filter
463            block (bool): Block the execution until the plot window is closed
464            filename (str): The filename to save the plot
465    """
466    from matplotlib import pyplot as plt
467
468    no_data_imputed = pipe1.named_steps["no_data_imputer"].output_data
469    pre_aggclu_data = pipe2.named_steps["agglomerative_clustering"].input_data
470
471    # filtrar onehot
472    num_onehots = sum([1 for i in info_list if i["scaling_strategy"] == "onehot"])
473    num_no_onehots = len(info_list) - num_onehots
474    pre_aggclu_data = pre_aggclu_data[:, :num_no_onehots]
475
476    # indices sin onehots
477    nohots_idxs = [i for i, info in enumerate(info_list) if info["scaling_strategy"] != "onehot"]
478
479    # filtrar onehot de no_data_treated
480    no_data_imputed = no_data_imputed[:, nohots_idxs]
481
482    # reordenados en robust y despues standard
483    rob_std_idxs = [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "robust"]
484    rob_std_idxs += [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "standard"]
485
486    # reordenar rob then std
487    pre_aggclu_data = pre_aggclu_data[:, rob_std_idxs]
488
489    names = [info_list[i]["fname"] for i in nohots_idxs]
490
491    fgs = np.array(plt.rcParams["figure.figsize"]) * 5
492    fig, axs = plt.subplots(3, 2, figsize=fgs)
493    suptitle = ""
494    if n_clusters := kwargs.get("n_clusters"):
495        suptitle = f"n_clusters: {n_clusters}"
496    if distance_threshold := kwargs.get("distance_threshold"):
497        suptitle = f"distance_threshold: {distance_threshold}"
498    if sieve := kwargs.get("sieve"):
499        suptitle += f", sieve: {sieve}"
500    if n_clusters or distance_threshold or sieve:
501        suptitle += f", resulting clusters: {len(np.unique(labels_reshaped))}"
502    suptitle += "\n(Not showing categorical data)"
503    fig.suptitle(suptitle)
504
505    # plot violin plot
506    axs[0, 0].violinplot(no_data_imputed, showmeans=False, showmedians=True, showextrema=True)
507    axs[0, 0].set_title("Violin Plot of NoData Imputed")
508    axs[0, 0].yaxis.grid(True)
509    axs[0, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
510    axs[0, 0].set_ylabel("Observed values")
511
512    # plot boxplot
513    axs[0, 1].boxplot(no_data_imputed)
514    axs[0, 1].set_title("Box Plot of NoData Imputed")
515    axs[0, 1].yaxis.grid(True)
516    axs[0, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
517    axs[0, 1].set_ylabel("Observed values")
518
519    # plot violin plot
520    axs[1, 0].violinplot(pre_aggclu_data, showmeans=False, showmedians=True, showextrema=True)
521    axs[1, 0].set_title("Violin Plot of Common Rescaled")
522    axs[1, 0].yaxis.grid(True)
523    axs[1, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
524    axs[0, 1].set_ylabel("Adjusted range")
525
526    # plot boxplot
527    axs[1, 1].boxplot(pre_aggclu_data)
528    axs[1, 1].set_title("Box Plot of Common Rescaled")
529    axs[1, 1].yaxis.grid(True)
530    axs[1, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
531    axs[0, 1].set_ylabel("Adjusted range")
532
533    # cluster history
534    unique_labels, counts = np.unique(labels_reshaped, return_counts=True)
535    axs[2, 0].plot(unique_labels, counts, marker="o", color="blue")
536    axs[2, 0].set_title("Cluster history size (in pixels)")
537    axs[2, 0].set_xlabel("Algorithm Step")
538    axs[2, 0].set_ylabel("Size (in pixels)")
539
540    # cluster histogram
541    axs[2, 1].hist(counts, log=True)
542    axs[2, 1].set_xlabel("Cluster Size (in pixels)")
543    axs[2, 1].set_ylabel("Number of Clusters")
544    axs[2, 1].set_title("Histogram of Cluster Sizes")
545
546    plt.tight_layout()
547    if filename := kwargs.get("filename"):
548        logger.info(f"Saving plot to {filename}")
549        plt.savefig(filename)
550    else:
551        if block := kwargs.get("block"):
552            plt.show(block=block)
553        else:
554            plt.show()

Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram. Args: labels_reshaped (np.ndarray): The reshaped labels of the clusters pipe1 (Pipeline): The first pipeline object containing imputer and feature scaling steps pipe2 (Pipeline): The second pipeline object containing the rescaling and clustering steps info_list (list): A list of dictionaries containing information about each feature **kargs: Additional keyword arguments n_clusters (int): The number of clusters distance_threshold (float): The linkage distance threshold sieve (int): The number of pixels to use as a sieve filter block (bool): Block the execution until the plot window is closed filename (str): The filename to save the plot

def sieve_filter(data, threshold=2, connectedness=4, feedback=None):
557def sieve_filter(data, threshold=2, connectedness=4, feedback=None):
558    """Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve
559    Args:
560        data (np.ndarray): The input data to filter
561        threshold (int): The maximum number of pixels in a cluster to keep
562        connectedness (int): The number of connected pixels to consider when filtering 4 or 8
563        feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins
564    Returns:
565        np.ndarray: The filtered data
566    """
567    logger.info("Applying sieve filter")
568    from osgeo import gdal
569
570    height, width = data.shape
571    # fprint("antes", np.sort(np.unique(data, return_counts=True)), len(np.unique(data)), level="info", feedback=feedback, logger=logger)
572    num_clusters = len(np.unique(data))
573    src_ds = gdal.GetDriverByName("MEM").Create("sieve", width, height, 1, gdal.GDT_Int64)
574    src_band = src_ds.GetRasterBand(1)
575    src_band.WriteArray(data)
576    if 0 != gdal.SieveFilter(src_band, None, src_band, threshold, connectedness):
577        fprint("Error applying sieve filter", level="error", feedback=feedback, logger=logger)
578    else:
579        sieved = src_band.ReadAsArray()
580        src_band = None
581        src_ds = None
582        num_sieved = len(np.unique(sieved))
583        # fprint("despues", np.sort(np.unique(sieved, return_counts=True)), len(np.unique(sieved)), level="info", feedback=feedback, logger=logger)
584        fprint(
585            f"Reduced from {num_clusters} to {num_sieved} clusters, {num_clusters-num_sieved} less",
586            level="info",
587            feedback=feedback,
588            logger=logger,
589        )
590        fprint(
591            "Please try again increasing distance_threshold or reducing n_clusters instead...",
592            level="info",
593            feedback=feedback,
594            logger=logger,
595        )
596        # from matplotlib import pyplot as plt
597        # fig, (ax1, ax2) = plt.subplots(1, 2)
598        # ax1.imshow(data)
599        # ax1.set_title("before sieve" + str(len(np.unique(data))))
600        # ax2.imshow(sieved)
601        # ax2.set_title("after sieve" + str(len(np.unique(sieved))))
602        # plt.show()
603        # data = sieved
604        return sieved

Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve Args: data (np.ndarray): The input data to filter threshold (int): The maximum number of pixels in a cluster to keep connectedness (int): The number of connected pixels to consider when filtering 4 or 8 feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins Returns: np.ndarray: The filtered data

def arg_parser(argv=None):
607def arg_parser(argv=None):
608    """Parse command line arguments."""
609    from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
610
611    parser = ArgumentParser(
612        description="Agglomerative Clustering with Connectivity for raster data",
613        formatter_class=ArgumentDefaultsHelpFormatter,
614        epilog="More at https://fire2a.github.io/fire2a-lib",
615    )
616    parser.add_argument(
617        "config_file",
618        nargs="?",
619        type=Path,
620        help="For each raster file, configure its preprocess: nodata & scaling methods",
621        default="config.toml",
622    )
623
624    aggclu = parser.add_mutually_exclusive_group(required=True)
625    aggclu.add_argument(
626        "-d",
627        "--distance_threshold",
628        type=float,
629        help="Distance threshold (a good starting point when scaling is 10, higher means less clusters, 0 could take a long time)",
630    )
631    aggclu.add_argument("-n", "--n_clusters", type=int, help="Number of clusters")
632
633    parser.add_argument("-or", "--output_raster", help="Output raster file, warning overwrites!", default="")
634    parser.add_argument("-op", "--output_poly", help="Output polygons file, warning overwrites!", default="output.gpkg")
635    parser.add_argument("-a", "--authid", type=str, help="Output raster authid", default="EPSG:3857")
636    parser.add_argument(
637        "-g", "--geotransform", type=str, help="Output raster geotransform", default="(0, 1, 0, 0, 0, 1)"
638    )
639    parser.add_argument(
640        "-nw",
641        "--no_write",
642        action="store_true",
643        help="Do not write outputs raster nor polygons",
644        default=False,
645    )
646    parser.add_argument(
647        "-s",
648        "--script",
649        action="store_true",
650        help="Run in script mode, returning the label_map and the pipeline object",
651        default=False,
652    )
653    parser.add_argument(
654        "--sieve",
655        type=int,
656        help="Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor",
657    )
658    parser.add_argument("--verbose", "-v", action="count", default=0, help="WARNING:1, INFO:2, DEBUG:3")
659
660    plot = parser.add_argument_group(
661        "Plotting, Visually inspect input distributions: NoData treated observations, rescaled data, with violing plots and boxplots. Also check output clustering size history and histograms"
662    )
663    plot.add_argument(
664        "-p",
665        "--plots",
666        action="store_true",
667        help="Activate the plotting routines",
668    )
669    plot.add_argument(
670        "-b",
671        "--block",
672        action="store_false",
673        default=True,
674        help="Block the execution until the plot window is closed. Use False for interactive ipykernels or QGIS",
675    )
676    plot.add_argument(
677        "-f",
678        "--filename",
679        type=str,
680        help="Filename to save the plot. If not provided, matplotlib will raise a window",
681    )
682    args = parser.parse_args(argv)
683    args.geotransform = tuple(map(float, args.geotransform[1:-1].split(",")))
684    if Path(args.config_file).is_file() is False:
685        parser.error(f"File {args.config_file} not found")
686    return args

Parse command line arguments.

def main(argv=None):
689def main(argv=None):
690    """
691
692    args = arg_parser(["-d","10.0", "-g","(0, 10, 0, 0, 0, 10)", "config2.toml"])
693    args = arg_parser(["-d","10.0"]])
694    args = arg_parser(["-d","10.0", "config2.toml"])
695    args = arg_parser(["-n","10"])
696    """
697    if argv is sys.argv:
698        argv = sys.argv[1:]
699    args = arg_parser(argv)
700
701    if args.verbose != 0:
702        global logger
703        from fire2a import setup_logger
704
705        logger = setup_logger(verbosity=args.verbose)
706
707    logger.info("args %s", args)
708
709    # 2 LEE CONFIG
710    config = read_toml(args.config_file)
711    # logger.debug(config)
712
713    # 2.1 ADD DEFAULTS
714    for filename, file_config in config.items():
715        if "no_data_strategy" not in file_config:
716            config[filename]["no_data_strategy"] = "mean"
717        if "scaling_strategy" not in file_config:
718            config[filename]["scaling_strategy"] = "robust"
719        if "fill_value" not in file_config:
720            config[filename]["fill_value"] = 0
721        if "weight" not in file_config:
722            config[filename]["weight"] = 1
723    logger.debug(config)
724
725    # 3. LEE DATA
726    from fire2a.raster import read_raster
727
728    data_list, info_list = [], []
729    for filename, file_config in config.items():
730        data, info = read_raster(filename)
731        info["fname"] = Path(filename).name
732        info["no_data_strategy"] = file_config["no_data_strategy"]
733        info["scaling_strategy"] = file_config["scaling_strategy"]
734        info["fill_value"] = file_config["fill_value"]
735        info["weight"] = file_config["weight"]
736        data_list += [data]
737        info_list += [info]
738        logger.debug("%s", data[:2, :2])
739        logger.debug("%s", info)
740
741    # 4. VALIDAR 2d todos mismo shape
742    height, width = check_shapes(data_list)
743
744    # 5. lista[mapas] -> OBSERVACIONES
745    observations = np.column_stack([data.ravel() for data in data_list])
746
747    # 6. nodata -> feature scaling -> all scaling -> clustering
748    labels_reshaped, pipe1, pipe2 = pipelie(
749        observations,
750        info_list,
751        height,
752        width,
753        n_clusters=args.n_clusters,
754        distance_threshold=args.distance_threshold,
755    )  # insert more keyworded arguments to the clustering algorithm here!
756
757    # SIEVE
758    if args.sieve:
759        logger.info(f"Number of clusters before sieving: {len(np.unique(labels_reshaped))}")
760        labels_reshaped = sieve_filter(labels_reshaped, args.sieve)
761
762    logger.info(f"Final number of clusters: {len(np.unique(labels_reshaped))}")
763
764    # 7 debbuging plots
765    if args.plots:
766        plot(labels_reshaped, pipe1, pipe2, info_list, **vars(args))
767
768    # 8. ESCRIBIR RASTER
769    if not args.no_write:
770        if not write(
771            labels_reshaped,
772            width,
773            height,
774            output_raster=args.output_raster,
775            output_poly=args.output_poly,
776            authid=args.authid,
777            geotransform=args.geotransform,
778        ):
779            logger.error("Error writing output raster")
780
781    # 9. SCRIPT MODE
782    if args.script:
783        return labels_reshaped, pipe1, pipe2
784
785    return 0

args = arg_parser(["-d","10.0", "-g","(0, 10, 0, 0, 0, 10)", "config2.toml"]) args = arg_parser(["-d","10.0"]]) args = arg_parser(["-d","10.0", "config2.toml"]) args = arg_parser(["-n","10"])