From b4b4fa764380d50fc894e3e4f1ef8715cf56ad69 Mon Sep 17 00:00:00 2001 From: JulienAldon Date: Fri, 6 Mar 2026 00:00:01 +0100 Subject: [PATCH] fix all pylint warnings, add tests (wip) fix recap --- backend/src/auth/__init__.py | 0 backend/src/auth/auth.py | 65 ++++-- backend/src/contracts/contracts.py | 3 +- backend/src/contracts/generate_contract.py | 121 ++++++++--- backend/src/forms/forms.py | 5 +- backend/src/forms/service.py | 11 +- backend/src/productors/productors.py | 26 ++- backend/src/productors/service.py | 22 +- backend/src/products/products.py | 68 ++++-- backend/src/products/service.py | 43 ++-- backend/src/shipments/service.py | 40 +++- backend/src/shipments/shipments.py | 24 ++- backend/src/templates/templates.py | 3 +- backend/src/users/service.py | 12 +- backend/src/users/users.py | 38 +++- backend/tests/conftest.py | 2 - backend/tests/factories/__init__.py | 3 + backend/tests/routers/__init__.py | 3 + backend/tests/routers/test_contracts.py | 75 ++++--- backend/tests/routers/test_forms.py | 151 ++++++++----- backend/tests/routers/test_productors.py | 202 ++++++++++++------ backend/tests/routers/test_products.py | 131 +++++++----- backend/tests/routers/test_shipments.py | 167 +++++++++------ backend/tests/services/__init__.py | 3 + .../src/components/Contracts/Modal/index.tsx | 3 + 25 files changed, 845 insertions(+), 376 deletions(-) create mode 100644 backend/src/auth/__init__.py create mode 100644 backend/tests/factories/__init__.py create mode 100644 backend/tests/routers/__init__.py create mode 100644 backend/tests/services/__init__.py diff --git a/backend/src/auth/__init__.py b/backend/src/auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/src/auth/auth.py b/backend/src/auth/auth.py index 3015828..6c81279 100644 --- a/backend/src/auth/auth.py +++ b/backend/src/auth/auth.py @@ -4,14 +4,13 @@ from urllib.parse import urlencode import jwt import requests -import src.messages as messages import src.users.service as service -from fastapi import (APIRouter, Cookie, Depends, HTTPException, Request, - Security) +from fastapi import APIRouter, Cookie, Depends, HTTPException, Request from fastapi.responses import RedirectResponse, Response -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi.security import HTTPBearer from jwt import PyJWKClient from sqlmodel import Session, select +from src import messages from src.database import get_session from src.models import User, UserCreate, UserPublic from src.settings import (AUTH_URL, ISSUER, JWKS_URL, LOGOUT_URL, TOKEN_URL, @@ -78,7 +77,18 @@ def callback(code: str, session: Session = Depends(get_session)): headers = { 'Content-Type': 'application/x-www-form-urlencoded' } - response = requests.post(TOKEN_URL, data=data, headers=headers) + try: + response = requests.post( + TOKEN_URL, + data=data, + headers=headers, + timeout=10 + ) + except requests.exceptions.Timeout as error: + raise HTTPException( + status_code=404, + detail=messages.Messages.not_found('token') + ) from error if response.status_code != 200: raise HTTPException( status_code=404, @@ -99,7 +109,13 @@ def callback(code: str, session: Session = Depends(get_session)): 'client_secret': settings.keycloak_client_secret, 'refresh_token': token_data['refresh_token'], } - requests.post(LOGOUT_URL, data=data) + try: + requests.post(LOGOUT_URL, data=data, timeout=10) + except requests.exceptions.Timeout as error: + raise HTTPException( + status_code=404, + detail=messages.Messages.not_found('token') + ) from error resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true') return resp roles = resource_access.get(settings.keycloak_client_id) @@ -109,7 +125,13 @@ def callback(code: str, session: Session = Depends(get_session)): 'client_secret': settings.keycloak_client_secret, 'refresh_token': token_data['refresh_token'], } - requests.post(LOGOUT_URL, data=data) + try: + requests.post(LOGOUT_URL, data=data, timeout=10) + except requests.exceptions.Timeout as error: + raise HTTPException( + status_code=404, + detail=messages.Messages.not_found('token') + ) from error resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true') return resp @@ -160,16 +182,16 @@ def verify_token(token: str): leeway=60, ) return decoded - except jwt.ExpiredSignatureError: + except jwt.ExpiredSignatureError as error: raise HTTPException( status_code=401, detail=messages.Messages.tokenexipired - ) - except jwt.InvalidTokenError: + ) from error + except jwt.InvalidTokenError as error: raise HTTPException( status_code=401, detail=messages.Messages.invalidtoken - ) + ) from error def get_current_user( @@ -184,7 +206,7 @@ def get_current_user( payload = verify_token(access_token) if not payload: raise HTTPException( - status_code=401, + status_code=401, detail='aze' ) email = payload.get('email') @@ -205,7 +227,7 @@ def get_current_user( @router.post('/refresh') -def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None): +def refresh_user_token(refresh_token: Annotated[str | None, Cookie()] = None): refresh = refresh_token data = { 'grant_type': 'refresh_token', @@ -216,7 +238,18 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None): headers = { 'Content-Type': 'application/x-www-form-urlencoded' } - result = requests.post(TOKEN_URL, data=data, headers=headers) + try: + result = requests.post( + TOKEN_URL, + data=data, + headers=headers, + timeout=10, + ) + except requests.exceptions.Timeout as error: + raise HTTPException( + status_code=404, + detail=messages.Messages.not_found('token') + ) from error if result.status_code != 200: raise HTTPException( status_code=404, @@ -229,7 +262,7 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None): key='access_token', value=token_data['access_token'], httponly=True, - secure=True if settings.debug == False else True, + secure=True if settings.debug is False else True, samesite='strict', max_age=settings.max_age ) @@ -237,7 +270,7 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None): key='refresh_token', value=token_data['refresh_token'] or '', httponly=True, - secure=True if settings.debug == False else True, + secure=True if settings.debug is False else True, samesite='strict', max_age=30 * 24 * settings.max_age ) diff --git a/backend/src/contracts/contracts.py b/backend/src/contracts/contracts.py index 12f029c..d30911d 100644 --- a/backend/src/contracts/contracts.py +++ b/backend/src/contracts/contracts.py @@ -4,11 +4,10 @@ import zipfile import src.contracts.service as service import src.forms.service as form_service -import src.messages as messages from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from sqlmodel import Session -from src import models +from src import messages, models from src.auth.auth import get_current_user from src.contracts.generate_contract import (generate_html_contract, generate_recap) diff --git a/backend/src/contracts/generate_contract.py b/backend/src/contracts/generate_contract.py index 2dbf7cb..af5f6ec 100644 --- a/backend/src/contracts/generate_contract.py +++ b/backend/src/contracts/generate_contract.py @@ -18,11 +18,24 @@ def generate_html_contract( reccurents: list[dict], recurrent_price: float | None = None, total_price: float | None = None -): +) -> bytes: + """Generate a html contract + Arguments: + contract(models.Contract): Contract source. + cheques(list[dict]): cheques formated in dict. + occasionals(list[dict]): occasional products. + reccurents(list[dict]): recurrent products. + recurrent_price(float | None = None): total price of recurent products. + total_price(float | None = Non): total price. + Return: + result(bytes): contract file in pdf as bytes. + """ template_dir = pathlib.Path("./src/contracts/templates").resolve() template_loader = jinja2.FileSystemLoader(searchpath=template_dir) template_env = jinja2.Environment( - loader=template_loader, autoescape=jinja2.select_autoescape(["html", "xml"])) + loader=template_loader, + autoescape=jinja2.select_autoescape(["html", "xml"]) + ) template_file = "layout.html" template = template_env.get_template(template_file) output_text = template.render( @@ -65,13 +78,16 @@ def generate_html_contract( def flatten(xss): + """flatten a list of list. + """ return [x for xs in xss for x in xs] def create_column_style_width(size: str) -> odfdo.Style: """Create a table columm style for a given width. Paramenters: - size(str): size of the style (format ) unit can be in, cm... see odfdo documentation. + size(str): size of the style (format ) + unit can be in, cm... see odfdo documentation. Returns: odfdo.Style with the correct column-width attribute. """ @@ -85,7 +101,8 @@ def create_column_style_width(size: str) -> odfdo.Style: def create_row_style_height(size: str) -> odfdo.Style: """Create a table height style for a given height. Paramenters: - size(str): size of the style (format ) unit can be in, cm... see odfdo documentation. + size(str): size of the style (format ) + unit can be in, cm... see odfdo documentation. Returns: odfdo.Style with the correct column-height attribute. """ @@ -97,10 +114,17 @@ def create_row_style_height(size: str) -> odfdo.Style: def create_currency_style(name:str = 'currency-euro'): + """Create a table currency style. + Paramenters: + name(str): name of the style (default to `currency-euro`). + Returns: + odfdo.Style with the correct column-height attribute. + """ return odfdo.Element.from_tag( f""" - + """ ) @@ -113,6 +137,18 @@ def create_cell_style( color: str = '#000000', currency: bool = False, ) -> odfdo.Style: + """Create a cell style + Paramenters: + name(str): name of the style (default to `centered-cell`). + font_size(str): font_size of the cell (default to `10pt`). + bold(str): is the text bold (default to `False`). + background_color(str): background_color of the cell + (default to `#FFFFFF`). + color(str): color of the text of the cell (default to `#000000`). + currency(str): is the cell a currency (default to `False`). + Returns: + odfdo.Style with the correct column-height attribute. + """ bold_attr = """ fo:font-weight="bold" style:font-weight-asian="bold" @@ -138,7 +174,13 @@ def create_cell_style( ) -def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols: list[int]): +def apply_cell_style( + document: odfdo.Document, + table: odfdo.Table, + currency_cols: list[int] +): + """Apply cell style + """ document.insert_style( style=create_currency_style(), ) @@ -151,7 +193,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols color="#FFF" ) ) - body_style_even = document.insert_style( create_cell_style( name="body-style-even", @@ -160,7 +201,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols color="#000000", ) ) - body_style_odd = document.insert_style( create_cell_style( name="body-style-odd", @@ -169,7 +209,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols color="#000000", ) ) - footer_style = document.insert_style( create_cell_style( name="footer-cells", @@ -177,7 +216,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols font_size='12pt', ) ) - body_style_even_currency = document.insert_style( create_cell_style( name="body-style-even-currency", @@ -187,7 +225,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols currency=True, ) ) - body_style_odd_currency = document.insert_style( create_cell_style( name="body-style-odd-currency", @@ -197,7 +234,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols currency=True, ) ) - footer_style_currency = document.insert_style( create_cell_style( name="footer-cells-currency", @@ -206,7 +242,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols currency=True, ) ) - for index, row in enumerate(table.get_rows()): style = body_style_even currency_style = body_style_even_currency @@ -228,7 +263,12 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols cell.style = style -def apply_column_height_style(document: odfdo.Document, table: odfdo.Table): +def apply_column_height_style( + document: odfdo.Document, + table: odfdo.Table +): + """Apply column height for a given table + """ header_style = document.insert_style( style=create_row_style_height('1.60cm'), name='1.60cm', automatic=True ) @@ -241,19 +281,29 @@ def apply_column_height_style(document: odfdo.Document, table: odfdo.Table): else: row.style = body_style -def apply_cell_style_by_column(table: odfdo.Table, style: odfdo.Style, col_index: int): + +def apply_cell_style_by_column( + table: odfdo.Table, + style: odfdo.Style, + col_index: int +): + """Apply cell style for a given table + """ for cell in table.get_column_cells(col_index): - print(cell.style) cell.style = style - print(cell.serialize()) -def apply_column_width_style(document: odfdo.Document, table: odfdo.Table, widths: list[str]): +def apply_column_width_style( + document: odfdo.Document, + table: odfdo.Table, + widths: list[str] +): """Apply column width style to a table. Parameters: document(odfdo.Document): Document where the table is located. table(odfdo.Table): Table to apply columns widths. - widths(list[str]): list of width in format unit ca be in, cm... see odfdo documentation. + widths(list[str]): list of width in format unit ca be + in, cm... see odfdo documentation. """ styles = [] for w in widths: @@ -268,6 +318,12 @@ def apply_column_width_style(document: odfdo.Document, table: odfdo.Table, width def generate_ods_letters(n: int): + """Generate letters following excel format. + Arguments: + n(int): `n` letters to generate. + Return: + result(list[str]): list of `n` letters that follow excel pattern. + """ letters = string.ascii_lowercase result = [] for i in range(n): @@ -282,6 +338,8 @@ def generate_ods_letters(n: int): def compute_contract_prices(contract: models.Contract) -> dict: + """Compute price for a give contract. + """ occasional_contract_products = list( filter( lambda contract_product: ( @@ -313,6 +371,8 @@ def compute_contract_prices(contract: models.Contract) -> dict: def transform_formula_cells(sheet: odfdo.Spreadsheet): + """Transform cell value to a formula using odfdo. + """ for row in sheet.get_rows(): for cell in row.get_cells(): if not cell.value or cell.get_attribute("office:value-type") == "float": @@ -330,6 +390,8 @@ def merge_shipment_cells( occasionnals: list[str], shipments: list[models.Shipment] ): + """Merge cells for shipment header. + """ index = len(prefix_header) + len(recurrents) + 1 for _ in enumerate(shipments): startcol = index @@ -341,19 +403,23 @@ def generate_recap( contracts: list[models.Contract], form: models.Form, ): + """Generate excel recap for a list of contracts. + """ product_unit_map = { '1': 'g', '2': 'Kg', '3': 'Piece' } recurrents = [ - f'{pr.name}{f' - {pr.quantity}{pr.quantity_unit}' if pr.quantity else ''} ({product_unit_map[pr.unit]})' + f'{pr.name}{f' - {pr.quantity}{pr.quantity_unit}' + if pr.quantity else ''} ({product_unit_map[pr.unit]})' for pr in form.productor.products if pr.type == models.ProductType.RECCURENT ] recurrents.sort() occasionnals = [ - f'{pr.name}{f' - {pr.quantity}{pr.quantity_unit}' if pr.quantity else ''} ({product_unit_map[pr.unit]})' + f'{pr.name}{f' - {pr.quantity}{pr.quantity_unit}' + if pr.quantity else ''} ({product_unit_map[pr.unit]})' for pr in form.productor.products if pr.type == models.ProductType.OCCASIONAL ] @@ -402,7 +468,8 @@ def generate_recap( len(info_header)+len(payment_formula_letters)+len(recurrents) + 1 ] occasionnals_formula_letters = letters[ - len(info_header)+len(payment_formula_letters)+len(recurent_formula_letters): + len(info_header)+len(payment_formula_letters)+ + len(recurent_formula_letters): len(info_header)+len(payment_formula_letters) + len(recurent_formula_letters)+len(occasionnals_header) + 1 ] @@ -421,11 +488,17 @@ def generate_recap( for index, contract in enumerate(contracts): prices = compute_contract_prices(contract) occasionnal_sorted = sorted( - [product for product in contract.products if product.product.type == models.ProductType.OCCASIONAL], + [ + product for product in contract.products + if product.product.type == models.ProductType.OCCASIONAL + ], key=lambda x: (x.shipment.name, x.product.name) ) recurrent_sorted = sorted( - [product for product in contract.products if product.product.type == models.ProductType.RECCURENT], + [ + product for product in contract.products + if product.product.type == models.ProductType.RECCURENT + ], key=lambda x: x.product.name ) diff --git a/backend/src/forms/forms.py b/backend/src/forms/forms.py index 23253c9..f2e68e4 100644 --- a/backend/src/forms/forms.py +++ b/backend/src/forms/forms.py @@ -1,9 +1,8 @@ import src.forms.exceptions as exceptions import src.forms.service as service -import src.messages as messages from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel import Session -from src import models +from src import messages, models from src.auth.auth import get_current_user from src.database import get_session @@ -33,7 +32,7 @@ async def get_forms_filtered( @router.get('/{_id}', response_model=models.FormPublic) async def get_form( - _id: int, + _id: int, session: Session = Depends(get_session) ): result = service.get_one(session, _id) diff --git a/backend/src/forms/service.py b/backend/src/forms/service.py index 1f46177..23b8a1e 100644 --- a/backend/src/forms/service.py +++ b/backend/src/forms/service.py @@ -1,8 +1,7 @@ import src.forms.exceptions as exceptions -import src.messages as messages from sqlalchemy import func from sqlmodel import Session, select -from src import models +from src import messages, models def get_all( @@ -109,11 +108,13 @@ def delete_one(session: Session, _id: int) -> models.FormPublic: def is_allowed( - session: Session, - user: models.User, - _id: int = None, + session: Session, + user: models.User, + _id: int = None, form: models.FormCreate = None ) -> bool: + if not _id and not form: + return False if not _id: statement = ( select(models.Productor) diff --git a/backend/src/productors/productors.py b/backend/src/productors/productors.py index b8a2b3d..50bca5d 100644 --- a/backend/src/productors/productors.py +++ b/backend/src/productors/productors.py @@ -1,11 +1,9 @@ -import src.messages as messages -import src.productors.exceptions as exceptions -import src.productors.service as service from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel import Session -from src import models +from src import messages, models from src.auth.auth import get_current_user from src.database import get_session +from src.productors import exceptions, service router = APIRouter(prefix='/productors') @@ -26,6 +24,11 @@ def get_productor( user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(session, user, _id=_id): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('productor', 'get') + ) result = service.get_one(session, _id) if result is None: raise HTTPException( @@ -41,6 +44,11 @@ def create_productor( user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(session, user, productor=productor): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('productor', 'create') + ) try: result = service.create_one(session, productor) except exceptions.ProductorCreateError as error: @@ -54,6 +62,11 @@ def update_productor( user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(session, user, _id=_id): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('productor', 'update') + ) try: result = service.update_one(session, _id, productor) except exceptions.ProductorNotFoundError as error: @@ -67,6 +80,11 @@ def delete_productor( user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(session, user, _id=_id): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('productor', 'delete') + ) try: result = service.delete_one(session, _id) except exceptions.ProductorNotFoundError as error: diff --git a/backend/src/productors/service.py b/backend/src/productors/service.py index 28d927c..9df58b1 100644 --- a/backend/src/productors/service.py +++ b/backend/src/productors/service.py @@ -1,7 +1,6 @@ -import src.messages as messages -import src.productors.exceptions as exceptions from sqlmodel import Session, select -from src import models +from src import messages, models +from src.productors import exceptions def get_all( @@ -50,9 +49,10 @@ def create_one( def update_one( session: Session, - id: int, - productor: models.ProductorUpdate) -> models.ProductorPublic: - statement = select(models.Productor).where(models.Productor.id == id) + _id: int, + productor: models.ProductorUpdate +) -> models.ProductorPublic: + statement = select(models.Productor).where(models.Productor.id == _id) result = session.exec(statement) new_productor = result.first() if not new_productor: @@ -94,11 +94,13 @@ def delete_one(session: Session, _id: int) -> models.ProductorPublic: return result def is_allowed( - session: Session, - user: models.User, - _id: int, - productor: models.ProductorCreate + session: Session, + user: models.User, + _id: int = None, + productor: models.ProductorCreate = None ) -> bool: + if not _id and not productor: + return False if not _id: return productor.type in [r.name for r in user.roles] statement = ( diff --git a/backend/src/products/products.py b/backend/src/products/products.py index 3ce5007..9566896 100644 --- a/backend/src/products/products.py +++ b/backend/src/products/products.py @@ -1,11 +1,10 @@ -import src.messages as messages -import src.products.exceptions as exceptions import src.products.service as service from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel import Session -from src import models +from src import messages, models from src.auth.auth import get_current_user from src.database import get_session +from src.products import exceptions router = APIRouter(prefix='/products') @@ -27,13 +26,18 @@ def get_products( ) -@router.get('/{id}', response_model=models.ProductPublic) +@router.get('/{_id}', response_model=models.ProductPublic) def get_product( - id: int, + _id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): - result = service.get_one(session, id) + if not service.is_allowed(session, user, _id=_id): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('product', 'create') + ) + result = service.get_one(session, _id) if result is None: raise HTTPException(status_code=404, detail=messages.Messages.not_found('product')) @@ -46,38 +50,68 @@ def create_product( user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(session, user, product=product): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('product', 'create') + ) try: result = service.create_one(session, product) except exceptions.ProductCreateError as error: - raise HTTPException(status_code=400, detail=str(error)) + raise HTTPException( + status_code=400, + detail=str(error) + ) from error except exceptions.ProductorNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException( + status_code=404, + detail=str(error) + ) from error return result -@router.put('/{id}', response_model=models.ProductPublic) +@router.put('/{_id}', response_model=models.ProductPublic) def update_product( - id: int, product: models.ProductUpdate, + _id: int, product: models.ProductUpdate, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(session, user, _id=_id): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('product', 'update') + ) try: - result = service.update_one(session, id, product) + result = service.update_one(session, _id, product) except exceptions.ProductNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException( + status_code=404, + detail=str(error) + ) from error except exceptions.ProductorNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException( + status_code=404, + detail=str(error) + ) from error return result -@router.delete('/{id}', response_model=models.ProductPublic) +@router.delete('/{_id}', response_model=models.ProductPublic) def delete_product( - id: int, + _id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(session, user, _id=_id): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('product', 'delete') + ) try: - result = service.delete_one(session, id) + result = service.delete_one(session, _id) except exceptions.ProductNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException( + status_code=404, + detail=str(error) + ) from error return result diff --git a/backend/src/products/service.py b/backend/src/products/service.py index f35d745..3ce4ca4 100644 --- a/backend/src/products/service.py +++ b/backend/src/products/service.py @@ -1,7 +1,6 @@ -import src.messages as messages -import src.products.exceptions as exceptions from sqlmodel import Session, select -from src import models +from src import messages, models +from src.products import exceptions def get_all( @@ -27,13 +26,17 @@ def get_all( return session.exec(statement.order_by(models.Product.name)).all() -def get_one(session: Session, product_id: int) -> models.ProductPublic: +def get_one( + session: Session, + product_id: int, +) -> models.ProductPublic: return session.get(models.Product, product_id) def create_one( - session: Session, - product: models.ProductCreate) -> models.ProductPublic: + session: Session, + product: models.ProductCreate, +) -> models.ProductPublic: if not product: raise exceptions.ProductCreateError( messages.Messages.invalid_input( @@ -50,10 +53,11 @@ def create_one( def update_one( - session: Session, - id: int, - product: models.ProductUpdate) -> models.ProductPublic: - statement = select(models.Product).where(models.Product.id == id) + session: Session, + _id: int, + product: models.ProductUpdate +) -> models.ProductPublic: + statement = select(models.Product).where(models.Product.id == _id) result = session.exec(statement) new_product = result.first() if not new_product: @@ -74,8 +78,11 @@ def update_one( return new_product -def delete_one(session: Session, id: int) -> models.ProductPublic: - statement = select(models.Product).where(models.Product.id == id) +def delete_one( + session: Session, + _id: int +) -> models.ProductPublic: + statement = select(models.Product).where(models.Product.id == _id) result = session.exec(statement) product = result.first() if not product: @@ -87,11 +94,13 @@ def delete_one(session: Session, id: int) -> models.ProductPublic: return result def is_allowed( - session: Session, - user: models.User, - _id: int, - product: models.ProductCreate + session: Session, + user: models.User, + _id: int = None, + product: models.ProductCreate = None, ) -> bool: + if not _id and not product: + return False if not _id: statement = ( select(models.Product) @@ -113,4 +122,4 @@ def is_allowed( .where(models.Productor.type.in_([r.name for r in user.roles])) .distinct() ) - return len(session.exec(statement).all()) > 0 \ No newline at end of file + return len(session.exec(statement).all()) > 0 diff --git a/backend/src/shipments/service.py b/backend/src/shipments/service.py index 02202e9..32792be 100644 --- a/backend/src/shipments/service.py +++ b/backend/src/shipments/service.py @@ -1,10 +1,9 @@ # pylint: disable=E1101 import datetime -import src.messages as messages import src.shipments.exceptions as exceptions from sqlmodel import Session, select -from src import models +from src import messages, models def get_all( @@ -127,3 +126,40 @@ def delete_one(session: Session, _id: int) -> models.ShipmentPublic: session.delete(shipment) session.commit() return result + + +def is_allowed( + session: Session, + user: models.User, + _id: int = None, + shipment: models.ShipmentCreate = None, +): + if not _id and not shipment: + return False + if not _id: + statement = ( + select(models.Shipment) + .join( + models.Form, + models.Shipment.form_id == models.Form.id + ) + .where(models.Form.id == shipment.form_id) + ) + form = session.exec(statement).first() + return form.productor.type in [r.name for r in user.roles] + statement = ( + select(models.Shipment) + .join( + models.Form, + models.Shipment.form_id == models.Form.id + ) + .join( + models.Productor, + models.Form.productor_id == models.Productor.id + ) + .where(models.Shipment.id == _id) + .where(models.Productor.type.in_([r.name for r in user.roles])) + .distinct() + ) + return len(session.exec(statement).all()) > 0 + diff --git a/backend/src/shipments/shipments.py b/backend/src/shipments/shipments.py index a2e7f40..886dd51 100644 --- a/backend/src/shipments/shipments.py +++ b/backend/src/shipments/shipments.py @@ -1,9 +1,8 @@ -import src.messages as messages import src.shipments.exceptions as exceptions import src.shipments.service as service from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel import Session -from src import models +from src import messages, models from src.auth.auth import get_current_user from src.database import get_session @@ -33,6 +32,11 @@ def get_shipment( user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(session, user, _id=_id): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('shipment', 'get') + ) result = service.get_one(session, _id) if result is None: raise HTTPException( @@ -48,6 +52,11 @@ def create_shipment( user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(session, user, shipment=shipment): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('shipment', 'create') + ) try: result = service.create_one(session, shipment) except exceptions.ShipmentCreateError as error: @@ -62,6 +71,11 @@ def update_shipment( user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(session, user, _id=_id): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('shipment', 'update') + ) try: result = service.update_one(session, _id, shipment) except exceptions.ShipmentNotFoundError as error: @@ -75,6 +89,12 @@ def delete_shipment( user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(session, user, _id=_id): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('shipment', 'delete') + ) + try: result = service.delete_one(session, _id) except exceptions.ShipmentNotFoundError as error: diff --git a/backend/src/templates/templates.py b/backend/src/templates/templates.py index 316120e..69a8002 100644 --- a/backend/src/templates/templates.py +++ b/backend/src/templates/templates.py @@ -1,8 +1,7 @@ -import src.messages as messages import src.templates.service as service from fastapi import APIRouter, Depends, HTTPException from sqlmodel import Session -from src import models +from src import messages, models from src.auth.auth import get_current_user from src.database import get_session diff --git a/backend/src/users/service.py b/backend/src/users/service.py index 991517e..27e59fe 100644 --- a/backend/src/users/service.py +++ b/backend/src/users/service.py @@ -1,7 +1,6 @@ -import src.messages as messages import src.users.exceptions as exceptions from sqlmodel import Session, select -from src import models +from src import messages, models def get_all( @@ -48,7 +47,8 @@ def get_or_create_user(session: Session, user_create: models.UserCreate): user = session.exec(statement).first() if user: user_role_names = [r.name for r in user.roles] - if user_role_names != user_create.role_names or user.name != user_create.name: + if (user_role_names != user_create.role_names or + user.name != user_create.name): user = update_one(session, user.id, user_create) return user user = create_one(session, user_create) @@ -119,3 +119,9 @@ def delete_one(session: Session, _id: int) -> models.UserPublic: session.delete(user) session.commit() return result + + +def is_allowed( + logged_user: models.User, +): + return len(logged_user.roles) >= 5 diff --git a/backend/src/users/users.py b/backend/src/users/users.py index 6e8308c..6904edb 100644 --- a/backend/src/users/users.py +++ b/backend/src/users/users.py @@ -1,9 +1,8 @@ -import src.messages as messages import src.users.exceptions as exceptions import src.users.service as service from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel import Session -from src import models +from src import messages, models from src.auth.auth import get_current_user from src.database import get_session @@ -13,7 +12,7 @@ router = APIRouter(prefix='/users') @router.get('', response_model=list[models.UserPublic]) def get_users( session: Session = Depends(get_session), - user: models.User = Depends(get_current_user), + _: models.User = Depends(get_current_user), names: list[str] = Query([]), emails: list[str] = Query([]), ): @@ -29,6 +28,11 @@ def get_roles( user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(user): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('roles', 'get all') + ) return service.get_roles(session) @@ -38,6 +42,11 @@ def get_user( user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(user): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('user', 'get') + ) result = service.get_one(session, _id) if result is None: raise HTTPException( @@ -53,11 +62,16 @@ def create_user( logged_user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(logged_user): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('user', 'create') + ) try: user = service.create_one(session, user) except exceptions.UserCreateError as error: raise HTTPException( - status_code=400, + status_code=400, detail=str(error) ) from error return user @@ -70,6 +84,11 @@ def update_user( logged_user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(logged_user): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('user', 'update') + ) try: result = service.update_one(session, _id, user) except exceptions.UserNotFoundError as error: @@ -80,14 +99,19 @@ def update_user( return result -@router.delete('/{id}', response_model=models.UserPublic) +@router.delete('/{_id}', response_model=models.UserPublic) def delete_user( - id: int, + _id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): + if not service.is_allowed(user): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('user', 'delete') + ) try: - result = service.delete_one(session, id) + result = service.delete_one(session, _id) except exceptions.UserNotFoundError as error: raise HTTPException( status_code=404, diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 5096dc8..fd09e33 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -7,8 +7,6 @@ from src.auth.auth import get_current_user from src.database import get_session from src.main import app -from .fixtures import * - @pytest.fixture def mock_session(mocker): diff --git a/backend/tests/factories/__init__.py b/backend/tests/factories/__init__.py new file mode 100644 index 0000000..e9a63bc --- /dev/null +++ b/backend/tests/factories/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2026-present Julien Aldon +# +# SPDX-License-Identifier: MIT diff --git a/backend/tests/routers/__init__.py b/backend/tests/routers/__init__.py new file mode 100644 index 0000000..e9a63bc --- /dev/null +++ b/backend/tests/routers/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2026-present Julien Aldon +# +# SPDX-License-Identifier: MIT diff --git a/backend/tests/routers/test_contracts.py b/backend/tests/routers/test_contracts.py index 3ef65f9..dd7726e 100644 --- a/backend/tests/routers/test_contracts.py +++ b/backend/tests/routers/test_contracts.py @@ -1,15 +1,18 @@ import src.contracts.service as service -import tests.factories.contract_products as contract_products_factory import tests.factories.contracts as contract_factory -import tests.factories.forms as form_factory from fastapi.exceptions import HTTPException -from src import models from src.auth.auth import get_current_user from src.main import app class TestContracts: - def test_get_all(self, client, mocker, mock_session, mock_user): + def test_get_all( + self, + client, + mocker, + mock_session, + mock_user + ): mock_results = [ contract_factory.contract_public_factory(id=1), contract_factory.contract_public_factory(id=2), @@ -32,7 +35,13 @@ class TestContracts: [], ) - def test_get_all_filters(self, client, mocker, mock_session, mock_user): + def test_get_all_filters( + self, + client, + mocker, + mock_session, + mock_user + ): mock_results = [ contract_factory.contract_public_factory(id=2), ] @@ -54,11 +63,10 @@ class TestContracts: ) def test_get_all_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) @@ -72,7 +80,12 @@ class TestContracts: app.dependency_overrides.clear() - def test_get_one(self, client, mocker, mock_session, mock_user): + def test_get_one( + self, + client, + mocker, + mock_session, + ): mock_result = contract_factory.contract_public_factory(id=2) mock = mocker.patch.object( @@ -95,7 +108,12 @@ class TestContracts: 2 ) - def test_get_one_notfound(self, client, mocker, mock_session, mock_user): + def test_get_one_notfound( + self, + client, + mocker, + mock_session, + ): mock_result = None mock = mocker.patch.object( service, @@ -116,11 +134,10 @@ class TestContracts: ) def test_get_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) @@ -134,7 +151,12 @@ class TestContracts: app.dependency_overrides.clear() - def test_delete_one(self, client, mocker, mock_session, mock_user): + def test_delete_one( + self, + client, + mocker, + mock_session, + ): contract_result = contract_factory.contract_public_factory() mock = mocker.patch.object( @@ -158,11 +180,10 @@ class TestContracts: ) def test_delete_one_notfound( - self, - client, - mocker, - mock_session, - mock_user + self, + client, + mocker, + mock_session, ): contract_result = None @@ -187,11 +208,9 @@ class TestContracts: ) def test_delete_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user + self, + client, + mocker, ): def unauthorized(): raise HTTPException(status_code=401) diff --git a/backend/tests/routers/test_forms.py b/backend/tests/routers/test_forms.py index 75a42b7..fb29793 100644 --- a/backend/tests/routers/test_forms.py +++ b/backend/tests/routers/test_forms.py @@ -1,15 +1,20 @@ import src.forms.exceptions as forms_exceptions import src.forms.service as service -import src.messages as messages import tests.factories.forms as form_factory from fastapi.exceptions import HTTPException -from src import models +from src import messages from src.auth.auth import get_current_user from src.main import app class TestForms: - def test_get_all(self, client, mocker, mock_session, mock_user): + def test_get_all( + self, + client, + mocker, + mock_session, + mock_user, + ): mock_results = [ form_factory.form_public_factory(name="test 1", id=1), form_factory.form_public_factory(name="test 2", id=2), @@ -34,7 +39,13 @@ class TestForms: mock_user, ) - def test_get_all_filters(self, client, mocker, mock_session, mock_user): + def test_get_all_filters( + self, + client, + mocker, + mock_session, + mock_user, + ): mock_results = [ form_factory.form_public_factory(name="test 2", id=2), ] @@ -59,11 +70,10 @@ class TestForms: ) def test_get_all_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) @@ -77,7 +87,12 @@ class TestForms: app.dependency_overrides.clear() - def test_get_one(self, client, mocker, mock_session, mock_user): + def test_get_one( + self, + client, + mocker, + mock_session, + ): mock_result = form_factory.form_public_factory(name="test 2", id=2) mock = mocker.patch.object( @@ -96,7 +111,12 @@ class TestForms: 2 ) - def test_get_one_notfound(self, client, mocker, mock_session, mock_user): + def test_get_one_notfound( + self, + client, + mocker, + mock_session, + ): mock_result = None mock = mocker.patch.object( service, @@ -104,14 +124,18 @@ class TestForms: return_value=mock_result ) response = client.get('/api/forms/2') - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( mock_session, 2 ) - def test_create_one(self, client, mocker, mock_session, mock_user): + def test_create_one( + self, + client, + mocker, + mock_session, + ): form_body = form_factory.form_body_factory(name='test form create') form_create = form_factory.form_create_factory(name='test form create') form_result = form_factory.form_public_factory(name='test form create') @@ -133,7 +157,11 @@ class TestForms: ) def test_create_one_referer_notfound( - self, client, mocker, mock_session, mock_user): + self, + client, + mocker, + mock_session, + ): form_body = form_factory.form_body_factory( name='test form create', referer_id=12312) form_create = form_factory.form_create_factory( @@ -144,8 +172,6 @@ class TestForms: messages.Messages.not_found('referer'))) response = client.post('/api/forms', json=form_body) - response_data = response.json() - assert response.status_code == 404 mock.assert_called_once_with( mock_session, @@ -153,7 +179,11 @@ class TestForms: ) def test_create_one_productor_notfound( - self, client, mocker, mock_session, mock_user): + self, + client, + mocker, + mock_session, + ): form_body = form_factory.form_body_factory( name='test form create', productor_id=1231) form_create = form_factory.form_create_factory( @@ -164,7 +194,6 @@ class TestForms: messages.Messages.not_found('productor'))) response = client.post('/api/forms', json=form_body) - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( @@ -173,11 +202,10 @@ class TestForms: ) def test_create_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) form_body = form_factory.form_body_factory(name='test form create') @@ -192,7 +220,12 @@ class TestForms: app.dependency_overrides.clear() - def test_update_one(self, client, mocker, mock_session, mock_user): + def test_update_one( + self, + client, + mocker, + mock_session, + ): form_body = form_factory.form_body_factory(name='test form update') form_update = form_factory.form_update_factory(name='test form update') form_result = form_factory.form_public_factory(name='test form update') @@ -215,11 +248,11 @@ class TestForms: ) def test_update_one_notfound( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + mock_session, + ): form_body = form_factory.form_body_factory(name='test form update') form_update = form_factory.form_update_factory(name='test form update') @@ -228,7 +261,6 @@ class TestForms: messages.Messages.not_found('form'))) response = client.put('/api/forms/2', json=form_body) - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( @@ -238,7 +270,11 @@ class TestForms: ) def test_update_one_referer_notfound( - self, client, mocker, mock_session, mock_user): + self, + client, + mocker, + mock_session, + ): form_body = form_factory.form_body_factory(name='test form update') form_update = form_factory.form_update_factory(name='test form update') @@ -247,7 +283,6 @@ class TestForms: messages.Messages.not_found('referer'))) response = client.put('/api/forms/2', json=form_body) - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( @@ -257,7 +292,11 @@ class TestForms: ) def test_update_one_productor_notfound( - self, client, mocker, mock_session, mock_user): + self, + client, + mocker, + mock_session, + ): form_body = form_factory.form_body_factory(name='test form update') form_update = form_factory.form_update_factory(name='test form update') @@ -266,7 +305,6 @@ class TestForms: messages.Messages.not_found('productor'))) response = client.put('/api/forms/2', json=form_body) - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( @@ -276,11 +314,10 @@ class TestForms: ) def test_update_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) form_body = form_factory.form_body_factory(name='test form update') @@ -295,7 +332,12 @@ class TestForms: app.dependency_overrides.clear() - def test_delete_one(self, client, mocker, mock_session, mock_user): + def test_delete_one( + self, + client, + mocker, + mock_session, + ): form_result = form_factory.form_public_factory(name='test form delete') mock = mocker.patch.object( @@ -315,19 +357,19 @@ class TestForms: ) def test_delete_one_notfound( - self, - client, - mocker, - mock_session, - mock_user): - form_result = None - + self, + client, + mocker, + mock_session, + ): mock = mocker.patch.object( - service, 'delete_one', side_effect=forms_exceptions.FormNotFoundError( - messages.Messages.not_found('form'))) + service, + 'delete_one', + side_effect=forms_exceptions.FormNotFoundError( + messages.Messages.not_found('form')) + ) response = client.delete('/api/forms/2') - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( @@ -336,11 +378,10 @@ class TestForms: ) def test_delete_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) diff --git a/backend/tests/routers/test_productors.py b/backend/tests/routers/test_productors.py index 5ff9046..0065f19 100644 --- a/backend/tests/routers/test_productors.py +++ b/backend/tests/routers/test_productors.py @@ -1,15 +1,19 @@ -import src.messages as messages -import src.productors.exceptions as exceptions -import src.productors.service as service import tests.factories.productors as productor_factory from fastapi.exceptions import HTTPException -from src import models +from src import messages from src.auth.auth import get_current_user from src.main import app +from src.productors import exceptions, service class TestProductors: - def test_get_all(self, client, mocker, mock_session, mock_user): + def test_get_all( + self, + client, + mocker, + mock_session, + mock_user, + ): mock_results = [ productor_factory.productor_public_factory(name="test 1", id=1), productor_factory.productor_public_factory(name="test 2", id=2), @@ -33,7 +37,13 @@ class TestProductors: [], ) - def test_get_all_filters(self, client, mocker, mock_session, mock_user): + def test_get_all_filters( + self, + client, + mocker, + mock_session, + mock_user, + ): mock_results = [ productor_factory.productor_public_factory(name="test 2", id=2), ] @@ -57,11 +67,10 @@ class TestProductors: ) def test_get_all_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) @@ -75,10 +84,22 @@ class TestProductors: app.dependency_overrides.clear() - def test_get_one(self, client, mocker, mock_session, mock_user): + def test_get_one( + self, + client, + mocker, + mock_session, + mock_user, + ): mock_result = productor_factory.productor_public_factory( name="test 2", id=2) + mocker.patch.object( + service, + 'is_allowed', + return_value=True + ) + mock = mocker.patch.object( service, 'get_one', @@ -95,7 +116,18 @@ class TestProductors: 2 ) - def test_get_one_notfound(self, client, mocker, mock_session, mock_user): + def test_get_one_notfound( + self, + client, + mocker, + mock_session, + mock_user, + ): + mocker.patch.object( + service, + 'is_allowed', + return_value=True + ) mock_result = None mock = mocker.patch.object( service, @@ -103,7 +135,6 @@ class TestProductors: return_value=mock_result ) response = client.get('/api/productors/2') - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( mock_session, @@ -111,11 +142,10 @@ class TestProductors: ) def test_get_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) @@ -129,7 +159,13 @@ class TestProductors: app.dependency_overrides.clear() - def test_create_one(self, client, mocker, mock_session, mock_user): + def test_create_one( + self, + client, + mocker, + mock_session, + mock_user, + ): productor_body = productor_factory.productor_body_factory( name='test productor create') productor_create = productor_factory.productor_create_factory( @@ -137,6 +173,12 @@ class TestProductors: productor_result = productor_factory.productor_public_factory( name='test productor create') + mocker.patch.object( + service, + 'is_allowed', + return_value=True + ) + mock = mocker.patch.object( service, 'create_one', @@ -154,11 +196,10 @@ class TestProductors: ) def test_create_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) productor_body = productor_factory.productor_body_factory( @@ -174,7 +215,13 @@ class TestProductors: app.dependency_overrides.clear() - def test_update_one(self, client, mocker, mock_session, mock_user): + def test_update_one( + self, + client, + mocker, + mock_session, + mock_user, + ): productor_body = productor_factory.productor_body_factory( name='test productor update') productor_update = productor_factory.productor_update_factory( @@ -182,6 +229,12 @@ class TestProductors: productor_result = productor_factory.productor_public_factory( name='test productor update') + mocker.patch.object( + service, + 'is_allowed', + return_value=True + ) + mock = mocker.patch.object( service, 'update_one', @@ -200,23 +253,34 @@ class TestProductors: ) def test_update_one_notfound( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + mock_session, + mock_user, + ): productor_body = productor_factory.productor_body_factory( - name='test productor update') + name='test productor update', + ) productor_update = productor_factory.productor_update_factory( - name='test productor update') - productor_result = None + name='test productor update', + ) + + mocker.patch.object( + service, + 'is_allowed', + return_value=True + ) mock = mocker.patch.object( - service, 'update_one', side_effect=exceptions.ProductorNotFoundError( - messages.Messages.not_found('productor'))) + service, + 'update_one', + side_effect=exceptions.ProductorNotFoundError( + messages.Messages.not_found('productor') + ) + ) response = client.put('/api/productors/2', json=productor_body) - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( @@ -225,12 +289,12 @@ class TestProductors: productor_update ) + def test_update_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) productor_body = productor_factory.productor_body_factory( @@ -246,10 +310,23 @@ class TestProductors: app.dependency_overrides.clear() - def test_delete_one(self, client, mocker, mock_session, mock_user): + + def test_delete_one( + self, + client, + mocker, + mock_session, + mock_user, + ): productor_result = productor_factory.productor_public_factory( name='test productor delete') + mocker.patch.object( + service, + 'is_allowed', + return_value=True + ) + mock = mocker.patch.object( service, 'delete_one', @@ -267,19 +344,26 @@ class TestProductors: ) def test_delete_one_notfound( - self, - client, - mocker, - mock_session, - mock_user): - productor_result = None - + self, + client, + mocker, + mock_session, + mock_user, + ): + mocker.patch.object( + service, + 'is_allowed', + return_value=True + ) mock = mocker.patch.object( - service, 'delete_one', side_effect=exceptions.ProductorNotFoundError( - messages.Messages.not_found('productor'))) + service, + 'delete_one', + side_effect=exceptions.ProductorNotFoundError( + messages.Messages.not_found('productor') + ) + ) response = client.delete('/api/productors/2') - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( @@ -288,16 +372,12 @@ class TestProductors: ) def test_delete_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) - productor_body = productor_factory.productor_body_factory( - name='test productor delete') - app.dependency_overrides[get_current_user] = unauthorized mock = mocker.patch('src.productors.service.delete_one') diff --git a/backend/tests/routers/test_products.py b/backend/tests/routers/test_products.py index 01e79f4..6c7b493 100644 --- a/backend/tests/routers/test_products.py +++ b/backend/tests/routers/test_products.py @@ -1,14 +1,19 @@ -import src.products.exceptions as exceptions import src.products.service as service import tests.factories.products as product_factory from fastapi.exceptions import HTTPException -from src import models from src.auth.auth import get_current_user from src.main import app +from src.products import exceptions class TestProducts: - def test_get_all(self, client, mocker, mock_session, mock_user): + def test_get_all( + self, + client, + mocker, + mock_session, + mock_user + ): mock_results = [ product_factory.product_public_factory(name="test 1", id=1), product_factory.product_public_factory(name="test 2", id=2), @@ -33,7 +38,13 @@ class TestProducts: [] ) - def test_get_all_filters(self, client, mocker, mock_session, mock_user): + def test_get_all_filters( + self, + client, + mocker, + mock_session, + mock_user + ): mock_results = [ product_factory.product_public_factory(name="test 2", id=2), ] @@ -60,8 +71,7 @@ class TestProducts: self, client, mocker, - mock_session, - mock_user): + ): def unauthorized(): raise HTTPException(status_code=401) @@ -75,7 +85,12 @@ class TestProducts: app.dependency_overrides.clear() - def test_get_one(self, client, mocker, mock_session, mock_user): + def test_get_one( + self, + client, + mocker, + mock_session, + ): mock_result = product_factory.product_public_factory( name="test 2", id=2) @@ -95,7 +110,12 @@ class TestProducts: 2 ) - def test_get_one_notfound(self, client, mocker, mock_session, mock_user): + def test_get_one_notfound( + self, + client, + mocker, + mock_session, + ): mock_result = None mock = mocker.patch.object( service, @@ -103,7 +123,6 @@ class TestProducts: return_value=mock_result ) response = client.get('/api/products/2') - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( mock_session, @@ -111,11 +130,10 @@ class TestProducts: ) def test_get_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) @@ -129,7 +147,12 @@ class TestProducts: app.dependency_overrides.clear() - def test_create_one(self, client, mocker, mock_session, mock_user): + def test_create_one( + self, + client, + mocker, + mock_session, + ): product_body = product_factory.product_body_factory( name='test product create') product_create = product_factory.product_create_factory( @@ -154,11 +177,10 @@ class TestProducts: ) def test_create_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) product_body = product_factory.product_body_factory( @@ -174,13 +196,21 @@ class TestProducts: app.dependency_overrides.clear() - def test_update_one(self, client, mocker, mock_session, mock_user): + def test_update_one( + self, + client, + mocker, + mock_session, + ): product_body = product_factory.product_body_factory( - name='test product update') + name='test product update' + ) product_update = product_factory.product_update_factory( - name='test product update') + name='test product update' + ) product_result = product_factory.product_public_factory( - name='test product update') + name='test product update' + ) mock = mocker.patch.object( service, @@ -200,16 +230,17 @@ class TestProducts: ) def test_update_one_notfound( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + mock_session, + ): product_body = product_factory.product_body_factory( - name='test product update') + name='test product update' + ) product_update = product_factory.product_update_factory( - name='test product update') - product_result = None + name='test product update' + ) mock = mocker.patch.object( service, @@ -218,7 +249,6 @@ class TestProducts: ) response = client.put('/api/products/2', json=product_body) - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( @@ -228,11 +258,10 @@ class TestProducts: ) def test_update_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) product_body = product_factory.product_body_factory( @@ -248,7 +277,12 @@ class TestProducts: app.dependency_overrides.clear() - def test_delete_one(self, client, mocker, mock_session, mock_user): + def test_delete_one( + self, + client, + mocker, + mock_session, + ): product_result = product_factory.product_public_factory( name='test product delete') @@ -269,13 +303,11 @@ class TestProducts: ) def test_delete_one_notfound( - self, - client, - mocker, - mock_session, - mock_user): - product_result = None - + self, + client, + mocker, + mock_session, + ): mock = mocker.patch.object( service, 'delete_one', @@ -283,7 +315,6 @@ class TestProducts: ) response = client.delete('/api/products/2') - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( @@ -295,13 +326,9 @@ class TestProducts: self, client, mocker, - mock_session, - mock_user): + ): def unauthorized(): raise HTTPException(status_code=401) - product_body = product_factory.product_body_factory( - name='test product delete') - app.dependency_overrides[get_current_user] = unauthorized mock = mocker.patch('src.products.service.delete_one') diff --git a/backend/tests/routers/test_shipments.py b/backend/tests/routers/test_shipments.py index 3d60e0f..2fe1e16 100644 --- a/backend/tests/routers/test_shipments.py +++ b/backend/tests/routers/test_shipments.py @@ -1,15 +1,20 @@ -import src.messages as messages import src.shipments.exceptions as exceptions import src.shipments.service as service import tests.factories.shipments as shipment_factory from fastapi.exceptions import HTTPException -from src import models +from src import messages from src.auth.auth import get_current_user from src.main import app class TestShipments: - def test_get_all(self, client, mocker, mock_session, mock_user): + def test_get_all( + self, + client, + mocker, + mock_session, + mock_user, + ): mock_results = [ shipment_factory.shipment_public_factory(name="test 1", id=1), shipment_factory.shipment_public_factory(name="test 2", id=2), @@ -34,7 +39,13 @@ class TestShipments: [], ) - def test_get_all_filters(self, client, mocker, mock_session, mock_user): + def test_get_all_filters( + self, + client, + mocker, + mock_session, + mock_user, + ): mock_results = [ shipment_factory.shipment_public_factory(name="test 2", id=2), ] @@ -59,11 +70,10 @@ class TestShipments: ) def test_get_all_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) @@ -77,7 +87,12 @@ class TestShipments: app.dependency_overrides.clear() - def test_get_one(self, client, mocker, mock_session, mock_user): + def test_get_one( + self, + client, + mocker, + mock_session, + ): mock_result = shipment_factory.shipment_public_factory( name="test 2", id=2) @@ -97,7 +112,12 @@ class TestShipments: 2 ) - def test_get_one_notfound(self, client, mocker, mock_session, mock_user): + def test_get_one_notfound( + self, + client, + mocker, + mock_session, + ): mock_result = None mock = mocker.patch.object( service, @@ -105,7 +125,6 @@ class TestShipments: return_value=mock_result ) response = client.get('/api/shipments/2') - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( mock_session, @@ -113,11 +132,10 @@ class TestShipments: ) def test_get_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) @@ -131,13 +149,21 @@ class TestShipments: app.dependency_overrides.clear() - def test_create_one(self, client, mocker, mock_session, mock_user): + def test_create_one( + self, + client, + mocker, + mock_session, +): shipment_body = shipment_factory.shipment_body_factory( - name='test shipment create') + name='test shipment create' + ) shipment_create = shipment_factory.shipment_create_factory( - name='test shipment create') + name='test shipment create' + ) shipment_result = shipment_factory.shipment_public_factory( - name='test shipment create') + name='test shipment create' + ) mock = mocker.patch.object( service, @@ -156,15 +182,15 @@ class TestShipments: ) def test_create_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) shipment_body = shipment_factory.shipment_body_factory( - name='test shipment create') + name='test shipment create' + ) app.dependency_overrides[get_current_user] = unauthorized @@ -176,13 +202,21 @@ class TestShipments: app.dependency_overrides.clear() - def test_update_one(self, client, mocker, mock_session, mock_user): + def test_update_one( + self, + client, + mocker, + mock_session, + ): shipment_body = shipment_factory.shipment_body_factory( - name='test shipment update') + name='test shipment update' + ) shipment_update = shipment_factory.shipment_update_factory( - name='test shipment update') + name='test shipment update' + ) shipment_result = shipment_factory.shipment_public_factory( - name='test shipment update') + name='test shipment update' + ) mock = mocker.patch.object( service, @@ -202,22 +236,23 @@ class TestShipments: ) def test_update_one_notfound( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + mock_session, + ): shipment_body = shipment_factory.shipment_body_factory( - name='test shipment update') + name='test shipment update' + ) shipment_update = shipment_factory.shipment_update_factory( - name='test shipment update') + name='test shipment update' + ) mock = mocker.patch.object( service, 'update_one', side_effect=exceptions.ShipmentNotFoundError( messages.Messages.not_found('shipment'))) response = client.put('/api/shipments/2', json=shipment_body) - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( @@ -227,15 +262,15 @@ class TestShipments: ) def test_update_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) shipment_body = shipment_factory.shipment_body_factory( - name='test shipment update') + name='test shipment update' + ) app.dependency_overrides[get_current_user] = unauthorized @@ -247,9 +282,15 @@ class TestShipments: app.dependency_overrides.clear() - def test_delete_one(self, client, mocker, mock_session, mock_user): + def test_delete_one( + self, + client, + mocker, + mock_session, + ): shipment_result = shipment_factory.shipment_public_factory( - name='test shipment delete') + name='test shipment delete' + ) mock = mocker.patch.object( service, @@ -268,19 +309,20 @@ class TestShipments: ) def test_delete_one_notfound( - self, - client, - mocker, - mock_session, - mock_user): - shipment_result = None - + self, + client, + mocker, + mock_session, + ): mock = mocker.patch.object( - service, 'delete_one', side_effect=exceptions.ShipmentNotFoundError( - messages.Messages.not_found('shipment'))) + service, + 'delete_one', + side_effect=exceptions.ShipmentNotFoundError( + messages.Messages.not_found('shipment') + ) + ) response = client.delete('/api/shipments/2') - response_data = response.json() assert response.status_code == 404 mock.assert_called_once_with( @@ -289,15 +331,12 @@ class TestShipments: ) def test_delete_one_unauthorized( - self, - client, - mocker, - mock_session, - mock_user): + self, + client, + mocker, + ): def unauthorized(): raise HTTPException(status_code=401) - shipment_body = shipment_factory.shipment_body_factory( - name='test shipment delete') app.dependency_overrides[get_current_user] = unauthorized diff --git a/backend/tests/services/__init__.py b/backend/tests/services/__init__.py new file mode 100644 index 0000000..e9a63bc --- /dev/null +++ b/backend/tests/services/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2026-present Julien Aldon +# +# SPDX-License-Identifier: MIT diff --git a/frontend/src/components/Contracts/Modal/index.tsx b/frontend/src/components/Contracts/Modal/index.tsx index 8388ef4..f04dc8d 100644 --- a/frontend/src/components/Contracts/Modal/index.tsx +++ b/frontend/src/components/Contracts/Modal/index.tsx @@ -26,6 +26,9 @@ export function ContractModal({ opened, onClose, handleSubmit }: ContractModalPr }); const formSelect = useMemo(() => { + if (!allForms) { + return []; + } return allForms?.map((form) => ({ value: String(form.id), label: `${form.season} ${form.name}`,