Compare commits

...

2 Commits

Author SHA1 Message Date
Julien Aldon
8352097ffb fix pylint errors
Some checks failed
Deploy Amap / deploy (push) Failing after 16s
2026-03-03 11:08:08 +01:00
Julien Aldon
0e48d1bbaa fix test format 2026-03-02 16:33:52 +01:00
60 changed files with 2040 additions and 1049 deletions

View File

@@ -0,0 +1,26 @@
default_language_version:
python: python3.13
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-added-large-files
- id: trailing-whitespace
- id: check-ast
- id: check-builtin-literals
- id: check-docstring-first
- id: check-yaml
- id: check-toml
- id: mixed-line-ending
- id: end-of-file-fixer
- repo: local
hooks:
- id: check-pylint
name: check-pylint
entry: pylint -d R0801,R0903,W0511,W0603,C0103,R0902
language: system
types: [python]
pass_filenames: false
args:
- backend

View File

@@ -35,6 +35,12 @@ hatch run pytest
hatch run pytest --cov=src -vv hatch run pytest --cov=src -vv
``` ```
## Autoformat
```console
find -type f -name '*.py' ! -path 'alembic/*' -exec autopep8 --in-place --aggressive --aggressive '{}' \;
pylint -d R0801,R0903,W0511,W0603,C0103,R0902 .
```
## License ## License
`backend` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license. `backend` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license.

View File

@@ -25,7 +25,8 @@ target_metadata = SQLModel.metadata
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
# can be acquired: # can be acquired:
config.set_main_option("sqlalchemy.url", f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}') config.set_main_option(
"sqlalchemy.url", f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}')
# ... etc. # ... etc.

View File

