fix all pylint warnings, add tests (wip) fix recap

This commit is contained in:
2026-03-06 00:00:01 +01:00
parent 60812652cf
commit b4b4fa7643
25 changed files with 845 additions and 376 deletions

View File

View File

@@ -4,14 +4,13 @@ from urllib.parse import urlencode
import jwt
import requests
import src.messages as messages
import src.users.service as service
from fastapi import (APIRouter, Cookie, Depends, HTTPException, Request,
Security)
from fastapi import APIRouter, Cookie, Depends, HTTPException, Request
from fastapi.responses import RedirectResponse, Response
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security import HTTPBearer
from jwt import PyJWKClient
from sqlmodel import Session, select
from src import messages
from src.database import get_session
from src.models import User, UserCreate, UserPublic
from src.settings import (AUTH_URL, ISSUER, JWKS_URL, LOGOUT_URL, TOKEN_URL,
@@ -78,7 +77,18 @@ def callback(code: str, session: Session = Depends(get_session)):
headers = {
'Content-Type': 'application/x-www-form-urlencoded'
}
response = requests.post(TOKEN_URL, data=data, headers=headers)
try:
response = requests.post(
TOKEN_URL,
data=data,
headers=headers,
timeout=10
)
except requests.exceptions.Timeout as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('token')
) from error
if response.status_code != 200:
raise HTTPException(
status_code=404,
@@ -99,7 +109,13 @@ def callback(code: str, session: Session = Depends(get_session)):
'client_secret': settings.keycloak_client_secret,
'refresh_token': token_data['refresh_token'],
}
requests.post(LOGOUT_URL, data=data)
try:
requests.post(LOGOUT_URL, data=data, timeout=10)
except requests.exceptions.Timeout as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('token')
) from error
resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true')
return resp
roles = resource_access.get(settings.keycloak_client_id)
@@ -109,7 +125,13 @@ def callback(code: str, session: Session = Depends(get_session)):
'client_secret': settings.keycloak_client_secret,
'refresh_token': token_data['refresh_token'],
}
requests.post(LOGOUT_URL, data=data)
try:
requests.post(LOGOUT_URL, data=data, timeout=10)
except requests.exceptions.Timeout as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('token')
) from error
resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true')
return resp
@@ -160,16 +182,16 @@ def verify_token(token: str):
leeway=60,
)
return decoded
except jwt.ExpiredSignatureError:
except jwt.ExpiredSignatureError as error:
raise HTTPException(
status_code=401,
detail=messages.Messages.tokenexipired
)
except jwt.InvalidTokenError:
) from error
except jwt.InvalidTokenError as error:
raise HTTPException(
status_code=401,
detail=messages.Messages.invalidtoken
)
) from error
def get_current_user(
@@ -184,7 +206,7 @@ def get_current_user(
payload = verify_token(access_token)
if not payload:
raise HTTPException(
status_code=401,
status_code=401,
detail='aze'
)
email = payload.get('email')
@@ -205,7 +227,7 @@ def get_current_user(
@router.post('/refresh')
def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
def refresh_user_token(refresh_token: Annotated[str | None, Cookie()] = None):
refresh = refresh_token
data = {
'grant_type': 'refresh_token',
@@ -216,7 +238,18 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
headers = {
'Content-Type': 'application/x-www-form-urlencoded'
}
result = requests.post(TOKEN_URL, data=data, headers=headers)
try:
result = requests.post(
TOKEN_URL,
data=data,
headers=headers,
timeout=10,
)
except requests.exceptions.Timeout as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('token')
) from error
if result.status_code != 200:
raise HTTPException(
status_code=404,
@@ -229,7 +262,7 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
key='access_token',
value=token_data['access_token'],
httponly=True,
secure=True if settings.debug == False else True,
secure=True if settings.debug is False else True,
samesite='strict',
max_age=settings.max_age
)
@@ -237,7 +270,7 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
key='refresh_token',
value=token_data['refresh_token'] or '',
httponly=True,
secure=True if settings.debug == False else True,
secure=True if settings.debug is False else True,
samesite='strict',
max_age=30 * 24 * settings.max_age
)

View File

@@ -4,11 +4,10 @@ import zipfile
import src.contracts.service as service
import src.forms.service as form_service
import src.messages as messages
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlmodel import Session
from src import models
from src import messages, models
from src.auth.auth import get_current_user
from src.contracts.generate_contract import (generate_html_contract,
generate_recap)

View File

@@ -18,11 +18,24 @@ def generate_html_contract(
reccurents: list[dict],
recurrent_price: float | None = None,
total_price: float | None = None
):
) -> bytes:
"""Generate a html contract
Arguments:
contract(models.Contract): Contract source.
cheques(list[dict]): cheques formated in dict.
occasionals(list[dict]): occasional products.
reccurents(list[dict]): recurrent products.
recurrent_price(float | None = None): total price of recurent products.
total_price(float | None = Non): total price.
Return:
result(bytes): contract file in pdf as bytes.
"""
template_dir = pathlib.Path("./src/contracts/templates").resolve()
template_loader = jinja2.FileSystemLoader(searchpath=template_dir)
template_env = jinja2.Environment(
loader=template_loader, autoescape=jinja2.select_autoescape(["html", "xml"]))
loader=template_loader,
autoescape=jinja2.select_autoescape(["html", "xml"])
)
template_file = "layout.html"
template = template_env.get_template(template_file)
output_text = template.render(
@@ -65,13 +78,16 @@ def generate_html_contract(
def flatten(xss):
"""flatten a list of list.
"""
return [x for xs in xss for x in xs]
def create_column_style_width(size: str) -> odfdo.Style:
"""Create a table columm style for a given width.
Paramenters:
size(str): size of the style (format <number><unit>) unit can be in, cm... see odfdo documentation.
size(str): size of the style (format <number><unit>)
unit can be in, cm... see odfdo documentation.
Returns:
odfdo.Style with the correct column-width attribute.
"""
@@ -85,7 +101,8 @@ def create_column_style_width(size: str) -> odfdo.Style:
def create_row_style_height(size: str) -> odfdo.Style:
"""Create a table height style for a given height.
Paramenters:
size(str): size of the style (format <number><unit>) unit can be in, cm... see odfdo documentation.
size(str): size of the style (format <number><unit>)
unit can be in, cm... see odfdo documentation.
Returns:
odfdo.Style with the correct column-height attribute.
"""
@@ -97,10 +114,17 @@ def create_row_style_height(size: str) -> odfdo.Style:
def create_currency_style(name:str = 'currency-euro'):
"""Create a table currency style.
Paramenters:
name(str): name of the style (default to `currency-euro`).
Returns:
odfdo.Style with the correct column-height attribute.
"""
return odfdo.Element.from_tag(
f"""
<number:currency-style style:name="{name}">
<number:number number:min-integer-digits="1" number:decimal-places="2"/>
<number:number number:min-integer-digits="1"
number:decimal-places="2"/>
<number:text> €</number:text>
</number:currency-style>"""
)
@@ -113,6 +137,18 @@ def create_cell_style(
color: str = '#000000',
currency: bool = False,
) -> odfdo.Style:
"""Create a cell style
Paramenters:
name(str): name of the style (default to `centered-cell`).
font_size(str): font_size of the cell (default to `10pt`).
bold(str): is the text bold (default to `False`).
background_color(str): background_color of the cell
(default to `#FFFFFF`).
color(str): color of the text of the cell (default to `#000000`).
currency(str): is the cell a currency (default to `False`).
Returns:
odfdo.Style with the correct column-height attribute.
"""
bold_attr = """
fo:font-weight="bold"
style:font-weight-asian="bold"
@@ -138,7 +174,13 @@ def create_cell_style(
)
def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols: list[int]):
def apply_cell_style(
document: odfdo.Document,
table: odfdo.Table,
currency_cols: list[int]
):
"""Apply cell style
"""
document.insert_style(
style=create_currency_style(),
)
@@ -151,7 +193,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols
color="#FFF"
)
)
body_style_even = document.insert_style(
create_cell_style(
name="body-style-even",
@@ -160,7 +201,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols
color="#000000",
)
)
body_style_odd = document.insert_style(
create_cell_style(
name="body-style-odd",
@@ -169,7 +209,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols
color="#000000",
)
)
footer_style = document.insert_style(
create_cell_style(
name="footer-cells",
@@ -177,7 +216,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols
font_size='12pt',
)
)
body_style_even_currency = document.insert_style(
create_cell_style(
name="body-style-even-currency",
@@ -187,7 +225,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols
currency=True,
)
)
body_style_odd_currency = document.insert_style(
create_cell_style(
name="body-style-odd-currency",
@@ -197,7 +234,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols
currency=True,
)
)
footer_style_currency = document.insert_style(
create_cell_style(
name="footer-cells-currency",
@@ -206,7 +242,6 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols
currency=True,
)
)
for index, row in enumerate(table.get_rows()):
style = body_style_even
currency_style = body_style_even_currency
@@ -228,7 +263,12 @@ def apply_cell_style(document: odfdo.Document, table: odfdo.Table, currency_cols
cell.style = style
def apply_column_height_style(document: odfdo.Document, table: odfdo.Table):
def apply_column_height_style(
document: odfdo.Document,
table: odfdo.Table
):
"""Apply column height for a given table
"""
header_style = document.insert_style(
style=create_row_style_height('1.60cm'), name='1.60cm', automatic=True
)
@@ -241,19 +281,29 @@ def apply_column_height_style(document: odfdo.Document, table: odfdo.Table):
else:
row.style = body_style
def apply_cell_style_by_column(table: odfdo.Table, style: odfdo.Style, col_index: int):
def apply_cell_style_by_column(
table: odfdo.Table,
style: odfdo.Style,
col_index: int
):
"""Apply cell style for a given table
"""
for cell in table.get_column_cells(col_index):
print(cell.style)
cell.style = style
print(cell.serialize())
def apply_column_width_style(document: odfdo.Document, table: odfdo.Table, widths: list[str]):
def apply_column_width_style(
document: odfdo.Document,
table: odfdo.Table,
widths: list[str]
):
"""Apply column width style to a table.
Parameters:
document(odfdo.Document): Document where the table is located.
table(odfdo.Table): Table to apply columns widths.
widths(list[str]): list of width in format <number><unit> unit ca be in, cm... see odfdo documentation.
widths(list[str]): list of width in format <number><unit> unit ca be
in, cm... see odfdo documentation.
"""
styles = []
for w in widths:
@@ -268,6 +318,12 @@ def apply_column_width_style(document: odfdo.Document, table: odfdo.Table, width
def generate_ods_letters(n: int):
"""Generate letters following excel format.
Arguments:
n(int): `n` letters to generate.
Return:
result(list[str]): list of `n` letters that follow excel pattern.
"""
letters = string.ascii_lowercase
result = []
for i in range(n):
@@ -282,6 +338,8 @@ def generate_ods_letters(n: int):
def compute_contract_prices(contract: models.Contract) -> dict:
"""Compute price for a give contract.
"""
occasional_contract_products = list(
filter(
lambda contract_product: (
@@ -313,6 +371,8 @@ def compute_contract_prices(contract: models.Contract) -> dict:
def transform_formula_cells(sheet: odfdo.Spreadsheet):
"""Transform cell value to a formula using odfdo.
"""
for row in sheet.get_rows():
for cell in row.get_cells():
if not cell.value or cell.get_attribute("office:value-type") == "float":
@@ -330,6 +390,8 @@ def merge_shipment_cells(
occasionnals: list[str],
shipments: list[models.Shipment]
):
"""Merge cells for shipment header.
"""
index = len(prefix_header) + len(recurrents) + 1
for _ in enumerate(shipments):
startcol = index
@@ -341,19 +403,23 @@ def generate_recap(
contracts: list[models.Contract],
form: models.Form,
):
"""Generate excel recap for a list of contracts.
"""
product_unit_map = {
'1': 'g',
'2': 'Kg',
'3': 'Piece'
}
recurrents = [
f'{pr.name}{f' - {pr.quantity}{pr.quantity_unit}' if pr.quantity else ''} ({product_unit_map[pr.unit]})'
f'{pr.name}{f' - {pr.quantity}{pr.quantity_unit}'
if pr.quantity else ''} ({product_unit_map[pr.unit]})'
for pr in form.productor.products
if pr.type == models.ProductType.RECCURENT
]
recurrents.sort()
occasionnals = [
f'{pr.name}{f' - {pr.quantity}{pr.quantity_unit}' if pr.quantity else ''} ({product_unit_map[pr.unit]})'
f'{pr.name}{f' - {pr.quantity}{pr.quantity_unit}'
if pr.quantity else ''} ({product_unit_map[pr.unit]})'
for pr in form.productor.products
if pr.type == models.ProductType.OCCASIONAL
]
@@ -402,7 +468,8 @@ def generate_recap(
len(info_header)+len(payment_formula_letters)+len(recurrents) + 1
]
occasionnals_formula_letters = letters[
len(info_header)+len(payment_formula_letters)+len(recurent_formula_letters):
len(info_header)+len(payment_formula_letters)+
len(recurent_formula_letters):
len(info_header)+len(payment_formula_letters) +
len(recurent_formula_letters)+len(occasionnals_header) + 1
]
@@ -421,11 +488,17 @@ def generate_recap(
for index, contract in enumerate(contracts):
prices = compute_contract_prices(contract)
occasionnal_sorted = sorted(
[product for product in contract.products if product.product.type == models.ProductType.OCCASIONAL],
[
product for product in contract.products
if product.product.type == models.ProductType.OCCASIONAL
],
key=lambda x: (x.shipment.name, x.product.name)
)
recurrent_sorted = sorted(
[product for product in contract.products if product.product.type == models.ProductType.RECCURENT],
[
product for product in contract.products
if product.product.type == models.ProductType.RECCURENT
],
key=lambda x: x.product.name
)

View File

@@ -1,9 +1,8 @@
import src.forms.exceptions as exceptions
import src.forms.service as service
import src.messages as messages
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src import messages, models
from src.auth.auth import get_current_user
from src.database import get_session
@@ -33,7 +32,7 @@ async def get_forms_filtered(
@router.get('/{_id}', response_model=models.FormPublic)
async def get_form(
_id: int,
_id: int,
session: Session = Depends(get_session)
):
result = service.get_one(session, _id)

View File

@@ -1,8 +1,7 @@
import src.forms.exceptions as exceptions
import src.messages as messages
from sqlalchemy import func
from sqlmodel import Session, select
from src import models
from src import messages, models
def get_all(
@@ -109,11 +108,13 @@ def delete_one(session: Session, _id: int) -> models.FormPublic:
def is_allowed(
session: Session,
user: models.User,
_id: int = None,
session: Session,
user: models.User,
_id: int = None,
form: models.FormCreate = None
) -> bool:
if not _id and not form:
return False
if not _id:
statement = (
select(models.Productor)

View File

@@ -1,11 +1,9 @@
import src.messages as messages
import src.productors.exceptions as exceptions
import src.productors.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src import messages, models
from src.auth.auth import get_current_user
from src.database import get_session
from src.productors import exceptions, service
router = APIRouter(prefix='/productors')
@@ -26,6 +24,11 @@ def get_productor(
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('productor', 'get')
)
result = service.get_one(session, _id)
if result is None:
raise HTTPException(
@@ -41,6 +44,11 @@ def create_productor(
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, productor=productor):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('productor', 'create')
)
try:
result = service.create_one(session, productor)
except exceptions.ProductorCreateError as error:
@@ -54,6 +62,11 @@ def update_productor(
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('productor', 'update')
)
try:
result = service.update_one(session, _id, productor)
except exceptions.ProductorNotFoundError as error:
@@ -67,6 +80,11 @@ def delete_productor(
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('productor', 'delete')
)
try:
result = service.delete_one(session, _id)
except exceptions.ProductorNotFoundError as error:

View File

@@ -1,7 +1,6 @@
import src.messages as messages
import src.productors.exceptions as exceptions
from sqlmodel import Session, select
from src import models
from src import messages, models
from src.productors import exceptions
def get_all(
@@ -50,9 +49,10 @@ def create_one(
def update_one(
session: Session,
id: int,
productor: models.ProductorUpdate) -> models.ProductorPublic:
statement = select(models.Productor).where(models.Productor.id == id)
_id: int,
productor: models.ProductorUpdate
) -> models.ProductorPublic:
statement = select(models.Productor).where(models.Productor.id == _id)
result = session.exec(statement)
new_productor = result.first()
if not new_productor:
@@ -94,11 +94,13 @@ def delete_one(session: Session, _id: int) -> models.ProductorPublic:
return result
def is_allowed(
session: Session,
user: models.User,
_id: int,
productor: models.ProductorCreate
session: Session,
user: models.User,
_id: int = None,
productor: models.ProductorCreate = None
) -> bool:
if not _id and not productor:
return False
if not _id:
return productor.type in [r.name for r in user.roles]
statement = (

View File

@@ -1,11 +1,10 @@
import src.messages as messages
import src.products.exceptions as exceptions
import src.products.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src import messages, models
from src.auth.auth import get_current_user
from src.database import get_session
from src.products import exceptions
router = APIRouter(prefix='/products')
@@ -27,13 +26,18 @@ def get_products(
)
@router.get('/{id}', response_model=models.ProductPublic)
@router.get('/{_id}', response_model=models.ProductPublic)
def get_product(
id: int,
_id: int,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
result = service.get_one(session, id)
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('product', 'create')
)
result = service.get_one(session, _id)
if result is None:
raise HTTPException(status_code=404,
detail=messages.Messages.not_found('product'))
@@ -46,38 +50,68 @@ def create_product(
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, product=product):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('product', 'create')
)
try:
result = service.create_one(session, product)
except exceptions.ProductCreateError as error:
raise HTTPException(status_code=400, detail=str(error))
raise HTTPException(
status_code=400,
detail=str(error)
) from error
except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(
status_code=404,
detail=str(error)
) from error
return result
@router.put('/{id}', response_model=models.ProductPublic)
@router.put('/{_id}', response_model=models.ProductPublic)
def update_product(
id: int, product: models.ProductUpdate,
_id: int, product: models.ProductUpdate,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('product', 'update')
)
try:
result = service.update_one(session, id, product)
result = service.update_one(session, _id, product)
except exceptions.ProductNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(
status_code=404,
detail=str(error)
) from error
except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(
status_code=404,
detail=str(error)
) from error
return result
@router.delete('/{id}', response_model=models.ProductPublic)
@router.delete('/{_id}', response_model=models.ProductPublic)
def delete_product(
id: int,
_id: int,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('product', 'delete')
)
try:
result = service.delete_one(session, id)
result = service.delete_one(session, _id)
except exceptions.ProductNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(
status_code=404,
detail=str(error)
) from error
return result

View File

@@ -1,7 +1,6 @@
import src.messages as messages
import src.products.exceptions as exceptions
from sqlmodel import Session, select
from src import models
from src import messages, models
from src.products import exceptions
def get_all(
@@ -27,13 +26,17 @@ def get_all(
return session.exec(statement.order_by(models.Product.name)).all()
def get_one(session: Session, product_id: int) -> models.ProductPublic:
def get_one(
session: Session,
product_id: int,
) -> models.ProductPublic:
return session.get(models.Product, product_id)
def create_one(
session: Session,
product: models.ProductCreate) -> models.ProductPublic:
session: Session,
product: models.ProductCreate,
) -> models.ProductPublic:
if not product:
raise exceptions.ProductCreateError(
messages.Messages.invalid_input(
@@ -50,10 +53,11 @@ def create_one(
def update_one(
session: Session,
id: int,
product: models.ProductUpdate) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == id)
session: Session,
_id: int,
product: models.ProductUpdate
) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == _id)
result = session.exec(statement)
new_product = result.first()
if not new_product:
@@ -74,8 +78,11 @@ def update_one(
return new_product
def delete_one(session: Session, id: int) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == id)
def delete_one(
session: Session,
_id: int
) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == _id)
result = session.exec(statement)
product = result.first()
if not product:
@@ -87,11 +94,13 @@ def delete_one(session: Session, id: int) -> models.ProductPublic:
return result
def is_allowed(
session: Session,
user: models.User,
_id: int,
product: models.ProductCreate
session: Session,
user: models.User,
_id: int = None,
product: models.ProductCreate = None,
) -> bool:
if not _id and not product:
return False
if not _id:
statement = (
select(models.Product)
@@ -113,4 +122,4 @@ def is_allowed(
.where(models.Productor.type.in_([r.name for r in user.roles]))
.distinct()
)
return len(session.exec(statement).all()) > 0
return len(session.exec(statement).all()) > 0

View File

@@ -1,10 +1,9 @@
# pylint: disable=E1101
import datetime
import src.messages as messages
import src.shipments.exceptions as exceptions
from sqlmodel import Session, select
from src import models
from src import messages, models
def get_all(
@@ -127,3 +126,40 @@ def delete_one(session: Session, _id: int) -> models.ShipmentPublic:
session.delete(shipment)
session.commit()
return result
def is_allowed(
session: Session,
user: models.User,
_id: int = None,
shipment: models.ShipmentCreate = None,
):
if not _id and not shipment:
return False
if not _id:
statement = (
select(models.Shipment)
.join(
models.Form,
models.Shipment.form_id == models.Form.id
)
.where(models.Form.id == shipment.form_id)
)
form = session.exec(statement).first()
return form.productor.type in [r.name for r in user.roles]
statement = (
select(models.Shipment)
.join(
models.Form,
models.Shipment.form_id == models.Form.id
)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id
)
.where(models.Shipment.id == _id)
.where(models.Productor.type.in_([r.name for r in user.roles]))
.distinct()
)
return len(session.exec(statement).all()) > 0

View File

@@ -1,9 +1,8 @@
import src.messages as messages
import src.shipments.exceptions as exceptions
import src.shipments.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src import messages, models
from src.auth.auth import get_current_user
from src.database import get_session
@@ -33,6 +32,11 @@ def get_shipment(
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('shipment', 'get')
)
result = service.get_one(session, _id)
if result is None:
raise HTTPException(
@@ -48,6 +52,11 @@ def create_shipment(
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, shipment=shipment):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('shipment', 'create')
)
try:
result = service.create_one(session, shipment)
except exceptions.ShipmentCreateError as error:
@@ -62,6 +71,11 @@ def update_shipment(
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('shipment', 'update')
)
try:
result = service.update_one(session, _id, shipment)
except exceptions.ShipmentNotFoundError as error:
@@ -75,6 +89,12 @@ def delete_shipment(
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('shipment', 'delete')
)
try:
result = service.delete_one(session, _id)
except exceptions.ShipmentNotFoundError as error:

View File

@@ -1,8 +1,7 @@
import src.messages as messages
import src.templates.service as service
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session
from src import models
from src import messages, models
from src.auth.auth import get_current_user
from src.database import get_session

View File

@@ -1,7 +1,6 @@
import src.messages as messages
import src.users.exceptions as exceptions
from sqlmodel import Session, select
from src import models
from src import messages, models
def get_all(
@@ -48,7 +47,8 @@ def get_or_create_user(session: Session, user_create: models.UserCreate):
user = session.exec(statement).first()
if user:
user_role_names = [r.name for r in user.roles]
if user_role_names != user_create.role_names or user.name != user_create.name:
if (user_role_names != user_create.role_names or
user.name != user_create.name):
user = update_one(session, user.id, user_create)
return user
user = create_one(session, user_create)
@@ -119,3 +119,9 @@ def delete_one(session: Session, _id: int) -> models.UserPublic:
session.delete(user)
session.commit()
return result
def is_allowed(
logged_user: models.User,
):
return len(logged_user.roles) >= 5

View File

@@ -1,9 +1,8 @@
import src.messages as messages
import src.users.exceptions as exceptions
import src.users.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src import messages, models
from src.auth.auth import get_current_user
from src.database import get_session
@@ -13,7 +12,7 @@ router = APIRouter(prefix='/users')
@router.get('', response_model=list[models.UserPublic])
def get_users(
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user),
_: models.User = Depends(get_current_user),
names: list[str] = Query([]),
emails: list[str] = Query([]),
):
@@ -29,6 +28,11 @@ def get_roles(
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(user):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('roles', 'get all')
)
return service.get_roles(session)
@@ -38,6 +42,11 @@ def get_user(
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(user):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('user', 'get')
)
result = service.get_one(session, _id)
if result is None:
raise HTTPException(
@@ -53,11 +62,16 @@ def create_user(
logged_user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(logged_user):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('user', 'create')
)
try:
user = service.create_one(session, user)
except exceptions.UserCreateError as error:
raise HTTPException(
status_code=400,
status_code=400,
detail=str(error)
) from error
return user
@@ -70,6 +84,11 @@ def update_user(
logged_user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(logged_user):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('user', 'update')
)
try:
result = service.update_one(session, _id, user)
except exceptions.UserNotFoundError as error:
@@ -80,14 +99,19 @@ def update_user(
return result
@router.delete('/{id}', response_model=models.UserPublic)
@router.delete('/{_id}', response_model=models.UserPublic)
def delete_user(
id: int,
_id: int,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(user):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('user', 'delete')
)
try:
result = service.delete_one(session, id)
result = service.delete_one(session, _id)
except exceptions.UserNotFoundError as error:
raise HTTPException(
status_code=404,

View File

@@ -7,8 +7,6 @@ from src.auth.auth import get_current_user
from src.database import get_session
from src.main import app
from .fixtures import *
@pytest.fixture
def mock_session(mocker):

View File

@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
#
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
#
# SPDX-License-Identifier: MIT

View File

@@ -1,15 +1,18 @@
import src.contracts.service as service
import tests.factories.contract_products as contract_products_factory
import tests.factories.contracts as contract_factory
import tests.factories.forms as form_factory
from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user
from src.main import app
class TestContracts:
def test_get_all(self, client, mocker, mock_session, mock_user):
def test_get_all(
self,
client,
mocker,
mock_session,
mock_user
):
mock_results = [
contract_factory.contract_public_factory(id=1),
contract_factory.contract_public_factory(id=2),
@@ -32,7 +35,13 @@ class TestContracts:
[],
)
def test_get_all_filters(self, client, mocker, mock_session, mock_user):
def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user
):
mock_results = [
contract_factory.contract_public_factory(id=2),
]
@@ -54,11 +63,10 @@ class TestContracts:
)
def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
@@ -72,7 +80,12 @@ class TestContracts:
app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user):
def test_get_one(
self,
client,
mocker,
mock_session,
):
mock_result = contract_factory.contract_public_factory(id=2)
mock = mocker.patch.object(
@@ -95,7 +108,12 @@ class TestContracts:
2
)
def test_get_one_notfound(self, client, mocker, mock_session, mock_user):
def test_get_one_notfound(
self,
client,
mocker,
mock_session,
):
mock_result = None
mock = mocker.patch.object(
service,
@@ -116,11 +134,10 @@ class TestContracts:
)
def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
@@ -134,7 +151,12 @@ class TestContracts:
app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user):
def test_delete_one(
self,
client,
mocker,
mock_session,
):
contract_result = contract_factory.contract_public_factory()
mock = mocker.patch.object(
@@ -158,11 +180,10 @@ class TestContracts:
)
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user
self,
client,
mocker,
mock_session,
):
contract_result = None
@@ -187,11 +208,9 @@ class TestContracts:
)
def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)

View File

@@ -1,15 +1,20 @@
import src.forms.exceptions as forms_exceptions
import src.forms.service as service
import src.messages as messages
import tests.factories.forms as form_factory
from fastapi.exceptions import HTTPException
from src import models
from src import messages
from src.auth.auth import get_current_user
from src.main import app
class TestForms:
def test_get_all(self, client, mocker, mock_session, mock_user):
def test_get_all(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
form_factory.form_public_factory(name="test 1", id=1),
form_factory.form_public_factory(name="test 2", id=2),
@@ -34,7 +39,13 @@ class TestForms:
mock_user,
)
def test_get_all_filters(self, client, mocker, mock_session, mock_user):
def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
form_factory.form_public_factory(name="test 2", id=2),
]
@@ -59,11 +70,10 @@ class TestForms:
)
def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
@@ -77,7 +87,12 @@ class TestForms:
app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user):
def test_get_one(
self,
client,
mocker,
mock_session,
):
mock_result = form_factory.form_public_factory(name="test 2", id=2)
mock = mocker.patch.object(
@@ -96,7 +111,12 @@ class TestForms:
2
)
def test_get_one_notfound(self, client, mocker, mock_session, mock_user):
def test_get_one_notfound(
self,
client,
mocker,
mock_session,
):
mock_result = None
mock = mocker.patch.object(
service,
@@ -104,14 +124,18 @@ class TestForms:
return_value=mock_result
)
response = client.get('/api/forms/2')
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2
)
def test_create_one(self, client, mocker, mock_session, mock_user):
def test_create_one(
self,
client,
mocker,
mock_session,
):
form_body = form_factory.form_body_factory(name='test form create')
form_create = form_factory.form_create_factory(name='test form create')
form_result = form_factory.form_public_factory(name='test form create')
@@ -133,7 +157,11 @@ class TestForms:
)
def test_create_one_referer_notfound(
self, client, mocker, mock_session, mock_user):
self,
client,
mocker,
mock_session,
):
form_body = form_factory.form_body_factory(
name='test form create', referer_id=12312)
form_create = form_factory.form_create_factory(
@@ -144,8 +172,6 @@ class TestForms:
messages.Messages.not_found('referer')))
response = client.post('/api/forms', json=form_body)
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
@@ -153,7 +179,11 @@ class TestForms:
)
def test_create_one_productor_notfound(
self, client, mocker, mock_session, mock_user):
self,
client,
mocker,
mock_session,
):
form_body = form_factory.form_body_factory(
name='test form create', productor_id=1231)
form_create = form_factory.form_create_factory(
@@ -164,7 +194,6 @@ class TestForms:
messages.Messages.not_found('productor')))
response = client.post('/api/forms', json=form_body)
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
@@ -173,11 +202,10 @@ class TestForms:
)
def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
form_body = form_factory.form_body_factory(name='test form create')
@@ -192,7 +220,12 @@ class TestForms:
app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user):
def test_update_one(
self,
client,
mocker,
mock_session,
):
form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update')
form_result = form_factory.form_public_factory(name='test form update')
@@ -215,11 +248,11 @@ class TestForms:
)
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
mock_session,
):
form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update')
@@ -228,7 +261,6 @@ class TestForms:
messages.Messages.not_found('form')))
response = client.put('/api/forms/2', json=form_body)
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
@@ -238,7 +270,11 @@ class TestForms:
)
def test_update_one_referer_notfound(
self, client, mocker, mock_session, mock_user):
self,
client,
mocker,
mock_session,
):
form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update')
@@ -247,7 +283,6 @@ class TestForms:
messages.Messages.not_found('referer')))
response = client.put('/api/forms/2', json=form_body)
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
@@ -257,7 +292,11 @@ class TestForms:
)
def test_update_one_productor_notfound(
self, client, mocker, mock_session, mock_user):
self,
client,
mocker,
mock_session,
):
form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update')
@@ -266,7 +305,6 @@ class TestForms:
messages.Messages.not_found('productor')))
response = client.put('/api/forms/2', json=form_body)
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
@@ -276,11 +314,10 @@ class TestForms:
)
def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
form_body = form_factory.form_body_factory(name='test form update')
@@ -295,7 +332,12 @@ class TestForms:
app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user):
def test_delete_one(
self,
client,
mocker,
mock_session,
):
form_result = form_factory.form_public_factory(name='test form delete')
mock = mocker.patch.object(
@@ -315,19 +357,19 @@ class TestForms:
)
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
form_result = None
self,
client,
mocker,
mock_session,
):
mock = mocker.patch.object(
service, 'delete_one', side_effect=forms_exceptions.FormNotFoundError(
messages.Messages.not_found('form')))
service,
'delete_one',
side_effect=forms_exceptions.FormNotFoundError(
messages.Messages.not_found('form'))
)
response = client.delete('/api/forms/2')
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
@@ -336,11 +378,10 @@ class TestForms:
)
def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)

View File

@@ -1,15 +1,19 @@
import src.messages as messages
import src.productors.exceptions as exceptions
import src.productors.service as service
import tests.factories.productors as productor_factory
from fastapi.exceptions import HTTPException
from src import models
from src import messages
from src.auth.auth import get_current_user
from src.main import app
from src.productors import exceptions, service
class TestProductors:
def test_get_all(self, client, mocker, mock_session, mock_user):
def test_get_all(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
productor_factory.productor_public_factory(name="test 1", id=1),
productor_factory.productor_public_factory(name="test 2", id=2),
@@ -33,7 +37,13 @@ class TestProductors:
[],
)
def test_get_all_filters(self, client, mocker, mock_session, mock_user):
def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
productor_factory.productor_public_factory(name="test 2", id=2),
]
@@ -57,11 +67,10 @@ class TestProductors:
)
def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
@@ -75,10 +84,22 @@ class TestProductors:
app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user):
def test_get_one(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = productor_factory.productor_public_factory(
name="test 2", id=2)
mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'get_one',
@@ -95,7 +116,18 @@ class TestProductors:
2
)
def test_get_one_notfound(self, client, mocker, mock_session, mock_user):
def test_get_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock_result = None
mock = mocker.patch.object(
service,
@@ -103,7 +135,6 @@ class TestProductors:
return_value=mock_result
)
response = client.get('/api/productors/2')
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
@@ -111,11 +142,10 @@ class TestProductors:
)
def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
@@ -129,7 +159,13 @@ class TestProductors:
app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user):
def test_create_one(
self,
client,
mocker,
mock_session,
mock_user,
):
productor_body = productor_factory.productor_body_factory(
name='test productor create')
productor_create = productor_factory.productor_create_factory(
@@ -137,6 +173,12 @@ class TestProductors:
productor_result = productor_factory.productor_public_factory(
name='test productor create')
mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'create_one',
@@ -154,11 +196,10 @@ class TestProductors:
)
def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory(
@@ -174,7 +215,13 @@ class TestProductors:
app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user):
def test_update_one(
self,
client,
mocker,
mock_session,
mock_user,
):
productor_body = productor_factory.productor_body_factory(
name='test productor update')
productor_update = productor_factory.productor_update_factory(
@@ -182,6 +229,12 @@ class TestProductors:
productor_result = productor_factory.productor_public_factory(
name='test productor update')
mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'update_one',
@@ -200,23 +253,34 @@ class TestProductors:
)
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
mock_session,
mock_user,
):
productor_body = productor_factory.productor_body_factory(
name='test productor update')
name='test productor update',
)
productor_update = productor_factory.productor_update_factory(
name='test productor update')
productor_result = None
name='test productor update',
)
mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service, 'update_one', side_effect=exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')))
service,
'update_one',
side_effect=exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')
)
)
response = client.put('/api/productors/2', json=productor_body)
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
@@ -225,12 +289,12 @@ class TestProductors:
productor_update
)
def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory(
@@ -246,10 +310,23 @@ class TestProductors:
app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user):
def test_delete_one(
self,
client,
mocker,
mock_session,
mock_user,
):
productor_result = productor_factory.productor_public_factory(
name='test productor delete')
mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'delete_one',
@@ -267,19 +344,26 @@ class TestProductors:
)
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
productor_result = None
self,
client,
mocker,
mock_session,
mock_user,
):
mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service, 'delete_one', side_effect=exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')))
service,
'delete_one',
side_effect=exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')
)
)
response = client.delete('/api/productors/2')
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
@@ -288,16 +372,12 @@ class TestProductors:
)
def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory(
name='test productor delete')
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.productors.service.delete_one')

View File

@@ -1,14 +1,19 @@
import src.products.exceptions as exceptions
import src.products.service as service
import tests.factories.products as product_factory
from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user
from src.main import app
from src.products import exceptions
class TestProducts:
def test_get_all(self, client, mocker, mock_session, mock_user):
def test_get_all(
self,
client,
mocker,
mock_session,
mock_user
):
mock_results = [
product_factory.product_public_factory(name="test 1", id=1),
product_factory.product_public_factory(name="test 2", id=2),
@@ -33,7 +38,13 @@ class TestProducts:
[]
)
def test_get_all_filters(self, client, mocker, mock_session, mock_user):
def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user
):
mock_results = [
product_factory.product_public_factory(name="test 2", id=2),
]
@@ -60,8 +71,7 @@ class TestProducts:
self,
client,
mocker,
mock_session,
mock_user):
):
def unauthorized():
raise HTTPException(status_code=401)
@@ -75,7 +85,12 @@ class TestProducts:
app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user):
def test_get_one(
self,
client,
mocker,
mock_session,
):
mock_result = product_factory.product_public_factory(
name="test 2", id=2)
@@ -95,7 +110,12 @@ class TestProducts:
2
)
def test_get_one_notfound(self, client, mocker, mock_session, mock_user):
def test_get_one_notfound(
self,
client,
mocker,
mock_session,
):
mock_result = None
mock = mocker.patch.object(
service,
@@ -103,7 +123,6 @@ class TestProducts:
return_value=mock_result
)
response = client.get('/api/products/2')
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
@@ -111,11 +130,10 @@ class TestProducts:
)
def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
@@ -129,7 +147,12 @@ class TestProducts:
app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user):
def test_create_one(
self,
client,
mocker,
mock_session,
):
product_body = product_factory.product_body_factory(
name='test product create')
product_create = product_factory.product_create_factory(
@@ -154,11 +177,10 @@ class TestProducts:
)
def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory(
@@ -174,13 +196,21 @@ class TestProducts:
app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user):
def test_update_one(
self,
client,
mocker,
mock_session,
):
product_body = product_factory.product_body_factory(
name='test product update')
name='test product update'
)
product_update = product_factory.product_update_factory(
name='test product update')
name='test product update'
)
product_result = product_factory.product_public_factory(
name='test product update')
name='test product update'
)
mock = mocker.patch.object(
service,
@@ -200,16 +230,17 @@ class TestProducts:
)
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
mock_session,
):
product_body = product_factory.product_body_factory(
name='test product update')
name='test product update'
)
product_update = product_factory.product_update_factory(
name='test product update')
product_result = None
name='test product update'
)
mock = mocker.patch.object(
service,
@@ -218,7 +249,6 @@ class TestProducts:
)
response = client.put('/api/products/2', json=product_body)
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
@@ -228,11 +258,10 @@ class TestProducts:
)
def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory(
@@ -248,7 +277,12 @@ class TestProducts:
app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user):
def test_delete_one(
self,
client,
mocker,
mock_session,
):
product_result = product_factory.product_public_factory(
name='test product delete')
@@ -269,13 +303,11 @@ class TestProducts:
)
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
product_result = None
self,
client,
mocker,
mock_session,
):
mock = mocker.patch.object(
service,
'delete_one',
@@ -283,7 +315,6 @@ class TestProducts:
)
response = client.delete('/api/products/2')
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
@@ -295,13 +326,9 @@ class TestProducts:
self,
client,
mocker,
mock_session,
mock_user):
):
def unauthorized():
raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory(
name='test product delete')
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.products.service.delete_one')

View File

@@ -1,15 +1,20 @@
import src.messages as messages
import src.shipments.exceptions as exceptions
import src.shipments.service as service
import tests.factories.shipments as shipment_factory
from fastapi.exceptions import HTTPException
from src import models
from src import messages
from src.auth.auth import get_current_user
from src.main import app
class TestShipments:
def test_get_all(self, client, mocker, mock_session, mock_user):
def test_get_all(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
shipment_factory.shipment_public_factory(name="test 1", id=1),
shipment_factory.shipment_public_factory(name="test 2", id=2),
@@ -34,7 +39,13 @@ class TestShipments:
[],
)
def test_get_all_filters(self, client, mocker, mock_session, mock_user):
def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
shipment_factory.shipment_public_factory(name="test 2", id=2),
]
@@ -59,11 +70,10 @@ class TestShipments:
)
def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
@@ -77,7 +87,12 @@ class TestShipments:
app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user):
def test_get_one(
self,
client,
mocker,
mock_session,
):
mock_result = shipment_factory.shipment_public_factory(
name="test 2", id=2)
@@ -97,7 +112,12 @@ class TestShipments:
2
)
def test_get_one_notfound(self, client, mocker, mock_session, mock_user):
def test_get_one_notfound(
self,
client,
mocker,
mock_session,
):
mock_result = None
mock = mocker.patch.object(
service,
@@ -105,7 +125,6 @@ class TestShipments:
return_value=mock_result
)
response = client.get('/api/shipments/2')
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
@@ -113,11 +132,10 @@ class TestShipments:
)
def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
@@ -131,13 +149,21 @@ class TestShipments:
app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user):
def test_create_one(
self,
client,
mocker,
mock_session,
):
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment create')
name='test shipment create'
)
shipment_create = shipment_factory.shipment_create_factory(
name='test shipment create')
name='test shipment create'
)
shipment_result = shipment_factory.shipment_public_factory(
name='test shipment create')
name='test shipment create'
)
mock = mocker.patch.object(
service,
@@ -156,15 +182,15 @@ class TestShipments:
)
def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment create')
name='test shipment create'
)
app.dependency_overrides[get_current_user] = unauthorized
@@ -176,13 +202,21 @@ class TestShipments:
app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user):
def test_update_one(
self,
client,
mocker,
mock_session,
):
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update')
name='test shipment update'
)
shipment_update = shipment_factory.shipment_update_factory(
name='test shipment update')
name='test shipment update'
)
shipment_result = shipment_factory.shipment_public_factory(
name='test shipment update')
name='test shipment update'
)
mock = mocker.patch.object(
service,
@@ -202,22 +236,23 @@ class TestShipments:
)
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
mock_session,
):
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update')
name='test shipment update'
)
shipment_update = shipment_factory.shipment_update_factory(
name='test shipment update')
name='test shipment update'
)
mock = mocker.patch.object(
service, 'update_one', side_effect=exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment')))
response = client.put('/api/shipments/2', json=shipment_body)
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
@@ -227,15 +262,15 @@ class TestShipments:
)
def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update')
name='test shipment update'
)
app.dependency_overrides[get_current_user] = unauthorized
@@ -247,9 +282,15 @@ class TestShipments:
app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user):
def test_delete_one(
self,
client,
mocker,
mock_session,
):
shipment_result = shipment_factory.shipment_public_factory(
name='test shipment delete')
name='test shipment delete'
)
mock = mocker.patch.object(
service,
@@ -268,19 +309,20 @@ class TestShipments:
)
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
shipment_result = None
self,
client,
mocker,
mock_session,
):
mock = mocker.patch.object(
service, 'delete_one', side_effect=exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment')))
service,
'delete_one',
side_effect=exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment')
)
)
response = client.delete('/api/shipments/2')
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
@@ -289,15 +331,12 @@ class TestShipments:
)
def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment delete')
app.dependency_overrides[get_current_user] = unauthorized

View File

@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
#
# SPDX-License-Identifier: MIT

View File

@@ -26,6 +26,9 @@ export function ContractModal({ opened, onClose, handleSubmit }: ContractModalPr
});
const formSelect = useMemo(() => {
if (!allForms) {
return [];
}
return allForms?.map((form) => ({
value: String(form.id),
label: `${form.season} ${form.name}`,