Building a High-Performance ML Inference Server in Go with ONNX Runtime
How I Built GoServe: A Go-Based ML Server That's 53x Faster Than FastAPI
When I set out to build GoServe, a lightweight ML model server in Go, I knew I wanted to integrate ONNX Runtime for real-world machine learning inference. What I didn't expect was the journey of troubleshooting CGO, library mismatches, and API quirks that would ultimately lead to a server that's 53 times faster than FastAPI.
This guide documents every step, every issue, and every solution I encountered while integrating ONNX Runtime with Go. If you're trying to do the same, this will save you hours of debugging.
Table of Contents
- Why Go + ONNX?
- Prerequisites
- Part 1: Setting Up the Development Environment
- Part 2: Installing ONNX Runtime Go Bindings
- Part 3: The CGO Requirement
- Part 4: Version Mismatch Hell
- Part 5: Building the ONNX Session Wrapper
- Part 6: API Changes Between Versions
- Part 7: Path Resolution Issues
- Part 8: Testing and Validation
- Part 9: Building the REST API
- Part 10: Benchmarking Results
- Lessons Learned
- Complete Code Examples
Why Go + ONNX?
The Problem: Python-based ML servers (Flask, FastAPI) are slow, memory-heavy, and have 2-3 second cold starts. This kills serverless deployments and drives up costs.
The Solution: Go provides: - Sub-second cold starts - Low memory footprint (50-100 MB vs 300+ MB for Python) - Native concurrency - Static binaries (no dependency hell)
ONNX Runtime is the bridge that lets Go serve ML models exported from any framework (PyTorch, TensorFlow, scikit-learn, XGBoost).
Prerequisites
- Go 1.25+ installed
- Windows 10/11 (guide uses Windows, but concepts apply to Linux/macOS)
- Basic knowledge of Go and REST APIs
- A trained ML model (I used credit card fraud detection with XGBoost)
Part 1: Setting Up the Development Environment
Project Structure
goserve-project/
├── cmd/
│ └── server/
│ └── main.go
├── internal/
│ ├── onnx/ # ONNX Runtime wrapper
│ ├── models/ # Model registry
│ └── server/ # HTTP server
├── models/
│ └── fraud_detector.onnx
├── lib/
│ ├── onnxruntime/ # ONNX Runtime DLL
│ └── mingw64/ # MinGW-w64 compiler
└── examples/
└── fraud-detection/
├── train_model.py
├── fastapi_server.py
└── benchmark.py
Training a Model (Python)
First, I created a fraud detection model using XGBoost and exported it to ONNX:
# examples/fraud-detection/train_model.py (simplified)
from xgboost import XGBClassifier
import onnxmltools
from onnxmltools.convert.common.data_types import FloatTensorType
# Train model
model = XGBClassifier(n_estimators=100, max_depth=6)
model.fit(X_train, y_train)
# Export to ONNX
initial_type = [('input', FloatTensorType([None, 30]))]
onnx_model = onnxmltools.convert_xgboost(model, initial_types=initial_type)
onnxmltools.utils.save_model(onnx_model, 'models/fraud_detector.onnx')
Result: 154 KB ONNX model with 99.87% accuracy on fraud detection.
Part 2: Installing ONNX Runtime Go Bindings
Step 1: Add the Dependency
Easy, right? Not so fast.
Part 3: The CGO Requirement
Issue #1: Build Constraints Exclude All Go Files
When I first tried to build:
Problem: ONNX Runtime Go bindings use CGO (C-Go interop), which requires a C compiler. By default, CGO is disabled on Windows.
Solution: Install MinGW-w64
Option A: Chocolatey (requires admin)
Option B: Manual Download (what I did)
When Chocolatey failed due to permissions, I downloaded MinGW-w64 directly:
# Download from WinLibs
curl -L -o mingw-w64.zip "https://github.com/brechtsanders/winlibs_mingw/releases/download/15.2.0posix-13.0.0-msvcrt-r1/winlibs-x86_64-posix-seh-gcc-15.2.0-mingw-w64msvcrt-13.0.0-r1.zip"
# Extract
unzip mingw-w64.zip -d lib/
# Add to PATH (for current session)
export PATH="$(pwd)/lib/mingw64/bin:$PATH"
# Verify
gcc --version
Output:
Enable CGO and Test
New error:
Progress! Now Go can compile with CGO, but we need the ONNX Runtime library.
Part 4: Version Mismatch Hell
Issue #2: Missing ONNX Runtime Shared Library
ONNX Runtime Go bindings are just wrappers around the native ONNX Runtime C library. We need to download it separately.
Step 1: Download ONNX Runtime
I downloaded version 1.20.1 (matching our Python environment):
curl -L -o onnxruntime-win-x64-1.20.1.zip \
https://github.com/microsoft/onnxruntime/releases/download/v1.20.1/onnxruntime-win-x64-1.20.1.zip
unzip onnxruntime-win-x64-1.20.1.zip -d lib/onnxruntime/
Contents:
lib/onnxruntime/onnxruntime-win-x64-1.20.1/
├── lib/
│ ├── onnxruntime.dll (11 MB)
│ └── onnxruntime_providers_shared.dll
└── include/
└── onnxruntime_c_api.h (223 KB)
Step 2: Set Library Path in Code
import (
"path/filepath"
onnxruntime "github.com/yalue/onnxruntime_go"
)
func init() {
dllPath, _ := filepath.Abs("lib/onnxruntime/onnxruntime-win-x64-1.20.1/lib/onnxruntime.dll")
onnxruntime.SetSharedLibraryPath(dllPath)
}
Issue #3: API Version Mismatch
$ CGO_ENABLED=1 go run test_onnx.go
The requested API version [23] is not available,
only API versions [1, 20] are supported in this build.
Problem: I installed github.com/yalue/onnxruntime_go@v1.25.0 (latest), but ONNX Runtime library was v1.20.1. The API versions didn't match.
Solution: Downgrade Go Bindings
Why v1.12.0? It's the closest version that matches ONNX Runtime 1.20.1's API version (20).
Success!
$ CGO_ENABLED=1 go run test_onnx.go
Setting ONNX Runtime library path: C:\Users\...\onnxruntime.dll
ONNX Runtime initialized successfully!
ONNX Runtime version: 1.20.1
Lesson: Always match your Go binding version with the ONNX Runtime library version. Check the API version compatibility in the library's release notes.
Part 5: Building the ONNX Session Wrapper
Now that ONNX Runtime was working, I built a clean Go wrapper around it.
Design Goals
- Thread-safe: Multiple goroutines can use the same session
- Clean API: Hide ONNX Runtime complexity
- Batch inference: Support multiple predictions at once
- Automatic cleanup: Proper resource management
Core Structure
// internal/onnx/session.go
package onnx
import (
"fmt"
"sync"
onnxruntime "github.com/yalue/onnxruntime_go"
)
type Session struct {
modelPath string
inputName string
outputName string
mu sync.RWMutex
}
func NewSession(modelPath string) (*Session, error) {
// Initialize ONNX Runtime environment (once globally)
initOnce.Do(func() {
initErr = onnxruntime.InitializeEnvironment()
})
if initErr != nil {
return nil, fmt.Errorf("failed to initialize: %w", initErr)
}
return &Session{
modelPath: modelPath,
inputName: "input",
outputName: "probabilities",
}, nil
}
Part 6: API Changes Between Versions
Issue #4: AdvancedSession API Changed
The ONNX Runtime Go v1.25.0 documentation showed one API, but v1.12.0 (what I needed) had a different signature.
v1.25.0 API (from docs):
v1.12.0 API (actual):
session, err := onnxruntime.NewAdvancedSession(
modelPath,
[]string{inputName}, // Input names
[]string{outputName}, // Output names
[]onnxruntime.Value{inputTensor}, // Input tensors
[]onnxruntime.Value{outputTensor}, // Output tensors
nil) // Session options
Key Difference: v1.12.0 requires pre-allocated input/output tensors passed during session creation, and the session is single-use (destroyed after each inference).
Working Implementation
func (s *Session) Run(input [][]float32) ([][]float32, error) {
batchSize := len(input)
featureCount := len(input[0])
// Flatten input to 1D array
flatInput := make([]float32, batchSize*featureCount)
for i, sample := range input {
copy(flatInput[i*featureCount:], sample)
}
// Create input tensor
inputShape := onnxruntime.NewShape(int64(batchSize), int64(featureCount))
inputTensor, err := onnxruntime.NewTensor(inputShape, flatInput)
if err != nil {
return nil, fmt.Errorf("failed to create input tensor: %w", err)
}
defer inputTensor.Destroy()
// Create output tensor (fraud detection: batchSize x 2)
outputShape := onnxruntime.NewShape(int64(batchSize), 2)
outputTensor, err := onnxruntime.NewEmptyTensor[float32](outputShape)
if err != nil {
return nil, fmt.Errorf("failed to create output tensor: %w", err)
}
defer outputTensor.Destroy()
// Create session (single-use in v1.12.0)
session, err := onnxruntime.NewAdvancedSession(
s.modelPath,
[]string{s.inputName},
[]string{s.outputName},
[]onnxruntime.Value{inputTensor},
[]onnxruntime.Value{outputTensor},
nil,
)
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
defer session.Destroy()
// Run inference
if err := session.Run(); err != nil {
return nil, fmt.Errorf("inference failed: %w", err)
}
// Extract results
outputData := outputTensor.GetData()
// Reshape to [batchSize][2] for binary classification
result := make([][]float32, batchSize)
for i := 0; i < batchSize; i++ {
result[i] = make([]float32, 2)
result[i][0] = outputData[i*2] // legitimate probability
result[i][1] = outputData[i*2+1] // fraud probability
}
return result, nil
}
Performance Note: Creating a new session per inference isn't ideal, but v1.12.0's API requires it. In production, consider session pooling or upgrading to newer ONNX Runtime versions.
Part 7: Path Resolution Issues
Issue #5: DLL Not Found During Tests
When running tests from internal/onnx/:
$ cd internal/onnx
$ go test -v
Failed to create session: failed to initialize ONNX Runtime:
Error loading ONNX shared library: The specified module could not be found.
Problem: The DLL path in init() was relative (lib/onnxruntime/...), which broke when running tests from subdirectories.
Solution: Multi-Path Fallback
func init() {
// Try multiple locations to find the DLL
possiblePaths := []string{
"lib/onnxruntime/onnxruntime-win-x64-1.20.1/lib/onnxruntime.dll",
"../../lib/onnxruntime/onnxruntime-win-x64-1.20.1/lib/onnxruntime.dll",
"../../../lib/onnxruntime/onnxruntime-win-x64-1.20.1/lib/onnxruntime.dll",
}
for _, p := range possiblePaths {
if absPath, err := filepath.Abs(p); err == nil {
if _, err := os.Stat(absPath); err == nil {
onnxruntime.SetSharedLibraryPath(absPath)
return
}
}
}
}
Better Solution: Use an environment variable:
func init() {
// Check environment variable first
if dllPath := os.Getenv("ONNX_RUNTIME_DLL"); dllPath != "" {
onnxruntime.SetSharedLibraryPath(dllPath)
return
}
// Fallback to relative paths
// ...
}
Set it in your shell:
Part 8: Testing and Validation
Test Cases
// internal/onnx/session_test.go
func TestSessionInference(t *testing.T) {
session, err := NewSession("../../models/fraud_detector.onnx")
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
defer session.Close()
// Normal transaction (30 features: Time, Amount, V1-V28)
normalTransaction := [][]float32{{
0.5, 125.50, // Time, Amount
-0.8, 1.2, -0.3, 0.5, 1.1, -0.2, 0.9, -1.2, 0.3, -0.7, // V1-V10
1.5, -0.4, 0.6, -1.1, 0.8, 0.2, -0.9, 1.3, -0.5, 0.4, // V11-V20
-1.0, 0.7, -0.3, 0.9, -0.6, 1.2, -0.8, 0.5, // V21-V28
}}
result, err := session.Run(normalTransaction)
if err != nil {
t.Fatalf("Inference failed: %v", err)
}
fraudProb := result[0][1]
t.Logf("Fraud probability: %.4f%%", fraudProb*100)
// Should be low for normal transaction
if fraudProb > 0.5 {
t.Errorf("Expected low fraud probability, got %.4f", fraudProb)
}
}
Running Tests
Output:
=== RUN TestNewSession
session_test.go:23: Model inputs: [{Name:input Shape:[-1 30] Type:float32}]
session_test.go:31: Model outputs: [{Name:probabilities Shape:[-1 2] Type:float32}]
--- PASS: TestNewSession (0.03s)
=== RUN TestSessionInference
session_test.go:67: Fraud probability: 0.0031%
--- PASS: TestSessionInference (0.03s)
PASS
ok github.com/placeholder/goserve/internal/onnx 0.440s
Success! Predictions match the Python ONNX Runtime exactly.
Part 9: Building the REST API
Model Registry
// internal/models/registry.go
type Model struct {
Name string
Path string
Format string
Session *onnx.Session
InputInfo []onnx.TensorInfo
OutputInfo []onnx.TensorInfo
LoadedAt time.Time
}
type Registry struct {
mu sync.RWMutex
models map[string]*Model
}
func (r *Registry) LoadModel(name string, path string) error {
r.mu.Lock()
defer r.mu.Unlock()
session, err := onnx.NewSession(path)
if err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
r.models[name] = &Model{
Name: name,
Path: path,
Format: "onnx",
Session: session,
InputInfo: session.GetInputInfo(),
OutputInfo: session.GetOutputInfo(),
LoadedAt: time.Now(),
}
return nil
}
func (r *Registry) Infer(modelName string, input [][]float32) ([][]float32, error) {
model, err := r.GetModel(modelName)
if err != nil {
return nil, err
}
return model.Session.Run(input)
}
HTTP Handlers
// internal/server/model_handlers.go
func (s *Server) handleModelInfer(w http.ResponseWriter, r *http.Request) {
// Extract model name from URL: /v1/models/{model}/infer
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
modelName := pathParts[2]
// Parse request
var req InferenceRequest
json.NewDecoder(r.Body).Decode(&req)
// Validate input (fraud detection needs 30 features)
for i, features := range req.Inputs {
if len(features) != 30 {
http.Error(w,
fmt.Sprintf("Transaction %d has %d features, expected 30", i, len(features)),
http.StatusBadRequest)
return
}
}
// Run inference
probabilities, err := s.registry.Infer(modelName, req.Inputs)
if err != nil {
http.Error(w, fmt.Sprintf("Inference failed: %v", err),
http.StatusInternalServerError)
return
}
// Build response
predictions := make([]int, len(probabilities))
isFraud := make([]bool, len(probabilities))
confidence := make([]float32, len(probabilities))
for i, probs := range probabilities {
fraudProb := probs[1]
pred := 0
if fraudProb > 0.5 {
pred = 1
}
predictions[i] = pred
isFraud[i] = (pred == 1)
confidence[i] = fraudProb
}
response := InferenceResponse{
ModelName: modelName,
Predictions: predictions,
Probabilities: probabilities,
IsFraud: isFraud,
Confidence: confidence,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
Building and Running
# Build
export PATH="$(pwd)/lib/mingw64/bin:$PATH"
CGO_ENABLED=1 go build -o goserve.exe ./cmd/server
# Run
./goserve.exe
Output:
{"time":"2026-01-13T15:36:26Z","level":"INFO","msg":"Starting GoServe"}
{"time":"2026-01-13T15:36:26Z","level":"INFO","msg":"HTTP server listening","address":":8080"}
Testing the API
# Load model
curl -X POST http://localhost:8080/v1/models \
-H "Content-Type: application/json" \
-d '{"name":"fraud_detector","path":"models/fraud_detector.onnx"}'
# Response
{"status":"loaded","name":"fraud_detector","message":"Model loaded successfully"}
# Run inference
curl -X POST http://localhost:8080/v1/models/fraud_detector/infer \
-H "Content-Type: application/json" \
-d '{
"inputs": [[0.5, 125.50, -0.8, 1.2, -0.3, 0.5, 1.1, -0.2, 0.9, -1.2,
0.3, -0.7, 1.5, -0.4, 0.6, -1.1, 0.8, 0.2, -0.9, 1.3,
-0.5, 0.4, -1.0, 0.7, -0.3, 0.9, -0.6, 1.2, -0.8, 0.5]]
}'
# Response
{
"model_name":"fraud_detector",
"predictions":[0],
"probabilities":[[0.99996924,0.000030756]],
"is_fraud":[false],
"confidence":[0.000030756]
}
Perfect! 0.003% fraud probability for a normal transaction.
Part 10: Benchmarking Results
I compared GoServe against a FastAPI server using the identical ONNX model.
FastAPI Server (Baseline)
# examples/fraud-detection/fastapi_server.py
from fastapi import FastAPI
import onnxruntime as ort
import numpy as np
app = FastAPI()
session = ort.InferenceSession("models/fraud_detector.onnx")
@app.post("/v1/models/fraud_detector/infer")
async def predict(request: dict):
inputs = np.array(request["inputs"], dtype=np.float32)
outputs = session.run(None, {"input": inputs})
# ... (return probabilities)
Benchmark Script
# examples/fraud-detection/benchmark.py
from concurrent.futures import ThreadPoolExecutor
import time
import requests
def benchmark_server(url, num_requests=100, workers=10):
with ThreadPoolExecutor(max_workers=workers) as executor:
futures = [
executor.submit(send_request, url, sample_data)
for _ in range(num_requests)
]
latencies = [f.result() for f in futures]
return {
"throughput": num_requests / total_time,
"p50_latency": median(latencies),
"p95_latency": percentile(latencies, 95),
}
Results
Running 100 requests with 10 concurrent workers:
============================================================
GoServe vs FastAPI - Performance Comparison
============================================================
Metric GoServe FastAPI Improvement
------------------------------------------------------------------------
Throughput (req/s) 259.95 4.91 52.96x faster
P50 Latency (ms) 31.45 2034.78 64.70x faster
P95 Latency (ms) 68.35 2052.92 30.04x faster
P99 Latency (ms) 87.42 2059.46 23.56x faster
Avg Latency (ms) 36.36 2034.75 55.96x faster
============================================================
Summary:
============================================================
✓ GoServe is 53.0x faster than FastAPI
✓ GoServe has 64.7x better latency
Real-world impact:
- Handle 5,196% more transactions per second
- Reduce response time by 98%
- Potentially save ~$11,773/year on serverless compute
Why Is GoServe So Much Faster?
- Native Compilation: Go compiles to machine code; Python is interpreted
- No GIL: Go has true parallelism; Python's GIL limits concurrency
- Lower Overhead: Go's HTTP server is more efficient than uvicorn/FastAPI
- Memory Management: Go's GC is faster than Python's for short-lived requests
- Single Binary: No import overhead, no module loading
Lessons Learned
1. CGO Is Required for ONNX Runtime Go
You can't avoid it. Install a C compiler (MinGW-w64 on Windows, gcc on Linux).
2. Version Matching Is Critical
Always match your Go binding version with the ONNX Runtime library version:
| Go Binding Version | ONNX Runtime Version | API Version |
|---|---|---|
| v1.25.0 | 1.25.0 | 23 |
| v1.12.0 | 1.20.1 | 20 |
| v1.10.0 | 1.18.0 | 18 |
Check the ONNX Runtime releases and Go bindings before installing.
3. API Changes Between Versions
Don't trust documentation for older versions. The v1.12.0 API is very different from v1.25.0. Always check the actual source code or examples.
4. Path Resolution Matters
Use absolute paths or environment variables for the ONNX Runtime DLL. Relative paths break when running tests from subdirectories.
5. Session Lifecycle Varies
In v1.12.0, sessions are single-use. In newer versions, sessions can be reused. Plan your architecture accordingly.
6. Performance Is Worth It
The extra setup complexity pays off: 53x faster throughput and 65x better latency compared to Python.
Complete Code Examples
Minimal Working Example
package main
import (
"fmt"
"path/filepath"
onnxruntime "github.com/yalue/onnxruntime_go"
)
func main() {
// 1. Set DLL path
dllPath, _ := filepath.Abs("lib/onnxruntime/onnxruntime-win-x64-1.20.1/lib/onnxruntime.dll")
onnxruntime.SetSharedLibraryPath(dllPath)
// 2. Initialize environment
if err := onnxruntime.InitializeEnvironment(); err != nil {
panic(err)
}
defer onnxruntime.DestroyEnvironment()
// 3. Prepare input
inputData := []float32{0.5, 125.50, /* ... 28 more features ... */}
inputShape := onnxruntime.NewShape(1, 30)
inputTensor, _ := onnxruntime.NewTensor(inputShape, inputData)
defer inputTensor.Destroy()
// 4. Prepare output
outputShape := onnxruntime.NewShape(1, 2)
outputTensor, _ := onnxruntime.NewEmptyTensor[float32](outputShape)
defer outputTensor.Destroy()
// 5. Create session
session, _ := onnxruntime.NewAdvancedSession(
"models/fraud_detector.onnx",
[]string{"input"},
[]string{"probabilities"},
[]onnxruntime.Value{inputTensor},
[]onnxruntime.Value{outputTensor},
nil,
)
defer session.Destroy()
// 6. Run inference
session.Run()
// 7. Get results
outputData := outputTensor.GetData()
fmt.Printf("Fraud probability: %.4f%%\n", outputData[1]*100)
}
Build Script
Create a build.sh:
#!/bin/bash
# Set paths
export PATH="$(pwd)/lib/mingw64/bin:$PATH"
export ONNX_RUNTIME_DLL="$(pwd)/lib/onnxruntime/onnxruntime-win-x64-1.20.1/lib/onnxruntime.dll"
# Enable CGO
export CGO_ENABLED=1
# Build
echo "Building GoServe..."
go build -o goserve.exe ./cmd/server
echo "Done! Run with: ./goserve.exe"
Conclusion
Integrating ONNX Runtime with Go isn't trivial, but it's achievable with the right approach. The key challenges are:
- Setting up CGO with a C compiler
- Matching library and binding versions
- Understanding API differences between versions
- Managing DLL paths correctly
Once set up, the performance gains are massive: 53x faster than FastAPI for real-world ML inference.
When to Use Go + ONNX
✅ Use Go + ONNX when: - You need high throughput (100s-1000s of requests/second) - Low latency matters (<50ms) - Running in serverless/Kubernetes (fast cold starts) - Want to reduce infrastructure costs
❌ Stick with Python when: - Rapid prototyping (Python is faster to develop) - Complex preprocessing (NumPy/Pandas are easier) - Model training (Go doesn't have ML frameworks) - Team lacks Go expertise
Next Steps
- Production Hardening: Add metrics, tracing, and health checks
- GPU Support: Use CUDA execution provider for GPU acceleration
- Model Versioning: Support A/B testing with multiple model versions
- Caching: Add response caching for common inputs
- Batch Optimization: Implement dynamic batching for higher throughput
Resources
- ONNX Runtime Releases
- ONNX Runtime Go Bindings
- WinLibs MinGW-w64
- GoServe Repository (replace with your actual repo)
Questions? Issues? Drop a comment or open an issue on GitHub!
Found this helpful? Share it with others building high-performance ML infrastructure!
Built with Go 1.25, ONNX Runtime 1.20.1, and too much coffee. ☕