@@ -22,7 +22,12 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
"""Upgrade schema.""" """Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.add_column('paymentmethod', sa.Column('max', sa.Integer(), nullable=True)) op.add_column(
'paymentmethod',
sa.Column(
'max',
sa.Integer(),
nullable=True))
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -1,7 +1,7 @@
"""Initial repository """Initial repository
Revision ID: c0b1073a8394 Revision ID: c0b1073a8394
Revises: Revises:
Create Date: 2026-02-20 00:09:35.920486 Create Date: 2026-02-20 00:09:35.920486
""" """
@@ -22,117 +22,121 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
"""Upgrade schema.""" """Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('contracttype', op.create_table(
sa.Column('id', sa.Integer(), nullable=False), 'contracttype',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
sa.PrimaryKeyConstraint('id') 'id',
) sa.Integer(),
op.create_table('productor', nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
sa.Column('address', sqlmodel.sql.sqltypes.AutoString(), nullable=False), 'name',
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sqlmodel.sql.sqltypes.AutoString(),
sa.Column('id', sa.Integer(), nullable=False), nullable=False),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id'))
) op.create_table(
'productor', sa.Column(
'name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
'address', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
'type', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
'id', sa.Integer(), nullable=False), sa.PrimaryKeyConstraint('id'))
op.create_table('template', op.create_table('template',
sa.Column('id', sa.Integer(), nullable=False), sa.Column('id', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')
) )
op.create_table('user', op.create_table(
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), 'user', sa.Column(
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False), 'name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
sa.Column('id', sa.Integer(), nullable=False), 'email', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
sa.PrimaryKeyConstraint('id') 'id', sa.Integer(), nullable=False), sa.PrimaryKeyConstraint('id'))
)
op.create_table('form', op.create_table('form',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('productor_id', sa.Integer(), nullable=True), sa.Column('productor_id', sa.Integer(), nullable=True),
sa.Column('referer_id', sa.Integer(), nullable=True), sa.Column('referer_id', sa.Integer(), nullable=True),
sa.Column('season', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('season', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('start', sa.Date(), nullable=False), sa.Column('start', sa.Date(), nullable=False),
sa.Column('end', sa.Date(), nullable=False), sa.Column('end', sa.Date(), nullable=False),
sa.Column('minimum_shipment_value', sa.Float(), nullable=True), sa.Column('minimum_shipment_value', sa.Float(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False), sa.Column('id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ), sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ),
sa.ForeignKeyConstraint(['referer_id'], ['user.id'], ), sa.ForeignKeyConstraint(['referer_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')
) )
op.create_table('paymentmethod', op.create_table('paymentmethod',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('details', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('details', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False), sa.Column('id', sa.Integer(), nullable=False),
sa.Column('productor_id', sa.Integer(), nullable=False), sa.Column('productor_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ondelete='CASCADE'), sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')
) )
op.create_table('product', op.create_table('product',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('unit', sa.Enum('GRAMS', 'KILO', 'PIECE', name='unit'), nullable=False), sa.Column('unit', sa.Enum('GRAMS', 'KILO', 'PIECE', name='unit'), nullable=False),
sa.Column('price', sa.Float(), nullable=True), sa.Column('price', sa.Float(), nullable=True),
sa.Column('price_kg', sa.Float(), nullable=True), sa.Column('price_kg', sa.Float(), nullable=True),
sa.Column('quantity', sa.Float(), nullable=True), sa.Column('quantity', sa.Float(), nullable=True),
sa.Column('quantity_unit', sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column('quantity_unit', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('type', sa.Enum('OCCASIONAL', 'RECCURENT', name='producttype'), nullable=False), sa.Column('type', sa.Enum('OCCASIONAL', 'RECCURENT', name='producttype'), nullable=False),
sa.Column('productor_id', sa.Integer(), nullable=True), sa.Column('productor_id', sa.Integer(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False), sa.Column('id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ), sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')
) )
op.create_table('usercontracttypelink', op.create_table(
sa.Column('user_id', sa.Integer(), nullable=False), 'usercontracttypelink', sa.Column(
sa.Column('contract_type_id', sa.Integer(), nullable=False), 'user_id', sa.Integer(), nullable=False), sa.Column(
sa.ForeignKeyConstraint(['contract_type_id'], ['contracttype.id'], ), 'contract_type_id', sa.Integer(), nullable=False), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), ['contract_type_id'], ['contracttype.id'], ), sa.ForeignKeyConstraint(
sa.PrimaryKeyConstraint('user_id', 'contract_type_id') ['user_id'], ['user.id'], ), sa.PrimaryKeyConstraint(
) 'user_id', 'contract_type_id'))
op.create_table('contract', op.create_table('contract',
sa.Column('firstname', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('firstname', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('lastname', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('lastname', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('phone', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('phone', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('payment_method', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('payment_method', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('cheque_quantity', sa.Integer(), nullable=False), sa.Column('cheque_quantity', sa.Integer(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False), sa.Column('id', sa.Integer(), nullable=False),
sa.Column('form_id', sa.Integer(), nullable=False), sa.Column('form_id', sa.Integer(), nullable=False),
sa.Column('file', sa.LargeBinary(), nullable=True), sa.Column('file', sa.LargeBinary(), nullable=True),
sa.Column('total_price', sa.Float(), nullable=True), sa.Column('total_price', sa.Float(), nullable=True),
sa.ForeignKeyConstraint(['form_id'], ['form.id'], ondelete='CASCADE'), sa.ForeignKeyConstraint(['form_id'], ['form.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')
) )
op.create_table('shipment', op.create_table('shipment',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('date', sa.Date(), nullable=False), sa.Column('date', sa.Date(), nullable=False),
sa.Column('form_id', sa.Integer(), nullable=True), sa.Column('form_id', sa.Integer(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False), sa.Column('id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['form_id'], ['form.id'], ondelete='CASCADE'), sa.ForeignKeyConstraint(['form_id'], ['form.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')
) )
op.create_table('cheque', op.create_table('cheque',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('value', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('value', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False), sa.Column('id', sa.Integer(), nullable=False),
sa.Column('contract_id', sa.Integer(), nullable=False), sa.Column('contract_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['contract_id'], ['contract.id'], ondelete='CASCADE'), sa.ForeignKeyConstraint(['contract_id'], ['contract.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')
) )
op.create_table('contractproduct', op.create_table('contractproduct',
sa.Column('product_id', sa.Integer(), nullable=False), sa.Column('product_id', sa.Integer(), nullable=False),
sa.Column('shipment_id', sa.Integer(), nullable=True), sa.Column('shipment_id', sa.Integer(), nullable=True),
sa.Column('quantity', sa.Float(), nullable=False), sa.Column('quantity', sa.Float(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False), sa.Column('id', sa.Integer(), nullable=False),
sa.Column('contract_id', sa.Integer(), nullable=False), sa.Column('contract_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['contract_id'], ['contract.id'], ondelete='CASCADE'), sa.ForeignKeyConstraint(['contract_id'], ['contract.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['product_id'], ['product.id'], ondelete='CASCADE'), sa.ForeignKeyConstraint(['product_id'], ['product.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['shipment_id'], ['shipment.id'], ondelete='CASCADE'), sa.ForeignKeyConstraint(['shipment_id'], ['shipment.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')
) )
op.create_table('shipmentproductlink', op.create_table('shipmentproductlink',
sa.Column('shipment_id', sa.Integer(), nullable=False), sa.Column('shipment_id', sa.Integer(), nullable=False),
sa.Column('product_id', sa.Integer(), nullable=False), sa.Column('product_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['product_id'], ['product.id'], ), sa.ForeignKeyConstraint(['product_id'], ['product.id'], ),
sa.ForeignKeyConstraint(['shipment_id'], ['shipment.id'], ), sa.ForeignKeyConstraint(['shipment_id'], ['shipment.id'], ),
sa.PrimaryKeyConstraint('shipment_id', 'product_id') sa.PrimaryKeyConstraint('shipment_id', 'product_id')
) )
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -22,7 +22,14 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
"""Upgrade schema.""" """Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.add_column('form', sa.Column('visible', sa.Boolean(), nullable=False, default=False, server_default="False")) op.add_column(
'form',
sa.Column(
'visible',
sa.Boolean(),
nullable=False,
default=False,
server_default="False"))
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -34,6 +34,9 @@ dependencies = [
"pytest", "pytest",
"pytest-cov", "pytest-cov",
"pytest-mock", "pytest-mock",
"autopep8",
"prek",
"pylint",
] ]
[project.urls] [project.urls]

View File

@@ -0,0 +1,84 @@
alembic==1.18.4
annotated-doc==0.0.4
annotated-types==0.7.0
anyio==4.12.1
astroid==4.0.4
autopep8==2.3.2
brotli==1.2.0
certifi==2026.2.25
cffi==2.0.0
charset-normalizer==3.4.4
click==8.3.1
coverage==7.13.4
cryptography==46.0.5
cssselect2==0.9.0
dill==0.4.1
dnspython==2.8.0
email-validator==2.3.0
fastapi==0.135.1
fastapi-cli==0.0.24
fastapi-cloud-cli==0.14.0
fastar==0.8.0
fonttools==4.61.1
greenlet==3.3.2
h11==0.16.0
httpcore==1.0.9
httptools==0.7.1
httpx==0.28.1
idna==3.11
iniconfig==2.3.0
isort==8.0.1
Jinja2==3.1.6
lxml==6.0.2
Mako==1.3.10
markdown-it-py==4.0.0
MarkupSafe==3.0.3
mccabe==0.7.0
mdurl==0.1.2
odfdo==3.21.0
packaging==26.0
pillow==12.1.1
platformdirs==4.9.2
pluggy==1.6.0
prek==0.3.4
psycopg2-binary==2.9.11
pycodestyle==2.14.0
pycparser==3.0
pydantic==2.12.5
pydantic-extra-types==2.11.0
pydantic-settings==2.13.1
pydantic_core==2.41.5
pydyf==0.12.1
Pygments==2.19.2
PyJWT==2.11.0
pylint==4.0.5
pyphen==0.17.2
pytest==9.0.2
pytest-cov==7.0.0
pytest-mock==3.15.1
python-dotenv==1.2.2
python-multipart==0.0.22
PyYAML==6.0.3
requests==2.32.5
rich==14.3.3
rich-toolkit==0.19.7
rignore==0.7.6
sentry-sdk==2.53.0
shellingham==1.5.4
SQLAlchemy==2.0.47
sqlmodel==0.0.37
starlette==0.52.1
tinycss2==1.5.1
tinyhtml5==2.0.0
tomlkit==0.14.0
typer==0.24.1
typing-inspection==0.4.2
typing_extensions==4.15.0
urllib3==2.6.3
uvicorn==0.41.0
uvloop==0.22.1
watchfiles==1.1.1
weasyprint==68.1
webencodings==0.5.1
websockets==16.0
zopfli==0.4.1

View File

@@ -1,3 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr> # SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT

View File

@@ -21,6 +21,7 @@ router = APIRouter(prefix='/auth')
jwk_client = PyJWKClient(JWKS_URL) jwk_client = PyJWKClient(JWKS_URL)
security = HTTPBearer() security = HTTPBearer()
@router.get('/logout') @router.get('/logout')
def logout(): def logout():
params = { params = {
@@ -59,9 +60,11 @@ def login():
'redirect_uri': settings.keycloak_redirect_uri, 'redirect_uri': settings.keycloak_redirect_uri,
'state': state, 'state': state,
} }
request_url = requests.Request('GET', AUTH_URL, params=params).prepare().url request_url = requests.Request(
'GET', AUTH_URL, params=params).prepare().url
return RedirectResponse(request_url) return RedirectResponse(request_url)
@router.get('/callback') @router.get('/callback')
def callback(code: str, session: Session = Depends(get_session)): def callback(code: str, session: Session = Depends(get_session)):
data = { data = {
@@ -82,10 +85,12 @@ def callback(code: str, session: Session = Depends(get_session)):
) )
token_data = response.json() token_data = response.json()
id_token = token_data['id_token'] id_token = token_data['id_token']
decoded_token = jwt.decode(id_token, options={'verify_signature': False}) decoded_token = jwt.decode(id_token, options={'verify_signature': False})
decoded_access_token = jwt.decode(token_data['access_token'], options={'verify_signature': False}) decoded_access_token = jwt.decode(
token_data['access_token'], options={
'verify_signature': False})
resource_access = decoded_access_token.get('resource_access') resource_access = decoded_access_token.get('resource_access')
if not resource_access: if not resource_access:
data = { data = {
@@ -141,6 +146,7 @@ def callback(code: str, session: Session = Depends(get_session)):
return response return response
def verify_token(token: str): def verify_token(token: str):
try: try:
signing_key = jwk_client.get_signing_key_from_jwt(token) signing_key = jwk_client.get_signing_key_from_jwt(token)
@@ -154,28 +160,37 @@ def verify_token(token: str):
) )
return decoded return decoded
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail=messages.Messages.tokenexipired) raise HTTPException(status_code=401,
detail=messages.Messages.tokenexipired)
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail=messages.Messages.invalidtoken) raise HTTPException(
status_code=401,
detail=messages.Messages.invalidtoken)
def get_current_user(request: Request, session: Session = Depends(get_session)): def get_current_user(
request: Request,
session: Session = Depends(get_session)):
access_token = request.cookies.get('access_token') access_token = request.cookies.get('access_token')
if not access_token: if not access_token:
raise HTTPException(status_code=401, detail=messages.Messages.notauthenticated) raise HTTPException(status_code=401,
detail=messages.Messages.notauthenticated)
payload = verify_token(access_token) payload = verify_token(access_token)
if not payload: if not payload:
raise HTTPException(status_code=401, detail='aze') raise HTTPException(status_code=401, detail='aze')
email = payload.get('email') email = payload.get('email')
if not email: if not email:
raise HTTPException(status_code=401, detail=messages.Messages.notauthenticated) raise HTTPException(status_code=401,
detail=messages.Messages.notauthenticated)
user = session.exec(select(User).where(User.email == email)).first() user = session.exec(select(User).where(User.email == email)).first()
if not user: if not user:
raise HTTPException(status_code=401, detail=messages.Messages.not_found('user')) raise HTTPException(status_code=401,
detail=messages.Messages.not_found('user'))
return user return user
@router.post('/refresh') @router.post('/refresh')
def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None): def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
refresh = refresh_token refresh = refresh_token
@@ -223,6 +238,7 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
) )
return response return response
@router.get('/user/me') @router.get('/user/me')
def me(user: UserPublic = Depends(get_current_user)): def me(user: UserPublic = Depends(get_current_user)):
if not user: if not user:
@@ -235,4 +251,4 @@ def me(user: UserPublic = Depends(get_current_user)):
'id': user.id, 'id': user.id,
'roles': [role.name for role in user.roles] 'roles': [role.name for role in user.roles]
} }
} }

View File

@@ -1,18 +1,27 @@
from fastapi import APIRouter, Depends, HTTPException, Query """Router for contract resource"""
from fastapi.responses import StreamingResponse
from src.database import get_session
from sqlmodel import Session
from src.contracts.generate_contract import generate_html_contract, generate_recap
from src.auth.auth import get_current_user
import src.models as models
import src.messages as messages
import src.contracts.service as service
import src.forms.service as form_service
import io import io
import zipfile import zipfile
import src.contracts.service as service
import src.forms.service as form_service
import src.messages as messages
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user
from src.contracts.generate_contract import (generate_html_contract,
generate_recap)
from src.database import get_session
router = APIRouter(prefix='/contracts') router = APIRouter(prefix='/contracts')
def compute_recurrent_prices(products_quantities: list[dict], nb_shipment: int):
def compute_recurrent_prices(
products_quantities: list[dict],
nb_shipment: int
):
"""Compute price for recurrent products"""
result = 0 result = 0
for product_quantity in products_quantities: for product_quantity in products_quantities:
product = product_quantity['product'] product = product_quantity['product']
@@ -20,30 +29,50 @@ def compute_recurrent_prices(products_quantities: list[dict], nb_shipment: int):
result += compute_product_price(product, quantity, nb_shipment) result += compute_product_price(product, quantity, nb_shipment)
return result return result
def compute_occasional_prices(occasionals: list[dict]): def compute_occasional_prices(occasionals: list[dict]):
"""Compute prices for occassional products"""
result = 0 result = 0
for occasional in occasionals: for occasional in occasionals:
result += occasional['price'] result += occasional['price']
return result return result
def compute_product_price(product: models.Product, quantity: int, nb_shipment: int = 1):
product_quantity_unit = 1 if product.unit == models.Unit.KILO else 1000 def compute_product_price(
final_quantity = quantity if product.price else quantity / product_quantity_unit product: models.Product,
final_price = product.price if product.price else product.price_kg quantity: int,
return final_price * final_quantity * nb_shipment nb_shipment: int = 1
):
"""Compute price for a product"""
product_quantity_unit = (
1 if product.unit == models.Unit.KILO else 1000
)
final_quantity = (
quantity if product.price else quantity / product_quantity_unit
)
final_price = (
product.price if product.price else product.price_kg
)
return final_price * final_quantity * nb_shipment
def find_dict_in_list(lst, key, value): def find_dict_in_list(lst, key, value):
"""Find the index of a dictionnary in a list of dictionnaries given a key
and a value.
"""
for i, dic in enumerate(lst): for i, dic in enumerate(lst):
if dic[key].id == value: if dic[key].id == value:
return i return i
return -1 return -1
def create_occasional_dict(contract_products: list[models.ContractProduct]): def create_occasional_dict(contract_products: list[models.ContractProduct]):
"""Create a dictionnary of occasional products"""
result = [] result = []
for contract_product in contract_products: for contract_product in contract_products:
existing_id = find_dict_in_list( existing_id = find_dict_in_list(
result, result,
'shipment', 'shipment',
contract_product.shipment.id contract_product.shipment.id
) )
if existing_id < 0: if existing_id < 0:
@@ -69,18 +98,46 @@ def create_occasional_dict(contract_products: list[models.ContractProduct]):
) )
return result return result
@router.post('') @router.post('')
async def create_contract( async def create_contract(
contract: models.ContractCreate, contract: models.ContractCreate,
session: Session = Depends(get_session), session: Session = Depends(get_session),
): ):
"""Create contract route"""
new_contract = service.create_one(session, contract) new_contract = service.create_one(session, contract)
occasional_contract_products = list(filter(lambda contract_product: contract_product.product.type == models.ProductType.OCCASIONAL, new_contract.products)) occasional_contract_products = list(
filter(
lambda contract_product: (
contract_product.product.type == models.ProductType.OCCASIONAL
),
new_contract.products
)
)
occasionals = create_occasional_dict(occasional_contract_products) occasionals = create_occasional_dict(occasional_contract_products)
recurrents = list(map(lambda x: {"product": x.product, "quantity": x.quantity}, filter(lambda contract_product: contract_product.product.type == models.ProductType.RECCURENT, new_contract.products))) recurrents = list(
recurrent_price = compute_recurrent_prices(recurrents, len(new_contract.form.shipments)) map(
lambda x: {'product': x.product, 'quantity': x.quantity},
filter(
lambda contract_product: (
contract_product.product.type ==
models.ProductType.RECCURENT
),
new_contract.products
)
)
)
recurrent_price = compute_recurrent_prices(
recurrents,
len(new_contract.form.shipments)
)
price = recurrent_price + compute_occasional_prices(occasionals) price = recurrent_price + compute_occasional_prices(occasionals)
cheques = list(map(lambda x: {"name": x.name, "value": x.value}, new_contract.cheques)) cheques = list(
map(
lambda x: {'name': x.name, 'value': x.value},
new_contract.cheques
)
)
try: try:
pdf_bytes = generate_html_contract( pdf_bytes = generate_html_contract(
new_contract, new_contract,
@@ -91,43 +148,63 @@ async def create_contract(
'{:10.2f}'.format(price) '{:10.2f}'.format(price)
) )
pdf_file = io.BytesIO(pdf_bytes) pdf_file = io.BytesIO(pdf_bytes)
contract_id = f'{new_contract.firstname}_{new_contract.lastname}_{new_contract.form.productor.type}_{new_contract.form.season}' contract_id = (
f'{new_contract.firstname}_'
f'{new_contract.lastname}_'
f'{new_contract.form.productor.type}_'
f'{new_contract.form.season}'
)
service.add_contract_file(session, new_contract.id, pdf_bytes, price) service.add_contract_file(session, new_contract.id, pdf_bytes, price)
except Exception: except Exception as error:
raise HTTPException(status_code=400, detail=messages.pdferror) raise HTTPException(
status_code=400,
detail=messages.pdferror
) from error
return StreamingResponse( return StreamingResponse(
pdf_file, pdf_file,
media_type='application/pdf', media_type='application/pdf',
headers={ headers={
'Content-Disposition': f'attachment; filename=contract_{contract_id}.pdf' 'Content-Disposition': (
f'attachment; filename=contract_{contract_id}.pdf'
)
} }
) )
@router.get('/{form_id}/base') @router.get('/{form_id}/base')
async def get_base_contract_template( async def get_base_contract_template(
form_id: int, form_id: int,
session: Session = Depends(get_session), session: Session = Depends(get_session),
): ):
"""Get contract template route"""
form = form_service.get_one(session, form_id) form = form_service.get_one(session, form_id)
recurrents = list(map(lambda x: {"product": x, "quantity": None}, filter(lambda product: product.type == models.ProductType.RECCURENT, form.productor.products))) recurrents = [
{'product': product, 'quantity': None}
for product in form.productor.products
if product.type == models.ProductType.RECCURENT
]
occasionals = [{ occasionals = [{
'shipment': sh, 'shipment': sh,
'price': None, 'price': None,
'products': [{'product': pr, 'quantity': None} for pr in sh.products] 'products': [{'product': pr, 'quantity': None} for pr in sh.products]
} for sh in form.shipments] } for sh in form.shipments]
empty_contract = models.ContractPublic( empty_contract = models.ContractPublic(
firstname="", firstname='',
form=form, form=form,
lastname="", lastname='',
email="", email='',
phone="", phone='',
products=[], products=[],
payment_method="cheque", payment_method='cheque',
cheque_quantity=3, cheque_quantity=3,
total_price=0, total_price=0,
id=1 id=1
) )
cheques = [{"name": None, "value": None}, {"name": None, "value": None}, {"name": None, "value": None}] cheques = [
{'name': None, 'value': None},
{'name': None, 'value': None},
{'name': None, 'value': None}
]
try: try:
pdf_bytes = generate_html_contract( pdf_bytes = generate_html_contract(
empty_contract, empty_contract,
@@ -136,45 +213,68 @@ async def get_base_contract_template(
recurrents, recurrents,
) )
pdf_file = io.BytesIO(pdf_bytes) pdf_file = io.BytesIO(pdf_bytes)
contract_id = f'{empty_contract.form.productor.type}_{empty_contract.form.season}' contract_id = (
except Exception as e: f'{empty_contract.form.productor.type}_'
print(e) f'{empty_contract.form.season}'
raise HTTPException(status_code=400, detail=messages.pdferror) )
except Exception as error:
raise HTTPException(
status_code=400,
detail=messages.pdferror
) from error
return StreamingResponse( return StreamingResponse(
pdf_file, pdf_file,
media_type='application/pdf', media_type='application/pdf',
headers={ headers={
'Content-Disposition': f'attachment; filename=contract_{contract_id}.pdf' 'Content-Disposition': (
f'attachment; filename=contract_{contract_id}.pdf'
)
} }
) )
@router.get('', response_model=list[models.ContractPublic]) @router.get('', response_model=list[models.ContractPublic])
def get_contracts( def get_contracts(
forms: list[str] = Query([]), forms: list[str] = Query([]),
session: Session = Depends(get_session), session: Session = Depends(get_session),
user: models.User = Depends(get_current_user) user: models.User = Depends(get_current_user)
): ):
"""Get all contracts route"""
return service.get_all(session, user, forms) return service.get_all(session, user, forms)
@router.get('/{id}/file')
@router.get('/{_id}/file')
def get_contract_file( def get_contract_file(
id: int, _id: int,
session: Session = Depends(get_session), session: Session = Depends(get_session),
user: models.User = Depends(get_current_user) user: models.User = Depends(get_current_user)
): ):
if not service.is_allowed(session, user, id): """Get a contract file (in pdf) route"""
raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contract', 'get')) if not service.is_allowed(session, user, _id):
contract = service.get_one(session, id) raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract', 'get')
)
contract = service.get_one(session, _id)
if contract is None: if contract is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('contract')) raise HTTPException(
filename = f'{contract.form.name.replace(' ', '_')}_{contract.form.season}_{contract.firstname}-{contract.lastname}' status_code=404,
detail=messages.Messages.not_found('contract')
)
filename = (
f'{contract.form.name.replace(' ', '_')}_'
f'{contract.form.season}_'
f'{contract.firstname}_'
f'{contract.lastname}'
)
return StreamingResponse( return StreamingResponse(
io.BytesIO(contract.file), io.BytesIO(contract.file),
media_type='application/pdf', media_type='application/pdf',
headers={ headers={
'Content-Disposition': f'attachment; filename={filename}.pdf' 'Content-Disposition': f'attachment; filename={filename}.pdf'
} }
) )
@router.get('/{form_id}/files') @router.get('/{form_id}/files')
def get_contract_files( def get_contract_files(
@@ -182,17 +282,30 @@ def get_contract_files(
session: Session = Depends(get_session), session: Session = Depends(get_session),
user: models.User = Depends(get_current_user) user: models.User = Depends(get_current_user)
): ):
"""Get all contract files for a given form"""
if not form_service.is_allowed(session, user, form_id): if not form_service.is_allowed(session, user, form_id):
raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contracts', 'get')) raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contracts', 'get')
)
form = form_service.get_one(session, form_id=form_id) form = form_service.get_one(session, form_id=form_id)
contracts = service.get_all(session, user, forms=[form.name]) contracts = service.get_all(session, user, forms=[form.name])
zipped_contracts = io.BytesIO() zipped_contracts = io.BytesIO()
with zipfile.ZipFile(zipped_contracts, "a", zipfile.ZIP_DEFLATED, False) as zip_file: with zipfile.ZipFile(
zipped_contracts,
'a',
zipfile.ZIP_DEFLATED,
False
) as zip_file:
for contract in contracts: for contract in contracts:
contract_filename = f'{contract.form.name.replace(' ', '_')}_{contract.form.season}_{contract.firstname}-{contract.lastname}.pdf' contract_filename = (
f'{contract.form.name.replace(' ', '_')}_'
f'{contract.form.season}_'
f'{contract.firstname}_'
f'{contract.lastname}'
)
zip_file.writestr(contract_filename, contract.file) zip_file.writestr(contract_filename, contract.file)
filename = f'{form.name.replace(' ', '_')}_{form.season}'
filename = f'{form.name.replace(" ", "_")}_{form.season}'
return StreamingResponse( return StreamingResponse(
io.BytesIO(zipped_contracts.getvalue()), io.BytesIO(zipped_contracts.getvalue()),
media_type='application/zip', media_type='application/zip',
@@ -201,39 +314,69 @@ def get_contract_files(
} }
) )
@router.get('/{form_id}/recap') @router.get('/{form_id}/recap')
def get_contract_recap( def get_contract_recap(
form_id: int, form_id: int,
session: Session = Depends(get_session), session: Session = Depends(get_session),
user: models.User = Depends(get_current_user) user: models.User = Depends(get_current_user)
): ):
"""Get a contract recap for a given form"""
if not form_service.is_allowed(session, user, form_id): if not form_service.is_allowed(session, user, form_id):
raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contract recap', 'get')) raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract recap', 'get')
)
form = form_service.get_one(session, form_id=form_id) form = form_service.get_one(session, form_id=form_id)
contracts = service.get_all(session, user, forms=[form.name]) contracts = service.get_all(session, user, forms=[form.name])
return StreamingResponse( return StreamingResponse(
io.BytesIO(generate_recap(contracts, form)), io.BytesIO(generate_recap(contracts, form)),
media_type='application/zip', media_type='application/zip',
headers={ headers={
'Content-Disposition': f'attachment; filename=filename.ods' 'Content-Disposition': (
'attachment; filename=filename.ods'
)
} }
) )
@router.get('/{id}', response_model=models.ContractPublic)
def get_contract(id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user)): @router.get('/{_id}', response_model=models.ContractPublic)
if not service.is_allowed(session, user, id): def get_contract(
raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contract', 'get')) _id: int,
result = service.get_one(session, id) session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Get a contract route"""
if not service.is_allowed(session, user, _id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract', 'get')
)
result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('contract')) raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('contract')
)
return result return result
@router.delete('/{id}', response_model=models.ContractPublic)
def delete_contract(id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user)): @router.delete('/{_id}', response_model=models.ContractPublic)
if not service.is_allowed(session, user, id): def delete_contract(
raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contract', 'delete')) _id: int,
result = service.delete_one(session, id) session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Delete contract route"""
if not service.is_allowed(session, user, _id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract', 'delete')
)
result = service.delete_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('contract')) raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('contract')
)
return result return result

View File

@@ -1,11 +1,13 @@
import html
import io
import pathlib
import jinja2 import jinja2
import src.models as models from odfdo import Cell, Document, Row, Table
import html from src import models
from weasyprint import HTML from weasyprint import HTML
import io
import pathlib
def generate_html_contract( def generate_html_contract(
contract: models.Contract, contract: models.Contract,
@@ -14,10 +16,11 @@ def generate_html_contract(
reccurents: list[dict], reccurents: list[dict],
recurrent_price: float | None = None, recurrent_price: float | None = None,
total_price: float | None = None total_price: float | None = None
): ):
template_dir = pathlib.Path("./src/contracts/templates").resolve() template_dir = pathlib.Path("./src/contracts/templates").resolve()
template_loader = jinja2.FileSystemLoader(searchpath=template_dir) template_loader = jinja2.FileSystemLoader(searchpath=template_dir)
template_env = jinja2.Environment(loader=template_loader, autoescape=jinja2.select_autoescape(["html", "xml"])) template_env = jinja2.Environment(
loader=template_loader, autoescape=jinja2.select_autoescape(["html", "xml"]))
template_file = "layout.html" template_file = "layout.html"
template = template_env.get_template(template_file) template = template_env.get_template(template_file)
output_text = template.render( output_text = template.render(
@@ -28,41 +31,36 @@ def generate_html_contract(
referer_email=contract.form.referer.email, referer_email=contract.form.referer.email,
productor_name=contract.form.productor.name, productor_name=contract.form.productor.name,
productor_address=contract.form.productor.address, productor_address=contract.form.productor.address,
payment_methods_map={"cheque": "Ordre du chèque", "transfer": "virements"}, payment_methods_map={
"cheque": "Ordre du chèque",
"transfer": "virements"},
productor_payment_methods=contract.form.productor.payment_methods, productor_payment_methods=contract.form.productor.payment_methods,
member_name=f'{html.escape(contract.firstname)} {html.escape(contract.lastname)}', member_name=f'{
member_email=html.escape(contract.email), html.escape(
member_phone=html.escape(contract.phone), contract.firstname)} {
html.escape(
contract.lastname)}',
member_email=html.escape(
contract.email),
member_phone=html.escape(
contract.phone),
contract_start_date=contract.form.start, contract_start_date=contract.form.start,
contract_end_date=contract.form.end, contract_end_date=contract.form.end,
occasionals=occasionals, occasionals=occasionals,
recurrents=reccurents, recurrents=reccurents,
recurrent_price=recurrent_price, recurrent_price=recurrent_price,
total_price=total_price, total_price=total_price,
contract_payment_method={"cheque": "chèque", "transfer": "virements"}[contract.payment_method], contract_payment_method={
cheques=cheques "cheque": "chèque",
) "transfer": "virements"}[
# options = { contract.payment_method],
# 'page-size': 'Letter', cheques=cheques)
# 'margin-top': '0.5in',
# 'margin-right': '0.5in',
# 'margin-bottom': '0.5in',
# 'margin-left': '0.5in',
# 'encoding': "UTF-8",
# 'print-media-type': True,
# "disable-javascript": True,
# "disable-external-links": True,
# 'enable-local-file-access': False,
# "disable-local-file-access": True,
# "no-images": True,
# }
return HTML( return HTML(
string=output_text, string=output_text,
base_url=template_dir, base_url=template_dir,
).write_pdf() ).write_pdf()
from odfdo import Document, Table, Row, Cell
def generate_recap( def generate_recap(
contracts: list[models.Contract], contracts: list[models.Contract],
@@ -76,9 +74,8 @@ def generate_recap(
sheet.set_values(data) sheet.set_values(data)
doc.body.append(sheet) doc.body.append(sheet)
buffer = io.BytesIO() buffer = io.BytesIO()
doc.save(buffer) doc.save(buffer)
return buffer.getvalue() return buffer.getvalue()

View File

@@ -1,28 +1,57 @@
"""Contract service responsible for read, create, update and delete contracts"""
from sqlalchemy.orm import selectinload
from sqlmodel import Session, select from sqlmodel import Session, select
import src.models as models from src import models
def get_all( def get_all(
session: Session, session: Session,
user: models.User, user: models.User,
forms: list[str] = [], forms: list[str] | None = None,
form_id: int | None = None, form_id: int | None = None,
) -> list[models.ContractPublic]: ) -> list[models.ContractPublic]:
statement = select(models.Contract)\ """Get all contracts"""
.join(models.Form, models.Contract.form_id == models.Form.id)\ statement = (
.join(models.Productor, models.Form.productor_id == models.Productor.id)\ select(models.Contract)
.where(models.Productor.type.in_([r.name for r in user.roles]))\ .join(
models.Form,
models.Contract.form_id == models.Form.id
)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id
)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct() .distinct()
if len(forms) > 0: )
if forms:
statement = statement.where(models.Form.name.in_(forms)) statement = statement.where(models.Form.name.in_(forms))
if form_id: if form_id:
statement = statement.where(models.Form.id == form_id) statement = statement.where(models.Form.id == form_id)
return session.exec(statement.order_by(models.Contract.id)).all() return session.exec(statement.order_by(models.Contract.id)).all()
def get_one(session: Session, contract_id: int) -> models.ContractPublic:
def get_one(
session: Session,
contract_id: int
) -> models.ContractPublic:
"""Get one contract"""
return session.get(models.Contract, contract_id) return session.get(models.Contract, contract_id)
def create_one(session: Session, contract: models.ContractCreate) -> models.ContractPublic:
contract_create = contract.model_dump(exclude_unset=True, exclude=["products", "cheques"]) def create_one(
session: Session,
contract: models.ContractCreate
) -> models.ContractPublic:
"""Create one contract"""
contract_create = contract.model_dump(
exclude_unset=True,
exclude=["products", "cheques"]
)
new_contract = models.Contract(**contract_create) new_contract = models.Contract(**contract_create)
new_contract.cheques = [ new_contract.cheques = [
@@ -45,10 +74,27 @@ def create_one(session: Session, contract: models.ContractCreate) -> models.Cont
session.add(new_contract) session.add(new_contract)
session.commit() session.commit()
session.refresh(new_contract) session.refresh(new_contract)
return new_contract
def add_contract_file(session: Session, id: int, file: bytes, price: float): statement = (
statement = select(models.Contract).where(models.Contract.id == id) select(models.Contract)
.where(models.Contract.id == new_contract.id)
.options(
selectinload(models.Contract.form)
.selectinload(models.Form.productor)
)
)
return session.exec(statement).one()
def add_contract_file(
session: Session,
_id: int,
file: bytes,
price: float
):
"""Add a file to an existing contract"""
statement = select(models.Contract).where(models.Contract.id == _id)
result = session.exec(statement) result = session.exec(statement)
contract = result.first() contract = result.first()
contract.total_price = price contract.total_price = price
@@ -58,8 +104,14 @@ def add_contract_file(session: Session, id: int, file: bytes, price: float):
session.refresh(contract) session.refresh(contract)
return contract return contract
def update_one(session: Session, id: int, contract: models.ContractUpdate) -> models.ContractPublic:
statement = select(models.Contract).where(models.Contract.id == id) def update_one(
session: Session,
_id: int,
contract: models.ContractUpdate
) -> models.ContractPublic:
"""Update one contract"""
statement = select(models.Contract).where(models.Contract.id == _id)
result = session.exec(statement) result = session.exec(statement)
new_contract = result.first() new_contract = result.first()
if not new_contract: if not new_contract:
@@ -72,8 +124,13 @@ def update_one(session: Session, id: int, contract: models.ContractUpdate) -> mo
session.refresh(new_contract) session.refresh(new_contract)
return new_contract return new_contract
def delete_one(session: Session, id: int) -> models.ContractPublic:
statement = select(models.Contract).where(models.Contract.id == id) def delete_one(
session: Session,
_id: int
) -> models.ContractPublic:
"""Delete one contract"""
statement = select(models.Contract).where(models.Contract.id == _id)
result = session.exec(statement) result = session.exec(statement)
contract = result.first() contract = result.first()
if not contract: if not contract:
@@ -83,11 +140,29 @@ def delete_one(session: Session, id: int) -> models.ContractPublic:
session.commit() session.commit()
return result return result
def is_allowed(session: Session, user: models.User, id: int) -> bool:
statement = select(models.Contract)\ def is_allowed(
.join(models.Form, models.Contract.form_id == models.Form.id)\ session: Session,
.join(models.Productor, models.Form.productor_id == models.Productor.id)\ user: models.User,
.where(models.Contract.id == id)\ _id: int
.where(models.Productor.type.in_([r.name for r in user.roles]))\ ) -> bool:
"""Determine if a user is allowed to access a contract by id"""
statement = (
select(models.Contract)
.join(
models.Form,
models.Contract.form_id == models.Form.id
)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id
)
.where(models.Contract.id == _id)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct() .distinct()
return len(session.exec(statement).all()) > 0 )
return len(session.exec(statement).all()) > 0

View File

@@ -1,11 +1,14 @@
from sqlmodel import create_engine, SQLModel, Session from sqlmodel import Session, SQLModel, create_engine
from src.settings import settings from src.settings import settings
engine = create_engine(f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}') engine = create_engine(
f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}')
def get_session(): def get_session():
with Session(engine) as session: with Session(engine) as session:
yield session yield session
def create_all_tables(): def create_all_tables():
SQLModel.metadata.create_all(engine) SQLModel.metadata.create_all(engine)

View File

@@ -1,17 +1,26 @@
"""Forms module exceptions"""
import logging
class FormServiceError(Exception): class FormServiceError(Exception):
"""Form service exception"""
def __init__(self, message: str): def __init__(self, message: str):
super().__init__(message) super().__init__(message)
logging.error('FormService : %s', message)
class UserNotFoundError(FormServiceError): class UserNotFoundError(FormServiceError):
pass pass
class ProductorNotFoundError(FormServiceError): class ProductorNotFoundError(FormServiceError):
pass pass
class FormNotFoundError(FormServiceError): class FormNotFoundError(FormServiceError):
pass pass
class FormCreateError(FormServiceError): class FormCreateError(FormServiceError):
def __init__(self, message: str, field: str | None = None): def __init__(self, message: str, field: str | None = None):
super().__init__(message) super().__init__(message)
self.field = field self.field = field

View File

@@ -1,14 +1,15 @@
from fastapi import APIRouter, HTTPException, Depends, Query
import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.forms.service as service
import src.forms.exceptions as exceptions import src.forms.exceptions as exceptions
import src.forms.service as service
import src.messages as messages
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/forms') router = APIRouter(prefix='/forms')
@router.get('', response_model=list[models.FormPublic]) @router.get('', response_model=list[models.FormPublic])
async def get_forms( async def get_forms(
seasons: list[str] = Query([]), seasons: list[str] = Query([]),
@@ -18,6 +19,7 @@ async def get_forms(
): ):
return service.get_all(session, seasons, productors, current_season) return service.get_all(session, seasons, productors, current_season)
@router.get('/referents', response_model=list[models.FormPublic]) @router.get('/referents', response_model=list[models.FormPublic])
async def get_forms_filtered( async def get_forms_filtered(
seasons: list[str] = Query([]), seasons: list[str] = Query([]),
@@ -28,53 +30,60 @@ async def get_forms_filtered(
): ):
return service.get_all(session, seasons, productors, current_season, user) return service.get_all(session, seasons, productors, current_season, user)
@router.get('/{id}', response_model=models.FormPublic)
async def get_form(id: int, session: Session = Depends(get_session)): @router.get('/{_id}', response_model=models.FormPublic)
result = service.get_one(session, id) async def get_form(_id: int, session: Session = Depends(get_session)):
result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('form')) raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('form')
)
return result return result
@router.post('', response_model=models.FormPublic) @router.post('', response_model=models.FormPublic)
async def create_form( async def create_form(
form: models.FormCreate, form: models.FormCreate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
try: try:
form = service.create_one(session, form) form = service.create_one(session, form)
except exceptions.ProductorNotFoundError as error: except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(status_code=404, detail=str(error)) from error
except exceptions.UserNotFoundError as error: except exceptions.UserNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(status_code=404, detail=str(error)) from error
except exceptions.FormCreateError as error: except exceptions.FormCreateError as error:
raise HTTPException(status_code=400, detail=str(error)) raise HTTPException(status_code=400, detail=str(error)) from error
return form return form
@router.put('/{id}', response_model=models.FormPublic)
@router.put('/{_id}', response_model=models.FormPublic)
async def update_form( async def update_form(
id: int, form: models.FormUpdate, _id: int, form: models.FormUpdate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
try: try:
result = service.update_one(session, id, form) result = service.update_one(session, _id, form)
except exceptions.FormNotFoundError as error: except exceptions.FormNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(status_code=404, detail=str(error)) from error
except exceptions.ProductorNotFoundError as error: except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(status_code=404, detail=str(error)) from error
except exceptions.UserNotFoundError as error: except exceptions.UserNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(status_code=404, detail=str(error)) from error
return result return result
@router.delete('/{id}', response_model=models.FormPublic)
@router.delete('/{_id}', response_model=models.FormPublic)
async def delete_form( async def delete_form(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
try: try:
result = service.delete_one(session, id) result = service.delete_one(session, _id)
except exceptions.FormNotFoundError as error: except exceptions.FormNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(status_code=404, detail=str(error)) from error
return result return result

View File

@@ -1,12 +1,12 @@
from sqlmodel import Session, select
from sqlalchemy import func
import src.models as models
import src.forms.exceptions as exceptions import src.forms.exceptions as exceptions
import src.messages as messages import src.messages as messages
from sqlalchemy import func
from sqlmodel import Session, select
from src import models
def get_all( def get_all(
session: Session, session: Session,
seasons: list[str], seasons: list[str],
productors: list[str], productors: list[str],
current_season: bool, current_season: bool,
@@ -14,45 +14,54 @@ def get_all(
) -> list[models.FormPublic]: ) -> list[models.FormPublic]:
statement = select(models.Form) statement = select(models.Form)
if user: if user:
statement = statement\ statement = statement .join(
.join(models.Productor, models.Form.productor_id == models.Productor.id)\ models.Productor,
.where(models.Productor.type.in_([r.name for r in user.roles]))\ models.Form.productor_id == models.Productor.id) .where(
.distinct() models.Productor.type.in_(
[
r.name for r in user.roles])) .distinct()
if len(seasons) > 0: if len(seasons) > 0:
statement = statement.where(models.Form.season.in_(seasons)) statement = statement.where(models.Form.season.in_(seasons))
if len(productors) > 0: if len(productors) > 0:
statement = statement.join(models.Productor).where(models.Productor.name.in_(productors)) statement = statement.join(
models.Productor).where(
models.Productor.name.in_(productors))
if not user: if not user:
statement = statement.where(models.Form.visible == True) statement = statement.where(models.Form.visible)
if current_season: if current_season:
subquery = ( subquery = (
select( select(
models.Productor.type, models.Productor.type,
func.max(models.Form.start).label("max_start") func.max(models.Form.start).label("max_start")
) )
.join(models.Form)\ .join(models.Form)
.group_by(models.Productor.type)\ .group_by(models.Productor.type)
.subquery() .subquery()
) )
statement = select(models.Form)\ statement = select(models.Form)\
.join(models.Productor)\ .join(models.Productor)\
.join(subquery, .join(subquery,
(models.Productor.type == subquery.c.type) & (models.Productor.type == subquery.c.type) &
(models.Form.start == subquery.c.max_start) (models.Form.start == subquery.c.max_start)
) )
if not user: if not user:
statement = statement.where(models.Form.visible == True) statement = statement.where(models.Form.visible)
return session.exec(statement.order_by(models.Form.name)).all() return session.exec(statement.order_by(models.Form.name)).all()
return session.exec(statement.order_by(models.Form.name)).all() return session.exec(statement.order_by(models.Form.name)).all()
def get_one(session: Session, form_id: int) -> models.FormPublic: def get_one(session: Session, form_id: int) -> models.FormPublic:
return session.get(models.Form, form_id) return session.get(models.Form, form_id)
def create_one(session: Session, form: models.FormCreate) -> models.FormPublic: def create_one(session: Session, form: models.FormCreate) -> models.FormPublic:
if not form: if not form:
raise exceptions.FormCreateError(messages.Messages.invalid_input('form', 'input cannot be None')) raise exceptions.FormCreateError(
messages.Messages.invalid_input(
'form', 'input cannot be None'))
if not session.get(models.Productor, form.productor_id): if not session.get(models.Productor, form.productor_id):
raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor')) raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
if not session.get(models.User, form.referer_id): if not session.get(models.User, form.referer_id):
raise exceptions.UserNotFoundError(messages.Messages.not_found('user')) raise exceptions.UserNotFoundError(messages.Messages.not_found('user'))
form_create = form.model_dump(exclude_unset=True) form_create = form.model_dump(exclude_unset=True)
@@ -62,14 +71,20 @@ def create_one(session: Session, form: models.FormCreate) -> models.FormPublic:
session.refresh(new_form) session.refresh(new_form)
return new_form return new_form
def update_one(session: Session, id: int, form: models.FormUpdate) -> models.FormPublic:
statement = select(models.Form).where(models.Form.id == id) def update_one(
session: Session,
_id: int,
form: models.FormUpdate) -> models.FormPublic:
statement = select(models.Form).where(models.Form.id == _id)
result = session.exec(statement) result = session.exec(statement)
new_form = result.first() new_form = result.first()
if not new_form: if not new_form:
raise exceptions.FormNotFoundError(messages.Messages.not_found('form')) raise exceptions.FormNotFoundError(messages.Messages.not_found('form'))
if form.productor_id and not session.get(models.Productor, form.productor_id): if form.productor_id and not session.get(
raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor')) models.Productor, form.productor_id):
raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
if form.referer_id and not session.get(models.User, form.referer_id): if form.referer_id and not session.get(models.User, form.referer_id):
raise exceptions.UserNotFoundError(messages.Messages.not_found('user')) raise exceptions.UserNotFoundError(messages.Messages.not_found('user'))
form_updates = form.model_dump(exclude_unset=True) form_updates = form.model_dump(exclude_unset=True)
@@ -80,8 +95,9 @@ def update_one(session: Session, id: int, form: models.FormUpdate) -> models.For
session.refresh(new_form) session.refresh(new_form)
return new_form return new_form
def delete_one(session: Session, id: int) -> models.FormPublic:
statement = select(models.Form).where(models.Form.id == id) def delete_one(session: Session, _id: int) -> models.FormPublic:
statement = select(models.Form).where(models.Form.id == _id)
result = session.exec(statement) result = session.exec(statement)
form = result.first() form = result.first()
if not form: if not form:
@@ -91,10 +107,19 @@ def delete_one(session: Session, id: int) -> models.FormPublic:
session.commit() session.commit()
return result return result
def is_allowed(session: Session, user: models.User, id: int) -> bool:
statement = select(models.Form)\ def is_allowed(session: Session, user: models.User, _id: int) -> bool:
.join(models.Productor, models.Form.productor_id == models.Productor.id)\ statement = (
.where(models.Form.id == id)\ select(models.Form)
.where(models.Productor.type.in_([r.name for r in user.roles]))\ .join(
models.Productor,
models.Form.productor_id == models.Productor.id)
.where(models.Form.id == _id)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct() .distinct()
return len(session.exec(statement).all()) > 0 )
return len(session.exec(statement).all()) > 0

View File

@@ -1,18 +1,15 @@
from sqlmodel import SQLModel
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from src.auth.auth import router as auth_router
from src.templates.templates import router as template_router
from src.contracts.contracts import router as contracts_router from src.contracts.contracts import router as contracts_router
from src.forms.forms import router as forms_router from src.forms.forms import router as forms_router
from src.productors.productors import router as productors_router from src.productors.productors import router as productors_router
from src.products.products import router as products_router from src.products.products import router as products_router
from src.users.users import router as users_router
from src.auth.auth import router as auth_router
from src.shipments.shipments import router as shipment_router
from src.settings import settings from src.settings import settings
from src.database import engine, create_all_tables from src.shipments.shipments import router as shipment_router
from src.templates.templates import router as template_router
from src.users.users import router as users_router
app = FastAPI() app = FastAPI()
@@ -34,4 +31,4 @@ app.include_router(productors_router, prefix="/api")
app.include_router(products_router, prefix="/api") app.include_router(products_router, prefix="/api")
app.include_router(users_router, prefix="/api") app.include_router(users_router, prefix="/api")
app.include_router(auth_router, prefix="/api") app.include_router(auth_router, prefix="/api")
app.include_router(shipment_router, prefix="/api") app.include_router(shipment_router, prefix="/api")

View File

@@ -1,19 +1,20 @@
pdferror = 'An error occured during PDF generation please contact administrator' pdferror = 'An error occured during PDF generation please contact administrator'
class Messages: class Messages:
unauthorized = 'User is Unauthorized' unauthorized = 'User is Unauthorized'
notauthenticated = 'User is not authenticated' notauthenticated = 'User is not authenticated'
tokenexipired = 'Token has expired' tokenexipired = 'Token has expired'
invalidtoken = 'Token is invalid' invalidtoken = 'Token is invalid'
@staticmethod @staticmethod
def not_found(resource: str) -> str: def not_found(resource: str) -> str:
return f'{resource.capitalize()} not found' return f'{resource.capitalize()} not found'
@staticmethod @staticmethod
def invalid_input(resource: str, reason: str = "") -> str: def invalid_input(resource: str, reason: str = "") -> str:
return f'Invalid {resource} input {':' if reason else ""} {reason}' return f'Invalid {resource} input {':' if reason else ""} {reason}'
@staticmethod @staticmethod
def not_allowed(resource: str, action: str) -> str: def not_allowed(resource: str, action: str) -> str:
return f'User is not allowed to {action} this {resource}' return f'User is not allowed to {action} this {resource}'

View File

@@ -1,99 +1,136 @@
from sqlmodel import Field, SQLModel, Relationship, Column, LargeBinary import datetime
from enum import StrEnum from enum import StrEnum
from typing import Optional from typing import Optional
import datetime
from sqlmodel import Column, Field, LargeBinary, Relationship, SQLModel
class ContractType(SQLModel, table=True): class ContractType(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(
default=None,
primary_key=True
)
name: str name: str
class UserContractTypeLink(SQLModel, table=True): class UserContractTypeLink(SQLModel, table=True):
user_id: int = Field(foreign_key="user.id", primary_key=True) user_id: int = Field(
contract_type_id: int = Field(foreign_key="contracttype.id", primary_key=True) foreign_key='user.id',
primary_key=True
)
contract_type_id: int = Field(
foreign_key='contracttype.id',
primary_key=True
)
class UserBase(SQLModel): class UserBase(SQLModel):
name: str name: str
email: str email: str
class UserPublic(UserBase): class UserPublic(UserBase):
id: int id: int
roles: list[ContractType] roles: list[ContractType]
class User(UserBase, table=True): class User(UserBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
roles: list[ContractType] = Relationship( roles: list[ContractType] = Relationship(
link_model=UserContractTypeLink link_model=UserContractTypeLink
) )
class UserUpdate(SQLModel): class UserUpdate(SQLModel):
name: str | None name: str | None
email: str | None email: str | None
role_names: list[str] | None role_names: list[str] | None
class UserCreate(UserBase): class UserCreate(UserBase):
role_names: list[str] | None role_names: list[str] | None
class PaymentMethodBase(SQLModel): class PaymentMethodBase(SQLModel):
name: str name: str
details: str details: str
max: int | None max: int | None
class PaymentMethod(PaymentMethodBase, table=True): class PaymentMethod(PaymentMethodBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
productor_id: int = Field(foreign_key="productor.id", ondelete="CASCADE") productor_id: int = Field(foreign_key='productor.id', ondelete='CASCADE')
productor: Optional["Productor"] = Relationship( productor: Optional['Productor'] = Relationship(
back_populates="payment_methods", back_populates='payment_methods',
) )
class PaymentMethodPublic(PaymentMethodBase): class PaymentMethodPublic(PaymentMethodBase):
id: int id: int
productor: Optional["Productor"] productor: Optional['Productor']
class ProductorBase(SQLModel): class ProductorBase(SQLModel):
name: str name: str
address: str address: str
type: str type: str
class ProductorPublic(ProductorBase): class ProductorPublic(ProductorBase):
id: int id: int
products: list["Product"] = [] products: list['Product'] = Field(default_factory=list)
payment_methods: list["PaymentMethod"] = [] payment_methods: list['PaymentMethod'] = Field(default_factory=list)
class Productor(ProductorBase, table=True): class Productor(ProductorBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
products: list["Product"] = Relationship( products: list['Product'] = Relationship(
back_populates='productor', back_populates='productor',
sa_relationship_kwargs={ sa_relationship_kwargs={
"order_by": "Product.name" 'order_by': 'Product.name'
}, },
) )
payment_methods: list["PaymentMethod"] = Relationship( payment_methods: list['PaymentMethod'] = Relationship(
back_populates="productor", back_populates='productor',
cascade_delete=True cascade_delete=True
) )
class ProductorUpdate(SQLModel): class ProductorUpdate(SQLModel):
name: str | None name: str | None
address: str | None address: str | None
payment_methods: list["PaymentMethod"] = [] payment_methods: list['PaymentMethod'] = Field(default_factory=list)
type: str | None type: str | None
class ProductorCreate(ProductorBase): class ProductorCreate(ProductorBase):
payment_methods: list["PaymentMethod"] = [] payment_methods: list['PaymentMethod'] = Field(default_factory=list)
class Unit(StrEnum): class Unit(StrEnum):
GRAMS = "1" GRAMS = '1'
KILO = "2" KILO = '2'
PIECE = "3" PIECE = '3'
class ProductType(StrEnum): class ProductType(StrEnum):
OCCASIONAL = "1" OCCASIONAL = '1'
RECCURENT = "2" RECCURENT = '2'
class ShipmentProductLink(SQLModel, table=True): class ShipmentProductLink(SQLModel, table=True):
shipment_id: Optional[int] = Field(default=None, foreign_key="shipment.id", primary_key=True) shipment_id: Optional[int] = Field(
product_id: Optional[int] = Field(default=None, foreign_key="product.id", primary_key=True) default=None,
foreign_key='shipment.id',
primary_key=True
)
product_id: Optional[int] = Field(
default=None,
foreign_key='product.id',
primary_key=True
)
class ProductBase(SQLModel): class ProductBase(SQLModel):
name: str name: str
@@ -103,17 +140,31 @@ class ProductBase(SQLModel):
quantity: float | None quantity: float | None
quantity_unit: str | None quantity_unit: str | None
type: ProductType type: ProductType
productor_id: int | None = Field(default=None, foreign_key="productor.id") productor_id: int | None = Field(
default=None,
foreign_key='productor.id'
)
class ProductPublic(ProductBase): class ProductPublic(ProductBase):
id: int id: int
productor: Productor | None productor: Productor | None
shipments: list["Shipment"] | None shipments: list['Shipment'] | None
class Product(ProductBase, table=True): class Product(ProductBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(
shipments: list["Shipment"] = Relationship(back_populates="products", link_model=ShipmentProductLink) default=None,
productor: Optional[Productor] = Relationship(back_populates="products") primary_key=True
)
shipments: list['Shipment'] = Relationship(
back_populates='products',
link_model=ShipmentProductLink
)
productor: Optional[Productor] = Relationship(
back_populates='products'
)
class ProductUpdate(SQLModel): class ProductUpdate(SQLModel):
name: str | None name: str | None
@@ -125,41 +176,46 @@ class ProductUpdate(SQLModel):
productor_id: int | None productor_id: int | None
type: ProductType | None type: ProductType | None
class ProductCreate(ProductBase): class ProductCreate(ProductBase):
pass pass
class FormBase(SQLModel): class FormBase(SQLModel):
name: str name: str
productor_id: int | None = Field(default=None, foreign_key="productor.id") productor_id: int | None = Field(default=None, foreign_key='productor.id')
referer_id: int | None = Field(default=None, foreign_key="user.id") referer_id: int | None = Field(default=None, foreign_key='user.id')
season: str season: str
start: datetime.date start: datetime.date
end: datetime.date end: datetime.date
minimum_shipment_value: float | None minimum_shipment_value: float | None
visible: bool visible: bool
class FormPublic(FormBase): class FormPublic(FormBase):
id: int id: int
productor: ProductorPublic | None productor: ProductorPublic | None
referer: User | None referer: User | None
shipments: list["ShipmentPublic"] = [] shipments: list['ShipmentPublic'] = Field(default_factory=list)
class Form(FormBase, table=True): class Form(FormBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
productor: Optional['Productor'] = Relationship() productor: Optional['Productor'] = Relationship()
referer: Optional['User'] = Relationship() referer: Optional['User'] = Relationship()
shipments: list["Shipment"] = Relationship( shipments: list['Shipment'] = Relationship(
back_populates="form", back_populates='form',
cascade_delete=True, cascade_delete=True,
sa_relationship_kwargs={ sa_relationship_kwargs={
"order_by": "Shipment.name" 'order_by': 'Shipment.name'
}, },
) )
contracts: list["Contract"] = Relationship( contracts: list['Contract'] = Relationship(
back_populates="form", back_populates='form',
cascade_delete=True cascade_delete=True
) )
class FormUpdate(SQLModel): class FormUpdate(SQLModel):
name: str | None name: str | None
productor_id: int | None productor_id: int | None
@@ -170,35 +226,44 @@ class FormUpdate(SQLModel):
minimum_shipment_value: float | None minimum_shipment_value: float | None
visible: bool | None visible: bool | None
class FormCreate(FormBase): class FormCreate(FormBase):
pass pass
class TemplateBase(SQLModel): class TemplateBase(SQLModel):
pass pass
class TemplatePublic(TemplateBase): class TemplatePublic(TemplateBase):
id: int id: int
class Template(TemplateBase, table=True): class Template(TemplateBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
class TemplateUpdate(SQLModel): class TemplateUpdate(SQLModel):
pass pass
class TemplateCreate(TemplateBase): class TemplateCreate(TemplateBase):
pass pass
class ChequeBase(SQLModel): class ChequeBase(SQLModel):
name: str name: str
value: str value: str
class Cheque(ChequeBase, table=True): class Cheque(ChequeBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
contract_id: int = Field(foreign_key="contract.id", ondelete="CASCADE") contract_id: int = Field(foreign_key='contract.id', ondelete='CASCADE')
contract: Optional["Contract"] = Relationship( contract: Optional['Contract'] = Relationship(
back_populates="cheques", back_populates='cheques',
) )
class ContractBase(SQLModel): class ContractBase(SQLModel):
firstname: str firstname: str
lastname: str lastname: str
@@ -207,105 +272,122 @@ class ContractBase(SQLModel):
payment_method: str payment_method: str
cheque_quantity: int cheque_quantity: int
class Contract(ContractBase, table=True): class Contract(ContractBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
form_id: int = Field( form_id: int = Field(
foreign_key="form.id", foreign_key='form.id',
nullable=False, nullable=False,
ondelete="CASCADE" ondelete='CASCADE'
) )
products: list["ContractProduct"] = Relationship( products: list['ContractProduct'] = Relationship(
back_populates="contract", back_populates='contract',
cascade_delete=True cascade_delete=True
) )
form: Optional[Form] = Relationship(back_populates="contracts") form: Form = Relationship(back_populates='contracts')
cheques: list[Cheque] = Relationship( cheques: list[Cheque] = Relationship(
back_populates="contract", back_populates='contract',
cascade_delete=True cascade_delete=True
) )
file: bytes = Field(sa_column=Column(LargeBinary)) file: bytes = Field(sa_column=Column(LargeBinary))
total_price: float | None total_price: float | None
class ContractCreate(ContractBase): class ContractCreate(ContractBase):
products: list["ContractProductCreate"] = [] products: list['ContractProductCreate'] = Field(default_factory=list)
cheques: list["Cheque"] = [] cheques: list['Cheque'] = Field(default_factory=list)
form_id: int form_id: int
class ContractUpdate(SQLModel): class ContractUpdate(SQLModel):
file: bytes file: bytes
class ContractPublic(ContractBase): class ContractPublic(ContractBase):
id: int id: int
products: list["ContractProduct"] = [] products: list['ContractProduct'] = Field(default_factory=list)
form: Form form: Form
total_price: float | None total_price: float | None
# file: bytes # file: bytes
class ContractProductBase(SQLModel): class ContractProductBase(SQLModel):
product_id: int = Field( product_id: int = Field(
foreign_key="product.id", foreign_key='product.id',
nullable=False, nullable=False,
ondelete="CASCADE" ondelete='CASCADE'
) )
shipment_id: int | None = Field( shipment_id: int | None = Field(
default=None, default=None,
foreign_key="shipment.id", foreign_key='shipment.id',
nullable=True, nullable=True,
ondelete="CASCADE" ondelete='CASCADE'
) )
quantity: float quantity: float
class ContractProduct(ContractProductBase, table=True): class ContractProduct(ContractProductBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
contract_id: int = Field( contract_id: int = Field(
foreign_key="contract.id", foreign_key='contract.id',
nullable=False, nullable=False,
ondelete="CASCADE" ondelete='CASCADE'
) )
contract: Optional["Contract"] = Relationship(back_populates="products") contract: Optional['Contract'] = Relationship(back_populates='products')
product: Optional["Product"] = Relationship() product: Optional['Product'] = Relationship()
shipment: Optional["Shipment"] = Relationship() shipment: Optional['Shipment'] = Relationship()
class ContractProductPublic(ContractProductBase): class ContractProductPublic(ContractProductBase):
id: int id: int
quantity: float quantity: float
contract: Contract contract: Contract
product: Product product: Product
shipment: Optional["Shipment"] shipment: Optional['Shipment']
class ContractProductCreate(ContractProductBase): class ContractProductCreate(ContractProductBase):
pass pass
class ContractProductUpdate(ContractProductBase): class ContractProductUpdate(ContractProductBase):
pass pass
class ShipmentBase(SQLModel): class ShipmentBase(SQLModel):
name: str name: str
date: datetime.date date: datetime.date
form_id: int | None = Field(default=None, foreign_key="form.id", ondelete="CASCADE") form_id: int | None = Field(
default=None,
foreign_key='form.id',
ondelete='CASCADE')
class ShipmentPublic(ShipmentBase): class ShipmentPublic(ShipmentBase):
id: int id: int
products: list[Product] = [] products: list[Product] = Field(default_factory=list)
form: Form | None form: Form | None
class Shipment(ShipmentBase, table=True): class Shipment(ShipmentBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
products: list[Product] = Relationship( products: list[Product] = Relationship(
back_populates="shipments", back_populates='shipments',
link_model=ShipmentProductLink, link_model=ShipmentProductLink,
sa_relationship_kwargs={ sa_relationship_kwargs={
"order_by": "Product.name" 'order_by': 'Product.name'
}, },
) )
form: Optional[Form] = Relationship(back_populates="shipments") form: Optional[Form] = Relationship(back_populates='shipments')
class ShipmentUpdate(SQLModel): class ShipmentUpdate(SQLModel):
name: str | None name: str | None
date: datetime.date | None date: datetime.date | None
product_ids: list[int] | None = [] product_ids: list[int] | None = Field(default_factory=list)
class ShipmentCreate(ShipmentBase): class ShipmentCreate(ShipmentBase):
product_ids: list[int] = [] product_ids: list[int] = Field(default_factory=list)
form_id: int form_id: int

View File

@@ -1,3 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr> # SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT

View File

@@ -1,11 +1,17 @@
import logging
class ProductorServiceError(Exception): class ProductorServiceError(Exception):
def __init__(self, message: str): def __init__(self, message: str):
super().__init__(message) super().__init__(message)
logging.error('ProductorService : %s', message)
class ProductorNotFoundError(ProductorServiceError): class ProductorNotFoundError(ProductorServiceError):
pass pass
class ProductorCreateError(ProductorServiceError): class ProductorCreateError(ProductorServiceError):
def __init__(self, message: str, field: str | None = None): def __init__(self, message: str, field: str | None = None):
super().__init__(message) super().__init__(message)
self.field = field self.field = field

View File

@@ -1,14 +1,15 @@
from fastapi import APIRouter, HTTPException, Depends, Query
import src.messages as messages import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.productors.service as service
import src.productors.exceptions as exceptions import src.productors.exceptions as exceptions
import src.productors.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/productors') router = APIRouter(prefix='/productors')
@router.get('', response_model=list[models.ProductorPublic]) @router.get('', response_model=list[models.ProductorPublic])
def get_productors( def get_productors(
names: list[str] = Query([]), names: list[str] = Query([]),
@@ -18,49 +19,56 @@ def get_productors(
): ):
return service.get_all(session, user, names, types) return service.get_all(session, user, names, types)
@router.get('/{id}', response_model=models.ProductorPublic)
@router.get('/{_id}', response_model=models.ProductorPublic)
def get_productor( def get_productor(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.get_one(session, id) result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('productor')) raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('productor')
)
return result return result
@router.post('', response_model=models.ProductorPublic) @router.post('', response_model=models.ProductorPublic)
def create_productor( def create_productor(
productor: models.ProductorCreate, productor: models.ProductorCreate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
try: try:
result = service.create_one(session, productor) result = service.create_one(session, productor)
except exceptions.ProductorCreateError as error: except exceptions.ProductorCreateError as error:
raise HTTPException(status_code=400, detail=str(error)) raise HTTPException(status_code=400, detail=str(error)) from error
return result return result
@router.put('/{id}', response_model=models.ProductorPublic)
@router.put('/{_id}', response_model=models.ProductorPublic)
def update_productor( def update_productor(
id: int, productor: models.ProductorUpdate, _id: int, productor: models.ProductorUpdate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
try: try:
result = service.update_one(session, id, productor) result = service.update_one(session, _id, productor)
except exceptions.ProductorNotFoundError as error: except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(status_code=404, detail=str(error)) from error
return result return result
@router.delete('/{id}', response_model=models.ProductorPublic)
@router.delete('/{_id}', response_model=models.ProductorPublic)
def delete_productor( def delete_productor(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
try: try:
result = service.delete_one(session, id) result = service.delete_one(session, _id)
except exceptions.ProductorNotFoundError as error: except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(status_code=404, detail=str(error)) from error
return result return result

View File

@@ -1,12 +1,13 @@
from sqlmodel import Session, select
import src.models as models
import src.productors.exceptions as exceptions
import src.messages as messages import src.messages as messages
import src.productors.exceptions as exceptions
from sqlmodel import Session, select
from src import models
def get_all( def get_all(
session: Session, session: Session,
user: models.User, user: models.User,
names: list[str], names: list[str],
types: list[str] types: list[str]
) -> list[models.ProductorPublic]: ) -> list[models.ProductorPublic]:
statement = select(models.Productor)\ statement = select(models.Productor)\
@@ -18,13 +19,20 @@ def get_all(
statement = statement.where(models.Productor.type.in_(types)) statement = statement.where(models.Productor.type.in_(types))
return session.exec(statement.order_by(models.Productor.name)).all() return session.exec(statement.order_by(models.Productor.name)).all()
def get_one(session: Session, productor_id: int) -> models.ProductorPublic: def get_one(session: Session, productor_id: int) -> models.ProductorPublic:
return session.get(models.Productor, productor_id) return session.get(models.Productor, productor_id)
def create_one(session: Session, productor: models.ProductorCreate) -> models.ProductorPublic:
def create_one(
session: Session,
productor: models.ProductorCreate) -> models.ProductorPublic:
if not productor: if not productor:
raise exceptions.ProductorCreateError(messages.Messages.invalid_input('productor', 'input cannot be None')) raise exceptions.ProductorCreateError(
productor_create = productor.model_dump(exclude_unset=True, exclude='payment_methods') messages.Messages.invalid_input(
'productor', 'input cannot be None'))
productor_create = productor.model_dump(
exclude_unset=True, exclude='payment_methods')
new_productor = models.Productor(**productor_create) new_productor = models.Productor(**productor_create)
new_productor.payment_methods = [ new_productor.payment_methods = [
@@ -39,13 +47,18 @@ def create_one(session: Session, productor: models.ProductorCreate) -> models.Pr
session.refresh(new_productor) session.refresh(new_productor)
return new_productor return new_productor
def update_one(session: Session, id: int, productor: models.ProductorUpdate) -> models.ProductorPublic:
def update_one(
session: Session,
id: int,
productor: models.ProductorUpdate) -> models.ProductorPublic:
statement = select(models.Productor).where(models.Productor.id == id) statement = select(models.Productor).where(models.Productor.id == id)
result = session.exec(statement) result = session.exec(statement)
new_productor = result.first() new_productor = result.first()
if not new_productor: if not new_productor:
raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor')) raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
productor_updates = productor.model_dump(exclude_unset=True) productor_updates = productor.model_dump(exclude_unset=True)
if 'payment_methods' in productor_updates: if 'payment_methods' in productor_updates:
new_productor.payment_methods.clear() new_productor.payment_methods.clear()
@@ -67,12 +80,14 @@ def update_one(session: Session, id: int, productor: models.ProductorUpdate) ->
session.refresh(new_productor) session.refresh(new_productor)
return new_productor return new_productor
def delete_one(session: Session, id: int) -> models.ProductorPublic: def delete_one(session: Session, id: int) -> models.ProductorPublic:
statement = select(models.Productor).where(models.Productor.id == id) statement = select(models.Productor).where(models.Productor.id == id)
result = session.exec(statement) result = session.exec(statement)
productor = result.first() productor = result.first()
if not productor: if not productor:
raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor')) raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
result = models.ProductorPublic.model_validate(productor) result = models.ProductorPublic.model_validate(productor)
session.delete(productor) session.delete(productor)
session.commit() session.commit()

View File

@@ -1,3 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr> # SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT

View File

@@ -2,13 +2,16 @@ class ProductServiceError(Exception):
def __init__(self, message: str): def __init__(self, message: str):
super().__init__(message) super().__init__(message)
class ProductorNotFoundError(ProductServiceError): class ProductorNotFoundError(ProductServiceError):
pass pass
class ProductNotFoundError(ProductServiceError): class ProductNotFoundError(ProductServiceError):
pass pass
class ProductCreateError(ProductServiceError): class ProductCreateError(ProductServiceError):
def __init__(self, message: str, field: str | None = None): def __init__(self, message: str, field: str | None = None):
super().__init__(message) super().__init__(message)
self.field = field self.field = field

View File

@@ -1,18 +1,19 @@
from fastapi import APIRouter, HTTPException, Depends, Query
import src.messages as messages import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.products.service as service
import src.products.exceptions as exceptions import src.products.exceptions as exceptions
import src.products.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/products') router = APIRouter(prefix='/products')
@router.get('', response_model=list[models.ProductPublic], ) @router.get('', response_model=list[models.ProductPublic], )
def get_products( def get_products(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session), session: Session = Depends(get_session),
names: list[str] = Query([]), names: list[str] = Query([]),
types: list[str] = Query([]), types: list[str] = Query([]),
productors: list[str] = Query([]), productors: list[str] = Query([]),
@@ -20,25 +21,28 @@ def get_products(
return service.get_all( return service.get_all(
session, session,
user, user,
names, names,
productors, productors,
types, types,
) )
@router.get('/{id}', response_model=models.ProductPublic) @router.get('/{id}', response_model=models.ProductPublic)
def get_product( def get_product(
id: int, id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.get_one(session, id) result = service.get_one(session, id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('product')) raise HTTPException(status_code=404,
detail=messages.Messages.not_found('product'))
return result return result
@router.post('', response_model=models.ProductPublic) @router.post('', response_model=models.ProductPublic)
def create_product( def create_product(
product: models.ProductCreate, product: models.ProductCreate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
@@ -50,9 +54,10 @@ def create_product(
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(status_code=404, detail=str(error))
return result return result
@router.put('/{id}', response_model=models.ProductPublic) @router.put('/{id}', response_model=models.ProductPublic)
def update_product( def update_product(
id: int, product: models.ProductUpdate, id: int, product: models.ProductUpdate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
@@ -64,9 +69,10 @@ def update_product(
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(status_code=404, detail=str(error))
return result return result
@router.delete('/{id}', response_model=models.ProductPublic) @router.delete('/{id}', response_model=models.ProductPublic)
def delete_product( def delete_product(
id: int, id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):

View File

@@ -1,19 +1,23 @@
from sqlmodel import Session, select
import src.models as models
import src.products.exceptions as exceptions
import src.messages as messages import src.messages as messages
import src.products.exceptions as exceptions
from sqlmodel import Session, select
from src import models
def get_all( def get_all(
session: Session, session: Session,
user: models.User, user: models.User,
names: list[str], names: list[str],
productors: list[str], productors: list[str],
types: list[str], types: list[str],
) -> list[models.ProductPublic]: ) -> list[models.ProductPublic]:
statement = select(models.Product)\ statement = select(
.join(models.Productor, models.Product.productor_id == models.Productor.id)\ models.Product) .join(
.where(models.Productor.type.in_([r.name for r in user.roles]))\ models.Productor,
.distinct() models.Product.productor_id == models.Productor.id) .where(
models.Productor.type.in_(
[
r.name for r in user.roles])) .distinct()
if len(names) > 0: if len(names) > 0:
statement = statement.where(models.Product.name.in_(names)) statement = statement.where(models.Product.name.in_(names))
if len(productors) > 0: if len(productors) > 0:
@@ -22,14 +26,21 @@ def get_all(
statement = statement.where(models.Product.type.in_(types)) statement = statement.where(models.Product.type.in_(types))
return session.exec(statement.order_by(models.Product.name)).all() return session.exec(statement.order_by(models.Product.name)).all()
def get_one(session: Session, product_id: int) -> models.ProductPublic: def get_one(session: Session, product_id: int) -> models.ProductPublic:
return session.get(models.Product, product_id) return session.get(models.Product, product_id)
def create_one(session: Session, product: models.ProductCreate) -> models.ProductPublic:
def create_one(
session: Session,
product: models.ProductCreate) -> models.ProductPublic:
if not product: if not product:
raise exceptions.ProductCreateError(messages.Messages.invalid_input('product', 'input cannot be None')) raise exceptions.ProductCreateError(
messages.Messages.invalid_input(
'product', 'input cannot be None'))
if not session.get(models.Productor, product.productor_id): if not session.get(models.Productor, product.productor_id):
raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor')) raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
product_create = product.model_dump(exclude_unset=True) product_create = product.model_dump(exclude_unset=True)
new_product = models.Product(**product_create) new_product = models.Product(**product_create)
session.add(new_product) session.add(new_product)
@@ -37,14 +48,21 @@ def create_one(session: Session, product: models.ProductCreate) -> models.Produc
session.refresh(new_product) session.refresh(new_product)
return new_product return new_product
def update_one(session: Session, id: int, product: models.ProductUpdate) -> models.ProductPublic:
def update_one(
session: Session,
id: int,
product: models.ProductUpdate) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == id) statement = select(models.Product).where(models.Product.id == id)
result = session.exec(statement) result = session.exec(statement)
new_product = result.first() new_product = result.first()
if not new_product: if not new_product:
raise exceptions.ProductNotFoundError(messages.Messages.not_found('product')) raise exceptions.ProductNotFoundError(
if product.productor_id and not session.get(models.Productor, product.productor_id): messages.Messages.not_found('product'))
raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor')) if product.productor_id and not session.get(
models.Productor, product.productor_id):
raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
product_updates = product.model_dump(exclude_unset=True) product_updates = product.model_dump(exclude_unset=True)
for key, value in product_updates.items(): for key, value in product_updates.items():
@@ -55,12 +73,14 @@ def update_one(session: Session, id: int, product: models.ProductUpdate) -> mode
session.refresh(new_product) session.refresh(new_product)
return new_product return new_product
def delete_one(session: Session, id: int) -> models.ProductPublic: def delete_one(session: Session, id: int) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == id) statement = select(models.Product).where(models.Product.id == id)
result = session.exec(statement) result = session.exec(statement)
product = result.first() product = result.first()
if not product: if not product:
raise exceptions.ProductNotFoundError(messages.Messages.not_found('product')) raise exceptions.ProductNotFoundError(
messages.Messages.not_found('product'))
result = models.ProductPublic.model_validate(product) result = models.ProductPublic.model_validate(product)
session.delete(product) session.delete(product)
session.commit() session.commit()

View File

@@ -1,5 +1,6 @@
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings): class Settings(BaseSettings):
origins: str origins: str
db_host: str db_host: str
@@ -20,10 +21,21 @@ class Settings(BaseSettings):
env_file='../.env' env_file='../.env'
) )
settings = Settings() settings = Settings()
AUTH_URL = f"{settings.keycloak_server}/realms/{settings.keycloak_realm}/protocol/openid-connect/auth" AUTH_URL = (
TOKEN_URL = f"{settings.keycloak_server}/realms/{settings.keycloak_realm}/protocol/openid-connect/token" f'{settings.keycloak_server}/realms/'
ISSUER = f"{settings.keycloak_server}/realms/{settings.keycloak_realm}" f'{settings.keycloak_realm}/protocol/openid-connect/auth'
JWKS_URL = f"{ISSUER}/protocol/openid-connect/certs" )
LOGOUT_URL = f'{settings.keycloak_server}/realms/{settings.keycloak_realm}/protocol/openid-connect/logout' TOKEN_URL = (
f'{settings.keycloak_server}/realms/'
f'{settings.keycloak_realm}/protocol/openid-connect/token'
)
ISSUER = f'{settings.keycloak_server}/realms/{settings.keycloak_realm}'
JWKS_URL = f'{ISSUER}/protocol/openid-connect/certs'
LOGOUT_URL = (
f'{settings.keycloak_server}/realms/'
f'{settings.keycloak_realm}/protocol/openid-connect/logout'
)

View File

@@ -1,3 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr> # SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT

View File

@@ -1,11 +1,17 @@
import logging
class ShipmentServiceError(Exception): class ShipmentServiceError(Exception):
def __init__(self, message: str): def __init__(self, message: str):
super().__init__(message) super().__init__(message)
logging.error('ShipmentService : %s', message)
class ShipmentNotFoundError(ShipmentServiceError): class ShipmentNotFoundError(ShipmentServiceError):
pass pass
class ShipmentCreateError(ShipmentServiceError): class ShipmentCreateError(ShipmentServiceError):
def __init__(self, message: str, field: str | None = None): def __init__(self, message: str, field: str | None = None):
super().__init__(message) super().__init__(message)
self.field = field self.field = field

View File

@@ -1,58 +1,111 @@
from sqlmodel import Session, select # pylint: disable=E1101
import src.models as models
import src.shipments.exceptions as exceptions
import src.messages as messages
import datetime import datetime
import src.messages as messages
import src.shipments.exceptions as exceptions
from sqlmodel import Session, select
from src import models
def get_all( def get_all(
session: Session, session: Session,
user: models.User, user: models.User,
names: list[str], names: list[str] = None,
dates: list[str], dates: list[str] = None,
forms: list[str] forms: list[str] = None
) -> list[models.ShipmentPublic]: ) -> list[models.ShipmentPublic]:
statement = select(models.Shipment)\ statement = (
.join(models.Form, models.Shipment.form_id == models.Form.id)\ select(models.Shipment)
.join(models.Productor, models.Form.productor_id == models.Productor.id)\ .join(
.where(models.Productor.type.in_([r.name for r in user.roles]))\ models.Form,
models.Shipment.form_id == models.Form.id)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct() .distinct()
if len(names) > 0: )
if names and len(names) > 0:
statement = statement.where(models.Shipment.name.in_(names)) statement = statement.where(models.Shipment.name.in_(names))
if len(dates) > 0: if dates and len(dates) > 0:
statement = statement.where(models.Shipment.date.in_(list(map(lambda x: datetime.datetime.strptime(x, '%Y-%m-%d').date(), dates)))) statement = statement.where(
if len(forms) > 0: models.Shipment.date.in_(
list(map(
lambda x: datetime.datetime.strptime(
x, '%Y-%m-%d').date(),
dates
))
)
)
if forms and len(forms) > 0:
statement = statement.where(models.Form.name.in_(forms)) statement = statement.where(models.Form.name.in_(forms))
return session.exec(statement.order_by(models.Shipment.name)).all() return session.exec(statement.order_by(models.Shipment.name)).all()
def get_one(session: Session, shipment_id: int) -> models.ShipmentPublic: def get_one(session: Session, shipment_id: int) -> models.ShipmentPublic:
return session.get(models.Shipment, shipment_id) return session.get(models.Shipment, shipment_id)
def create_one(session: Session, shipment: models.ShipmentCreate) -> models.ShipmentPublic:
def create_one(
session: Session,
shipment: models.ShipmentCreate) -> models.ShipmentPublic:
if shipment is None: if shipment is None:
raise exceptions.ShipmentCreateError(messages.Messages.invalid_input('shipment', 'input cannot be None')) raise exceptions.ShipmentCreateError(
products = session.exec(select(models.Product).where(models.Product.id.in_(shipment.product_ids))).all() messages.Messages.invalid_input(
shipment_create = shipment.model_dump(exclude_unset=True, exclude={'product_ids'}) 'shipment', 'input cannot be None'))
products = session.exec(
select(models.Product)
.where(
models.Product.id.in_(
shipment.product_ids
)
)
).all()
shipment_create = shipment.model_dump(
exclude_unset=True, exclude={'product_ids'}
)
new_shipment = models.Shipment(**shipment_create, products=products) new_shipment = models.Shipment(**shipment_create, products=products)
session.add(new_shipment) session.add(new_shipment)
session.commit() session.commit()
session.refresh(new_shipment) session.refresh(new_shipment)
return new_shipment return new_shipment
def update_one(session: Session, id: int, shipment: models.ShipmentUpdate) -> models.ShipmentPublic:
def update_one(
session: Session,
_id: int,
shipment: models.ShipmentUpdate) -> models.ShipmentPublic:
if shipment is None: if shipment is None:
raise exceptions.ShipmentCreateError(messages.Messages.invalid_input('shipment', 'input cannot be None')) raise exceptions.ShipmentCreateError(
statement = select(models.Shipment).where(models.Shipment.id == id) messages.Messages.invalid_input(
'shipment', 'input cannot be None'))
statement = select(models.Shipment).where(models.Shipment.id == _id)
result = session.exec(statement) result = session.exec(statement)
new_shipment = result.first() new_shipment = result.first()
if not new_shipment: if not new_shipment:
raise exceptions.ShipmentNotFoundError(messages.Messages.not_found('shipment')) raise exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment'))
products_to_add = session.exec(select(models.Product).where(models.Product.id.in_(shipment.product_ids))).all() products_to_add = session.exec(
select(
models.Product
).where(
models.Product.id.in_(
shipment.product_ids
)
)
).all()
new_shipment.products.clear() new_shipment.products.clear()
for add in products_to_add: for add in products_to_add:
new_shipment.products.append(add) new_shipment.products.append(add)
shipment_updates = shipment.model_dump(exclude_unset=True, exclude={"product_ids"}) shipment_updates = shipment.model_dump(
exclude_unset=True, exclude={"product_ids"}
)
for key, value in shipment_updates.items(): for key, value in shipment_updates.items():
setattr(new_shipment, key, value) setattr(new_shipment, key, value)
@@ -61,14 +114,16 @@ def update_one(session: Session, id: int, shipment: models.ShipmentUpdate) -> mo
session.refresh(new_shipment) session.refresh(new_shipment)
return new_shipment return new_shipment
def delete_one(session: Session, id: int) -> models.ShipmentPublic:
statement = select(models.Shipment).where(models.Shipment.id == id) def delete_one(session: Session, _id: int) -> models.ShipmentPublic:
statement = select(models.Shipment).where(models.Shipment.id == _id)
result = session.exec(statement) result = session.exec(statement)
shipment = result.first() shipment = result.first()
if not shipment: if not shipment:
raise exceptions.ShipmentNotFoundError(messages.Messages.not_found('shipment')) raise exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment'))
result = models.ShipmentPublic.model_validate(shipment) result = models.ShipmentPublic.model_validate(shipment)
session.delete(shipment) session.delete(shipment)
session.commit() session.commit()
return result return result

View File

@@ -1,14 +1,15 @@
from fastapi import APIRouter, HTTPException, Depends, Query
import src.messages as messages import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.shipments.service as service
import src.shipments.exceptions as exceptions import src.shipments.exceptions as exceptions
import src.shipments.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/shipments') router = APIRouter(prefix='/shipments')
@router.get('', response_model=list[models.ShipmentPublic], ) @router.get('', response_model=list[models.ShipmentPublic], )
def get_shipments( def get_shipments(
session: Session = Depends(get_session), session: Session = Depends(get_session),
@@ -25,17 +26,22 @@ def get_shipments(
forms, forms,
) )
@router.get('/{id}', response_model=models.ShipmentPublic)
@router.get('/{_id}', response_model=models.ShipmentPublic)
def get_shipment( def get_shipment(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.get_one(session, id) result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('shipment')) raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('shipment')
)
return result return result
@router.post('', response_model=models.ShipmentPublic) @router.post('', response_model=models.ShipmentPublic)
def create_shipment( def create_shipment(
shipment: models.ShipmentCreate, shipment: models.ShipmentCreate,
@@ -45,30 +51,32 @@ def create_shipment(
try: try:
result = service.create_one(session, shipment) result = service.create_one(session, shipment)
except exceptions.ShipmentCreateError as error: except exceptions.ShipmentCreateError as error:
raise HTTPException(status_code=400, detail=str(error)) raise HTTPException(status_code=400, detail=str(error)) from error
return result return result
@router.put('/{id}', response_model=models.ShipmentPublic)
@router.put('/{_id}', response_model=models.ShipmentPublic)
def update_shipment( def update_shipment(
id: int, _id: int,
shipment: models.ShipmentUpdate, shipment: models.ShipmentUpdate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
try: try:
result = service.update_one(session, id, shipment) result = service.update_one(session, _id, shipment)
except exceptions.ShipmentNotFoundError as error: except exceptions.ShipmentNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(status_code=404, detail=str(error)) from error
return result return result
@router.delete('/{id}', response_model=models.ShipmentPublic)
@router.delete('/{_id}', response_model=models.ShipmentPublic)
def delete_shipment( def delete_shipment(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
try: try:
result = service.delete_one(session, id) result = service.delete_one(session, _id)
except exceptions.ShipmentNotFoundError as error: except exceptions.ShipmentNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) raise HTTPException(status_code=404, detail=str(error)) from error
return result return result

View File

@@ -1,3 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr> # SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT

View File

@@ -1,14 +1,19 @@
from sqlmodel import Session, select from sqlmodel import Session, select
import src.models as models from src import models
def get_all(session: Session) -> list[models.TemplatePublic]: def get_all(session: Session) -> list[models.TemplatePublic]:
statement = select(models.Template) statement = select(models.Template)
return session.exec(statement.order_by(models.Template.name)).all() return session.exec(statement.order_by(models.Template.name)).all()
def get_one(session: Session, template_id: int) -> models.TemplatePublic: def get_one(session: Session, template_id: int) -> models.TemplatePublic:
return session.get(models.Template, template_id) return session.get(models.Template, template_id)
def create_one(session: Session, template: models.TemplateCreate) -> models.TemplatePublic:
def create_one(
session: Session,
template: models.TemplateCreate) -> models.TemplatePublic:
template_create = template.model_dump(exclude_unset=True) template_create = template.model_dump(exclude_unset=True)
new_template = models.Template(**template_create) new_template = models.Template(**template_create)
session.add(new_template) session.add(new_template)
@@ -16,7 +21,11 @@ def create_one(session: Session, template: models.TemplateCreate) -> models.Temp
session.refresh(new_template) session.refresh(new_template)
return new_template return new_template
def update_one(session: Session, id: int, template: models.TemplateUpdate) -> models.TemplatePublic:
def update_one(
session: Session,
id: int,
template: models.TemplateUpdate) -> models.TemplatePublic:
statement = select(models.Template).where(models.Template.id == id) statement = select(models.Template).where(models.Template.id == id)
result = session.exec(statement) result = session.exec(statement)
new_template = result.first() new_template = result.first()
@@ -30,6 +39,7 @@ def update_one(session: Session, id: int, template: models.TemplateUpdate) -> mo
session.refresh(new_template) session.refresh(new_template)
return new_template return new_template
def delete_one(session: Session, id: int) -> models.TemplatePublic: def delete_one(session: Session, id: int) -> models.TemplatePublic:
statement = select(models.Template).where(models.Template.id == id) statement = select(models.Template).where(models.Template.id == id)
result = session.exec(statement) result = session.exec(statement)

View File

@@ -1,13 +1,14 @@
from fastapi import APIRouter, HTTPException, Depends
import src.messages as messages import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.templates.service as service import src.templates.service as service
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/templates') router = APIRouter(prefix='/templates')
@router.get('', response_model=list[models.TemplatePublic]) @router.get('', response_model=list[models.TemplatePublic])
def get_templates( def get_templates(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
@@ -15,43 +16,50 @@ def get_templates(
): ):
return service.get_all(session) return service.get_all(session)
@router.get('/{id}', response_model=models.TemplatePublic) @router.get('/{id}', response_model=models.TemplatePublic)
def get_template( def get_template(
id: int, id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.get_one(session, id) result = service.get_one(session, id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('template')) raise HTTPException(status_code=404,
detail=messages.Messages.not_found('template'))
return result return result
@router.post('', response_model=models.TemplatePublic) @router.post('', response_model=models.TemplatePublic)
def create_template( def create_template(
template: models.TemplateCreate, template: models.TemplateCreate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
return service.create_one(session, template) return service.create_one(session, template)
@router.put('/{id}', response_model=models.TemplatePublic) @router.put('/{id}', response_model=models.TemplatePublic)
def update_template( def update_template(
id: int, template: models.TemplateUpdate, id: int, template: models.TemplateUpdate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.update_one(session, id, template) result = service.update_one(session, id, template)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('template')) raise HTTPException(status_code=404,
detail=messages.Messages.not_found('template'))
return result return result
@router.delete('/{id}', response_model=models.TemplatePublic) @router.delete('/{id}', response_model=models.TemplatePublic)
def delete_template( def delete_template(
id: int, id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.delete_one(session, id) result = service.delete_one(session, id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('template')) raise HTTPException(status_code=404,
detail=messages.Messages.not_found('template'))
return result return result

View File

@@ -1,3 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr> # SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT

View File

@@ -1,11 +1,17 @@
import logging
class UserServiceError(Exception): class UserServiceError(Exception):
def __init__(self, message: str): def __init__(self, message: str):
super().__init__(message) super().__init__(message)
logging.error('UserService : %s', message)
class UserNotFoundError(UserServiceError): class UserNotFoundError(UserServiceError):
pass pass
class UserCreateError(UserServiceError): class UserCreateError(UserServiceError):
def __init__(self, message: str, field: str | None = None): def __init__(self, message: str, field: str | None = None):
super().__init__(message) super().__init__(message)
self.field = field self.field = field

View File

@@ -1,9 +1,8 @@
from sqlmodel import Session, select
import src.models as models
import src.messages as messages import src.messages as messages
import src.users.exceptions as exceptions import src.users.exceptions as exceptions
from sqlmodel import Session, select
from src import models
def get_all( def get_all(
session: Session, session: Session,
@@ -17,11 +16,15 @@ def get_all(
statement = statement.where(models.User.email.in_(emails)) statement = statement.where(models.User.email.in_(emails))
return session.exec(statement.order_by(models.User.name)).all() return session.exec(statement.order_by(models.User.name)).all()
def get_one(session: Session, user_id: int) -> models.UserPublic: def get_one(session: Session, user_id: int) -> models.UserPublic:
return session.get(models.User, user_id) return session.get(models.User, user_id)
def get_or_create_roles(session: Session, role_names: list[str]) -> list[models.ContractType]:
statement = select(models.ContractType).where(models.ContractType.name.in_(role_names)) def get_or_create_roles(session: Session,
role_names: list[str]) -> list[models.ContractType]:
statement = select(models.ContractType).where(
models.ContractType.name.in_(role_names))
existing = session.exec(statement).all() existing = session.exec(statement).all()
existing_roles = {role.name for role in existing} existing_roles = {role.name for role in existing}
missing_role = set(role_names) - existing_roles missing_role = set(role_names) - existing_roles
@@ -37,8 +40,11 @@ def get_or_create_roles(session: Session, role_names: list[str]) -> list[models.
session.refresh(role) session.refresh(role)
return existing + new_roles return existing + new_roles
def get_or_create_user(session: Session, user_create: models.UserCreate): def get_or_create_user(session: Session, user_create: models.UserCreate):
statement = select(models.User).where(models.User.email == user_create.email) statement = select(
models.User).where(
models.User.email == user_create.email)
user = session.exec(statement).first() user = session.exec(statement).first()
if user: if user:
user_role_names = [r.name for r in user.roles] user_role_names = [r.name for r in user.roles]
@@ -48,13 +54,17 @@ def get_or_create_user(session: Session, user_create: models.UserCreate):
user = create_one(session, user_create) user = create_one(session, user_create)
return user return user
def get_roles(session: Session): def get_roles(session: Session):
statement = select(models.ContractType) statement = select(models.ContractType)
return session.exec(statement.order_by(models.ContractType.name)).all() return session.exec(statement.order_by(models.ContractType.name)).all()
def create_one(session: Session, user: models.UserCreate) -> models.UserPublic: def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
if user is None: if user is None:
raise exceptions.UserCreateError(messages.Messages.invalid_input('user', 'input cannot be None')) raise exceptions.UserCreateError(
messages.Messages.invalid_input(
'user', 'input cannot be None'))
new_user = models.User( new_user = models.User(
name=user.name, name=user.name,
email=user.email email=user.email
@@ -68,9 +78,15 @@ def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
session.refresh(new_user) session.refresh(new_user)
return new_user return new_user
def update_one(session: Session, id: int, user: models.UserCreate) -> models.UserPublic:
def update_one(
session: Session,
id: int,
user: models.UserCreate) -> models.UserPublic:
if user is None: if user is None:
raise exceptions.UserCreateError(messages.s.invalid_input('user', 'input cannot be None')) raise exceptions.UserCreateError(
messages.s.invalid_input(
'user', 'input cannot be None'))
statement = select(models.User).where(models.User.id == id) statement = select(models.User).where(models.User.id == id)
result = session.exec(statement) result = session.exec(statement)
new_user = result.first() new_user = result.first()
@@ -86,6 +102,7 @@ def update_one(session: Session, id: int, user: models.UserCreate) -> models.Use
session.refresh(new_user) session.refresh(new_user)
return new_user return new_user
def delete_one(session: Session, id: int) -> models.UserPublic: def delete_one(session: Session, id: int) -> models.UserPublic:
statement = select(models.User).where(models.User.id == id) statement = select(models.User).where(models.User.id == id)
result = session.exec(statement) result = session.exec(statement)
@@ -95,4 +112,4 @@ def delete_one(session: Session, id: int) -> models.UserPublic:
result = models.UserPublic.model_validate(user) result = models.UserPublic.model_validate(user)
session.delete(user) session.delete(user)
session.commit() session.commit()
return result return result

View File

@@ -1,14 +1,15 @@
from fastapi import APIRouter, HTTPException, Depends, Query
import src.messages as messages import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.users.service as service
from src.auth.auth import get_current_user
import src.users.exceptions as exceptions import src.users.exceptions as exceptions
import src.users.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/users') router = APIRouter(prefix='/users')
@router.get('', response_model=list[models.UserPublic]) @router.get('', response_model=list[models.UserPublic])
def get_users( def get_users(
session: Session = Depends(get_session), session: Session = Depends(get_session),
@@ -22,6 +23,7 @@ def get_users(
emails, emails,
) )
@router.get('/roles', response_model=list[models.ContractType]) @router.get('/roles', response_model=list[models.ContractType])
def get_roles( def get_roles(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
@@ -29,20 +31,23 @@ def get_roles(
): ):
return service.get_roles(session) return service.get_roles(session)
@router.get('/{id}', response_model=models.UserPublic) @router.get('/{id}', response_model=models.UserPublic)
def get_users( def get_users(
id: int, id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.get_one(session, id) result = service.get_one(session, id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('user')) raise HTTPException(status_code=404,
detail=messages.Messages.not_found('user'))
return result return result
@router.post('', response_model=models.UserPublic) @router.post('', response_model=models.UserPublic)
def create_user( def create_user(
user: models.UserCreate, user: models.UserCreate,
logged_user: models.User = Depends(get_current_user), logged_user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
@@ -52,27 +57,31 @@ def create_user(
raise HTTPException(status_code=400, detail=str(error)) raise HTTPException(status_code=400, detail=str(error))
return user return user
@router.put('/{id}', response_model=models.UserPublic) @router.put('/{id}', response_model=models.UserPublic)
def update_user( def update_user(
id: int, id: int,
user: models.UserUpdate, user: models.UserUpdate,
logged_user: models.User = Depends(get_current_user), logged_user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
try: try:
result = service.update_one(session, id, user) result = service.update_one(session, id, user)
except exceptions.UserNotFoundError as error: except exceptions.UserNotFoundError as error:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('user')) raise HTTPException(status_code=404,
detail=messages.Messages.not_found('user'))
return result return result
@router.delete('/{id}', response_model=models.UserPublic) @router.delete('/{id}', response_model=models.UserPublic)
def delete_user( def delete_user(
id: int, id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
try: try:
result = service.delete_one(session, id) result = service.delete_one(session, id)
except exceptions.UserNotFoundError as error: except exceptions.UserNotFoundError as error:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('user')) raise HTTPException(status_code=404,
detail=messages.Messages.not_found('user'))
return result return result

View File

@@ -1,13 +1,14 @@
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlmodel import SQLModel, Session, create_engine
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src import models
from src.auth.auth import get_current_user
from src.database import get_session
from src.main import app
from .fixtures import * from .fixtures import *
from src.main import app
import src.models as models
from src.database import get_session
from src.auth.auth import get_current_user
@pytest.fixture @pytest.fixture
def mock_session(mocker): def mock_session(mocker):
@@ -15,26 +16,29 @@ def mock_session(mocker):
def override(): def override():
return session return session
app.dependency_overrides[get_session] = override app.dependency_overrides[get_session] = override
yield session yield session
app.dependency_overrides.clear() app.dependency_overrides.clear()
@pytest.fixture @pytest.fixture
def mock_user(): def mock_user():
user = models.User(id=1, name='test user', email='test@user.com') user = models.User(id=1, name='test user', email='test@user.com')
def override(): def override():
return user return user
app.dependency_overrides[get_current_user] = override app.dependency_overrides[get_current_user] = override
yield user yield user
app.dependency_overrides.clear() app.dependency_overrides.clear()
@pytest.fixture @pytest.fixture
def client(): def client():
return TestClient(app) return TestClient(app)
@pytest.fixture(name='session') @pytest.fixture(name='session')
def session_fixture(): def session_fixture():
engine = create_engine( engine = create_engine(
@@ -55,4 +59,4 @@ def session_fixture():
transaction.rollback() transaction.rollback()
session.close() session.close()
connection.close() connection.close()
engine.dispose() engine.dispose()

View File

@@ -1,10 +1,12 @@
import src.models as models
import tests.factories.contracts as contract_factory import tests.factories.contracts as contract_factory
import tests.factories.products as product_factory import tests.factories.products as product_factory
from src import models
def contract_product_factory(**kwargs): def contract_product_factory(**kwargs):
contract = contract_factory.contract_factory(id=1) contract = contract_factory.contract_factory(id=1)
product = product_factory.product_public_factory(id=1, type=models.ProductType.RECCURENT) product = product_factory.product_public_factory(
id=1, type=models.ProductType.RECCURENT)
data = dict( data = dict(
product_id=1, product_id=1,
shipment_id=1, shipment_id=1,
@@ -16,6 +18,7 @@ def contract_product_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ContractProduct(**data) return models.ContractProduct(**data)
def contract_product_public_factory(**kwargs): def contract_product_public_factory(**kwargs):
contract = contract_factory.contract_factory(id=1) contract = contract_factory.contract_factory(id=1)
product = product_factory.product_public_factory(id=1) product = product_factory.product_public_factory(id=1)
@@ -31,6 +34,7 @@ def contract_product_public_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ContractProductPublic(**data) return models.ContractProductPublic(**data)
def contract_product_create_factory(**kwargs): def contract_product_create_factory(**kwargs):
data = dict( data = dict(
product_id=1, product_id=1,
@@ -40,6 +44,7 @@ def contract_product_create_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ContractProductCreate(**data) return models.ContractProductCreate(**data)
def contract_product_update_factory(**kwargs): def contract_product_update_factory(**kwargs):
data = dict( data = dict(
product_id=1, product_id=1,
@@ -49,6 +54,7 @@ def contract_product_update_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ContractProductUpdate(**data) return models.ContractProductUpdate(**data)
def contract_product_body_factory(**kwargs): def contract_product_body_factory(**kwargs):
data = dict( data = dict(
product_id=1, product_id=1,
@@ -56,4 +62,4 @@ def contract_product_body_factory(**kwargs):
quantity=1, quantity=1,
) )
data.update(kwargs) data.update(kwargs)
return data return data

View File

@@ -1,6 +1,8 @@
import src.models as models from src import models
from .forms import form_factory from .forms import form_factory
def contract_factory(**kwargs): def contract_factory(**kwargs):
data = dict( data = dict(
id=1, id=1,
@@ -17,6 +19,7 @@ def contract_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.Contract(**data) return models.Contract(**data)
def contract_public_factory(**kwargs): def contract_public_factory(**kwargs):
data = dict( data = dict(
id=1, id=1,
@@ -33,6 +36,7 @@ def contract_public_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ContractPublic(**data) return models.ContractPublic(**data)
def contract_create_factory(**kwargs): def contract_create_factory(**kwargs):
data = dict( data = dict(
firstname="test", firstname="test",
@@ -48,6 +52,7 @@ def contract_create_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ContractCreate(**data) return models.ContractCreate(**data)
def contract_update_factory(**kwargs): def contract_update_factory(**kwargs):
data = dict( data = dict(
firstname="test", firstname="test",
@@ -60,6 +65,7 @@ def contract_update_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ContractUpdate(**data) return models.ContractUpdate(**data)
def contract_body_factory(**kwargs): def contract_body_factory(**kwargs):
data = dict( data = dict(
firstname="test", firstname="test",
@@ -73,4 +79,4 @@ def contract_body_factory(**kwargs):
form_id=1 form_id=1
) )
data.update(kwargs) data.update(kwargs)
return data return data

View File

@@ -1,9 +1,11 @@
import src.models as models
from .productors import productor_public_factory
from .shipments import shipment_public_factory
from .users import user_factory
import datetime import datetime
from src import models
from .productors import productor_public_factory
from .users import user_factory
def form_factory(**kwargs): def form_factory(**kwargs):
data = dict( data = dict(
id=1, id=1,
@@ -21,7 +23,7 @@ def form_factory(**kwargs):
) )
data.update(kwargs) data.update(kwargs)
return models.Form(**data) return models.Form(**data)
def form_body_factory(**kwargs): def form_body_factory(**kwargs):
data = dict( data = dict(
@@ -37,6 +39,7 @@ def form_body_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return data return data
def form_create_factory(**kwargs): def form_create_factory(**kwargs):
data = dict( data = dict(
name='form 1', name='form 1',
@@ -51,6 +54,7 @@ def form_create_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.FormCreate(**data) return models.FormCreate(**data)
def form_update_factory(**kwargs): def form_update_factory(**kwargs):
data = dict( data = dict(
name='form 1', name='form 1',
@@ -65,7 +69,8 @@ def form_update_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.FormUpdate(**data) return models.FormUpdate(**data)
def form_public_factory(form=None, shipments=[],**kwargs):
def form_public_factory(form=None, shipments=[], **kwargs):
data = dict( data = dict(
id=1, id=1,
name='form 1', name='form 1',
@@ -81,4 +86,4 @@ def form_public_factory(form=None, shipments=[],**kwargs):
productor=productor_public_factory(), productor=productor_public_factory(),
) )
data.update(kwargs) data.update(kwargs)
return models.FormPublic(**data) return models.FormPublic(**data)

View File

@@ -1,4 +1,5 @@
import src.models as models from src import models
def productor_factory(**kwargs): def productor_factory(**kwargs):
data = dict( data = dict(
@@ -10,6 +11,7 @@ def productor_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.Productor(**data) return models.Productor(**data)
def productor_public_factory(**kwargs): def productor_public_factory(**kwargs):
data = dict( data = dict(
id=1, id=1,
@@ -22,6 +24,7 @@ def productor_public_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ProductorPublic(**data) return models.ProductorPublic(**data)
def productor_create_factory(**kwargs): def productor_create_factory(**kwargs):
data = dict( data = dict(
id=1, id=1,
@@ -34,6 +37,7 @@ def productor_create_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ProductorCreate(**data) return models.ProductorCreate(**data)
def productor_update_factory(**kwargs): def productor_update_factory(**kwargs):
data = dict( data = dict(
id=1, id=1,
@@ -46,6 +50,7 @@ def productor_update_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ProductorUpdate(**data) return models.ProductorUpdate(**data)
def productor_body_factory(**kwargs): def productor_body_factory(**kwargs):
data = dict( data = dict(
id=1, id=1,
@@ -56,4 +61,4 @@ def productor_body_factory(**kwargs):
payment_methods=[], payment_methods=[],
) )
data.update(kwargs) data.update(kwargs)
return data return data

View File

@@ -1,6 +1,7 @@
import src.models as models from src import models
from .productors import productor_factory from .productors import productor_factory
from .shipments import shipment_factory
def product_body_factory(**kwargs): def product_body_factory(**kwargs):
data = dict( data = dict(
@@ -16,6 +17,7 @@ def product_body_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return data return data
def product_create_factory(**kwargs): def product_create_factory(**kwargs):
data = dict( data = dict(
name='product test 1', name='product test 1',
@@ -30,6 +32,7 @@ def product_create_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ProductCreate(**data) return models.ProductCreate(**data)
def product_update_factory(**kwargs): def product_update_factory(**kwargs):
data = dict( data = dict(
name='product test 1', name='product test 1',
@@ -44,7 +47,8 @@ def product_update_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ProductUpdate(**data) return models.ProductUpdate(**data)
def product_public_factory(productor=None, shipments=[],**kwargs):
def product_public_factory(productor=None, shipments=[], **kwargs):
if productor is None: if productor is None:
productor = productor_factory() productor = productor_factory()
data = dict( data = dict(
@@ -61,4 +65,4 @@ def product_public_factory(productor=None, shipments=[],**kwargs):
shipments=shipments, shipments=shipments,
) )
data.update(kwargs) data.update(kwargs)
return models.ProductPublic(**data) return models.ProductPublic(**data)

View File

@@ -1,6 +1,8 @@
import src.models as models
import datetime import datetime
from src import models
def shipment_factory(**kwargs): def shipment_factory(**kwargs):
data = dict( data = dict(
id=1, id=1,
@@ -11,6 +13,7 @@ def shipment_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.Shipment(**data) return models.Shipment(**data)
def shipment_public_factory(**kwargs): def shipment_public_factory(**kwargs):
data = dict( data = dict(
id=1, id=1,
@@ -23,6 +26,7 @@ def shipment_public_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ShipmentPublic(**data) return models.ShipmentPublic(**data)
def shipment_create_factory(**kwargs): def shipment_create_factory(**kwargs):
data = dict( data = dict(
name="test shipment", name="test shipment",
@@ -33,6 +37,7 @@ def shipment_create_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ShipmentCreate(**data) return models.ShipmentCreate(**data)
def shipment_update_factory(**kwargs): def shipment_update_factory(**kwargs):
data = dict( data = dict(
name="test shipment", name="test shipment",
@@ -43,6 +48,7 @@ def shipment_update_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.ShipmentUpdate(**data) return models.ShipmentUpdate(**data)
def shipment_body_factory(**kwargs): def shipment_body_factory(**kwargs):
data = dict( data = dict(
name="test shipment", name="test shipment",
@@ -50,4 +56,4 @@ def shipment_body_factory(**kwargs):
date="2025-10-10", date="2025-10-10",
) )
data.update(kwargs) data.update(kwargs)
return data return data

View File

@@ -1,4 +1,5 @@
import src.models as models from src import models
def user_factory(**kwargs): def user_factory(**kwargs):
data = dict( data = dict(
@@ -10,6 +11,7 @@ def user_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.User(**data) return models.User(**data)
def user_public_factory(**kwargs): def user_public_factory(**kwargs):
data = dict( data = dict(
id=1, id=1,
@@ -20,6 +22,7 @@ def user_public_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.UserPublic(**data) return models.UserPublic(**data)
def user_create_factory(**kwargs): def user_create_factory(**kwargs):
data = dict( data = dict(
name="test user", name="test user",
@@ -29,6 +32,7 @@ def user_create_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.UserCreate(**data) return models.UserCreate(**data)
def user_update_factory(**kwargs): def user_update_factory(**kwargs):
data = dict( data = dict(
name="test user", name="test user",
@@ -38,6 +42,7 @@ def user_update_factory(**kwargs):
data.update(kwargs) data.update(kwargs)
return models.UserUpdate(**data) return models.UserUpdate(**data)
def user_body_factory(**kwargs): def user_body_factory(**kwargs):
data = dict( data = dict(
name="test user", name="test user",
@@ -45,4 +50,4 @@ def user_body_factory(**kwargs):
role_names=[], role_names=[],
) )
data.update(kwargs) data.update(kwargs)
return data return data

View File

@@ -1,18 +1,19 @@
import pytest
import datetime import datetime
from sqlmodel import Session
import src.models as models import pytest
import src.forms.service as forms_service import src.forms.service as forms_service
import src.shipments.service as shipments_service
import src.productors.service as productors_service import src.productors.service as productors_service
import src.products.service as products_service import src.products.service as products_service
import src.shipments.service as shipments_service
import src.users.service as users_service import src.users.service as users_service
import tests.factories.forms as forms_factory import tests.factories.forms as forms_factory
import tests.factories.shipments as shipments_factory
import tests.factories.productors as productors_factory import tests.factories.productors as productors_factory
import tests.factories.products as products_factory import tests.factories.products as products_factory
import tests.factories.shipments as shipments_factory
import tests.factories.users as users_factory import tests.factories.users as users_factory
from sqlmodel import Session
from src import models
@pytest.fixture @pytest.fixture
def productor(session: Session) -> models.ProductorPublic: def productor(session: Session) -> models.ProductorPublic:
@@ -46,8 +47,10 @@ def productors(session: Session) -> models.ProductorPublic:
] ]
return productors return productors
@pytest.fixture @pytest.fixture
def products(session: Session, productor: models.ProductorPublic) -> list[models.ProductPublic]: def products(session: Session,
productor: models.ProductorPublic) -> list[models.ProductPublic]:
products = [ products = [
products_service.create_one( products_service.create_one(
session, session,
@@ -68,65 +71,70 @@ def products(session: Session, productor: models.ProductorPublic) -> list[models
] ]
return products return products
@pytest.fixture @pytest.fixture
def user(session: Session) -> models.UserPublic: def user(session: Session) -> models.UserPublic:
user = users_service.create_one( user = users_service.create_one(
session, session,
users_factory.user_create_factory( users_factory.user_create_factory(
name='test user', name='test user',
email='test@test.com', email='test@test.com',
role_names=['Légumineuses'] role_names=['Légumineuses']
) )
) )
return user return user
@pytest.fixture @pytest.fixture
def users(session: Session) -> list[models.UserPublic]: def users(session: Session) -> list[models.UserPublic]:
users = [ users = [
users_service.create_one( users_service.create_one(
session, session,
users_factory.user_create_factory( users_factory.user_create_factory(
name='test user 1 (admin)', name='test user 1 (admin)',
email='test1@test.com', email='test1@test.com',
role_names=['Légumineuses', 'Légumes', 'Oeufs', 'Porc-Agneau', 'Vin', 'Fruits'] role_names=[
) 'Légumineuses',
), 'Légumes',
'Oeufs',
'Porc-Agneau',
'Vin',
'Fruits'])),
users_service.create_one( users_service.create_one(
session, session,
users_factory.user_create_factory( users_factory.user_create_factory(
name='test user 2', name='test user 2',
email='test2@test.com', email='test2@test.com',
role_names=['Légumineuses'] role_names=['Légumineuses'])),
)
),
users_service.create_one( users_service.create_one(
session, session,
users_factory.user_create_factory( users_factory.user_create_factory(
name='test user 3', name='test user 3',
email='test3@test.com', email='test3@test.com',
role_names=['Porc-Agneau'] role_names=['Porc-Agneau']))]
)
)
]
return users return users
@pytest.fixture @pytest.fixture
def referer(session: Session) -> models.UserPublic: def referer(session: Session) -> models.UserPublic:
referer = users_service.create_one( referer = users_service.create_one(
session, session,
users_factory.user_create_factory( users_factory.user_create_factory(
name='test referer', name='test referer',
email='test@test.com', email='test@test.com',
role_names=['Légumineuses'], role_names=['Légumineuses'],
) )
) )
return referer return referer
@pytest.fixture @pytest.fixture
def shipments(session: Session, forms: list[models.FormPublic], products: list[models.ProductPublic]): def shipments(session: Session,
forms: list[models.FormPublic],
products: list[models.ProductPublic]):
shipments = [ shipments = [
shipments_service.create_one( shipments_service.create_one(
session, session,
shipments_factory.shipment_create_factory( shipments_factory.shipment_create_factory(
name='test shipment 1', name='test shipment 1',
date=datetime.date(2025, 10, 10), date=datetime.date(2025, 10, 10),
@@ -135,7 +143,7 @@ def shipments(session: Session, forms: list[models.FormPublic], products: list[m
) )
), ),
shipments_service.create_one( shipments_service.create_one(
session, session,
shipments_factory.shipment_create_factory( shipments_factory.shipment_create_factory(
name='test shipment 2', name='test shipment 2',
date=datetime.date(2025, 11, 10), date=datetime.date(2025, 11, 10),
@@ -146,15 +154,16 @@ def shipments(session: Session, forms: list[models.FormPublic], products: list[m
] ]
return shipments return shipments
@pytest.fixture @pytest.fixture
def forms( def forms(
session: Session, session: Session,
productor: models.ProductorPublic, productor: models.ProductorPublic,
referer: models.UserPublic referer: models.UserPublic
) -> list[models.FormPublic]: ) -> list[models.FormPublic]:
forms = [ forms = [
forms_service.create_one( forms_service.create_one(
session, session,
forms_factory.form_create_factory( forms_factory.form_create_factory(
name='test form 1', name='test form 1',
productor_id=productor.id, productor_id=productor.id,
@@ -163,7 +172,7 @@ def forms(
) )
), ),
forms_service.create_one( forms_service.create_one(
session, session,
forms_factory.form_create_factory( forms_factory.form_create_factory(
name='test form 2', name='test form 2',
productor_id=productor.id, productor_id=productor.id,
@@ -173,4 +182,3 @@ def forms(
) )
] ]
return forms return forms

View File

@@ -1,12 +1,12 @@
import src.contracts.service as service import src.contracts.service as service
import src.models as models import tests.factories.contract_products as contract_products_factory
from src.main import app
from src.auth.auth import get_current_user
import tests.factories.contracts as contract_factory import tests.factories.contracts as contract_factory
import tests.factories.forms as form_factory import tests.factories.forms as form_factory
import tests.factories.contract_products as contract_products_factory
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user
from src.main import app
class TestContracts: class TestContracts:
def test_get_all(self, client, mocker, mock_session, mock_user): def test_get_all(self, client, mocker, mock_session, mock_user):
@@ -31,6 +31,7 @@ class TestContracts:
mock_user, mock_user,
[], [],
) )
def test_get_all_filters(self, client, mocker, mock_session, mock_user): def test_get_all_filters(self, client, mocker, mock_session, mock_user):
mock_results = [ mock_results = [
contract_factory.contract_public_factory(id=2), contract_factory.contract_public_factory(id=2),
@@ -52,10 +53,15 @@ class TestContracts:
['form test'], ['form test'],
) )
def test_get_all_unauthorized(self, client, mocker, mock_session, mock_user): def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.contracts.service.get_all') mock = mocker.patch('src.contracts.service.get_all')
@@ -74,7 +80,7 @@ class TestContracts:
'get_one', 'get_one',
return_value=mock_result return_value=mock_result
) )
mock_is_allowed = mocker.patch.object( mocker.patch.object(
service, service,
'is_allowed', 'is_allowed',
return_value=True return_value=True
@@ -88,7 +94,7 @@ class TestContracts:
mock_session, mock_session,
2 2
) )
def test_get_one_notfound(self, client, mocker, mock_session, mock_user): def test_get_one_notfound(self, client, mocker, mock_session, mock_user):
mock_result = None mock_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -97,23 +103,27 @@ class TestContracts:
return_value=mock_result return_value=mock_result
) )
mock_is_allowed = mocker.patch.object( mocker.patch.object(
service, service,
'is_allowed', 'is_allowed',
return_value=True return_value=True
) )
response = client.get('/api/contracts/2') response = client.get('/api/contracts/2')
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
mock_session, mock_session,
2 2
) )
def test_get_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.contracts.service.get_one') mock = mocker.patch('src.contracts.service.get_one')
@@ -127,37 +137,53 @@ class TestContracts:
def test_create_one(self, client, mocker, mock_session, mock_user): def test_create_one(self, client, mocker, mock_session, mock_user):
contract_body = contract_factory.contract_body_factory( contract_body = contract_factory.contract_body_factory(
products=[ products=[
contract_products_factory.contract_product_body_factory(product_id=1), contract_products_factory.contract_product_body_factory(
contract_products_factory.contract_product_body_factory(product_id=2), product_id=1
contract_products_factory.contract_product_body_factory(product_id=3) ),
contract_products_factory.contract_product_body_factory(
product_id=2
),
contract_products_factory.contract_product_body_factory(
product_id=3
)
], ],
cheques=[{'name': '123123', 'value': '100'}] cheques=[{'name': '123123', 'value': '100'}]
) )
contract_result = contract_factory.contract_factory( contract_result = contract_factory.contract_factory(
products=[ products=[
contract_products_factory.contract_product_factory(product_id=1), contract_products_factory.contract_product_factory(
contract_products_factory.contract_product_factory(product_id=2), product_id=1
contract_products_factory.contract_product_factory(product_id=3) ),
contract_products_factory.contract_product_factory(
product_id=2
),
contract_products_factory.contract_product_factory(
product_id=3
)
], ],
form=form_factory.form_factory(), form=form_factory.form_factory(),
cheques=[models.Cheque(name='123123', value='100')] cheques=[models.Cheque(name='123123', value='100')]
) )
mock_create_one = mocker.patch.object( mocker.patch.object(
service, service,
'create_one', 'create_one',
return_value=contract_result return_value=contract_result
) )
mock_add_contract_file = mocker.patch.object( mocker.patch.object(
service, service,
'add_contract_file', 'add_contract_file',
return_value=True return_value=True
) )
mock_generate_html_contract = mocker.patch('src.contracts.generate_contract.generate_html_contract') mocker.patch(
'src.contracts.generate_contract.generate_html_contract')
response = client.post('/api/contracts', json=contract_body) response = client.post('/api/contracts', json=contract_body)
assert response.status_code == 200 assert response.status_code == 200
contract_id = 'test_test_test type_hiver-2026' contract_id = 'test_test_test type_hiver-2026'
assert response.headers['Content-Disposition'] == f'attachment; filename=contract_{contract_id}.pdf' assert response.headers[
'Content-Disposition'] == (
f'attachment; filename=contract_{contract_id}.pdf'
)
def test_delete_one(self, client, mocker, mock_session, mock_user): def test_delete_one(self, client, mocker, mock_session, mock_user):
contract_result = contract_factory.contract_public_factory() contract_result = contract_factory.contract_public_factory()
@@ -168,14 +194,13 @@ class TestContracts:
return_value=contract_result return_value=contract_result
) )
mock_is_allowed = mocker.patch.object( mocker.patch.object(
service, service,
'is_allowed', 'is_allowed',
return_value=True return_value=True
) )
response = client.delete('/api/contracts/2') response = client.delete('/api/contracts/2')
response_data = response.json()
assert response.status_code == 200 assert response.status_code == 200
mock.assert_called_once_with( mock.assert_called_once_with(
@@ -183,7 +208,12 @@ class TestContracts:
2, 2,
) )
def test_delete_one_notfound(self, client, mocker, mock_session, mock_user): def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
contract_result = None contract_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -192,14 +222,13 @@ class TestContracts:
return_value=contract_result return_value=contract_result
) )
mock_is_allowed = mocker.patch.object( mocker.patch.object(
service, service,
'is_allowed', 'is_allowed',
return_value=True return_value=True
) )
response = client.delete('/api/contracts/2') response = client.delete('/api/contracts/2')
response_data = response.json()
assert response.status_code == 404 assert response.status_code == 404
mock.assert_called_once_with( mock.assert_called_once_with(
@@ -207,11 +236,15 @@ class TestContracts:
2, 2,
) )
def test_delete_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
contract_body = contract_factory.contract_body_factory()
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.contracts.service.delete_one') mock = mocker.patch('src.contracts.service.delete_one')
@@ -220,4 +253,4 @@ class TestContracts:
assert response.status_code == 401 assert response.status_code == 401
mock.assert_not_called() mock.assert_not_called()
app.dependency_overrides.clear() app.dependency_overrides.clear()

View File

@@ -1,11 +1,12 @@
import src.forms.service as service
import src.forms.exceptions as forms_exceptions import src.forms.exceptions as forms_exceptions
import src.models as models import src.forms.service as service
from src.main import app import src.messages as messages
from src.auth.auth import get_current_user
import tests.factories.forms as form_factory import tests.factories.forms as form_factory
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
import src.messages as messages from src import models
from src.auth.auth import get_current_user
from src.main import app
class TestForms: class TestForms:
def test_get_all(self, client, mocker, mock_session, mock_user): def test_get_all(self, client, mocker, mock_session, mock_user):
@@ -32,6 +33,7 @@ class TestForms:
False, False,
mock_user, mock_user,
) )
def test_get_all_filters(self, client, mocker, mock_session, mock_user): def test_get_all_filters(self, client, mocker, mock_session, mock_user):
mock_results = [ mock_results = [
form_factory.form_public_factory(name="test 2", id=2), form_factory.form_public_factory(name="test 2", id=2),
@@ -42,7 +44,8 @@ class TestForms:
return_value=mock_results return_value=mock_results
) )
response = client.get('/api/forms/referents?current_season=true&seasons=hiver-2025&productors=test productor') response = client.get(
'/api/forms/referents?current_season=true&seasons=hiver-2025&productors=test productor')
response_data = response.json() response_data = response.json()
assert response.status_code == 200 assert response.status_code == 200
assert response_data[0]['id'] == 2 assert response_data[0]['id'] == 2
@@ -55,10 +58,15 @@ class TestForms:
mock_user, mock_user,
) )
def test_get_all_unauthorized(self, client, mocker, mock_session, mock_user): def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.forms.service.get_all') mock = mocker.patch('src.forms.service.get_all')
@@ -87,7 +95,7 @@ class TestForms:
mock_session, mock_session,
2 2
) )
def test_get_one_notfound(self, client, mocker, mock_session, mock_user): def test_get_one_notfound(self, client, mocker, mock_session, mock_user):
mock_result = None mock_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -102,8 +110,7 @@ class TestForms:
mock_session, mock_session,
2 2
) )
def test_create_one(self, client, mocker, mock_session, mock_user): def test_create_one(self, client, mocker, mock_session, mock_user):
form_body = form_factory.form_body_factory(name='test form create') form_body = form_factory.form_body_factory(name='test form create')
form_create = form_factory.form_create_factory(name='test form create') form_create = form_factory.form_create_factory(name='test form create')
@@ -124,16 +131,17 @@ class TestForms:
mock_session, mock_session,
form_create form_create
) )
def test_create_one_referer_notfound(self, client, mocker, mock_session, mock_user): def test_create_one_referer_notfound(
form_body = form_factory.form_body_factory(name='test form create', referer_id=12312) self, client, mocker, mock_session, mock_user):
form_create = form_factory.form_create_factory(name='test form create', referer_id=12312) form_body = form_factory.form_body_factory(
name='test form create', referer_id=12312)
form_create = form_factory.form_create_factory(
name='test form create', referer_id=12312)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service, 'create_one', side_effect=forms_exceptions.UserNotFoundError(
'create_one', messages.Messages.not_found('referer')))
side_effect=forms_exceptions.UserNotFoundError(messages.Messages.not_found('referer'))
)
response = client.post('/api/forms', json=form_body) response = client.post('/api/forms', json=form_body)
response_data = response.json() response_data = response.json()
@@ -144,15 +152,16 @@ class TestForms:
form_create form_create
) )
def test_create_one_productor_notfound(self, client, mocker, mock_session, mock_user): def test_create_one_productor_notfound(
form_body = form_factory.form_body_factory(name='test form create', productor_id=1231) self, client, mocker, mock_session, mock_user):
form_create = form_factory.form_create_factory(name='test form create', productor_id=1231) form_body = form_factory.form_body_factory(
name='test form create', productor_id=1231)
form_create = form_factory.form_create_factory(
name='test form create', productor_id=1231)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service, 'create_one', side_effect=forms_exceptions.ProductorNotFoundError(
'create_one', messages.Messages.not_found('productor')))
side_effect=forms_exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
)
response = client.post('/api/forms', json=form_body) response = client.post('/api/forms', json=form_body)
response_data = response.json() response_data = response.json()
@@ -163,11 +172,16 @@ class TestForms:
form_create form_create
) )
def test_create_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
form_body = form_factory.form_body_factory(name='test form create') form_body = form_factory.form_body_factory(name='test form create')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.forms.service.create_one') mock = mocker.patch('src.forms.service.create_one')
@@ -200,35 +214,18 @@ class TestForms:
form_update form_update
) )
def test_update_one_notfound(self, client, mocker, mock_session, mock_user): def test_update_one_notfound(
form_body = form_factory.form_body_factory(name='test form update') self,
form_update = form_factory.form_update_factory(name='test form update') client,
mocker,
mock = mocker.patch.object(
service,
'update_one',
side_effect=forms_exceptions.FormNotFoundError(messages.Messages.not_found('form'))
)
response = client.put('/api/forms/2', json=form_body)
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
mock_session, mock_session,
2, mock_user):
form_update
)
def test_update_one_referer_notfound(self, client, mocker, mock_session, mock_user):
form_body = form_factory.form_body_factory(name='test form update') form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update') form_update = form_factory.form_update_factory(name='test form update')
mock = mocker.patch.object( mock = mocker.patch.object(
service, service, 'update_one', side_effect=forms_exceptions.FormNotFoundError(
'update_one', messages.Messages.not_found('form')))
side_effect=forms_exceptions.UserNotFoundError(messages.Messages.not_found('referer'))
)
response = client.put('/api/forms/2', json=form_body) response = client.put('/api/forms/2', json=form_body)
response_data = response.json() response_data = response.json()
@@ -240,15 +237,14 @@ class TestForms:
form_update form_update
) )
def test_update_one_productor_notfound(self, client, mocker, mock_session, mock_user): def test_update_one_referer_notfound(
self, client, mocker, mock_session, mock_user):
form_body = form_factory.form_body_factory(name='test form update') form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update') form_update = form_factory.form_update_factory(name='test form update')
mock = mocker.patch.object( mock = mocker.patch.object(
service, service, 'update_one', side_effect=forms_exceptions.UserNotFoundError(
'update_one', messages.Messages.not_found('referer')))
side_effect=forms_exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
)
response = client.put('/api/forms/2', json=form_body) response = client.put('/api/forms/2', json=form_body)
response_data = response.json() response_data = response.json()
@@ -260,11 +256,35 @@ class TestForms:
form_update form_update
) )
def test_update_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_update_one_productor_notfound(
self, client, mocker, mock_session, mock_user):
form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update')
mock = mocker.patch.object(
service, 'update_one', side_effect=forms_exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')))
response = client.put('/api/forms/2', json=form_body)
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
form_update
)
def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
form_body = form_factory.form_body_factory(name='test form update') form_body = form_factory.form_body_factory(name='test form update')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.forms.service.update_one') mock = mocker.patch('src.forms.service.update_one')
@@ -294,14 +314,17 @@ class TestForms:
2, 2,
) )
def test_delete_one_notfound(self, client, mocker, mock_session, mock_user): def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
form_result = None form_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
service, service, 'delete_one', side_effect=forms_exceptions.FormNotFoundError(
'delete_one', messages.Messages.not_found('form')))
side_effect=forms_exceptions.FormNotFoundError(messages.Messages.not_found('form'))
)
response = client.delete('/api/forms/2') response = client.delete('/api/forms/2')
response_data = response.json() response_data = response.json()
@@ -312,10 +335,15 @@ class TestForms:
2, 2,
) )
def test_delete_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.forms.service.delete_one') mock = mocker.patch('src.forms.service.delete_one')
@@ -324,4 +352,4 @@ class TestForms:
assert response.status_code == 401 assert response.status_code == 401
mock.assert_not_called() mock.assert_not_called()
app.dependency_overrides.clear() app.dependency_overrides.clear()

View File

@@ -1,14 +1,12 @@
from fastapi.exceptions import HTTPException
from src.main import app
import src.models as models
import src.messages as messages import src.messages as messages
from src.auth.auth import get_current_user
import src.productors.service as service
import src.productors.exceptions as exceptions import src.productors.exceptions as exceptions
import src.productors.service as service
import tests.factories.productors as productor_factory import tests.factories.productors as productor_factory
from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user
from src.main import app
class TestProductors: class TestProductors:
def test_get_all(self, client, mocker, mock_session, mock_user): def test_get_all(self, client, mocker, mock_session, mock_user):
@@ -34,6 +32,7 @@ class TestProductors:
[], [],
[], [],
) )
def test_get_all_filters(self, client, mocker, mock_session, mock_user): def test_get_all_filters(self, client, mocker, mock_session, mock_user):
mock_results = [ mock_results = [
productor_factory.productor_public_factory(name="test 2", id=2), productor_factory.productor_public_factory(name="test 2", id=2),
@@ -44,7 +43,8 @@ class TestProductors:
return_value=mock_results return_value=mock_results
) )
response = client.get('/api/productors?types=Légumineuses&names=test 2') response = client.get(
'/api/productors?types=Légumineuses&names=test 2')
response_data = response.json() response_data = response.json()
assert response.status_code == 200 assert response.status_code == 200
assert response_data[0]['id'] == 2 assert response_data[0]['id'] == 2
@@ -56,10 +56,15 @@ class TestProductors:
['Légumineuses'], ['Légumineuses'],
) )
def test_get_all_unauthorized(self, client, mocker, mock_session, mock_user): def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.productors.service.get_all') mock = mocker.patch('src.productors.service.get_all')
@@ -71,7 +76,8 @@ class TestProductors:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user): def test_get_one(self, client, mocker, mock_session, mock_user):
mock_result = productor_factory.productor_public_factory(name="test 2", id=2) mock_result = productor_factory.productor_public_factory(
name="test 2", id=2)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
@@ -88,7 +94,7 @@ class TestProductors:
mock_session, mock_session,
2 2
) )
def test_get_one_notfound(self, client, mocker, mock_session, mock_user): def test_get_one_notfound(self, client, mocker, mock_session, mock_user):
mock_result = None mock_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -103,11 +109,16 @@ class TestProductors:
mock_session, mock_session,
2 2
) )
def test_get_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.productors.service.get_one') mock = mocker.patch('src.productors.service.get_one')
@@ -117,11 +128,14 @@ class TestProductors:
mock.assert_not_called() mock.assert_not_called()
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user): def test_create_one(self, client, mocker, mock_session, mock_user):
productor_body = productor_factory.productor_body_factory(name='test productor create') productor_body = productor_factory.productor_body_factory(
productor_create = productor_factory.productor_create_factory(name='test productor create') name='test productor create')
productor_result = productor_factory.productor_public_factory(name='test productor create') productor_create = productor_factory.productor_create_factory(
name='test productor create')
productor_result = productor_factory.productor_public_factory(
name='test productor create')
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
@@ -139,11 +153,17 @@ class TestProductors:
productor_create productor_create
) )
def test_create_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory(name='test productor create') productor_body = productor_factory.productor_body_factory(
name='test productor create')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.productors.service.create_one') mock = mocker.patch('src.productors.service.create_one')
@@ -155,9 +175,12 @@ class TestProductors:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user): def test_update_one(self, client, mocker, mock_session, mock_user):
productor_body = productor_factory.productor_body_factory(name='test productor update') productor_body = productor_factory.productor_body_factory(
productor_update = productor_factory.productor_update_factory(name='test productor update') name='test productor update')
productor_result = productor_factory.productor_public_factory(name='test productor update') productor_update = productor_factory.productor_update_factory(
name='test productor update')
productor_result = productor_factory.productor_public_factory(
name='test productor update')
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
@@ -176,16 +199,21 @@ class TestProductors:
productor_update productor_update
) )
def test_update_one_notfound(self, client, mocker, mock_session, mock_user): def test_update_one_notfound(
productor_body = productor_factory.productor_body_factory(name='test productor update') self,
productor_update = productor_factory.productor_update_factory(name='test productor update') client,
mocker,
mock_session,
mock_user):
productor_body = productor_factory.productor_body_factory(
name='test productor update')
productor_update = productor_factory.productor_update_factory(
name='test productor update')
productor_result = None productor_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
service, service, 'update_one', side_effect=exceptions.ProductorNotFoundError(
'update_one', messages.Messages.not_found('productor')))
side_effect=exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
)
response = client.put('/api/productors/2', json=productor_body) response = client.put('/api/productors/2', json=productor_body)
response_data = response.json() response_data = response.json()
@@ -197,11 +225,17 @@ class TestProductors:
productor_update productor_update
) )
def test_update_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory(name='test productor update') productor_body = productor_factory.productor_body_factory(
name='test productor update')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.productors.service.update_one') mock = mocker.patch('src.productors.service.update_one')
@@ -213,7 +247,8 @@ class TestProductors:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user): def test_delete_one(self, client, mocker, mock_session, mock_user):
productor_result = productor_factory.productor_public_factory(name='test productor delete') productor_result = productor_factory.productor_public_factory(
name='test productor delete')
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
@@ -231,14 +266,17 @@ class TestProductors:
2, 2,
) )
def test_delete_one_notfound(self, client, mocker, mock_session, mock_user): def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
productor_result = None productor_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
service, service, 'delete_one', side_effect=exceptions.ProductorNotFoundError(
'delete_one', messages.Messages.not_found('productor')))
side_effect=exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
)
response = client.delete('/api/productors/2') response = client.delete('/api/productors/2')
response_data = response.json() response_data = response.json()
@@ -249,11 +287,17 @@ class TestProductors:
2, 2,
) )
def test_delete_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory(name='test productor delete') productor_body = productor_factory.productor_body_factory(
name='test productor delete')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.productors.service.delete_one') mock = mocker.patch('src.productors.service.delete_one')
@@ -262,4 +306,4 @@ class TestProductors:
assert response.status_code == 401 assert response.status_code == 401
mock.assert_not_called() mock.assert_not_called()
app.dependency_overrides.clear() app.dependency_overrides.clear()

View File

@@ -1,11 +1,11 @@
import src.products.service as service
import src.products.exceptions as exceptions import src.products.exceptions as exceptions
import src.models as models import src.products.service as service
from src.main import app
from src.auth.auth import get_current_user
import tests.factories.products as product_factory import tests.factories.products as product_factory
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user
from src.main import app
class TestProducts: class TestProducts:
def test_get_all(self, client, mocker, mock_session, mock_user): def test_get_all(self, client, mocker, mock_session, mock_user):
@@ -32,6 +32,7 @@ class TestProducts:
[], [],
[] []
) )
def test_get_all_filters(self, client, mocker, mock_session, mock_user): def test_get_all_filters(self, client, mocker, mock_session, mock_user):
mock_results = [ mock_results = [
product_factory.product_public_factory(name="test 2", id=2), product_factory.product_public_factory(name="test 2", id=2),
@@ -55,10 +56,15 @@ class TestProducts:
['1'], ['1'],
) )
def test_get_all_unauthorized(self, client, mocker, mock_session, mock_user): def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.products.service.get_all') mock = mocker.patch('src.products.service.get_all')
@@ -70,7 +76,8 @@ class TestProducts:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user): def test_get_one(self, client, mocker, mock_session, mock_user):
mock_result = product_factory.product_public_factory(name="test 2", id=2) mock_result = product_factory.product_public_factory(
name="test 2", id=2)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
@@ -87,7 +94,7 @@ class TestProducts:
mock_session, mock_session,
2 2
) )
def test_get_one_notfound(self, client, mocker, mock_session, mock_user): def test_get_one_notfound(self, client, mocker, mock_session, mock_user):
mock_result = None mock_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -102,11 +109,16 @@ class TestProducts:
mock_session, mock_session,
2 2
) )
def test_get_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.products.service.get_one') mock = mocker.patch('src.products.service.get_one')
@@ -116,11 +128,14 @@ class TestProducts:
mock.assert_not_called() mock.assert_not_called()
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user): def test_create_one(self, client, mocker, mock_session, mock_user):
product_body = product_factory.product_body_factory(name='test product create') product_body = product_factory.product_body_factory(
product_create = product_factory.product_create_factory(name='test product create') name='test product create')
product_result = product_factory.product_public_factory(name='test product create') product_create = product_factory.product_create_factory(
name='test product create')
product_result = product_factory.product_public_factory(
name='test product create')
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
@@ -138,11 +153,17 @@ class TestProducts:
product_create product_create
) )
def test_create_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory(name='test product create') product_body = product_factory.product_body_factory(
name='test product create')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.products.service.create_one') mock = mocker.patch('src.products.service.create_one')
@@ -154,9 +175,12 @@ class TestProducts:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user): def test_update_one(self, client, mocker, mock_session, mock_user):
product_body = product_factory.product_body_factory(name='test product update') product_body = product_factory.product_body_factory(
product_update = product_factory.product_update_factory(name='test product update') name='test product update')
product_result = product_factory.product_public_factory(name='test product update') product_update = product_factory.product_update_factory(
name='test product update')
product_result = product_factory.product_public_factory(
name='test product update')
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
@@ -175,9 +199,16 @@ class TestProducts:
product_update product_update
) )
def test_update_one_notfound(self, client, mocker, mock_session, mock_user): def test_update_one_notfound(
product_body = product_factory.product_body_factory(name='test product update') self,
product_update = product_factory.product_update_factory(name='test product update') client,
mocker,
mock_session,
mock_user):
product_body = product_factory.product_body_factory(
name='test product update')
product_update = product_factory.product_update_factory(
name='test product update')
product_result = None product_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -196,11 +227,17 @@ class TestProducts:
product_update product_update
) )
def test_update_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory(name='test product update') product_body = product_factory.product_body_factory(
name='test product update')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.products.service.update_one') mock = mocker.patch('src.products.service.update_one')
@@ -212,7 +249,8 @@ class TestProducts:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user): def test_delete_one(self, client, mocker, mock_session, mock_user):
product_result = product_factory.product_public_factory(name='test product delete') product_result = product_factory.product_public_factory(
name='test product delete')
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
@@ -230,7 +268,12 @@ class TestProducts:
2, 2,
) )
def test_delete_one_notfound(self, client, mocker, mock_session, mock_user): def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
product_result = None product_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -248,11 +291,17 @@ class TestProducts:
2, 2,
) )
def test_delete_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory(name='test product delete') product_body = product_factory.product_body_factory(
name='test product delete')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.products.service.delete_one') mock = mocker.patch('src.products.service.delete_one')
@@ -261,4 +310,4 @@ class TestProducts:
assert response.status_code == 401 assert response.status_code == 401
mock.assert_not_called() mock.assert_not_called()
app.dependency_overrides.clear() app.dependency_overrides.clear()

View File

@@ -1,12 +1,12 @@
import src.shipments.service as service
import src.models as models
from src.main import app
import src.messages as messages import src.messages as messages
import src.shipments.exceptions as exceptions import src.shipments.exceptions as exceptions
from src.auth.auth import get_current_user import src.shipments.service as service
import tests.factories.shipments as shipment_factory import tests.factories.shipments as shipment_factory
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user
from src.main import app
class TestShipments: class TestShipments:
def test_get_all(self, client, mocker, mock_session, mock_user): def test_get_all(self, client, mocker, mock_session, mock_user):
@@ -33,6 +33,7 @@ class TestShipments:
[], [],
[], [],
) )
def test_get_all_filters(self, client, mocker, mock_session, mock_user): def test_get_all_filters(self, client, mocker, mock_session, mock_user):
mock_results = [ mock_results = [
shipment_factory.shipment_public_factory(name="test 2", id=2), shipment_factory.shipment_public_factory(name="test 2", id=2),
@@ -43,7 +44,8 @@ class TestShipments:
return_value=mock_results return_value=mock_results
) )
response = client.get('/api/shipments?dates=2025-10-10&names=test 2&forms=contract form 1') response = client.get(
'/api/shipments?dates=2025-10-10&names=test 2&forms=contract form 1')
response_data = response.json() response_data = response.json()
assert response.status_code == 200 assert response.status_code == 200
assert response_data[0]['id'] == 2 assert response_data[0]['id'] == 2
@@ -56,10 +58,15 @@ class TestShipments:
['contract form 1'], ['contract form 1'],
) )
def test_get_all_unauthorized(self, client, mocker, mock_session, mock_user): def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.shipments.service.get_all') mock = mocker.patch('src.shipments.service.get_all')
@@ -71,7 +78,8 @@ class TestShipments:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user): def test_get_one(self, client, mocker, mock_session, mock_user):
mock_result = shipment_factory.shipment_public_factory(name="test 2", id=2) mock_result = shipment_factory.shipment_public_factory(
name="test 2", id=2)
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
@@ -88,7 +96,7 @@ class TestShipments:
mock_session, mock_session,
2 2
) )
def test_get_one_notfound(self, client, mocker, mock_session, mock_user): def test_get_one_notfound(self, client, mocker, mock_session, mock_user):
mock_result = None mock_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -103,11 +111,16 @@ class TestShipments:
mock_session, mock_session,
2 2
) )
def test_get_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.shipments.service.get_one') mock = mocker.patch('src.shipments.service.get_one')
@@ -117,11 +130,14 @@ class TestShipments:
mock.assert_not_called() mock.assert_not_called()
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user): def test_create_one(self, client, mocker, mock_session, mock_user):
shipment_body = shipment_factory.shipment_body_factory(name='test shipment create') shipment_body = shipment_factory.shipment_body_factory(
shipment_create = shipment_factory.shipment_create_factory(name='test shipment create') name='test shipment create')
shipment_result = shipment_factory.shipment_public_factory(name='test shipment create') shipment_create = shipment_factory.shipment_create_factory(
name='test shipment create')
shipment_result = shipment_factory.shipment_public_factory(
name='test shipment create')
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
@@ -139,11 +155,17 @@ class TestShipments:
shipment_create shipment_create
) )
def test_create_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory(name='test shipment create') shipment_body = shipment_factory.shipment_body_factory(
name='test shipment create')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.shipments.service.create_one') mock = mocker.patch('src.shipments.service.create_one')
@@ -155,9 +177,12 @@ class TestShipments:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user): def test_update_one(self, client, mocker, mock_session, mock_user):
shipment_body = shipment_factory.shipment_body_factory(name='test shipment update') shipment_body = shipment_factory.shipment_body_factory(
shipment_update = shipment_factory.shipment_update_factory(name='test shipment update') name='test shipment update')
shipment_result = shipment_factory.shipment_public_factory(name='test shipment update') shipment_update = shipment_factory.shipment_update_factory(
name='test shipment update')
shipment_result = shipment_factory.shipment_public_factory(
name='test shipment update')
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
@@ -176,15 +201,20 @@ class TestShipments:
shipment_update shipment_update
) )
def test_update_one_notfound(self, client, mocker, mock_session, mock_user): def test_update_one_notfound(
shipment_body = shipment_factory.shipment_body_factory(name='test shipment update') self,
shipment_update = shipment_factory.shipment_update_factory(name='test shipment update') client,
mocker,
mock_session,
mock_user):
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update')
shipment_update = shipment_factory.shipment_update_factory(
name='test shipment update')
mock = mocker.patch.object( mock = mocker.patch.object(
service, service, 'update_one', side_effect=exceptions.ShipmentNotFoundError(
'update_one', messages.Messages.not_found('shipment')))
side_effect=exceptions.ShipmentNotFoundError(messages.Messages.not_found('shipment'))
)
response = client.put('/api/shipments/2', json=shipment_body) response = client.put('/api/shipments/2', json=shipment_body)
response_data = response.json() response_data = response.json()
@@ -196,11 +226,17 @@ class TestShipments:
shipment_update shipment_update
) )
def test_update_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory(name='test shipment update') shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.shipments.service.update_one') mock = mocker.patch('src.shipments.service.update_one')
@@ -212,7 +248,8 @@ class TestShipments:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user): def test_delete_one(self, client, mocker, mock_session, mock_user):
shipment_result = shipment_factory.shipment_public_factory(name='test shipment delete') shipment_result = shipment_factory.shipment_public_factory(
name='test shipment delete')
mock = mocker.patch.object( mock = mocker.patch.object(
service, service,
@@ -230,14 +267,17 @@ class TestShipments:
2, 2,
) )
def test_delete_one_notfound(self, client, mocker, mock_session, mock_user): def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
shipment_result = None shipment_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
service, service, 'delete_one', side_effect=exceptions.ShipmentNotFoundError(
'delete_one', messages.Messages.not_found('shipment')))
side_effect=exceptions.ShipmentNotFoundError(messages.Messages.not_found('shipment'))
)
response = client.delete('/api/shipments/2') response = client.delete('/api/shipments/2')
response_data = response.json() response_data = response.json()
@@ -248,11 +288,17 @@ class TestShipments:
2, 2,
) )
def test_delete_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory(name='test shipment delete') shipment_body = shipment_factory.shipment_body_factory(
name='test shipment delete')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.shipments.service.delete_one') mock = mocker.patch('src.shipments.service.delete_one')
@@ -261,4 +307,4 @@ class TestShipments:
assert response.status_code == 401 assert response.status_code == 401
mock.assert_not_called() mock.assert_not_called()
app.dependency_overrides.clear() app.dependency_overrides.clear()

View File

@@ -1,11 +1,11 @@
import src.users.service as service
import src.models as models
from src.main import app
from src.auth.auth import get_current_user
import tests.factories.users as user_factory
import src.users.exceptions as exceptions import src.users.exceptions as exceptions
import src.users.service as service
import tests.factories.users as user_factory
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user
from src.main import app
class TestUsers: class TestUsers:
def test_get_all(self, client, mocker, mock_session, mock_user): def test_get_all(self, client, mocker, mock_session, mock_user):
@@ -30,6 +30,7 @@ class TestUsers:
[], [],
[], [],
) )
def test_get_all_filters(self, client, mocker, mock_session, mock_user): def test_get_all_filters(self, client, mocker, mock_session, mock_user):
mock_results = [ mock_results = [
user_factory.user_public_factory(name="test 2", id=2), user_factory.user_public_factory(name="test 2", id=2),
@@ -51,10 +52,15 @@ class TestUsers:
['test@test.test'], ['test@test.test'],
) )
def test_get_all_unauthorized(self, client, mocker, mock_session, mock_user): def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.users.service.get_all') mock = mocker.patch('src.users.service.get_all')
@@ -83,7 +89,7 @@ class TestUsers:
mock_session, mock_session,
2 2
) )
def test_get_one_notfound(self, client, mocker, mock_session, mock_user): def test_get_one_notfound(self, client, mocker, mock_session, mock_user):
mock_result = None mock_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -98,11 +104,16 @@ class TestUsers:
mock_session, mock_session,
2 2
) )
def test_get_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.users.service.get_one') mock = mocker.patch('src.users.service.get_one')
@@ -112,7 +123,7 @@ class TestUsers:
mock.assert_not_called() mock.assert_not_called()
app.dependency_overrides.clear() app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user): def test_create_one(self, client, mocker, mock_session, mock_user):
user_body = user_factory.user_body_factory(name='test user create') user_body = user_factory.user_body_factory(name='test user create')
user_create = user_factory.user_create_factory(name='test user create') user_create = user_factory.user_create_factory(name='test user create')
@@ -134,11 +145,16 @@ class TestUsers:
user_create user_create
) )
def test_create_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
user_body = user_factory.user_body_factory(name='test user create') user_body = user_factory.user_body_factory(name='test user create')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.users.service.create_one') mock = mocker.patch('src.users.service.create_one')
@@ -171,7 +187,12 @@ class TestUsers:
user_update user_update
) )
def test_update_one_notfound(self, client, mocker, mock_session, mock_user): def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
user_body = user_factory.user_body_factory(name='test user update') user_body = user_factory.user_body_factory(name='test user update')
user_update = user_factory.user_update_factory(name='test user update') user_update = user_factory.user_update_factory(name='test user update')
user_result = None user_result = None
@@ -192,11 +213,16 @@ class TestUsers:
user_update user_update
) )
def test_update_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
user_body = user_factory.user_body_factory(name='test user update') user_body = user_factory.user_body_factory(name='test user update')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.users.service.update_one') mock = mocker.patch('src.users.service.update_one')
@@ -226,7 +252,12 @@ class TestUsers:
2, 2,
) )
def test_delete_one_notfound(self, client, mocker, mock_session, mock_user): def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
user_result = None user_result = None
mock = mocker.patch.object( mock = mocker.patch.object(
@@ -244,11 +275,16 @@ class TestUsers:
2, 2,
) )
def test_delete_one_unauthorized(self, client, mocker, mock_session, mock_user): def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized(): def unauthorized():
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
user_body = user_factory.user_body_factory(name='test user delete') user_body = user_factory.user_body_factory(name='test user delete')
app.dependency_overrides[get_current_user] = unauthorized app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.users.service.delete_one') mock = mocker.patch('src.users.service.delete_one')
@@ -257,4 +293,4 @@ class TestUsers:
assert response.status_code == 401 assert response.status_code == 401
mock.assert_not_called() mock.assert_not_called()
app.dependency_overrides.clear() app.dependency_overrides.clear()

View File

@@ -1,35 +1,41 @@
import pytest import pytest
from sqlmodel import Session
import src.models as models
import src.forms.service as forms_service
import src.forms.exceptions as forms_exceptions import src.forms.exceptions as forms_exceptions
import src.forms.service as forms_service
import tests.factories.forms as forms_factory import tests.factories.forms as forms_factory
from sqlmodel import Session
from src import models
class TestFormsService: class TestFormsService:
def test_get_all_forms(self, session: Session, forms: list[models.FormPublic]): def test_get_all_forms(self, session: Session,
forms: list[models.FormPublic]):
result = forms_service.get_all(session, [], [], False) result = forms_service.get_all(session, [], [], False)
assert len(result) == 2 assert len(result) == 2
assert result == forms assert result == forms
def test_get_all_forms_filter_productors(self, session: Session, forms: list[models.FormPublic]): def test_get_all_forms_filter_productors(
self, session: Session, forms: list[models.FormPublic]):
result = forms_service.get_all(session, [], ['test productor'], False) result = forms_service.get_all(session, [], ['test productor'], False)
assert len(result) == 2 assert len(result) == 2
assert result == forms assert result == forms
def test_get_all_forms_filter_season(self, session: Session, forms: list[models.FormPublic]): def test_get_all_forms_filter_season(
self, session: Session, forms: list[models.FormPublic]):
result = forms_service.get_all(session, ['test season 1'], [], False) result = forms_service.get_all(session, ['test season 1'], [], False)
assert len(result) == 1 assert len(result) == 1
def test_get_all_forms_all_filters(self, session: Session, forms: list[models.FormPublic]): def test_get_all_forms_all_filters(
result = forms_service.get_all(session, ['test season 1'], ['test productor'], True) self, session: Session, forms: list[models.FormPublic]):
result = forms_service.get_all(
session, ['test season 1'], ['test productor'], True)
assert result == forms assert result == forms
def test_get_one_form(self, session: Session, forms: list[models.FormPublic]): def test_get_one_form(self, session: Session,
forms: list[models.FormPublic]):
result = forms_service.get_one(session, forms[0].id) result = forms_service.get_one(session, forms[0].id)
assert result == forms[0] assert result == forms[0]
@@ -37,10 +43,10 @@ class TestFormsService:
def test_get_one_form_notfound(self, session: Session): def test_get_one_form_notfound(self, session: Session):
result = forms_service.get_one(session, 122) result = forms_service.get_one(session, 122)
assert result == None assert result is None
def test_create_form( def test_create_form(
self, self,
session: Session, session: Session,
productor: models.ProductorPublic, productor: models.ProductorPublic,
referer: models.ProductorPublic referer: models.ProductorPublic
@@ -56,10 +62,10 @@ class TestFormsService:
assert result.id is not None assert result.id is not None
assert result.name == "new test form" assert result.name == "new test form"
assert result.productor.name == "test productor" assert result.productor.name == "test productor"
def test_create_form_invalidinput( def test_create_form_invalidinput(
self, self,
session: Session, session: Session,
productor: models.Productor productor: models.Productor
): ):
form_create = None form_create = None
@@ -69,17 +75,16 @@ class TestFormsService:
form_create = forms_factory.form_create_factory(productor_id=123) form_create = forms_factory.form_create_factory(productor_id=123)
with pytest.raises(forms_exceptions.ProductorNotFoundError): with pytest.raises(forms_exceptions.ProductorNotFoundError):
result = forms_service.create_one(session, form_create) result = forms_service.create_one(session, form_create)
form_create = forms_factory.form_create_factory( form_create = forms_factory.form_create_factory(
productor_id=productor.id, productor_id=productor.id,
referer_id=123 referer_id=123
) )
with pytest.raises(forms_exceptions.UserNotFoundError): with pytest.raises(forms_exceptions.UserNotFoundError):
result = forms_service.create_one(session, form_create) result = forms_service.create_one(session, form_create)
def test_update_form( def test_update_form(
self, self,
session: Session, session: Session,
productor: models.ProductorPublic, productor: models.ProductorPublic,
referer: models.ProductorPublic, referer: models.ProductorPublic,
@@ -97,9 +102,9 @@ class TestFormsService:
assert result.id == form_id assert result.id == form_id
assert result.name == 'updated test form' assert result.name == 'updated test form'
assert result.season == 'updated test season' assert result.season == 'updated test season'
def test_update_form_notfound( def test_update_form_notfound(
self, self,
session: Session, session: Session,
productor: models.ProductorPublic, productor: models.ProductorPublic,
referer: models.ProductorPublic, referer: models.ProductorPublic,
@@ -113,42 +118,41 @@ class TestFormsService:
form_id = 123 form_id = 123
with pytest.raises(forms_exceptions.FormNotFoundError): with pytest.raises(forms_exceptions.FormNotFoundError):
result = forms_service.update_one(session, form_id, form_update) result = forms_service.update_one(session, form_id, form_update)
def test_update_form_invalidinput( def test_update_form_invalidinput(
self, self,
session: Session, session: Session,
productor: models.ProductorPublic, productor: models.ProductorPublic,
forms: list[models.FormPublic] forms: list[models.FormPublic]
): ):
form_id = forms[0].id form_id = forms[0].id
form_update = forms_factory.form_update_factory(productor_id=123) form_update = forms_factory.form_update_factory(productor_id=123)
with pytest.raises(forms_exceptions.ProductorNotFoundError): with pytest.raises(forms_exceptions.ProductorNotFoundError):
result = forms_service.update_one(session, form_id, form_update) result = forms_service.update_one(session, form_id, form_update)
form_update = forms_factory.form_update_factory( form_update = forms_factory.form_update_factory(
productor_id=productor.id, productor_id=productor.id,
referer_id=123 referer_id=123
) )
with pytest.raises(forms_exceptions.UserNotFoundError): with pytest.raises(forms_exceptions.UserNotFoundError):
result = forms_service.update_one(session, form_id, form_update) result = forms_service.update_one(session, form_id, form_update)
def test_delete_form( def test_delete_form(
self, self,
session: Session, session: Session,
forms: list[models.FormPublic] forms: list[models.FormPublic]
): ):
form_id = forms[0].id form_id = forms[0].id
result = forms_service.delete_one(session, form_id) result = forms_service.delete_one(session, form_id)
check = forms_service.get_one(session, form_id) check = forms_service.get_one(session, form_id)
assert check == None assert check is None
def test_delete_form_notfound( def test_delete_form_notfound(
self, self,
session: Session, session: Session,
forms: list[models.FormPublic] forms: list[models.FormPublic]
): ):
form_id = 123 form_id = 123
with pytest.raises(forms_exceptions.FormNotFoundError): with pytest.raises(forms_exceptions.FormNotFoundError):
result = forms_service.delete_one(session, form_id) result = forms_service.delete_one(session, form_id)

View File

@@ -1,15 +1,15 @@
import pytest import pytest
from sqlmodel import Session
import src.models as models
import src.productors.service as productors_service
import src.productors.exceptions as productors_exceptions import src.productors.exceptions as productors_exceptions
import src.productors.service as productors_service
import tests.factories.productors as productors_factory import tests.factories.productors as productors_factory
from sqlmodel import Session
from src import models
class TestProductorsService: class TestProductorsService:
def test_get_all_productors( def test_get_all_productors(
self, self,
session: Session, session: Session,
productors: list[models.ProductorPublic], productors: list[models.ProductorPublic],
user: models.UserPublic user: models.UserPublic
): ):
@@ -19,51 +19,53 @@ class TestProductorsService:
assert result == [productors[0]] assert result == [productors[0]]
def test_get_all_productors_filter_names( def test_get_all_productors_filter_names(
self, self,
session: Session, session: Session,
productors: list[models.ProductorPublic], productors: list[models.ProductorPublic],
user: models.UserPublic user: models.UserPublic
): ):
result = productors_service.get_all( result = productors_service.get_all(
session, session,
user, user,
['test productor 1'], ['test productor 1'],
[] []
) )
assert len(result) == 1 assert len(result) == 1
def test_get_all_productors_filter_types( def test_get_all_productors_filter_types(
self, self,
session: Session, session: Session,
productors: list[models.ProductorPublic], productors: list[models.ProductorPublic],
user: models.UserPublic user: models.UserPublic
): ):
result = productors_service.get_all( result = productors_service.get_all(
session, session,
user, user,
[], [],
['Légumineuses'], ['Légumineuses'],
) )
assert len(result) == 1 assert len(result) == 1
def test_get_all_productors_all_filters( def test_get_all_productors_all_filters(
self, self,
session: Session, session: Session,
productors: list[models.ProductorPublic], productors: list[models.ProductorPublic],
user: models.UserPublic user: models.UserPublic
): ):
result = productors_service.get_all( result = productors_service.get_all(
session, session,
user, user,
['test productor 1'], ['test productor 1'],
['Légumineuses'], ['Légumineuses'],
) )
assert len(result) == 1 assert len(result) == 1
def test_get_one_productor(self, session: Session, productors: list[models.ProductorPublic]): def test_get_one_productor(self,
session: Session,
productors: list[models.ProductorPublic]):
result = productors_service.get_one(session, productors[0].id) result = productors_service.get_one(session, productors[0].id)
assert result == productors[0] assert result == productors[0]
@@ -71,10 +73,10 @@ class TestProductorsService:
def test_get_one_productor_notfound(self, session: Session): def test_get_one_productor_notfound(self, session: Session):
result = productors_service.get_one(session, 122) result = productors_service.get_one(session, 122)
assert result == None assert result is None
def test_create_productor( def test_create_productor(
self, self,
session: Session, session: Session,
referer: models.ProductorPublic referer: models.ProductorPublic
): ):
@@ -85,17 +87,17 @@ class TestProductorsService:
assert result.id is not None assert result.id is not None
assert result.name == "new test productor" assert result.name == "new test productor"
def test_create_productor_invalidinput( def test_create_productor_invalidinput(
self, self,
session: Session, session: Session,
): ):
productor_create = None productor_create = None
with pytest.raises(productors_exceptions.ProductorCreateError): with pytest.raises(productors_exceptions.ProductorCreateError):
result = productors_service.create_one(session, productor_create) result = productors_service.create_one(session, productor_create)
def test_update_productor( def test_update_productor(
self, self,
session: Session, session: Session,
referer: models.ProductorPublic, referer: models.ProductorPublic,
productors: list[models.ProductorPublic] productors: list[models.ProductorPublic]
@@ -104,13 +106,14 @@ class TestProductorsService:
name='updated test productor', name='updated test productor',
) )
productor_id = productors[0].id productor_id = productors[0].id
result = productors_service.update_one(session, productor_id, productor_update) result = productors_service.update_one(
session, productor_id, productor_update)
assert result.id == productor_id assert result.id == productor_id
assert result.name == 'updated test productor' assert result.name == 'updated test productor'
def test_update_productor_notfound( def test_update_productor_notfound(
self, self,
session: Session, session: Session,
referer: models.ProductorPublic, referer: models.ProductorPublic,
): ):
@@ -119,25 +122,25 @@ class TestProductorsService:
) )
productor_id = 123 productor_id = 123
with pytest.raises(productors_exceptions.ProductorNotFoundError): with pytest.raises(productors_exceptions.ProductorNotFoundError):
result = productors_service.update_one(session, productor_id, productor_update) result = productors_service.update_one(
session, productor_id, productor_update)
def test_delete_productor( def test_delete_productor(
self, self,
session: Session, session: Session,
productors: list[models.ProductorPublic] productors: list[models.ProductorPublic]
): ):
productor_id = productors[0].id productor_id = productors[0].id
result = productors_service.delete_one(session, productor_id) result = productors_service.delete_one(session, productor_id)
check = productors_service.get_one(session, productor_id) check = productors_service.get_one(session, productor_id)
assert check == None assert check is None
def test_delete_productor_notfound( def test_delete_productor_notfound(
self, self,
session: Session, session: Session,
productors: list[models.ProductorPublic] productors: list[models.ProductorPublic]
): ):
productor_id = 123 productor_id = 123
with pytest.raises(productors_exceptions.ProductorNotFoundError): with pytest.raises(productors_exceptions.ProductorNotFoundError):
result = productors_service.delete_one(session, productor_id) result = productors_service.delete_one(session, productor_id)

View File

@@ -1,15 +1,15 @@
import pytest import pytest
from sqlmodel import Session
import src.models as models
import src.products.service as products_service
import src.products.exceptions as products_exceptions import src.products.exceptions as products_exceptions
import src.products.service as products_service
import tests.factories.products as products_factory import tests.factories.products as products_factory
from sqlmodel import Session
from src import models
class TestProductsService: class TestProductsService:
def test_get_all_products( def test_get_all_products(
self, self,
session: Session, session: Session,
products: list[models.ProductPublic], products: list[models.ProductPublic],
user: models.UserPublic user: models.UserPublic
): ):
@@ -19,16 +19,16 @@ class TestProductsService:
assert result == products assert result == products
def test_get_all_products_filter_productors( def test_get_all_products_filter_productors(
self, self,
session: Session, session: Session,
products: list[models.ProductPublic], products: list[models.ProductPublic],
user: models.UserPublic user: models.UserPublic
): ):
result = products_service.get_all( result = products_service.get_all(
session, session,
user, user,
[], [],
['test productor'], ['test productor'],
[] []
) )
@@ -36,54 +36,55 @@ class TestProductsService:
assert result == products assert result == products
def test_get_all_products_filter_names( def test_get_all_products_filter_names(
self, self,
session: Session, session: Session,
products: list[models.ProductPublic], products: list[models.ProductPublic],
user: models.UserPublic user: models.UserPublic
): ):
result = products_service.get_all( result = products_service.get_all(
session, session,
user, user,
['product 1 occasionnal'], ['product 1 occasionnal'],
[], [],
[] []
) )
assert len(result) == 1 assert len(result) == 1
def test_get_all_products_filter_types( def test_get_all_products_filter_types(
self, self,
session: Session, session: Session,
products: list[models.ProductPublic], products: list[models.ProductPublic],
user: models.UserPublic user: models.UserPublic
): ):
result = products_service.get_all( result = products_service.get_all(
session, session,
user, user,
[], [],
[], [],
['1'] ['1']
) )
assert len(result) == 1 assert len(result) == 1
def test_get_all_products_all_filters( def test_get_all_products_all_filters(
self, self,
session: Session, session: Session,
products: list[models.ProductPublic], products: list[models.ProductPublic],
user: models.UserPublic user: models.UserPublic
): ):
result = products_service.get_all( result = products_service.get_all(
session, session,
user, user,
['product 1 occasionnal'], ['product 1 occasionnal'],
['test productor'], ['test productor'],
['1'] ['1']
) )
assert len(result) == 1 assert len(result) == 1
def test_get_one_product(self, session: Session, products: list[models.ProductPublic]): def test_get_one_product(self, session: Session,
products: list[models.ProductPublic]):
result = products_service.get_one(session, products[0].id) result = products_service.get_one(session, products[0].id)
assert result == products[0] assert result == products[0]
@@ -91,10 +92,10 @@ class TestProductsService:
def test_get_one_product_notfound(self, session: Session): def test_get_one_product_notfound(self, session: Session):
result = products_service.get_one(session, 122) result = products_service.get_one(session, 122)
assert result == None assert result is None
def test_create_product( def test_create_product(
self, self,
session: Session, session: Session,
productor: models.ProductorPublic, productor: models.ProductorPublic,
referer: models.ProductorPublic referer: models.ProductorPublic
@@ -108,22 +109,23 @@ class TestProductsService:
assert result.id is not None assert result.id is not None
assert result.name == "new test product" assert result.name == "new test product"
assert result.productor.name == "test productor" assert result.productor.name == "test productor"
def test_create_product_invalidinput( def test_create_product_invalidinput(
self, self,
session: Session, session: Session,
productor: models.Productor productor: models.Productor
): ):
product_create = None product_create = None
with pytest.raises(products_exceptions.ProductCreateError): with pytest.raises(products_exceptions.ProductCreateError):
result = products_service.create_one(session, product_create) result = products_service.create_one(session, product_create)
product_create = products_factory.product_create_factory(productor_id=123) product_create = products_factory.product_create_factory(
productor_id=123)
with pytest.raises(products_exceptions.ProductorNotFoundError): with pytest.raises(products_exceptions.ProductorNotFoundError):
result = products_service.create_one(session, product_create) result = products_service.create_one(session, product_create)
def test_update_product( def test_update_product(
self, self,
session: Session, session: Session,
productor: models.ProductorPublic, productor: models.ProductorPublic,
referer: models.ProductorPublic, referer: models.ProductorPublic,
@@ -134,13 +136,14 @@ class TestProductsService:
productor_id=productor.id, productor_id=productor.id,
) )
product_id = products[0].id product_id = products[0].id
result = products_service.update_one(session, product_id, product_update) result = products_service.update_one(
session, product_id, product_update)
assert result.id == product_id assert result.id == product_id
assert result.name == 'updated test product' assert result.name == 'updated test product'
def test_update_product_notfound( def test_update_product_notfound(
self, self,
session: Session, session: Session,
productor: models.ProductorPublic, productor: models.ProductorPublic,
referer: models.ProductorPublic, referer: models.ProductorPublic,
@@ -151,41 +154,43 @@ class TestProductsService:
) )
product_id = 123 product_id = 123
with pytest.raises(products_exceptions.ProductNotFoundError): with pytest.raises(products_exceptions.ProductNotFoundError):
result = products_service.update_one(session, product_id, product_update) result = products_service.update_one(
session, product_id, product_update)
def test_update_product_invalidinput( def test_update_product_invalidinput(
self, self,
session: Session, session: Session,
productor: models.ProductorPublic, productor: models.ProductorPublic,
products: list[models.ProductPublic] products: list[models.ProductPublic]
): ):
product_id = products[0].id product_id = products[0].id
product_update = products_factory.product_update_factory(productor_id=123)
with pytest.raises(products_exceptions.ProductorNotFoundError):
result = products_service.update_one(session, product_id, product_update)
product_update = products_factory.product_update_factory( product_update = products_factory.product_update_factory(
productor_id=productor.id, productor_id=123)
with pytest.raises(products_exceptions.ProductorNotFoundError):
result = products_service.update_one(
session, product_id, product_update)
product_update = products_factory.product_update_factory(
productor_id=productor.id,
referer_id=123 referer_id=123
) )
def test_delete_product( def test_delete_product(
self, self,
session: Session, session: Session,
products: list[models.ProductPublic] products: list[models.ProductPublic]
): ):
product_id = products[0].id product_id = products[0].id
result = products_service.delete_one(session, product_id) result = products_service.delete_one(session, product_id)
check = products_service.get_one(session, product_id) check = products_service.get_one(session, product_id)
assert check == None assert check is None
def test_delete_product_notfound( def test_delete_product_notfound(
self, self,
session: Session, session: Session,
products: list[models.ProductPublic] products: list[models.ProductPublic]
): ):
product_id = 123 product_id = 123
with pytest.raises(products_exceptions.ProductNotFoundError): with pytest.raises(products_exceptions.ProductNotFoundError):
result = products_service.delete_one(session, product_id) result = products_service.delete_one(session, product_id)

View File

@@ -1,16 +1,17 @@
import pytest
import datetime import datetime
from sqlmodel import Session
import src.models as models import pytest
import src.shipments.service as shipments_service
import src.shipments.exceptions as shipments_exceptions import src.shipments.exceptions as shipments_exceptions
import src.shipments.service as shipments_service
import tests.factories.shipments as shipments_factory import tests.factories.shipments as shipments_factory
from sqlmodel import Session
from src import models
class TestShipmentsService: class TestShipmentsService:
def test_get_all_shipments( def test_get_all_shipments(
self, self,
session: Session, session: Session,
shipments: list[models.ShipmentPublic], shipments: list[models.ShipmentPublic],
user: models.UserPublic, user: models.UserPublic,
): ):
@@ -20,49 +21,54 @@ class TestShipmentsService:
assert result == shipments assert result == shipments
def test_get_all_shipments_filter_names( def test_get_all_shipments_filter_names(
self, self,
session: Session, session: Session,
shipments: list[models.ShipmentPublic], shipments: list[models.ShipmentPublic],
user: models.UserPublic, user: models.UserPublic,
): ):
result = shipments_service.get_all(session, user, ['test shipment 1'], [], []) result = shipments_service.get_all(
session, user, ['test shipment 1'], [], [])
assert len(result) == 1 assert len(result) == 1
assert result == [shipments[0]] assert result == [shipments[0]]
def test_get_all_shipments_filter_dates( def test_get_all_shipments_filter_dates(
self, self,
session: Session, session: Session,
shipments: list[models.ShipmentPublic], shipments: list[models.ShipmentPublic],
user: models.UserPublic, user: models.UserPublic,
): ):
result = shipments_service.get_all(session, user, [], ['2025-10-10'], []) result = shipments_service.get_all(
session, user, [], ['2025-10-10'], [])
assert len(result) == 1 assert len(result) == 1
def test_get_all_shipments_filter_forms( def test_get_all_shipments_filter_forms(
self, self,
session: Session, session: Session,
shipments: list[models.ShipmentPublic], shipments: list[models.ShipmentPublic],
forms: list[models.FormPublic], forms: list[models.FormPublic],
user: models.UserPublic, user: models.UserPublic,
): ):
result = shipments_service.get_all(session, user, [], [], [forms[0].name]) result = shipments_service.get_all(
session, user, [], [], [forms[0].name])
assert len(result) == 2 assert len(result) == 2
def test_get_all_shipments_all_filters( def test_get_all_shipments_all_filters(
self, self,
session: Session, session: Session,
shipments: list[models.ShipmentPublic], shipments: list[models.ShipmentPublic],
forms: list[models.FormPublic], forms: list[models.FormPublic],
user: models.UserPublic, user: models.UserPublic,
): ):
result = shipments_service.get_all(session, user, ['test shipment 1'], ['2025-10-10'], [forms[0].name]) result = shipments_service.get_all(session, user, ['test shipment 1'], [
'2025-10-10'], [forms[0].name])
assert len(result) == 1 assert len(result) == 1
def test_get_one_shipment(self, session: Session, shipments: list[models.ShipmentPublic]): def test_get_one_shipment(self, session: Session,
shipments: list[models.ShipmentPublic]):
result = shipments_service.get_one(session, shipments[0].id) result = shipments_service.get_one(session, shipments[0].id)
assert result == shipments[0] assert result == shipments[0]
@@ -70,10 +76,10 @@ class TestShipmentsService:
def test_get_one_shipment_notfound(self, session: Session): def test_get_one_shipment_notfound(self, session: Session):
result = shipments_service.get_one(session, 122) result = shipments_service.get_one(session, 122)
assert result == None assert result is None
def test_create_shipment( def test_create_shipment(
self, self,
session: Session, session: Session,
): ):
shipment_create = shipments_factory.shipment_create_factory( shipment_create = shipments_factory.shipment_create_factory(
@@ -84,17 +90,17 @@ class TestShipmentsService:
assert result.id is not None assert result.id is not None
assert result.name == "new test shipment" assert result.name == "new test shipment"
def test_create_shipment_invalidinput( def test_create_shipment_invalidinput(
self, self,
session: Session, session: Session,
): ):
shipment_create = None shipment_create = None
with pytest.raises(shipments_exceptions.ShipmentCreateError): with pytest.raises(shipments_exceptions.ShipmentCreateError):
result = shipments_service.create_one(session, shipment_create) result = shipments_service.create_one(session, shipment_create)
def test_update_shipment( def test_update_shipment(
self, self,
session: Session, session: Session,
shipments: list[models.ShipmentPublic] shipments: list[models.ShipmentPublic]
): ):
@@ -103,14 +109,15 @@ class TestShipmentsService:
date='2025-12-10', date='2025-12-10',
) )
shipment_id = shipments[0].id shipment_id = shipments[0].id
result = shipments_service.update_one(session, shipment_id, shipment_update) result = shipments_service.update_one(
session, shipment_id, shipment_update)
assert result.id == shipment_id assert result.id == shipment_id
assert result.name == 'updated shipment 1' assert result.name == 'updated shipment 1'
assert result.date == datetime.date(2025, 12, 10) assert result.date == datetime.date(2025, 12, 10)
def test_update_shipment_notfound( def test_update_shipment_notfound(
self, self,
session: Session, session: Session,
): ):
shipment_update = shipments_factory.shipment_update_factory( shipment_update = shipments_factory.shipment_update_factory(
@@ -119,25 +126,25 @@ class TestShipmentsService:
) )
shipment_id = 123 shipment_id = 123
with pytest.raises(shipments_exceptions.ShipmentNotFoundError): with pytest.raises(shipments_exceptions.ShipmentNotFoundError):
result = shipments_service.update_one(session, shipment_id, shipment_update) result = shipments_service.update_one(
session, shipment_id, shipment_update)
def test_delete_shipment( def test_delete_shipment(
self, self,
session: Session, session: Session,
shipments: list[models.ShipmentPublic] shipments: list[models.ShipmentPublic]
): ):
shipment_id = shipments[0].id shipment_id = shipments[0].id
result = shipments_service.delete_one(session, shipment_id) result = shipments_service.delete_one(session, shipment_id)
check = shipments_service.get_one(session, shipment_id) check = shipments_service.get_one(session, shipment_id)
assert check == None assert check is None
def test_delete_shipment_notfound( def test_delete_shipment_notfound(
self, self,
session: Session, session: Session,
shipments: list[models.ShipmentPublic] shipments: list[models.ShipmentPublic]
): ):
shipment_id = 123 shipment_id = 123
with pytest.raises(shipments_exceptions.ShipmentNotFoundError): with pytest.raises(shipments_exceptions.ShipmentNotFoundError):
result = shipments_service.delete_one(session, shipment_id) result = shipments_service.delete_one(session, shipment_id)

View File

@@ -1,35 +1,41 @@
import pytest import pytest
from sqlmodel import Session
import src.models as models
import src.users.service as users_service
import src.users.exceptions as users_exceptions import src.users.exceptions as users_exceptions
import src.users.service as users_service
import tests.factories.users as users_factory import tests.factories.users as users_factory
from sqlmodel import Session
from src import models
class TestUsersService: class TestUsersService:
def test_get_all_users(self, session: Session, users: list[models.UserPublic]): def test_get_all_users(self, session: Session,
users: list[models.UserPublic]):
result = users_service.get_all(session, [], []) result = users_service.get_all(session, [], [])
assert len(result) == 3 assert len(result) == 3
assert result == users assert result == users
def test_get_all_users_filter_names(self, session: Session, users: list[models.UserPublic]): def test_get_all_users_filter_names(
self, session: Session, users: list[models.UserPublic]):
result = users_service.get_all(session, ['test user 1 (admin)'], []) result = users_service.get_all(session, ['test user 1 (admin)'], [])
assert len(result) == 1 assert len(result) == 1
assert result == [users[0]] assert result == [users[0]]
def test_get_all_users_filter_emails(self, session: Session, users: list[models.UserPublic]): def test_get_all_users_filter_emails(
self, session: Session, users: list[models.UserPublic]):
result = users_service.get_all(session, [], ['test1@test.com']) result = users_service.get_all(session, [], ['test1@test.com'])
assert len(result) == 1 assert len(result) == 1
def test_get_all_users_all_filters(self, session: Session, users: list[models.UserPublic]): def test_get_all_users_all_filters(
result = users_service.get_all(session, ['test user 1 (admin)'], ['test1@test.com']) self, session: Session, users: list[models.UserPublic]):
result = users_service.get_all(
session, ['test user 1 (admin)'], ['test1@test.com'])
assert len(result) == 1 assert len(result) == 1
def test_get_one_user(self, session: Session, users: list[models.UserPublic]): def test_get_one_user(self, session: Session,
users: list[models.UserPublic]):
result = users_service.get_one(session, users[0].id) result = users_service.get_one(session, users[0].id)
assert result == users[0] assert result == users[0]
@@ -37,10 +43,10 @@ class TestUsersService:
def test_get_one_user_notfound(self, session: Session): def test_get_one_user_notfound(self, session: Session):
result = users_service.get_one(session, 122) result = users_service.get_one(session, 122)
assert result == None assert result is None
def test_create_user( def test_create_user(
self, self,
session: Session, session: Session,
): ):
user_create = users_factory.user_create_factory( user_create = users_factory.user_create_factory(
@@ -54,17 +60,17 @@ class TestUsersService:
assert result.name == "new test user" assert result.name == "new test user"
assert result.email == "test@test.fr" assert result.email == "test@test.fr"
assert len(result.roles) == 1 assert len(result.roles) == 1
def test_create_user_invalidinput( def test_create_user_invalidinput(
self, self,
session: Session, session: Session,
): ):
user_create = None user_create = None
with pytest.raises(users_exceptions.UserCreateError): with pytest.raises(users_exceptions.UserCreateError):
result = users_service.create_one(session, user_create) result = users_service.create_one(session, user_create)
def test_update_user( def test_update_user(
self, self,
session: Session, session: Session,
users: list[models.UserPublic] users: list[models.UserPublic]
): ):
@@ -79,9 +85,9 @@ class TestUsersService:
assert result.id == user_id assert result.id == user_id
assert result.name == 'updated test user' assert result.name == 'updated test user'
assert result.email == 'test@testttt.fr' assert result.email == 'test@testttt.fr'
def test_update_user_notfound( def test_update_user_notfound(
self, self,
session: Session, session: Session,
): ):
user_update = users_factory.user_update_factory( user_update = users_factory.user_update_factory(
@@ -92,24 +98,23 @@ class TestUsersService:
user_id = 123 user_id = 123
with pytest.raises(users_exceptions.UserNotFoundError): with pytest.raises(users_exceptions.UserNotFoundError):
result = users_service.update_one(session, user_id, user_update) result = users_service.update_one(session, user_id, user_update)
def test_delete_user( def test_delete_user(
self, self,
session: Session, session: Session,
users: list[models.UserPublic] users: list[models.UserPublic]
): ):
user_id = users[0].id user_id = users[0].id
result = users_service.delete_one(session, user_id) result = users_service.delete_one(session, user_id)
check = users_service.get_one(session, user_id) check = users_service.get_one(session, user_id)
assert check == None assert check is None
def test_delete_user_notfound( def test_delete_user_notfound(
self, self,
session: Session, session: Session,
users: list[models.UserPublic] users: list[models.UserPublic]
): ):
user_id = 123 user_id = 123
with pytest.raises(users_exceptions.UserNotFoundError): with pytest.raises(users_exceptions.UserNotFoundError):
result = users_service.delete_one(session, user_id) result = users_service.delete_one(session, user_id)