Technical Report

Inference Optimization for Large-Scale Model Architectures

JAX & Accelerated Serving

Performance cost-improvements for popular large-model architectures using JAX/XLA, quantisation and sharding.

Problem Statement

Large-scale transformer models — dense, MoE or sparsely-activated — face major serving challenges: high latency per token, large hardware/memory cost, limited throughput for many users. Our aim: serve these models at scale — many users, long contexts, variable prompts — while keeping latency low and cost per query reasonable.

timer

High Latency

Unoptimized models suffer from redundant compute and memory transfers.

attach_money

Hardware Cost

Serving 70B+ models on GPUs costs orders of magnitude more than necessary.

group

Limited Throughput

Naive serving can't sustain high concurrency without latency blowup.

Our Approach

Four technical pillars that combine to deliver production-ready JAX inference on GPU and TPU at dramatically lower cost.

bolt

JAX + XLA Kernel Fusion

Fuse kernels, reduce overhead, and optimise execution graphs for large-scale transformer inference. Compilation is backend-aware — TPU and GPU paths are optimized independently.

compress

Quantisation (int8/FP8)

Reduce KV-cache memory and bandwidth usage to achieve higher efficiency and longer context handling — with less than 1% precision loss on financial benchmarks.

device_hub

Multi-Device Sharding

Shard model parameters, activations, and KV caches across GPU/TPU pods using jax.pmap and shard_map for expert and tensor parallelism.

smart_toy

Agent-Based Optimisation

Autonomous agents automatically convert, analyze, and re-structure model execution code — bridging PyTorch/TensorFlow definitions to optimized JAX graph execution.

Why This Matters

speed

Lower latency

Smoother interactive experience for end users.

trending_up

Higher throughput

More users per dollar, less idle hardware.

memory

Lower cost

Proven hardware specific cost optimizations per model.

deployed_code

Production-ready

JAX stack deployable on GPUs and TPUs today.

Measured Efficiency Gains: TPU vs GPU

Real benchmark results from vLLM with 1,000 concurrent prompts and identical model checkpoints — no post-processing.

Llama-3 1B

TPU v5e-1 vs T4 GPU

Similar req/s · GPU cheaper at 1B scale
Request Throughput1.0× higher
TPU v5e-1
19.36 req/s
T4 GPU
18.82 req/s
Total Tokens/sec0.5× higher
TPU v5e-1
2.48k tok/s
T4 GPU
4.84k tok/s
Cost / 1M tokens (TPU v5e-1)
$0.11
vs T4 GPU
$0.04

Llama-3 3B

TPU v5e-1 vs T4 GPU

GPU cost-efficient at 3B scale
Request Throughput1.1× higher
TPU v5e-1
6.70 req/s
T4 GPU
6.13 req/s
Total Tokens/sec0.5× higher
TPU v5e-1
857.34 tok/s
T4 GPU
1.68k tok/s
Cost / 1M tokens (TPU v5e-1)
$0.93
vs T4 GPU
$0.27

Llama-3.1-8B

TPU v6e-1 vs A100

2.1× lower TPOT · 2× cheaper
Request Throughput1.6× higher
TPU v6e-1
13.52 req/s
A100
8.38 req/s
Total Tokens/sec1.6× higher
TPU v6e-1
15.57k tok/s
A100
9.64k tok/s
Mean TTFT1.7× lower
TPU v6e-1
34.8s ms
A100
57.7s ms
Mean TPOT2.1× lower
TPU v6e-1
47.30 ms
A100
100.39 ms
Cost / 1M tokens (TPU v6e-1)
$0.06
vs A100
$0.13

Llama-3.3-70B

TPU v6e-8 vs 2×H200

2.6× lower TPOT · 4× cheaper
Request Throughput2.0× higher
TPU v6e-8
11.09 req/s
2×H200
5.56 req/s
Total Tokens/sec2.0× higher
TPU v6e-8
12.77k tok/s
2×H200
6.39k tok/s
Mean TTFT2.0× lower
TPU v6e-8
42.1s ms
2×H200
83.8s ms
Mean TPOT2.6× lower
TPU v6e-8
141.51 ms
2×H200
374.66 ms
Cost / 1M tokens (TPU v6e-8)
$0.22
vs 2×H200
$0.83

Conclusion

By leveraging AI agents purpose-built for code translation and optimization, we unlock the full performance potential of JAX on TPU — automatically. Through intelligent kernel fusion (JAX/XLA), advanced quantization, and dynamic sharding of models and KV caches, these agents eliminate manual bottlenecks and continuously optimize execution paths in real time.

3–4×
Lower latency on average
through agent-driven kernel optimization
2.5–3×
Higher throughput per TPU
via adaptive execution strategies
60–75%
Lower cost per query
through automated optimization and resource utilization

Ready to Cut Your Inference Costs?

See how Siaivo automatically migrates your LLM serving stack to JAX on TPU — no manual porting required.

Request a Demo