|
| 1 | +import pygad.cnn |
| 2 | +import pygad.gacnn |
| 3 | +import pygad |
| 4 | +import numpy |
| 5 | + |
| 6 | +def test_gacnn_evolution(): |
| 7 | + """Test pygad.gacnn with pygad.GA.""" |
| 8 | + # Small dummy data |
| 9 | + data_inputs = numpy.random.uniform(0, 1, (4, 10, 10, 3)) |
| 10 | + data_outputs = numpy.array([0, 1, 1, 0]) |
| 11 | + |
| 12 | + input_layer = pygad.cnn.Input2D(input_shape=(10, 10, 3)) |
| 13 | + conv_layer = pygad.cnn.Conv2D(num_filters=2, |
| 14 | + kernel_size=3, |
| 15 | + previous_layer=input_layer, |
| 16 | + activation_function="relu") |
| 17 | + flatten_layer = pygad.cnn.Flatten(previous_layer=conv_layer) |
| 18 | + dense_layer = pygad.cnn.Dense(num_neurons=2, |
| 19 | + previous_layer=flatten_layer, |
| 20 | + activation_function="softmax") |
| 21 | + |
| 22 | + model = pygad.cnn.Model(last_layer=dense_layer, |
| 23 | + epochs=1, |
| 24 | + learning_rate=0.01) |
| 25 | + |
| 26 | + gacnn_instance = pygad.gacnn.GACNN(model=model, |
| 27 | + num_solutions=4) |
| 28 | + |
| 29 | + def fitness_func(ga_instance, solution, sol_idx): |
| 30 | + predictions = gacnn_instance.population_networks[sol_idx].predict(data_inputs=data_inputs) |
| 31 | + correct_predictions = numpy.where(predictions == data_outputs)[0].size |
| 32 | + solution_fitness = (correct_predictions/data_outputs.size)*100 |
| 33 | + return solution_fitness |
| 34 | + |
| 35 | + def callback_generation(ga_instance): |
| 36 | + population_matrices = pygad.gacnn.population_as_matrices(population_networks=gacnn_instance.population_networks, |
| 37 | + population_vectors=ga_instance.population) |
| 38 | + gacnn_instance.update_population_trained_weights(population_trained_weights=population_matrices) |
| 39 | + |
| 40 | + initial_population = pygad.gacnn.population_as_vectors(population_networks=gacnn_instance.population_networks) |
| 41 | + |
| 42 | + ga_instance = pygad.GA(num_generations=2, |
| 43 | + num_parents_mating=2, |
| 44 | + initial_population=initial_population, |
| 45 | + fitness_func=fitness_func, |
| 46 | + on_generation=callback_generation, |
| 47 | + suppress_warnings=True) |
| 48 | + |
| 49 | + ga_instance.run() |
| 50 | + assert ga_instance.run_completed |
| 51 | + assert ga_instance.generations_completed == 2 |
| 52 | + |
| 53 | + print("test_gacnn_evolution passed.") |
| 54 | + |
| 55 | +if __name__ == "__main__": |
| 56 | + test_gacnn_evolution() |
| 57 | + print("\nAll GACNN tests passed!") |
0 commit comments