Compare commits

4 Commits

Author SHA1 Message Date
Julien Aldon
5c356f5802 fix header width order 2026-03-05 17:20:44 +01:00
Julien Aldon
ff19448991 add functionnal recap ready for tests 2026-03-05 17:17:23 +01:00
Julien Aldon
3cfa60507e [WIP] add styles 2026-03-03 17:58:33 +01:00
8c6b25ded8 WIP contract recap 2026-02-19 16:19:40 +01:00
13 changed files with 506 additions and 258 deletions

View File

@@ -1,21 +1,20 @@
import secrets
from typing import Annotated from typing import Annotated
from urllib.parse import urlencode from fastapi import APIRouter, Security, HTTPException, Depends, Request, Cookie
import jwt
import requests
import src.messages as messages
import src.users.service as service
from fastapi import (APIRouter, Cookie, Depends, HTTPException, Request,
Security)
from fastapi.responses import RedirectResponse, Response from fastapi.responses import RedirectResponse, Response
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jwt import PyJWKClient
from sqlmodel import Session, select from sqlmodel import Session, select
import jwt
from jwt import PyJWKClient
from src.settings import AUTH_URL, TOKEN_URL, JWKS_URL, ISSUER, LOGOUT_URL, settings
import src.users.service as service
from src.database import get_session from src.database import get_session
from src.models import User, UserCreate, UserPublic from src.models import UserCreate, User, UserPublic
from src.settings import (AUTH_URL, ISSUER, JWKS_URL, LOGOUT_URL, TOKEN_URL,
settings) import secrets
import requests
from urllib.parse import urlencode
import src.messages as messages
router = APIRouter(prefix='/auth') router = APIRouter(prefix='/auth')
@@ -99,7 +98,7 @@ def callback(code: str, session: Session = Depends(get_session)):
'client_secret': settings.keycloak_client_secret, 'client_secret': settings.keycloak_client_secret,
'refresh_token': token_data['refresh_token'], 'refresh_token': token_data['refresh_token'],
} }
requests.post(LOGOUT_URL, data=data) res = requests.post(LOGOUT_URL, data=data)
resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true') resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true')
return resp return resp
roles = resource_access.get(settings.keycloak_client_id) roles = resource_access.get(settings.keycloak_client_id)
@@ -109,7 +108,7 @@ def callback(code: str, session: Session = Depends(get_session)):
'client_secret': settings.keycloak_client_secret, 'client_secret': settings.keycloak_client_secret,
'refresh_token': token_data['refresh_token'], 'refresh_token': token_data['refresh_token'],
} }
requests.post(LOGOUT_URL, data=data) res = requests.post(LOGOUT_URL, data=data)
resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true') resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true')
return resp return resp
@@ -161,15 +160,12 @@ def verify_token(token: str):
) )
return decoded return decoded
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
raise HTTPException( raise HTTPException(status_code=401,
status_code=401, detail=messages.Messages.tokenexipired)
detail=messages.Messages.tokenexipired
)
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail=messages.Messages.invalidtoken detail=messages.Messages.invalidtoken)
)
def get_current_user( def get_current_user(
@@ -177,30 +173,21 @@ def get_current_user(
session: Session = Depends(get_session)): session: Session = Depends(get_session)):
access_token = request.cookies.get('access_token') access_token = request.cookies.get('access_token')
if not access_token: if not access_token:
raise HTTPException( raise HTTPException(status_code=401,
status_code=401, detail=messages.Messages.notauthenticated)
detail=messages.Messages.notauthenticated
)
payload = verify_token(access_token) payload = verify_token(access_token)
if not payload: if not payload:
raise HTTPException( raise HTTPException(status_code=401, detail='aze')
status_code=401,
detail='aze'
)
email = payload.get('email') email = payload.get('email')
if not email: if not email:
raise HTTPException( raise HTTPException(status_code=401,
status_code=401, detail=messages.Messages.notauthenticated)
detail=messages.Messages.notauthenticated
)
user = session.exec(select(User).where(User.email == email)).first() user = session.exec(select(User).where(User.email == email)).first()
if not user: if not user:
raise HTTPException( raise HTTPException(status_code=401,
status_code=401, detail=messages.Messages.not_found('user'))
detail=messages.Messages.not_found('user')
)
return user return user
@@ -262,6 +249,6 @@ def me(user: UserPublic = Depends(get_current_user)):
'name': user.name, 'name': user.name,
'email': user.email, 'email': user.email,
'id': user.id, 'id': user.id,
'roles': user.roles 'roles': [role.name for role in user.roles]
} }
} }

