Skip to content
This repository was archived by the owner on Sep 20, 2024. It is now read-only.
This repository was archived by the owner on Sep 20, 2024. It is now read-only.

Error while running CNN Python notebook with CIFAR10 Images #3

@AbhiDu96

Description

@AbhiDu96

Colab_The_annotated_MNIST_image_classification_example_with_Flax_Linen_and_Optax.ipynb

In the above notebook, I ran the same code as it is with CIFAR10 image dataset from tfds. There was no problem with the dataset itself, but in the train_step as shown below.

I had the same error as below when the arguments of the train_step were state and batch (as a dictionary) but even the first epoch did not take complete. Then I changed the train_step argument to instead take (state, batch_labels, batch_images) as all arrays. This time I could finish the first training epoch but then it threw the following error.

"""
python

Training -epoch: 1, loss: 1.9994, accuracy: 0.29
Traceback (most recent call last):

File "", line 4, in <cell line: 1>
test_loss, test_accuracy = eval_model(state.params, test_ds)

File "", line 2, in eval_model
metrics = eval_step(model, test_ds)

File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)

File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 208, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(

File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 150, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(

File "/usr/local/lib/python3.10/dist-packages/jax/_src/api.py", line 301, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)

File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 460, in common_infer_params
avals.append(shaped_abstractify(a))

File "/usr/local/lib/python3.10/dist-packages/jax/_src/api_util.py", line 563, in shaped_abstractify
return _shaped_abstractify_handlerstype(x)

File "/usr/local/lib/python3.10/dist-packages/jax/_src/api_util.py", line 575, in _numpy_array_abstractify
dtypes.check_valid_dtype(dtype)

File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 441, in check_valid_dtype
raise TypeError(f"Dtype {dtype} is not a valid JAX array "

TypeError: Dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

"""

Everything else is the same.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions