Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 101 additions & 45 deletions prototype_v1/classifier.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""
Classifier avec Cross-Entropy
Classifier Binaire avec Symmetric BCE Loss

Transforme les embeddings contextualisés en prédictions par nœud.
Transforme les embeddings contextualisés en prédictions binaires par nœud.
La Symmetric BCE Loss gère automatiquement la symétrie des solutions (MaxCut).

Input: [batch, n_nodes, hidden_dim]
Output: [batch, n_nodes, num_classes] (logits)
Output: [batch, n_nodes] (probabilités entre 0 et 1)
"""

import torch
Expand All @@ -14,41 +15,38 @@

class Classifier(nn.Module):
"""
Classifier multi-classe pour les problèmes d'optimisation.
Classifier binaire pour les problèmes d'optimisation sur graphes.

Supporte:
- MaxCut, Vertex Cover, Independent Set: 2 classes
- Graph Coloring: k classes
Supporte: MaxCut, Vertex Cover, Independent Set (tous binaires).

Loss: Symmetric BCE
loss = min(BCE(pred, target), BCE(pred, 1-target))
→ Gère automatiquement la symétrie des solutions
"""

def __init__(self, hidden_dim=256, max_classes=10, dropout=0.1):
def __init__(self, hidden_dim=256, dropout=0.1):
super().__init__()

self.hidden_dim = hidden_dim
self.max_classes = max_classes

self.layers = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, max_classes)
nn.Linear(hidden_dim // 2, 1) # 1 seule sortie → binaire
)

def forward(self, x, num_classes=2):
def forward(self, x):
"""
Args:
x: [batch, n_nodes, hidden_dim] - embeddings contextualisés
num_classes: int - nombre de classes

Returns:
logits: [batch, n_nodes, num_classes]
probs: [batch, n_nodes, num_classes]
predictions: [batch, n_nodes]
probs: [batch, n_nodes] - probabilités entre 0 et 1
predictions: [batch, n_nodes] - 0 ou 1
"""
logits = self.layers(x)[:, :, :num_classes]
probs = F.softmax(logits, dim=-1)
predictions = torch.argmax(logits, dim=-1)
logits = self.layers(x).squeeze(-1) # [batch, n_nodes]
probs = torch.sigmoid(logits) # Sigmoid → [0, 1]
predictions = (probs > 0.5).long() # Seuil → 0 ou 1

return {
'logits': logits,
Expand All @@ -58,47 +56,105 @@ def forward(self, x, num_classes=2):

def compute_loss(self, logits, targets, mask=None):
"""
Cross-Entropy Loss.
Symmetric BCE Loss.

Calcule la BCE dans les deux sens (target et 1-target)
et garde le minimum → gère la symétrie.

Args:
logits: [batch, n_nodes, num_classes]
targets: [batch, n_nodes] - classes {0, 1, ..., k-1}
mask: [batch, n_nodes] - optionnel
logits: [batch, n_nodes] - sorties brutes (avant sigmoid)
targets: [batch, n_nodes] - valeurs 0 ou 1
mask: [batch, n_nodes] - optionnel (pour graphes de tailles différentes)

Returns:
loss: scalar
"""
b, n, c = logits.shape
logits_flat = logits.reshape(-1, c)
targets_flat = targets.reshape(-1).long()
targets = targets.float()

if mask is not None:
mask_flat = mask.reshape(-1).float()
loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
return (loss * mask_flat).sum() / mask_flat.sum().clamp(min=1)
mask = mask.float()

# Loss directe : pred vs target
loss_direct = F.binary_cross_entropy_with_logits(
logits, targets, reduction='none'
)
loss_direct = (loss_direct * mask).sum() / mask.sum().clamp(min=1)

# Loss inversée : pred vs (1 - target)
loss_inverse = F.binary_cross_entropy_with_logits(
logits, 1.0 - targets, reduction='none'
)
loss_inverse = (loss_inverse * mask).sum() / mask.sum().clamp(min=1)
else:
# Loss directe : pred vs target
loss_direct = F.binary_cross_entropy_with_logits(logits, targets)

return F.cross_entropy(logits_flat, targets_flat)
# Loss inversée : pred vs (1 - target)
loss_inverse = F.binary_cross_entropy_with_logits(logits, 1.0 - targets)

# Symmetric : on prend le minimum des deux
loss = torch.min(loss_direct, loss_inverse)

return loss

def compute_similarity(self, predictions, targets):
"""
Calcule le pourcentage de ressemblance (en tenant compte de la symétrie).

Args:
predictions: [batch, n_nodes] - 0 ou 1
targets: [batch, n_nodes] - 0 ou 1

