Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions coldfront/config/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
# ------------------------------------------------------------------------------
PROJECT_ENABLE_PROJECT_REVIEW = ENV.bool("PROJECT_ENABLE_PROJECT_REVIEW", default=True)

# ------------------------------------------------------------------------------
# Maximum number of projects per PI
# ------------------------------------------------------------------------------
MAX_PROJECTS_PER_PI = ENV.int("MAX_PROJECTS_PER_PI", default=3)

# ------------------------------------------------------------------------------
# Allocation related
# ------------------------------------------------------------------------------
Expand Down
51 changes: 51 additions & 0 deletions coldfront/core/project/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
ProjectAttributeType,
ProjectUser,
ProjectPermission,
ProjectStatusChoice,
)
from django.urls import reverse

logging.disable(logging.CRITICAL)

Expand All @@ -32,10 +34,16 @@ class Data:
def __init__(self):
user = UserFactory(username="cgray")
user.userprofile.is_pi = True
user.userprofile.save()

school = SchoolFactory(description="Tandon School of Engineering")
status = ProjectStatusChoiceFactory(name="Active")

# Ensure the status of "New" expects exists
ProjectStatusChoiceFactory(name="New")
ProjectUserRoleChoiceFactory(name="Manager")
ProjectUserStatusChoiceFactory(name="Active")

self.initial_fields = {
"pi": user,
"title": "Angular momentum in QGP holography",
Expand All @@ -47,9 +55,52 @@ def __init__(self):

self.unsaved_object = Project(**self.initial_fields)

# POST payload for the CreateView (fields = title, description, school)
self.create_post_data = {
"title": "P4 attempt",
"description": "desc",
"school": school.pk,
}

self.user = user
self.school = school

def setUp(self):
self.data = self.Data()

def test_pi_cannot_create_more_than_three_projects(self):
"""Test that a PI cannot create more than MAX_PROJECTS_PER_PI projects."""
user = self.data.initial_fields["pi"]

# Create 3 existing projects for this PI
for i in range(3):
Project.objects.create(
pi=user,
title=f"Existing {i}",
description="d",
school=self.data.initial_fields["school"],
status=ProjectStatusChoice.objects.get(name="New"),
)

self.assertEqual(3, Project.objects.filter(pi=user).count())

self.client.force_login(user)

post_data = {
"title": "P4 attempt",
"description": "desc",
"school": self.data.initial_fields["school"].pk,
}

resp = self.client.post(reverse("project-create"), data=post_data)

# You are redirecting on failure (Location should show /project/)
self.assertEqual(302, resp.status_code)
self.assertEqual("/project/", resp["Location"])

# Core requirement: still only 3 projects
self.assertEqual(3, Project.objects.filter(pi=user).count())

def test_fields_generic(self):
"""Test that generic project fields save correctly"""
self.assertEqual(0, len(Project.objects.all()))
Expand Down
45 changes: 35 additions & 10 deletions coldfront/core/project/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from django.views.generic import CreateView, DetailView, ListView, UpdateView
from django.views.generic.base import TemplateView
from django.views.generic.edit import FormView

from django.db import transaction
from coldfront.core.allocation.models import (
Allocation,
AllocationStatusChoice,
Expand Down Expand Up @@ -592,10 +592,11 @@ def post(self, request, *args, **kwargs):
return redirect(reverse("project-detail", kwargs={"pk": project.pk}))


MAX_PROJECTS_PER_PI = import_from_settings("MAX_PROJECTS_PER_PI")

class ProjectCreateView(LoginRequiredMixin, UserPassesTestMixin, CreateView):
model = Project
template_name_suffix = "_create_form"
# Add one more field here
fields = [
"title",
"description",
Expand All @@ -609,16 +610,40 @@ def test_func(self):

if self.request.user.userprofile.is_pi:
return True
return False

def dispatch(self, request, *args, **kwargs):
# Block early for nicer UX (still keep the check in form_valid for safety)
user = request.user
if (
user.is_authenticated
and not user.is_superuser
and hasattr(user, "userprofile")
and user.userprofile.is_pi
and Project.objects.filter(pi=user).count() >= MAX_PROJECTS_PER_PI
):
messages.error(request, f"You can only create up to {MAX_PROJECTS_PER_PI} projects.")
return redirect("project-list") # change to wherever you want to send them
return super().dispatch(request, *args, **kwargs)

def form_valid(self, form):
project_obj = form.save(commit=False)
form.instance.pi = self.request.user
form.instance.status = ProjectStatusChoice.objects.get(name="New")
project_obj.save()
self.object = project_obj

project_user_obj = ProjectUser.objects.create(
user=self.request.user,
user = self.request.user

with transaction.atomic():
current_count = Project.objects.select_for_update().filter(pi=user).count()
if current_count >= MAX_PROJECTS_PER_PI:
form.add_error(None, f"You can only create up to {MAX_PROJECTS_PER_PI} projects.")
return self.form_invalid(form)

project_obj = form.save(commit=False)
project_obj.pi = user
project_obj.status = ProjectStatusChoice.objects.get(name="New")
project_obj.save()
self.object = project_obj


ProjectUser.objects.create(
user=user,
project=project_obj,
role=ProjectUserRoleChoice.objects.get(name="Manager"),
status=ProjectUserStatusChoice.objects.get(name="Active"),
Expand Down
Loading