| """ |
| Hybrid ASPP-Attention Architecture (Asterisk Model) |
| Combines Adjacency-Structured Parallel Propagation (ASPP) with standard attention mechanisms |
| to enhance model expressiveness while maintaining efficiency. |
| |
| Architecture Design: |
| - Hybrid layers: Standard attention + ASPP operator in parallel |
| - Gate mechanism for dynamic fusion |
| - Knowledge distillation from SmolLM2-135M base model |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel |
| from transformers.models.llama.modeling_llama import ( |
| LlamaAttention, |
| LlamaDecoderLayer, |
| LlamaRMSNorm, |
| LlamaMLP, |
| ) |
| from transformers import AutoConfig, AutoModelForCausalLM |
| from typing import Optional, Tuple, List |
|
|
|
|
| class AsteriskConfig(LlamaConfig): |
| """ |
| Configuration class for Asterisk model. |
| Inherits from LlamaConfig with custom model_type. |
| """ |
| model_type = "asterisk" |
|
|
| def __init__( |
| self, |
| hybrid_layer_indices: Optional[List[int]] = None, |
| aspp_hidden_dim: Optional[int] = None, |
| aspp_num_steps: int = 2, |
| aspp_dropout: float = 0.1, |
| aspp_num_neighbors: int = 1, |
| |
| pi_flow: bool = False, |
| pi_flow_steps: int = 1, |
| pi_flow_scale: float = 0.2, |
| pi_flow_use_gate: bool = True, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.hybrid_layer_indices = hybrid_layer_indices |
| self.aspp_hidden_dim = aspp_hidden_dim |
| self.aspp_num_steps = aspp_num_steps |
| self.aspp_dropout = aspp_dropout |
| self.aspp_num_neighbors = aspp_num_neighbors |
| |
| self.pi_flow = pi_flow |
| self.pi_flow_steps = pi_flow_steps |
| self.pi_flow_scale = pi_flow_scale |
| self.pi_flow_use_gate = pi_flow_use_gate |
|
|
|
|
| class ASPPOperator(nn.Module): |
| """ |
| Asterisk Operator (ASPP) - Union-Find Graph Propagation |
| |
| Uses Union-Find (Disjoint Set Union) structure for dynamic parent connections: |
| - Each position maintains a parent pointer: parent[i] |
| - Initial structure: parent[i] = max(0, i-1) (linear chain) |
| - Message passing: aggregate self + parent features |
| - Can apply path compression for optimization |
| |
| Advantages: |
| - O(n) complexity with simple indexing |
| - Dynamic grouping of related positions |
| - Efficient parent-only propagation (no complex gather) |
| - Nearly constant time find with path compression |
| |
| Complexity: O(n) with α(n) ≈ O(1) per operation |
| Message passing: h_i^(t+1) = φ(h_i^(t), h_parent[i]) |
| |
| Args: |
| hidden_size: Dimension of hidden states (input/output) |
| aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size) |
| num_steps: Number of evolution steps K (default: 2) |
| dropout: Dropout rate for regularization (default: 0.1) |
| num_neighbors: Fixed at 1 (only parent) for Union-Find structure |
| """ |
|
|
| def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1, num_neighbors: int = 1): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.aspp_hidden_dim = aspp_hidden_dim or hidden_size |
| self.num_steps = num_steps |
| self.num_neighbors = 1 |
|
|
| |
| self.use_projection = (self.aspp_hidden_dim != hidden_size) |
| if self.use_projection: |
| self.down_proj = nn.Linear(hidden_size, self.aspp_hidden_dim) |
| self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size) |
| self.proj_dropout = nn.Dropout(dropout) |
|
|
| |
| self.message_net = nn.Sequential( |
| nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim * 2), |
| nn.SiLU(), |
| nn.Dropout(dropout), |
| nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim), |
| nn.Dropout(dropout), |
| ) |
|
|
| |
| self.k_logit = nn.Parameter(torch.tensor(1.0)) |
|
|
| |
| self.residual_scale = nn.Parameter(torch.tensor(0.1)) |
|
|
| |
| self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5) |
|
|
| def compute_parent_indices(self, seq_len: int, device) -> torch.Tensor: |
| """ |
| Compute parent index for each position using Union-Find structure |
| |
| Simple implementation: parent[i] = i-1 (linear chain) |
| - Position 0 points to itself (root) |
| - All others point to previous position |
| |
| Can be extended with dynamic union operations based on: |
| - Semantic similarity |
| - Positional heuristics |
| - Learned grouping |
| |
| Returns: [seq_len] tensor of parent indices |
| """ |
| |
| parent_indices = torch.arange(seq_len, device=device) - 1 |
| parent_indices[0] = 0 |
| parent_indices = torch.clamp(parent_indices, 0, seq_len - 1) |
|
|
| return parent_indices |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| hidden_states: [batch_size, seq_len, hidden_size] |
| Returns: |
| evolved_states: [batch_size, seq_len, hidden_size] |
| """ |
| batch_size, seq_len, _ = hidden_states.shape |
|
|
| |
| if self.use_projection: |
| h_t = self.down_proj(hidden_states) |
| h_t = self.proj_dropout(h_t) |
| else: |
| h_t = hidden_states |
|
|
| |
| k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps)) |
|
|
| |
| for t in range(k_steps): |
| |
| parent_indices = self.compute_parent_indices(seq_len, h_t.device) |
|
|
| |
| |
| |
| parent_features = h_t[:, parent_indices, :] |
|
|
| |
| message_input = torch.cat([h_t, parent_features], dim=-1) |
| h_t_next = self.message_net(message_input) |
|
|
| |
| h_t = h_t + self.residual_scale * h_t_next |
| h_t = self.norm(h_t) |
|
|
| |
| if self.use_projection: |
| h_t = self.up_proj(h_t) |
| h_t = self.proj_dropout(h_t) |
|
|
| return h_t |
|
|
|
|
| class HybridASPPAttentionLayer(LlamaDecoderLayer): |
| """ |
| Hybrid layer combining ASPP operator and standard attention |
| Inherits from LlamaDecoderLayer to maintain compatibility |
| |
| Architecture: |
| 1. Parallel branches: |
| - ASPP operator for local structured reasoning |
| - Standard LlamaAttention for global context |
| 2. Gated fusion of both outputs |
| 3. π-flow refinement (optional, per-layer) |
| 4. Feed-forward network |
| """ |
|
|
| def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 1): |
| |
| super().__init__(config, layer_idx) |
|
|
| |
| self.aspp_operator = ASPPOperator( |
| hidden_size=config.hidden_size, |
| aspp_hidden_dim=aspp_hidden_dim, |
| num_steps=aspp_num_steps, |
| dropout=aspp_dropout, |
| num_neighbors=aspp_num_neighbors |
| ) |
|
|
| |
| self.fusion_gate = nn.Sequential( |
| nn.Linear(config.hidden_size * 2, config.hidden_size), |
| nn.Dropout(aspp_dropout), |
| nn.Sigmoid() |
| ) |
|
|
| |
| with torch.no_grad(): |
| self.fusion_gate[0].bias.fill_(0.0) |
|
|
| |
| if getattr(config, 'pi_flow', False): |
| self.pi_flow_aspp = ASPPOperator( |
| hidden_size=config.hidden_size, |
| aspp_hidden_dim=aspp_hidden_dim, |
| num_steps=aspp_num_steps, |
| dropout=aspp_dropout, |
| num_neighbors=aspp_num_neighbors |
| ) |
|
|
| |
| self.pi_flow_scale = nn.Parameter( |
| torch.tensor(getattr(config, 'pi_flow_scale', 0.2)) |
| ) |
|
|
| |
| if getattr(config, 'pi_flow_use_gate', True): |
| self.pi_flow_gate = nn.Sequential( |
| nn.Linear(config.hidden_size, config.hidden_size // 4), |
| nn.SiLU(), |
| nn.Dropout(aspp_dropout), |
| nn.Linear(config.hidden_size // 4, 1), |
| nn.Sigmoid() |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values = None, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| """ |
| Override LlamaDecoderLayer.forward to add ASPP branch and π-flow |
| Returns single tensor like LlamaDecoderLayer |
| """ |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| |
| aspp_output = self.aspp_operator(hidden_states) |
|
|
| |
| attn_output, _ = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
|
|
| |
| fusion_input = torch.cat([aspp_output, attn_output], dim=-1) |
| gate = self.fusion_gate(fusion_input) |
|
|
| |
| fused_output = gate * aspp_output + (1 - gate) * attn_output |
|
|
| |
| hidden_states = residual + fused_output |
|
|
| |
| if hasattr(self, 'pi_flow_aspp'): |
| pi_flow_steps = getattr(self.config if hasattr(self, 'config') else kwargs.get('config'), 'pi_flow_steps', 1) |
|
|
| for step in range(pi_flow_steps): |
| |
| v = self.pi_flow_aspp(hidden_states) |
|
|
| |
| if hasattr(self, 'pi_flow_gate'): |
| gate = self.pi_flow_gate(hidden_states) |
| alpha = self.pi_flow_scale * gate |
| else: |
| alpha = self.pi_flow_scale |
|
|
| |
| hidden_states = hidden_states + alpha * v |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| |
| return hidden_states |
|
|
|
|
| class AsteriskLlamaModel(LlamaModel): |
| """ |
| Asterisk-Llama model with full hybrid ASPP-Attention architecture |
| |
| All layers use hybrid ASPP+Attention by default for maximum expressiveness. |
| """ |
|
|
| def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2): |
| super().__init__(config) |
|
|
| |
| if hybrid_layer_indices is None: |
| |
| num_layers = config.num_hidden_layers |
| hybrid_layer_indices = list(range(num_layers)) |
|
|
| self.hybrid_layer_indices = hybrid_layer_indices |
|
|
| |
| for idx in hybrid_layer_indices: |
| if idx < len(self.layers): |
| self.layers[idx] = HybridASPPAttentionLayer( |
| config, |
| layer_idx=idx, |
| aspp_hidden_dim=aspp_hidden_dim, |
| aspp_num_steps=aspp_num_steps, |
| aspp_dropout=aspp_dropout, |
| aspp_num_neighbors=aspp_num_neighbors |
| ) |
|
|
| |
| self.post_init() |
|
|
|
|
| class AsteriskForCausalLM(LlamaForCausalLM): |
| """ |
| Asterisk Causal LM with Hybrid ASPP-Attention architecture |
| |
| Registered as: AsteriskForCausalLM |
| """ |
|
|
| config_class = AsteriskConfig |
|
|
| def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2): |
| |
| if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'): |
| hybrid_layer_indices = config.hybrid_layer_indices |
| if aspp_hidden_dim is None and hasattr(config, 'aspp_hidden_dim'): |
| aspp_hidden_dim = config.aspp_hidden_dim |
| if hasattr(config, 'aspp_num_steps'): |
| aspp_num_steps = config.aspp_num_steps |
| if hasattr(config, 'aspp_dropout'): |
| aspp_dropout = config.aspp_dropout |
| if hasattr(config, 'aspp_num_neighbors'): |
| aspp_num_neighbors = config.aspp_num_neighbors |
|
|
| super().__init__(config) |
|
|
| |
| self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors) |
|
|
| |
| self.config.hybrid_layer_indices = hybrid_layer_indices |
|
|
| |
| self.post_init() |
|
|
| @classmethod |
| def from_pretrained_base( |
| cls, |
| base_model_path: str, |
| hybrid_layer_indices: Optional[List[int]] = None, |
| aspp_hidden_dim: Optional[int] = None, |
| aspp_num_steps: int = 2, |
| aspp_dropout: float = 0.1, |
| aspp_num_neighbors: int = 1, |
| |
| pi_flow: bool = False, |
| pi_flow_steps: int = 1, |
| pi_flow_scale: float = 0.2, |
| pi_flow_use_gate: bool = True, |
| **kwargs |
| ): |
| """ |
| Load base model and convert to Asterisk architecture |
| |
| Args: |
| base_model_path: Path to base SmolLM2 model |
| hybrid_layer_indices: Which layers to make hybrid (None for all) |
| aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size) |
| aspp_num_steps: Number of evolution steps K for ASPP (default: 2) |
| aspp_dropout: Dropout rate for ASPP regularization (default: 0.1) |
| aspp_num_neighbors: Number of neighbors for Union-Find (fixed at 1: only parent) |
| pi_flow: Enable π-flow refinement step (default: False) |
| pi_flow_steps: Number of flow refinement steps (default: 1) |
| pi_flow_scale: Initial flow scale parameter (default: 0.2) |
| pi_flow_use_gate: Use token-wise adaptive gating (default: True) |
| """ |
| |
| base_model = LlamaForCausalLM.from_pretrained(base_model_path, **kwargs) |
| base_config = base_model.config |
|
|
| |
| asterisk_config = AsteriskConfig( |
| **base_config.to_dict(), |
| hybrid_layer_indices=hybrid_layer_indices, |
| aspp_hidden_dim=aspp_hidden_dim, |
| aspp_num_steps=aspp_num_steps, |
| aspp_dropout=aspp_dropout, |
| aspp_num_neighbors=aspp_num_neighbors, |
| pi_flow=pi_flow, |
| pi_flow_steps=pi_flow_steps, |
| pi_flow_scale=pi_flow_scale, |
| pi_flow_use_gate=pi_flow_use_gate, |
| ) |
|
|
| |
| asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors) |
|
|
| |
| asterisk_model.load_state_dict(base_model.state_dict(), strict=False) |
|
|
| print(f"✓ Converted base model to Asterisk architecture with Graph Propagation") |
| print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}") |
| aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)" |
| print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}, neighbors={aspp_num_neighbors}") |
| if pi_flow: |
| print(f" π-flow enabled: steps={pi_flow_steps}, scale={pi_flow_scale}, gate={pi_flow_use_gate}") |
|
|
| return asterisk_model, base_model |
|
|
|
|
| |
| AutoConfig.register("asterisk", AsteriskConfig) |
| AutoModelForCausalLM.register(AsteriskConfig, AsteriskForCausalLM) |
|
|
|
|
| def get_model_info(model): |
| """Print model architecture information""" |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
| print(f" • Total parameters: {total_params:,}") |
| print(f" • Trainable parameters: {trainable_params:,}") |
| print(f" • Model size: {total_params * 4 / 1024**2:.2f} MB (fp32)") |
|
|
| if isinstance(model, AsteriskForCausalLM): |
| print(f" • Hybrid layer indices: {model.model.hybrid_layer_indices}") |
| print(f" • Number of hybrid layers: {len(model.model.hybrid_layer_indices)}") |
|
|