Returns:
similarity: float entre 0 et 1 (1 = parfait)
"""
predictions = predictions.float()
targets = targets.float()

# Ressemblance directe
match_direct = (predictions == targets).float().mean()

# Ressemblance inversée
match_inverse = (predictions == (1.0 - targets)).float().mean()

# Meilleure des deux
similarity = torch.max(match_direct, match_inverse)

return similarity.item()


if __name__ == "__main__":
print("=== Test Classifier ===")
print("=== Test Classifier (Binaire + Symmetric BCE) ===\n")

x = torch.randn(4, 6, 256) # [batch, n_nodes, hidden_dim]
classifier = Classifier(hidden_dim=256, max_classes=10)
classifier = Classifier(hidden_dim=256)

# Test 2 classes (MaxCut)
output = classifier(x, num_classes=2)
print(f"Logits (2 classes): {output['logits'].shape}")
# Forward
output = classifier(x)
print(f"Logits: {output['logits'].shape}")
print(f"Probs: {output['probs'].shape}")
print(f"Predictions: {output['predictions'].shape}")
print(f"Exemple probs: {output['probs'][0].tolist()}")
print(f"Exemple preds: {output['predictions'][0].tolist()}")

# Test 5 classes (Graph Coloring)
output = classifier(x, num_classes=5)
print(f"Logits (5 classes): {output['logits'].shape}")
# Test Symmetric Loss
targets = torch.tensor([[1, 0, 1, 0, 1, 0]] * 4).float()

# Test loss
targets = torch.randint(0, 2, (4, 6))
output = classifier(x, num_classes=2)
loss = classifier.compute_loss(output['logits'], targets)
print(f"Loss: {loss.item():.4f}")

print(f"Params: {sum(p.numel() for p in classifier.parameters()):,}")
print(f"\nSymmetric BCE Loss: {loss.item():.4f}")

# Test symétrie : target et 1-target doivent donner la même loss
loss_normal = classifier.compute_loss(output['logits'], targets)
loss_inverted = classifier.compute_loss(output['logits'], 1 - targets)
print(f"Loss (target normal): {loss_normal.item():.4f}")
print(f"Loss (target inversé): {loss_inverted.item():.4f}")
print(f"Égales ? {'✅ OUI' if abs(loss_normal.item() - loss_inverted.item()) < 1e-6 else '❌ NON'}")

# Test similarité
pred = torch.tensor([[0, 1, 0, 1, 0, 1]])
target = torch.tensor([[1, 0, 1, 0, 1, 0]])
sim = classifier.compute_similarity(pred, target)
print(f"\nSimilarité [0,1,0,1,0,1] vs [1,0,1,0,1,0]: {sim:.0%}")

print(f"\nParams: {sum(p.numel() for p in classifier.parameters()):,}")
print("✅ OK")
54 changes: 54 additions & 0 deletions prototype_v1/explication_bitstring_comparaison.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
Format QAOA (Qiskit)

result = solver.solve(problem)
print(result.x) # numpy array: [0, 1, 0, 1]
print(type(result.x)) # <class 'numpy.ndarray'>
Format Notre Modèle

output = model(x, edge_index, problem_id=0)
print(output['predictions']) # tensor([[0, 1, 0, 1]])
print(type(output['predictions'])) # <class 'torch.Tensor'>


Sont-ils comparables ?
QAOA Notre Modèle
Type numpy.ndarray torch.Tensor
Shape [n_nodes] [batch, n_nodes]
Valeurs {0, 1} {0, 1}

Signification Nœud i dans set 0 ou 1 Nœud i dans set 0 ou 1
OUI mais il faut convertir :


# QAOA → Tensor pour comparaison
qaoa_target = torch.tensor(result.x) # [0, 1, 0, 1]

# Notre modèle → squeeze pour enlever batch dim
model_pred = output['predictions'].squeeze(0) # [0, 1, 0, 1]

# Maintenant comparable !
⚠️ ATTENTION : Symétrie du problème !
Pour MaxCut, il y a une subtilité :


Solution [0, 1, 0, 1] = Set A: {0, 2}, Set B: {1, 3}
Solution [1, 0, 1, 0] = Set A: {1, 3}, Set B: {0, 2}

CE SONT LES MÊMES SOLUTIONS ! (juste inversées)
Donc si :


QAOA dit: [0, 1, 0, 1]
Modèle dit: [1, 0, 1, 0]

→ C'est CORRECT ! Même partition, juste labels inversés.
Solutions pour gérer la symétrie
Option 1 : Normaliser (forcer à commencer par 0)

def normalize_bitstring(bits):
if bits[0] == 1:
return 1 - bits # Inverser
return bits

target = normalize_bitstring(qaoa_result)
pred = normalize_bitstring(model_pred)
56 changes: 31 additions & 25 deletions prototype_v1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Graph → GNN Encoder → E_local, E_global
problem_id → Lookup Table → E_prob
Concat [E_global || E_local || E_prob] → Transformer → embeddings contextualisés
Classifier → logitsCross-Entropy Loss → Backpropagation
Classifier binaire → probsSymmetric BCE Loss → Backpropagation
"""

import torch
Expand All @@ -20,11 +20,12 @@ class QuantumGraphModel(nn.Module):
"""
Modèle complet pour résoudre des problèmes d'optimisation sur graphes.

Supporte:
Supporte (tous binaires):
- MaxCut (2 classes)
- Vertex Cover (2 classes)
- Independent Set (2 classes)
- Graph Coloring (k classes)

Loss: Symmetric BCE (gère la symétrie des solutions)
"""

def __init__(
Expand All @@ -36,7 +37,6 @@ def __init__(
transformer_layers=4,
num_heads=8,
num_problems=10,
max_classes=10,
dropout=0.1
):
super().__init__()
Expand Down Expand Up @@ -67,24 +67,22 @@ def __init__(
dropout=dropout
)

# 4. Classifier
# 4. Classifier (binaire)
self.classifier = Classifier(
hidden_dim=hidden_dim,
max_classes=max_classes,
dropout=dropout
)

def forward(self, x, edge_index, problem_id, batch=None, num_classes=2):
def forward(self, x, edge_index, problem_id, batch=None):
"""
Args:
x: [n_nodes, node_input_dim]
edge_index: [2, n_edges]
problem_id: int ou [batch_size]
batch: [n_nodes] (optionnel)
num_classes: int

Returns:
dict avec logits, probs, predictions
dict avec logits, probs, predictions (tous [batch, n_nodes])
"""
# 1. GNN Encoder → E_local, E_global
e_local, e_global = self.encoder(x, edge_index, batch)
Expand All @@ -105,8 +103,8 @@ def forward(self, x, edge_index, problem_id, batch=None, num_classes=2):
# 3. Transformer → embeddings contextualisés
contextualized = self.transformer(e_local, e_global, e_prob)

# 4. Classifier → logits, probs, predictions
output = self.classifier(contextualized, num_classes=num_classes)
# 4. Classifier binaire → probs, predictions
output = self.classifier(contextualized)

return output

Expand All @@ -126,14 +124,19 @@ def _batch_node_embeddings(self, e_local, batch, batch_size):
return out

def compute_loss(self, logits, targets, mask=None):
"""Cross-Entropy Loss"""
"""Symmetric BCE Loss"""
return self.classifier.compute_loss(logits, targets, mask)

def forward_with_loss(self, x, edge_index, problem_id, targets, batch=None, num_classes=2):
def compute_similarity(self, predictions, targets):
"""Pourcentage de ressemblance (avec symétrie)"""
return self.classifier.compute_similarity(predictions, targets)

def forward_with_loss(self, x, edge_index, problem_id, targets, batch=None):
"""Forward + Loss en une seule passe"""
output = self.forward(x, edge_index, problem_id, batch, num_classes)
output = self.forward(x, edge_index, problem_id, batch)
loss = self.compute_loss(output['logits'], targets)
return output, loss
similarity = self.compute_similarity(output['predictions'], targets)
return output, loss, similarity


if __name__ == "__main__":
Expand All @@ -159,24 +162,27 @@ def forward_with_loss(self, x, edge_index, problem_id, targets, batch=None, num_
print(f"Paramètres: {sum(p.numel() for p in model.parameters()):,}")

# Forward (MaxCut)
output = model(x, edge_index, problem_id=0, num_classes=2)
print(f"\nMaxCut (2 classes):")
print(f" Logits: {output['logits'].shape}")
output = model(x, edge_index, problem_id=0)
print(f"\nMaxCut:")
print(f" Probs: {output['probs'].shape}")
print(f" Predictions: {output['predictions']}")

# Loss
# Symmetric Loss
targets = torch.tensor([[1, 0, 1, 0, 1, 0]])
loss = model.compute_loss(output['logits'], targets)
print(f" Loss: {loss.item():.4f}")

# Test symétrie
loss_inv = model.compute_loss(output['logits'], 1 - targets)
print(f" Loss inversée: {loss_inv.item():.4f}")
print(f" Symétrie OK ? {'✅' if abs(loss.item() - loss_inv.item()) < 1e-6 else '❌'}")

# Similarité
sim = model.compute_similarity(output['predictions'], targets)
print(f" Similarité: {sim:.0%}")

# Backprop
loss.backward()
print(" Backprop OK")

# Graph Coloring
output = model(x, edge_index, problem_id=3, num_classes=5)
print(f"\nGraph Coloring (5 classes):")
print(f" Logits: {output['logits'].shape}")
print(f" Predictions: {output['predictions']}")

print("\n✅ Tous les tests passés!")
Loading
Loading