We run a streaming ASR server that handles dozens of concurrent WebSocket connections on a single GPU. Audio flows in at 80ms chunks, passes through VAD, a Conformer encoder, and an RNN-T beam search decoder, all in one async inference loop.
We were chasing latency. Every millisecond in this loop matters. It's the difference between a voice agent that feels instant and one that feels broken. So we profiled everything, phase by phase, to find where the time was actually going.
What we found led us to write custom Triton kernels in two places and to leave the rest alone. Along the way, torch.compile made things worse.
The Pipeline
Our inference loop processes all concurrent streams in a single while loop:
Audio (WebSocket) → VAD → Encoder → Beam Search → Text (WebSocket)
↑ ↑
Silero LSTM Joint + Softmax + TopK
(speech boundaries) (candidate selection)Each cycle:
- VAD phase: Decode PCM audio, run batched Silero VAD, inject silence at speech boundaries
- Encoder phase: Batch-encode audio frames through Conformer
- Decoder phase: Run beam search steps (joint network, softmax, top-K, beam update)
We profiled each phase to find where custom kernels would pay off.
Fusing the Batched VAD Pipeline
Every inference cycle, we process audio from all connected clients. The original code did this in scattered steps: numpy PCM decode on CPU, copy to GPU for VAD, back to Python for the speech state machine, then write results to the pipeline. Data bounced between CPU and GPU, between numpy and torch, between Python loops and CUDA kernels.
We fused this into a single GPU-resident path: raw int16 bytes go up to the GPU once, PCM decode + windowing happens in a Triton kernel, Silero VAD runs on the already-resident data, and the post-VAD state machine runs as vectorized tensor ops. No round-trips.
The original code, per-stream numpy on CPU, N separate GPU transfers:
# Per-stream: numpy decode on CPU, then transfer each to GPU separately
for chunk_count, audio_bytes in chunks:
pcm = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
pcm = np.concatenate([pcm_buf, pcm])
# ... slice into 512-sample windows on CPU ...
# Later: stack and transfer to GPU for VAD
batch_tensor = torch.from_numpy(np.stack(round_windows)).to(self.device)The fused path, one transfer, everything stays on GPU:
# Fused: concat bytes, single transfer, decode + window on GPU all_bytes = b"".join(audio_bytes for _, audio_bytes in all_chunks) pcm_int16 = torch.frombuffer(all_bytes, dtype=torch.int16).to(device) # Triton kernel: int16→float32 + windowing in one pass pcm_windows = fused_pcm_decode_and_window(pcm_int16, stream_offsets, window_size=512) # Silero runs directly on GPU-resident windows, no transfer needed probs = silero_model(pcm_windows, sample_rate) # Vectorized state machine, parallel across all streams speech_end_mask = vad_state_step(probs, triggered, temp_end, current_sample)
Results, PCM decode + GPU transfer (A100):
| Streams | Per-stream numpy→GPU | Fused single transfer | Speedup |
|---|---|---|---|
| 10 | 0.345ms | 0.083ms | 4.2x |
| 50 | 1.698ms | 0.114ms | 14.9x |
| 100 | 3.038ms | 0.158ms | 19.2x |
| 1000 | 32.178ms | 0.787ms | 40.9x |
The key insight: the win isn't from faster math. It's from eliminating N separate CPU→GPU transfersin favor of one. At 1000 streams, that's 32ms of overhead eliminated by a one-line change.
Fused Log-Softmax + Top-K for Beam Search
In each beam search step, the joint network produces logits over the full vocabulary. We then need log-softmax, extract non-blank tokens, top-K, append blank. The standard PyTorch implementation uses 6 separate CUDA kernels, each reading and writing the full [N_streams, vocab_size] tensor from global memory.
This runs up to 20 times per inference cycle (one per beam step).
The current code, 6 separate CUDA kernels:
log_probs = torch.log_softmax(ytu[:, 0, 0, :] / temperature, dim=-1) # kernel 1-2 non_blank = log_probs[:, ids_t] # kernel 3 topk_vals, topk_ids = non_blank.topk(beam_k, dim=-1) # kernel 4 blank_probs = log_probs[:, blank_idx:blank_idx+1] # kernel 5 topk_vals = torch.cat([topk_vals, blank_probs], dim=-1) # kernel 6
The fused Triton kernel, one program per stream, one read, one write:
@triton.jit
def _fused_log_softmax_topk_kernel(
logits_ptr, topk_vals_ptr, topk_ids_ptr,
V: tl.constexpr, K: tl.constexpr,
blank_idx: tl.constexpr, inv_temperature: tl.constexpr,
BLOCK_V: tl.constexpr,
):
row = tl.program_id(0)
offs = tl.arange(0, BLOCK_V)
mask = offs < V
# Single load from global memory
logits = tl.load(logits_ptr + row * V + offs, mask=mask, other=float('-inf'))
logits = logits * inv_temperature
# Log-softmax in registers (no intermediate writes)
max_val = tl.max(logits, axis=0)
shifted = logits - max_val
log_sum_exp = tl.log(tl.sum(tl.exp(shifted), axis=0))
log_probs = shifted - log_sum_exp
# Extract blank, mask it from candidates
blank_logit = tl.load(logits_ptr + row * V + blank_idx)
blank_log_prob = blank_logit * inv_temperature - max_val - log_sum_exp
candidates = tl.where(offs == blank_idx, float('-inf'), log_probs)
# Iterative top-K (K is small, 2-8, fully unrolled)
out_base = row * (K + 1)
for k in tl.static_range(K):
best_val = tl.max(candidates, axis=0)
is_best = (candidates == best_val) & mask
best_idx = tl.min(tl.where(is_best, offs, V), axis=0)
tl.store(topk_vals_ptr + out_base + k, best_val)
tl.store(topk_ids_ptr + out_base + k, best_idx)
candidates = tl.where(offs == best_idx, float('-inf'), candidates)
# Append blank
tl.store(topk_vals_ptr + out_base + K, blank_log_prob)
tl.store(topk_ids_ptr + out_base + K, blank_idx)The entire vocabulary row is loaded once into registers. Temperature scaling, log-softmax, and top-K selection all happen in-register. No intermediate tensors, no global memory writes until the final K+1 results.
Results (vocab_size=1025, beam_k=4, A100):
| Streams | Naive (loop) | Batched PyTorch | torch.compile | Triton Fused |
|---|---|---|---|---|
| 1 | 0.241ms | 0.172ms | 0.271ms | 0.078ms |
| 8 | 1.344ms | 0.168ms | 0.288ms | 0.077ms |
| 32 | 5.091ms | 0.169ms | 0.396ms | 0.098ms |
| 64 | 10.713ms | 0.196ms | 0.294ms | 0.078ms |
| 128 | 20.939ms | 0.229ms | 0.322ms | 0.087ms |
2.2x faster than batched PyTorch. 3.7x faster than torch.compile. 241x faster than naive.
The Triton kernel stays nearly flat (~0.08ms) from 1 to 128 streams because each row is an independent program instance. The GPU processes them all in parallel.
The torch.compile Surprise
We expected torch.compile(mode="reduce-overhead")to help. It's designed to fuse operations and reduce kernel launch overhead. Instead, it was consistently 1.5-2x slower than eager PyTorch.
Why? It uses CUDA graphs under the hood, which record a fixed execution plan. But our batch size (N = number of active streams) changes every cycle. Each new N triggers a re-recording:
[__cudagraphs] CUDAGraph supports dynamic shapes by recording a new graph for each distinct input size. Recording too many CUDAGraphs may lead to extra overhead. We have observed 9 distinct sizes.
torch.compile isn't free. If your batch dimension is dynamic, which it always is in streaming inference, the CUDA graph overhead can negate the fusion benefits.
What We Didn't Touch
The encoder (Conformer attention + convolution + FFN): This is a full transformer. Writing it in Triton means reimplementing multi-head attention, convolution modules, and feed-forward layers. Use TensorRT or torch.compile with static shapes instead.
The decoder LSTM (prediction network): Sequential by nature. Each token depends on the previous hidden state. Triton excels at data-parallel work, not sequential recurrence.
These components dominate cycle time (10-50ms), but the right tools for them are TensorRT export and batched inference, not hand-written kernels.
Takeaways
- Eliminate data movement first. The biggest win (40x) came from batching CPU→GPU transfers. No custom kernels needed. Just stop doing N separate copies.
- Fusion helps when memory bandwidth is the bottleneck. The joint scoring kernel wins because it reads the vocab tensor once instead of six times, not because it does less math.
- torch.compile has hidden costs with dynamic shapes. In streaming inference where batch size varies every cycle, CUDA graph re-recording makes it slower than eager.
- Know when to stop. The encoder and decoder are better served by existing tools (TensorRT, ONNX) than hand-written kernels. Engineering judgment is knowing what not to optimize.
All benchmarks run on NVIDIA A100 (PG509-210), PyTorch 2.12, Triton 3.x.