Skip to content

Commit 8d5fcfd

Browse files
committed
Update tests to use correct auth
1 parent ac88392 commit 8d5fcfd

4 files changed

Lines changed: 54 additions & 10 deletions

File tree

bats_ai/core/tests/conftest.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from django.contrib.auth.models import User
22
from django.test import Client
3+
from oauth2_provider.models import AccessToken
34
import pytest
45

56
from bats_ai.core.models import VettingDetails
67

7-
from .factories import SuperuserFactory, UserFactory, VettingDetailsFactory
8+
from .factories import AccessTokenFactory, SuperuserFactory, UserFactory, VettingDetailsFactory
89

910

1011
@pytest.fixture
@@ -18,21 +19,31 @@ def user() -> User:
1819

1920

2021
@pytest.fixture
21-
def superuser() -> User:
22-
return SuperuserFactory()
22+
def user_token(user) -> AccessToken:
23+
return AccessTokenFactory(user=user)
2324

2425

2526
@pytest.fixture
26-
def authenticated_client(user: User) -> Client:
27+
def authenticated_client(user: User, user_token: AccessToken) -> Client:
2728
client = Client()
28-
client.force_login(user=user)
29+
client.defaults['HTTP_AUTHORIZATION'] = f'Bearer {user_token.token}'
2930
return client
3031

3132

3233
@pytest.fixture
33-
def authorized_client(superuser: User) -> Client:
34+
def superuser() -> User:
35+
return SuperuserFactory()
36+
37+
38+
@pytest.fixture
39+
def superuser_token(superuser) -> AccessToken:
40+
return AccessTokenFactory(user=superuser)
41+
42+
43+
@pytest.fixture
44+
def authorized_client(superuser: User, superuser_token: AccessToken) -> Client:
3445
client = Client()
35-
client.force_login(user=superuser)
46+
client.defaults['HTTP_AUTHORIZATION'] = f'Bearer {superuser_token.token}'
3647
return client
3748

3849

bats_ai/core/tests/factories.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
from datetime import timedelta
2+
13
from django.contrib.auth.models import User
4+
from django.utils import timezone
25
import factory.django
6+
from oauth2_provider.models import AccessToken
37

48
from bats_ai.core.models import VettingDetails
59

@@ -39,3 +43,13 @@ class Meta:
3943

4044
user = factory.SubFactory(UserFactory)
4145
reference_materials = factory.Faker('paragraph', nb_sentences=3)
46+
47+
48+
class AccessTokenFactory(factory.django.DjangoModelFactory[AccessToken]):
49+
class Meta:
50+
model = AccessToken
51+
52+
user = factory.SubFactory(UserFactory)
53+
token = factory.Faker('uuid4')
54+
scope = 'read write'
55+
expires = factory.LazyFunction(lambda: timezone.now() + timedelta(hours=1))

bats_ai/core/tests/test_admin.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,20 @@ def test_is_admin(client_fixture, status_code, is_admin, request):
1616
assert resp.status_code == status_code
1717
if is_admin is not None:
1818
assert resp.json()['is_admin'] == is_admin
19+
20+
21+
@pytest.mark.parametrize(
22+
'client_fixture,status_code',
23+
[
24+
('authenticated_client', 200),
25+
('client', 401),
26+
],
27+
)
28+
@pytest.mark.django_db
29+
def test_get_current_user(client_fixture, status_code, user, request):
30+
api_client = request.getfixturevalue(client_fixture)
31+
resp = api_client.get('/api/v1/configuration/me')
32+
assert resp.status_code == status_code
33+
if status_code == 200:
34+
assert resp.json()['name'] == user.username
35+
assert resp.json()['email'] == user.email

bats_ai/core/tests/test_vetting_details.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from .factories import UserFactory, VettingDetailsFactory
3+
from .factories import AccessTokenFactory, UserFactory, VettingDetailsFactory
44

55

66
@pytest.mark.parametrize(
@@ -33,7 +33,8 @@ def test_create_vetting_details(client):
3333
test_text = 'foo'
3434
data = {'reference_materials': test_text}
3535
test_user = UserFactory()
36-
client.force_login(user=test_user)
36+
test_token = AccessTokenFactory(user=test_user)
37+
client.defaults['HTTP_AUTHORIZATION'] = f'Bearer {test_token.token}'
3738
resp = client.post(
3839
f'/api/v1/vetting/user/{test_user.id}', data=data, content_type='application/json'
3940
)
@@ -67,8 +68,9 @@ def test_update_vetting_details(client):
6768
test_text = 'bar'
6869
data = {'reference_materials': 'bar'}
6970
test_user = UserFactory()
71+
test_token = AccessTokenFactory(user=test_user)
7072
VettingDetailsFactory(user=test_user, reference_materials='foo')
71-
client.force_login(test_user)
73+
client.defaults['HTTP_AUTHORIZATION'] = f'Bearer {test_token.token}'
7274

7375
initial_resp = client.get(f'/api/v1/vetting/user/{test_user.id}')
7476
assert initial_resp.status_code == 200

0 commit comments

Comments
 (0)