GoMLX: ML in Go without Python
How ML models are implemented
-
Written in Python, using frameworks like TensorFlow, JAX or Pytorch that take care of:
- Expressive way to describe the model architecture, including auto-differentiation for training.
- Efficient implementation of computational primitives on common HW: CPUs, GPUs and TPUs.
-
The frameworks that provide high-level primitives to define and translate ML models to a common interchange format called StableHLO (High-Level Operations).
-
The OpenXLA system, which includes two major components: the XLA compiler translating HLO to HW machine code, and PJRT - the runtime component responsible for managing HW devices, moving data (tensors) between the host CPU and these devices, executing tasks, sharding and so on.
-
HW that executes these models efficiently. (C/C++ hidden complexity)
GoMLX
- Wraps XLA - access to all building blocks TF and JAX use
Examples
- a CNN (convolutional neural network) without any Python, training it on CIFAR-10
as expected, Go code is longer and more explicit
// define the model graph
func C10ConvModel(mlxctx *mlxcontext.Context, spec any, inputs []*graph.Node) []*graph.Node {
batchedImages := inputs[0]
g := batchedImages.Graph()
dtype := batchedImages.DType()
batchSize := batchedImages.Shape().Dimensions[0]
logits := batchedImages
layerIdx := 0
nextCtx := func(name string) *mlxcontext.Context {
newCtx := mlxctx.Inf("%03d_%s", layerIdx, name)
layerIdx++
return newCtx
}
// Convolution / activation layers
logits = layers.Convolution(nextCtx("conv"), logits).Filters(32).KernelSize(3).PadSame().Done()
logits.AssertDims(batchSize, 32, 32, 32)
logits = activations.Relu(logits)
logits = layers.Convolution(nextCtx("conv"), logits).Filters(32).KernelSize(3).PadSame().Done()
logits = activations.Relu(logits)
logits = graph.MaxPool(logits).Window(2).Done()
logits = layers.DropoutNormalize(nextCtx("dropout"), logits, graph.Scalar(g, dtype, 0.3), true)
logits.AssertDims(batchSize, 16, 16, 32)
logits = layers.Convolution(nextCtx("conv"), logits).Filters(64).KernelSize(3).PadSame().Done()
logits.AssertDims(batchSize, 16, 16, 64)
logits = activations.Relu(logits)
logits = layers.Convolution(nextCtx("conv"), logits).Filters(64).KernelSize(3).PadSame().Done()
logits.AssertDims(batchSize, 16, 16, 64)
logits = activations.Relu(logits)
logits = graph.MaxPool(logits).Window(2).Done()
logits = layers.DropoutNormalize(nextCtx("dropout"), logits, graph.Scalar(g, dtype, 0.5), true)
logits.AssertDims(batchSize, 8, 8, 64)
logits = layers.Convolution(nextCtx("conv"), logits).Filters(128).KernelSize(3).PadSame().Done()
logits.AssertDims(batchSize, 8, 8, 128)
logits = activations.Relu(logits)
logits = layers.Convolution(nextCtx("conv"), logits).Filters(128).KernelSize(3).PadSame().Done()
logits.AssertDims(batchSize, 8, 8, 128)
logits = activations.Relu(logits)
logits = graph.MaxPool(logits).Window(2).Done()
logits = layers.DropoutNormalize(nextCtx("dropout"), logits, graph.Scalar(g, dtype, 0.5), true)
logits.AssertDims(batchSize, 4, 4, 128)
// Flatten logits, and apply dense layer
logits = graph.Reshape(logits, batchSize, -1)
logits = layers.Dense(nextCtx("dense"), logits, true, 128)
logits = activations.Relu(logits)
logits = layers.DropoutNormalize(nextCtx("dropout"), logits, graph.Scalar(g, dtype, 0.5), true)
numClasses := 10
logits = layers.Dense(nextCtx("dense"), logits, true, numClasses)
return []*graph.Node{logits}
}
// the classifier
func main() {
flagCheckpoint := flag.String("checkpoint", "", "Directory to load checkpoint from")
flag.Parse()
mlxctx := mlxcontext.New()
backend := backends.New()
_, err := checkpoints.Load(mlxctx).Dir(*flagCheckpoint).Done()
if err != nil {
panic(err)
}
mlxctx = mlxctx.Reuse() // helps sanity check the loaded context
exec := mlxcontext.NewExec(backend, mlxctx.In("model"), func(mlxctx *mlxcontext.Context, image *graph.Node) *graph.Node {
// Convert our image to a tensor with batch dimension of size 1, and pass
// it to the C10ConvModel graph.
image = graph.ExpandAxes(image, 0) // Create a batch dimension of size 1.
logits := cnnmodel.C10ConvModel(mlxctx, nil, []*graph.Node{image})[0]
// Take the class with highest logit value, then remove the batch dimension.
choice := graph.ArgMax(logits, -1, dtypes.Int32)
return graph.Reshape(choice)
})
// classify takes a 32x32 image and returns a Cifar-10 classification according
// to the models. Use C10Labels to convert the returned class to a string
// name. The returned class is from 0 to 9.
classify := func(img image.Image) int32 {
input := images.ToTensor(dtypes.Float32).Single(img)
outputs := exec.Call(input)
classID := tensors.ToScalar[int32](outputs[0])
return classID
}
// ...
}
- a Gemma2 from Kaggle example
var (
flagDataDir = flag.String("data", "", "dir with converted weights")
flagVocabFile = flag.String("vocab", "", "tokenizer vocabulary file")
)
func main() {
flag.Parse()
ctx := context.New()
// Load model weights from the checkpoint downloaded from Kaggle.
err := kaggle.ReadConvertedWeights(ctx, *flagDataDir)
if err != nil {
log.Fatal(err)
}
// Load tokenizer vocabulary.
vocab, err := sentencepiece.NewFromPath(*flagVocabFile)
if err != nil {
log.Fatal(err)
}
// Create a Gemma sampler and start sampling tokens.
sampler, err := samplers.New(backends.New(), ctx, vocab, 256)
if err != nil {
log.Fatalf("%+v", err)
}
start := time.Now()
output, err := sampler.Sample([]string{
"Are bees and wasps similar?",
})
if err != nil {
log.Fatalf("%+v", err)
}
fmt.Printf("\tElapsed time: %s\n", time.Since(start))
fmt.Printf("Generated text:\n%s\n", strings.Join(output, "\n\n"))
}
Conclusion
-
Using GoMLX can help implement ML inference in Go without Python
-
Since it's a relatively new project, it may be a little risky for production uses for now.
https://eli.thegreenplace.net/2024/gomlx-ml-in-go-without-python