Skip to content

AGNES

toyml.clustering.agnes.AGNES dataclass

AGNES(
    n_cluster: int,
    linkage: Literal[
        "single", "complete", "average"
    ] = "single",
    distance_metric: Literal["euclidean"] = "euclidean",
    distance_matrix_: list[list[float]] = list(),
    clusters_: list[ClusterTree] = list(),
    labels_: list[int] = list(),
    cluster_tree_: Optional[ClusterTree] = None,
    linkage_matrix: list[list[float]] = list(),
    _cluster_index: int = 0,
)

Agglomerative clustering algorithm (Bottom-up Hierarchical Clustering)

Examples:

>>> from toyml.clustering import AGNES
>>> dataset = [[1, 0], [1, 1], [1, 2], [10, 0], [10, 1], [10, 2]]
>>> agnes = AGNES(n_cluster=2).fit(dataset)
>>> print(agnes.labels_)
[0, 0, 0, 1, 1, 1]
>>> # Using fit_predict method
>>> labels = agnes.fit_predict(dataset)
>>> print(labels)
[0, 0, 0, 1, 1, 1]
>>> # Using different linkage methods
>>> agnes = AGNES(n_cluster=2, linkage="complete").fit(dataset)
>>> print(agnes.labels_)
[0, 0, 0, 1, 1, 1]
>>> # Plotting dendrogram
>>> agnes = AGNES(n_cluster=1).fit(dataset)
>>> agnes.plot_dendrogram(show=True)
The AGNES Dendrogram Plot

AGNES Dendrogram

References
  1. Zhou Zhihua
  2. Tan

n_cluster instance-attribute

n_cluster: int

The number of clusters, specified by user.

linkage class-attribute instance-attribute

linkage: Literal["single", "complete", "average"] = "single"

The linkage method to use.

distance_metric class-attribute instance-attribute

distance_metric: Literal['euclidean'] = 'euclidean'

The distance metric to use.(For now we only support euclidean).

distance_matrix_ class-attribute instance-attribute

distance_matrix_: list[list[float]] = field(
    default_factory=list
)

The distance matrix.

clusters_ class-attribute instance-attribute

clusters_: list[ClusterTree] = field(default_factory=list)

The clusters.

labels_ class-attribute instance-attribute

labels_: list[int] = field(default_factory=list)

The labels of each sample.

fit

fit(dataset: list[list[float]]) -> AGNES

Fit the model.

Source code in toyml/clustering/agnes.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def fit(self, dataset: list[list[float]]) -> AGNES:
    """
    Fit the model.
    """
    self._validate(dataset)
    n = len(dataset)
    self.clusters_ = [ClusterTree(cluster_index=i, sample_indices=[i]) for i in range(n)]
    self._cluster_index = n
    self.distance_matrix_ = self._get_init_distance_matrix(dataset)
    while len(self.clusters_) > self.n_cluster:
        (i, j), cluster_ij_distance = self._get_closest_clusters()
        # merge cluster_i and cluster_j
        self._merge_clusters(i, j, cluster_ij_distance)
        # update distance matrix
        self._update_distance_matrix(dataset, i, j)
    # build cluster_tree_
    self.cluster_tree_ = self._build_cluster_tree(n)
    # assign dataset labels
    self._get_labels(len(dataset))
    return self

fit_predict

fit_predict(dataset: list[list[float]]) -> list[int]

Fit the model and return the labels of each sample.

Source code in toyml/clustering/agnes.py
90
91
92
93
94
95
def fit_predict(self, dataset: list[list[float]]) -> list[int]:
    """
    Fit the model and return the labels of each sample.
    """
    self.fit(dataset)
    return self.labels_

plot_dendrogram

plot_dendrogram(
    figure_name: str = "agnes_dendrogram.png",
    show: bool = False,
) -> None

Plot the dendrogram of the clustering result.

This method visualizes the hierarchical structure of the clustering using a dendrogram. It requires the number of clusters to be set to 1 during initialization.

PARAMETER DESCRIPTION
figure_name

The filename for saving the plot. Defaults to "agnes_dendrogram.png".

TYPE: str DEFAULT: 'agnes_dendrogram.png'

show

If True, displays the plot. Defaults to False.

TYPE: bool DEFAULT: False

RAISES DESCRIPTION
ValueError

If the number of clusters is not 1.

Note

This method requires matplotlib and scipy to be installed.

Source code in toyml/clustering/agnes.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def plot_dendrogram(
    self,
    figure_name: str = "agnes_dendrogram.png",
    show: bool = False,
) -> None:
    """
    Plot the dendrogram of the clustering result.

    This method visualizes the hierarchical structure of the clustering
    using a dendrogram. It requires the number of clusters to be set to 1
    during initialization.

    Args:
        figure_name: The filename for saving the plot.
                           Defaults to "agnes_dendrogram.png".
        show: If True, displays the plot. Defaults to False.

    Raises:
        ValueError: If the number of clusters is not 1.

    Note:
        This method requires matplotlib and scipy to be installed.
    """
    import matplotlib.pyplot as plt
    import numpy as np

    from scipy.cluster.hierarchy import dendrogram

    if self.n_cluster != 1:
        raise ValueError("The number of clusters should be 1 to plot dendrogram")
    # Plot the dendrogram
    plt.figure(figsize=(10, 7))
    dendrogram(np.array(self.linkage_matrix))
    plt.title("AGNES Dendrogram")
    plt.xlabel("Sample Index")
    plt.ylabel("Distance")
    plt.savefig(f"{figure_name}", dpi=300, bbox_inches="tight")
    if show:
        plt.show()

toyml.clustering.agnes.ClusterTree dataclass

ClusterTree(
    cluster_index: int,
    parent: Optional[ClusterTree] = None,
    children: list[ClusterTree] = list(),
    sample_indices: list[int] = list(),
    children_cluster_distance: Optional[float] = None,
)

Represents a node in the hierarchical clustering tree.

Each node is a cluster containing sample indices. Leaf nodes represent individual samples, while internal nodes represent merged clusters. The root node contains all samples.

parent class-attribute instance-attribute

parent: Optional[ClusterTree] = None

Parent node.

children class-attribute instance-attribute

children: list[ClusterTree] = field(default_factory=list)

Children nodes.

sample_indices class-attribute instance-attribute

sample_indices: list[int] = field(default_factory=list)

The cluster: dataset sample indices.