Content description:
In this post I'll describe how to test a ModelViewSet based on the created Contract model.
I'll test the methods: GET, POST, PATCH and DELETE.
I'll check the correctness of the validators and permissions.
I create a Contract model.
The model consists of fields:
- contract status (default: open),
- contract number set by the customer,
- contractor who is to perform the order,
- planned delivery date and time,
- warehouse to which the goods are to be delivered.
The delivery date contains a validator - it cannot be earlier than the order date.
from django.db import models
from django.core import validators
from django.utils import timezone
from users.models import ContractUser
from .warehouse import Warehouse
class Contract(models.Model):
class Meta:
ordering = ("-date_of_order",)
class DeliveryDateValidator:
def validate(value):
if value < timezone.now().date():
raise validators.ValidationError(
"Date of delivery can't be earlier than date of order"
)
else:
return value
status = models.CharField(max_length=9, default="open")
contract_number = models.CharField(max_length=15, unique=True)
client = models.ForeignKey(ContractUser, on_delete=models.CASCADE)
contractor = models.ForeignKey(ContractUser, on_delete=models.CASCADE, related_name="+")
date_of_order = models.DateTimeField(default=timezone.now)
date_of_delivery = models.DateField(validators=(DeliveryDateValidator.validate,))
time_of_delivery = models.TimeField()
pallets_planned = models.IntegerField(
validators=(
validators.MinValueValidator(1),
validators.MaxValueValidator(30, "Max number of pallets is 30"),
)
)
warehouse = models.ForeignKey(Warehouse, on_delete=models.CASCADE)
def __str__(self):
return self.contract_number
This model is serialized using ModelSerializer, and I add additional information about the warehouse, e.g. warehouse address.
from rest_framework import serializers
from contracts.models import Contract
class ContractSerializer(serializers.ModelSerializer):
class Meta:
model = Contract
fields = [
"id",
"status",
"contract_number",
"client",
"contractor",
"date_of_order",
"date_of_delivery",
"time_of_delivery",
"pallets_planned",
"warehouse",
"warehouse_info"
]
extra_kwargs = {
"status": {"read_only": True},
"date_of_order": {
"read_only": True,
"format": "%Y-%m-%d"
},
"time_of_delivery": {"format": "%H:%M"},
}
client = serializers.StringRelatedField()
warehouse_info = serializers.SerializerMethodField()
def validate(self, attrs):
return super().validate(attrs)
def get_warehouse_info(self, obj):
return obj.warehouse.warehouse_info
def to_representation(self, instance):
contract = super().to_representation(instance)
contract['contractor'] = instance.contractor.username
contract['warehouse'] = instance.warehouse.warehouse_name
return contract
def create(self, validated_data):
validated_data["client"] = self.context["client"]
contract = Contract.objects.create(**validated_data)
return contract
In the view, inheriting from ModelViewSet, I override the destroy() method, because the contract is not supposed to be deleted by default, and only its status field is supposed to be changed to "cancelled".
from rest_framework import viewsets, permissions, status
from rest_framework.response import Response
from contracts.models import Contract
from ..serializers import ContractSerializer
from ..permissions.contract import ContractWritePermission
class ContractViewSet(viewsets.ModelViewSet):
serializer_class = ContractSerializer
permission_classes = [
permissions.IsAuthenticated,
ContractWritePermission
]
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
if instance.status == "open":
# instance.delete() // if instance should be deleted - default option
# return Response(status=status.HTTP_204_NO_CONTENT)
instance.status = "cancelled"
instance.save()
return Response({"msg": "Contract succesfully cancelled"}, status=status.HTTP_200_OK)
else:
error_msg = f"Bad request. Contract status: '{instance.status}'"
return Response({"error": error_msg}, status=status.HTTP_400_BAD_REQUEST)
I test all access methods and the correctness of permissions and validators.
Testing the GET request:
from rest_framework.test import APITestCase
from rest_framework import status
from django.urls import reverse
from django.utils import timezone
from datetime import timedelta, datetime
from users.models import ContractUser
from contracts.models import Warehouse, Contract
class ContractAPITestCase(APITestCase):
def setUp(self):
self.client1 = ContractUser.objects.create(
username="client1", password="123", email="client1@company.com"
)
self.contractor1 = ContractUser.objects.create(
username="contractor1", password="123", email="contractor1@company.com", profile="contractor"
)
self.contractor2 = ContractUser.objects.create(
username="contractor2", password="123", email="contractor2@company.com", profile="contractor"
)
self.warehouse = Warehouse.objects.create(
warehouse_name="Warehouse",
warehouse_info="Warehouse info",
client=self.client1,
)
self.contract = Contract.objects.create(
contract_number="A01",
client=self.client1,
contractor=self.contractor1,
date_of_delivery = timezone.now() + timedelta(1),
time_of_delivery= "12:00",
pallets_planned=1,
warehouse=self.warehouse,
)
self.client.force_authenticate(user=self.client1)
def test_get(self):
endpoint = reverse("contract-list")
response = self.client.get(endpoint, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
Testing if a user is logged in:
def test_is_authenticated(self):
self.client.force_authenticate(user=None)
endpoint = reverse("contract-list")
response = self.client.get(endpoint, format="json")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.data["detail"], "Authentication credentials were not provided.")
Testing the POST request:
def test_create(self):
endpoint = reverse("contract-list")
data = {
"contract_number": "A02",
"contractor": self.contractor1.id,
"date_of_delivery": (timezone.now() + timedelta(1)).strftime("%Y-%m-%d"),
"time_of_delivery": "12:00",
"pallets_planned": 1,
"warehouse": self.warehouse.id,
}
response = self.client.post(endpoint, data, format="json")
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
Testing if a user can create a contract (only a user with a customer profile can do this):
def test_user_profile_is_client(self):
self.client.force_authenticate(user=self.contractor1)
endpoint = reverse("contract-list")
data = {
"contract_number": "A02",
"contractor": self.contractor2.id,
"date_of_delivery": (timezone.now() + timedelta(1)).strftime("%Y-%m-%d"),
"time_of_delivery": "12:00",
"pallets_planned": 1,
"warehouse": self.warehouse.id,
}
response = self.client.post(endpoint, data, format="json")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.data['detail'], "No sufficient permissions")
Testing if the delivery date is not earlier than the order date:
def test_date_of_delivery_validator(self):
endpoint = reverse("contract-list")
data = {
"contract_number": "A02",
"contractor": self.contractor1.id,
"date_of_delivery": (timezone.now() - timedelta(1)).strftime("%Y-%m-%d"),
"time_of_delivery": "12:00",
"pallets_planned": 1,
"warehouse": self.warehouse.id,
}
response = self.client.post(endpoint, data, format="json")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data['date_of_delivery'][0], "Date of delivery can't be earlier than date of order")
Testing retrieving of a specific contract:
def test_retrieve(self):
endpoint = reverse("contract-detail", args=[self.contract.id])
response = self.client.get(endpoint, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["contract_number"], self.contract.contract_number)
Testing updating a specific contract:
def test_update(self):
data = {
"contract_number": "A02",
}
endpoint = reverse("contract-detail", args=[self.contract.id])
response = self.client.patch(endpoint, data, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["contract_number"], "A02")
Testing contract cancellation:
def test_delete(self):
endpoint = reverse("contract-detail", args=[self.contract.id])
response = self.client.delete(endpoint)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["msg"], "Contract succesfully cancelled")