Zerfoo is a machine learning framework built from the ground up in Go. It is designed for performance, scalability, and developer experience, enabling everything from practical deep learning tasks to large-scale AGI experimentation.
By leveraging Go's strengths—simplicity, strong typing, and best-in-class concurrency—Zerfoo provides a robust and maintainable foundation for building production-ready ML systems.
Status: Pre-release — actively in development.
Define, train, and run a simple model in just a few lines of idiomatic Go.
package main
import (
"fmt"
"github.com/zerfoo/zerfoo/compute"
"github.com/zerfoo/zerfoo/graph"
"github.com/zerfoo/zerfoo/layers/activations"
"github.com/zerfoo/zerfoo/layers/core"
"github.com/zerfoo/zerfoo/tensor"
"github.com/zerfoo/zerfoo/training"
)
func main() {
// 1. Create a compute engine
engine := compute.NewCPUEngine()
// 2. Define the model architecture using a graph builder
builder := graph.NewBuilder[float32](engine)
input := builder.Input([]int{1, 10})
dense1 := builder.AddNode(core.NewDense(10, 32), input)
act1 := builder.AddNode(activations.NewReLU(), dense1)
output := builder.AddNode(core.NewDense(32, 1), act1)
// 3. Build the computational graph
forward, backward, err := builder.Build(output)
if err != nil {
panic(err)
}
// 4. Create an optimizer
optimizer := training.NewAdamOptimizer[float32](0.01)
// 5. Generate dummy data
inputTensor, _ := tensor.NewTensor(engine, []int{1, 10})
targetTensor, _ := tensor.NewTensor(engine, []int{1, 1})
// 6. Run the training loop
for i := 0; i < 100; i++ {
// Forward pass
predTensor := forward(map[graph.NodeHandle]*tensor.Tensor[float32]{input: inputTensor})
// Compute loss (dummy loss for this example)
loss := predTensor.Data()[0] - targetTensor.Data()[0]
grad := tensor.NewScalar(engine, 2*loss)
// Backward pass to compute gradients
backward(grad, map[graph.NodeHandle]*tensor.Tensor[float32]{input: inputTensor})
// Update weights
optimizer.Step(builder.Parameters())
}
fmt.Println("Training complete!")
}
Zerfoo is designed to address the limitations of existing ML frameworks by embracing Go's philosophy.
- ✅ Idiomatic and Simple: Build models using clean, readable Go. We favor composition over inheritance and explicit interfaces over magic.
- 🚀 High-Performance by Design: A static graph execution model, pluggable compute engines (CPU, GPU planned), and minimal Cgo overhead ensure your code runs fast.
- ⛓️ Robust and Type-Safe: Leverage Go's strong type system to catch errors at compile time, not runtime. Shape mismatches and configuration issues are caught before training even begins.
- 🌐 Scalable from the Start: With first-class support for distributed training, Zerfoo is architected to scale from a single laptop to massive compute clusters.
- 🧩 Modular and Extensible: A clean, layered architecture allows you to extend any part of the framework—from custom layers to new hardware backends—by implementing well-defined interfaces.
- Declarative Graph Construction: Define models programmatically with a
Builder
API or declaratively using Go structs and tags. - Static Execution Graph: The graph is built and validated once, resulting in error-free forward/backward passes and significant performance optimizations.
- Pluggable Compute Engines: A hardware abstraction layer allows Zerfoo to target different backends. The default is a pure Go engine, with BLAS and GPU (CUDA) engines planned.
- Automatic Differentiation: Gradients are computed efficiently using reverse-mode AD (backpropagation).
- First-Class Distributed Training: A
DistributedStrategy
interface abstracts away the complexity of multi-node training, with support for patterns like All-Reduce and Parameter Server. - Multi-Precision Support: Native support for
float32
andfloat64
, withfloat16
andfloat8
for cutting-edge, low-precision training. - ONNX Interoperability: Export models to the Open Neural Network Exchange (ONNX) format for deployment in any compatible environment.
- Data Package: Comprehensive data loading with native Parquet support for efficient dataset handling
- Feature Transformers: Built-in transformers for lagged features, rolling statistics, and FFT-based frequency domain features
- Normalization: Automatic feature normalization (z-score) for stable training
- Hierarchical Recurrent Modules (HRM): Dual-timescale recurrent architecture with H-Module (high-level) and L-Module (low-level) for complex temporal reasoning
- Spectral Fingerprint Layers: Advanced signal processing layers using Koopman operator theory and Fourier analysis for regime detection
- Feature-wise Linear Modulation (FiLM): Conditional normalization for dynamic model behavior adaptation
- Grouped-Query Attention (GQA): Memory-efficient attention mechanism with optional Rotary Positional Embeddings (RoPE)
- Advanced Metrics: Comprehensive evaluation metrics including Pearson/Spearman correlation, MSE, RMSE, and MAE
- Flexible Training Loops: Generic trainer with pluggable loss functions, optimizers, and gradient strategies
Zerfoo is built on a clean, layered architecture that separates concerns, ensuring the framework is both powerful and maintainable. A core tenet of this architecture is the use of the Zerfoo Model Format (ZMF) as the universal intermediate representation for models. This enables a strict decoupling from model converters like zonnx
, ensuring that zerfoo
focuses solely on efficient model execution without any ONNX-specific dependencies.
- Composition Layer: Define models as a Directed Acyclic Graph (DAG) of nodes (layers, activations). This layer is hardware-agnostic.
- Execution Layer: A pluggable
Engine
performs the actual tensor computations on specific hardware (CPU, GPU). This allows the same model to run on different devices without code changes.
For a deep dive into the design philosophy, core interfaces, and technical roadmap, please read our Architectural Design Document.
This project is in a pre-release state. The API is not yet stable.
To install Zerfoo, use go get
:
go get github.com/zerfoo/zerfoo
import (
"github.com/zerfoo/zerfoo/graph"
"github.com/zerfoo/zerfoo/layers/core"
"github.com/zerfoo/zerfoo/layers/attention"
"github.com/zerfoo/zerfoo/layers/activations"
)
// Build sophisticated model with attention
builder := graph.NewBuilder[float32](engine)
input := builder.Input([]int{batchSize, seqLen, inputDim})
// Add linear transformation
linear := builder.AddNode(core.NewLinear("input_proj", engine, ops, inputDim, hiddenDim), input)
// Add multi-head attention
attention := builder.AddNode(attention.NewGlobalAttention(engine, ops, hiddenDim, numHeads, headSize), linear)
// Add residual connection and activation
residual := builder.AddNode(core.NewAdd(engine), linear, attention)
activated := builder.AddNode(activations.NewTanh(engine, ops), residual)
// Output projection
output := builder.AddNode(core.NewLinear("output_proj", engine, ops, hiddenDim, outputDim), activated)
import (
"github.com/zerfoo/zerfoo/graph"
"github.com/zerfoo/zerfoo/layers/hrm"
"github.com/zerfoo/zerfoo/layers/core"
)
// Build hierarchical model
builder := graph.NewBuilder[float32](engine)
spectralInput := builder.Input([]int{windowSize})
featureInput := builder.Input([]int{numStocks, numFeatures})
// Add spectral fingerprint layer
spectral := builder.AddNode(core.NewSpectralFingerprint(windowSize, fingerprintDim), spectralInput)
// Add hierarchical recurrent modules
hModule := builder.AddNode(hrm.NewHModule(hDim), spectral)
lModule := builder.AddNode(hrm.NewLModule(lDim), featureInput, hModule)
// Add FiLM conditioning
film := builder.AddNode(core.NewFiLM(lDim), lModule, hModule)
output := builder.AddNode(core.NewDense(lDim, 1), film)
model, _, err := builder.Build(output)
import (
"github.com/zerfoo/zerfoo/training"
"github.com/zerfoo/zerfoo/metrics"
)
// Set up trainer with metrics evaluation
trainer := training.NewDefaultTrainer(graph, lossNode, optimizer, strategy)
// Training loop with evaluation metrics
for epoch := 0; epoch < epochs; epoch++ {
loss, err := trainer.TrainBatch(batchData)
if err != nil {
log.Printf("Training error: %v", err)
continue
}
// Evaluate with correlation and error metrics
evalMetrics := metrics.CalculateMetrics(predictions, targets)
log.Printf("Epoch %d: Loss=%.4f, Pearson=%.4f, MSE=%.4f",
epoch, loss, evalMetrics.PearsonCorrelation, evalMetrics.MSE)
}
Zerfoo is an ambitious project, and we welcome contributions from the community! Whether you're an expert in machine learning, compilers, or distributed systems, or a Go developer passionate about AI, there are many ways to get involved.
Please read our Architectural Design Document to understand the project's vision and technical foundations.
Zerfoo is licensed under the Apache 2.0 License.