¿Cómo usar Scikit-Learn Label Propagation en datos estructurados de gráficos?

11

Como parte de mi investigación, estoy interesado en realizar la propagación de etiquetas en un gráfico. Estoy especialmente interesado en esos dos métodos:

Vi que scikit-learn ofrece un modelo para hacer eso. Sin embargo, se supone que este modelo se aplica a datos estructurados de vectores ( es decir , puntos de datos).

El modelo construye una matriz de afinidad a partir de los puntos de datos utilizando un núcleo, y luego ejecuta el algoritmo en la matriz construida. Me gustaría poder ingresar directamente la matriz de adyacencia de mi gráfico en lugar de la matriz de similitud.

¿Alguna idea sobre cómo lograr eso? ¿O conoce alguna biblioteca de Python que permita ejecutar la propagación de etiquetas directamente en datos estructurados de gráficos para los dos métodos antes mencionados?

¡Gracias de antemano por tu ayuda!

Thibaud Martinez
fuente
¿Ha verificado el código fuente de Scikit-learn para ver qué hace después de calcular la matriz de afinidad? Tal vez podría "copiar" el código después de esa parte para aplicarlo directamente a su matriz de adyacencia.
Tasos
¡Gracias por tu comentario! Entonces, en realidad, esto es lo que estoy haciendo actualmente, pero algunas partes del código que necesito modificar para satisfacer mis necesidades son algo crípticas. Me temo que reescribir esas partes provocará errores. Esperaba que existiera un método más directo.
Thibaud Martinez
1
El código fuente en github.com/scikit-learn/scikit-learn/blob/7389dba/sklearn/… - dice que las implementaciones deberían anular el método _build_graph. Entonces, de forma nativa, debe intentar crear una clase derivada que acepte una matriz calculada previamente.
mikalai

Respuestas:

2

Respondo mi propia pregunta aquí, ya que espero que sea útil para algunos lectores.

Scikit-learn está diseñado principalmente para manejar datos estructurados de vectores. Por lo tanto, si desea realizar propagación de etiquetas / difusión de etiquetas en datos estructurados con gráficos, probablemente sea mejor que vuelva a implementar el método usted mismo en lugar de usar la interfaz Scikit.

Aquí hay una implementación de Propagación de etiquetas y Distribución de etiquetas en PyTorch.

Los dos métodos en general siguen los mismos pasos algorítmicos, con variaciones sobre cómo se normaliza la matriz de adyacencia y cómo se propagan las etiquetas en cada paso. Por lo tanto, creemos una clase base para nuestros dos modelos.

from abc import abstractmethod
import torch

class BaseLabelPropagation:
    """Base class for label propagation models.

    Parameters
    ----------
    adj_matrix: torch.FloatTensor
        Adjacency matrix of the graph.
    """
    def __init__(self, adj_matrix):
        self.norm_adj_matrix = self._normalize(adj_matrix)
        self.n_nodes = adj_matrix.size(0)
        self.one_hot_labels = None 
        self.n_classes = None
        self.labeled_mask = None
        self.predictions = None

    @staticmethod
    @abstractmethod
    def _normalize(adj_matrix):
        raise NotImplementedError("_normalize must be implemented")

    @abstractmethod
    def _propagate(self):
        raise NotImplementedError("_propagate must be implemented")

    def _one_hot_encode(self, labels):
        # Get the number of classes
        classes = torch.unique(labels)
        classes = classes[classes != -1]
        self.n_classes = classes.size(0)

        # One-hot encode labeled data instances and zero rows corresponding to unlabeled instances
        unlabeled_mask = (labels == -1)
        labels = labels.clone()  # defensive copying
        labels[unlabeled_mask] = 0
        self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
        self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)
        self.one_hot_labels[unlabeled_mask, 0] = 0

        self.labeled_mask = ~unlabeled_mask

    def fit(self, labels, max_iter, tol):
        """Fits a semi-supervised learning label propagation model.

        labels: torch.LongTensor
            Tensor of size n_nodes indicating the class number of each node.
            Unlabeled nodes are denoted with -1.
        max_iter: int
            Maximum number of iterations allowed.
        tol: float
            Convergence tolerance: threshold to consider the system at steady state.
        """
        self._one_hot_encode(labels)

        self.predictions = self.one_hot_labels.clone()
        prev_predictions = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)

        for i in range(max_iter):
            # Stop iterations if the system is considered at a steady state
            variation = torch.abs(self.predictions - prev_predictions).sum().item()

            if variation < tol:
                print(f"The method stopped after {i} iterations, variation={variation:.4f}.")
                break

            prev_predictions = self.predictions
            self._propagate()

    def predict(self):
        return self.predictions

    def predict_classes(self):
        return self.predictions.max(dim=1).indices

El modelo toma como entrada la matriz de adyacencia del gráfico, así como las etiquetas de los nodos. Las etiquetas tienen la forma de un vector de un número entero que indica el número de clase de cada nodo con un -1 en la posición de los nodos sin etiqueta.

El algoritmo de propagación de etiquetas se presenta a continuación.

W: matriz de adyacencia del gráfico Calcule la matriz de grados diagonales re por reyoyojWyoj Inicializar Y^(0 0)(y1,...,yl,0 0,0 0,...,0 0) Iterar  1) Y^(t+1)re-1WY^(t) 2) Y^l(t+1)Yl hasta la convergencia a Y^() Punto de etiqueta Xyo por el signo de y^yo()

