from django.contrib.auth import get_user_model from django.urls import reverse from rest_framework import status from rest_framework.test import APITestCase from .models import VendorProfile User = get_user_model() class AccountModelTests(APITestCase): def test_custom_user_creation_and_roles(self): user = User.objects.create_user( email="customer@example.com", password="Pass123456!", is_customer=True, is_vendor=False, ) self.assertEqual(user.email, "customer@example.com") self.assertTrue(user.check_password("Pass123456!")) self.assertTrue(user.is_customer) self.assertFalse(user.is_vendor) class AccountApiTests(APITestCase): def test_vendor_registration_and_jwt_flow(self): register_payload = { "email": "vendor@example.com", "password": "Pass123456!", "business_name": "Ocean Rentals", } register_res = self.client.post(reverse("register_vendor"), register_payload, format="json") self.assertEqual(register_res.status_code, status.HTTP_201_CREATED) self.assertTrue(User.objects.filter(email="vendor@example.com", is_vendor=True).exists()) self.assertTrue(VendorProfile.objects.filter(user__email="vendor@example.com").exists()) token_res = self.client.post( reverse("token_obtain_pair"), {"email": "vendor@example.com", "password": "Pass123456!"}, format="json", ) self.assertEqual(token_res.status_code, status.HTTP_200_OK) self.assertIn("access", token_res.data) self.assertIn("refresh", token_res.data) self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token_res.data['access']}") me_res = self.client.get(reverse("me")) self.assertEqual(me_res.status_code, status.HTTP_200_OK) self.assertEqual(me_res.data["email"], "vendor@example.com") def test_vendor_profile_endpoint_forbidden_for_customer(self): customer = User.objects.create_user( email="customer@example.com", password="Pass123456!", is_customer=True, is_vendor=False, ) token_res = self.client.post( reverse("token_obtain_pair"), {"email": customer.email, "password": "Pass123456!"}, format="json", ) self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {token_res.data['access']}") res = self.client.get(reverse("vendor_profile_me")) self.assertEqual(res.status_code, status.HTTP_403_FORBIDDEN)