import jax import jax.numpy as jnp def generate_mat(N, k, key=jax.random.PRNGKey(0)): perm = jax.random.permutation(key, N * k) return perm.reshape(N, k) print(generate_mat(50000, 20))
Standard input is empty
Standard output is empty
Traceback (most recent call last): File "./prog.py", line 1, in <module> ModuleNotFoundError: No module named 'jax'