Compare commits

8 Commits

Author SHA1 Message Date
Julien Aldon
e970bb683a fix a bug that could prevent user to selet their payment methods
All checks were successful
Deploy Amap / deploy (push) Successful in 1m52s
2026-03-06 11:59:02 +01:00
Julien Aldon
c27c7598b5 fix tests 2026-03-06 11:26:02 +01:00
b4b4fa7643 fix all pylint warnings, add tests (wip) fix recap 2026-03-06 00:00:01 +01:00
60812652cf Merge branch 'feat/permissions' of gitea.aldon.fr:Mop/amap into feature/export-recap 2026-03-05 20:58:05 +01:00
cb0235e19f fix contract recap 2026-03-05 20:58:00 +01:00
Julien Aldon
5c356f5802 fix header width order 2026-03-05 17:20:44 +01:00
Julien Aldon
ff19448991 add functionnal recap ready for tests 2026-03-05 17:17:23 +01:00
5e413b11e0 add permission check for form productor and product 2026-03-04 23:36:17 +01:00
44 changed files with 2032 additions and 732 deletions

View File

@@ -1,26 +0,0 @@
default_language_version:
python: python3.13
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-added-large-files
- id: trailing-whitespace
- id: check-ast
- id: check-builtin-literals
- id: check-docstring-first
- id: check-yaml
- id: check-toml
- id: mixed-line-ending
- id: end-of-file-fixer
- repo: local
hooks:
- id: check-pylint
name: check-pylint
entry: pylint -d R0801,R0903,W0511,W0603,C0103,R0902
language: system
types: [python]
pass_filenames: false
args:
- backend

View File

@@ -34,8 +34,6 @@ dependencies = [
"pytest", "pytest",
"pytest-cov", "pytest-cov",
"pytest-mock", "pytest-mock",
"autopep8",
"prek",
"pylint", "pylint",
] ]

View File

View File

@@ -1,20 +1,20 @@
from typing import Annotated
from fastapi import APIRouter, Security, HTTPException, Depends, Request, Cookie
from fastapi.responses import RedirectResponse, Response
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlmodel import Session, select
import jwt
from jwt import PyJWKClient
from src.settings import AUTH_URL, TOKEN_URL, JWKS_URL, ISSUER, LOGOUT_URL, settings
import src.users.service as service
from src.database import get_session
from src.models import UserCreate, User, UserPublic
import secrets import secrets
import requests from typing import Annotated
from urllib.parse import urlencode from urllib.parse import urlencode
import src.messages as messages
import jwt
import requests
import src.users.service as service
from fastapi import APIRouter, Cookie, Depends, HTTPException, Request
from fastapi.responses import RedirectResponse, Response
from fastapi.security import HTTPBearer
from jwt import PyJWKClient
from sqlmodel import Session, select
from src import messages
from src.database import get_session
from src.models import User, UserCreate, UserPublic
from src.settings import (AUTH_URL, ISSUER, JWKS_URL, LOGOUT_URL, TOKEN_URL,
settings)
router = APIRouter(prefix='/auth') router = APIRouter(prefix='/auth')
@@ -77,7 +77,18 @@ def callback(code: str, session: Session = Depends(get_session)):
headers = { headers = {
'Content-Type': 'application/x-www-form-urlencoded' 'Content-Type': 'application/x-www-form-urlencoded'
} }
response = requests.post(TOKEN_URL, data=data, headers=headers) try:
response = requests.post(
TOKEN_URL,
data=data,
headers=headers,
timeout=10
)
except requests.exceptions.Timeout as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('token')
) from error
if response.status_code != 200: if response.status_code != 200:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
@@ -98,7 +109,13 @@ def callback(code: str, session: Session = Depends(get_session)):
'client_secret': settings.keycloak_client_secret, 'client_secret': settings.keycloak_client_secret,
'refresh_token': token_data['refresh_token'], 'refresh_token': token_data['refresh_token'],
} }
res = requests.post(LOGOUT_URL, data=data) try:
requests.post(LOGOUT_URL, data=data, timeout=10)
except requests.exceptions.Timeout as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('token')
) from error
resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true') resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true')
return resp return resp
roles = resource_access.get(settings.keycloak_client_id) roles = resource_access.get(settings.keycloak_client_id)
@@ -108,7 +125,13 @@ def callback(code: str, session: Session = Depends(get_session)):
'client_secret': settings.keycloak_client_secret, 'client_secret': settings.keycloak_client_secret,
'refresh_token': token_data['refresh_token'], 'refresh_token': token_data['refresh_token'],
} }
res = requests.post(LOGOUT_URL, data=data) try:
requests.post(LOGOUT_URL, data=data, timeout=10)
except requests.exceptions.Timeout as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('token')
) from error
resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true') resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true')
return resp return resp
@@ -159,13 +182,16 @@ def verify_token(token: str):
leeway=60, leeway=60,
) )
return decoded return decoded
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError as error:
raise HTTPException(status_code=401,
detail=messages.Messages.tokenexipired)
except jwt.InvalidTokenError:
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail=messages.Messages.invalidtoken) detail=messages.Messages.tokenexipired
) from error
except jwt.InvalidTokenError as error:
raise HTTPException(
status_code=401,
detail=messages.Messages.invalidtoken
) from error
def get_current_user( def get_current_user(
@@ -173,26 +199,35 @@ def get_current_user(
session: Session = Depends(get_session)): session: Session = Depends(get_session)):
access_token = request.cookies.get('access_token') access_token = request.cookies.get('access_token')
if not access_token: if not access_token:
raise HTTPException(status_code=401, raise HTTPException(
detail=messages.Messages.notauthenticated) status_code=401,
detail=messages.Messages.notauthenticated
)
payload = verify_token(access_token) payload = verify_token(access_token)
if not payload: if not payload:
raise HTTPException(status_code=401, detail='aze') raise HTTPException(
status_code=401,
detail='aze'
)
email = payload.get('email') email = payload.get('email')
if not email: if not email:
raise HTTPException(status_code=401, raise HTTPException(
detail=messages.Messages.notauthenticated) status_code=401,
detail=messages.Messages.notauthenticated
)
user = session.exec(select(User).where(User.email == email)).first() user = session.exec(select(User).where(User.email == email)).first()
if not user: if not user:
raise HTTPException(status_code=401, raise HTTPException(
detail=messages.Messages.not_found('user')) status_code=401,
detail=messages.Messages.not_found('user')
)
return user return user
@router.post('/refresh') @router.post('/refresh')
def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None): def refresh_user_token(refresh_token: Annotated[str | None, Cookie()] = None):
refresh = refresh_token refresh = refresh_token
data = { data = {
'grant_type': 'refresh_token', 'grant_type': 'refresh_token',
@@ -203,7 +238,18 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
headers = { headers = {
'Content-Type': 'application/x-www-form-urlencoded' 'Content-Type': 'application/x-www-form-urlencoded'
} }
result = requests.post(TOKEN_URL, data=data, headers=headers) try:
result = requests.post(
TOKEN_URL,
data=data,
headers=headers,
timeout=10,
)
except requests.exceptions.Timeout as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('token')
) from error
if result.status_code != 200: if result.status_code != 200:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
@@ -216,7 +262,7 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
key='access_token', key='access_token',
value=token_data['access_token'], value=token_data['access_token'],
httponly=True, httponly=True,
secure=True if settings.debug == False else True, secure=True if settings.debug is False else True,
samesite='strict', samesite='strict',
max_age=settings.max_age max_age=settings.max_age
) )
@@ -224,7 +270,7 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
key='refresh_token', key='refresh_token',
value=token_data['refresh_token'] or '', value=token_data['refresh_token'] or '',
httponly=True, httponly=True,
secure=True if settings.debug == False else True, secure=True if settings.debug is False else True,
samesite='strict', samesite='strict',
max_age=30 * 24 * settings.max_age max_age=30 * 24 * settings.max_age
) )
@@ -249,6 +295,6 @@ def me(user: UserPublic = Depends(get_current_user)):
'name': user.name, 'name': user.name,
'email': user.email, 'email': user.email,
'id': user.id, 'id': user.id,
'roles': [role.name for role in user.roles] 'roles': user.roles
} }
} }

View File

