fork download
  1. import jax
  2. import jax.numpy as jnp
  3.  
  4.  
  5. def generate_mat(N, k, key=jax.random.PRNGKey(0)):
  6. perm = jax.random.permutation(key, N * k)
  7. return perm.reshape(N, k)
  8.  
  9. print(generate_mat(50000, 20))
  10.  
Runtime error #stdin #stdout #stderr 0.12s 26080KB
stdin
Standard input is empty
stdout
Standard output is empty
stderr
Traceback (most recent call last):
  File "./prog.py", line 1, in <module>
ModuleNotFoundError: No module named 'jax'