View File

@@ -17,88 +17,6 @@ from src.database import get_session
router = APIRouter(prefix='/contracts') router = APIRouter(prefix='/contracts')
def compute_recurrent_prices(
products_quantities: list[dict],
nb_shipment: int
):
"""Compute price for recurrent products"""
result = 0
for product_quantity in products_quantities:
product = product_quantity['product']
quantity = product_quantity['quantity']
result += compute_product_price(product, quantity, nb_shipment)
return result
def compute_occasional_prices(occasionals: list[dict]):
"""Compute prices for occassional products"""
result = 0
for occasional in occasionals:
result += occasional['price']
return result
def compute_product_price(
product: models.Product,
quantity: int,
nb_shipment: int = 1
):
"""Compute price for a product"""
product_quantity_unit = (
1 if product.unit == models.Unit.KILO else 1000
)
final_quantity = (
quantity if product.price else quantity / product_quantity_unit
)
final_price = (
product.price if product.price else product.price_kg
)
return final_price * final_quantity * nb_shipment
def find_dict_in_list(lst, key, value):
"""Find the index of a dictionnary in a list of dictionnaries given a key
and a value.
"""
for i, dic in enumerate(lst):
if dic[key].id == value:
return i
return -1
def create_occasional_dict(contract_products: list[models.ContractProduct]):
"""Create a dictionnary of occasional products"""
result = []
for contract_product in contract_products:
existing_id = find_dict_in_list(
result,
'shipment',
contract_product.shipment.id
)
if existing_id < 0:
result.append({
'shipment': contract_product.shipment,
'price': compute_product_price(
contract_product.product,
contract_product.quantity
),
'products': [{
'product': contract_product.product,
'quantity': contract_product.quantity
}]
})
else:
result[existing_id]['products'].append({
'product': contract_product.product,
'quantity': contract_product.quantity
})
result[existing_id]['price'] += compute_product_price(
contract_product.product,
contract_product.quantity
)
return result
@router.post('') @router.post('')
async def create_contract( async def create_contract(
contract: models.ContractCreate, contract: models.ContractCreate,
@@ -114,7 +32,7 @@ async def create_contract(
new_contract.products new_contract.products
) )
) )
occasionals = create_occasional_dict(occasional_contract_products) occasionals = service.create_occasional_dict(occasional_contract_products)
recurrents = list( recurrents = list(
map( map(
lambda x: {'product': x.product, 'quantity': x.quantity}, lambda x: {'product': x.product, 'quantity': x.quantity},
@@ -127,11 +45,13 @@ async def create_contract(
) )
) )
) )
recurrent_price = compute_recurrent_prices( prices = service.generate_products_prices(
occasionals,
recurrents, recurrents,
len(new_contract.form.shipments) new_contract.form.shipments
) )
price = recurrent_price + compute_occasional_prices(occasionals) recurrent_price = prices['recurrent']
total_price = prices['total']
cheques = list( cheques = list(
map( map(
lambda x: {'name': x.name, 'value': x.value}, lambda x: {'name': x.name, 'value': x.value},
@@ -145,7 +65,7 @@ async def create_contract(
occasionals, occasionals,
recurrents, recurrents,
'{:10.2f}'.format(recurrent_price), '{:10.2f}'.format(recurrent_price),
'{:10.2f}'.format(price) '{:10.2f}'.format(total_price)
) )
pdf_file = io.BytesIO(pdf_bytes) pdf_file = io.BytesIO(pdf_bytes)
contract_id = ( contract_id = (
@@ -154,7 +74,8 @@ async def create_contract(
f'{new_contract.form.productor.type}_' f'{new_contract.form.productor.type}_'
f'{new_contract.form.season}' f'{new_contract.form.season}'
) )
service.add_contract_file(session, new_contract.id, pdf_bytes, price) service.add_contract_file(
session, new_contract.id, pdf_bytes, total_price)
except Exception as error: except Exception as error:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,

View File

@@ -1,11 +1,16 @@
import html import html
import io import io
import math
import pathlib import pathlib
import string
import jinja2 import jinja2
from odfdo import Cell, Document, Row, Table import odfdo
# from odfdo import Cell, Document, Row, Style, Table
from odfdo.element import Element
from src import models from src import models
from src.contracts import service
from weasyprint import HTML from weasyprint import HTML
@@ -62,20 +67,345 @@ def generate_html_contract(
).write_pdf() ).write_pdf()
def flatten(xss):
return [x for xs in xss for x in xs]
def create_column_style_width(size: str) -> odfdo.Style:
"""Create a table columm style for a given width.
Paramenters:
size(str): size of the style (format <number><unit>) unit can be in, cm... see odfdo documentation.
Returns:
odfdo.Style with the correct column-width attribute.
"""
return odfdo.Element.from_tag(
'<style:style style:name="product-table.A" style:family="table-column">'
f'<style:table-column-properties style:column-width="{size}"/>'
'</style:style>'
)
def create_row_style_height(size: str) -> odfdo.Style:
"""Create a table height style for a given height.
Paramenters:
size(str): size of the style (format <number><unit>) unit can be in, cm... see odfdo documentation.
Returns:
odfdo.Style with the correct column-height attribute.
"""
return odfdo.Element.from_tag(
'<style:style style:name="product-table.A" style:family="table-row">'
f'<style:table-row-properties style:row-height="{size}"/>'
'</style:style>'
)
def create_cell_style(
name: str = "centered-cell",
font_size: str = '10pt',
bold: bool = False,
background_color: str = '#FFFFFF',
color: str = '#000000'
) -> odfdo.Style:
bold_attr = """
fo:font-weight="bold"
style:font-weight-asian="bold"
style:font-weight-complex="bold"
""" if bold else ''
return odfdo.Element.from_tag(
f"""<style:style style:name="{name}" style:family="table-cell">
<style:table-cell-properties
fo:border="0.75pt solid #000000"
style:vertical-align="middle"
fo:wrap-option="wrap"
fo:background-color="{background_color}"/>
<style:paragraph-properties fo:text-align="center"/>
<style:text-properties
{bold_attr}
fo:font-size="{font_size}"
fo:color="{color}"/>
</style:style>"""
)
def apply_cell_style(document: odfdo.Document, table: odfdo.Table):
header_style = document.insert_style(
create_cell_style(
name="header-cells",
bold=True,
font_size='12pt',
background_color="#3480eb",
color="#FFF"
)
)
body_style_even = document.insert_style(
create_cell_style(
name="body-style-even",
bold=False,
background_color="#e8eaed",
color="#000000"
)
)
body_style_odd = document.insert_style(
create_cell_style(
name="body-style-odd",
bold=False,
background_color="#FFFFFF",
color="#000000"
)
)
footer_style = document.insert_style(
create_cell_style(
name="footer-cells",
bold=True,
font_size='12pt',
)
)
for index, row in enumerate(table.get_rows()):
style = body_style_even
if index == 0 or index == 1:
style = header_style
elif index % 2 == 0:
style = body_style_even
elif index == len(table.get_rows()) - 1:
style = footer_style
else:
style = body_style_odd
for cell in row.get_cells():
cell.style = style
def apply_column_height_style(document: odfdo.Document, table: odfdo.Table):
header_style = document.insert_style(
style=create_row_style_height('1.60cm'), name='1.60cm', automatic=True
)
body_style = document.insert_style(
style=create_row_style_height('0.90cm'), name='0.90cm', automatic=True
)
for index, row in enumerate(table.get_rows()):
if index == 1:
row.style = header_style
else:
row.style = body_style
def apply_column_width_style(document: odfdo.Document, table: odfdo.Table, widths: list[str]):
"""Apply column width style to a table.
Parameters:
document(odfdo.Document): Document where the table is located.
table(odfdo.Table): Table to apply columns widths.
widths(list[str]): list of width in format <number><unit> unit ca be in, cm... see odfdo documentation.
"""
styles = []
for w in widths:
styles.append(document.insert_style(
style=create_column_style_width(w), name=w, automatic=True)
)
for position in range(table.width):
col = table.get_column(position)
col.style = styles[position]
table.set_column(position, col)
def generate_ods_letters(n: int):
letters = string.ascii_lowercase
result = []
for i in range(n):
if i > len(letters) - 1:
letter = f'{letters[int(i / len(letters)) - 1]}'
letter += f'{letters[i % len(letters)]}'
result.append(letter)
continue
letter = letters[i]
result.append(letters[i])
return result
def compute_contract_prices(contract: models.Contract) -> dict:
occasional_contract_products = list(
filter(
lambda contract_product: (
contract_product.product.type == models.ProductType.OCCASIONAL
),
contract.products
)
)
occasionals_dict = service.create_occasional_dict(
occasional_contract_products)
recurrents_dict = list(
map(
lambda x: {'product': x.product, 'quantity': x.quantity},
filter(
lambda contract_product: (
contract_product.product.type ==
models.ProductType.RECCURENT
),
contract.products
)
)
)
prices = service.generate_products_prices(
occasionals_dict,
recurrents_dict,
contract.form.shipments
)
return prices
def generate_recap( def generate_recap(
contracts: list[models.Contract], contracts: list[models.Contract],
form: models.Form, form: models.Form,
): ):
data = [ product_unit_map = {
["nom", "email"], '1': 'g',
'2': 'Kg',
'3': 'Piece'
}
recurrents = [
f'{pr.name}({product_unit_map[pr.unit]})'
for pr in form.productor.products
if pr.type == models.ProductType.RECCURENT
] ]
doc = Document("spreadsheet") recurrents.sort()
sheet = Table(name="Recap") occasionnals = [
f'{pr.name}({product_unit_map[pr.unit]})'
for pr in form.productor.products
if pr.type == models.ProductType.OCCASIONAL
]
occasionnals.sort()
shipments = form.shipments
occasionnals_header = [
occ for shipment in shipments for occ in occasionnals
]
info_header: list[str] = ['', 'Nom', 'Email']
cheque_header: list[str] = ['Cheque 1', 'Cheque 2', 'Cheque 3']
payment_header = (
cheque_header +
[f'Total {len(shipments)} livraisons + produits occasionnels']
)
prefix_header: list[str] = (
info_header +
payment_header
)
suffix_header: list[str] = [
'Total produits occasionnels',
'Remarques',
'Nom'
]
shipment_header = flatten([
[f'{shipment.name} - {shipment.date.strftime('%Y-%m-%d')}'] +
['' * len(occasionnals)] for shipment in shipments] +
[''] * len(suffix_header)
)
header: list[str] = (
prefix_header +
recurrents +
['Total produits récurrents'] +
occasionnals_header +
suffix_header
)
letters = generate_ods_letters(len(header))
payment_formula_letters = letters[
len(info_header):len(info_header) + len(payment_header)
]
recurent_formula_letters = letters[
len(info_header)+len(payment_formula_letters):
len(info_header)+len(payment_formula_letters)+len(recurrents) + 1
]
occasionnals_formula_letters = letters[
len(info_header)+len(payment_formula_letters)+len(recurent_formula_letters):
len(info_header)+len(payment_formula_letters) +
len(recurent_formula_letters)+len(occasionnals_header) + 1
]
print(payment_formula_letters)
print(recurent_formula_letters)
print(occasionnals_formula_letters)
footer = (
['', 'Total contrats', ''] +
[f'=SUM({letter}3:{letter}{2+len(contracts)})'
for letter in payment_formula_letters] +
[f'=SUM({letter}3:{letter}{2+len(contracts)})'
for letter in recurent_formula_letters] +
[f'=SUM({letter}3:{letter}{2+len(contracts)})'
for letter in occasionnals_formula_letters]
)
data = [
[''] * (len(prefix_header) + len(recurrents) + 1) + shipment_header,
header,
*[
[
f'{index + 1}',
f'{contract.firstname} {contract.lastname}',
f'{contract.email}',
*[float(contract.cheques[i].value) if len(
contract.cheques) > i else '' for i in range(3)],
compute_contract_prices(contract)['total'],
*[pr.quantity for pr in sorted(
contract.products, key=lambda x: x.product.name)
if pr.product.type == models.ProductType.RECCURENT],
compute_contract_prices(contract)['recurrent'],
*[pr.quantity for pr in sorted(
contract.products, key=lambda x: x.product.name)
if pr.product.type == models.ProductType.OCCASIONAL],
compute_contract_prices(contract)['occasionnal'],
'',
f'{contract.firstname} {contract.lastname}',
] for index, contract in enumerate(contracts)
],
footer
]
doc = odfdo.Document('spreadsheet')
sheet = doc.body.get_sheet(0)
sheet.name = 'Recap'
sheet.set_values(data) sheet.set_values(data)
index = len(prefix_header) + len(recurrents) + 1
for _ in enumerate(shipments):
startcol = index
endcol = index+len(occasionnals) - 1
sheet.set_span((startcol, 0, endcol, 0), merge=True)
index += len(occasionnals)
for row in sheet.get_rows():
for cell in row.get_cells():
if not cell.value or cell.get_attribute("office:value-type") == "float":
continue
if '=' in cell.value:
formula = cell.value
cell.clear()
cell.formula = formula
apply_column_width_style(
doc,
doc.body.get_table(0),
['2cm'] +
['4cm'] * 2 +
['2.40cm'] * (len(payment_header) - 1) +
['4cm'] * len(recurrents) +
['4cm'] +
['4cm'] * (len(occasionnals_header) + 1) +
['4cm', '8cm', '4cm']
)
apply_column_height_style(
doc,
doc.body.get_table(0),
)
apply_cell_style(doc, doc.body.get_table(0))
doc.body.append(sheet) doc.body.append(sheet)
buffer = io.BytesIO() buffer = io.BytesIO()
doc.save(buffer) doc.save(buffer)
# doc.save('test.ods')
return buffer.getvalue() return buffer.getvalue()

View File

@@ -166,3 +166,103 @@ def is_allowed(
.distinct() .distinct()
) )
return len(session.exec(statement).all()) > 0 return len(session.exec(statement).all()) > 0
def compute_recurrent_prices(
products_quantities: list[dict],
nb_shipment: int
):
"""Compute price for recurrent products"""
result = 0
for product_quantity in products_quantities:
product = product_quantity['product']
quantity = product_quantity['quantity']
result += compute_product_price(product, quantity, nb_shipment)
return result
def compute_occasional_prices(occasionals: list[dict]):
"""Compute prices for occassional products"""
result = 0
for occasional in occasionals:
result += occasional['price']
return result
def compute_product_price(
product: models.Product,
quantity: int,
nb_shipment: int = 1
):
"""Compute price for a product"""
product_quantity_unit = (
1 if product.unit == models.Unit.KILO else 1000
)
final_quantity = (
quantity if product.price else quantity / product_quantity_unit
)
final_price = (
product.price if product.price else product.price_kg
)
return final_price * final_quantity * nb_shipment
def find_dict_in_list(lst, key, value):
"""Find the index of a dictionnary in a list of dictionnaries given a key
and a value.
"""
for i, dic in enumerate(lst):
if dic[key].id == value:
return i
return -1
def create_occasional_dict(contract_products: list[models.ContractProduct]):
"""Create a dictionnary of occasional products"""
result = []
for contract_product in contract_products:
existing_id = find_dict_in_list(
result,
'shipment',
contract_product.shipment.id
)
if existing_id < 0:
result.append({
'shipment': contract_product.shipment,
'price': compute_product_price(
contract_product.product,
contract_product.quantity
),
'products': [{
'product': contract_product.product,
'quantity': contract_product.quantity
}]
})
else:
result[existing_id]['products'].append({
'product': contract_product.product,
'quantity': contract_product.quantity
})
result[existing_id]['price'] += compute_product_price(
contract_product.product,
contract_product.quantity
)
return result
def generate_products_prices(
occasionals: list[dict],
recurrents: list[dict],
shipments: list[models.ShipmentPublic]
):
recurrent_price = compute_recurrent_prices(
recurrents,
len(shipments)
)
occasional_price = compute_occasional_prices(occasionals)
price = recurrent_price + occasional_price
return {
'total': price,
'recurrent': recurrent_price,
'occasionnal': occasional_price
}

View File

@@ -32,10 +32,7 @@ async def get_forms_filtered(
@router.get('/{_id}', response_model=models.FormPublic) @router.get('/{_id}', response_model=models.FormPublic)
async def get_form( async def get_form(_id: int, session: Session = Depends(get_session)):
_id: int,
session: Session = Depends(get_session)
):
result = service.get_one(session, _id) result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException( raise HTTPException(
@@ -51,11 +48,6 @@ async def create_form(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, form=form):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('forms', 'update')
)
try: try:
form = service.create_one(session, form) form = service.create_one(session, form)
except exceptions.ProductorNotFoundError as error: except exceptions.ProductorNotFoundError as error:
@@ -69,16 +61,10 @@ async def create_form(
@router.put('/{_id}', response_model=models.FormPublic) @router.put('/{_id}', response_model=models.FormPublic)
async def update_form( async def update_form(
_id: int, _id: int, form: models.FormUpdate,
form: models.FormUpdate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('forms', 'update')
)
try: try:
result = service.update_one(session, _id, form) result = service.update_one(session, _id, form)
except exceptions.FormNotFoundError as error: except exceptions.FormNotFoundError as error:
@@ -96,11 +82,6 @@ async def delete_form(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('forms', 'delete')
)
try: try:
result = service.delete_one(session, _id) result = service.delete_one(session, _id)
except exceptions.FormNotFoundError as error: except exceptions.FormNotFoundError as error:

View File

@@ -108,25 +108,12 @@ def delete_one(session: Session, _id: int) -> models.FormPublic:
return result return result
def is_allowed( def is_allowed(session: Session, user: models.User, _id: int) -> bool:
session: Session,
user: models.User,
_id: int = None,
form: models.FormCreate = None
) -> bool:
if not _id:
statement = (
select(models.Productor)
.where(models.Productor.id == form.productor_id)
)
productor = session.exec(statement).first()
return productor.type in [r.name for r in user.roles]
statement = ( statement = (
select(models.Form) select(models.Form)
.join( .join(
models.Productor, models.Productor,
models.Form.productor_id == models.Productor.id models.Form.productor_id == models.Productor.id)
)
.where(models.Form.id == _id) .where(models.Form.id == _id)
.where( .where(
models.Productor.type.in_( models.Productor.type.in_(

View File

@@ -92,19 +92,3 @@ def delete_one(session: Session, id: int) -> models.ProductorPublic:
session.delete(productor) session.delete(productor)
session.commit() session.commit()
return result return result
def is_allowed(
session: Session,
user: models.User,
_id: int,
productor: models.ProductorCreate
) -> bool:
if not _id:
return productor.type in [r.name for r in user.roles]
statement = (
select(models.Productor)
.where(models.Productor.id == _id)
.where(models.Productor.type.in_([r.name for r in user.roles]))
.distinct()
)
return len(session.exec(statement).all()) > 0

View File

@@ -85,32 +85,3 @@ def delete_one(session: Session, id: int) -> models.ProductPublic:
session.delete(product) session.delete(product)
session.commit() session.commit()
return result return result
def is_allowed(
session: Session,
user: models.User,
_id: int,
product: models.ProductCreate
) -> bool:
if not _id:
statement = (
select(models.Product)
.join(
models.Productor,
models.Product.productor_id == models.Productor.id
)
.where(models.Product.id == product.productor_id)
)
productor = session.exec(statement).first()
return productor.type in [r.name for r in user.roles]
statement = (
select(models.Product)
.join(
models.Productor,
models.Product.productor_id == models.Productor.id
)
.where(models.Product.id == _id)
.where(models.Productor.type.in_([r.name for r in user.roles]))
.distinct()
)
return len(session.exec(statement).all()) > 0

View File

@@ -56,9 +56,7 @@ def get_or_create_user(session: Session, user_create: models.UserCreate):
def get_roles(session: Session): def get_roles(session: Session):
statement = ( statement = select(models.ContractType)
select(models.ContractType)
)
return session.exec(statement.order_by(models.ContractType.name)).all() return session.exec(statement.order_by(models.ContractType.name)).all()
@@ -66,9 +64,7 @@ def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
if user is None: if user is None:
raise exceptions.UserCreateError( raise exceptions.UserCreateError(
messages.Messages.invalid_input( messages.Messages.invalid_input(
'user', 'input cannot be None' 'user', 'input cannot be None'))
)
)
new_user = models.User( new_user = models.User(
name=user.name, name=user.name,
email=user.email email=user.email
@@ -85,19 +81,17 @@ def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
def update_one( def update_one(
session: Session, session: Session,
_id: int, id: int,
user: models.UserCreate) -> models.UserPublic: user: models.UserCreate) -> models.UserPublic:
if user is None: if user is None:
raise exceptions.UserCreateError( raise exceptions.UserCreateError(
messages.Messages.invalid_input( messages.s.invalid_input(
'user', 'input cannot be None' 'user', 'input cannot be None'))
) statement = select(models.User).where(models.User.id == id)
)
statement = select(models.User).where(models.User.id == _id)
result = session.exec(statement) result = session.exec(statement)
new_user = result.first() new_user = result.first()
if not new_user: if not new_user:
raise exceptions.UserNotFoundError(f'User {_id} not found') raise exceptions.UserNotFoundError(f'User {id} not found')
new_user.email = user.email new_user.email = user.email
new_user.name = user.name new_user.name = user.name
@@ -109,12 +103,12 @@ def update_one(
return new_user return new_user
def delete_one(session: Session, _id: int) -> models.UserPublic: def delete_one(session: Session, id: int) -> models.UserPublic:
statement = select(models.User).where(models.User.id == _id) statement = select(models.User).where(models.User.id == id)
result = session.exec(statement) result = session.exec(statement)
user = result.first() user = result.first()
if not user: if not user:
raise exceptions.UserNotFoundError(f'User {_id} not found') raise exceptions.UserNotFoundError(f'User {id} not found')
result = models.UserPublic.model_validate(user) result = models.UserPublic.model_validate(user)
session.delete(user) session.delete(user)
session.commit() session.commit()

View File

@@ -32,18 +32,16 @@ def get_roles(
return service.get_roles(session) return service.get_roles(session)
@router.get('/{_id}', response_model=models.UserPublic) @router.get('/{id}', response_model=models.UserPublic)
def get_user( def get_users(
_id: int, id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.get_one(session, _id) result = service.get_one(session, id)
if result is None: if result is None:
raise HTTPException( raise HTTPException(status_code=404,
status_code=404, detail=messages.Messages.not_found('user'))
detail=messages.Messages.not_found('user')
)
return result return result
@@ -56,27 +54,22 @@ def create_user(
try: try:
user = service.create_one(session, user) user = service.create_one(session, user)
except exceptions.UserCreateError as error: except exceptions.UserCreateError as error:
raise HTTPException( raise HTTPException(status_code=400, detail=str(error))
status_code=400,
detail=str(error)
) from error
return user return user
@router.put('/{_id}', response_model=models.UserPublic) @router.put('/{id}', response_model=models.UserPublic)
def update_user( def update_user(
_id: int, id: int,
user: models.UserUpdate, user: models.UserUpdate,
logged_user: models.User = Depends(get_current_user), logged_user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
try: try:
result = service.update_one(session, _id, user) result = service.update_one(session, id, user)
except exceptions.UserNotFoundError as error: except exceptions.UserNotFoundError as error:
raise HTTPException( raise HTTPException(status_code=404,
status_code=404, detail=messages.Messages.not_found('user'))
detail=messages.Messages.not_found('user')
) from error
return result return result
@@ -89,8 +82,6 @@ def delete_user(
try: try:
result = service.delete_one(session, id) result = service.delete_one(session, id)
except exceptions.UserNotFoundError as error: except exceptions.UserNotFoundError as error:
raise HTTPException( raise HTTPException(status_code=404,
status_code=404, detail=messages.Messages.not_found('user'))
detail=messages.Messages.not_found('user')
) from error
return result return result

Binary file not shown.

View File

@@ -19,7 +19,7 @@ import {
type ProductorInputs, type ProductorInputs,
} from "@/services/resources/productors"; } from "@/services/resources/productors";
import { useMemo } from "react"; import { useMemo } from "react";
import { useAuth } from "@/services/auth/AuthProvider"; import { useGetRoles } from "@/services/api";
export type ProductorModalProps = ModalBaseProps & { export type ProductorModalProps = ModalBaseProps & {
currentProductor?: Productor; currentProductor?: Productor;
@@ -32,7 +32,7 @@ export function ProductorModal({
currentProductor, currentProductor,
handleSubmit, handleSubmit,
}: ProductorModalProps) { }: ProductorModalProps) {
const { loggedUser } = useAuth(); const { data: allRoles } = useGetRoles();
const form = useForm<ProductorInputs>({ const form = useForm<ProductorInputs>({
initialValues: { initialValues: {
@@ -58,8 +58,8 @@ export function ProductorModal({
}); });
const roleSelect = useMemo(() => { const roleSelect = useMemo(() => {
return loggedUser?.user?.roles?.map((role) => ({ value: String(role.name), label: role.name })); return allRoles?.map((role) => ({ value: String(role.name), label: role.name }));
}, [loggedUser?.user?.roles]); }, [allRoles]);
return ( return (
<Modal opened={opened} onClose={onClose} title={t("create productor", { capfirst: true })}> <Modal opened={opened} onClose={onClose} title={t("create productor", { capfirst: true })}>

View File

@@ -27,6 +27,8 @@ export default function Contracts() {
const { data: allContracts } = useGetContracts(); const { data: allContracts } = useGetContracts();
const forms = useMemo(() => { const forms = useMemo(() => {
if (!allContracts)
return [];
return allContracts return allContracts
?.map((contract: Contract) => contract.form.name) ?.map((contract: Contract) => contract.form.name)
.filter((contract, index, array) => array.indexOf(contract) === index); .filter((contract, index, array) => array.indexOf(contract) === index);
@@ -89,7 +91,7 @@ export default function Contracts() {
label={t("download recap", { capfirst: true })} label={t("download recap", { capfirst: true })}
> >
<ActionIcon <ActionIcon
disabled={true} disabled={false}
onClick={(e) => { onClick={(e) => {
e.stopPropagation(); e.stopPropagation();
navigate( navigate(