@@ -218,14 +218,33 @@ def create_invite_tokens(
218218 _check_election_is_not_ended (get_election (db , election_ref ))
219219 now = datetime .now ()
220220 params = {"date_created" : now , "date_modified" : now , "election_ref" : election_ref }
221- db_votes = [models .Vote (** params ) for _ in range (num_voters * num_candidates )]
222- db .bulk_save_objects (db_votes , return_defaults = True )
223- db .commit ()
221+
222+ try :
223+ db_ballots = [models .Ballot (election_ref = election_ref ) for _ in range (num_voters )]
224+ db .bulk_save_objects (db_ballots , return_defaults = True )
225+
226+ db_votes = []
227+
228+ for ballot in db_ballots :
229+ for _ in range (num_candidates ):
230+ db_votes .append (models .Vote (** params , ballot_id = ballot .id ))
231+
232+ db .bulk_save_objects (db_votes , return_defaults = True )
233+ db .commit ()
234+ except Exception as e :
235+ db .rollback ()
236+ raise e
237+
238+ tokens = []
224239 vote_ids = [int (str (v .id )) for v in db_votes ]
225- tokens = [
226- create_ballot_token (vote_ids [i ::num_voters ], election_ref )
227- for i in range (num_voters )
228- ]
240+
241+ for i , ballot in enumerate (db_ballots ):
242+ start = i * num_candidates
243+ end = start + num_candidates
244+ tokens .append (
245+ create_ballot_token (vote_ids [start :end ], election_ref , int (str (ballot .id )))
246+ )
247+
229248 return tokens
230249
231250
@@ -405,18 +424,29 @@ def create_ballot(db: Session, ballot: schemas.BallotCreate) -> schemas.BallotGe
405424 )
406425 _check_ballot_is_consistent (election , ballot )
407426
408- # Ideally, we would use RETURNING but it does not work yet for SQLite
409- db_votes = [
410- models .Vote (** v .model_dump (), election_ref = ballot .election_ref ) for v in ballot .votes
411- ]
412- db .add_all (db_votes )
413- db .commit ()
414- for v in db_votes :
415- db .refresh (v )
427+ try :
428+ db_ballot = models .Ballot (election_ref = ballot .election_ref )
429+ db .add (db_ballot )
430+ db .flush ()
431+
432+ # Create votes and associate them with the ballot
433+ db_votes = [
434+ models .Vote (** v .model_dump (), election_ref = ballot .election_ref , ballot_id = db_ballot .id )
435+ for v in ballot .votes
436+ ]
437+ db .add_all (db_votes )
438+ db .commit ()
439+ db .refresh (db_ballot )
440+
441+ for v in db_votes :
442+ db .refresh (v )
443+ except Exception as e :
444+ db .rollback ()
445+ raise e
416446
417447 votes_get = [schemas .VoteGet .model_validate (v ) for v in db_votes ]
418448 vote_ids = [v .id for v in votes_get ]
419- token = create_ballot_token (vote_ids , ballot .election_ref )
449+ token = create_ballot_token (vote_ids , ballot .election_ref , db_ballot . id )
420450 return schemas .BallotGet (votes = votes_get , token = token , election = election )
421451
422452
@@ -511,6 +541,13 @@ def update_ballot(
511541 if len (db_votes ) != len (vote_ids ):
512542 raise errors .NotFoundError ("votes" )
513543
544+ # Verify all votes belong to the same ballot
545+ ballot_ids = {v .ballot_id for v in db_votes if v .ballot_id is not None }
546+
547+ if len (ballot_ids ) > 1 :
548+ print ("ballots_ids:" , sorted (ballot_ids ))
549+ raise errors .ForbiddenError ("All votes must belong to the same ballot" )
550+
514551 election = schemas .ElectionGet .model_validate (db_votes [0 ].election )
515552
516553 for vote , db_vote in zip (ballot .votes , db_votes ):
@@ -521,7 +558,7 @@ def update_ballot(
521558 db .commit ()
522559
523560 votes_get = [schemas .VoteGet .model_validate (v ) for v in db_votes ]
524- token = create_ballot_token (vote_ids , election_ref )
561+ token = create_ballot_token (vote_ids , election_ref , db_votes [ 0 ]. ballot_id )
525562 return schemas .BallotGet (votes = votes_get , token = token , election = election )
526563
527564
0 commit comments