@@ -4,11 +4,10 @@ import zipfile
import src.contracts.service as service import src.contracts.service as service
import src.forms.service as form_service import src.forms.service as form_service
import src.messages as messages
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlmodel import Session from sqlmodel import Session
from src import models from src import messages, models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.contracts.generate_contract import (generate_html_contract, from src.contracts.generate_contract import (generate_html_contract,
generate_recap) generate_recap)
@@ -17,88 +16,6 @@ from src.database import get_session
router = APIRouter(prefix='/contracts') router = APIRouter(prefix='/contracts')
def compute_recurrent_prices(
products_quantities: list[dict],
nb_shipment: int
):
"""Compute price for recurrent products"""
result = 0
for product_quantity in products_quantities:
product = product_quantity['product']
quantity = product_quantity['quantity']
result += compute_product_price(product, quantity, nb_shipment)
return result
def compute_occasional_prices(occasionals: list[dict]):
"""Compute prices for occassional products"""
result = 0
for occasional in occasionals:
result += occasional['price']
return result
def compute_product_price(
product: models.Product,
quantity: int,
nb_shipment: int = 1
):
"""Compute price for a product"""
product_quantity_unit = (
1 if product.unit == models.Unit.KILO else 1000
)
final_quantity = (
quantity if product.price else quantity / product_quantity_unit
)
final_price = (
product.price if product.price else product.price_kg
)
return final_price * final_quantity * nb_shipment
def find_dict_in_list(lst, key, value):
"""Find the index of a dictionnary in a list of dictionnaries given a key
and a value.
"""
for i, dic in enumerate(lst):
if dic[key].id == value:
return i
return -1
def create_occasional_dict(contract_products: list[models.ContractProduct]):
"""Create a dictionnary of occasional products"""
result = []
for contract_product in contract_products:
existing_id = find_dict_in_list(
result,
'shipment',
contract_product.shipment.id
)
if existing_id < 0:
result.append({
'shipment': contract_product.shipment,
'price': compute_product_price(
contract_product.product,
contract_product.quantity
),
'products': [{
'product': contract_product.product,
'quantity': contract_product.quantity
}]
})
else:
result[existing_id]['products'].append({
'product': contract_product.product,
'quantity': contract_product.quantity
})
result[existing_id]['price'] += compute_product_price(
contract_product.product,
contract_product.quantity
)
return result
@router.post('') @router.post('')
async def create_contract( async def create_contract(
contract: models.ContractCreate, contract: models.ContractCreate,
@@ -114,7 +31,7 @@ async def create_contract(
new_contract.products new_contract.products
) )
) )
occasionals = create_occasional_dict(occasional_contract_products) occasionals = service.create_occasional_dict(occasional_contract_products)
recurrents = list( recurrents = list(
map( map(
lambda x: {'product': x.product, 'quantity': x.quantity}, lambda x: {'product': x.product, 'quantity': x.quantity},
@@ -127,11 +44,13 @@ async def create_contract(
) )
) )
) )
recurrent_price = compute_recurrent_prices( prices = service.generate_products_prices(
occasionals,
recurrents, recurrents,
len(new_contract.form.shipments) new_contract.form.shipments
) )
price = recurrent_price + compute_occasional_prices(occasionals) recurrent_price = prices['recurrent']
total_price = prices['total']
cheques = list( cheques = list(
map( map(
lambda x: {'name': x.name, 'value': x.value}, lambda x: {'name': x.name, 'value': x.value},
@@ -145,7 +64,7 @@ async def create_contract(
occasionals, occasionals,
recurrents, recurrents,
'{:10.2f}'.format(recurrent_price), '{:10.2f}'.format(recurrent_price),
'{:10.2f}'.format(price) '{:10.2f}'.format(total_price)
) )
pdf_file = io.BytesIO(pdf_bytes) pdf_file = io.BytesIO(pdf_bytes)
contract_id = ( contract_id = (
@@ -154,7 +73,8 @@ async def create_contract(
f'{new_contract.form.productor.type}_' f'{new_contract.form.productor.type}_'
f'{new_contract.form.season}' f'{new_contract.form.season}'
) )
service.add_contract_file(session, new_contract.id, pdf_bytes, price) service.add_contract_file(
session, new_contract.id, pdf_bytes, total_price)
except Exception as error: except Exception as error:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@@ -283,7 +203,7 @@ def get_contract_files(
user: models.User = Depends(get_current_user) user: models.User = Depends(get_current_user)
): ):
"""Get all contract files for a given form""" """Get all contract files for a given form"""
if not form_service.is_allowed(session, user, form_id): if not service.is_allowed(session, user, form_id):
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail=messages.Messages.not_allowed('contracts', 'get') detail=messages.Messages.not_allowed('contracts', 'get')
@@ -329,12 +249,13 @@ def get_contract_recap(
) )
form = form_service.get_one(session, form_id=form_id) form = form_service.get_one(session, form_id=form_id)
contracts = service.get_all(session, user, forms=[form.name]) contracts = service.get_all(session, user, forms=[form.name])
filename = f'{form.name}_recapitulatif_contrats.ods'
return StreamingResponse( return StreamingResponse(
io.BytesIO(generate_recap(contracts, form)), io.BytesIO(generate_recap(contracts, form)),
media_type='application/zip', media_type='application/vnd.oasis.opendocument.spreadsheet',
headers={ headers={
'Content-Disposition': ( 'Content-Disposition': (
'attachment; filename=filename.ods' f'attachment; filename={filename}'
) )
} }
) )

View File

@@ -2,12 +2,12 @@
import html import html
import io import io
import pathlib import pathlib
import string
import jinja2 import jinja2
import odfdo import odfdo
# from odfdo import Cell, Document, Row, Style, Table
from odfdo.element import Element
from src import models from src import models
from src.contracts import service
from weasyprint import HTML from weasyprint import HTML
@@ -18,11 +18,24 @@ def generate_html_contract(
reccurents: list[dict], reccurents: list[dict],
recurrent_price: float | None = None, recurrent_price: float | None = None,
total_price: float | None = None total_price: float | None = None
): ) -> bytes:
"""Generate a html contract
Arguments:
contract(models.Contract): Contract source.
cheques(list[dict]): cheques formated in dict.
occasionals(list[dict]): occasional products.
reccurents(list[dict]): recurrent products.
recurrent_price(float | None = None): total price of recurent products.
total_price(float | None = Non): total price.
Return:
result(bytes): contract file in pdf as bytes.
"""
template_dir = pathlib.Path("./src/contracts/templates").resolve() template_dir = pathlib.Path("./src/contracts/templates").resolve()
template_loader = jinja2.FileSystemLoader(searchpath=template_dir) template_loader = jinja2.FileSystemLoader(searchpath=template_dir)
template_env = jinja2.Environment( template_env = jinja2.Environment(
loader=template_loader, autoescape=jinja2.select_autoescape(["html", "xml"])) loader=template_loader,
autoescape=jinja2.select_autoescape(["html", "xml"])
)
template_file = "layout.html" template_file = "layout.html"
template = template_env.get_template(template_file) template = template_env.get_template(template_file)
output_text = template.render( output_text = template.render(
@@ -65,13 +78,16 @@ def generate_html_contract(
def flatten(xss): def flatten(xss):
"""flatten a list of list.
"""
return [x for xs in xss for x in xs] return [x for xs in xss for x in xs]
def create_column_style_width(size: str) -> odfdo.Style: def create_column_style_width(size: str) -> odfdo.Style:
"""Create a table columm style for a given width. """Create a table columm style for a given width.
Paramenters: Paramenters:
size(str): size of the style (format <number><unit>) unit can be in, cm... see odfdo documentation. size(str): size of the style (format <number><unit>)
unit can be in, cm... see odfdo documentation.
Returns: Returns:
odfdo.Style with the correct column-width attribute. odfdo.Style with the correct column-width attribute.
""" """
@@ -85,7 +101,8 @@ def create_column_style_width(size: str) -> odfdo.Style:
def create_row_style_height(size: str) -> odfdo.Style: def create_row_style_height(size: str) -> odfdo.Style:
"""Create a table height style for a given height. """Create a table height style for a given height.
Paramenters: Paramenters:
size(str): size of the style (format <number><unit>) unit can be in, cm... see odfdo documentation. size(str): size of the style (format <number><unit>)
unit can be in, cm... see odfdo documentation.
Returns: Returns:
odfdo.Style with the correct column-height attribute. odfdo.Style with the correct column-height attribute.
""" """
@@ -96,73 +113,204 @@ def create_row_style_height(size: str) -> odfdo.Style:
) )
def create_center_cell_style(name: str = "centered-cell") -> odfdo.Style: def create_currency_style(name: str = 'currency-euro'):
"""Create a table currency style.
Paramenters:
name(str): name of the style (default to `currency-euro`).
Returns:
odfdo.Style with the correct column-height attribute.
"""
return odfdo.Element.from_tag( return odfdo.Element.from_tag(
f'<style:style style:name="{name}" style:family="table-cell">' f"""
'<style:table-cell-properties style:vertical-align="middle" fo:wrap-option="wrap"/>' <number:currency-style style:name="{name}">
'<style:paragraph-properties fo:text-align="center"/>' <number:number number:min-integer-digits="1"
'</style:style>' number:decimal-places="2"/>
<number:text> €</number:text>
</number:currency-style>"""
) )
def create_cell_style_with_font(name: str = "font", font_size="14pt", bold: bool = False) -> odfdo.Style: def create_cell_style(
name: str = "centered-cell",
font_size: str = '10pt',
bold: bool = False,
background_color: str = '#FFFFFF',
color: str = '#000000',
currency: bool = False,
) -> odfdo.Style:
"""Create a cell style
Paramenters:
name(str): name of the style (default to `centered-cell`).
font_size(str): font_size of the cell (default to `10pt`).
bold(str): is the text bold (default to `False`).
background_color(str): background_color of the cell
(default to `#FFFFFF`).
color(str): color of the text of the cell (default to `#000000`).
currency(str): is the cell a currency (default to `False`).
Returns:
odfdo.Style with the correct column-height attribute.
"""
bold_attr = """
fo:font-weight="bold"
style:font-weight-asian="bold"
style:font-weight-complex="bold"
""" if bold else ''
currency_attr = """
style:data-style-name="currency-euro">
""" if currency else ''
return odfdo.Element.from_tag( return odfdo.Element.from_tag(
f'<style:style style:name="{name}" style:family="table-cell" ' f"""<style:style style:name="{name}" style:family="table-cell"
f'xmlns:fo="urn:oasis:names:tc:opendocument:xmlns:xsl-fo-compatible:1.0">' {currency_attr}>
'<style:table-cell-properties style:vertical-align="middle" fo:wrap-option="wrap"/>' <style:table-cell-properties
f'<style:paragraph-properties fo:text-align="center" fo:font-size="{font_size}" ' fo:border="0.75pt solid #000000"
f'{"fo:font-weight=\"bold\"" if bold else ""}/>' style:vertical-align="middle"
'</style:style>' fo:wrap-option="wrap"
fo:background-color="{background_color}"/>
<style:paragraph-properties fo:text-align="center"/>
<style:text-properties
{bold_attr}
fo:font-size="{font_size}"
fo:color="{color}"/>
</style:style>"""
) )
def apply_center_cell_style(document: odfdo.Document, row: odfdo.Row): def apply_cell_style(
style = document.insert_style( document: odfdo.Document,
create_center_cell_style() table: odfdo.Table,
currency_cols: list[int]
):
"""Apply cell style
"""
document.insert_style(
style=create_currency_style(),
) )
for cell in row.get_cells(): header_style = document.insert_style(
create_cell_style(
name="header-cells",
bold=True,
font_size='12pt',
background_color="#3480eb",
color="#FFF"
)
)
body_style_even = document.insert_style(
create_cell_style(
name="body-style-even",
bold=False,
background_color="#e8eaed",
color="#000000",
)
)
body_style_odd = document.insert_style(
create_cell_style(
name="body-style-odd",
bold=False,
background_color="#FFFFFF",
color="#000000",
)
)
footer_style = document.insert_style(
create_cell_style(
name="footer-cells",
bold=True,
font_size='12pt',
)
)
body_style_even_currency = document.insert_style(
create_cell_style(
name="body-style-even-currency",
bold=False,
background_color="#e8eaed",
color="#000000",
currency=True,
)
)
body_style_odd_currency = document.insert_style(
create_cell_style(
name="body-style-odd-currency",
bold=False,
background_color="#FFFFFF",
color="#000000",
currency=True,
)
)
footer_style_currency = document.insert_style(
create_cell_style(
name="footer-cells-currency",
bold=True,
font_size='12pt',
currency=True,
)
)
for index, row in enumerate(table.get_rows()):
style = body_style_even
currency_style = body_style_even_currency
if index == 0 or index == 1:
style = header_style
elif index == len(table.get_rows()) - 1:
style = footer_style
currency_style = footer_style_currency
elif index % 2 == 0:
style = body_style_even
currency_style = body_style_even_currency
else:
style = body_style_odd
currency_style = body_style_odd_currency
for cell_index, cell in enumerate(row.get_cells()):
if cell_index in currency_cols and not (index == 0 or index == 1):
cell.style = currency_style
else:
cell.style = style cell.style = style
def apply_column_height_style(document: odfdo.Document, row: odfdo.Row, height: str): def apply_column_height_style(
style = document.insert_style( document: odfdo.Document,
style=create_row_style_height(height), name=height, automatic=True table: odfdo.Table
):
"""Apply column height for a given table
"""
header_style = document.insert_style(
style=create_row_style_height('1.60cm'), name='1.60cm', automatic=True
) )
row.style = style body_style = document.insert_style(
style=create_row_style_height('0.90cm'), name='0.90cm', automatic=True
def apply_font_style(document: odfdo.Document, table: odfdo.Table, size: str = "14pt"):
style_header = document.insert_style(
style=create_cell_style_with_font(
'header_font', font_size=size, bold=True
) )
) for index, row in enumerate(table.get_rows()):
if index == 1:
style_body = document.insert_style( row.style = header_style
style=create_cell_style_with_font( else:
'body_font', font_size=size, bold=False row.style = body_style
)
)
for position in range(table.height):
row = table.get_row(position)
for cell in row.get_cells():
cell.style = style_header if position == 0 or position == 1 else style_body
for paragraph in cell.get_paragraphs():
paragraph.style = cell.style
def apply_column_width_style(document: odfdo.Document, table: odfdo.Table, widths: list[str]): def apply_cell_style_by_column(
table: odfdo.Table,
style: odfdo.Style,
col_index: int
):
"""Apply cell style for a given table
"""
for cell in table.get_column_cells(col_index):
cell.style = style
def apply_column_width_style(
document: odfdo.Document,
table: odfdo.Table,
widths: list[str]
):
"""Apply column width style to a table. """Apply column width style to a table.
Parameters: Parameters:
document(odfdo.Document): Document where the table is located. document(odfdo.Document): Document where the table is located.
table(odfdo.Table): Table to apply columns widths. table(odfdo.Table): Table to apply columns widths.
widths(list[str]): list of width in format <number><unit> unit ca be in, cm... see odfdo documentation. widths(list[str]): list of width in format <number><unit> unit ca be
in, cm... see odfdo documentation.
""" """
styles = [] styles = []
for w in widths: for w in widths:
styles.append(document.insert_style( styles.append(document.insert_style(
style=create_column_style_width(w), name=w, automatic=True)) style=create_column_style_width(w), name=w, automatic=True)
)
for position in range(table.width): for position in range(table.width):
col = table.get_column(position) col = table.get_column(position)
@@ -170,75 +318,260 @@ def apply_column_width_style(document: odfdo.Document, table: odfdo.Table, width
table.set_column(position, col) table.set_column(position, col)
def generate_recap( def generate_ods_letters(n: int):
contracts: list[models.Contract], """Generate letters following excel format.
form: models.Form, Arguments:
n(int): `n` letters to generate.
Return:
result(list[str]): list of `n` letters that follow excel pattern.
"""
letters = string.ascii_lowercase
result = []
for i in range(n):
if i > len(letters) - 1:
letter = f'{letters[int(i / len(letters)) - 1]}'
letter += f'{letters[i % len(letters)]}'
result.append(letter)
continue
letter = letters[i]
result.append(letters[i])
return result
def compute_contract_prices(contract: models.Contract) -> dict:
"""Compute price for a give contract.
"""
occasional_contract_products = list(
filter(
lambda contract_product: (
contract_product.product.type == models.ProductType.OCCASIONAL
),
contract.products
)
)
occasionals_dict = service.create_occasional_dict(
occasional_contract_products)
recurrents_dict = list(
map(
lambda x: {'product': x.product, 'quantity': x.quantity},
filter(
lambda contract_product: (
contract_product.product.type ==
models.ProductType.RECCURENT
),
contract.products
)
)
)
prices = service.generate_products_prices(
occasionals_dict,
recurrents_dict,
contract.form.shipments
)
return prices
def transform_formula_cells(sheet: odfdo.Spreadsheet):
"""Transform cell value to a formula using odfdo.
"""
for row in sheet.get_rows():
for cell in row.get_cells():
if not cell.value or cell.get_attribute("office:value-type") == "float":
continue
if '=' in cell.value:
formula = cell.value
cell.clear()
cell.formula = formula
def merge_shipment_cells(
sheet: odfdo.Spreadsheet,
prefix_header: list[str],
recurrents: list[str],
occasionnals: list[str],
shipments: list[models.Shipment]
): ):
recurrents = [pr.name for pr in form.productor.products if pr.type == """Merge cells for shipment header.
models.ProductType.RECCURENT] """
recurrents.sort() index = len(prefix_header) + len(recurrents) + 1
occasionnals = [pr.name for pr in form.productor.products if pr.type ==
models.ProductType.OCCASIONAL]
occasionnals.sort()
shipments = form.shipments
occasionnals_header = [
occ for shipment in shipments for occ in occasionnals]
shipment_header = flatten(
[[f'{shipment.name} - {shipment.date.strftime('%Y-%m-%d')}'] + ["" * len(occasionnals)] for shipment in shipments])
product_unit_map = {
"1": "g",
"2": "kg",
"3": "p"
}
header = (
["Nom", "Email"] +
["Tarif panier", "Total Paniers", "Total à payer"] +
["Cheque 1", "Cheque 2", "Cheque 3"] +
[f"Total {len(shipments)} livraisons + produits occasionnels"] +
recurrents +
occasionnals_header +
["Remarques", "Nom"]
)
data = [
[""] * (9 + len(recurrents)) + shipment_header,
header,
*[
[
f'{contract.firstname} {contract.lastname}',
f'{contract.email}',
*[f'{pr.quantity} {product_unit_map[pr.product.unit]}' for pr in sorted(
contract.products, key=lambda x: x.product.name) if pr.product.type == models.ProductType.RECCURENT],
*[f'{pr.quantity} {product_unit_map[pr.product.unit]}' for pr in sorted(
contract.products, key=lambda x: x.product.name) if pr.product.type == models.ProductType.OCCASIONAL],
"",
f'{contract.firstname} {contract.lastname}',
] for contract in contracts
]
]
doc = odfdo.Document("spreadsheet")
sheet = doc.body.get_sheet(0)
sheet.name = 'Recap'
sheet.set_values(data)
apply_column_width_style(doc, doc.body.get_table(0), ["4cm"] * len(header))
apply_column_height_style(
doc,
doc.body.get_table(0).get_rows((1, 1))[0],
"1.20cm"
)
apply_center_cell_style(doc, doc.body.get_table(0).get_rows((1, 1))[0])
apply_font_style(doc, doc.body.get_table(0))
index = 9 + len(recurrents)
for _ in enumerate(shipments): for _ in enumerate(shipments):
startcol = index startcol = index
endcol = index+len(occasionnals) - 1 endcol = index+len(occasionnals) - 1
sheet.set_span((startcol, 0, endcol, 0), merge=True) sheet.set_span((startcol, 0, endcol, 0), merge=True)
index += len(occasionnals) index += len(occasionnals)
doc.body.append(sheet)
def generate_recap(
contracts: list[models.Contract],
form: models.Form,
):
"""Generate excel recap for a list of contracts.
"""
product_unit_map = {
'1': 'g',
'2': 'Kg',
'3': 'Piece'
}
recurrents = [
f'{pr.name}{f' - {pr.quantity}{pr.quantity_unit}'
if pr.quantity else ''} ({product_unit_map[pr.unit]})'
for pr in form.productor.products
if pr.type == models.ProductType.RECCURENT
]
recurrents.sort()
occasionnals = [
f'{pr.name}{f' - {pr.quantity}{pr.quantity_unit}'
if pr.quantity else ''} ({product_unit_map[pr.unit]})'
for pr in form.productor.products
if pr.type == models.ProductType.OCCASIONAL
]
occasionnals.sort()
shipments = form.shipments
occasionnals_header = [
occ for shipment in shipments for occ in occasionnals
]
info_header: list[str] = ['', 'Nom', 'Email']
cheque_header: list[str] = ['Cheque 1', 'Cheque 2', 'Cheque 3']
payment_header = (
cheque_header +
[f'Total {len(shipments)} livraisons + produits occasionnels']
)
prefix_header: list[str] = (
info_header +
payment_header
)
suffix_header: list[str] = [
'Total produits occasionnels',
'Remarques',
'Nom'
]
shipment_header = flatten([
[f'{shipment.name} - {shipment.date.strftime('%Y-%m-%d')}'] +
['' * len(occasionnals)] for shipment in shipments] +
[''] * len(suffix_header)
)
header: list[str] = (
prefix_header +
recurrents +
['Total produits récurrents'] +
occasionnals_header +
suffix_header
)
letters = generate_ods_letters(len(header))
payment_formula_letters = letters[
len(info_header):len(info_header) + len(payment_header)
]
recurent_formula_letters = letters[
len(info_header)+len(payment_formula_letters):
len(info_header)+len(payment_formula_letters)+len(recurrents) + 1
]
occasionnals_formula_letters = letters[
len(info_header)+len(payment_formula_letters) +
len(recurent_formula_letters):
len(info_header)+len(payment_formula_letters) +
len(recurent_formula_letters)+len(occasionnals_header) + 1
]
footer = (
['', 'Total contrats', ''] +
[f'=SUM({letter}3:{letter}{2+len(contracts)})'
for letter in payment_formula_letters] +
[f'=SUM({letter}3:{letter}{2+len(contracts)})'
for letter in recurent_formula_letters] +
[f'=SUM({letter}3:{letter}{2+len(contracts)})'
for letter in occasionnals_formula_letters]
)
main_data = []
for index, contract in enumerate(contracts):
prices = compute_contract_prices(contract)
occasionnal_sorted = sorted(
[
product for product in contract.products
if product.product.type == models.ProductType.OCCASIONAL
],
key=lambda x: (x.shipment.name, x.product.name)
)
recurrent_sorted = sorted(
[
product for product in contract.products
if product.product.type == models.ProductType.RECCURENT
],
key=lambda x: x.product.name
)
main_data.append([
f'{index + 1}',
f'{contract.firstname} {contract.lastname}',
f'{contract.email}',
*[float(contract.cheques[i].value)
if len(contract.cheques) > i
else ''
for i in range(3)],
prices['total'],
*[pr.quantity for pr in recurrent_sorted],
prices['recurrent'],
*[pr.quantity for pr in occasionnal_sorted],
prices['occasionnal'],
'',
f'{contract.firstname} {contract.lastname}',
])
data = [
[''] * (len(prefix_header) + len(recurrents) + 1) + shipment_header,
header,
*main_data,
footer
]
doc = odfdo.Document('spreadsheet')
sheet = doc.body.get_sheet(0)
sheet.name = 'Recap'
sheet.set_values(data)
if len(occasionnals) > 0:
merge_shipment_cells(
sheet,
prefix_header,
recurrents,
occasionnals,
shipments
)
transform_formula_cells(sheet)
apply_column_width_style(
doc,
doc.body.get_table(0),
['2cm'] +
['6cm'] * 2 +
['2.40cm'] * (len(payment_header) - 1) +
['4cm'] * len(recurrents) +
['4cm'] +
['4cm'] * (len(occasionnals_header) + 1) +
['4cm', '8cm', '6cm']
)
apply_column_height_style(
doc,
doc.body.get_table(0),
)
apply_cell_style(
doc,
doc.body.get_table(0),
[
3,
4,
5,
6,
len(info_header) + len(payment_header),
len(info_header) + len(payment_header) + 1 + len(occasionnals),
]
)
doc.body.append(sheet)
buffer = io.BytesIO() buffer = io.BytesIO()
doc.save('test.ods') doc.save(buffer)
return buffer.getvalue() return buffer.getvalue()

View File

@@ -166,3 +166,103 @@ def is_allowed(
.distinct() .distinct()
) )
return len(session.exec(statement).all()) > 0 return len(session.exec(statement).all()) > 0
def compute_recurrent_prices(
products_quantities: list[dict],
nb_shipment: int
):
"""Compute price for recurrent products"""
result = 0
for product_quantity in products_quantities:
product = product_quantity['product']
quantity = product_quantity['quantity']
result += compute_product_price(product, quantity, nb_shipment)
return result
def compute_occasional_prices(occasionals: list[dict]):
"""Compute prices for occassional products"""
result = 0
for occasional in occasionals:
result += occasional['price']
return result
def compute_product_price(
product: models.Product,
quantity: int,
nb_shipment: int = 1
):
"""Compute price for a product"""
product_quantity_unit = (
1 if product.unit == models.Unit.KILO else 1000
)
final_quantity = (
quantity if product.price else quantity / product_quantity_unit
)
final_price = (
product.price if product.price else product.price_kg
)
return final_price * final_quantity * nb_shipment
def find_dict_in_list(lst, key, value):
"""Find the index of a dictionnary in a list of dictionnaries given a key
and a value.
"""
for i, dic in enumerate(lst):
if dic[key].id == value:
return i
return -1
def create_occasional_dict(contract_products: list[models.ContractProduct]):
"""Create a dictionnary of occasional products"""
result = []
for contract_product in contract_products:
existing_id = find_dict_in_list(
result,
'shipment',
contract_product.shipment.id
)
if existing_id < 0:
result.append({
'shipment': contract_product.shipment,
'price': compute_product_price(
contract_product.product,
contract_product.quantity
),
'products': [{
'product': contract_product.product,
'quantity': contract_product.quantity
}]
})
else:
result[existing_id]['products'].append({
'product': contract_product.product,
'quantity': contract_product.quantity
})
result[existing_id]['price'] += compute_product_price(
contract_product.product,
contract_product.quantity
)
return result
def generate_products_prices(
occasionals: list[dict],
recurrents: list[dict],
shipments: list[models.ShipmentPublic]
):
recurrent_price = compute_recurrent_prices(
recurrents,
len(shipments)
)
occasional_price = compute_occasional_prices(occasionals)
price = recurrent_price + occasional_price
return {
'total': price,
'recurrent': recurrent_price,
'occasionnal': occasional_price
}

View File

@@ -1,9 +1,8 @@
import src.forms.exceptions as exceptions import src.forms.exceptions as exceptions
import src.forms.service as service import src.forms.service as service
import src.messages as messages
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session from sqlmodel import Session
from src import models from src import messages, models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session from src.database import get_session
@@ -32,7 +31,10 @@ async def get_forms_filtered(
@router.get('/{_id}', response_model=models.FormPublic) @router.get('/{_id}', response_model=models.FormPublic)
async def get_form(_id: int, session: Session = Depends(get_session)): async def get_form(
_id: int,
session: Session = Depends(get_session)
):
result = service.get_one(session, _id) result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException( raise HTTPException(
@@ -48,6 +50,11 @@ async def create_form(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, form=form):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('forms', 'update')
)
try: try:
form = service.create_one(session, form) form = service.create_one(session, form)
except exceptions.ProductorNotFoundError as error: except exceptions.ProductorNotFoundError as error:
@@ -61,10 +68,16 @@ async def create_form(
@router.put('/{_id}', response_model=models.FormPublic) @router.put('/{_id}', response_model=models.FormPublic)
async def update_form( async def update_form(
_id: int, form: models.FormUpdate, _id: int,
form: models.FormUpdate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('forms', 'update')
)
try: try:
result = service.update_one(session, _id, form) result = service.update_one(session, _id, form)
except exceptions.FormNotFoundError as error: except exceptions.FormNotFoundError as error:
@@ -82,6 +95,11 @@ async def delete_form(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('forms', 'delete')
)
try: try:
result = service.delete_one(session, _id) result = service.delete_one(session, _id)
except exceptions.FormNotFoundError as error: except exceptions.FormNotFoundError as error:

View File

@@ -1,8 +1,7 @@
import src.forms.exceptions as exceptions import src.forms.exceptions as exceptions
import src.messages as messages
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import Session, select from sqlmodel import Session, select
from src import models from src import messages, models
def get_all( def get_all(
@@ -108,12 +107,27 @@ def delete_one(session: Session, _id: int) -> models.FormPublic:
return result return result
def is_allowed(session: Session, user: models.User, _id: int) -> bool: def is_allowed(
session: Session,
user: models.User,
_id: int = None,
form: models.FormCreate = None
) -> bool:
if not _id and not form:
return False
if not _id:
statement = (
select(models.Productor)
.where(models.Productor.id == form.productor_id)
)
productor = session.exec(statement).first()
return productor.type in [r.name for r in user.roles]
statement = ( statement = (
select(models.Form) select(models.Form)
.join( .join(
models.Productor, models.Productor,
models.Form.productor_id == models.Productor.id) models.Form.productor_id == models.Productor.id
)
.where(models.Form.id == _id) .where(models.Form.id == _id)
.where( .where(
models.Productor.type.in_( models.Productor.type.in_(

View File

@@ -1,11 +1,9 @@
import src.messages as messages
import src.productors.exceptions as exceptions
import src.productors.service as service
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session from sqlmodel import Session
from src import models from src import messages, models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session from src.database import get_session
from src.productors import exceptions, service
router = APIRouter(prefix='/productors') router = APIRouter(prefix='/productors')
@@ -26,6 +24,11 @@ def get_productor(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('productor', 'get')
)
result = service.get_one(session, _id) result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException( raise HTTPException(
@@ -41,6 +44,11 @@ def create_productor(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, productor=productor):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('productor', 'create')
)
try: try:
result = service.create_one(session, productor) result = service.create_one(session, productor)
except exceptions.ProductorCreateError as error: except exceptions.ProductorCreateError as error:
@@ -54,6 +62,11 @@ def update_productor(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('productor', 'update')
)
try: try:
result = service.update_one(session, _id, productor) result = service.update_one(session, _id, productor)
except exceptions.ProductorNotFoundError as error: except exceptions.ProductorNotFoundError as error:
@@ -67,6 +80,11 @@ def delete_productor(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('productor', 'delete')
)
try: try:
result = service.delete_one(session, _id) result = service.delete_one(session, _id)
except exceptions.ProductorNotFoundError as error: except exceptions.ProductorNotFoundError as error:

View File

@@ -1,7 +1,6 @@
import src.messages as messages
import src.productors.exceptions as exceptions
from sqlmodel import Session, select from sqlmodel import Session, select
from src import models from src import messages, models
from src.productors import exceptions
def get_all( def get_all(
@@ -50,9 +49,10 @@ def create_one(
def update_one( def update_one(
session: Session, session: Session,
id: int, _id: int,
productor: models.ProductorUpdate) -> models.ProductorPublic: productor: models.ProductorUpdate
statement = select(models.Productor).where(models.Productor.id == id) ) -> models.ProductorPublic:
statement = select(models.Productor).where(models.Productor.id == _id)
result = session.exec(statement) result = session.exec(statement)
new_productor = result.first() new_productor = result.first()
if not new_productor: if not new_productor:
@@ -81,8 +81,8 @@ def update_one(
return new_productor return new_productor
def delete_one(session: Session, id: int) -> models.ProductorPublic: def delete_one(session: Session, _id: int) -> models.ProductorPublic:
statement = select(models.Productor).where(models.Productor.id == id) statement = select(models.Productor).where(models.Productor.id == _id)
result = session.exec(statement) result = session.exec(statement)
productor = result.first() productor = result.first()
if not productor: if not productor:
@@ -92,3 +92,21 @@ def delete_one(session: Session, id: int) -> models.ProductorPublic:
session.delete(productor) session.delete(productor)
session.commit() session.commit()
return result return result
def is_allowed(
session: Session,
user: models.User,
_id: int = None,
productor: models.ProductorCreate = None
) -> bool:
if not _id and not productor:
return False
if not _id:
return productor.type in [r.name for r in user.roles]
statement = (
select(models.Productor)
.where(models.Productor.id == _id)
.where(models.Productor.type.in_([r.name for r in user.roles]))
.distinct()
)
return len(session.exec(statement).all()) > 0

View File

@@ -1,11 +1,10 @@
import src.messages as messages
import src.products.exceptions as exceptions
import src.products.service as service import src.products.service as service
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session from sqlmodel import Session
from src import models from src import messages, models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session from src.database import get_session
from src.products import exceptions
router = APIRouter(prefix='/products') router = APIRouter(prefix='/products')
@@ -27,13 +26,18 @@ def get_products(
) )
@router.get('/{id}', response_model=models.ProductPublic) @router.get('/{_id}', response_model=models.ProductPublic)
def get_product( def get_product(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.get_one(session, id) if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('product', 'create')
)
result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, raise HTTPException(status_code=404,
detail=messages.Messages.not_found('product')) detail=messages.Messages.not_found('product'))
@@ -46,38 +50,68 @@ def create_product(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, product=product):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('product', 'create')
)
try: try:
result = service.create_one(session, product) result = service.create_one(session, product)
except exceptions.ProductCreateError as error: except exceptions.ProductCreateError as error:
raise HTTPException(status_code=400, detail=str(error)) raise HTTPException(
status_code=400,
detail=str(error)
) from error
except exceptions.ProductorNotFoundError as error: except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(
status_code=404,
detail=str(error)
) from error
return result return result
@router.put('/{id}', response_model=models.ProductPublic) @router.put('/{_id}', response_model=models.ProductPublic)
def update_product( def update_product(
id: int, product: models.ProductUpdate, _id: int, product: models.ProductUpdate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('product', 'update')
)
try: try:
result = service.update_one(session, id, product) result = service.update_one(session, _id, product)
except exceptions.ProductNotFoundError as error: except exceptions.ProductNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(
status_code=404,
detail=str(error)
) from error
except exceptions.ProductorNotFoundError as error: except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(
status_code=404,
detail=str(error)
) from error
return result return result
@router.delete('/{id}', response_model=models.ProductPublic) @router.delete('/{_id}', response_model=models.ProductPublic)
def delete_product( def delete_product(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('product', 'delete')
)
try: try:
result = service.delete_one(session, id) result = service.delete_one(session, _id)
except exceptions.ProductNotFoundError as error: except exceptions.ProductNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(
status_code=404,
detail=str(error)
) from error
return result return result

View File

@@ -1,7 +1,6 @@
import src.messages as messages
import src.products.exceptions as exceptions
from sqlmodel import Session, select from sqlmodel import Session, select
from src import models from src import messages, models
from src.products import exceptions
def get_all( def get_all(
@@ -27,13 +26,17 @@ def get_all(
return session.exec(statement.order_by(models.Product.name)).all() return session.exec(statement.order_by(models.Product.name)).all()
def get_one(session: Session, product_id: int) -> models.ProductPublic: def get_one(
session: Session,
product_id: int,
) -> models.ProductPublic:
return session.get(models.Product, product_id) return session.get(models.Product, product_id)
def create_one( def create_one(
session: Session, session: Session,
product: models.ProductCreate) -> models.ProductPublic: product: models.ProductCreate,
) -> models.ProductPublic:
if not product: if not product:
raise exceptions.ProductCreateError( raise exceptions.ProductCreateError(
messages.Messages.invalid_input( messages.Messages.invalid_input(
@@ -51,9 +54,10 @@ def create_one(
def update_one( def update_one(
session: Session, session: Session,
id: int, _id: int,
product: models.ProductUpdate) -> models.ProductPublic: product: models.ProductUpdate
statement = select(models.Product).where(models.Product.id == id) ) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == _id)
result = session.exec(statement) result = session.exec(statement)
new_product = result.first() new_product = result.first()
if not new_product: if not new_product:
@@ -74,8 +78,11 @@ def update_one(
return new_product return new_product
def delete_one(session: Session, id: int) -> models.ProductPublic: def delete_one(
statement = select(models.Product).where(models.Product.id == id) session: Session,
_id: int
) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == _id)
result = session.exec(statement) result = session.exec(statement)
product = result.first() product = result.first()
if not product: if not product:
@@ -85,3 +92,34 @@ def delete_one(session: Session, id: int) -> models.ProductPublic:
session.delete(product) session.delete(product)
session.commit() session.commit()
return result return result
def is_allowed(
session: Session,
user: models.User,
_id: int = None,
product: models.ProductCreate = None,
) -> bool:
if not _id and not product:
return False
if not _id:
statement = (
select(models.Product)
.join(
models.Productor,
models.Product.productor_id == models.Productor.id
)
.where(models.Product.id == product.productor_id)
)
productor = session.exec(statement).first()
return productor.type in [r.name for r in user.roles]
statement = (
select(models.Product)
.join(
models.Productor,
models.Product.productor_id == models.Productor.id
)
.where(models.Product.id == _id)
.where(models.Productor.type.in_([r.name for r in user.roles]))
.distinct()
)
return len(session.exec(statement).all()) > 0

View File

@@ -1,10 +1,9 @@
# pylint: disable=E1101 # pylint: disable=E1101
import datetime import datetime
import src.messages as messages
import src.shipments.exceptions as exceptions import src.shipments.exceptions as exceptions
from sqlmodel import Session, select from sqlmodel import Session, select
from src import models from src import messages, models
def get_all( def get_all(
@@ -127,3 +126,40 @@ def delete_one(session: Session, _id: int) -> models.ShipmentPublic:
session.delete(shipment) session.delete(shipment)
session.commit() session.commit()
return result return result
def is_allowed(
session: Session,
user: models.User,
_id: int = None,
shipment: models.ShipmentCreate = None,
):
if not _id and not shipment:
return False
if not _id:
statement = (
select(models.Shipment)
.join(
models.Form,
models.Shipment.form_id == models.Form.id
)
.where(models.Form.id == shipment.form_id)
)
form = session.exec(statement).first()
return form.productor.type in [r.name for r in user.roles]
statement = (
select(models.Shipment)
.join(
models.Form,
models.Shipment.form_id == models.Form.id
)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id
)
.where(models.Shipment.id == _id)
.where(models.Productor.type.in_([r.name for r in user.roles]))
.distinct()
)
return len(session.exec(statement).all()) > 0

View File

@@ -1,9 +1,8 @@
import src.messages as messages
import src.shipments.exceptions as exceptions import src.shipments.exceptions as exceptions
import src.shipments.service as service import src.shipments.service as service
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session from sqlmodel import Session
from src import models from src import messages, models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session from src.database import get_session
@@ -33,6 +32,11 @@ def get_shipment(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('shipment', 'get')
)
result = service.get_one(session, _id) result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException( raise HTTPException(
@@ -48,6 +52,11 @@ def create_shipment(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, shipment=shipment):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('shipment', 'create')
)
try: try:
result = service.create_one(session, shipment) result = service.create_one(session, shipment)
except exceptions.ShipmentCreateError as error: except exceptions.ShipmentCreateError as error:
@@ -62,6 +71,11 @@ def update_shipment(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('shipment', 'update')
)
try: try:
result = service.update_one(session, _id, shipment) result = service.update_one(session, _id, shipment)
except exceptions.ShipmentNotFoundError as error: except exceptions.ShipmentNotFoundError as error:
@@ -75,6 +89,12 @@ def delete_shipment(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('shipment', 'delete')
)
try: try:
result = service.delete_one(session, _id) result = service.delete_one(session, _id)
except exceptions.ShipmentNotFoundError as error: except exceptions.ShipmentNotFoundError as error:

View File

@@ -1,8 +1,7 @@
import src.messages as messages
import src.templates.service as service import src.templates.service as service
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session from sqlmodel import Session
from src import models from src import messages, models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session from src.database import get_session

View File

@@ -1,7 +1,6 @@
import src.messages as messages
import src.users.exceptions as exceptions import src.users.exceptions as exceptions
from sqlmodel import Session, select from sqlmodel import Session, select
from src import models from src import messages, models
def get_all( def get_all(
@@ -48,7 +47,8 @@ def get_or_create_user(session: Session, user_create: models.UserCreate):
user = session.exec(statement).first() user = session.exec(statement).first()
if user: if user:
user_role_names = [r.name for r in user.roles] user_role_names = [r.name for r in user.roles]
if user_role_names != user_create.role_names or user.name != user_create.name: if (user_role_names != user_create.role_names or
user.name != user_create.name):
user = update_one(session, user.id, user_create) user = update_one(session, user.id, user_create)
return user return user
user = create_one(session, user_create) user = create_one(session, user_create)
@@ -56,7 +56,9 @@ def get_or_create_user(session: Session, user_create: models.UserCreate):
def get_roles(session: Session): def get_roles(session: Session):
statement = select(models.ContractType) statement = (
select(models.ContractType)
)
return session.exec(statement.order_by(models.ContractType.name)).all() return session.exec(statement.order_by(models.ContractType.name)).all()
@@ -64,7 +66,9 @@ def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
if user is None: if user is None:
raise exceptions.UserCreateError( raise exceptions.UserCreateError(
messages.Messages.invalid_input( messages.Messages.invalid_input(
'user', 'input cannot be None')) 'user', 'input cannot be None'
)
)
new_user = models.User( new_user = models.User(
name=user.name, name=user.name,
email=user.email email=user.email
@@ -81,17 +85,19 @@ def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
def update_one( def update_one(
session: Session, session: Session,
id: int, _id: int,
user: models.UserCreate) -> models.UserPublic: user: models.UserCreate) -> models.UserPublic:
if user is None: if user is None:
raise exceptions.UserCreateError( raise exceptions.UserCreateError(
messages.s.invalid_input( messages.Messages.invalid_input(
'user', 'input cannot be None')) 'user', 'input cannot be None'
statement = select(models.User).where(models.User.id == id) )
)
statement = select(models.User).where(models.User.id == _id)
result = session.exec(statement) result = session.exec(statement)
new_user = result.first() new_user = result.first()
if not new_user: if not new_user:
raise exceptions.UserNotFoundError(f'User {id} not found') raise exceptions.UserNotFoundError(f'User {_id} not found')
new_user.email = user.email new_user.email = user.email
new_user.name = user.name new_user.name = user.name
@@ -103,13 +109,19 @@ def update_one(
return new_user return new_user
def delete_one(session: Session, id: int) -> models.UserPublic: def delete_one(session: Session, _id: int) -> models.UserPublic:
statement = select(models.User).where(models.User.id == id) statement = select(models.User).where(models.User.id == _id)
result = session.exec(statement) result = session.exec(statement)
user = result.first() user = result.first()
if not user: if not user:
raise exceptions.UserNotFoundError(f'User {id} not found') raise exceptions.UserNotFoundError(f'User {_id} not found')
result = models.UserPublic.model_validate(user) result = models.UserPublic.model_validate(user)
session.delete(user) session.delete(user)
session.commit() session.commit()
return result return result
def is_allowed(
logged_user: models.User,
):
return len(logged_user.roles) >= 5

View File

@@ -1,9 +1,8 @@
import src.messages as messages
import src.users.exceptions as exceptions import src.users.exceptions as exceptions
import src.users.service as service import src.users.service as service
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session from sqlmodel import Session
from src import models from src import messages, models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session from src.database import get_session
@@ -13,7 +12,7 @@ router = APIRouter(prefix='/users')
@router.get('', response_model=list[models.UserPublic]) @router.get('', response_model=list[models.UserPublic])
def get_users( def get_users(
session: Session = Depends(get_session), session: Session = Depends(get_session),
user: models.User = Depends(get_current_user), _: models.User = Depends(get_current_user),
names: list[str] = Query([]), names: list[str] = Query([]),
emails: list[str] = Query([]), emails: list[str] = Query([]),
): ):
@@ -29,19 +28,31 @@ def get_roles(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(user):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('roles', 'get all')
)
return service.get_roles(session) return service.get_roles(session)
@router.get('/{id}', response_model=models.UserPublic) @router.get('/{_id}', response_model=models.UserPublic)
def get_users( def get_user(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.get_one(session, id) if not service.is_allowed(user):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('user', 'get')
)
result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, raise HTTPException(
detail=messages.Messages.not_found('user')) status_code=404,
detail=messages.Messages.not_found('user')
)
return result return result
@@ -51,37 +62,59 @@ def create_user(
logged_user: models.User = Depends(get_current_user), logged_user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(logged_user):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('user', 'create')
)
try: try:
user = service.create_one(session, user) user = service.create_one(session, user)
except exceptions.UserCreateError as error: except exceptions.UserCreateError as error:
raise HTTPException(status_code=400, detail=str(error)) raise HTTPException(
status_code=400,
detail=str(error)
) from error
return user return user
@router.put('/{id}', response_model=models.UserPublic) @router.put('/{_id}', response_model=models.UserPublic)
def update_user( def update_user(
id: int, _id: int,
user: models.UserUpdate, user: models.UserUpdate,
logged_user: models.User = Depends(get_current_user), logged_user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(logged_user):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('user', 'update')
)
try: try:
result = service.update_one(session, id, user) result = service.update_one(session, _id, user)
except exceptions.UserNotFoundError as error: except exceptions.UserNotFoundError as error:
raise HTTPException(status_code=404, raise HTTPException(
detail=messages.Messages.not_found('user')) status_code=404,
detail=messages.Messages.not_found('user')
) from error
return result return result
@router.delete('/{id}', response_model=models.UserPublic) @router.delete('/{_id}', response_model=models.UserPublic)
def delete_user( def delete_user(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(user):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('user', 'delete')
)
try: try:
result = service.delete_one(session, id) result = service.delete_one(session, _id)
except exceptions.UserNotFoundError as error: except exceptions.UserNotFoundError as error:
raise HTTPException(status_code=404, raise HTTPException(
detail=messages.Messages.not_found('user')) status_code=404,
detail=messages.Messages.not_found('user')
) from error
return result return result

Binary file not shown.

View File

@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
#
# SPDX-License-Identifier: MIT

View File

@@ -17,19 +17,19 @@ from src import models
@pytest.fixture @pytest.fixture
def productor(session: Session) -> models.ProductorPublic: def productor(session: Session) -> models.ProductorPublic:
productor = productors_service.create_one( result = productors_service.create_one(
session, session,
productors_factory.productor_create_factory( productors_factory.productor_create_factory(
name='test productor', name='test productor',
type='Légumineuses', type='Légumineuses',
) )
) )
return productor return result
@pytest.fixture @pytest.fixture
def productors(session: Session) -> models.ProductorPublic: def productors(session: Session) -> models.ProductorPublic:
productors = [ result = [
productors_service.create_one( productors_service.create_one(
session, session,
productors_factory.productor_create_factory( productors_factory.productor_create_factory(
@@ -45,13 +45,15 @@ def productors(session: Session) -> models.ProductorPublic:
) )
) )
] ]
return productors return result
@pytest.fixture @pytest.fixture
def products(session: Session, def products(
productor: models.ProductorPublic) -> list[models.ProductPublic]: session: Session,
products = [ productor: models.ProductorPublic
) -> list[models.ProductPublic]:
result = [
products_service.create_one( products_service.create_one(
session, session,
products_factory.product_create_factory( products_factory.product_create_factory(
@@ -69,7 +71,7 @@ def products(session: Session,
) )
), ),
] ]
return products return result
@pytest.fixture @pytest.fixture
@@ -87,7 +89,7 @@ def user(session: Session) -> models.UserPublic:
@pytest.fixture @pytest.fixture
def users(session: Session) -> list[models.UserPublic]: def users(session: Session) -> list[models.UserPublic]:
users = [ result = [
users_service.create_one( users_service.create_one(
session, session,
users_factory.user_create_factory( users_factory.user_create_factory(
@@ -112,12 +114,12 @@ def users(session: Session) -> list[models.UserPublic]:
name='test user 3', name='test user 3',
email='test3@test.com', email='test3@test.com',
role_names=['Porc-Agneau']))] role_names=['Porc-Agneau']))]
return users return result
@pytest.fixture @pytest.fixture
def referer(session: Session) -> models.UserPublic: def referer(session: Session) -> models.UserPublic:
referer = users_service.create_one( result = users_service.create_one(
session, session,
users_factory.user_create_factory( users_factory.user_create_factory(
name='test referer', name='test referer',
@@ -125,14 +127,16 @@ def referer(session: Session) -> models.UserPublic:
role_names=['Légumineuses'], role_names=['Légumineuses'],
) )
) )
return referer return result
@pytest.fixture @pytest.fixture
def shipments(session: Session, def shipments(
session: Session,
forms: list[models.FormPublic], forms: list[models.FormPublic],
products: list[models.ProductPublic]): products: list[models.ProductPublic]
shipments = [ ):
result = [
shipments_service.create_one( shipments_service.create_one(
session, session,
shipments_factory.shipment_create_factory( shipments_factory.shipment_create_factory(
@@ -152,7 +156,7 @@ def shipments(session: Session,
) )
), ),
] ]
return shipments return result
@pytest.fixture @pytest.fixture
@@ -161,7 +165,7 @@ def forms(
productor: models.ProductorPublic, productor: models.ProductorPublic,
referer: models.UserPublic referer: models.UserPublic
) -> list[models.FormPublic]: ) -> list[models.FormPublic]:
forms = [ result = [
forms_service.create_one( forms_service.create_one(
session, session,
forms_factory.form_create_factory( forms_factory.form_create_factory(
@@ -181,4 +185,4 @@ def forms(
) )
) )
] ]
return forms return result

View File

@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
#
# SPDX-License-Identifier: MIT

View File

@@ -1,15 +1,18 @@
import src.contracts.service as service import src.contracts.service as service
import tests.factories.contract_products as contract_products_factory
import tests.factories.contracts as contract_factory import tests.factories.contracts as contract_factory
import tests.factories.forms as form_factory
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.main import app from src.main import app
class TestContracts: class TestContracts:
def test_get_all(self, client, mocker, mock_session, mock_user): def test_get_all(
self,
client,
mocker,
mock_session,
mock_user
):
mock_results = [ mock_results = [
contract_factory.contract_public_factory(id=1), contract_factory.contract_public_factory(id=1),
contract_factory.contract_public_factory(id=2), contract_factory.contract_public_factory(id=2),
@@ -32,7 +35,13 @@ class TestContracts:
[], [],
) )
def test_get_all_filters(self, client, mocker, mock_session, mock_user): def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user
):
mock_results = [ mock_results = [
contract_factory.contract_public_factory(id=2), contract_factory.contract_public_factory(id=2),
] ]
@@ -57,8 +66,7 @@ class TestContracts:
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
@@ -72,7 +80,13 @@ class TestContracts:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user): def test_get_one(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = contract_factory.contract_public_factory(id=2) mock_result = contract_factory.contract_public_factory(id=2)
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -80,7 +94,7 @@ class TestContracts:
'get_one', 'get_one',
return_value=mock_result return_value=mock_result
) )
mocker.patch.object( mock_is_allowed = mocker.patch.object(
service, service,
'is_allowed', 'is_allowed',
return_value=True return_value=True
@@ -94,33 +108,48 @@ class TestContracts:
mock_session, mock_session,
2 2
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
2
)
def test_get_one_notfound(self, client, mocker, mock_session, mock_user): def test_get_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = None mock_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'get_one', 'get_one',
return_value=mock_result return_value=mock_result
) )
mock_is_allowed = mocker.patch.object(
mocker.patch.object(
service, service,
'is_allowed', 'is_allowed',
return_value=True return_value=True
) )
response = client.get('/api/contracts/2') response = client.get('/api/contracts/2')
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
mock_session, mock_session,
2 2
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
2
)
def test_get_one_unauthorized( def test_get_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
@@ -134,7 +163,13 @@ class TestContracts:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user): def test_delete_one(
self,
client,
mocker,
mock_session,
mock_user,
):
contract_result = contract_factory.contract_public_factory() contract_result = contract_factory.contract_public_factory()
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -142,8 +177,7 @@ class TestContracts:
'delete_one', 'delete_one',
return_value=contract_result return_value=contract_result
) )
mock_is_allowed = mocker.patch.object(
mocker.patch.object(
service, service,
'is_allowed', 'is_allowed',
return_value=True return_value=True
@@ -156,13 +190,18 @@ class TestContracts:
mock_session, mock_session,
2, 2,
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
2
)
def test_delete_one_notfound( def test_delete_one_notfound(
self, self,
client, client,
mocker, mocker,
mock_session, mock_session,
mock_user mock_user,
): ):
contract_result = None contract_result = None
@@ -171,8 +210,7 @@ class TestContracts:
'delete_one', 'delete_one',
return_value=contract_result return_value=contract_result
) )
mock_is_allowed = mocker.patch.object(
mocker.patch.object(
service, service,
'is_allowed', 'is_allowed',
return_value=True return_value=True
@@ -185,13 +223,16 @@ class TestContracts:
mock_session, mock_session,
2, 2,
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
2
)
def test_delete_one_unauthorized( def test_delete_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session,
mock_user
): ):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)

View File

@@ -1,15 +1,20 @@
import src.forms.exceptions as forms_exceptions import src.forms.exceptions as forms_exceptions
import src.forms.service as service import src.forms.service as service
import src.messages as messages
import tests.factories.forms as form_factory import tests.factories.forms as form_factory
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from src import models from src import messages
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.main import app from src.main import app
class TestForms: class TestForms:
def test_get_all(self, client, mocker, mock_session, mock_user): def test_get_all(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [ mock_results = [
form_factory.form_public_factory(name="test 1", id=1), form_factory.form_public_factory(name="test 1", id=1),
form_factory.form_public_factory(name="test 2", id=2), form_factory.form_public_factory(name="test 2", id=2),
@@ -34,7 +39,13 @@ class TestForms:
mock_user, mock_user,
) )
def test_get_all_filters(self, client, mocker, mock_session, mock_user): def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [ mock_results = [
form_factory.form_public_factory(name="test 2", id=2), form_factory.form_public_factory(name="test 2", id=2),
] ]
@@ -62,8 +73,7 @@ class TestForms:
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
@@ -77,7 +87,13 @@ class TestForms:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user): def test_get_one(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = form_factory.form_public_factory(name="test 2", id=2) mock_result = form_factory.form_public_factory(name="test 2", id=2)
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -85,7 +101,6 @@ class TestForms:
'get_one', 'get_one',
return_value=mock_result return_value=mock_result
) )
response = client.get('/api/forms/2') response = client.get('/api/forms/2')
response_data = response.json() response_data = response.json()
@@ -95,8 +110,14 @@ class TestForms:
mock_session, mock_session,
2 2
) )
assert mock_user
def test_get_one_notfound(self, client, mocker, mock_session, mock_user): def test_get_one_notfound(
self,
client,
mocker,
mock_session,
):
mock_result = None mock_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
@@ -104,14 +125,19 @@ class TestForms:
return_value=mock_result return_value=mock_result
) )
response = client.get('/api/forms/2') response = client.get('/api/forms/2')
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
mock_session, mock_session,
2 2
) )
def test_create_one(self, client, mocker, mock_session, mock_user): def test_create_one(
self,
client,
mocker,
mock_session,
mock_user,
):
form_body = form_factory.form_body_factory(name='test form create') form_body = form_factory.form_body_factory(name='test form create')
form_create = form_factory.form_create_factory(name='test form create') form_create = form_factory.form_create_factory(name='test form create')
form_result = form_factory.form_public_factory(name='test form create') form_result = form_factory.form_public_factory(name='test form create')
@@ -121,6 +147,11 @@ class TestForms:
'create_one', 'create_one',
return_value=form_result return_value=form_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.post('/api/forms', json=form_body) response = client.post('/api/forms', json=form_body)
response_data = response.json() response_data = response.json()
@@ -131,53 +162,95 @@ class TestForms:
mock_session, mock_session,
form_create form_create
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
form=form_create
)
def test_create_one_referer_notfound( def test_create_one_referer_notfound(
self, client, mocker, mock_session, mock_user): self,
client,
mocker,
mock_session,
mock_user,
):
form_body = form_factory.form_body_factory( form_body = form_factory.form_body_factory(
name='test form create', referer_id=12312) name='test form create', referer_id=12312
)
form_create = form_factory.form_create_factory( form_create = form_factory.form_create_factory(
name='test form create', referer_id=12312) name='test form create', referer_id=12312
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, 'create_one', side_effect=forms_exceptions.UserNotFoundError( service,
messages.Messages.not_found('referer'))) 'create_one',
side_effect=forms_exceptions.UserNotFoundError(
messages.Messages.not_found('referer')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.post('/api/forms', json=form_body)
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
form_create
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
form=form_create
)
def test_create_one_productor_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
form_body = form_factory.form_body_factory(
name='test form create', productor_id=1231
)
form_create = form_factory.form_create_factory(
name='test form create', productor_id=1231
)
mock = mocker.patch.object(
service,
'create_one',
side_effect=forms_exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.post('/api/forms', json=form_body) response = client.post('/api/forms', json=form_body)
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
mock_session, mock_session,
form_create form_create
) )
mock_is_allowed.assert_called_once_with(
def test_create_one_productor_notfound(
self, client, mocker, mock_session, mock_user):
form_body = form_factory.form_body_factory(
name='test form create', productor_id=1231)
form_create = form_factory.form_create_factory(
name='test form create', productor_id=1231)
mock = mocker.patch.object(
service, 'create_one', side_effect=forms_exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')))
response = client.post('/api/forms', json=form_body)
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
mock_session, mock_session,
form_create mock_user,
form=form_create
) )
def test_create_one_unauthorized( def test_create_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
form_body = form_factory.form_body_factory(name='test form create') form_body = form_factory.form_body_factory(name='test form create')
@@ -192,7 +265,13 @@ class TestForms:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user): def test_update_one(
self,
client,
mocker,
mock_session,
mock_user,
):
form_body = form_factory.form_body_factory(name='test form update') form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update') form_update = form_factory.form_update_factory(name='test form update')
form_result = form_factory.form_public_factory(name='test form update') form_result = form_factory.form_public_factory(name='test form update')
@@ -202,6 +281,11 @@ class TestForms:
'update_one', 'update_one',
return_value=form_result return_value=form_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/forms/2', json=form_body) response = client.put('/api/forms/2', json=form_body)
response_data = response.json() response_data = response.json()
@@ -213,22 +297,36 @@ class TestForms:
2, 2,
form_update form_update
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_notfound( def test_update_one_notfound(
self, self,
client, client,
mocker, mocker,
mock_session, mock_session,
mock_user): mock_user,
):
form_body = form_factory.form_body_factory(name='test form update') form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update') form_update = form_factory.form_update_factory(name='test form update')
mock = mocker.patch.object( mock = mocker.patch.object(
service, 'update_one', side_effect=forms_exceptions.FormNotFoundError( service,
messages.Messages.not_found('form'))) 'update_one',
side_effect=forms_exceptions.FormNotFoundError(
messages.Messages.not_found('form')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/forms/2', json=form_body) response = client.put('/api/forms/2', json=form_body)
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
@@ -236,18 +334,34 @@ class TestForms:
2, 2,
form_update form_update
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_referer_notfound( def test_update_one_referer_notfound(
self, client, mocker, mock_session, mock_user): self,
client,
mocker,
mock_session,
mock_user,
):
form_body = form_factory.form_body_factory(name='test form update') form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update') form_update = form_factory.form_update_factory(name='test form update')
mock = mocker.patch.object( mock = mocker.patch.object(
service, 'update_one', side_effect=forms_exceptions.UserNotFoundError( service, 'update_one', side_effect=forms_exceptions.UserNotFoundError(
messages.Messages.not_found('referer'))) messages.Messages.not_found('referer')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/forms/2', json=form_body) response = client.put('/api/forms/2', json=form_body)
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
@@ -255,18 +369,36 @@ class TestForms:
2, 2,
form_update form_update
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_productor_notfound( def test_update_one_productor_notfound(
self, client, mocker, mock_session, mock_user): self,
client,
mocker,
mock_session,
mock_user,
):
form_body = form_factory.form_body_factory(name='test form update') form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update') form_update = form_factory.form_update_factory(name='test form update')
mock = mocker.patch.object( mock = mocker.patch.object(
service, 'update_one', side_effect=forms_exceptions.ProductorNotFoundError( service,
messages.Messages.not_found('productor'))) 'update_one',
side_effect=forms_exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/forms/2', json=form_body) response = client.put('/api/forms/2', json=form_body)
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
@@ -274,13 +406,17 @@ class TestForms:
2, 2,
form_update form_update
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_unauthorized( def test_update_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
form_body = form_factory.form_body_factory(name='test form update') form_body = form_factory.form_body_factory(name='test form update')
@@ -295,7 +431,13 @@ class TestForms:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user): def test_delete_one(
self,
client,
mocker,
mock_session,
mock_user,
):
form_result = form_factory.form_public_factory(name='test form delete') form_result = form_factory.form_public_factory(name='test form delete')
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -303,6 +445,11 @@ class TestForms:
'delete_one', 'delete_one',
return_value=form_result return_value=form_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/forms/2') response = client.delete('/api/forms/2')
response_data = response.json() response_data = response.json()
@@ -313,34 +460,49 @@ class TestForms:
mock_session, mock_session,
2, 2,
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_notfound( def test_delete_one_notfound(
self, self,
client, client,
mocker, mocker,
mock_session, mock_session,
mock_user): mock_user,
form_result = None ):
mock = mocker.patch.object( mock = mocker.patch.object(
service, 'delete_one', side_effect=forms_exceptions.FormNotFoundError( service,
messages.Messages.not_found('form'))) 'delete_one',
side_effect=forms_exceptions.FormNotFoundError(
messages.Messages.not_found('form'))
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/forms/2') response = client.delete('/api/forms/2')
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
mock_session, mock_session,
2, 2,
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_unauthorized( def test_delete_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)

View File

@@ -1,15 +1,19 @@
import src.messages as messages
import src.productors.exceptions as exceptions
import src.productors.service as service
import tests.factories.productors as productor_factory import tests.factories.productors as productor_factory
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from src import models from src import messages
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.main import app from src.main import app
from src.productors import exceptions, service
class TestProductors: class TestProductors:
def test_get_all(self, client, mocker, mock_session, mock_user): def test_get_all(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [ mock_results = [
productor_factory.productor_public_factory(name="test 1", id=1), productor_factory.productor_public_factory(name="test 1", id=1),
productor_factory.productor_public_factory(name="test 2", id=2), productor_factory.productor_public_factory(name="test 2", id=2),
@@ -33,7 +37,13 @@ class TestProductors:
[], [],
) )
def test_get_all_filters(self, client, mocker, mock_session, mock_user): def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [ mock_results = [
productor_factory.productor_public_factory(name="test 2", id=2), productor_factory.productor_public_factory(name="test 2", id=2),
] ]
@@ -60,8 +70,7 @@ class TestProductors:
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
@@ -75,10 +84,22 @@ class TestProductors:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user): def test_get_one(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = productor_factory.productor_public_factory( mock_result = productor_factory.productor_public_factory(
name="test 2", id=2) name="test 2", id=2)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'get_one', 'get_one',
@@ -95,27 +116,49 @@ class TestProductors:
2 2
) )
def test_get_one_notfound(self, client, mocker, mock_session, mock_user): mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_get_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock_result = None mock_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'get_one', 'get_one',
return_value=mock_result return_value=mock_result
) )
response = client.get('/api/productors/2') response = client.get('/api/productors/2')
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
mock_session, mock_session,
2 2
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_get_one_unauthorized( def test_get_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
@@ -129,7 +172,13 @@ class TestProductors:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user): def test_create_one(
self,
client,
mocker,
mock_session,
mock_user,
):
productor_body = productor_factory.productor_body_factory( productor_body = productor_factory.productor_body_factory(
name='test productor create') name='test productor create')
productor_create = productor_factory.productor_create_factory( productor_create = productor_factory.productor_create_factory(
@@ -137,6 +186,12 @@ class TestProductors:
productor_result = productor_factory.productor_public_factory( productor_result = productor_factory.productor_public_factory(
name='test productor create') name='test productor create')
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'create_one', 'create_one',
@@ -152,13 +207,17 @@ class TestProductors:
mock_session, mock_session,
productor_create productor_create
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
productor=productor_create
)
def test_create_one_unauthorized( def test_create_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory( productor_body = productor_factory.productor_body_factory(
@@ -174,7 +233,13 @@ class TestProductors:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user): def test_update_one(
self,
client,
mocker,
mock_session,
mock_user,
):
productor_body = productor_factory.productor_body_factory( productor_body = productor_factory.productor_body_factory(
name='test productor update') name='test productor update')
productor_update = productor_factory.productor_update_factory( productor_update = productor_factory.productor_update_factory(
@@ -182,6 +247,12 @@ class TestProductors:
productor_result = productor_factory.productor_public_factory( productor_result = productor_factory.productor_public_factory(
name='test productor update') name='test productor update')
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'update_one', 'update_one',
@@ -199,24 +270,41 @@ class TestProductors:
productor_update productor_update
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_notfound( def test_update_one_notfound(
self, self,
client, client,
mocker, mocker,
mock_session, mock_session,
mock_user): mock_user,
):
productor_body = productor_factory.productor_body_factory( productor_body = productor_factory.productor_body_factory(
name='test productor update') name='test productor update',
)
productor_update = productor_factory.productor_update_factory( productor_update = productor_factory.productor_update_factory(
name='test productor update') name='test productor update',
productor_result = None )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, 'update_one', side_effect=exceptions.ProductorNotFoundError( service,
messages.Messages.not_found('productor'))) 'update_one',
side_effect=exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')
)
)
response = client.put('/api/productors/2', json=productor_body) response = client.put('/api/productors/2', json=productor_body)
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
@@ -225,12 +313,17 @@ class TestProductors:
productor_update productor_update
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_unauthorized( def test_update_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory( productor_body = productor_factory.productor_body_factory(
@@ -246,10 +339,22 @@ class TestProductors:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user): def test_delete_one(
self,
client,
mocker,
mock_session,
mock_user,
):
productor_result = productor_factory.productor_public_factory( productor_result = productor_factory.productor_public_factory(
name='test productor delete') name='test productor delete')
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'delete_one', 'delete_one',
@@ -265,21 +370,34 @@ class TestProductors:
mock_session, mock_session,
2, 2,
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_notfound( def test_delete_one_notfound(
self, self,
client, client,
mocker, mocker,
mock_session, mock_session,
mock_user): mock_user,
productor_result = None ):
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, 'delete_one', side_effect=exceptions.ProductorNotFoundError( service,
messages.Messages.not_found('productor'))) 'delete_one',
side_effect=exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')
)
)
response = client.delete('/api/productors/2') response = client.delete('/api/productors/2')
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
@@ -287,17 +405,19 @@ class TestProductors:
2, 2,
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_unauthorized( def test_delete_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory(
name='test productor delete')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.productors.service.delete_one') mock = mocker.patch('src.productors.service.delete_one')

View File

@@ -1,14 +1,19 @@
import src.products.exceptions as exceptions
import src.products.service as service import src.products.service as service
import tests.factories.products as product_factory import tests.factories.products as product_factory
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.main import app from src.main import app
from src.products import exceptions
class TestProducts: class TestProducts:
def test_get_all(self, client, mocker, mock_session, mock_user): def test_get_all(
self,
client,
mocker,
mock_session,
mock_user
):
mock_results = [ mock_results = [
product_factory.product_public_factory(name="test 1", id=1), product_factory.product_public_factory(name="test 1", id=1),
product_factory.product_public_factory(name="test 2", id=2), product_factory.product_public_factory(name="test 2", id=2),
@@ -33,7 +38,13 @@ class TestProducts:
[] []
) )
def test_get_all_filters(self, client, mocker, mock_session, mock_user): def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user
):
mock_results = [ mock_results = [
product_factory.product_public_factory(name="test 2", id=2), product_factory.product_public_factory(name="test 2", id=2),
] ]
@@ -60,8 +71,7 @@ class TestProducts:
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
@@ -75,10 +85,22 @@ class TestProducts:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user): def test_get_one(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = product_factory.product_public_factory( mock_result = product_factory.product_public_factory(
name="test 2", id=2) name="test 2", id=2)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'get_one', 'get_one',
@@ -94,28 +116,47 @@ class TestProducts:
mock_session, mock_session,
2 2
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_get_one_notfound(self, client, mocker, mock_session, mock_user): def test_get_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = None mock_result = None
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'get_one', 'get_one',
return_value=mock_result return_value=mock_result
) )
response = client.get('/api/products/2') response = client.get('/api/products/2')
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
mock_session, mock_session,
2 2
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_get_one_unauthorized( def test_get_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
@@ -129,7 +170,13 @@ class TestProducts:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user): def test_create_one(
self,
client,
mocker,
mock_session,
mock_user,
):
product_body = product_factory.product_body_factory( product_body = product_factory.product_body_factory(
name='test product create') name='test product create')
product_create = product_factory.product_create_factory( product_create = product_factory.product_create_factory(
@@ -137,6 +184,11 @@ class TestProducts:
product_result = product_factory.product_public_factory( product_result = product_factory.product_public_factory(
name='test product create') name='test product create')
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'create_one', 'create_one',
@@ -152,13 +204,17 @@ class TestProducts:
mock_session, mock_session,
product_create product_create
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
product=product_create
)
def test_create_one_unauthorized( def test_create_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory( product_body = product_factory.product_body_factory(
@@ -174,14 +230,28 @@ class TestProducts:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user): def test_update_one(
self,
client,
mocker,
mock_session,
mock_user,
):
product_body = product_factory.product_body_factory( product_body = product_factory.product_body_factory(
name='test product update') name='test product update'
)
product_update = product_factory.product_update_factory( product_update = product_factory.product_update_factory(
name='test product update') name='test product update'
)
product_result = product_factory.product_public_factory( product_result = product_factory.product_public_factory(
name='test product update') name='test product update'
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'update_one', 'update_one',
@@ -199,18 +269,31 @@ class TestProducts:
product_update product_update
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_notfound( def test_update_one_notfound(
self, self,
client, client,
mocker, mocker,
mock_session, mock_session,
mock_user): mock_user,
):
product_body = product_factory.product_body_factory( product_body = product_factory.product_body_factory(
name='test product update') name='test product update'
)
product_update = product_factory.product_update_factory( product_update = product_factory.product_update_factory(
name='test product update') name='test product update'
product_result = None )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'update_one', 'update_one',
@@ -218,7 +301,6 @@ class TestProducts:
) )
response = client.put('/api/products/2', json=product_body) response = client.put('/api/products/2', json=product_body)
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
@@ -226,13 +308,17 @@ class TestProducts:
2, 2,
product_update product_update
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_unauthorized( def test_update_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory( product_body = product_factory.product_body_factory(
@@ -248,7 +334,13 @@ class TestProducts:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user): def test_delete_one(
self,
client,
mocker,
mock_session,
mock_user,
):
product_result = product_factory.product_public_factory( product_result = product_factory.product_public_factory(
name='test product delete') name='test product delete')
@@ -258,6 +350,11 @@ class TestProducts:
return_value=product_result return_value=product_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/products/2') response = client.delete('/api/products/2')
response_data = response.json() response_data = response.json()
@@ -267,23 +364,31 @@ class TestProducts:
mock_session, mock_session,
2, 2,
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_notfound( def test_delete_one_notfound(
self, self,
client, client,
mocker, mocker,
mock_session, mock_session,
mock_user): mock_user,
product_result = None ):
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'delete_one', 'delete_one',
side_effect=exceptions.ProductNotFoundError('Product not found') side_effect=exceptions.ProductNotFoundError('Product not found')
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/products/2') response = client.delete('/api/products/2')
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
@@ -291,17 +396,19 @@ class TestProducts:
2, 2,
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_unauthorized( def test_delete_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory(
name='test product delete')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.products.service.delete_one') mock = mocker.patch('src.products.service.delete_one')

View File

@@ -1,15 +1,20 @@
import src.messages as messages
import src.shipments.exceptions as exceptions import src.shipments.exceptions as exceptions
import src.shipments.service as service import src.shipments.service as service
import tests.factories.shipments as shipment_factory import tests.factories.shipments as shipment_factory
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from src import models from src import messages
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.main import app from src.main import app
class TestShipments: class TestShipments:
def test_get_all(self, client, mocker, mock_session, mock_user): def test_get_all(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [ mock_results = [
shipment_factory.shipment_public_factory(name="test 1", id=1), shipment_factory.shipment_public_factory(name="test 1", id=1),
shipment_factory.shipment_public_factory(name="test 2", id=2), shipment_factory.shipment_public_factory(name="test 2", id=2),
@@ -34,7 +39,13 @@ class TestShipments:
[], [],
) )
def test_get_all_filters(self, client, mocker, mock_session, mock_user): def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [ mock_results = [
shipment_factory.shipment_public_factory(name="test 2", id=2), shipment_factory.shipment_public_factory(name="test 2", id=2),
] ]
@@ -62,8 +73,7 @@ class TestShipments:
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
@@ -77,7 +87,13 @@ class TestShipments:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user): def test_get_one(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = shipment_factory.shipment_public_factory( mock_result = shipment_factory.shipment_public_factory(
name="test 2", id=2) name="test 2", id=2)
@@ -86,6 +102,11 @@ class TestShipments:
'get_one', 'get_one',
return_value=mock_result return_value=mock_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.get('/api/shipments/2') response = client.get('/api/shipments/2')
response_data = response.json() response_data = response.json()
@@ -96,28 +117,47 @@ class TestShipments:
mock_session, mock_session,
2 2
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_get_one_notfound(self, client, mocker, mock_session, mock_user): def test_get_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = None mock_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'get_one', 'get_one',
return_value=mock_result return_value=mock_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.get('/api/shipments/2') response = client.get('/api/shipments/2')
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
mock_session, mock_session,
2 2
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_get_one_unauthorized( def test_get_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
@@ -131,19 +171,33 @@ class TestShipments:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user): def test_create_one(
self,
client,
mocker,
mock_session,
mock_user,
):
shipment_body = shipment_factory.shipment_body_factory( shipment_body = shipment_factory.shipment_body_factory(
name='test shipment create') name='test shipment create'
)
shipment_create = shipment_factory.shipment_create_factory( shipment_create = shipment_factory.shipment_create_factory(
name='test shipment create') name='test shipment create'
)
shipment_result = shipment_factory.shipment_public_factory( shipment_result = shipment_factory.shipment_public_factory(
name='test shipment create') name='test shipment create'
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'create_one', 'create_one',
return_value=shipment_result return_value=shipment_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.post('/api/shipments', json=shipment_body) response = client.post('/api/shipments', json=shipment_body)
response_data = response.json() response_data = response.json()
@@ -154,17 +208,22 @@ class TestShipments:
mock_session, mock_session,
shipment_create shipment_create
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
shipment=shipment_create
)
def test_create_one_unauthorized( def test_create_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory( shipment_body = shipment_factory.shipment_body_factory(
name='test shipment create') name='test shipment create'
)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
@@ -176,19 +235,33 @@ class TestShipments:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user): def test_update_one(
self,
client,
mocker,
mock_session,
mock_user,
):
shipment_body = shipment_factory.shipment_body_factory( shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update') name='test shipment update'
)
shipment_update = shipment_factory.shipment_update_factory( shipment_update = shipment_factory.shipment_update_factory(
name='test shipment update') name='test shipment update'
)
shipment_result = shipment_factory.shipment_public_factory( shipment_result = shipment_factory.shipment_public_factory(
name='test shipment update') name='test shipment update'
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'update_one', 'update_one',
return_value=shipment_result return_value=shipment_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/shipments/2', json=shipment_body) response = client.put('/api/shipments/2', json=shipment_body)
response_data = response.json() response_data = response.json()
@@ -200,24 +273,40 @@ class TestShipments:
2, 2,
shipment_update shipment_update
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_notfound( def test_update_one_notfound(
self, self,
client, client,
mocker, mocker,
mock_session, mock_session,
mock_user): mock_user,
):
shipment_body = shipment_factory.shipment_body_factory( shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update') name='test shipment update'
)
shipment_update = shipment_factory.shipment_update_factory( shipment_update = shipment_factory.shipment_update_factory(
name='test shipment update') name='test shipment update'
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, 'update_one', side_effect=exceptions.ShipmentNotFoundError( service,
messages.Messages.not_found('shipment'))) 'update_one',
side_effect=exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/shipments/2', json=shipment_body) response = client.put('/api/shipments/2', json=shipment_body)
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
@@ -225,17 +314,22 @@ class TestShipments:
2, 2,
shipment_update shipment_update
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_unauthorized( def test_update_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory( shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update') name='test shipment update'
)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
@@ -247,15 +341,27 @@ class TestShipments:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user): def test_delete_one(
self,
client,
mocker,
mock_session,
mock_user,
):
shipment_result = shipment_factory.shipment_public_factory( shipment_result = shipment_factory.shipment_public_factory(
name='test shipment delete') name='test shipment delete'
)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'delete_one', 'delete_one',
return_value=shipment_result return_value=shipment_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/shipments/2') response = client.delete('/api/shipments/2')
response_data = response.json() response_data = response.json()
@@ -266,38 +372,52 @@ class TestShipments:
mock_session, mock_session,
2, 2,
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_notfound( def test_delete_one_notfound(
self, self,
client, client,
mocker, mocker,
mock_session, mock_session,
mock_user): mock_user,
shipment_result = None ):
mock = mocker.patch.object( mock = mocker.patch.object(
service, 'delete_one', side_effect=exceptions.ShipmentNotFoundError( service,
messages.Messages.not_found('shipment'))) 'delete_one',
side_effect=exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/shipments/2') response = client.delete('/api/shipments/2')
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
mock_session, mock_session,
2, 2,
) )
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_unauthorized( def test_delete_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment delete')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized

View File

@@ -2,13 +2,18 @@ import src.users.exceptions as exceptions
import src.users.service as service import src.users.service as service
import tests.factories.users as user_factory import tests.factories.users as user_factory
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.main import app from src.main import app
class TestUsers: class TestUsers:
def test_get_all(self, client, mocker, mock_session, mock_user): def test_get_all(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [ mock_results = [
user_factory.user_public_factory(name="test 1", id=1), user_factory.user_public_factory(name="test 1", id=1),
user_factory.user_public_factory(name="test 2", id=2), user_factory.user_public_factory(name="test 2", id=2),
@@ -30,8 +35,15 @@ class TestUsers:
[], [],
[], [],
) )
assert mock_user
def test_get_all_filters(self, client, mocker, mock_session, mock_user): def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [ mock_results = [
user_factory.user_public_factory(name="test 2", id=2), user_factory.user_public_factory(name="test 2", id=2),
] ]
@@ -40,7 +52,6 @@ class TestUsers:
'get_all', 'get_all',
return_value=mock_results return_value=mock_results
) )
response = client.get('/api/users?emails=test@test.test&names=test 2') response = client.get('/api/users?emails=test@test.test&names=test 2')
response_data = response.json() response_data = response.json()
assert response.status_code == 200 assert response.status_code == 200
@@ -51,13 +62,13 @@ class TestUsers:
['test 2'], ['test 2'],
['test@test.test'], ['test@test.test'],
) )
assert mock_user
def test_get_all_unauthorized( def test_get_all_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
@@ -71,7 +82,13 @@ class TestUsers:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user): def test_get_one(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = user_factory.user_public_factory(name="test 2", id=2) mock_result = user_factory.user_public_factory(name="test 2", id=2)
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -79,6 +96,11 @@ class TestUsers:
'get_one', 'get_one',
return_value=mock_result return_value=mock_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.get('/api/users/2') response = client.get('/api/users/2')
response_data = response.json() response_data = response.json()
@@ -89,28 +111,43 @@ class TestUsers:
mock_session, mock_session,
2 2
) )
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_get_one_notfound(self, client, mocker, mock_session, mock_user): def test_get_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = None mock_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'get_one', 'get_one',
return_value=mock_result return_value=mock_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.get('/api/users/2') response = client.get('/api/users/2')
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
mock_session, mock_session,
2 2
) )
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_get_one_unauthorized( def test_get_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
@@ -124,7 +161,13 @@ class TestUsers:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user): def test_create_one(
self,
client,
mocker,
mock_session,
mock_user,
):
user_body = user_factory.user_body_factory(name='test user create') user_body = user_factory.user_body_factory(name='test user create')
user_create = user_factory.user_create_factory(name='test user create') user_create = user_factory.user_create_factory(name='test user create')
user_result = user_factory.user_public_factory(name='test user create') user_result = user_factory.user_public_factory(name='test user create')
@@ -134,6 +177,11 @@ class TestUsers:
'create_one', 'create_one',
return_value=user_result return_value=user_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.post('/api/users', json=user_body) response = client.post('/api/users', json=user_body)
response_data = response.json() response_data = response.json()
@@ -144,13 +192,15 @@ class TestUsers:
mock_session, mock_session,
user_create user_create
) )
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_create_one_unauthorized( def test_create_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
user_body = user_factory.user_body_factory(name='test user create') user_body = user_factory.user_body_factory(name='test user create')
@@ -165,7 +215,13 @@ class TestUsers:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user): def test_update_one(
self,
client,
mocker,
mock_session,
mock_user,
):
user_body = user_factory.user_body_factory(name='test user update') user_body = user_factory.user_body_factory(name='test user update')
user_update = user_factory.user_update_factory(name='test user update') user_update = user_factory.user_update_factory(name='test user update')
user_result = user_factory.user_public_factory(name='test user update') user_result = user_factory.user_public_factory(name='test user update')
@@ -175,6 +231,11 @@ class TestUsers:
'update_one', 'update_one',
return_value=user_result return_value=user_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/users/2', json=user_body) response = client.put('/api/users/2', json=user_body)
response_data = response.json() response_data = response.json()
@@ -186,25 +247,32 @@ class TestUsers:
2, 2,
user_update user_update
) )
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_update_one_notfound( def test_update_one_notfound(
self, self,
client, client,
mocker, mocker,
mock_session, mock_session,
mock_user): mock_user,
):
user_body = user_factory.user_body_factory(name='test user update') user_body = user_factory.user_body_factory(name='test user update')
user_update = user_factory.user_update_factory(name='test user update') user_update = user_factory.user_update_factory(name='test user update')
user_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'update_one', 'update_one',
side_effect=exceptions.UserNotFoundError('User 2 not found') side_effect=exceptions.UserNotFoundError('User 2 not found')
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/users/2', json=user_body) response = client.put('/api/users/2', json=user_body)
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
@@ -212,13 +280,15 @@ class TestUsers:
2, 2,
user_update user_update
) )
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_update_one_unauthorized( def test_update_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
user_body = user_factory.user_body_factory(name='test user update') user_body = user_factory.user_body_factory(name='test user update')
@@ -233,7 +303,13 @@ class TestUsers:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user): def test_delete_one(
self,
client,
mocker,
mock_session,
mock_user,
):
user_result = user_factory.user_public_factory(name='test user delete') user_result = user_factory.user_public_factory(name='test user delete')
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -241,6 +317,11 @@ class TestUsers:
'delete_one', 'delete_one',
return_value=user_result return_value=user_result
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/users/2') response = client.delete('/api/users/2')
response_data = response.json() response_data = response.json()
@@ -251,40 +332,46 @@ class TestUsers:
mock_session, mock_session,
2, 2,
) )
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_delete_one_notfound( def test_delete_one_notfound(
self, self,
client, client,
mocker, mocker,
mock_session, mock_session,
mock_user): mock_user,
user_result = None ):
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
'delete_one', 'delete_one',
side_effect=exceptions.UserNotFoundError('User 2 not found') side_effect=exceptions.UserNotFoundError('User 2 not found')
) )
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/users/2') response = client.delete('/api/users/2')
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
mock_session, mock_session,
2, 2,
) )
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_delete_one_unauthorized( def test_delete_one_unauthorized(
self, self,
client, client,
mocker, mocker,
mock_session, ):
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
user_body = user_factory.user_body_factory(name='test user delete')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.users.service.delete_one') mock = mocker.patch('src.users.service.delete_one')

View File

@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
#
# SPDX-License-Identifier: MIT

View File

@@ -160,6 +160,7 @@
"and/or": "and/or", "and/or": "and/or",
"form name recommandation": "recommendation: Contract <contract-type> (Example: Pork-Lamb Contract)", "form name recommandation": "recommendation: Contract <contract-type> (Example: Pork-Lamb Contract)",
"submit contract": "submit contract", "submit contract": "submit contract",
"submit": "submit",
"example in user forms": "example in user contract form", "example in user forms": "example in user contract form",
"occasional product": "occasional product", "occasional product": "occasional product",
"recurrent product": "recurrent product", "recurrent product": "recurrent product",

View File

@@ -160,6 +160,7 @@
"and/or": "et/ou", "and/or": "et/ou",
"form name recommandation": "recommandation : Contrat <contract-type> (Exemple : Contrat Porc-Agneau)", "form name recommandation": "recommandation : Contrat <contract-type> (Exemple : Contrat Porc-Agneau)",
"submit contract": "envoyer le contrat", "submit contract": "envoyer le contrat",
"submit": "envoyer",
"example in user forms": "exemple dans le formulaire à destination des amapiens", "example in user forms": "exemple dans le formulaire à destination des amapiens",
"occasional product": "produit occasionnel", "occasional product": "produit occasionnel",
"recurrent product": "produit récurrent", "recurrent product": "produit récurrent",

View File

@@ -26,6 +26,9 @@ export function ContractModal({ opened, onClose, handleSubmit }: ContractModalPr
}); });
const formSelect = useMemo(() => { const formSelect = useMemo(() => {
if (!allForms) {
return [];
}
return allForms?.map((form) => ({ return allForms?.map((form) => ({
value: String(form.id), value: String(form.id),
label: `${form.season} ${form.name}`, label: `${form.season} ${form.name}`,

View File

@@ -53,6 +53,8 @@ export default function FormModal({ opened, onClose, currentForm, handleSubmit }
}); });
const usersSelect = useMemo(() => { const usersSelect = useMemo(() => {
if (!users)
return [];
return users?.map((user) => ({ return users?.map((user) => ({
value: String(user.id), value: String(user.id),
label: `${user.name}`, label: `${user.name}`,
@@ -60,6 +62,8 @@ export default function FormModal({ opened, onClose, currentForm, handleSubmit }
}, [users]); }, [users]);
const productorsSelect = useMemo(() => { const productorsSelect = useMemo(() => {
if (!productors)
return [];
return productors?.map((prod) => ({ return productors?.map((prod) => ({
value: String(prod.id), value: String(prod.id),
label: `${prod.name}`, label: `${prod.name}`,

View File

@@ -4,9 +4,12 @@ import "./index.css";
import { Group, Loader } from "@mantine/core"; import { Group, Loader } from "@mantine/core";
import { Config } from "@/config/config"; import { Config } from "@/config/config";
import { useAuth } from "@/services/auth/AuthProvider"; import { useAuth } from "@/services/auth/AuthProvider";
import { useMediaQuery } from "@mantine/hooks";
import { IconHome, IconLogin, IconLogout, IconSettings } from "@tabler/icons-react";
export function Navbar() { export function Navbar() {
const { loggedUser: user, isLoading } = useAuth(); const { loggedUser: user, isLoading } = useAuth();
const isPhone = useMediaQuery("(max-width: 760px");
if (!user && isLoading) { if (!user && isLoading) {
return ( return (
@@ -20,11 +23,11 @@ export function Navbar() {
<nav> <nav>
<Group> <Group>
<NavLink className={"navLink"} aria-label={t("home")} to="/"> <NavLink className={"navLink"} aria-label={t("home")} to="/">
{t("home", { capfirst: true })} {isPhone ? <IconHome/> : t("home", { capfirst: true })}
</NavLink> </NavLink>
{user?.logged ? ( {user?.logged ? (
<NavLink className={"navLink"} aria-label={t("dashboard")} to="/dashboard/help"> <NavLink className={"navLink"} aria-label={t("dashboard")} to="/dashboard/help">
{t("dashboard", { capfirst: true })} {isPhone ? <IconSettings/> : t("dashboard", { capfirst: true })}
</NavLink> </NavLink>
) : null} ) : null}
</Group> </Group>
@@ -34,7 +37,7 @@ export function Navbar() {
className={"navLink"} className={"navLink"}
aria-label={t("login with keycloak", { capfirst: true })} aria-label={t("login with keycloak", { capfirst: true })}
> >
{t("login with keycloak", { capfirst: true })} {isPhone ? <IconLogin/> : t("login with keycloak", { capfirst: true })}
</a> </a>
) : ( ) : (
<a <a
@@ -42,7 +45,7 @@ export function Navbar() {
className={"navLink"} className={"navLink"}
aria-label={t("logout", { capfirst: true })} aria-label={t("logout", { capfirst: true })}
> >
{t("logout", { capfirst: true })} {isPhone ? <IconLogout/> : t("logout", { capfirst: true })}
</a> </a>
)} )}
</nav> </nav>

View File

@@ -19,7 +19,7 @@ import {
type ProductorInputs, type ProductorInputs,
} from "@/services/resources/productors"; } from "@/services/resources/productors";
import { useMemo } from "react"; import { useMemo } from "react";
import { useGetRoles } from "@/services/api"; import { useAuth } from "@/services/auth/AuthProvider";
export type ProductorModalProps = ModalBaseProps & { export type ProductorModalProps = ModalBaseProps & {
currentProductor?: Productor; currentProductor?: Productor;
@@ -32,7 +32,7 @@ export function ProductorModal({
currentProductor, currentProductor,
handleSubmit, handleSubmit,
}: ProductorModalProps) { }: ProductorModalProps) {
const { data: allRoles } = useGetRoles(); const { loggedUser } = useAuth();
const form = useForm<ProductorInputs>({ const form = useForm<ProductorInputs>({
initialValues: { initialValues: {
@@ -58,8 +58,8 @@ export function ProductorModal({
}); });
const roleSelect = useMemo(() => { const roleSelect = useMemo(() => {
return allRoles?.map((role) => ({ value: String(role.name), label: role.name })); return loggedUser?.user?.roles?.map((role) => ({ value: String(role.name), label: role.name }));
}, [allRoles]); }, [loggedUser?.user?.roles]);
return ( return (
<Modal opened={opened} onClose={onClose} title={t("create productor", { capfirst: true })}> <Modal opened={opened} onClose={onClose} title={t("create productor", { capfirst: true })}>

View File

@@ -59,6 +59,8 @@ export function ProductModal({ opened, onClose, currentProduct, handleSubmit }:
}); });
const productorsSelect = useMemo(() => { const productorsSelect = useMemo(() => {
if (!productors)
return [];
return productors?.map((productor) => ({ return productors?.map((productor) => ({
value: String(productor.id), value: String(productor.id),
label: `${productor.name}`, label: `${productor.name}`,

View File

@@ -48,6 +48,8 @@ export default function ShipmentModal({
const { data: allProductors } = useGetProductors(); const { data: allProductors } = useGetProductors();
const formsSelect = useMemo(() => { const formsSelect = useMemo(() => {
if (!allForms)
return [];
return allForms?.map((currentForm) => ({ return allForms?.map((currentForm) => ({
value: String(currentForm.id), value: String(currentForm.id),
label: `${currentForm.name} ${currentForm.season}`, label: `${currentForm.name} ${currentForm.season}`,
@@ -55,7 +57,7 @@ export default function ShipmentModal({
}, [allForms]); }, [allForms]);
const productsSelect = useMemo(() => { const productsSelect = useMemo(() => {
if (!allProducts || !allProductors) return; if (!allProducts || !allProductors) return [];
return allProductors?.map((productor) => { return allProductors?.map((productor) => {
return { return {
group: productor.name, group: productor.name,

View File

@@ -36,6 +36,8 @@ export function UserModal({ opened, onClose, currentUser, handleSubmit }: UserMo
}); });
const roleSelect = useMemo(() => { const roleSelect = useMemo(() => {
if (!allRoles)
return [];
return allRoles?.map((role) => ({ value: String(role.name), label: role.name })); return allRoles?.map((role) => ({ value: String(role.name), label: role.name }));
}, [allRoles]); }, [allRoles]);

View File

@@ -165,7 +165,7 @@ export function Contract() {
); );
return ( return (
<Stack w={{ base: "100%", md: "80%", lg: "50%" }}> <Stack w={{ base: "100%", md: "80%", lg: "50%" }} p={{base: 'xs'}}>
<Title order={2}>{form.name}</Title> <Title order={2}>{form.name}</Title>
<Title order={3}>{t("informations", { capfirst: true })}</Title> <Title order={3}>{t("informations", { capfirst: true })}</Title>
<Text size="sm"> <Text size="sm">
@@ -283,6 +283,10 @@ export function Contract() {
ref={(el) => { ref={(el) => {
inputRefs.current.payment_method = el; inputRefs.current.payment_method = el;
}} }}
comboboxProps={{
withinPortal: false,
position: "bottom-start",
}}
/> />
{inputForm.values.payment_method === "cheque" ? ( {inputForm.values.payment_method === "cheque" ? (
<ContractCheque <ContractCheque
@@ -319,7 +323,7 @@ export function Contract() {
<Button <Button
leftSection={<IconDownload/>} leftSection={<IconDownload/>}
aria-label={t("submit contracts")} onClick={handleSubmit}> aria-label={t("submit contracts")} onClick={handleSubmit}>
{t("submit contract", {capfirst: true})} {t("submit", {capfirst: true})}
</Button> </Button>
</Overlay> </Overlay>
</Stack> </Stack>

View File

@@ -27,6 +27,8 @@ export default function Contracts() {
const { data: allContracts } = useGetContracts(); const { data: allContracts } = useGetContracts();
const forms = useMemo(() => { const forms = useMemo(() => {
if (!allContracts)
return [];
return allContracts return allContracts
?.map((contract: Contract) => contract.form.name) ?.map((contract: Contract) => contract.form.name)
.filter((contract, index, array) => array.indexOf(contract) === index); .filter((contract, index, array) => array.indexOf(contract) === index);

View File

@@ -40,12 +40,16 @@ export default function Productors() {
}, [navigate, searchParams]); }, [navigate, searchParams]);
const names = useMemo(() => { const names = useMemo(() => {
if (!allProductors)
return [];
return allProductors return allProductors
?.map((productor: Productor) => productor.name) ?.map((productor: Productor) => productor.name)
.filter((season, index, array) => array.indexOf(season) === index); .filter((season, index, array) => array.indexOf(season) === index);
}, [allProductors]); }, [allProductors]);
const types = useMemo(() => { const types = useMemo(() => {
if (!allProductors)
return [];
return allProductors return allProductors
?.map((productor: Productor) => productor.type) ?.map((productor: Productor) => productor.type)
.filter((productor, index, array) => array.indexOf(productor) === index); .filter((productor, index, array) => array.indexOf(productor) === index);

View File

@@ -38,12 +38,16 @@ export default function Products() {
const { data: allProducts } = useGetProducts(); const { data: allProducts } = useGetProducts();
const names = useMemo(() => { const names = useMemo(() => {
if (!allProducts)
return [];
return allProducts return allProducts
?.map((product: Product) => product.name) ?.map((product: Product) => product.name)
.filter((season, index, array) => array.indexOf(season) === index); .filter((season, index, array) => array.indexOf(season) === index);
}, [allProducts]); }, [allProducts]);
const productors = useMemo(() => { const productors = useMemo(() => {
if (!allProducts)
return [];
return allProducts return allProducts
?.map((product: Product) => product.productor.name) ?.map((product: Product) => product.productor.name)
.filter((productor, index, array) => array.indexOf(productor) === index); .filter((productor, index, array) => array.indexOf(productor) === index);

View File

@@ -44,12 +44,16 @@ export default function Shipments() {
const { data: allShipments } = useGetShipments(); const { data: allShipments } = useGetShipments();
const names = useMemo(() => { const names = useMemo(() => {
if (!allShipments)
return [];
return allShipments return allShipments
?.map((shipment: Shipment) => shipment.name) ?.map((shipment: Shipment) => shipment.name)
.filter((season, index, array) => array.indexOf(season) === index); .filter((season, index, array) => array.indexOf(season) === index);
}, [allShipments]); }, [allShipments]);
const forms = useMemo(() => { const forms = useMemo(() => {
if (!allShipments)
return [];
return allShipments return allShipments
?.map((shipment: Shipment) => shipment.form.name) ?.map((shipment: Shipment) => shipment.form.name)
.filter((season, index, array) => array.indexOf(season) === index); .filter((season, index, array) => array.indexOf(season) === index);

View File

@@ -36,6 +36,8 @@ export default function Users() {
const { data: allUsers } = useGetUsers(); const { data: allUsers } = useGetUsers();
const names = useMemo(() => { const names = useMemo(() => {
if (!allUsers)
return [];
return allUsers return allUsers
?.map((user: User) => user.name) ?.map((user: User) => user.name)
.filter((season, index, array) => array.indexOf(season) === index); .filter((season, index, array) => array.indexOf(season) === index);