-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Significant performance difference of NNX relative to equinox #4045
Comments
To add to this: the performance of linen seems to be similar to NNX. Although I am even less clear how to profile there. Here was my implementation import typing as tp
import jax
import jax.numpy as jnp
import flax.linen as linen
from flax.core import freeze, unfreeze
from flax.training import train_state
from flax.typing import Dtype, PrecisionLike
import optax
import time
class MLPLinen(linen.Module):
in_features: int
out_features: int
width: int
depth: int
activation: tp.Callable
use_bias: bool = True
use_final_bias: bool = True
final_activation: tp.Optional[tp.Callable] = None
dtype: tp.Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
@linen.compact
def __call__(self, x: jax.Array) -> jax.Array:
x = linen.Dense(
self.width,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
for _ in range(self.depth - 1):
x = self.activation(x)
x = linen.Dense(
self.width,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
x = linen.Dense(
self.out_features,
use_bias=self.use_final_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
if self.final_activation is not None:
x = self.final_activation(x)
return x
def create_train_state_linen(rng, model, learning_rate):
params = model.init(rng, jnp.ones([1, model.in_features]))['params']
tx = optax.adam(learning_rate)
return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
def compute_loss_linen(params, batch, model_apply_fn):
logits = model_apply_fn({'params': params}, batch)
loss = jnp.mean(logits)
return loss
@jax.jit
def train_step_linen(state, batch):
grad_fn = jax.value_and_grad(compute_loss_linen)
loss, grads = grad_fn(state.params, batch, state.apply_fn)
state = state.apply_gradients(grads=grads)
return state, loss
if __name__ == "__main__":
rng = jax.random.PRNGKey(0)
n_in = 64
n_out = 1
depth = 3
width = 128
activation = linen.relu
model = MLPLinen(n_in, n_out, width=width, depth=depth, activation=activation)
state = create_train_state_linen(rng, model, learning_rate=0.001)
my_batch = jax.random.normal(rng, (20, n_in))
# Time Linen
state, loss_val = train_step_linen(state, my_batch)
jax.block_until_ready(loss_val)
start_time = time.time()
state, loss_val = train_step_linen(state, my_batch)
jax.block_until_ready(loss_val)
end_time = time.time()
print(f"Time taken Linen: {end_time - start_time:.5f} seconds") |
Hey @jlperla, can you use That said, this is what I would expect:
|
@jlperla Maybe useful to note here, For small MLPs you are likely will be in the overhead regime. To overcome the framework overhead (in nnx or equinox) you may use |
@ASEM000 correct. Ideally we document how to overcome the overhead problem in the near future. |
@cgarciae @ASEM000 Absolutely. But the issue is comparing the relative overhead of NNX vs. Equinox for the same pattern? I find the timeit hard to use, but made sure things were compiled and retried multiple times? Why the equinox code would be so much faster than NNX (which seems roughly similar to flax linen)? What is the overhead that would be so much more significant there, using the same coding pattern? If you look at my code I am isolating a single "value and grad" call, no optimizer overhead or training loop. And precompiling it before timing. So either
|
@jlperla I do imagine the NNX overhead being greater than the Equinox overhead as we do more bookkeeping and its not optimized. If performance is critical you should just train using from functools import partial
import typing as tp
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx.nnx import rnglib
from flax.typing import Dtype, PrecisionLike
import equinox as eqx
import time
class MLP(nnx.Module):
def __init__(
self,
in_features: int,
out_features: int,
*,
width: int,
depth: int,
activation: tp.Callable,
rngs: rnglib.Rngs,
use_bias: bool = True,
use_final_bias: bool = True,
final_activation: tp.Optional[tp.Callable] = None,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
):
self.in_features = in_features
self.out_features = out_features
self.width = width
self.depth = depth
self.use_bias = use_bias
self.use_final_bias = use_final_bias
self.activation = activation
self.final_activation = final_activation
assert depth > 0 # skipping specialization of no hidden layers
self.layers = []
self.layers.append(
nnx.Linear(
in_features,
width,
use_bias=use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
for i in range(self.depth - 1):
self.layers.append(
nnx.Linear(
width,
width,
use_bias=self.use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
self.layers.append(self.activation)
self.layers.append(
nnx.Linear(
width,
out_features,
use_bias=self.use_final_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
if self.final_activation is not None:
self.layers.append(self.final_activation)
def __call__(self, x: jax.Array) -> jax.Array:
for layer in self.layers:
x = layer(x)
return x
if __name__ == '__main__':
rngs = nnx.Rngs(0)
@jax.jit
def my_test(batch, graphdef, state):
model = nnx.merge(graphdef, state)
def loss_closure(model):
return jnp.mean(jax.vmap(model)(batch))
loss_val, loss_grad = nnx.value_and_grad(loss_closure)(model)
return loss_val
n_in = 64
n_out = 1
depth = 3
width = 128
activation = nnx.relu
model = MLP(
n_in, n_out, width=width, depth=depth, activation=activation, rngs=rngs
)
my_batch = jax.random.normal(rngs(), (20, n_in))
graphdef, state = nnx.split(model)
# Time NNX
out = my_test(my_batch, graphdef, state).block_until_ready()
start_time = time.time()
out = my_test(my_batch, graphdef, state).block_until_ready()
end_time = time.time()
print(f'Time taken NNX: {end_time - start_time:.5f} seconds')
# -----------
# Equinox
# -----------
@eqx.filter_jit
def my_test_eqx(batch, treedef, leaves):
model = jax.tree.unflatten(treedef, leaves)
@eqx.filter_jit
def loss_closure(f):
return jnp.mean(jax.vmap(f)(batch))
loss_val, loss_grad = eqx.filter_value_and_grad(loss_closure)(model)
return loss_val
equinox_model = eqx.nn.MLP(
n_in,
n_out,
width_size=width,
depth=depth,
activation=activation,
key=rngs(),
)
leaves, treedef = jax.tree.flatten(equinox_model)
# Time Equinox
out = my_test_eqx(my_batch, treedef, leaves)
start_time = time.time()
out = my_test_eqx(my_batch, treedef, leaves).block_until_ready()
end_time = time.time()
print(f'Time taken EQX: {end_time - start_time:.5f} seconds') Output on my M1:
This version might still be suboptimal for Equinox because of the use of |
We will add a guide on NNX transforms explaining how they work under the hood in the future. |
Some documentation would be very useful, also ran into this when profiling nnx vs. linen. |
Linen is already low-overhead, I'll try to add it to the benchmark. |
@cgarciaethanks, this helps a lot. I don't feel like you need to compare to equinox in your docs. My main concern was that it seemed to be 3x slower for the same task. But if you are doing more bookkeeping, then it isn't really the same task. and just to confirm: my MLP implementation is as high performance as possible? If so, maybe that is helpful to have in the docs for people to adapt. |
I believe so. @jlperla do you want to contribute it as an NNX example? |
@cgarciae OK, here is my attempt comparing your code. The summary is now:
At this point I am now convinced that there is nothing fundamentally different between NNX and Equinox that holds back performance, even if there is probably some performance tweaks that may occur in the future. I feel like you could close out this issue and I could prepare a simple example for the docs (without the performance comparisons) if you are willing? Maybe a simple nonlinear regression with an example MLP? If you are interested, here was my code, which runs on my system as
import typing as tp
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx.nnx import rnglib
from flax.typing import Dtype, PrecisionLike
import equinox as eqx
import time
import flax.linen as linen
from flax.core import freeze, unfreeze
from flax.training import train_state
from functools import partial
import optax
class MLPLinen(linen.Module):
in_features: int
out_features: int
width: int
depth: int
activation: tp.Callable
use_bias: bool = True
use_final_bias: bool = True
final_activation: tp.Optional[tp.Callable] = None
dtype: tp.Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
@linen.compact
def __call__(self, x: jax.Array) -> jax.Array:
x = linen.Dense(
self.width,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
for _ in range(self.depth - 1):
x = self.activation(x)
x = linen.Dense(
self.width,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
x = linen.Dense(
self.out_features,
use_bias=self.use_final_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
if self.final_activation is not None:
x = self.final_activation(x)
return x
class MLP(nnx.Module):
def __init__(
self,
in_features: int,
out_features: int,
*,
width: int,
depth: int,
activation: tp.Callable,
rngs: rnglib.Rngs,
use_bias: bool = True,
use_final_bias: bool = True,
final_activation: tp.Optional[tp.Callable] = None,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
):
self.in_features = in_features
self.out_features = out_features
self.width = width
self.depth = depth
self.use_bias = use_bias
self.use_final_bias = use_final_bias
self.activation = activation
self.final_activation = final_activation
assert depth > 0 # skipping specialization of no hidden layers
self.layers = []
self.layers.append(
nnx.Linear(
in_features,
width,
use_bias=use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
for i in range(self.depth - 1):
self.layers.append(
nnx.Linear(
width,
width,
use_bias=self.use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
self.layers.append(self.activation)
self.layers.append(
nnx.Linear(
width,
out_features,
use_bias=self.use_final_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
if self.final_activation is not None:
self.layers.append(self.final_activation)
def __call__(self, x: jax.Array) -> jax.Array:
for layer in self.layers:
x = layer(x)
return x
if __name__ == "__main__":
rngs = nnx.Rngs(0)
@nnx.jit
def my_test(batch, model):
def loss_closure(f):
return jnp.mean(jax.vmap(f)(batch))
loss_val, loss_grad = nnx.value_and_grad(loss_closure)(model)
return loss_val
@jax.jit
def my_test_split(batch, graphdef, state):
model = nnx.merge(graphdef, state)
def loss_closure(model):
return jnp.mean(jax.vmap(model)(batch))
loss_val, loss_grad = nnx.value_and_grad(loss_closure)(model)
return loss_val
n_in = 64
n_out = 1
depth = 1
width = 128
activation = nnx.relu
model = MLP(n_in, n_out, width=width, depth=depth, activation=activation, rngs=rngs)
my_batch = jax.random.normal(rngs(), (20, n_in))
# Time NNX
out = my_test(my_batch, model).block_until_ready()
start_time = time.time()
out = my_test(my_batch, model).block_until_ready()
end_time = time.time()
print(f"Time taken NNX: {end_time - start_time:.6f} seconds")
graphdef, state = nnx.split(model)
out = my_test_split(my_batch, graphdef, state).block_until_ready()
start_time = time.time()
out = my_test_split(my_batch, graphdef, state).block_until_ready()
end_time = time.time()
print(f"Time taken NNX Split: {end_time - start_time:.6f} seconds")
@eqx.filter_jit
def my_test_eqx(batch, model):
@eqx.filter_jit
def loss_closure(f):
return jnp.mean(jax.vmap(f)(batch))
loss_val, loss_grad = eqx.filter_value_and_grad(loss_closure)(model)
return loss_val
@partial(jax.jit, static_argnums=2)
def my_test_eqx_split(batch, params, static):
model = eqx.combine(params, static)
def loss_closure(f):
return jnp.mean(jax.vmap(f)(batch))
loss_val, loss_grad = eqx.filter_value_and_grad(loss_closure)(model)
return loss_val
equinox_model = eqx.nn.MLP(n_in, n_out, width_size=width, depth=depth, activation=activation, key=rngs())
# Time Equinox
out = my_test_eqx(my_batch, equinox_model)
start_time = time.time()
out = my_test_eqx(my_batch, equinox_model).block_until_ready()
end_time = time.time()
print(f"Time taken EQX: {end_time - start_time:.6f} seconds")
params, static = eqx.partition(equinox_model, eqx.is_array)
out = my_test_eqx_split(my_batch, params, static)
start_time = time.time()
out = my_test_eqx_split(my_batch, params, static).block_until_ready()
end_time = time.time()
print(f"Time taken EQX Split: {end_time - start_time:.6f} seconds")
# Time Linen
def create_train_state_linen(rng, model, learning_rate):
params = model.init(rng, jnp.ones([1, model.in_features]))['params']
tx = optax.adam(learning_rate)
return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
def compute_loss_linen(params, batch, model_apply_fn):
logits = model_apply_fn({'params': params}, batch)
loss = jnp.mean(logits)
return loss
@jax.jit
def train_step_linen(state, batch):
grad_fn = jax.value_and_grad(compute_loss_linen)
loss, grads = grad_fn(state.params, batch, state.apply_fn)
state = state.apply_gradients(grads=grads)
return state, loss
model = MLPLinen(n_in, n_out, width=width, depth=depth, activation=linen.relu)
state = create_train_state_linen(rngs(), model, learning_rate=0.001)
my_batch = jax.random.normal(rngs(), (20, n_in))
# Time Linen
state, loss_val = train_step_linen(state, my_batch)
jax.block_until_ready(loss_val)
start_time = time.time()
state, loss_val = train_step_linen(state, my_batch)
jax.block_until_ready(loss_val)
end_time = time.time()
print(f"Time taken Linen: {end_time - start_time:.6f} seconds") |
I decided to try the nnx vs. equinox for performance and am seeing significant differences (3'ish times slower for nnx). Could be that I wrote a poor MLP implementation or made a collosal profiling mistake.
My apologies if the benchmarking itself is flaws or the MLP implementation is incorrect in some way. But if it is the later, it shows that a documented MLP implementa`ton for NNX to copy/paste might help.
System information
Problem you have encountered:
The performance of my test suite on my CPU is
And on the colab T4 GPU runtime
Steps to reproduce:
Test Suite:
On colab you need to do
! pip install equinox
The text was updated successfully, but these errors were encountered: