add logout logic and wip recap
This commit is contained in:
@@ -29,6 +29,7 @@ dependencies = [
|
||||
"cryptography",
|
||||
"requests",
|
||||
"weasyprint",
|
||||
"odfdo"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.models import UserCreate, User, UserPublic
|
||||
|
||||
import secrets
|
||||
import requests
|
||||
|
||||
from urllib.parse import urlencode
|
||||
import src.messages as messages
|
||||
|
||||
router = APIRouter(prefix='/auth')
|
||||
@@ -23,24 +23,13 @@ security = HTTPBearer()
|
||||
|
||||
@router.get('/logout')
|
||||
def logout(
|
||||
id_token: Annotated[str | None, Cookie()] = None,
|
||||
refresh_token: Annotated[str | None, Cookie()] = None,
|
||||
):
|
||||
if refresh_token:
|
||||
print("invalidate tokens")
|
||||
requests.post(LOGOUT_URL, data={
|
||||
"client_id": settings.keycloak_client_id,
|
||||
"client_secret": settings.keycloak_client_secret,
|
||||
"refresh_token": refresh_token
|
||||
})
|
||||
|
||||
if id_token:
|
||||
print("redirect keycloak")
|
||||
response = RedirectResponse(f'{LOGOUT_URL}?post_logout_redirect_uri={settings.origins}&id_token_hint={id_token}')
|
||||
else:
|
||||
response = RedirectResponse(settings.origins)
|
||||
|
||||
print("clear cookies")
|
||||
params = {
|
||||
'client_id': settings.keycloak_client_id,
|
||||
'post_logout_redirect_uri': settings.origins,
|
||||
}
|
||||
response = RedirectResponse(f'{LOGOUT_URL}?{urlencode(params)}')
|
||||
response.delete_cookie(
|
||||
key='access_token',
|
||||
path='/',
|
||||
@@ -59,6 +48,12 @@ def logout(
|
||||
secure=not settings.debug,
|
||||
samesite='lax',
|
||||
)
|
||||
# if refresh_token:
|
||||
# requests.post(LOGOUT_URL, data={
|
||||
# 'client_id': settings.keycloak_client_id,
|
||||
# 'client_secret': settings.keycloak_client_secret,
|
||||
# 'refresh_token': refresh_token
|
||||
# })
|
||||
return response
|
||||
|
||||
|
||||
@@ -107,9 +102,9 @@ def callback(code: str, session: Session = Depends(get_session)):
|
||||
'refresh_token': token_data['refresh_token'],
|
||||
}
|
||||
res = requests.post(LOGOUT_URL, data=data)
|
||||
resp = RedirectResponse(settings.origins)
|
||||
resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true')
|
||||
return resp
|
||||
resource_access.get(settings.keycloak_client_id)
|
||||
roles = resource_access.get(settings.keycloak_client_id)
|
||||
if not roles:
|
||||
data = {
|
||||
'client_id': settings.keycloak_client_id,
|
||||
@@ -117,7 +112,7 @@ def callback(code: str, session: Session = Depends(get_session)):
|
||||
'refresh_token': token_data['refresh_token'],
|
||||
}
|
||||
res = requests.post(LOGOUT_URL, data=data)
|
||||
resp = RedirectResponse(settings.origins)
|
||||
resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true')
|
||||
return resp
|
||||
|
||||
user_create = UserCreate(
|
||||
|
||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from src.database import get_session
|
||||
from sqlmodel import Session
|
||||
from src.contracts.generate_contract import generate_html_contract
|
||||
from src.contracts.generate_contract import generate_html_contract, generate_recap
|
||||
from src.auth.auth import get_current_user
|
||||
import src.models as models
|
||||
import src.messages as messages
|
||||
@@ -79,7 +79,8 @@ async def create_contract(
|
||||
occasionals = create_occasional_dict(occasional_contract_products)
|
||||
recurrents = list(map(lambda x: {"product": x.product, "quantity": x.quantity}, filter(lambda contract_product: contract_product.product.type == models.ProductType.RECCURENT, new_contract.products)))
|
||||
recurrent_price = compute_recurrent_prices(recurrents, len(new_contract.form.shipments))
|
||||
total_price = '{:10.2f}'.format(recurrent_price + compute_occasional_prices(occasionals))
|
||||
price = recurrent_price + compute_occasional_prices(occasionals)
|
||||
total_price = '{:10.2f}'.format(price)
|
||||
cheques = list(map(lambda x: {"name": x.name, "value": x.value}, new_contract.cheques))
|
||||
# TODO: send contract to referer
|
||||
|
||||
@@ -94,7 +95,7 @@ async def create_contract(
|
||||
)
|
||||
pdf_file = io.BytesIO(pdf_bytes)
|
||||
contract_id = f'{new_contract.firstname}_{new_contract.lastname}_{new_contract.form.productor.type}_{new_contract.form.season}'
|
||||
service.add_contract_file(session, new_contract.id, pdf_bytes)
|
||||
service.add_contract_file(session, new_contract.id, pdf_bytes, price)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(status_code=400, detail=messages.pdferror)
|
||||
@@ -112,7 +113,7 @@ def get_contracts(
|
||||
session: Session = Depends(get_session),
|
||||
user: models.User = Depends(get_current_user)
|
||||
):
|
||||
return service.get_all(session, forms)
|
||||
return service.get_all(session, user, forms)
|
||||
|
||||
@router.get('/{id}/file')
|
||||
def get_contract_file(
|
||||
@@ -120,6 +121,8 @@ def get_contract_file(
|
||||
session: Session = Depends(get_session),
|
||||
user: models.User = Depends(get_current_user)
|
||||
):
|
||||
if not service.is_allowed(session, user, id):
|
||||
raise HTTPException(status_code=403, detail=messages.notallowed)
|
||||
contract = service.get_one(session, id)
|
||||
if contract is None:
|
||||
raise HTTPException(status_code=404, detail=messages.notfound)
|
||||
@@ -138,8 +141,10 @@ def get_contract_files(
|
||||
session: Session = Depends(get_session),
|
||||
user: models.User = Depends(get_current_user)
|
||||
):
|
||||
if not form_service.is_allowed(session, user, form_id):
|
||||
raise HTTPException(status_code=403, detail=messages.notallowed)
|
||||
form = form_service.get_one(session, form_id=form_id)
|
||||
contracts = service.get_all(session, [form.name])
|
||||
contracts = service.get_all(session, user, forms=[form.name])
|
||||
zipped_contracts = io.BytesIO()
|
||||
with zipfile.ZipFile(zipped_contracts, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
|
||||
for contract in contracts:
|
||||
@@ -155,9 +160,29 @@ def get_contract_files(
|
||||
}
|
||||
)
|
||||
|
||||
@router.get('/{form_id}/recap')
|
||||
def get_contract_recap(
|
||||
form_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
user: models.User = Depends(get_current_user)
|
||||
):
|
||||
if not form_service.is_allowed(session, user, form_id):
|
||||
raise HTTPException(status_code=403, detail=messages.notallowed)
|
||||
form = form_service.get_one(session, form_id=form_id)
|
||||
contracts = service.get_all(session, user, forms=[form.name])
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(generate_recap(contracts, form)),
|
||||
media_type='application/zip',
|
||||
headers={
|
||||
'Content-Disposition': f'attachment; filename=filename.ods'
|
||||
}
|
||||
)
|
||||
|
||||
@router.get('/{id}', response_model=models.ContractPublic)
|
||||
def get_contract(id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user)):
|
||||
if not service.is_allowed(session, user, id):
|
||||
raise HTTPException(status_code=403, detail=messages.notallowed)
|
||||
result = service.get_one(session, id)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail=messages.notfound)
|
||||
@@ -165,6 +190,8 @@ def get_contract(id: int, session: Session = Depends(get_session), user: models.
|
||||
|
||||
@router.delete('/{id}', response_model=models.ContractPublic)
|
||||
def delete_contract(id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user)):
|
||||
if not service.is_allowed(session, user, id):
|
||||
raise HTTPException(status_code=403, detail=messages.notallowed)
|
||||
result = service.delete_one(session, id)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail=messages.notfound)
|
||||
|
||||
@@ -3,6 +3,7 @@ import jinja2
|
||||
import src.models as models
|
||||
import html
|
||||
from weasyprint import HTML
|
||||
import io
|
||||
|
||||
def generate_html_contract(
|
||||
contract: models.Contract,
|
||||
@@ -57,4 +58,25 @@ def generate_html_contract(
|
||||
return HTML(
|
||||
string=output_text,
|
||||
base_url=template_dir
|
||||
).write_pdf()
|
||||
).write_pdf()
|
||||
|
||||
from odfdo import Document, Table, Row, Cell
|
||||
|
||||
def generate_recap(
|
||||
contracts: list[models.Contract],
|
||||
form: models.Form,
|
||||
):
|
||||
data = [
|
||||
["nom", "email"],
|
||||
]
|
||||
doc = Document("spreadsheet")
|
||||
sheet = Table(name="Recap")
|
||||
sheet.set_values(data)
|
||||
|
||||
doc.body.append(sheet)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
doc.save(buffer)
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
@@ -3,14 +3,19 @@ import src.models as models
|
||||
|
||||
def get_all(
|
||||
session: Session,
|
||||
user: models.User,
|
||||
forms: list[str] = [],
|
||||
form_id: int | None = None,
|
||||
form_id: int | None = None,
|
||||
) -> list[models.ContractPublic]:
|
||||
statement = select(models.Contract)
|
||||
if form_id:
|
||||
statement = statement.join(models.Form).where(models.Form.id == form_id)
|
||||
statement = select(models.Contract)\
|
||||
.join(models.Form, models.Contract.form_id == models.Form.id)\
|
||||
.join(models.Productor, models.Form.productor_id == models.Productor.id)\
|
||||
.where(models.Productor.type.in_([r.name for r in user.roles]))\
|
||||
.distinct()
|
||||
if len(forms) > 0:
|
||||
statement = statement.join(models.Form).where(models.Form.name.in_(forms))
|
||||
statement = statement.where(models.Form.name.in_(forms))
|
||||
if form_id:
|
||||
statement = statement.where(models.Form.id == form_id)
|
||||
return session.exec(statement.order_by(models.Contract.id)).all()
|
||||
|
||||
def get_one(session: Session, contract_id: int) -> models.ContractPublic:
|
||||
@@ -42,11 +47,11 @@ def create_one(session: Session, contract: models.ContractCreate) -> models.Cont
|
||||
session.refresh(new_contract)
|
||||
return new_contract
|
||||
|
||||
def add_contract_file(session: Session, id: int, file: bytes):
|
||||
def add_contract_file(session: Session, id: int, file: bytes, price: float):
|
||||
statement = select(models.Contract).where(models.Contract.id == id)
|
||||
result = session.exec(statement)
|
||||
contract = result.first()
|
||||
|
||||
contract.total_price = price
|
||||
contract.file = file
|
||||
session.add(contract)
|
||||
session.commit()
|
||||
@@ -77,3 +82,12 @@ def delete_one(session: Session, id: int) -> models.ContractPublic:
|
||||
session.delete(contract)
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
def is_allowed(session: Session, user: models.User, id: int) -> bool:
|
||||
statement = select(models.Contract)\
|
||||
.join(models.Form, models.Contract.form_id == models.Form.id)\
|
||||
.join(models.Productor, models.Form.productor_id == models.Productor.id)\
|
||||
.where(models.Contract.id == id)\
|
||||
.where(models.Productor.type.in_([r.name for r in user.roles]))\
|
||||
.distinct()
|
||||
return len(session.exec(statement).all()) > 0
|
||||
@@ -12,9 +12,10 @@ router = APIRouter(prefix='/forms')
|
||||
async def get_forms(
|
||||
seasons: list[str] = Query([]),
|
||||
productors: list[str] = Query([]),
|
||||
current_season: bool = False,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
return service.get_all(session, seasons, productors)
|
||||
return service.get_all(session, seasons, productors, current_season)
|
||||
|
||||
@router.get('/{id}', response_model=models.FormPublic)
|
||||
async def get_form(id: int, session: Session = Depends(get_session)):
|
||||
|
||||
@@ -1,16 +1,35 @@
|
||||
from sqlmodel import Session, select
|
||||
import src.models as models
|
||||
from sqlalchemy import func
|
||||
|
||||
def get_all(
|
||||
session: Session,
|
||||
seasons: list[str],
|
||||
productors: list[str]
|
||||
productors: list[str],
|
||||
current_season: bool,
|
||||
) -> list[models.FormPublic]:
|
||||
statement = select(models.Form)
|
||||
if len(seasons) > 0:
|
||||
statement = statement.where(models.Form.season.in_(seasons))
|
||||
if len(productors) > 0:
|
||||
statement = statement.join(models.Productor).where(models.Productor.name.in_(productors))
|
||||
if current_season:
|
||||
subquery = (
|
||||
select(
|
||||
models.Productor.type,
|
||||
func.max(models.Form.start).label("max_start")
|
||||
)
|
||||
.join(models.Form)\
|
||||
.group_by(models.Productor.type)\
|
||||
.subquery()
|
||||
)
|
||||
statement = select(models.Form)\
|
||||
.join(models.Productor)\
|
||||
.join(subquery,
|
||||
(models.Productor.type == subquery.c.type) &
|
||||
(models.Form.start == subquery.c.max_start)
|
||||
)
|
||||
return session.exec(statement.order_by(models.Form.name)).all()
|
||||
return session.exec(statement.order_by(models.Form.name)).all()
|
||||
|
||||
def get_one(session: Session, form_id: int) -> models.FormPublic:
|
||||
@@ -48,3 +67,11 @@ def delete_one(session: Session, id: int) -> models.FormPublic:
|
||||
session.delete(form)
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
def is_allowed(session: Session, user: models.User, id: int) -> bool:
|
||||
statement = select(models.Form)\
|
||||
.join(models.Productor, models.Form.productor_id == models.Productor.id)\
|
||||
.where(models.Form.id == id)\
|
||||
.where(models.Productor.type.in_([r.name for r in user.roles]))\
|
||||
.distinct()
|
||||
return len(session.exec(statement).all()) > 0
|
||||
@@ -6,4 +6,5 @@ notauthenticated = "Not authenticated"
|
||||
usernotfound = "User not found"
|
||||
userloggedout = "User logged out"
|
||||
failtogettoken = "Failed to get token"
|
||||
unauthorized = "Unauthorized"
|
||||
unauthorized = "Unauthorized"
|
||||
notallowed = "Not Allowed"
|
||||
@@ -222,6 +222,7 @@ class Contract(ContractBase, table=True):
|
||||
cascade_delete=True
|
||||
)
|
||||
file: bytes = Field(sa_column=Column(LargeBinary))
|
||||
total_price: float | None
|
||||
|
||||
class ContractCreate(ContractBase):
|
||||
products: list["ContractProductCreate"] = []
|
||||
@@ -235,6 +236,7 @@ class ContractPublic(ContractBase):
|
||||
id: int
|
||||
products: list["ContractProduct"] = []
|
||||
form: Form
|
||||
total_price: float | None
|
||||
# file: bytes
|
||||
|
||||
class ContractProductBase(SQLModel):
|
||||
|
||||
@@ -16,7 +16,7 @@ def get_all(
|
||||
def get_one(session: Session, user_id: int) -> models.UserPublic:
|
||||
return session.get(models.User, user_id)
|
||||
|
||||
def get_or_create_roles(session: Session, role_names) -> list[models.ContractType]:
|
||||
def get_or_create_roles(session: Session, role_names: list[str]) -> list[models.ContractType]:
|
||||
statement = select(models.ContractType).where(models.ContractType.name.in_(role_names))
|
||||
existing = session.exec(statement).all()
|
||||
existing_roles = {role.name for role in existing}
|
||||
@@ -37,6 +37,9 @@ def get_or_create_user(session: Session, user_create: models.UserCreate):
|
||||
statement = select(models.User).where(models.User.email == user_create.email)
|
||||
user = session.exec(statement).first()
|
||||
if user:
|
||||
user_role_names = [r.name for r in user.roles]
|
||||
if user_role_names != user_create.role_names or user.name != user_create.name:
|
||||
user = update_one(session, user.id, user_create)
|
||||
return user
|
||||
user = create_one(session, user_create)
|
||||
return user
|
||||
@@ -46,7 +49,6 @@ def get_roles(session: Session):
|
||||
return session.exec(statement.order_by(models.ContractType.name)).all()
|
||||
|
||||
def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
|
||||
print("USER CREATE", user)
|
||||
new_user = models.User(
|
||||
name=user.name,
|
||||
email=user.email
|
||||
@@ -60,15 +62,20 @@ def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
|
||||
session.refresh(new_user)
|
||||
return new_user
|
||||
|
||||
def update_one(session: Session, id: int, user: models.UserUpdate) -> models.UserPublic:
|
||||
def update_one(session: Session, id: int, user: models.UserCreate) -> models.UserPublic:
|
||||
statement = select(models.User).where(models.User.id == id)
|
||||
result = session.exec(statement)
|
||||
new_user = result.first()
|
||||
if not new_user:
|
||||
return None
|
||||
user_updates = user.model_dump(exclude_unset=True)
|
||||
|
||||
user_updates = user.model_dump(exclude="role_names")
|
||||
for key, value in user_updates.items():
|
||||
setattr(new_user, key, value)
|
||||
|
||||
roles = get_or_create_roles(session, user.role_names)
|
||||
new_user.roles = roles
|
||||
|
||||
session.add(new_user)
|
||||
session.commit()
|
||||
session.refresh(new_user)
|
||||
@@ -83,4 +90,4 @@ def delete_one(session: Session, id: int) -> models.UserPublic:
|
||||
result = models.UserPublic.model_validate(user)
|
||||
session.delete(user)
|
||||
session.commit()
|
||||
return result
|
||||
return result
|
||||
Reference in New Issue
Block a user