# Setup the repo

In [30]:
try:
    !pip uninstall -qy geometricvocab geofractal
except:
    pass

!pip install -q git+https://github.com/AbstractEyes/geofractal.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for geofractal (pyproject.toml) ... [?25l[?25hdone
  Building wheel for geometricvocab (pyproject.toml) ... [?25l[?25hdone


# test the factory

In [2]:
# Cell 2: Imports and Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.auto import tqdm

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Import our factory system
from geofractal.router.factory import (
    PrototypeBuilder,
    StreamSpec,
    HeadSpec,
    FusionSpec,
    get_prototype_registry,
)

print("Factory imports successful!")

Device: cuda
Factory imports successful!


In [3]:
# Cell 3: Dataset Setup (CIFAR-10 for quick validation)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # CLIP expects 224x224
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

train_dataset = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)
val_dataset = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

100%|██████████| 170M/170M [00:13<00:00, 12.2MB/s]


Train: 50000, Val: 10000


In [4]:
# Cell 4: Direct Component Test (bypass factory for now)
from geofractal.router.head import HeadBuilder, HeadConfig, build_standard_head
from geofractal.router.fusion import ConcatFusion, FusionConfig
from transformers import CLIPModel, CLIPProcessor

# Load CLIP
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip.eval()
for p in clip.parameters():
    p.requires_grad = False

# Build head
head_config = HeadConfig(feature_dim=512, fingerprint_dim=64, num_anchors=16, num_routes=4)
head = build_standard_head(head_config).to(device)

# Simple classifier
classifier = nn.Sequential(
    nn.LayerNorm(512),
    nn.Dropout(0.1),
    nn.Linear(512, 512),
    nn.GELU(),
    nn.Dropout(0.1),
    nn.Linear(512, 10),
).to(device)

print(f"Head params: {sum(p.numel() for p in head.parameters()):,}")
print(f"Classifier params: {sum(p.numel() for p in classifier.parameters()):,}")

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

Head params: 3,764,476
Classifier params: 268,810


In [5]:
# Cell 5: Forward test
@torch.no_grad()
def extract_clip_features(images):
    outputs = clip.vision_model(images)
    return outputs.last_hidden_state  # [B, 50, 768] for base

# Test forward
images, labels = next(iter(train_loader))
images = images.to(device)

features = extract_clip_features(images)
print(f"CLIP features: {features.shape}")

# Project to 512 if needed
proj = nn.Linear(768, 512).to(device) if features.shape[-1] != 512 else nn.Identity()
features = proj(features)
print(f"Projected: {features.shape}")

# Through head
head_out = head(features)
print(f"Head output: {head_out.shape}")

# Pool and classify
pooled = head_out[:, 0]  # CLS token
logits = classifier(pooled)
print(f"Logits: {logits.shape}")

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

CLIP features: torch.Size([64, 50, 768])
Projected: torch.Size([64, 50, 512])
Head output: torch.Size([64, 50, 512])
Logits: torch.Size([64, 10])


# one model one stream classification

In [6]:
# Cell 6: Training with Direct Components
from transformers import CLIPModel

# Setup
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip.eval()
for p in clip.parameters():
    p.requires_grad = False

# Projection (CLIP base is 768, head expects 512)
proj = nn.Linear(768, 512).to(device)

# Head
from geofractal.router.head import build_standard_head, HeadConfig
head_config = HeadConfig(feature_dim=512, fingerprint_dim=64, num_anchors=16, num_routes=4)
head = build_standard_head(head_config).to(device)

# Classifier
classifier = nn.Sequential(
    nn.LayerNorm(512),
    nn.Dropout(0.1),
    nn.Linear(512, 512),
    nn.GELU(),
    nn.Dropout(0.1),
    nn.Linear(512, 10),
).to(device)

# Combine trainable params
trainable_params = list(proj.parameters()) + list(head.parameters()) + list(classifier.parameters())
optimizer = torch.optim.AdamW(trainable_params, lr=3e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

print(f"Trainable: {sum(p.numel() for p in trainable_params):,}")

Trainable: 4,427,014


In [7]:
# Cell 7: Training Loop
def forward_pass(images):
    with torch.no_grad():
        clip_out = clip.vision_model(images).last_hidden_state
    features = proj(clip_out)
    head_out = head(features)
    pooled = head_out[:, 0]
    return classifier(pooled)

EPOCHS = 5
for epoch in range(EPOCHS):
    # Train
    head.train()
    classifier.train()
    proj.train()

    total_loss, correct, total = 0, 0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        logits = forward_pass(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (logits.argmax(-1) == labels).sum().item()
        total += labels.size(0)
        pbar.set_postfix({'loss': f'{loss.item():.3f}', 'acc': f'{correct/total:.1%}'})

    # Eval
    head.eval()
    classifier.eval()
    proj.eval()

    val_correct, val_total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            logits = forward_pass(images)
            val_correct += (logits.argmax(-1) == labels).sum().item()
            val_total += labels.size(0)

    print(f"Epoch {epoch+1}: Train {correct/total:.1%}, Val {val_correct/val_total:.1%}")

Epoch 1:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 1: Train 91.7%, Val 93.6%


Epoch 2:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 2: Train 93.8%, Val 93.3%


Epoch 3:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 3: Train 94.3%, Val 93.3%


Epoch 4:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 4: Train 94.5%, Val 94.1%


Epoch 5:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 5: Train 95.3%, Val 94.2%


# two-stream two models, two heads, fusion

In [9]:
# Cell: Setup - Fresh two-stream training with fixed gradients
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import CLIPModel
from geofractal.router.head import build_standard_head, HeadConfig
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Data
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

Device: cuda
Train: 50000, Val: 10000


In [10]:
# Cell: Build two-stream system with fixed heads
# Stream A: CLIP B/32
clip_b32 = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_b32.eval()
for p in clip_b32.parameters():
    p.requires_grad = False

# Stream B: CLIP B/16
clip_b16 = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(device)
clip_b16.eval()
for p in clip_b16.parameters():
    p.requires_grad = False

# Projections
proj_a = nn.Linear(768, 512).to(device)
proj_b = nn.Linear(768, 512).to(device)

# Heads with fixed gradients
head_config = HeadConfig(feature_dim=512, fingerprint_dim=64, num_anchors=16, num_routes=4)
head_a = build_standard_head(head_config).to(device)
head_b = build_standard_head(head_config).to(device)

# Fusion
fusion = nn.Sequential(
    nn.Linear(512 * 2, 512),
    nn.LayerNorm(512),
    nn.GELU(),
    nn.Dropout(0.1),
).to(device)

# Classifier
classifier = nn.Sequential(
    nn.LayerNorm(512),
    nn.Dropout(0.1),
    nn.Linear(512, 512),
    nn.GELU(),
    nn.Dropout(0.1),
    nn.Linear(512, 10),
).to(device)

# Individual classifiers for emergence measurement
classifier_a = nn.Linear(512, 10).to(device)
classifier_b = nn.Linear(512, 10).to(device)

# All trainable params
trainable_params = (
    list(proj_a.parameters()) + list(proj_b.parameters()) +
    list(head_a.parameters()) + list(head_b.parameters()) +
    list(fusion.parameters()) + list(classifier.parameters()) +
    list(classifier_a.parameters()) + list(classifier_b.parameters())
)

optimizer = torch.optim.AdamW(trainable_params, lr=3e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

print(f"Trainable params: {sum(p.numel() for p in trainable_params):,}")
print(f"Head A fingerprint: {head_a.fingerprint[:5].tolist()}")
print(f"Head B fingerprint: {head_b.fingerprint[:5].tolist()}")

Trainable params: 9,121,302
Head A fingerprint: [0.0017648048233240843, -0.008485701866447926, 0.0073287528939545155, -0.032939378172159195, -0.013251084834337234]
Head B fingerprint: [-0.011380697600543499, 0.01520864013582468, -0.013518509455025196, 0.0704391598701477, 0.0063894083723425865]


In [11]:
# Cell: Training loop
def forward_all(images):
    """Forward through both streams, return collective and individual outputs."""
    with torch.no_grad():
        feat_a = clip_b32.vision_model(images).last_hidden_state
        feat_b = clip_b16.vision_model(images).last_hidden_state

    feat_a = proj_a(feat_a)
    feat_b = proj_b(feat_b)

    out_a = head_a(feat_a)[:, 0]  # CLS
    out_b = head_b(feat_b)[:, 0]  # CLS

    fused = fusion(torch.cat([out_a, out_b], dim=-1))
    collective_logits = classifier(fused)

    # Individual predictions (for emergence measurement)
    ind_a_logits = classifier_a(out_a)
    ind_b_logits = classifier_b(out_b)

    return collective_logits, ind_a_logits, ind_b_logits, out_a, out_b

EPOCHS = 5
history = []

for epoch in range(EPOCHS):
    # Train
    head_a.train(); head_b.train(); fusion.train(); classifier.train()
    proj_a.train(); proj_b.train(); classifier_a.train(); classifier_b.train()

    total_loss, correct, total = 0, 0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        coll_logits, ind_a, ind_b, _, _ = forward_all(images)

        # Joint loss: collective + individual guidance
        loss = criterion(coll_logits, labels)
        loss += 0.1 * criterion(ind_a, labels)  # Light individual supervision
        loss += 0.1 * criterion(ind_b, labels)

        loss.backward()
        optimizer.step()

        correct += (coll_logits.argmax(-1) == labels).sum().item()
        total += labels.size(0)
        pbar.set_postfix({'loss': f'{loss.item():.3f}', 'acc': f'{correct/total:.1%}'})

    # Eval
    head_a.eval(); head_b.eval(); fusion.eval(); classifier.eval()
    classifier_a.eval(); classifier_b.eval()

    coll_correct, ind_a_correct, ind_b_correct, val_total = 0, 0, 0, 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            coll_logits, ind_a, ind_b, _, _ = forward_all(images)

            coll_correct += (coll_logits.argmax(-1) == labels).sum().item()
            ind_a_correct += (ind_a.argmax(-1) == labels).sum().item()
            ind_b_correct += (ind_b.argmax(-1) == labels).sum().item()
            val_total += labels.size(0)

    coll_acc = coll_correct / val_total
    ind_a_acc = ind_a_correct / val_total
    ind_b_acc = ind_b_correct / val_total
    max_ind = max(ind_a_acc, ind_b_acc)
    rho = coll_acc / max_ind if max_ind > 0 else 0

    history.append({
        'epoch': epoch + 1,
        'collective': coll_acc,
        'ind_a': ind_a_acc,
        'ind_b': ind_b_acc,
        'rho': rho,
    })

    print(f"Epoch {epoch+1}: Collective {coll_acc:.1%}, A {ind_a_acc:.1%}, B {ind_b_acc:.1%}, ρ={rho:.3f}")

Epoch 1:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 1: Collective 94.5%, A 91.7%, B 94.2%, ρ=1.004


Epoch 2:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 2: Collective 95.4%, A 93.2%, B 94.5%, ρ=1.010


Epoch 3:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 3: Collective 95.5%, A 91.9%, B 94.6%, ρ=1.010


Epoch 4:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 4: Collective 94.3%, A 91.5%, B 92.6%, ρ=1.019


Epoch 5:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 5: Collective 95.7%, A 93.7%, B 94.7%, ρ=1.011


In [12]:
# Cell: Summary comparison
print("\n=== Training Summary ===\n")
print("| Epoch | Collective | Stream A | Stream B | ρ |")
print("|-------|------------|----------|----------|--------|")
for h in history:
    print(f"| {h['epoch']} | {h['collective']:.1%} | {h['ind_a']:.1%} | {h['ind_b']:.1%} | {h['rho']:.3f} |")

print(f"\nPeak ρ: {max(h['rho'] for h in history):.3f}")
print(f"Final collective: {history[-1]['collective']:.1%}")


=== Training Summary ===

| Epoch | Collective | Stream A | Stream B | ρ |
|-------|------------|----------|----------|--------|
| 1 | 94.5% | 91.7% | 94.2% | 1.004 |
| 2 | 95.4% | 93.2% | 94.5% | 1.010 |
| 3 | 95.5% | 91.9% | 94.6% | 1.010 |
| 4 | 94.3% | 91.5% | 92.6% | 1.019 |
| 5 | 95.7% | 93.7% | 94.7% | 1.011 |

Peak ρ: 1.019
Final collective: 95.7%


# entailment dual stream t5-base 2heads different opinions

In [16]:
# Cell: Load MNLI dataset
from datasets import load_dataset

# MNLI: premise, hypothesis → entailment/neutral/contradiction (0/1/2)
dataset = load_dataset("glue", "mnli")

print(f"Train: {len(dataset['train'])}")
print(f"Val (matched): {len(dataset['validation_matched'])}")
print(f"Labels: 0=entailment, 1=neutral, 2=contradiction")

# Sample
sample = dataset['train'][0]
print(f"\nPremise: {sample['premise']}")
print(f"Hypothesis: {sample['hypothesis']}")
print(f"Label: {sample['label']}")

README.md: 0.00B [00:00, ?B/s]

mnli/train-00000-of-00001.parquet:   0%|          | 0.00/52.2M [00:00<?, ?B/s]

mnli/validation_matched-00000-of-00001.p(…):   0%|          | 0.00/1.21M [00:00<?, ?B/s]

mnli/validation_mismatched-00000-of-0000(…):   0%|          | 0.00/1.25M [00:00<?, ?B/s]

mnli/test_matched-00000-of-00001.parquet:   0%|          | 0.00/1.22M [00:00<?, ?B/s]

mnli/test_mismatched-00000-of-00001.parq(…):   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/392702 [00:00<?, ? examples/s]

Generating validation_matched split:   0%|          | 0/9815 [00:00<?, ? examples/s]

Generating validation_mismatched split:   0%|          | 0/9832 [00:00<?, ? examples/s]

Generating test_matched split:   0%|          | 0/9796 [00:00<?, ? examples/s]

Generating test_mismatched split:   0%|          | 0/9847 [00:00<?, ? examples/s]

Train: 392702
Val (matched): 9815
Labels: 0=entailment, 1=neutral, 2=contradiction

Premise: Conceptually cream skimming has two basic dimensions - product and geography.
Hypothesis: Product and geography are what make cream skimming work. 
Label: 1


In [17]:
# Cell: Build T5 dual-head entailment system
from transformers import T5Tokenizer, T5EncoderModel
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from geofractal.router.head import build_standard_head, HeadConfig
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# T5-base encoder
tokenizer = T5Tokenizer.from_pretrained("t5-base")
t5 = T5EncoderModel.from_pretrained("t5-base").to(device)
t5.eval()
for p in t5.parameters():
    p.requires_grad = False

# Two heads with different task framings
head_config = HeadConfig(feature_dim=768, fingerprint_dim=64, num_anchors=16, num_routes=4)
head_summarize = build_standard_head(head_config).to(device)  # "What's the gist?"
head_raw = build_standard_head(head_config).to(device)        # "What's literally there?"

# Fusion
fusion = nn.Sequential(
    nn.Linear(768 * 2, 768),
    nn.LayerNorm(768),
    nn.GELU(),
    nn.Dropout(0.1),
).to(device)

# Classifier for 3-way entailment
classifier = nn.Sequential(
    nn.Linear(768, 256),
    nn.GELU(),
    nn.Dropout(0.1),
    nn.Linear(256, 3),
).to(device)

# Individual classifiers for emergence tracking
classifier_summarize = nn.Linear(768, 3).to(device)
classifier_raw = nn.Linear(768, 3).to(device)

trainable_params = (
    list(head_summarize.parameters()) + list(head_raw.parameters()) +
    list(fusion.parameters()) + list(classifier.parameters()) +
    list(classifier_summarize.parameters()) + list(classifier_raw.parameters())
)

optimizer = torch.optim.AdamW(trainable_params, lr=2e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

print(f"Trainable params: {sum(p.numel() for p in trainable_params):,}")

Trainable params: 18,171,649


In [18]:
# Cell: Data encoding
def encode_entailment_batch(batch, max_length=128):
    """
    Encode premise-hypothesis pairs with dual framing.

    Stream A (summarize): "summarize: premise hypothesis"
    Stream B (raw): "premise hypothesis"
    """
    # Combine premise + hypothesis
    combined = [f"{p} </s> {h}" for p, h in zip(batch['premise'], batch['hypothesis'])]

    # Stream A: summarization framing
    texts_summarize = [f"summarize: {t}" for t in combined]
    enc_a = tokenizer(texts_summarize, return_tensors="pt",
                      padding=True, truncation=True, max_length=max_length)

    # Stream B: raw
    enc_b = tokenizer(combined, return_tensors="pt",
                      padding=True, truncation=True, max_length=max_length)

    return enc_a, enc_b, torch.tensor(batch['label'])

def forward_batch(enc_a, enc_b):
    """Forward through both streams and fuse."""
    with torch.no_grad():
        hidden_a = t5(enc_a.input_ids.to(device),
                      attention_mask=enc_a.attention_mask.to(device)).last_hidden_state
        hidden_b = t5(enc_b.input_ids.to(device),
                      attention_mask=enc_b.attention_mask.to(device)).last_hidden_state

    # Through heads
    out_a = head_summarize(hidden_a)
    out_b = head_raw(hidden_b)

    # Mean pooling
    mask_a = enc_a.attention_mask.to(device).unsqueeze(-1)
    mask_b = enc_b.attention_mask.to(device).unsqueeze(-1)

    pooled_a = (out_a * mask_a).sum(1) / mask_a.sum(1)
    pooled_b = (out_b * mask_b).sum(1) / mask_b.sum(1)

    # Fuse
    fused = fusion(torch.cat([pooled_a, pooled_b], dim=-1))

    return fused, pooled_a, pooled_b

In [19]:
# Cell: Training loop
from torch.utils.data import DataLoader

# Subset for faster iteration (MNLI is huge)
train_subset = dataset['train'].shuffle(seed=42).select(range(50000))
val_subset = dataset['validation_matched'].shuffle(seed=42).select(range(5000))

BATCH_SIZE = 32
EPOCHS = 3

def collate_fn(examples):
    return {
        'premise': [ex['premise'] for ex in examples],
        'hypothesis': [ex['hypothesis'] for ex in examples],
        'label': [ex['label'] for ex in examples],
    }

train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

history = []

for epoch in range(EPOCHS):
    # Train
    head_summarize.train(); head_raw.train(); fusion.train(); classifier.train()
    classifier_summarize.train(); classifier_raw.train()

    total_loss, correct, total = 0, 0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for batch in pbar:
        enc_a, enc_b, labels = encode_entailment_batch(batch)
        labels = labels.to(device)

        optimizer.zero_grad()

        fused, pooled_a, pooled_b = forward_batch(enc_a, enc_b)

        # Collective prediction
        logits = classifier(fused)
        loss = criterion(logits, labels)

        # Individual supervision (light)
        loss += 0.1 * criterion(classifier_summarize(pooled_a), labels)
        loss += 0.1 * criterion(classifier_raw(pooled_b), labels)

        loss.backward()
        optimizer.step()

        correct += (logits.argmax(-1) == labels).sum().item()
        total += labels.size(0)
        pbar.set_postfix({'loss': f'{loss.item():.3f}', 'acc': f'{correct/total:.1%}'})

    # Eval
    head_summarize.eval(); head_raw.eval(); fusion.eval(); classifier.eval()
    classifier_summarize.eval(); classifier_raw.eval()

    coll_correct, sum_correct, raw_correct, val_total = 0, 0, 0, 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            enc_a, enc_b, labels = encode_entailment_batch(batch)
            labels = labels.to(device)

            fused, pooled_a, pooled_b = forward_batch(enc_a, enc_b)

            coll_correct += (classifier(fused).argmax(-1) == labels).sum().item()
            sum_correct += (classifier_summarize(pooled_a).argmax(-1) == labels).sum().item()
            raw_correct += (classifier_raw(pooled_b).argmax(-1) == labels).sum().item()
            val_total += labels.size(0)

    coll_acc = coll_correct / val_total
    sum_acc = sum_correct / val_total
    raw_acc = raw_correct / val_total
    max_ind = max(sum_acc, raw_acc)
    rho = coll_acc / max_ind if max_ind > 0 else 0

    history.append({
        'epoch': epoch + 1,
        'collective': coll_acc,
        'summarize': sum_acc,
        'raw': raw_acc,
        'rho': rho,
    })

    print(f"Epoch {epoch+1}: Collective {coll_acc:.1%}, Summarize {sum_acc:.1%}, Raw {raw_acc:.1%}, ρ={rho:.3f}")

Epoch 1:   0%|          | 0/1563 [00:00<?, ?it/s]



Validating:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 1: Collective 74.5%, Summarize 71.7%, Raw 71.6%, ρ=1.040


Epoch 2:   0%|          | 0/1563 [00:00<?, ?it/s]

Validating:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 2: Collective 76.5%, Summarize 74.6%, Raw 73.8%, ρ=1.025


Epoch 3:   0%|          | 0/1563 [00:00<?, ?it/s]

Validating:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 3: Collective 77.5%, Summarize 74.6%, Raw 75.2%, ρ=1.031


In [20]:
# Cell: Summary
print("\n=== T5 Entailment Results ===\n")
print("| Epoch | Collective | Summarize | Raw | ρ |")
print("|-------|------------|-----------|-----|-------|")
for h in history:
    print(f"| {h['epoch']} | {h['collective']:.1%} | {h['summarize']:.1%} | {h['raw']:.1%} | {h['rho']:.3f} |")

print(f"\nTask framing divergence:")
print(f"  Head A: 'summarize: premise </s> hypothesis' → What's essential?")
print(f"  Head B: 'premise </s> hypothesis' → What's there?")


=== T5 Entailment Results ===

| Epoch | Collective | Summarize | Raw | ρ |
|-------|------------|-----------|-----|-------|
| 1 | 74.5% | 71.7% | 71.6% | 1.040 |
| 2 | 76.5% | 74.6% | 73.8% | 1.025 |
| 3 | 77.5% | 74.6% | 75.2% | 1.031 |

Task framing divergence:
  Head A: 'summarize: premise </s> hypothesis' → What's essential?
  Head B: 'premise </s> hypothesis' → What's there?


# entailment dual stream t5-base + bert

In [21]:
# Cell: T5 + BERT Dual Architecture System
from transformers import T5Tokenizer, T5EncoderModel, BertTokenizer, BertModel
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from geofractal.router.head import build_standard_head, HeadConfig
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ============ FROZEN BACKBONES ============

# T5-base encoder
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
t5 = T5EncoderModel.from_pretrained("t5-base").to(device)
t5.eval()
for p in t5.parameters():
    p.requires_grad = False

# BERT-base encoder
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert = BertModel.from_pretrained("bert-base-uncased").to(device)
bert.eval()
for p in bert.parameters():
    p.requires_grad = False

print(f"T5 hidden: {t5.config.d_model}, BERT hidden: {bert.config.hidden_size}")
print(f"T5 params: {sum(p.numel() for p in t5.parameters()):,} (frozen)")
print(f"BERT params: {sum(p.numel() for p in bert.parameters()):,} (frozen)")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

T5 hidden: 768, BERT hidden: 768
T5 params: 109,628,544 (frozen)
BERT params: 109,482,240 (frozen)


In [22]:
# Cell: Build routing heads and fusion
# Two heads - different architectures see differently
head_config = HeadConfig(feature_dim=768, fingerprint_dim=64, num_anchors=16, num_routes=4)
head_t5 = build_standard_head(head_config).to(device)    # T5's perspective
head_bert = build_standard_head(head_config).to(device)  # BERT's perspective

# Fusion
fusion = nn.Sequential(
    nn.Linear(768 * 2, 768),
    nn.LayerNorm(768),
    nn.GELU(),
    nn.Dropout(0.1),
).to(device)

# Classifier for 3-way entailment
classifier = nn.Sequential(
    nn.Linear(768, 256),
    nn.GELU(),
    nn.Dropout(0.1),
    nn.Linear(256, 3),
).to(device)

# Individual classifiers for emergence tracking
classifier_t5 = nn.Linear(768, 3).to(device)
classifier_bert = nn.Linear(768, 3).to(device)

trainable_params = (
    list(head_t5.parameters()) + list(head_bert.parameters()) +
    list(fusion.parameters()) + list(classifier.parameters()) +
    list(classifier_t5.parameters()) + list(classifier_bert.parameters())
)

optimizer = torch.optim.AdamW(trainable_params, lr=2e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

print(f"\nTrainable routing params: {sum(p.numel() for p in trainable_params):,}")
print(f"Frozen backbone params: {sum(p.numel() for p in t5.parameters()) + sum(p.numel() for p in bert.parameters()):,}")


Trainable routing params: 18,171,649
Frozen backbone params: 219,110,784


In [23]:
# Cell: Dual-architecture encoding
def encode_t5(premises, hypotheses, max_length=128):
    """T5 style: 'premise </s> hypothesis'"""
    texts = [f"{p} </s> {h}" for p, h in zip(premises, hypotheses)]
    enc = t5_tokenizer(texts, return_tensors="pt", padding=True,
                       truncation=True, max_length=max_length)
    return enc

def encode_bert(premises, hypotheses, max_length=128):
    """BERT style: '[CLS] premise [SEP] hypothesis [SEP]'"""
    enc = bert_tokenizer(premises, hypotheses, return_tensors="pt",
                         padding=True, truncation=True, max_length=max_length)
    return enc

def forward_dual_arch(premises, hypotheses):
    """Forward through both architectures."""
    # Encode
    enc_t5 = encode_t5(premises, hypotheses)
    enc_bert = encode_bert(premises, hypotheses)

    # Through frozen backbones
    with torch.no_grad():
        hidden_t5 = t5(
            enc_t5.input_ids.to(device),
            attention_mask=enc_t5.attention_mask.to(device)
        ).last_hidden_state

        hidden_bert = bert(
            enc_bert.input_ids.to(device),
            attention_mask=enc_bert.attention_mask.to(device)
        ).last_hidden_state

    # Through routing heads
    out_t5 = head_t5(hidden_t5)
    out_bert = head_bert(hidden_bert)

    # Pool (mean over sequence)
    mask_t5 = enc_t5.attention_mask.to(device).unsqueeze(-1)
    mask_bert = enc_bert.attention_mask.to(device).unsqueeze(-1)

    pooled_t5 = (out_t5 * mask_t5).sum(1) / mask_t5.sum(1)
    pooled_bert = (out_bert * mask_bert).sum(1) / mask_bert.sum(1)

    # Fuse
    fused = fusion(torch.cat([pooled_t5, pooled_bert], dim=-1))

    return fused, pooled_t5, pooled_bert

# Test forward
test_p = ["A man is playing guitar.", "The cat sat on the mat."]
test_h = ["Someone is making music.", "An animal is on furniture."]
fused, pt5, pbert = forward_dual_arch(test_p, test_h)
print(f"T5 pooled: {pt5.shape}, BERT pooled: {pbert.shape}, Fused: {fused.shape}")

T5 pooled: torch.Size([2, 768]), BERT pooled: torch.Size([2, 768]), Fused: torch.Size([2, 768])


In [24]:
# Cell: Load data and train
dataset = load_dataset("glue", "mnli")
train_subset = dataset['train'].shuffle(seed=42).select(range(50000))
val_subset = dataset['validation_matched'].shuffle(seed=42).select(range(5000))

BATCH_SIZE = 32
EPOCHS = 3

def collate_fn(examples):
    return {
        'premise': [ex['premise'] for ex in examples],
        'hypothesis': [ex['hypothesis'] for ex in examples],
        'label': [ex['label'] for ex in examples],
    }

train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

history = []

for epoch in range(EPOCHS):
    # Train
    head_t5.train(); head_bert.train(); fusion.train(); classifier.train()
    classifier_t5.train(); classifier_bert.train()

    total_loss, correct, total = 0, 0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for batch in pbar:
        labels = torch.tensor(batch['label']).to(device)

        optimizer.zero_grad()

        fused, pooled_t5, pooled_bert = forward_dual_arch(batch['premise'], batch['hypothesis'])

        # Collective
        logits = classifier(fused)
        loss = criterion(logits, labels)

        # Individual supervision
        loss += 0.1 * criterion(classifier_t5(pooled_t5), labels)
        loss += 0.1 * criterion(classifier_bert(pooled_bert), labels)

        loss.backward()
        optimizer.step()

        correct += (logits.argmax(-1) == labels).sum().item()
        total += labels.size(0)
        pbar.set_postfix({'loss': f'{loss.item():.3f}', 'acc': f'{correct/total:.1%}'})

    # Eval
    head_t5.eval(); head_bert.eval(); fusion.eval(); classifier.eval()
    classifier_t5.eval(); classifier_bert.eval()

    coll_correct, t5_correct, bert_correct, val_total = 0, 0, 0, 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            labels = torch.tensor(batch['label']).to(device)

            fused, pooled_t5, pooled_bert = forward_dual_arch(batch['premise'], batch['hypothesis'])

            coll_correct += (classifier(fused).argmax(-1) == labels).sum().item()
            t5_correct += (classifier_t5(pooled_t5).argmax(-1) == labels).sum().item()
            bert_correct += (classifier_bert(pooled_bert).argmax(-1) == labels).sum().item()
            val_total += labels.size(0)

    coll_acc = coll_correct / val_total
    t5_acc = t5_correct / val_total
    bert_acc = bert_correct / val_total
    max_ind = max(t5_acc, bert_acc)
    rho = coll_acc / max_ind if max_ind > 0 else 0

    history.append({
        'epoch': epoch + 1,
        'collective': coll_acc,
        't5': t5_acc,
        'bert': bert_acc,
        'rho': rho,
    })

    print(f"Epoch {epoch+1}: Collective {coll_acc:.1%}, T5 {t5_acc:.1%}, BERT {bert_acc:.1%}, ρ={rho:.3f}")

Epoch 1:   0%|          | 0/1563 [00:00<?, ?it/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Validating:   0%|          | 0/157 [00:00<?, ?it/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Epoch 1: Collective 75.8%, T5 74.8%, BERT 63.7%, ρ=1.013


Epoch 2:   0%|          | 0/1563 [00:00<?, ?it/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Validating:   0%|          | 0/157 [00:00<?, ?it/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Epoch 2: Collective 77.5%, T5 76.7%, BERT 65.2%, ρ=1.010


Epoch 3:   0%|          | 0/1563 [00:00<?, ?it/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Validating:   0%|          | 0/157 [00:00<?, ?it/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Epoch 3: Collective 76.8%, T5 76.0%, BERT 65.4%, ρ=1.011


In [25]:
# Cell: Summary
print("\n=== T5 + BERT Cross-Architecture Results ===\n")
print("| Epoch | Collective | T5 | BERT | ρ |")
print("|-------|------------|-----|------|-------|")
for h in history:
    print(f"| {h['epoch']} | {h['collective']:.1%} | {h['t5']:.1%} | {h['bert']:.1%} | {h['rho']:.3f} |")

print(f"\nArchitectural divergence:")
print(f"  T5:   Encoder-decoder, SentencePiece, span corruption pretraining")
print(f"  BERT: Encoder-only, WordPiece, MLM+NSP pretraining")
print(f"  Same hidden dim (768), completely different representations")


=== T5 + BERT Cross-Architecture Results ===

| Epoch | Collective | T5 | BERT | ρ |
|-------|------------|-----|------|-------|
| 1 | 75.8% | 74.8% | 63.7% | 1.013 |
| 2 | 77.5% | 76.7% | 65.2% | 1.010 |
| 3 | 76.8% | 76.0% | 65.4% | 1.011 |

Architectural divergence:
  T5:   Encoder-decoder, SentencePiece, span corruption pretraining
  BERT: Encoder-only, WordPiece, MLM+NSP pretraining
  Same hidden dim (768), completely different representations


# maxx-vit dino -> supervised and unsupervised dual-stream dual head

In [28]:
!pip install -q git+https://github.com/mlfoundations/open_clip.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for open_clip_torch (pyproject.toml) ... [?25l[?25hdone


In [30]:
# Cell: Load MaxViT-CLIP and DINO
import torch
import torch.nn as nn
import open_clip
from transformers import Dinov2Model, AutoImageProcessor
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from geofractal.router.head import build_standard_head, HeadConfig
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MaxViT CLIP (open_clip)
maxvit_clip, _, preprocess_clip = open_clip.create_model_and_transforms(
    'ViT-SO400M-14-SigLIP-384',  # Strong SigLIP variant
    pretrained='webli'
)
maxvit_clip = maxvit_clip.visual.to(device)
maxvit_clip.eval()
for p in maxvit_clip.parameters():
    p.requires_grad = False

# DINOv2
dino = Dinov2Model.from_pretrained("facebook/dinov2-base").to(device)
dino.eval()
for p in dino.parameters():
    p.requires_grad = False

# Check dimensions
print(f"MaxViT-CLIP output: {maxvit_clip.output_dim if hasattr(maxvit_clip, 'output_dim') else 'check'}")
print(f"DINOv2 hidden: {dino.config.hidden_size}")

open_clip_model.safetensors:   0%|          | 0.00/3.51G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/548 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

MaxViT-CLIP output: check
DINOv2 hidden: 768


In [31]:
# Cell: Test forward passes and get dimensions
from PIL import Image
import torchvision.transforms as T

# Test image
test_transform = T.Compose([
    T.Resize((384, 384)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Dummy input
x = torch.randn(2, 3, 384, 384).to(device)

with torch.no_grad():
    # MaxViT CLIP - get intermediate features if possible
    clip_out = maxvit_clip(x)
    print(f"CLIP output shape: {clip_out.shape}")

    # DINOv2 - get patch tokens
    dino_out = dino(x).last_hidden_state
    print(f"DINO output shape: {dino_out.shape}")

CLIP output shape: torch.Size([2, 1152])
DINO output shape: torch.Size([2, 730, 768])


In [32]:
# Cell: Build projections and routing heads
# Get actual dims from test above
CLIP_DIM = clip_out.shape[-1] if len(clip_out.shape) > 2 else clip_out.shape[-1]
DINO_DIM = dino_out.shape[-1]
ROUTE_DIM = 512

print(f"CLIP dim: {CLIP_DIM}, DINO dim: {DINO_DIM}")

# Projections to common space
proj_clip = nn.Linear(CLIP_DIM, ROUTE_DIM).to(device)
proj_dino = nn.Linear(DINO_DIM, ROUTE_DIM).to(device)

# Routing heads
head_config = HeadConfig(feature_dim=ROUTE_DIM, fingerprint_dim=64, num_anchors=16, num_routes=4)
head_clip = build_standard_head(head_config).to(device)  # "What is this called?"
head_dino = build_standard_head(head_config).to(device)  # "What does this look like?"

# Fusion
fusion = nn.Sequential(
    nn.Linear(ROUTE_DIM * 2, ROUTE_DIM),
    nn.LayerNorm(ROUTE_DIM),
    nn.GELU(),
    nn.Dropout(0.1),
).to(device)

# Classifier (CIFAR-10)
classifier = nn.Sequential(
    nn.Linear(ROUTE_DIM, 256),
    nn.GELU(),
    nn.Dropout(0.1),
    nn.Linear(256, 10),
).to(device)

# Individual classifiers
classifier_clip = nn.Linear(ROUTE_DIM, 10).to(device)
classifier_dino = nn.Linear(ROUTE_DIM, 10).to(device)

trainable = (
    list(proj_clip.parameters()) + list(proj_dino.parameters()) +
    list(head_clip.parameters()) + list(head_dino.parameters()) +
    list(fusion.parameters()) + list(classifier.parameters()) +
    list(classifier_clip.parameters()) + list(classifier_dino.parameters())
)

optimizer = torch.optim.AdamW(trainable, lr=3e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

print(f"Trainable params: {sum(p.numel() for p in trainable):,}")

CLIP dim: 1152, DINO dim: 768
Trainable params: 9,182,998


In [33]:
# Cell: Data loaders
transform = T.Compose([
    T.Resize((384, 384)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

Train: 50000, Val: 10000


In [34]:
# Cell: Forward function
def forward_all(images):
    with torch.no_grad():
        # CLIP - may be [B, D] or [B, S, D]
        clip_feat = maxvit_clip(images)
        if len(clip_feat.shape) == 2:
            clip_feat = clip_feat.unsqueeze(1)  # [B, 1, D]

        # DINO - [B, S, D] with CLS + patches
        dino_feat = dino(images).last_hidden_state

    # Project
    clip_proj = proj_clip(clip_feat)  # [B, S, 512]
    dino_proj = proj_dino(dino_feat)  # [B, S, 512]

    # Route
    clip_routed = head_clip(clip_proj)
    dino_routed = head_dino(dino_proj)

    # Pool (CLS token or mean)
    clip_pooled = clip_routed[:, 0] if clip_routed.shape[1] > 1 else clip_routed.squeeze(1)
    dino_pooled = dino_routed[:, 0]  # CLS token

    # Fuse
    fused = fusion(torch.cat([clip_pooled, dino_pooled], dim=-1))

    # Classify
    logits = classifier(fused)
    logits_clip = classifier_clip(clip_pooled)
    logits_dino = classifier_dino(dino_pooled)

    return logits, logits_clip, logits_dino

In [None]:
# Cell: Train
EPOCHS = 5
history = []

for epoch in range(EPOCHS):
    # Train
    head_clip.train(); head_dino.train(); fusion.train(); classifier.train()
    proj_clip.train(); proj_dino.train()
    classifier_clip.train(); classifier_dino.train()

    correct, total = 0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        logits, logits_clip, logits_dino = forward_all(images)

        loss = criterion(logits, labels)
        loss += 0.1 * criterion(logits_clip, labels)
        loss += 0.1 * criterion(logits_dino, labels)

        loss.backward()
        optimizer.step()

        correct += (logits.argmax(-1) == labels).sum().item()
        total += labels.size(0)
        pbar.set_postfix({'acc': f'{correct/total:.1%}'})

    # Eval
    head_clip.eval(); head_dino.eval(); fusion.eval(); classifier.eval()
    classifier_clip.eval(); classifier_dino.eval()

    coll_correct, clip_correct, dino_correct, val_total = 0, 0, 0, 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            logits, logits_clip, logits_dino = forward_all(images)

            coll_correct += (logits.argmax(-1) == labels).sum().item()
            clip_correct += (logits_clip.argmax(-1) == labels).sum().item()
            dino_correct += (logits_dino.argmax(-1) == labels).sum().item()
            val_total += labels.size(0)

    coll_acc = coll_correct / val_total
    clip_acc = clip_correct / val_total
    dino_acc = dino_correct / val_total
    rho = coll_acc / max(clip_acc, dino_acc)

    history.append({'epoch': epoch+1, 'collective': coll_acc, 'clip': clip_acc, 'dino': dino_acc, 'rho': rho})
    print(f"Epoch {epoch+1}: Collective {coll_acc:.1%}, CLIP {clip_acc:.1%}, DINO {dino_acc:.1%}, ρ={rho:.3f}")

Epoch 1:   0%|          | 0/1563 [00:00<?, ?it/s]

In [None]:
# Cell: Summary
print("\n=== CLIP + DINO Cross-Architecture Results ===\n")
print("| Epoch | Collective | CLIP | DINO | ρ |")
print("|-------|------------|------|------|-------|")
for h in history:
    print(f"| {h['epoch']} | {h['collective']:.1%} | {h['clip']:.1%} | {h['dino']:.1%} | {h['rho']:.3f} |")

print(f"\nPhilosophical divergence:")
print(f"  CLIP: 'What is this called?' (text-image contrastive)")
print(f"  DINO: 'What does this look like?' (self-supervised structure)")

# after

#