diff --git a/coldfront/config/core.py b/coldfront/config/core.py index 84fbd517c..48fd9a726 100644 --- a/coldfront/config/core.py +++ b/coldfront/config/core.py @@ -30,6 +30,11 @@ # ------------------------------------------------------------------------------ PROJECT_ENABLE_PROJECT_REVIEW = ENV.bool("PROJECT_ENABLE_PROJECT_REVIEW", default=True) +# ------------------------------------------------------------------------------ +# Maximum number of projects per PI +# ------------------------------------------------------------------------------ +MAX_PROJECTS_PER_PI = ENV.int("MAX_PROJECTS_PER_PI", default=3) + # ------------------------------------------------------------------------------ # Allocation related # ------------------------------------------------------------------------------ diff --git a/coldfront/core/project/tests/tests.py b/coldfront/core/project/tests/tests.py index 91ffac777..2b9fbed25 100644 --- a/coldfront/core/project/tests/tests.py +++ b/coldfront/core/project/tests/tests.py @@ -20,7 +20,9 @@ ProjectAttributeType, ProjectUser, ProjectPermission, + ProjectStatusChoice, ) +from django.urls import reverse logging.disable(logging.CRITICAL) @@ -32,10 +34,16 @@ class Data: def __init__(self): user = UserFactory(username="cgray") user.userprofile.is_pi = True + user.userprofile.save() school = SchoolFactory(description="Tandon School of Engineering") status = ProjectStatusChoiceFactory(name="Active") + # Ensure the status of "New" expects exists + ProjectStatusChoiceFactory(name="New") + ProjectUserRoleChoiceFactory(name="Manager") + ProjectUserStatusChoiceFactory(name="Active") + self.initial_fields = { "pi": user, "title": "Angular momentum in QGP holography", @@ -47,9 +55,52 @@ def __init__(self): self.unsaved_object = Project(**self.initial_fields) + # POST payload for the CreateView (fields = title, description, school) + self.create_post_data = { + "title": "P4 attempt", + "description": "desc", + "school": school.pk, + } + + self.user = user + self.school = school + def setUp(self): self.data = self.Data() + def test_pi_cannot_create_more_than_three_projects(self): + """Test that a PI cannot create more than MAX_PROJECTS_PER_PI projects.""" + user = self.data.initial_fields["pi"] + + # Create 3 existing projects for this PI + for i in range(3): + Project.objects.create( + pi=user, + title=f"Existing {i}", + description="d", + school=self.data.initial_fields["school"], + status=ProjectStatusChoice.objects.get(name="New"), + ) + + self.assertEqual(3, Project.objects.filter(pi=user).count()) + + self.client.force_login(user) + + post_data = { + "title": "P4 attempt", + "description": "desc", + "school": self.data.initial_fields["school"].pk, + } + + resp = self.client.post(reverse("project-create"), data=post_data) + + # You are redirecting on failure (Location should show /project/) + self.assertEqual(302, resp.status_code) + self.assertEqual("/project/", resp["Location"]) + + # Core requirement: still only 3 projects + self.assertEqual(3, Project.objects.filter(pi=user).count()) + def test_fields_generic(self): """Test that generic project fields save correctly""" self.assertEqual(0, len(Project.objects.all())) diff --git a/coldfront/core/project/views.py b/coldfront/core/project/views.py index def15489b..bd5f385a2 100644 --- a/coldfront/core/project/views.py +++ b/coldfront/core/project/views.py @@ -20,7 +20,7 @@ from django.views.generic import CreateView, DetailView, ListView, UpdateView from django.views.generic.base import TemplateView from django.views.generic.edit import FormView - +from django.db import transaction from coldfront.core.allocation.models import ( Allocation, AllocationStatusChoice, @@ -592,10 +592,11 @@ def post(self, request, *args, **kwargs): return redirect(reverse("project-detail", kwargs={"pk": project.pk})) +MAX_PROJECTS_PER_PI = import_from_settings("MAX_PROJECTS_PER_PI") + class ProjectCreateView(LoginRequiredMixin, UserPassesTestMixin, CreateView): model = Project template_name_suffix = "_create_form" - # Add one more field here fields = [ "title", "description", @@ -609,16 +610,40 @@ def test_func(self): if self.request.user.userprofile.is_pi: return True + return False + + def dispatch(self, request, *args, **kwargs): + # Block early for nicer UX (still keep the check in form_valid for safety) + user = request.user + if ( + user.is_authenticated + and not user.is_superuser + and hasattr(user, "userprofile") + and user.userprofile.is_pi + and Project.objects.filter(pi=user).count() >= MAX_PROJECTS_PER_PI + ): + messages.error(request, f"You can only create up to {MAX_PROJECTS_PER_PI} projects.") + return redirect("project-list") # change to wherever you want to send them + return super().dispatch(request, *args, **kwargs) def form_valid(self, form): - project_obj = form.save(commit=False) - form.instance.pi = self.request.user - form.instance.status = ProjectStatusChoice.objects.get(name="New") - project_obj.save() - self.object = project_obj - - project_user_obj = ProjectUser.objects.create( - user=self.request.user, + user = self.request.user + + with transaction.atomic(): + current_count = Project.objects.select_for_update().filter(pi=user).count() + if current_count >= MAX_PROJECTS_PER_PI: + form.add_error(None, f"You can only create up to {MAX_PROJECTS_PER_PI} projects.") + return self.form_invalid(form) + + project_obj = form.save(commit=False) + project_obj.pi = user + project_obj.status = ProjectStatusChoice.objects.get(name="New") + project_obj.save() + self.object = project_obj + + + ProjectUser.objects.create( + user=user, project=project_obj, role=ProjectUserRoleChoice.objects.get(name="Manager"), status=ProjectUserStatusChoice.objects.get(name="Active"),