Zeta is a modular PyTorch framework designed to simplify the development of AI models by providing reusable, high-performance building blocks. Think of it as a collection of LEGO blocks for AI each component is carefully crafted, tested, and optimized, allowing you to quickly assemble state-of-the-art models without reinventing the wheel.
Zeta provides a comprehensive library of modular components commonly used in modern AI architectures, including:
- Attention Mechanisms: Multi-query attention, sigmoid attention, flash attention, and more
- Mixture of Experts (MoE): Efficient expert routing and gating mechanisms
- Neural Network Modules: Feedforward networks, activation functions, normalization layers
- Quantization: BitLinear, dynamic quantization, and other optimization techniques
- Architectures: Transformers, encoders, decoders, vision transformers, and complete model implementations
- Training Utilities: Optimization algorithms, logging, and performance monitoring
Each component is designed to be:
- Modular: Drop-in replacements that work seamlessly with PyTorch
- High-Performance: Optimized implementations with fused kernels where applicable
- Well-Tested: Comprehensive test coverage ensuring reliability
- Production-Ready: Used in hundreds of models across various domains
pip3 install -U zetascaleMulti-query attention reduces memory usage while maintaining model quality by sharing key and value projections across attention heads.
import torch
from zeta import MultiQueryAttention
# Initialize the model
model = MultiQueryAttention(
dim=512,
heads=8,
)
# Forward pass
text = torch.randn(2, 4, 512)
output, _, _ = model(text)
print(output.shape) # torch.Size([2, 4, 512])The SwiGLU activation function applies a gating mechanism to selectively pass information through the network.
import torch
from zeta.nn import SwiGLUStacked
x = torch.randn(5, 10)
swiglu = SwiGLUStacked(10, 20)
output = swiglu(x)
print(output.shape) # torch.Size([5, 20])Relative position bias quantizes the distance between positions into buckets and uses embeddings to provide position-aware attention biases.
import torch
from torch import nn
from zeta.nn import RelativePositionBias
# Initialize the module
rel_pos_bias = RelativePositionBias()
# Compute bias for attention mechanism
bias_matrix = rel_pos_bias(1, 10, 10)
# Use in custom attention
class CustomAttention(nn.Module):
def __init__(self):
super().__init__()
self.rel_pos_bias = RelativePositionBias()
def forward(self, queries, keys):
bias = self.rel_pos_bias(queries.size(0), queries.size(1), keys.size(1))
# Use bias in attention computation
return NoneA flexible feedforward module with optional GLU activation and LayerNorm, commonly used in transformer architectures.
import torch
from zeta.nn import FeedForward
model = FeedForward(256, 512, glu=True, post_act_ln=True, dropout=0.2)
x = torch.randn(1, 256)
output = model(x)
print(output.shape) # torch.Size([1, 512])BitLinear performs linear transformation with quantization and dequantization, reducing memory usage while maintaining performance. Based on BitNet: Scaling 1-bit Transformers for Large Language Models.
import torch
from torch import nn
import zeta.quant as qt
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = qt.BitLinear(10, 20)
def forward(self, x):
return self.linear(x)
model = MyModel()
input = torch.randn(128, 10)
output = model(input)
print(output.size()) # torch.Size([128, 20])A complete implementation of the PalmE multi-modal model architecture, combining a ViT image encoder with a transformer decoder for vision-language tasks.
import torch
from zeta.structs import (
AutoRegressiveWrapper,
Decoder,
Encoder,
Transformer,
ViTransformerWrapper,
)
class PalmE(torch.nn.Module):
"""
PalmE is a transformer architecture that uses a ViT encoder and a transformer decoder.
This implementation demonstrates how to combine Zeta's modular components to build
a complete multi-modal model architecture.
"""
def __init__(
self,
image_size=256,
patch_size=32,
encoder_dim=512,
encoder_depth=6,
encoder_heads=8,
num_tokens=20000,
max_seq_len=1024,
decoder_dim=512,
decoder_depth=6,
decoder_heads=8,
alibi_num_heads=4,
attn_kv_heads=2,
use_abs_pos_emb=False,
cross_attend=True,
alibi_pos_bias=True,
rotary_xpos=True,
attn_flash=True,
qk_norm=True,
):
super().__init__()
# Vision encoder
self.encoder = ViTransformerWrapper(
image_size=image_size,
patch_size=patch_size,
attn_layers=Encoder(
dim=encoder_dim,
depth=encoder_depth,
heads=encoder_heads
),
)
# Language decoder
self.decoder = Transformer(
num_tokens=num_tokens,
max_seq_len=max_seq_len,
use_abs_pos_emb=use_abs_pos_emb,
attn_layers=Decoder(
dim=decoder_dim,
depth=decoder_depth,
heads=decoder_heads,
cross_attend=cross_attend,
alibi_pos_bias=alibi_pos_bias,
alibi_num_heads=alibi_num_heads,
rotary_xpos=rotary_xpos,
attn_kv_heads=attn_kv_heads,
attn_flash=attn_flash,
qk_norm=qk_norm,
),
)
# Enable autoregressive generation
self.decoder = AutoRegressiveWrapper(self.decoder)
def forward(self, img: torch.Tensor, text: torch.Tensor):
"""Forward pass of the model."""
encoded = self.encoder(img, return_embeddings=True)
return self.decoder(text, context=encoded)
# Usage
img = torch.randn(1, 3, 256, 256)
text = torch.randint(0, 20000, (1, 1024))
model = PalmE()
output = model(img, text)
print(output.shape)A complete U-Net implementation for image segmentation and generative tasks.
import torch
from zeta.nn import Unet
model = Unet(n_channels=1, n_classes=2)
x = torch.randn(1, 1, 572, 572)
y = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")Convert images into patch embeddings suitable for transformer-based vision models.
import torch
from zeta.nn import VisionEmbedding
vision_embedding = VisionEmbedding(
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
contain_mask_token=True,
prepend_cls_token=True,
)
input_image = torch.rand(1, 3, 224, 224)
output = vision_embedding(input_image)
print(output.shape)Niva provides dynamic quantization for specific layer types, ideal for models with variable runtime activations.
import torch
from torch import nn
from zeta import niva
# Load a pre-trained model
model = YourModelClass()
# Quantize the model dynamically
niva(
model=model,
model_path="path_to_pretrained_weights.pt",
output_path="quantized_model.pt",
quant_type="dynamic",
quantize_layers=[nn.Linear, nn.Conv2d],
dtype=torch.qint8,
)Zeta includes several fused operations that combine multiple operations into single kernels for improved performance.
Fuses two dense operations with GELU activation for up to 2x speedup.
import torch
from zeta.nn import FusedDenseGELUDense
x = torch.randn(1, 512)
model = FusedDenseGELUDense(512, 1024)
out = model(x)
print(out.shape) # torch.Size([1, 1024])Fuses dropout and layer normalization for faster feedforward networks.
import torch
from zeta.nn import FusedDropoutLayerNorm
model = FusedDropoutLayerNorm(dim=512)
x = torch.randn(1, 512)
output = model(x)
print(output.shape) # torch.Size([1, 512])PyTorch implementation of the Mamba state space model architecture.
import torch
from zeta.nn import MambaBlock
block = MambaBlock(dim=64, depth=1)
x = torch.randn(1, 10, 64)
y = block(x)
print(y.shape) # torch.Size([1, 10, 64])Feature-wise Linear Modulation for conditional feature transformation.
import torch
from zeta.nn import Film
film_layer = Film(dim=128, hidden_dim=64, expanse_ratio=4)
conditions = torch.randn(10, 128)
hiddens = torch.randn(10, 1, 128)
modulated_features = film_layer(conditions, hiddens)
print(modulated_features.shape) # torch.Size([10, 1, 128])The hyper_optimize decorator` provides a unified interface for multiple optimization techniques.
import torch
from zeta.nn import hyper_optimize
@hyper_optimize(
torch_fx=False,
torch_script=False,
torch_compile=True,
quantize=True,
mixed_precision=True,
enable_metrics=True,
)
def model(x):
return x @ x
out = model(torch.randn(1, 3, 32, 32))
print(out)DPO implementation for reinforcement learning from human feedback (RLHF) applications.
import torch
from torch import nn
from zeta.rl import DPO
class PolicyModel(nn.Module):
def __init__(self, dim, output_dim):
super().__init__()
self.fc = nn.Linear(dim, output_dim)
def forward(self, x):
return self.fc(x)
dim = 10
output_dim = 5
policy_model = PolicyModel(dim, output_dim)
dpo_model = DPO(model=policy_model, beta=0.1)
preferred_seq = torch.randint(0, output_dim, (3, dim))
unpreferred_seq = torch.randint(0, output_dim, (3, dim))
loss = dpo_model(preferred_seq, unpreferred_seq)
print(loss)A decorator for comprehensive model execution logging, including parameters, gradients, and memory usage.
import torch
from torch import nn
from zeta.utils.verbose_execution import verbose_execution
@verbose_execution(log_params=True, log_gradients=True, log_memory=True)
class YourPyTorchModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 3)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
self.fc = nn.Linear(64 * 222 * 222, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.flatten(x)
x = self.fc(x)
return x
model = YourPyTorchModel()
input_tensor = torch.randn(1, 3, 224, 224)
output = model(input_tensor)
# Gradient information requires backward pass
loss = output.sum()
loss.backward()An attention mechanism that replaces softmax with sigmoid, providing up to 18% speedup while maintaining performance.
import torch
from zeta import SigmoidAttention
batch_size = 32
seq_len = 128
dim = 512
heads = 8
x = torch.rand(batch_size, seq_len, dim)
mask = torch.ones(batch_size, seq_len, seq_len)
sigmoid_attn = SigmoidAttention(dim, heads, seq_len)
output = sigmoid_attn(x, mask)
print(output.shape) # torch.Size([32, 128, 512])Comprehensive documentation is available at zeta.apac.ai.
There are various examples that you can try out in the examples folder
To run the full test suite:
python3 -m pip install -e '.[testing]' # Install extra dependencies for testing
python3 -m pytest tests/ # Run the entire test suiteFor more details, refer to the CI workflow configuration.
Join our growing community for real-time support, ideas, and discussions on building better AI models.
| Platform | Link | Description |
|---|---|---|
| Docs | zeta.apac.ai | Official documentation |
| Discord | Join our Discord | Live chat & community |
| @kyegomez | Follow for updates | |
| The Swarm Corporation | Connect professionally | |
| YouTube | YouTube Channel | Watch our videos |
Zeta is an open-source project, and contributions are welcome! If you want to create new features, fix bugs, or improve the infrastructure, we'd love to have you contribute.
Getting Started:
- Pick any issue with the
good first issuetag to get started - Read our Contributing Guidelines
- Check out our contributing board for roadmap discussions
Report Issues:
Thank you to all of our contributors who have built this great framework 🙌
If you use Zeta in your research or projects, please cite it:
@misc{zetascale,
title = {Zetascale Framework},
author = {Kye Gomez},
year = {2024},
howpublished = {\url{https://github.com/kyegomez/zeta}},
}Apache 2.0 License
