-
Notifications
You must be signed in to change notification settings - Fork 2
Error while running CNN Python notebook with CIFAR10 Images #3
Description
Colab_The_annotated_MNIST_image_classification_example_with_Flax_Linen_and_Optax.ipynb
In the above notebook, I ran the same code as it is with CIFAR10 image dataset from tfds. There was no problem with the dataset itself, but in the train_step as shown below.
I had the same error as below when the arguments of the train_step were state and batch (as a dictionary) but even the first epoch did not take complete. Then I changed the train_step argument to instead take (state, batch_labels, batch_images) as all arrays. This time I could finish the first training epoch but then it threw the following error.
"""
python
Training -epoch: 1, loss: 1.9994, accuracy: 0.29
Traceback (most recent call last):
File "", line 4, in <cell line: 1>
test_loss, test_accuracy = eval_model(state.params, test_ds)
File "", line 2, in eval_model
metrics = eval_step(model, test_ds)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 208, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 150, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/api.py", line 301, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 460, in common_infer_params
avals.append(shaped_abstractify(a))
File "/usr/local/lib/python3.10/dist-packages/jax/_src/api_util.py", line 563, in shaped_abstractify
return _shaped_abstractify_handlerstype(x)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/api_util.py", line 575, in _numpy_array_abstractify
dtypes.check_valid_dtype(dtype)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 441, in check_valid_dtype
raise TypeError(f"Dtype {dtype} is not a valid JAX array "
TypeError: Dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
"""
Everything else is the same.