[WIP] add styles

This commit is contained in:
Julien Aldon
2026-03-03 17:58:33 +01:00
125 changed files with 5762 additions and 622 deletions

View File

@@ -1,18 +1,27 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from src.database import get_session
from sqlmodel import Session
from src.contracts.generate_contract import generate_html_contract, generate_recap
from src.auth.auth import get_current_user
import src.models as models
import src.messages as messages
import src.contracts.service as service
import src.forms.service as form_service
"""Router for contract resource"""
import io
import zipfile
import src.contracts.service as service
import src.forms.service as form_service
import src.messages as messages
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user
from src.contracts.generate_contract import (generate_html_contract,
generate_recap)
from src.database import get_session
router = APIRouter(prefix='/contracts')
def compute_recurrent_prices(products_quantities: list[dict], nb_shipment: int):
def compute_recurrent_prices(
products_quantities: list[dict],
nb_shipment: int
):
"""Compute price for recurrent products"""
result = 0
for product_quantity in products_quantities:
product = product_quantity['product']
@@ -20,30 +29,50 @@ def compute_recurrent_prices(products_quantities: list[dict], nb_shipment: int):
result += compute_product_price(product, quantity, nb_shipment)
return result
def compute_occasional_prices(occasionals: list[dict]):
"""Compute prices for occassional products"""
result = 0
for occasional in occasionals:
result += occasional['price']
return result
def compute_product_price(product: models.Product, quantity: int, nb_shipment: int = 1):
product_quantity_unit = 1 if product.unit == models.Unit.KILO else 1000
final_quantity = quantity if product.price else quantity / product_quantity_unit
final_price = product.price if product.price else product.price_kg
return final_price * final_quantity * nb_shipment
def compute_product_price(
product: models.Product,
quantity: int,
nb_shipment: int = 1
):
"""Compute price for a product"""
product_quantity_unit = (
1 if product.unit == models.Unit.KILO else 1000
)
final_quantity = (
quantity if product.price else quantity / product_quantity_unit
)
final_price = (
product.price if product.price else product.price_kg
)
return final_price * final_quantity * nb_shipment
def find_dict_in_list(lst, key, value):
"""Find the index of a dictionnary in a list of dictionnaries given a key
and a value.
"""
for i, dic in enumerate(lst):
if dic[key].id == value:
return i
return -1
def create_occasional_dict(contract_products: list[models.ContractProduct]):
"""Create a dictionnary of occasional products"""
result = []
for contract_product in contract_products:
existing_id = find_dict_in_list(
result,
'shipment',
result,
'shipment',
contract_product.shipment.id
)
if existing_id < 0:
@@ -69,71 +98,183 @@ def create_occasional_dict(contract_products: list[models.ContractProduct]):
)
return result
@router.post('/')
@router.post('')
async def create_contract(
contract: models.ContractCreate,
session: Session = Depends(get_session),
):
"""Create contract route"""
new_contract = service.create_one(session, contract)
occasional_contract_products = list(filter(lambda contract_product: contract_product.product.type == models.ProductType.OCCASIONAL, new_contract.products))
occasional_contract_products = list(
filter(
lambda contract_product: (
contract_product.product.type == models.ProductType.OCCASIONAL
),
new_contract.products
)
)
occasionals = create_occasional_dict(occasional_contract_products)
recurrents = list(map(lambda x: {"product": x.product, "quantity": x.quantity}, filter(lambda contract_product: contract_product.product.type == models.ProductType.RECCURENT, new_contract.products)))
recurrent_price = compute_recurrent_prices(recurrents, len(new_contract.form.shipments))
recurrents = list(
map(
lambda x: {'product': x.product, 'quantity': x.quantity},
filter(
lambda contract_product: (
contract_product.product.type ==
models.ProductType.RECCURENT
),
new_contract.products
)
)
)
recurrent_price = compute_recurrent_prices(
recurrents,
len(new_contract.form.shipments)
)
price = recurrent_price + compute_occasional_prices(occasionals)
total_price = '{:10.2f}'.format(price)
cheques = list(map(lambda x: {"name": x.name, "value": x.value}, new_contract.cheques))
# TODO: send contract to referer
cheques = list(
map(
lambda x: {'name': x.name, 'value': x.value},
new_contract.cheques
)
)
try:
pdf_bytes = generate_html_contract(
new_contract,
cheques,
occasionals,
recurrents,
recurrent_price,
total_price
'{:10.2f}'.format(recurrent_price),
'{:10.2f}'.format(price)
)
pdf_file = io.BytesIO(pdf_bytes)
contract_id = f'{new_contract.firstname}_{new_contract.lastname}_{new_contract.form.productor.type}_{new_contract.form.season}'
contract_id = (
f'{new_contract.firstname}_'
f'{new_contract.lastname}_'
f'{new_contract.form.productor.type}_'
f'{new_contract.form.season}'
)
service.add_contract_file(session, new_contract.id, pdf_bytes, price)
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail=messages.pdferror)
except Exception as error:
raise HTTPException(
status_code=400,
detail=messages.pdferror
) from error
return StreamingResponse(
pdf_file,
media_type='application/pdf',
headers={
'Content-Disposition': f'attachment; filename=contract_{contract_id}.pdf'
'Content-Disposition': (
f'attachment; filename=contract_{contract_id}.pdf'
)
}
)
@router.get('/', response_model=list[models.ContractPublic])
@router.get('/{form_id}/base')
async def get_base_contract_template(
form_id: int,
session: Session = Depends(get_session),
):
"""Get contract template route"""
form = form_service.get_one(session, form_id)
recurrents = [
{'product': product, 'quantity': None}
for product in form.productor.products
if product.type == models.ProductType.RECCURENT
]
occasionals = [{
'shipment': sh,
'price': None,
'products': [{'product': pr, 'quantity': None} for pr in sh.products]
} for sh in form.shipments]
empty_contract = models.ContractPublic(
firstname='',
form=form,
lastname='',
email='',
phone='',
products=[],
payment_method='cheque',
cheque_quantity=3,
total_price=0,
id=1
)
cheques = [
{'name': None, 'value': None},
{'name': None, 'value': None},
{'name': None, 'value': None}
]
try:
pdf_bytes = generate_html_contract(
empty_contract,
cheques,
occasionals,
recurrents,
)
pdf_file = io.BytesIO(pdf_bytes)
contract_id = (
f'{empty_contract.form.productor.type}_'
f'{empty_contract.form.season}'
)
except Exception as error:
raise HTTPException(
status_code=400,
detail=messages.pdferror
) from error
return StreamingResponse(
pdf_file,
media_type='application/pdf',
headers={
'Content-Disposition': (
f'attachment; filename=contract_{contract_id}.pdf'
)
}
)
@router.get('', response_model=list[models.ContractPublic])
def get_contracts(
forms: list[str] = Query([]),
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Get all contracts route"""
return service.get_all(session, user, forms)
@router.get('/{id}/file')
@router.get('/{_id}/file')
def get_contract_file(
id: int,
_id: int,
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
if not service.is_allowed(session, user, id):
raise HTTPException(status_code=403, detail=messages.notallowed)
contract = service.get_one(session, id)
"""Get a contract file (in pdf) route"""
if not service.is_allowed(session, user, _id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract', 'get')
)
contract = service.get_one(session, _id)
if contract is None:
raise HTTPException(status_code=404, detail=messages.notfound)
filename = f'{contract.form.name.replace(' ', '_')}_{contract.form.season}_{contract.firstname}-{contract.lastname}'
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('contract')
)
filename = (
f'{contract.form.name.replace(' ', '_')}_'
f'{contract.form.season}_'
f'{contract.firstname}_'
f'{contract.lastname}'
)
return StreamingResponse(
io.BytesIO(contract.file),
media_type='application/pdf',
headers={
'Content-Disposition': f'attachment; filename={filename}.pdf'
}
)
)
@router.get('/{form_id}/files')
def get_contract_files(
@@ -141,17 +282,30 @@ def get_contract_files(
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Get all contract files for a given form"""
if not form_service.is_allowed(session, user, form_id):
raise HTTPException(status_code=403, detail=messages.notallowed)
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contracts', 'get')
)
form = form_service.get_one(session, form_id=form_id)
contracts = service.get_all(session, user, forms=[form.name])
zipped_contracts = io.BytesIO()
with zipfile.ZipFile(zipped_contracts, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
with zipfile.ZipFile(
zipped_contracts,
'a',
zipfile.ZIP_DEFLATED,
False
) as zip_file:
for contract in contracts:
contract_filename = f'{contract.form.name.replace(' ', '_')}_{contract.form.season}_{contract.firstname}-{contract.lastname}.pdf'
contract_filename = (
f'{contract.form.name.replace(' ', '_')}_'
f'{contract.form.season}_'
f'{contract.firstname}_'
f'{contract.lastname}'
)
zip_file.writestr(contract_filename, contract.file)
filename = f'{form.name.replace(" ", "_")}_{form.season}'
filename = f'{form.name.replace(' ', '_')}_{form.season}'
return StreamingResponse(
io.BytesIO(zipped_contracts.getvalue()),
media_type='application/zip',
@@ -160,39 +314,69 @@ def get_contract_files(
}
)
@router.get('/{form_id}/recap')
def get_contract_recap(
form_id: int,
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Get a contract recap for a given form"""
if not form_service.is_allowed(session, user, form_id):
raise HTTPException(status_code=403, detail=messages.notallowed)
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract recap', 'get')
)
form = form_service.get_one(session, form_id=form_id)
contracts = service.get_all(session, user, forms=[form.name])
return StreamingResponse(
io.BytesIO(generate_recap(contracts, form)),
media_type='application/zip',
headers={
'Content-Disposition': f'attachment; filename=filename.ods'
'Content-Disposition': (
'attachment; filename=filename.ods'
)
}
)
@router.get('/{id}', response_model=models.ContractPublic)
def get_contract(id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user)):
if not service.is_allowed(session, user, id):
raise HTTPException(status_code=403, detail=messages.notallowed)
result = service.get_one(session, id)
@router.get('/{_id}', response_model=models.ContractPublic)
def get_contract(
_id: int,
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Get a contract route"""
if not service.is_allowed(session, user, _id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract', 'get')
)
result = service.get_one(session, _id)
if result is None:
raise HTTPException(status_code=404, detail=messages.notfound)
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('contract')
)
return result
@router.delete('/{id}', response_model=models.ContractPublic)
def delete_contract(id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user)):
if not service.is_allowed(session, user, id):
raise HTTPException(status_code=403, detail=messages.notallowed)
result = service.delete_one(session, id)
@router.delete('/{_id}', response_model=models.ContractPublic)
def delete_contract(
_id: int,
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Delete contract route"""
if not service.is_allowed(session, user, _id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract', 'delete')
)
result = service.delete_one(session, _id)
if result is None:
raise HTTPException(status_code=404, detail=messages.notfound)
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('contract')
)
return result

View File

@@ -1,21 +1,28 @@
import html
import io
import pathlib
import jinja2
import src.models as models
import html
import odfdo
# from odfdo import Cell, Document, Row, Style, Table
from odfdo.element import Element
from src import models
from weasyprint import HTML
import io
def generate_html_contract(
contract: models.Contract,
cheques: list[dict],
occasionals: list[dict],
reccurents: list[dict],
recurrent_price: float,
total_price: float
):
template_dir = "./src/contracts/templates"
recurrent_price: float | None = None,
total_price: float | None = None
):
template_dir = pathlib.Path("./src/contracts/templates").resolve()
template_loader = jinja2.FileSystemLoader(searchpath=template_dir)
template_env = jinja2.Environment(loader=template_loader, autoescape=jinja2.select_autoescape(["html", "xml"]))
template_env = jinja2.Environment(
loader=template_loader, autoescape=jinja2.select_autoescape(["html", "xml"]))
template_file = "layout.html"
template = template_env.get_template(template_file)
output_text = template.render(
@@ -26,95 +33,212 @@ def generate_html_contract(
referer_email=contract.form.referer.email,
productor_name=contract.form.productor.name,
productor_address=contract.form.productor.address,
payment_methods_map={"cheque": "Ordre du chèque", "transfer": "virements"},
payment_methods_map={
"cheque": "Ordre du chèque",
"transfer": "virements"},
productor_payment_methods=contract.form.productor.payment_methods,
member_name=f'{html.escape(contract.firstname)} {html.escape(contract.lastname)}',
member_email=html.escape(contract.email),
member_phone=html.escape(contract.phone),
member_name=f'{
html.escape(
contract.firstname)} {
html.escape(
contract.lastname)}',
member_email=html.escape(
contract.email),
member_phone=html.escape(
contract.phone),
contract_start_date=contract.form.start,
contract_end_date=contract.form.end,
occasionals=occasionals,
recurrents=reccurents,
recurrent_price=recurrent_price,
total_price=total_price,
contract_payment_method={"cheque": "chèque", "transfer": "virements"}[contract.payment_method],
cheques=cheques
)
options = {
'page-size': 'Letter',
'margin-top': '0.5in',
'margin-right': '0.5in',
'margin-bottom': '0.5in',
'margin-left': '0.5in',
'encoding': "UTF-8",
'print-media-type': True,
"disable-javascript": True,
"disable-external-links": True,
'enable-local-file-access': False,
"disable-local-file-access": True,
"no-images": True,
}
contract_payment_method={
"cheque": "chèque",
"transfer": "virements"}[
contract.payment_method],
cheques=cheques)
return HTML(
string=output_text,
base_url=template_dir
base_url=template_dir,
).write_pdf()
def flatten(xss):
return [x for xs in xss for x in xs]
from odfdo import Document, Table, Row, Cell
from odfdo.element import Element
def create_column_style_width(size: str) -> odfdo.Style:
"""Create a table columm style for a given width.
Paramenters:
size(str): size of the style (format <number><unit>) unit can be in, cm... see odfdo documentation.
Returns:
odfdo.Style with the correct column-width attribute.
"""
return odfdo.Element.from_tag(
'<style:style style:name="product-table.A" style:family="table-column">'
f'<style:table-column-properties style:column-width="{size}"/>'
'</style:style>'
)
def create_row_style_height(size: str) -> odfdo.Style:
"""Create a table height style for a given height.
Paramenters:
size(str): size of the style (format <number><unit>) unit can be in, cm... see odfdo documentation.
Returns:
odfdo.Style with the correct column-height attribute.
"""
return odfdo.Element.from_tag(
'<style:style style:name="product-table.A" style:family="table-row">'
f'<style:table-row-properties style:row-height="{size}"/>'
'</style:style>'
)
def create_center_cell_style(name: str = "centered-cell") -> odfdo.Style:
return odfdo.Element.from_tag(
f'<style:style style:name="{name}" style:family="table-cell">'
'<style:table-cell-properties style:vertical-align="middle" fo:wrap-option="wrap"/>'
'<style:paragraph-properties fo:text-align="center"/>'
'</style:style>'
)
def create_cell_style_with_font(name: str = "font", font_size="14pt", bold: bool = False) -> odfdo.Style:
return odfdo.Element.from_tag(
f'<style:style style:name="{name}" style:family="table-cell" '
f'xmlns:fo="urn:oasis:names:tc:opendocument:xmlns:xsl-fo-compatible:1.0">'
'<style:table-cell-properties style:vertical-align="middle" fo:wrap-option="wrap"/>'
f'<style:paragraph-properties fo:text-align="center" fo:font-size="{font_size}" '
f'{"fo:font-weight=\"bold\"" if bold else ""}/>'
'</style:style>'
)
def apply_center_cell_style(document: odfdo.Document, row: odfdo.Row):
style = document.insert_style(
create_center_cell_style()
)
for cell in row.get_cells():
cell.style = style
def apply_column_height_style(document: odfdo.Document, row: odfdo.Row, height: str):
style = document.insert_style(
style=create_row_style_height(height), name=height, automatic=True
)
row.style = style
def apply_font_style(document: odfdo.Document, table: odfdo.Table, size: str = "14pt"):
style_header = document.insert_style(
style=create_cell_style_with_font(
'header_font', font_size=size, bold=True
)
)
style_body = document.insert_style(
style=create_cell_style_with_font(
'body_font', font_size=size, bold=False
)
)
for position in range(table.height):
row = table.get_row(position)
for cell in row.get_cells():
cell.style = style_header if position == 0 or position == 1 else style_body
for paragraph in cell.get_paragraphs():
paragraph.style = cell.style
def apply_column_width_style(document: odfdo.Document, table: odfdo.Table, widths: list[str]):
"""Apply column width style to a table.
Parameters:
document(odfdo.Document): Document where the table is located.
table(odfdo.Table): Table to apply columns widths.
widths(list[str]): list of width in format <number><unit> unit ca be in, cm... see odfdo documentation.
"""
styles = []
for w in widths:
styles.append(document.insert_style(
style=create_column_style_width(w), name=w, automatic=True))
for position in range(table.width):
col = table.get_column(position)
col.style = styles[position]
table.set_column(position, col)
def generate_recap(
contracts: list[models.Contract],
form: models.Form,
):
print(form.productor.products)
recurrents = [pr.name for pr in form.productor.products if pr.type == models.ProductType.RECCURENT]
recurrents = [pr.name for pr in form.productor.products if pr.type ==
models.ProductType.RECCURENT]
recurrents.sort()
occasionnals = [pr.name for pr in form.productor.products if pr.type == models.ProductType.OCCASIONAL]
occasionnals = [pr.name for pr in form.productor.products if pr.type ==
models.ProductType.OCCASIONAL]
occasionnals.sort()
shipments = form.shipments
occasionnals_header = [occ for shipment in shipments for occ in occasionnals]
shipment_header = flatten([[shipment.name] + ["" * len(occasionnals)] for shipment in shipments])
occasionnals_header = [
occ for shipment in shipments for occ in occasionnals]
shipment_header = flatten(
[[f'{shipment.name} - {shipment.date.strftime('%Y-%m-%d')}'] + ["" * len(occasionnals)] for shipment in shipments])
product_unit_map = {
"1": "g",
"2": "kg",
"3": "p"
}
header = (
["Nom", "Email"] +
["Tarif panier", "Total Paniers", "Total à payer"] +
["Cheque 1", "Cheque 2", "Cheque 3"] +
[f"Total {len(shipments)} livraisons + produits occasionnels"] +
recurrents +
occasionnals_header +
["Remarques", "Nom"]
)
data = [
["", ""] + ["" * len(recurrents)] + shipment_header,
["nom", "email"] + recurrents + occasionnals_header + ["remarques", "name"],
[""] * (9 + len(recurrents)) + shipment_header,
header,
*[
[
f'{contract.firstname} {contract.lastname}',
f'{contract.email}',
*[f'{pr.quantity} {product_unit_map[pr.product.unit]}' for pr in sorted(contract.products, key=lambda x: x.product.name) if pr.product.type == models.ProductType.RECCURENT],
*[f'{pr.quantity} {product_unit_map[pr.product.unit]}' for pr in sorted(contract.products, key=lambda x: x.product.name) if pr.product.type == models.ProductType.OCCASIONAL],
"",
f'{contract.firstname} {contract.lastname}',
f'{contract.firstname} {contract.lastname}',
f'{contract.email}',
*[f'{pr.quantity} {product_unit_map[pr.product.unit]}' for pr in sorted(
contract.products, key=lambda x: x.product.name) if pr.product.type == models.ProductType.RECCURENT],
*[f'{pr.quantity} {product_unit_map[pr.product.unit]}' for pr in sorted(
contract.products, key=lambda x: x.product.name) if pr.product.type == models.ProductType.OCCASIONAL],
"",
f'{contract.firstname} {contract.lastname}',
] for contract in contracts
]
]
doc = Document("spreadsheet")
sheet = Table(name="Recap")
doc = odfdo.Document("spreadsheet")
sheet = doc.body.get_sheet(0)
sheet.name = 'Recap'
sheet.set_values(data)
apply_column_width_style(doc, doc.body.get_table(0), ["4cm"] * len(header))
apply_column_height_style(
doc,
doc.body.get_table(0).get_rows((1, 1))[0],
"1.20cm"
)
apply_center_cell_style(doc, doc.body.get_table(0).get_rows((1, 1))[0])
apply_font_style(doc, doc.body.get_table(0))
index = 9 + len(recurrents)
for _ in enumerate(shipments):
startcol = index
endcol = index+len(occasionnals) - 1
sheet.set_span((startcol, 0, endcol, 0), merge=True)
index += len(occasionnals)
offset = 0
index = 2 + len(recurrents)
for i in range(len(shipments)):
index = index + offset
print(index, index+len(occasionnals) - 1)
sheet.set_span((index, 0, index+len(occasionnals) - 1, 0), merge=True)
offset += len(occasionnals)
doc.body.append(sheet)
buffer = io.BytesIO()
doc.save(buffer)
doc.save('test.ods')
return buffer.getvalue()

View File

@@ -1,28 +1,57 @@
"""Contract service responsible for read, create, update and delete contracts"""
from sqlalchemy.orm import selectinload
from sqlmodel import Session, select
import src.models as models
from src import models
def get_all(
session: Session,
user: models.User,
forms: list[str] = [],
forms: list[str] | None = None,
form_id: int | None = None,
) -> list[models.ContractPublic]:
statement = select(models.Contract)\
.join(models.Form, models.Contract.form_id == models.Form.id)\
.join(models.Productor, models.Form.productor_id == models.Productor.id)\
.where(models.Productor.type.in_([r.name for r in user.roles]))\
"""Get all contracts"""
statement = (
select(models.Contract)
.join(
models.Form,
models.Contract.form_id == models.Form.id
)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id
)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct()
if len(forms) > 0:
)
if forms:
statement = statement.where(models.Form.name.in_(forms))
if form_id:
statement = statement.where(models.Form.id == form_id)
return session.exec(statement.order_by(models.Contract.id)).all()
def get_one(session: Session, contract_id: int) -> models.ContractPublic:
def get_one(
session: Session,
contract_id: int
) -> models.ContractPublic:
"""Get one contract"""
return session.get(models.Contract, contract_id)
def create_one(session: Session, contract: models.ContractCreate) -> models.ContractPublic:
contract_create = contract.model_dump(exclude_unset=True, exclude=["products", "cheques"])
def create_one(
session: Session,
contract: models.ContractCreate
) -> models.ContractPublic:
"""Create one contract"""
contract_create = contract.model_dump(
exclude_unset=True,
exclude=["products", "cheques"]
)
new_contract = models.Contract(**contract_create)
new_contract.cheques = [
@@ -45,10 +74,27 @@ def create_one(session: Session, contract: models.ContractCreate) -> models.Cont
session.add(new_contract)
session.commit()
session.refresh(new_contract)
return new_contract
def add_contract_file(session: Session, id: int, file: bytes, price: float):
statement = select(models.Contract).where(models.Contract.id == id)
statement = (
select(models.Contract)
.where(models.Contract.id == new_contract.id)
.options(
selectinload(models.Contract.form)
.selectinload(models.Form.productor)
)
)
return session.exec(statement).one()
def add_contract_file(
session: Session,
_id: int,
file: bytes,
price: float
):
"""Add a file to an existing contract"""
statement = select(models.Contract).where(models.Contract.id == _id)
result = session.exec(statement)
contract = result.first()
contract.total_price = price
@@ -58,8 +104,14 @@ def add_contract_file(session: Session, id: int, file: bytes, price: float):
session.refresh(contract)
return contract
def update_one(session: Session, id: int, contract: models.ContractUpdate) -> models.ContractPublic:
statement = select(models.Contract).where(models.Contract.id == id)
def update_one(
session: Session,
_id: int,
contract: models.ContractUpdate
) -> models.ContractPublic:
"""Update one contract"""
statement = select(models.Contract).where(models.Contract.id == _id)
result = session.exec(statement)
new_contract = result.first()
if not new_contract:
@@ -72,8 +124,13 @@ def update_one(session: Session, id: int, contract: models.ContractUpdate) -> mo
session.refresh(new_contract)
return new_contract
def delete_one(session: Session, id: int) -> models.ContractPublic:
statement = select(models.Contract).where(models.Contract.id == id)
def delete_one(
session: Session,
_id: int
) -> models.ContractPublic:
"""Delete one contract"""
statement = select(models.Contract).where(models.Contract.id == _id)
result = session.exec(statement)
contract = result.first()
if not contract:
@@ -83,11 +140,29 @@ def delete_one(session: Session, id: int) -> models.ContractPublic:
session.commit()
return result
def is_allowed(session: Session, user: models.User, id: int) -> bool:
statement = select(models.Contract)\
.join(models.Form, models.Contract.form_id == models.Form.id)\
.join(models.Productor, models.Form.productor_id == models.Productor.id)\
.where(models.Contract.id == id)\
.where(models.Productor.type.in_([r.name for r in user.roles]))\
def is_allowed(
session: Session,
user: models.User,
_id: int
) -> bool:
"""Determine if a user is allowed to access a contract by id"""
statement = (
select(models.Contract)
.join(
models.Form,
models.Contract.form_id == models.Form.id
)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id
)
.where(models.Contract.id == _id)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct()
return len(session.exec(statement).all()) > 0
)
return len(session.exec(statement).all()) > 0

View File

@@ -151,10 +151,6 @@
<th>Saison du contrat</th>
<td>{{contract_season}}</td>
</tr>
<tr>
<th>Type de contrat</th>
<td>{{contract_type}}</td>
</tr>
<tr>
<th>Référent·e</th>
<td>{{referer_name}}</td>
@@ -278,14 +274,14 @@
else ""}}
</td>
<td>
{{rec.quantity}}{{"g" if rec.product.unit == "1" else "kg" if
{{rec.quantity if rec.quantity != None else ""}}{{"g" if rec.product.unit == "1" else "kg" if
rec.product.unit == "2" else "p" }}
</td>
</tr>
{% endfor %}
<tr>
<th scope="row" colspan="4">Total</th>
<td>{{recurrent_price}}€</td>
<td>{{recurrent_price if recurrent_price else ""}}€</td>
</tr>
</tbody>
</table>
@@ -321,14 +317,15 @@
product.product.quantity_unit != None else ""}}
</td>
<td>
{{product.quantity}}{{"g" if product.product.unit == "1" else
{{product.quantity if product.quantity != None
else ""}}{{"g" if product.product.unit == "1" else
"kg" if product.product.unit == "2" else "p" }}
</td>
</tr>
{% endfor%}
<tr>
<th scope="row" colspan="4">Total</th>
<td>{{occasional.price}}€</td>
<td>{{occasional.price if occasional.price else ""}}€</td>
</tr>
</tbody>
</table>
@@ -337,7 +334,7 @@
{% endif %}
<div class="total-box">
<div class="total-label">Prix Total :</div>
<div class="total-price">{{total_price}}€</div>
<div class="total-price">{{total_price if total_price else ""}}€</div>
</div>
<h4>Paiement par {{contract_payment_method}}</h4>
{% if contract_payment_method == "chèque" %}
@@ -346,14 +343,14 @@
<thead>
<tr>
{% for cheque in cheques %}
<th>Cheque n°{{cheque.name}}</th>
<th>Cheque n°{{cheque.name if cheque.name else ""}}</th>
{% endfor %}
</tr>
</thead>
<tbody>
<tr>
{% for cheque in cheques %}
<td>{{cheque.value}}€</td>
<td>{{cheque.value if cheque.value else ""}}€</td>
{% endfor %}
</tr>
</tbody>