Who needs fast autograd? Seemingly everyone these days!And once upon a time I needed an autograd that is actually fast.Leaving project details aside, here are the requirements: we test many computation graphs (graph is changing constantly) many-many scalar operations with roughly 10k—100k nodes in each graph every graph should be compiled and ran around 10k times both forward and backward this should be done wicked fast, and with a convenient pythonic interfacePath that awaits us ahead: autograd in torch autograd in jax autograd in python autograd in rust autograd in C autograd in assemblyPlus a significant amount of sloppy code and timings on M1 macbook.Let’s autograd in pytorchWe start our journey with pytorch — the default autograd engine in research. We’ll create a graph with many nodes, and to keep things simple our benchmark has only several kinds of operations: unary (softplus), binary (multiplication), n-ary (sum) and n-to-n (softmax).This allows using just a few operations, but resembles a realistic load.All benchmarks in this post will reimplement the same logic as below.def run_graph(initial_variables, n_operations: int): nodes = [*initial_variables] for op in range(n_operations): match op % 4: case 0: # softplus nodes.append(F.softplus(nodes[-10])) case 1: # sum nodes.append(sum(nodes[-30:-10:5])) case 2: # prod nodes.append(nodes[-20] * nodes[-10]) case 3: # softmax softmaxes = F.softmax(torch.stack(nodes[-4:], dim=0), dim=0) nodes.extend(softmaxes) return nodesdef run_benchmark_pytorch(n_iterations, n_operations): init_vars = torch.arange(100, dtype=torch.float32, requires_grad=True) for _ in range(n_iterations): nodes = run_graph( initial_variables=init_vars, n_operations=n_operations, ) nodes[-1].backward()Run-time for 10k ops x 100 iterations: 11.3 secondsRun-time for 10k ops x 10k iterations: 1130 seconds (estimate)Given we created 100M python objects, it’s actually quite fast.And yes, that’s not going to deliver an interactive experience.Let’s also discuss torch.compile, a major innovation in pytorch 2.0.At 100 operations torch.compile takes 4.5 seconds. Execution gets faster: for 100 operations and 10k iterations it takes 4.52 seconds with torch.compile and 10.4 seconds without. Compilation + execution are still in the same ballpark. For bigger graphs (1k operations) torch.compile crashes.Let’s autograd in jaxJax is the new cool kid… well, not that new anymore.But in some aspects it is very interesting. Jax’s focus on JIT-compiling static graphs is very suitable for the problem at hand.Implementation for benchmark is similar to pytorch:import jaximport numpy as npdef run_graph_jax(initial_variables): nodes = [*initial_variables] for op in range(n_operations): match op % 4: case 0: # softplus nodes.append(jax.nn.softplus(nodes[-10])) case 1: # sum nodes.append(sum(nodes[-30:-10:5])) case 2: # prod nodes.append(nodes[-20] * nodes[-10]) case 3: # softmax softmaxes = jax.nn.softmax(jax.numpy.stack(nodes[-4:]), axis=0) nodes.extend(softmaxes) return nodes[-1]run_graph_and_grad = jax.value_and_grad(run_graph_jax)# or run_graph_and_grad = jax.jit(jax.value_and_grad(run_graph_jax))Without jit computations are extremely slow: 1k ops x 10 iterations => 15.9 seconds 10k ops x 10k iterations => 159,000 seconds (estimate)That’s a bit longer than forever! But whole point of jax is to JIT-compile stuff. So let’s do it.jit: compilation of 1k ops = 47 seconds jit: run-time for 1k ops x 10k iterations = 0.66 seconds jit: 10k ops x 10k iterations (compilation + run-time) => 470 seconds (estimate)Speed up in execution time is more than impressive, but we spend >99% of time compiling.TensorflowSomeone will mention TF anyway. I’ll leave this as an exercise for you, TF fans.Let’s autograd in pythonDone with baselines, time to see if we can speed things up.Let’s create a simplistic pseudo-framework and see how it competes with previous candidates.We’ll implement a tape-like autograd where operations order is explicitly tracked in a tape. show autograd engine in plain python class NaiveVar: def __init__(self, val): self.val = val self.grad = 0. class NaiveTape: def __init__(self, input_values): self.ops = [] def sum(self, *vars): res = NaiveVar(sum(v.val for v in vars)) self.ops.append(('sum', vars, res)) return res def prod(self, var1, var2): res = NaiveVar(var1.val * var2.val) self.ops.append(('prod', [var1, var2], res)) return res def softmax(self, *vars): vals = [v.val for v in vars] maxval = max(vals) vals = [v - maxval for v in vals] denom = sum(math.exp(v) for v in vals) res = [NaiveVar(math.exp(v) / denom) for v in vals] self.ops.append(('softmax', vars, denom)) return res def softplus(self, var): res = NaiveVar(math.log1p(math.exp(var.val))) self.ops.append(('splus', var, res)) return res def backward(self, var): assert var.grad == 0 var.grad += 1 for op, inputs, outputs in self.ops[::-1]: match op: case 'sum': out = outputs for v in inputs: v.grad += out.grad case 'prod': out = outputs in1, in2 = inputs in1.grad += in2.val * out.grad in2.grad += in1.val * out.grad case 'splus': inputs.grad += out.grad / (1 + math.exp(-inputs.val)) case 'softmax': pass # skip for now case _: raise NotImplementedError() and reimplement reference task using our new pseudo-framework: show benchmarking code def run_graph_python_and_backward(initial_variables, n_operations): nodes = [NaiveVar(x) for x in initial_variables] tape = NaiveTape(nodes) for op in range(n_operations): match op % 4: case 0: # softplus nodes.append(tape.softplus(nodes[-10])) case 1: # sum nodes.append(tape.sum(*nodes[-30:-10:5])) case 2: # prod nodes.append(tape.prod(nodes[-20], nodes[-10])) case 3: # softmax nodes.extend(tape.softmax(*nodes[-4:])) tape.backward(nodes[-1]) return tape Run-time for 10k ops and 10k iterations: 312 seconds.Expectably not fast. But compared to previous candidates, that’s actually quite competitive!Let’s autograd in python, againThis time we move all values into tape instead of keeping in variables. Additionally tape will keep a ‘static graph’ of computations by recording indices of variables participating in every operation. show code for autograd in plain python import numbaimport mathclass VarInd: def __init__(self, index): self.index = index # variable is just a unique index in tape class TapeInd: def __init__(self): self.ops = [] self.vals = [] # flat memory with values self.grads = [] # flat memory with gradients def make_var(self, value): self.vals.append(value) self.grads.append(0.) return VarInd(len(self.vals) - 1) def val(self, v: VarInd): return self.vals[v.index] def add_op(self, kls, input_vars, output_vars): # translate variable to indices. self.ops keeps only indices self.ops.append((kls, [x.index for x in input_vars], [x.index for x in output_vars])) def sum(self, *vars): res = self.make_var(sum(self.val(v) for v in vars)) self.add_op('sum', vars, [res]) return res def prod(self, var1, var2): res = self.make_var(self.val(var1) * self.val(var2)) self.add_op('prod', [var1, var2], [res]) return res def softmax(self, *vars): vals = [self.val(v) for v in vars] maxval = max(vals) vals = [v - maxval for v in vals] denom = sum(math.exp(v) for v in vals) res = [self.make_var(math.exp(v) / denom ) for v in vals] self.add_op('softmax', vars, res) return res def softplus(self, var): res = self.make_var(math.log1p( math.exp(self.val(var)) )) self.add_op('splus', [var], [res]) return res def forward_backward_external(self, grad_var: VarInd): return forward_backward_optimal(self.vals, self.grads, self.ops, grad_var_index=grad_var.index)def forward_backward_external(vals: list[float], grads: list[float], ops: list[tuple[str, list[int], list[int]]],grad_var_index: int): v: list[float] = vals g: list[float] = grads # forward pass for op, ins, outs in ops: match op: case 'sum': v[outs[0]] = sum(v[i] for i in ins) case 'prod': v[outs[0]] = v[ins[0]] * v[ins[1]] case 'splus': v[outs[0]] = math.log1p(math.exp( v[ins[0]] )) case 'softmax': maximal = max(v[i] for i in ins) exps = [math.exp(v[i] - maximal) for i in ins] denom = sum(outs) for i, exp in zip(outs, exps): v[i] = exp / denom g[grad_var_index] += 1# backward pass for op, ins, outs in ops[::-1]: match op: case 'sum': for i in ins: g[i] += g[outs[0]] case 'prod': out: int = outs[0] in1, in2 = ins g[in1] += v[in2] * g[out] g[in2] += v[in1] * g[out] case 'splus': g[ins[0]] += g[outs[0]] / (1 + math.exp(-v[ins[0]])) case 'softmax':avg_grad = sum(v[j] * g[j] for j in outs)for i, j in zip(ins, outs):g[i] += v[j] * (g[j] - avg_grad) and corresponding launching code def run_graph_python_and_backward(n_operations, n_iterations): tape = TapeInd() nodes = [tape.make_var(float(x)) for x in range(100)] for op in range(n_operations): match op % 4: case 0: # softplus nodes.append(tape.softplus(nodes[-10])) case 1: # sum nodes.append(tape.sum(*nodes[-30:-10:5])) case 2: # prod nodes.append(tape.prod(nodes[-20], nodes[-10])) case 3: # softmax softmaxes = tape.softmax(*nodes[-4:]) nodes.extend(softmaxes) for _ in range(n_iterations): tape.forward_backward(nodes[-1]) Run-time for 10k ops x 10k iterations: 94 secondsAs we see, moving all values into tape and switching to operating on indices is quite an efficient strategy. We still use python, but are now ~5-10 fold faster than pytorch or jax.At this point, I want to mention one more experiment: code above is organized to be numba-friendly. Numba is famous for speeding up number crunching in python with minimal changes by providing just-in-time compilation. Recent addition of numba.typed.List makes it possible to efficiently handle list of lists.Run-time with numba, 10k ops x 10k iterations: 41 second. At this point we’re >10-fold faster than jax/pytorch (and still writing code in python).Let’s autograd in rustOnce we moved graph tracking to tape, we can now use something fast to run computations for us. For instance, rust. For rust↔python interop I’ve used a small wrapper around rustimport.Rustimport allows to conveniently “import” a single rust file without creating a full-fledged rust project.Some optimization remarks: softmax was a bottleneck, so I switched to creating temporary arrays on stack instead of Vecs, which required specializing on input sizes I followed rust-y approach with iterators to reduce number of boundary checks I wondered if match with multiple options checked one-by-one is slow. In synthetic tests it seemed to be relatively fast, but I wish jump table optimization was implemented here(e.g. it is supported for enums in rust, and clang uses this optimization in C for switch-case) show rust code for minimal autograd // rustimport:pyo3use pyo3::prelude::*;// slower softmax version for larger number of inputsfn softmax_varlength(vals: &mut Vec, ins: &[usize], outs: &[usize]) { let mut max = -1e20_f32; let loc_vals: Vec = ins.into_iter().map(|i| { let x = vals[*i]; max = max.max(x); x} ).collect(); let mut sum: f32 = 0.0_f32; let exps: Vec = loc_vals.iter().map(|v| {let _exp = f32::exp(*v - max); sum += _exp; _exp}).collect(); outs.iter().zip(exps.iter()).for_each(|(j, exp)| vals[*j] = exp / sum );}// vecs are slow! so allocate slices on stack, and explicit grouping of computations also helpsfn softmax(vals: &mut Vec, ins: &[usize], outs: &[usize]) { let mut loc_vals: [f32; N] = [0_f32; N]; let mut exps: [f32; N] = [0_f32; N]; let mut max = -1e20_f32; let mut sum: f32 = 0.; for (n, i) in ins.into_iter().enumerate() { let v = vals[*i]; loc_vals[n] = v; max = max.max(v); } for (n, _i) in ins.into_iter().enumerate() { let exp = f32::exp(loc_vals[n] - max); exps[n] = exp; sum += exp; } let invsum = 1.0_f32 / sum; for (n, j) in outs.into_iter().enumerate() { vals[*j] = exps[n] * invsum; }}fn sigmoid(x: f32) -> f32 { 1.0 / (1.0 + (-x).exp())}#[pyfunction]unsafe fn autograd( vals_input: Vec, ops: Vec, input_ids: Vec, output_ids: Vec, backward_node_id: usize, n_iteration: i32,) -> (Vec, Vec) { let mut vals: Vec = vals_input.iter().map(|x| *x).collect(); let mut grad: Vec = vals_input.into_iter().map(|_| 0.0_f32).collect(); for _ in 0..n_iteration { for (i_op, op) in ops.iter().enumerate(){ let ins: &Vec = &input_ids[i_op]; let outs: &Vec = &output_ids[i_op]; match op { 0 => { // softplus let x = vals[ins[0]]; let max = f32::max(0., x); let min = f32::min(0., x); vals[outs[0]] = max + f32::ln_1p(f32::exp(min - max)); } 1 => { // sum vals[outs[0]] = ins.iter().map(|i| vals.get_unchecked(*i)).sum(); } 2 => { // prod vals[outs[0]] = vals[ins[0]] * vals[ins[1]]; } 3 => { // softmax. we will need switch-case resolution here for most common cases match ins.len() { 1 => {softmax::(&mut vals, &ins, &outs)} 2 => {softmax::(&mut vals, &ins, &outs)} 3 => {softmax::(&mut vals, &ins, &outs)} 4 => {softmax::(&mut vals, &ins, &outs)} 5 => {softmax::(&mut vals, &ins, &outs)} _ => {softmax_varlength(&mut vals, &ins, &outs)} } } _ => { panic!(""); } } } grad[backward_node_id] = 1.; for (i_op, op) in ops.iter().enumerate(){ let ins: &Vec = &input_ids[i_op]; let outs: &Vec = &output_ids[i_op]; match op { 0 => { // softplus grad[ins[0]] += grad[outs[0]] * sigmoid(vals[ins[0]]); } 1 => { // sum ins.iter().for_each(|i| grad[*i] += grad[outs[0]]); } 2 => { // prod grad[ins[0]] += grad[outs[0]] * vals[ins[1]]; grad[ins[1]] += grad[outs[0]] * vals[ins[0]]; } 3 => { // softmax let avg_grad: f32 = outs.iter().map(|j| grad[*j] * vals[*j] ).sum(); for (i, j) in ins.iter().zip(outs.iter()) { grad[*i] += vals[*j] * (grad[*j] - avg_grad); } } _ => { panic!(""); } } } } (vals, grad)} Run-time for 10k ops x 10k iterations: 1.4 secondsSuccess: we are in the realm of interactive experiences. Recall we started from >1000 seconds. But should we stop here?Let’s autograd in CTime to implement autograd logic in C. For interop with python I use python-cffi.I went bananas on optimization: I used the fact that output nodes are placed consequentially in memory, so we pass only index of the first output number of inputs is limited to 8, and those are baked into struct as int[8], not int * to avoid jumps in memory dynamic stack allocations of variable size (compared to rust, those are straightforward in C) -O3, and unsafe math: -ffast-math. Even experimented memory alignment and restrict-ing pointers, but no luck show me some code in C #include typedef struct { int opcode; size_t n_arguments; // used for softmax and sum int ins[8]; // at most 8 inputs int out; // points to the first output variable} MyOperation;MyOperation * allocate_memory(int n_elements) { return (MyOperation *) malloc(sizeof(MyOperation) * n_elements);}// stable implementationdouble logaddexp(double x, double y) { if (x > y) { return x + log1p(exp(y - x)); } else { return y + log1p(exp(x - y)); }}double sigmoid(double x) { return 1.0 / (1.0 + exp(-x)); }void run_multiple_passes( int n_operations, MyOperation *ops, double *values, double *grads, int n_iterations) { for(int iteration = 0; iteration