De Xiaojin Zhu y Zoubin Ghahramani. Aprendizaje de datos etiquetados y no etiquetados con propagación de etiquetas Informe técnico CMU-CALD-02-107, Universidad Carnegie Mellon, 2002

Obtenemos la siguiente implementación.

class LabelPropagation(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1 * W"""
        degs = adj_matrix.sum(dim=1)
        degs[degs == 0] = 1  # avoid division by 0 error
        return adj_matrix / degs[:, None]

    def _propagate(self):
        self.predictions = torch.matmul(self.norm_adj_matrix, self.predictions)

        # Put back already known labels
        self.predictions[self.labeled_mask] = self.one_hot_labels[self.labeled_mask]

    def fit(self, labels, max_iter=1000, tol=1e-3):
        super().fit(labels, max_iter, tol)

El algoritmo de difusión de etiquetas es:

W: matriz de adyacencia del gráfico Calcule la matriz de grados diagonales re por reyoyojWyoj Calcule el gráfico normalizado laplaciano Lre-1/ /2Wre-1/ /2 Inicializar Y^(0 0)(y1,...,yl,0 0,0 0,...,0 0) Elige un parámetro α[0 0,1) Iterar Y^(t+1)αLY^(t)+(1-α)Y^(0 0) hasta la convergencia a Y^() Punto de etiqueta Xyo por el signo de y^yo()

De Dengyong Zhou, Olivier Bousquet, Thomas Navin Lal, Jason Weston, Bernhard Schoelkopf. Aprendizaje con consistencia local y global (2004)

La implementación es, por lo tanto, la siguiente.

class LabelSpreading(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)
        self.alpha = None

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1/2 * W * D^-1/2"""
        degs = adj_matrix.sum(dim=1)
        norm = torch.pow(degs, -0.5)
        norm[torch.isinf(norm)] = 1
        return adj_matrix * norm[:, None] * norm[None, :]

    def _propagate(self):
        self.predictions = (
            self.alpha * torch.matmul(self.norm_adj_matrix, self.predictions)
            + (1 - self.alpha) * self.one_hot_labels
        )

    def fit(self, labels, max_iter=1000, tol=1e-3, alpha=0.5):
        """
        Parameters
        ----------
        alpha: float
            Clamping factor.
        """
        self.alpha = alpha
        super().fit(labels, max_iter, tol)

Probemos ahora nuestros modelos de propagación en datos sintéticos. Para hacerlo, elegimos usar un gráfico de hombre de las cavernas .

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

# Create caveman graph
n_cliques = 4
size_cliques = 10
caveman_graph = nx.connected_caveman_graph(n_cliques, size_cliques)
adj_matrix = nx.adjacency_matrix(caveman_graph).toarray()

# Create labels
labels = np.full(n_cliques * size_cliques, -1.)

# Only one node per clique is labeled. Each clique belongs to a different class.
labels[0] = 0
labels[size_cliques] = 1
labels[size_cliques * 2] = 2
labels[size_cliques * 3] = 3

# Create input tensors
adj_matrix_t = torch.FloatTensor(adj_matrix)
labels_t = torch.LongTensor(labels)

# Learn with Label Propagation
label_propagation = LabelPropagation(adj_matrix_t)
label_propagation.fit(labels_t)
label_propagation_output_labels = label_propagation.predict_classes()

# Learn with Label Spreading
label_spreading = LabelSpreading(adj_matrix_t)
label_spreading.fit(labels_t, alpha=0.8)
label_spreading_output_labels = label_spreading.predict_classes()

# Plot graphs
color_map = {-1: "orange", 0: "blue", 1: "green", 2: "red", 3: "cyan"}
input_labels_colors = [color_map[l] for l in labels]
lprop_labels_colors = [color_map[l] for l in label_propagation_output_labels.numpy()]
lspread_labels_colors = [color_map[l] for l in label_spreading_output_labels.numpy()]

plt.figure(figsize=(14, 6))
ax1 = plt.subplot(1, 4, 1)
ax2 = plt.subplot(1, 4, 2)
ax3 = plt.subplot(1, 4, 3)

ax1.title.set_text("Raw data (4 classes)")
ax2.title.set_text("Label Propagation")
ax3.title.set_text("Label Spreading")

pos = nx.spring_layout(caveman_graph)
nx.draw(caveman_graph, ax=ax1, pos=pos, node_color=input_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax2, pos=pos, node_color=lprop_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax3, pos=pos, node_color=lspread_labels_colors, node_size=50)

# Legend
ax4 = plt.subplot(1, 4, 4)
ax4.axis("off")
legend_colors = ["orange", "blue", "green", "red", "cyan"]
legend_labels = ["unlabeled", "class 0", "class 1", "class 2", "class 3"]
dummy_legend = [ax4.plot([], [], ls='-', c=c)[0] for c in legend_colors]
plt.legend(dummy_legend, legend_labels)

plt.show()

Los modelos implementados funcionan correctamente y permiten detectar las comunidades en el gráfico.

Propagación de etiquetas e implementaciones de difusión de etiquetas probadas en un gráfico cavernícola

Nota: Los métodos de propagación presentados están destinados a ser utilizados en gráficos no dirigidos.

El código está disponible como un cuaderno interactivo de Jupyter aquí .

Aglutinante

Thibaud M
fuente