Compare commits

31 Commits

Author SHA1 Message Date
Julien Aldon
e970bb683a fix a bug that could prevent user to selet their payment methods
All checks were successful
Deploy Amap / deploy (push) Successful in 1m52s
2026-03-06 11:59:02 +01:00
Julien Aldon
c27c7598b5 fix tests 2026-03-06 11:26:02 +01:00
b4b4fa7643 fix all pylint warnings, add tests (wip) fix recap 2026-03-06 00:00:01 +01:00
60812652cf Merge branch 'feat/permissions' of gitea.aldon.fr:Mop/amap into feature/export-recap 2026-03-05 20:58:05 +01:00
cb0235e19f fix contract recap 2026-03-05 20:58:00 +01:00
Julien Aldon
5c356f5802 fix header width order 2026-03-05 17:20:44 +01:00
Julien Aldon
ff19448991 add functionnal recap ready for tests 2026-03-05 17:17:23 +01:00
5e413b11e0 add permission check for form productor and product 2026-03-04 23:36:17 +01:00
Julien Aldon
3cfa60507e [WIP] add styles 2026-03-03 17:58:33 +01:00
Julien Aldon
6679107b13 downgrade python version in tests
All checks were successful
Deploy Amap / deploy (push) Successful in 1m47s
2026-03-03 14:34:00 +01:00
Julien Aldon
20eba7f183 remove debug in router
Some checks failed
Deploy Amap / deploy (push) Failing after 11s
2026-03-03 11:37:40 +01:00
Julien Aldon
c6d75831c9 add debug for tests
Some checks failed
Deploy Amap / deploy (push) Failing after 11s
2026-03-03 11:37:23 +01:00
Julien Aldon
b2e2d02818 add debug for tests
Some checks failed
Deploy Amap / deploy (push) Failing after 10s
2026-03-03 11:36:22 +01:00
Julien Aldon
8cb7893aff add debug for tests
Some checks failed
Deploy Amap / deploy (push) Failing after 10s
2026-03-03 11:34:41 +01:00
Julien Aldon
015e09a980 add debug for tests
Some checks failed
Deploy Amap / deploy (push) Failing after 10s
2026-03-03 11:31:16 +01:00
Julien Aldon
a70ab5d3cb add debug for tests
Some checks failed
Deploy Amap / deploy (push) Failing after 10s
2026-03-03 11:29:57 +01:00
Julien Aldon
9d5dbd80cc add debug for tests
Some checks failed
Deploy Amap / deploy (push) Failing after 10s
2026-03-03 11:27:01 +01:00
Julien Aldon
1c6e810ec1 fix python test version
Some checks failed
Deploy Amap / deploy (push) Failing after 1m30s
2026-03-03 11:15:39 +01:00
Julien Aldon
8352097ffb fix pylint errors
Some checks failed
Deploy Amap / deploy (push) Failing after 16s
2026-03-03 11:08:08 +01:00
Julien Aldon
0e48d1bbaa fix test format 2026-03-02 16:33:52 +01:00
Julien Aldon
5dd9e19877 add partial contract tests
Some checks failed
Deploy Amap / deploy (push) Failing after 44s
2026-03-02 11:45:05 +01:00
Julien Aldon
4a4c1225dc fix python version for tests
All checks were successful
Deploy Amap / deploy (push) Successful in 7m25s
2026-02-27 13:45:33 +01:00
Julien Aldon
9f57b11fcf fix version dependencies
Some checks failed
Deploy Amap / deploy (push) Failing after 31s
2026-02-27 13:34:55 +01:00
Julien Aldon
e303e0723e add forms, shipments tests
Some checks failed
Deploy Amap / deploy (push) Failing after 13s
2026-02-27 12:29:07 +01:00
Julien Aldon
d28640711c add forms, shipments tests
Some checks failed
Deploy Amap / deploy (push) Failing after 52s
2026-02-27 12:21:50 +01:00
Julien Aldon
61cbbf0366 add tests forn forms, products, productors
All checks were successful
Deploy Amap / deploy (push) Successful in 3m45s
2026-02-25 16:39:12 +01:00
Julien Aldon
cfb8d435a8 Merge branch 'main' of gitea.aldon.fr:Mop/amap
All checks were successful
Deploy Amap / deploy (push) Successful in 35s
2026-02-23 15:38:45 +01:00
Julien Aldon
124b0700da add visible field to form 2026-02-23 15:38:29 +01:00
da22f24198 fix layout
All checks were successful
Deploy Amap / deploy (push) Successful in 9s
2026-02-20 18:51:45 +01:00
f4bb71a296 fix layout 2026-02-20 18:51:11 +01:00
8c6b25ded8 WIP contract recap 2026-02-19 16:19:40 +01:00
86 changed files with 6337 additions and 698 deletions

View File

@@ -11,6 +11,14 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Test backend
uses: actions/setup-python@v6
with:
python-version: "3.12"
- run: |
python -m pip install --upgrade pip
pip install -r backend/requirements.txt
pytest -sv
- name: Build & deploy - name: Build & deploy
run: | run: |
docker compose -f docker-compose.yaml up -d --build docker compose -f docker-compose.yaml up -d --build

View File

@@ -4,23 +4,12 @@
- Extract recap - Extract recap
## Payment method max cheque number
## Link products to a form ## Link products to a form
## Wording ## Wording
- all translations - all translations
## Draft / Publish form
- By default form is in draft mode
- Validate a form (button)
- check if productor
- check if shipments
- check products
- Publish
## Footer ## Footer
### Legal ### Legal
@@ -29,4 +18,9 @@
### Contact ### Contact
## Pagination
## Confirmation modal on suppression
### Show on cascade deletion
## Update contract after (without registration) ## Update contract after (without registration)

View File

@@ -29,6 +29,17 @@ alembic revision --autogenerate -m "message"
```console ```console
alembic upgrade head alembic upgrade head
``` ```
## Tests
```
hatch run pytest
hatch run pytest --cov=src -vv
```
## Autoformat
```console
find -type f -name '*.py' ! -path 'alembic/*' -exec autopep8 --in-place --aggressive --aggressive '{}' \;
pylint -d R0801,R0903,W0511,W0603,C0103,R0902 .
```
## License ## License

View File

@@ -25,7 +25,8 @@ target_metadata = SQLModel.metadata
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
# can be acquired: # can be acquired:
config.set_main_option("sqlalchemy.url", f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}') config.set_main_option(
"sqlalchemy.url", f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}')
# ... etc. # ... etc.

View File

@@ -22,7 +22,12 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
"""Upgrade schema.""" """Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.add_column('paymentmethod', sa.Column('max', sa.Integer(), nullable=True)) op.add_column(
'paymentmethod',
sa.Column(
'max',
sa.Integer(),
nullable=True))
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -22,28 +22,32 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
"""Upgrade schema.""" """Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('contracttype', op.create_table(
sa.Column('id', sa.Integer(), nullable=False), 'contracttype',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
sa.PrimaryKeyConstraint('id') 'id',
) sa.Integer(),
op.create_table('productor', nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
sa.Column('address', sqlmodel.sql.sqltypes.AutoString(), nullable=False), 'name',
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sqlmodel.sql.sqltypes.AutoString(),
sa.Column('id', sa.Integer(), nullable=False), nullable=False),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id'))
) op.create_table(
'productor', sa.Column(
'name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
'address', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
'type', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
'id', sa.Integer(), nullable=False), sa.PrimaryKeyConstraint('id'))
op.create_table('template', op.create_table('template',
sa.Column('id', sa.Integer(), nullable=False), sa.Column('id', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')
) )
op.create_table('user', op.create_table(
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), 'user', sa.Column(
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False), 'name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
sa.Column('id', sa.Integer(), nullable=False), 'email', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
sa.PrimaryKeyConstraint('id') 'id', sa.Integer(), nullable=False), sa.PrimaryKeyConstraint('id'))
)
op.create_table('form', op.create_table('form',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('productor_id', sa.Integer(), nullable=True), sa.Column('productor_id', sa.Integer(), nullable=True),
@@ -78,13 +82,13 @@ def upgrade() -> None:
sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ), sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')
) )
op.create_table('usercontracttypelink', op.create_table(
sa.Column('user_id', sa.Integer(), nullable=False), 'usercontracttypelink', sa.Column(
sa.Column('contract_type_id', sa.Integer(), nullable=False), 'user_id', sa.Integer(), nullable=False), sa.Column(
sa.ForeignKeyConstraint(['contract_type_id'], ['contracttype.id'], ), 'contract_type_id', sa.Integer(), nullable=False), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), ['contract_type_id'], ['contracttype.id'], ), sa.ForeignKeyConstraint(
sa.PrimaryKeyConstraint('user_id', 'contract_type_id') ['user_id'], ['user.id'], ), sa.PrimaryKeyConstraint(
) 'user_id', 'contract_type_id'))
op.create_table('contract', op.create_table('contract',
sa.Column('firstname', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('firstname', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('lastname', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('lastname', sqlmodel.sql.sqltypes.AutoString(), nullable=False),

View File

@@ -0,0 +1,40 @@
"""message
Revision ID: e777ed5729ce
Revises: 7854064278ce
Create Date: 2026-02-23 13:53:09.999893
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
# revision identifiers, used by Alembic.
revision: str = 'e777ed5729ce'
down_revision: Union[str, Sequence[str], None] = '7854064278ce'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
'form',
sa.Column(
'visible',
sa.Boolean(),
nullable=False,
default=False,
server_default="False"))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('form', 'visible')
# ### end Alembic commands ###

View File

@@ -30,7 +30,11 @@ dependencies = [
"requests", "requests",
"weasyprint", "weasyprint",
"odfdo", "odfdo",
"alembic" "alembic",
"pytest",
"pytest-cov",
"pytest-mock",
"pylint",
] ]
[project.urls] [project.urls]

View File

@@ -1,34 +1,48 @@
alembic==1.18.4
annotated-doc==0.0.4 annotated-doc==0.0.4
annotated-types==0.7.0 annotated-types==0.7.0
anyio==4.12.1 anyio==4.12.1
astroid==4.0.4
autopep8==2.3.2
brotli==1.2.0 brotli==1.2.0
certifi==2026.1.4 certifi==2026.2.25
cffi==2.0.0 cffi==2.0.0
charset-normalizer==3.4.4 charset-normalizer==3.4.4
click==8.3.1 click==8.3.1
coverage==7.13.4
cryptography==46.0.5 cryptography==46.0.5
cssselect2==0.9.0 cssselect2==0.9.0
dill==0.4.1
dnspython==2.8.0 dnspython==2.8.0
email-validator==2.3.0 email-validator==2.3.0
fastapi==0.129.0 fastapi==0.135.1
fastapi-cli==0.0.23 fastapi-cli==0.0.24
fastapi-cloud-cli==0.13.0 fastapi-cloud-cli==0.14.0
fastar==0.8.0 fastar==0.8.0
fonttools==4.61.1 fonttools==4.61.1
greenlet==3.3.1 greenlet==3.3.2
h11==0.16.0 h11==0.16.0
httpcore==1.0.9 httpcore==1.0.9
httptools==0.7.1 httptools==0.7.1
httpx==0.28.1 httpx==0.28.1
idna==3.11 idna==3.11
iniconfig==2.3.0
isort==8.0.1
Jinja2==3.1.6 Jinja2==3.1.6
lxml==6.0.2 lxml==6.0.2
Mako==1.3.10
markdown-it-py==4.0.0 markdown-it-py==4.0.0
MarkupSafe==3.0.3 MarkupSafe==3.0.3
mccabe==0.7.0
mdurl==0.1.2 mdurl==0.1.2
odfdo==3.20.2 odfdo==3.21.0
packaging==26.0
pillow==12.1.1 pillow==12.1.1
platformdirs==4.9.2
pluggy==1.6.0
prek==0.3.4
psycopg2-binary==2.9.11 psycopg2-binary==2.9.11
pycodestyle==2.14.0
pycparser==3.0 pycparser==3.0
pydantic==2.12.5 pydantic==2.12.5
pydantic-extra-types==2.11.0 pydantic-extra-types==2.11.0
@@ -37,22 +51,27 @@ pydantic_core==2.41.5
pydyf==0.12.1 pydyf==0.12.1
Pygments==2.19.2 Pygments==2.19.2
PyJWT==2.11.0 PyJWT==2.11.0
pylint==4.0.5
pyphen==0.17.2 pyphen==0.17.2
python-dotenv==1.2.1 pytest==9.0.2
pytest-cov==7.0.0
pytest-mock==3.15.1
python-dotenv==1.2.2
python-multipart==0.0.22 python-multipart==0.0.22
PyYAML==6.0.3 PyYAML==6.0.3
requests==2.32.5 requests==2.32.5
rich==14.3.2 rich==14.3.3
rich-toolkit==0.19.4 rich-toolkit==0.19.7
rignore==0.7.6 rignore==0.7.6
sentry-sdk==2.53.0 sentry-sdk==2.53.0
shellingham==1.5.4 shellingham==1.5.4
SQLAlchemy==2.0.46 SQLAlchemy==2.0.47
sqlmodel==0.0.34 sqlmodel==0.0.37
starlette==0.52.1 starlette==0.52.1
tinycss2==1.5.1 tinycss2==1.5.1
tinyhtml5==2.0.0 tinyhtml5==2.0.0
typer==0.24.0 tomlkit==0.14.0
typer==0.24.1
typing-inspection==0.4.2 typing-inspection==0.4.2
typing_extensions==4.15.0 typing_extensions==4.15.0
urllib3==2.6.3 urllib3==2.6.3
@@ -63,4 +82,3 @@ weasyprint==68.1
webencodings==0.5.1 webencodings==0.5.1
websockets==16.0 websockets==16.0
zopfli==0.4.1 zopfli==0.4.1
alembic==1.18.4

View File

View File

@@ -1,26 +1,27 @@
from typing import Annotated
from fastapi import APIRouter, Security, HTTPException, Depends, Request, Cookie
from fastapi.responses import RedirectResponse, Response
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlmodel import Session, select
import jwt
from jwt import PyJWKClient
from src.settings import AUTH_URL, TOKEN_URL, JWKS_URL, ISSUER, LOGOUT_URL, settings
import src.users.service as service
from src.database import get_session
from src.models import UserCreate, User, UserPublic
import secrets import secrets
import requests from typing import Annotated
from urllib.parse import urlencode from urllib.parse import urlencode
import src.messages as messages
import jwt
import requests
import src.users.service as service
from fastapi import APIRouter, Cookie, Depends, HTTPException, Request
from fastapi.responses import RedirectResponse, Response
from fastapi.security import HTTPBearer
from jwt import PyJWKClient
from sqlmodel import Session, select
from src import messages
from src.database import get_session
from src.models import User, UserCreate, UserPublic
from src.settings import (AUTH_URL, ISSUER, JWKS_URL, LOGOUT_URL, TOKEN_URL,
settings)
router = APIRouter(prefix='/auth') router = APIRouter(prefix='/auth')
jwk_client = PyJWKClient(JWKS_URL) jwk_client = PyJWKClient(JWKS_URL)
security = HTTPBearer() security = HTTPBearer()
@router.get('/logout') @router.get('/logout')
def logout(): def logout():
params = { params = {
@@ -59,9 +60,11 @@ def login():
'redirect_uri': settings.keycloak_redirect_uri, 'redirect_uri': settings.keycloak_redirect_uri,
'state': state, 'state': state,
} }
request_url = requests.Request('GET', AUTH_URL, params=params).prepare().url request_url = requests.Request(
'GET', AUTH_URL, params=params).prepare().url
return RedirectResponse(request_url) return RedirectResponse(request_url)
@router.get('/callback') @router.get('/callback')
def callback(code: str, session: Session = Depends(get_session)): def callback(code: str, session: Session = Depends(get_session)):
data = { data = {
@@ -74,18 +77,31 @@ def callback(code: str, session: Session = Depends(get_session)):
headers = { headers = {
'Content-Type': 'application/x-www-form-urlencoded' 'Content-Type': 'application/x-www-form-urlencoded'
} }
response = requests.post(TOKEN_URL, data=data, headers=headers) try:
response = requests.post(
TOKEN_URL,
data=data,
headers=headers,
timeout=10
)
except requests.exceptions.Timeout as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('token')
) from error
if response.status_code != 200: if response.status_code != 200:
raise HTTPException( raise HTTPException(
status_code=400, status_code=404,
detail=messages.failtogettoken detail=messages.Messages.not_found('token')
) )
token_data = response.json() token_data = response.json()
id_token = token_data['id_token'] id_token = token_data['id_token']
decoded_token = jwt.decode(id_token, options={'verify_signature': False}) decoded_token = jwt.decode(id_token, options={'verify_signature': False})
decoded_access_token = jwt.decode(token_data['access_token'], options={'verify_signature': False}) decoded_access_token = jwt.decode(
token_data['access_token'], options={
'verify_signature': False})
resource_access = decoded_access_token.get('resource_access') resource_access = decoded_access_token.get('resource_access')
if not resource_access: if not resource_access:
data = { data = {
@@ -93,7 +109,13 @@ def callback(code: str, session: Session = Depends(get_session)):
'client_secret': settings.keycloak_client_secret, 'client_secret': settings.keycloak_client_secret,
'refresh_token': token_data['refresh_token'], 'refresh_token': token_data['refresh_token'],
} }
res = requests.post(LOGOUT_URL, data=data) try:
requests.post(LOGOUT_URL, data=data, timeout=10)
except requests.exceptions.Timeout as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('token')
) from error
resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true') resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true')
return resp return resp
roles = resource_access.get(settings.keycloak_client_id) roles = resource_access.get(settings.keycloak_client_id)
@@ -103,7 +125,13 @@ def callback(code: str, session: Session = Depends(get_session)):
'client_secret': settings.keycloak_client_secret, 'client_secret': settings.keycloak_client_secret,
'refresh_token': token_data['refresh_token'], 'refresh_token': token_data['refresh_token'],
} }
res = requests.post(LOGOUT_URL, data=data) try:
requests.post(LOGOUT_URL, data=data, timeout=10)
except requests.exceptions.Timeout as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('token')
) from error
resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true') resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true')
return resp return resp
@@ -141,6 +169,7 @@ def callback(code: str, session: Session = Depends(get_session)):
return response return response
def verify_token(token: str): def verify_token(token: str):
try: try:
signing_key = jwk_client.get_signing_key_from_jwt(token) signing_key = jwk_client.get_signing_key_from_jwt(token)
@@ -153,31 +182,52 @@ def verify_token(token: str):
leeway=60, leeway=60,
) )
return decoded return decoded
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError as error:
raise HTTPException(status_code=401, detail=messages.tokenexipired) raise HTTPException(
except jwt.InvalidTokenError: status_code=401,
raise HTTPException(status_code=401, detail=messages.invalidtoken) detail=messages.Messages.tokenexipired
) from error
except jwt.InvalidTokenError as error:
raise HTTPException(
status_code=401,
detail=messages.Messages.invalidtoken
) from error
def get_current_user(request: Request, session: Session = Depends(get_session)): def get_current_user(
request: Request,
session: Session = Depends(get_session)):
access_token = request.cookies.get('access_token') access_token = request.cookies.get('access_token')
if not access_token: if not access_token:
raise HTTPException(status_code=401, detail=messages.notauthenticated) raise HTTPException(
status_code=401,
detail=messages.Messages.notauthenticated
)
payload = verify_token(access_token) payload = verify_token(access_token)
if not payload: if not payload:
raise HTTPException(status_code=401, detail='aze') raise HTTPException(
status_code=401,
detail='aze'
)
email = payload.get('email') email = payload.get('email')
if not email: if not email:
raise HTTPException(status_code=401, detail=messages.notauthenticated) raise HTTPException(
status_code=401,
detail=messages.Messages.notauthenticated
)
user = session.exec(select(User).where(User.email == email)).first() user = session.exec(select(User).where(User.email == email)).first()
if not user: if not user:
raise HTTPException(status_code=401, detail=messages.usernotfound) raise HTTPException(
status_code=401,
detail=messages.Messages.not_found('user')
)
return user return user
@router.post('/refresh') @router.post('/refresh')
def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None): def refresh_user_token(refresh_token: Annotated[str | None, Cookie()] = None):
refresh = refresh_token refresh = refresh_token
data = { data = {
'grant_type': 'refresh_token', 'grant_type': 'refresh_token',
@@ -188,11 +238,22 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
headers = { headers = {
'Content-Type': 'application/x-www-form-urlencoded' 'Content-Type': 'application/x-www-form-urlencoded'
} }
result = requests.post(TOKEN_URL, data=data, headers=headers) try:
result = requests.post(
TOKEN_URL,
data=data,
headers=headers,
timeout=10,
)
except requests.exceptions.Timeout as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('token')
) from error
if result.status_code != 200: if result.status_code != 200:
raise HTTPException( raise HTTPException(
status_code=400, status_code=404,
detail=messages.failtogettoken detail=messages.Messages.not_found('token')
) )
token_data = result.json() token_data = result.json()
@@ -201,7 +262,7 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
key='access_token', key='access_token',
value=token_data['access_token'], value=token_data['access_token'],
httponly=True, httponly=True,
secure=True if settings.debug == False else True, secure=True if settings.debug is False else True,
samesite='strict', samesite='strict',
max_age=settings.max_age max_age=settings.max_age
) )
@@ -209,7 +270,7 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
key='refresh_token', key='refresh_token',
value=token_data['refresh_token'] or '', value=token_data['refresh_token'] or '',
httponly=True, httponly=True,
secure=True if settings.debug == False else True, secure=True if settings.debug is False else True,
samesite='strict', samesite='strict',
max_age=30 * 24 * settings.max_age max_age=30 * 24 * settings.max_age
) )
@@ -223,6 +284,7 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
) )
return response return response
@router.get('/user/me') @router.get('/user/me')
def me(user: UserPublic = Depends(get_current_user)): def me(user: UserPublic = Depends(get_current_user)):
if not user: if not user:
@@ -233,6 +295,6 @@ def me(user: UserPublic = Depends(get_current_user)):
'name': user.name, 'name': user.name,
'email': user.email, 'email': user.email,
'id': user.id, 'id': user.id,
'roles': [role.name for role in user.roles] 'roles': user.roles
} }
} }

View File

@@ -1,88 +1,62 @@
from fastapi import APIRouter, Depends, HTTPException, Query """Router for contract resource"""
from fastapi.responses import StreamingResponse
from src.database import get_session
from sqlmodel import Session
from src.contracts.generate_contract import generate_html_contract, generate_recap
from src.auth.auth import get_current_user
import src.models as models
import src.messages as messages
import src.contracts.service as service
import src.forms.service as form_service
import io import io
import zipfile import zipfile
import src.contracts.service as service
import src.forms.service as form_service
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlmodel import Session
from src import messages, models
from src.auth.auth import get_current_user
from src.contracts.generate_contract import (generate_html_contract,
generate_recap)
from src.database import get_session
router = APIRouter(prefix='/contracts') router = APIRouter(prefix='/contracts')
def compute_recurrent_prices(products_quantities: list[dict], nb_shipment: int):
result = 0
for product_quantity in products_quantities:
product = product_quantity['product']
quantity = product_quantity['quantity']
result += compute_product_price(product, quantity, nb_shipment)
return result
def compute_occasional_prices(occasionals: list[dict]):
result = 0
for occasional in occasionals:
result += occasional['price']
return result
def compute_product_price(product: models.Product, quantity: int, nb_shipment: int = 1):
product_quantity_unit = 1 if product.unit == models.Unit.KILO else 1000
final_quantity = quantity if product.price else quantity / product_quantity_unit
final_price = product.price if product.price else product.price_kg
return final_price * final_quantity * nb_shipment
def find_dict_in_list(lst, key, value):
for i, dic in enumerate(lst):
if dic[key].id == value:
return i
return -1
def create_occasional_dict(contract_products: list[models.ContractProduct]):
result = []
for contract_product in contract_products:
existing_id = find_dict_in_list(
result,
'shipment',
contract_product.shipment.id
)
if existing_id < 0:
result.append({
'shipment': contract_product.shipment,
'price': compute_product_price(
contract_product.product,
contract_product.quantity
),
'products': [{
'product': contract_product.product,
'quantity': contract_product.quantity
}]
})
else:
result[existing_id]['products'].append({
'product': contract_product.product,
'quantity': contract_product.quantity
})
result[existing_id]['price'] += compute_product_price(
contract_product.product,
contract_product.quantity
)
return result
@router.post('') @router.post('')
async def create_contract( async def create_contract(
contract: models.ContractCreate, contract: models.ContractCreate,
session: Session = Depends(get_session), session: Session = Depends(get_session),
): ):
"""Create contract route"""
new_contract = service.create_one(session, contract) new_contract = service.create_one(session, contract)
occasional_contract_products = list(filter(lambda contract_product: contract_product.product.type == models.ProductType.OCCASIONAL, new_contract.products)) occasional_contract_products = list(
occasionals = create_occasional_dict(occasional_contract_products) filter(
recurrents = list(map(lambda x: {"product": x.product, "quantity": x.quantity}, filter(lambda contract_product: contract_product.product.type == models.ProductType.RECCURENT, new_contract.products))) lambda contract_product: (
recurrent_price = compute_recurrent_prices(recurrents, len(new_contract.form.shipments)) contract_product.product.type == models.ProductType.OCCASIONAL
price = recurrent_price + compute_occasional_prices(occasionals) ),
cheques = list(map(lambda x: {"name": x.name, "value": x.value}, new_contract.cheques)) new_contract.products
# TODO: send contract to referer )
)
occasionals = service.create_occasional_dict(occasional_contract_products)
recurrents = list(
map(
lambda x: {'product': x.product, 'quantity': x.quantity},
filter(
lambda contract_product: (
contract_product.product.type ==
models.ProductType.RECCURENT
),
new_contract.products
)
)
)
prices = service.generate_products_prices(
occasionals,
recurrents,
new_contract.form.shipments
)
recurrent_price = prices['recurrent']
total_price = prices['total']
cheques = list(
map(
lambda x: {'name': x.name, 'value': x.value},
new_contract.cheques
)
)
try: try:
pdf_bytes = generate_html_contract( pdf_bytes = generate_html_contract(
new_contract, new_contract,
@@ -90,46 +64,67 @@ async def create_contract(
occasionals, occasionals,
recurrents, recurrents,
'{:10.2f}'.format(recurrent_price), '{:10.2f}'.format(recurrent_price),
'{:10.2f}'.format(price) '{:10.2f}'.format(total_price)
) )
pdf_file = io.BytesIO(pdf_bytes) pdf_file = io.BytesIO(pdf_bytes)
contract_id = f'{new_contract.firstname}_{new_contract.lastname}_{new_contract.form.productor.type}_{new_contract.form.season}' contract_id = (
service.add_contract_file(session, new_contract.id, pdf_bytes, price) f'{new_contract.firstname}_'
except Exception: f'{new_contract.lastname}_'
raise HTTPException(status_code=400, detail=messages.pdferror) f'{new_contract.form.productor.type}_'
f'{new_contract.form.season}'
)
service.add_contract_file(
session, new_contract.id, pdf_bytes, total_price)
except Exception as error:
raise HTTPException(
status_code=400,
detail=messages.pdferror
) from error
return StreamingResponse( return StreamingResponse(
pdf_file, pdf_file,
media_type='application/pdf', media_type='application/pdf',
headers={ headers={
'Content-Disposition': f'attachment; filename=contract_{contract_id}.pdf' 'Content-Disposition': (
f'attachment; filename=contract_{contract_id}.pdf'
)
} }
) )
@router.get('/{form_id}/base') @router.get('/{form_id}/base')
async def get_base_contract_template( async def get_base_contract_template(
form_id: int, form_id: int,
session: Session = Depends(get_session), session: Session = Depends(get_session),
): ):
"""Get contract template route"""
form = form_service.get_one(session, form_id) form = form_service.get_one(session, form_id)
recurrents = list(map(lambda x: {"product": x, "quantity": None}, filter(lambda product: product.type == models.ProductType.RECCURENT, form.productor.products))) recurrents = [
{'product': product, 'quantity': None}
for product in form.productor.products
if product.type == models.ProductType.RECCURENT
]
occasionals = [{ occasionals = [{
'shipment': sh, 'shipment': sh,
'price': None, 'price': None,
'products': [{'product': pr, 'quantity': None} for pr in sh.products] 'products': [{'product': pr, 'quantity': None} for pr in sh.products]
} for sh in form.shipments] } for sh in form.shipments]
empty_contract = models.ContractPublic( empty_contract = models.ContractPublic(
firstname="", firstname='',
form=form, form=form,
lastname="", lastname='',
email="", email='',
phone="", phone='',
products=[], products=[],
payment_method="cheque", payment_method='cheque',
cheque_quantity=3, cheque_quantity=3,
total_price=0, total_price=0,
id=1 id=1
) )
cheques = [{"name": None, "value": None}, {"name": None, "value": None}, {"name": None, "value": None}] cheques = [
{'name': None, 'value': None},
{'name': None, 'value': None},
{'name': None, 'value': None}
]
try: try:
pdf_bytes = generate_html_contract( pdf_bytes = generate_html_contract(
empty_contract, empty_contract,
@@ -138,38 +133,60 @@ async def get_base_contract_template(
recurrents, recurrents,
) )
pdf_file = io.BytesIO(pdf_bytes) pdf_file = io.BytesIO(pdf_bytes)
contract_id = f'{empty_contract.form.productor.type}_{empty_contract.form.season}' contract_id = (
except Exception as e: f'{empty_contract.form.productor.type}_'
print(e) f'{empty_contract.form.season}'
raise HTTPException(status_code=400, detail=messages.pdferror) )
except Exception as error:
raise HTTPException(
status_code=400,
detail=messages.pdferror
) from error
return StreamingResponse( return StreamingResponse(
pdf_file, pdf_file,
media_type='application/pdf', media_type='application/pdf',
headers={ headers={
'Content-Disposition': f'attachment; filename=contract_{contract_id}.pdf' 'Content-Disposition': (
f'attachment; filename=contract_{contract_id}.pdf'
)
} }
) )
@router.get('', response_model=list[models.ContractPublic]) @router.get('', response_model=list[models.ContractPublic])
def get_contracts( def get_contracts(
forms: list[str] = Query([]), forms: list[str] = Query([]),
session: Session = Depends(get_session), session: Session = Depends(get_session),
user: models.User = Depends(get_current_user) user: models.User = Depends(get_current_user)
): ):
"""Get all contracts route"""
return service.get_all(session, user, forms) return service.get_all(session, user, forms)
@router.get('/{id}/file')
@router.get('/{_id}/file')
def get_contract_file( def get_contract_file(
id: int, _id: int,
session: Session = Depends(get_session), session: Session = Depends(get_session),
user: models.User = Depends(get_current_user) user: models.User = Depends(get_current_user)
): ):
if not service.is_allowed(session, user, id): """Get a contract file (in pdf) route"""
raise HTTPException(status_code=403, detail=messages.notallowed) if not service.is_allowed(session, user, _id):
contract = service.get_one(session, id) raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract', 'get')
)
contract = service.get_one(session, _id)
if contract is None: if contract is None:
raise HTTPException(status_code=404, detail=messages.notfound) raise HTTPException(
filename = f'{contract.form.name.replace(' ', '_')}_{contract.form.season}_{contract.firstname}-{contract.lastname}' status_code=404,
detail=messages.Messages.not_found('contract')
)
filename = (
f'{contract.form.name.replace(' ', '_')}_'
f'{contract.form.season}_'
f'{contract.firstname}_'
f'{contract.lastname}'
)
return StreamingResponse( return StreamingResponse(
io.BytesIO(contract.file), io.BytesIO(contract.file),
media_type='application/pdf', media_type='application/pdf',
@@ -178,23 +195,37 @@ def get_contract_file(
} }
) )
@router.get('/{form_id}/files') @router.get('/{form_id}/files')
def get_contract_files( def get_contract_files(
form_id: int, form_id: int,
session: Session = Depends(get_session), session: Session = Depends(get_session),
user: models.User = Depends(get_current_user) user: models.User = Depends(get_current_user)
): ):
if not form_service.is_allowed(session, user, form_id): """Get all contract files for a given form"""
raise HTTPException(status_code=403, detail=messages.notallowed) if not service.is_allowed(session, user, form_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contracts', 'get')
)
form = form_service.get_one(session, form_id=form_id) form = form_service.get_one(session, form_id=form_id)
contracts = service.get_all(session, user, forms=[form.name]) contracts = service.get_all(session, user, forms=[form.name])
zipped_contracts = io.BytesIO() zipped_contracts = io.BytesIO()
with zipfile.ZipFile(zipped_contracts, "a", zipfile.ZIP_DEFLATED, False) as zip_file: with zipfile.ZipFile(
zipped_contracts,
'a',
zipfile.ZIP_DEFLATED,
False
) as zip_file:
for contract in contracts: for contract in contracts:
contract_filename = f'{contract.form.name.replace(' ', '_')}_{contract.form.season}_{contract.firstname}-{contract.lastname}.pdf' contract_filename = (
f'{contract.form.name.replace(' ', '_')}_'
f'{contract.form.season}_'
f'{contract.firstname}_'
f'{contract.lastname}'
)
zip_file.writestr(contract_filename, contract.file) zip_file.writestr(contract_filename, contract.file)
filename = f'{form.name.replace(' ', '_')}_{form.season}'
filename = f'{form.name.replace(" ", "_")}_{form.season}'
return StreamingResponse( return StreamingResponse(
io.BytesIO(zipped_contracts.getvalue()), io.BytesIO(zipped_contracts.getvalue()),
media_type='application/zip', media_type='application/zip',
@@ -203,39 +234,70 @@ def get_contract_files(
} }
) )
@router.get('/{form_id}/recap') @router.get('/{form_id}/recap')
def get_contract_recap( def get_contract_recap(
form_id: int, form_id: int,
session: Session = Depends(get_session), session: Session = Depends(get_session),
user: models.User = Depends(get_current_user) user: models.User = Depends(get_current_user)
): ):
"""Get a contract recap for a given form"""
if not form_service.is_allowed(session, user, form_id): if not form_service.is_allowed(session, user, form_id):
raise HTTPException(status_code=403, detail=messages.notallowed) raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract recap', 'get')
)
form = form_service.get_one(session, form_id=form_id) form = form_service.get_one(session, form_id=form_id)
contracts = service.get_all(session, user, forms=[form.name]) contracts = service.get_all(session, user, forms=[form.name])
filename = f'{form.name}_recapitulatif_contrats.ods'
return StreamingResponse( return StreamingResponse(
io.BytesIO(generate_recap(contracts, form)), io.BytesIO(generate_recap(contracts, form)),
media_type='application/zip', media_type='application/vnd.oasis.opendocument.spreadsheet',
headers={ headers={
'Content-Disposition': f'attachment; filename=filename.ods' 'Content-Disposition': (
f'attachment; filename={filename}'
)
} }
) )
@router.get('/{id}', response_model=models.ContractPublic)
def get_contract(id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user)): @router.get('/{_id}', response_model=models.ContractPublic)
if not service.is_allowed(session, user, id): def get_contract(
raise HTTPException(status_code=403, detail=messages.notallowed) _id: int,
result = service.get_one(session, id) session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Get a contract route"""
if not service.is_allowed(session, user, _id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract', 'get')
)
result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.notfound) raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('contract')
)
return result return result
@router.delete('/{id}', response_model=models.ContractPublic)
def delete_contract(id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user)): @router.delete('/{_id}', response_model=models.ContractPublic)
if not service.is_allowed(session, user, id): def delete_contract(
raise HTTPException(status_code=403, detail=messages.notallowed) _id: int,
result = service.delete_one(session, id) session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Delete contract route"""
if not service.is_allowed(session, user, _id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract', 'delete')
)
result = service.delete_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.notfound) raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('contract')
)
return result return result

View File

@@ -1,11 +1,15 @@
import html
import io
import pathlib
import string
import jinja2 import jinja2
import src.models as models import odfdo
import html from src import models
from src.contracts import service
from weasyprint import HTML from weasyprint import HTML
import io
import pathlib
def generate_html_contract( def generate_html_contract(
contract: models.Contract, contract: models.Contract,
@@ -14,10 +18,24 @@ def generate_html_contract(
reccurents: list[dict], reccurents: list[dict],
recurrent_price: float | None = None, recurrent_price: float | None = None,
total_price: float | None = None total_price: float | None = None
): ) -> bytes:
"""Generate a html contract
Arguments:
contract(models.Contract): Contract source.
cheques(list[dict]): cheques formated in dict.
occasionals(list[dict]): occasional products.
reccurents(list[dict]): recurrent products.
recurrent_price(float | None = None): total price of recurent products.
total_price(float | None = Non): total price.
Return:
result(bytes): contract file in pdf as bytes.
"""
template_dir = pathlib.Path("./src/contracts/templates").resolve() template_dir = pathlib.Path("./src/contracts/templates").resolve()
template_loader = jinja2.FileSystemLoader(searchpath=template_dir) template_loader = jinja2.FileSystemLoader(searchpath=template_dir)
template_env = jinja2.Environment(loader=template_loader, autoescape=jinja2.select_autoescape(["html", "xml"])) template_env = jinja2.Environment(
loader=template_loader,
autoescape=jinja2.select_autoescape(["html", "xml"])
)
template_file = "layout.html" template_file = "layout.html"
template = template_env.get_template(template_file) template = template_env.get_template(template_file)
output_text = template.render( output_text = template.render(
@@ -28,57 +46,532 @@ def generate_html_contract(
referer_email=contract.form.referer.email, referer_email=contract.form.referer.email,
productor_name=contract.form.productor.name, productor_name=contract.form.productor.name,
productor_address=contract.form.productor.address, productor_address=contract.form.productor.address,
payment_methods_map={"cheque": "Ordre du chèque", "transfer": "virements"}, payment_methods_map={
"cheque": "Ordre du chèque",
"transfer": "virements"},
productor_payment_methods=contract.form.productor.payment_methods, productor_payment_methods=contract.form.productor.payment_methods,
member_name=f'{html.escape(contract.firstname)} {html.escape(contract.lastname)}', member_name=f'{
member_email=html.escape(contract.email), html.escape(
member_phone=html.escape(contract.phone), contract.firstname)} {
html.escape(
contract.lastname)}',
member_email=html.escape(
contract.email),
member_phone=html.escape(
contract.phone),
contract_start_date=contract.form.start, contract_start_date=contract.form.start,
contract_end_date=contract.form.end, contract_end_date=contract.form.end,
occasionals=occasionals, occasionals=occasionals,
recurrents=reccurents, recurrents=reccurents,
recurrent_price=recurrent_price, recurrent_price=recurrent_price,
total_price=total_price, total_price=total_price,
contract_payment_method={"cheque": "chèque", "transfer": "virements"}[contract.payment_method], contract_payment_method={
cheques=cheques "cheque": "chèque",
) "transfer": "virements"}[
# options = { contract.payment_method],
# 'page-size': 'Letter', cheques=cheques)
# 'margin-top': '0.5in',
# 'margin-right': '0.5in',
# 'margin-bottom': '0.5in',
# 'margin-left': '0.5in',
# 'encoding': "UTF-8",
# 'print-media-type': True,
# "disable-javascript": True,
# "disable-external-links": True,
# 'enable-local-file-access': False,
# "disable-local-file-access": True,
# "no-images": True,
# }
return HTML( return HTML(
string=output_text, string=output_text,
base_url=template_dir, base_url=template_dir,
).write_pdf() ).write_pdf()
from odfdo import Document, Table, Row, Cell
def flatten(xss):
"""flatten a list of list.
"""
return [x for xs in xss for x in xs]
def create_column_style_width(size: str) -> odfdo.Style:
"""Create a table columm style for a given width.
Paramenters:
size(str): size of the style (format <number><unit>)
unit can be in, cm... see odfdo documentation.
Returns:
odfdo.Style with the correct column-width attribute.
"""
return odfdo.Element.from_tag(
'<style:style style:name="product-table.A" style:family="table-column">'
f'<style:table-column-properties style:column-width="{size}"/>'
'</style:style>'
)
def create_row_style_height(size: str) -> odfdo.Style:
"""Create a table height style for a given height.
Paramenters:
size(str): size of the style (format <number><unit>)
unit can be in, cm... see odfdo documentation.
Returns:
odfdo.Style with the correct column-height attribute.
"""
return odfdo.Element.from_tag(
'<style:style style:name="product-table.A" style:family="table-row">'
f'<style:table-row-properties style:row-height="{size}"/>'
'</style:style>'
)
def create_currency_style(name: str = 'currency-euro'):
"""Create a table currency style.
Paramenters:
name(str): name of the style (default to `currency-euro`).
Returns:
odfdo.Style with the correct column-height attribute.
"""
return odfdo.Element.from_tag(
f"""
<number:currency-style style:name="{name}">
<number:number number:min-integer-digits="1"
number:decimal-places="2"/>
<number:text> €</number:text>
</number:currency-style>"""
)
def create_cell_style(
name: str = "centered-cell",
font_size: str = '10pt',
bold: bool = False,
background_color: str = '#FFFFFF',
color: str = '#000000',
currency: bool = False,
) -> odfdo.Style:
"""Create a cell style
Paramenters:
name(str): name of the style (default to `centered-cell`).
font_size(str): font_size of the cell (default to `10pt`).
bold(str): is the text bold (default to `False`).
background_color(str): background_color of the cell
(default to `#FFFFFF`).
color(str): color of the text of the cell (default to `#000000`).
currency(str): is the cell a currency (default to `False`).
Returns:
odfdo.Style with the correct column-height attribute.
"""
bold_attr = """
fo:font-weight="bold"
style:font-weight-asian="bold"
style:font-weight-complex="bold"
""" if bold else ''
currency_attr = """
style:data-style-name="currency-euro">
""" if currency else ''
return odfdo.Element.from_tag(
f"""<style:style style:name="{name}" style:family="table-cell"
{currency_attr}>
<style:table-cell-properties
fo:border="0.75pt solid #000000"
style:vertical-align="middle"
fo:wrap-option="wrap"
fo:background-color="{background_color}"/>
<style:paragraph-properties fo:text-align="center"/>
<style:text-properties
{bold_attr}
fo:font-size="{font_size}"
fo:color="{color}"/>
</style:style>"""
)
def apply_cell_style(
document: odfdo.Document,
table: odfdo.Table,
currency_cols: list[int]
):
"""Apply cell style
"""
document.insert_style(
style=create_currency_style(),
)
header_style = document.insert_style(
create_cell_style(
name="header-cells",
bold=True,
font_size='12pt',
background_color="#3480eb",
color="#FFF"
)
)
body_style_even = document.insert_style(
create_cell_style(
name="body-style-even",
bold=False,
background_color="#e8eaed",
color="#000000",
)
)
body_style_odd = document.insert_style(
create_cell_style(
name="body-style-odd",
bold=False,
background_color="#FFFFFF",
color="#000000",
)
)
footer_style = document.insert_style(
create_cell_style(
name="footer-cells",
bold=True,
font_size='12pt',
)
)
body_style_even_currency = document.insert_style(
create_cell_style(
name="body-style-even-currency",
bold=False,
background_color="#e8eaed",
color="#000000",
currency=True,
)
)
body_style_odd_currency = document.insert_style(
create_cell_style(
name="body-style-odd-currency",
bold=False,
background_color="#FFFFFF",
color="#000000",
currency=True,
)
)
footer_style_currency = document.insert_style(
create_cell_style(
name="footer-cells-currency",
bold=True,
font_size='12pt',
currency=True,
)
)
for index, row in enumerate(table.get_rows()):
style = body_style_even
currency_style = body_style_even_currency
if index == 0 or index == 1:
style = header_style
elif index == len(table.get_rows()) - 1:
style = footer_style
currency_style = footer_style_currency
elif index % 2 == 0:
style = body_style_even
currency_style = body_style_even_currency
else:
style = body_style_odd
currency_style = body_style_odd_currency
for cell_index, cell in enumerate(row.get_cells()):
if cell_index in currency_cols and not (index == 0 or index == 1):
cell.style = currency_style
else:
cell.style = style
def apply_column_height_style(
document: odfdo.Document,
table: odfdo.Table
):
"""Apply column height for a given table
"""
header_style = document.insert_style(
style=create_row_style_height('1.60cm'), name='1.60cm', automatic=True
)
body_style = document.insert_style(
style=create_row_style_height('0.90cm'), name='0.90cm', automatic=True
)
for index, row in enumerate(table.get_rows()):
if index == 1:
row.style = header_style
else:
row.style = body_style
def apply_cell_style_by_column(
table: odfdo.Table,
style: odfdo.Style,
col_index: int
):
"""Apply cell style for a given table
"""
for cell in table.get_column_cells(col_index):
cell.style = style
def apply_column_width_style(
document: odfdo.Document,
table: odfdo.Table,
widths: list[str]
):
"""Apply column width style to a table.
Parameters:
document(odfdo.Document): Document where the table is located.
table(odfdo.Table): Table to apply columns widths.
widths(list[str]): list of width in format <number><unit> unit ca be
in, cm... see odfdo documentation.
"""
styles = []
for w in widths:
styles.append(document.insert_style(
style=create_column_style_width(w), name=w, automatic=True)
)
for position in range(table.width):
col = table.get_column(position)
col.style = styles[position]
table.set_column(position, col)
def generate_ods_letters(n: int):
"""Generate letters following excel format.
Arguments:
n(int): `n` letters to generate.
Return:
result(list[str]): list of `n` letters that follow excel pattern.
"""
letters = string.ascii_lowercase
result = []
for i in range(n):
if i > len(letters) - 1:
letter = f'{letters[int(i / len(letters)) - 1]}'
letter += f'{letters[i % len(letters)]}'
result.append(letter)
continue
letter = letters[i]
result.append(letters[i])
return result
def compute_contract_prices(contract: models.Contract) -> dict:
"""Compute price for a give contract.
"""
occasional_contract_products = list(
filter(
lambda contract_product: (
contract_product.product.type == models.ProductType.OCCASIONAL
),
contract.products
)
)
occasionals_dict = service.create_occasional_dict(
occasional_contract_products)
recurrents_dict = list(
map(
lambda x: {'product': x.product, 'quantity': x.quantity},
filter(
lambda contract_product: (
contract_product.product.type ==
models.ProductType.RECCURENT
),
contract.products
)
)
)
prices = service.generate_products_prices(
occasionals_dict,
recurrents_dict,
contract.form.shipments
)
return prices
def transform_formula_cells(sheet: odfdo.Spreadsheet):
"""Transform cell value to a formula using odfdo.
"""
for row in sheet.get_rows():
for cell in row.get_cells():
if not cell.value or cell.get_attribute("office:value-type") == "float":
continue
if '=' in cell.value:
formula = cell.value
cell.clear()
cell.formula = formula
def merge_shipment_cells(
sheet: odfdo.Spreadsheet,
prefix_header: list[str],
recurrents: list[str],
occasionnals: list[str],
shipments: list[models.Shipment]
):
"""Merge cells for shipment header.
"""
index = len(prefix_header) + len(recurrents) + 1
for _ in enumerate(shipments):
startcol = index
endcol = index+len(occasionnals) - 1
sheet.set_span((startcol, 0, endcol, 0), merge=True)
index += len(occasionnals)
def generate_recap( def generate_recap(
contracts: list[models.Contract], contracts: list[models.Contract],
form: models.Form, form: models.Form,
): ):
data = [ """Generate excel recap for a list of contracts.
["nom", "email"], """
product_unit_map = {
'1': 'g',
'2': 'Kg',
'3': 'Piece'
}
recurrents = [
f'{pr.name}{f' - {pr.quantity}{pr.quantity_unit}'
if pr.quantity else ''} ({product_unit_map[pr.unit]})'
for pr in form.productor.products
if pr.type == models.ProductType.RECCURENT
] ]
doc = Document("spreadsheet") recurrents.sort()
sheet = Table(name="Recap") occasionnals = [
f'{pr.name}{f' - {pr.quantity}{pr.quantity_unit}'
if pr.quantity else ''} ({product_unit_map[pr.unit]})'
for pr in form.productor.products
if pr.type == models.ProductType.OCCASIONAL
]
occasionnals.sort()
shipments = form.shipments
occasionnals_header = [
occ for shipment in shipments for occ in occasionnals
]
info_header: list[str] = ['', 'Nom', 'Email']
cheque_header: list[str] = ['Cheque 1', 'Cheque 2', 'Cheque 3']
payment_header = (
cheque_header +
[f'Total {len(shipments)} livraisons + produits occasionnels']
)
prefix_header: list[str] = (
info_header +
payment_header
)
suffix_header: list[str] = [
'Total produits occasionnels',
'Remarques',
'Nom'
]
shipment_header = flatten([
[f'{shipment.name} - {shipment.date.strftime('%Y-%m-%d')}'] +
['' * len(occasionnals)] for shipment in shipments] +
[''] * len(suffix_header)
)
header: list[str] = (
prefix_header +
recurrents +
['Total produits récurrents'] +
occasionnals_header +
suffix_header
)
letters = generate_ods_letters(len(header))
payment_formula_letters = letters[
len(info_header):len(info_header) + len(payment_header)
]
recurent_formula_letters = letters[
len(info_header)+len(payment_formula_letters):
len(info_header)+len(payment_formula_letters)+len(recurrents) + 1
]
occasionnals_formula_letters = letters[
len(info_header)+len(payment_formula_letters) +
len(recurent_formula_letters):
len(info_header)+len(payment_formula_letters) +
len(recurent_formula_letters)+len(occasionnals_header) + 1
]
footer = (
['', 'Total contrats', ''] +
[f'=SUM({letter}3:{letter}{2+len(contracts)})'
for letter in payment_formula_letters] +
[f'=SUM({letter}3:{letter}{2+len(contracts)})'
for letter in recurent_formula_letters] +
[f'=SUM({letter}3:{letter}{2+len(contracts)})'
for letter in occasionnals_formula_letters]
)
main_data = []
for index, contract in enumerate(contracts):
prices = compute_contract_prices(contract)
occasionnal_sorted = sorted(
[
product for product in contract.products
if product.product.type == models.ProductType.OCCASIONAL
],
key=lambda x: (x.shipment.name, x.product.name)
)
recurrent_sorted = sorted(
[
product for product in contract.products
if product.product.type == models.ProductType.RECCURENT
],
key=lambda x: x.product.name
)
main_data.append([
f'{index + 1}',
f'{contract.firstname} {contract.lastname}',
f'{contract.email}',
*[float(contract.cheques[i].value)
if len(contract.cheques) > i
else ''
for i in range(3)],
prices['total'],
*[pr.quantity for pr in recurrent_sorted],
prices['recurrent'],
*[pr.quantity for pr in occasionnal_sorted],
prices['occasionnal'],
'',
f'{contract.firstname} {contract.lastname}',
])
data = [
[''] * (len(prefix_header) + len(recurrents) + 1) + shipment_header,
header,
*main_data,
footer
]
doc = odfdo.Document('spreadsheet')
sheet = doc.body.get_sheet(0)
sheet.name = 'Recap'
sheet.set_values(data) sheet.set_values(data)
doc.body.append(sheet) if len(occasionnals) > 0:
merge_shipment_cells(
sheet,
prefix_header,
recurrents,
occasionnals,
shipments
)
transform_formula_cells(sheet)
apply_column_width_style(
doc,
doc.body.get_table(0),
['2cm'] +
['6cm'] * 2 +
['2.40cm'] * (len(payment_header) - 1) +
['4cm'] * len(recurrents) +
['4cm'] +
['4cm'] * (len(occasionnals_header) + 1) +
['4cm', '8cm', '6cm']
)
apply_column_height_style(
doc,
doc.body.get_table(0),
)
apply_cell_style(
doc,
doc.body.get_table(0),
[
3,
4,
5,
6,
len(info_header) + len(payment_header),
len(info_header) + len(payment_header) + 1 + len(occasionnals),
]
)
doc.body.append(sheet)
buffer = io.BytesIO() buffer = io.BytesIO()
doc.save(buffer) doc.save(buffer)
return buffer.getvalue() return buffer.getvalue()

View File

@@ -1,28 +1,57 @@
"""Contract service responsible for read, create, update and delete contracts"""
from sqlalchemy.orm import selectinload
from sqlmodel import Session, select from sqlmodel import Session, select
import src.models as models from src import models
def get_all( def get_all(
session: Session, session: Session,
user: models.User, user: models.User,
forms: list[str] = [], forms: list[str] | None = None,
form_id: int | None = None, form_id: int | None = None,
) -> list[models.ContractPublic]: ) -> list[models.ContractPublic]:
statement = select(models.Contract)\ """Get all contracts"""
.join(models.Form, models.Contract.form_id == models.Form.id)\ statement = (
.join(models.Productor, models.Form.productor_id == models.Productor.id)\ select(models.Contract)
.where(models.Productor.type.in_([r.name for r in user.roles]))\ .join(
models.Form,
models.Contract.form_id == models.Form.id
)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id
)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct() .distinct()
if len(forms) > 0: )
if forms:
statement = statement.where(models.Form.name.in_(forms)) statement = statement.where(models.Form.name.in_(forms))
if form_id: if form_id:
statement = statement.where(models.Form.id == form_id) statement = statement.where(models.Form.id == form_id)
return session.exec(statement.order_by(models.Contract.id)).all() return session.exec(statement.order_by(models.Contract.id)).all()
def get_one(session: Session, contract_id: int) -> models.ContractPublic:
def get_one(
session: Session,
contract_id: int
) -> models.ContractPublic:
"""Get one contract"""
return session.get(models.Contract, contract_id) return session.get(models.Contract, contract_id)
def create_one(session: Session, contract: models.ContractCreate) -> models.ContractPublic:
contract_create = contract.model_dump(exclude_unset=True, exclude=["products", "cheques"]) def create_one(
session: Session,
contract: models.ContractCreate
) -> models.ContractPublic:
"""Create one contract"""
contract_create = contract.model_dump(
exclude_unset=True,
exclude=["products", "cheques"]
)
new_contract = models.Contract(**contract_create) new_contract = models.Contract(**contract_create)
new_contract.cheques = [ new_contract.cheques = [
@@ -45,10 +74,27 @@ def create_one(session: Session, contract: models.ContractCreate) -> models.Cont
session.add(new_contract) session.add(new_contract)
session.commit() session.commit()
session.refresh(new_contract) session.refresh(new_contract)
return new_contract
def add_contract_file(session: Session, id: int, file: bytes, price: float): statement = (
statement = select(models.Contract).where(models.Contract.id == id) select(models.Contract)
.where(models.Contract.id == new_contract.id)
.options(
selectinload(models.Contract.form)
.selectinload(models.Form.productor)
)
)
return session.exec(statement).one()
def add_contract_file(
session: Session,
_id: int,
file: bytes,
price: float
):
"""Add a file to an existing contract"""
statement = select(models.Contract).where(models.Contract.id == _id)
result = session.exec(statement) result = session.exec(statement)
contract = result.first() contract = result.first()
contract.total_price = price contract.total_price = price
@@ -58,8 +104,14 @@ def add_contract_file(session: Session, id: int, file: bytes, price: float):
session.refresh(contract) session.refresh(contract)
return contract return contract
def update_one(session: Session, id: int, contract: models.ContractUpdate) -> models.ContractPublic:
statement = select(models.Contract).where(models.Contract.id == id) def update_one(
session: Session,
_id: int,
contract: models.ContractUpdate
) -> models.ContractPublic:
"""Update one contract"""
statement = select(models.Contract).where(models.Contract.id == _id)
result = session.exec(statement) result = session.exec(statement)
new_contract = result.first() new_contract = result.first()
if not new_contract: if not new_contract:
@@ -72,8 +124,13 @@ def update_one(session: Session, id: int, contract: models.ContractUpdate) -> mo
session.refresh(new_contract) session.refresh(new_contract)
return new_contract return new_contract
def delete_one(session: Session, id: int) -> models.ContractPublic:
statement = select(models.Contract).where(models.Contract.id == id) def delete_one(
session: Session,
_id: int
) -> models.ContractPublic:
"""Delete one contract"""
statement = select(models.Contract).where(models.Contract.id == _id)
result = session.exec(statement) result = session.exec(statement)
contract = result.first() contract = result.first()
if not contract: if not contract:
@@ -83,11 +140,129 @@ def delete_one(session: Session, id: int) -> models.ContractPublic:
session.commit() session.commit()
return result return result
def is_allowed(session: Session, user: models.User, id: int) -> bool:
statement = select(models.Contract)\ def is_allowed(
.join(models.Form, models.Contract.form_id == models.Form.id)\ session: Session,
.join(models.Productor, models.Form.productor_id == models.Productor.id)\ user: models.User,
.where(models.Contract.id == id)\ _id: int
.where(models.Productor.type.in_([r.name for r in user.roles]))\ ) -> bool:
"""Determine if a user is allowed to access a contract by id"""
statement = (
select(models.Contract)
.join(
models.Form,
models.Contract.form_id == models.Form.id
)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id
)
.where(models.Contract.id == _id)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct() .distinct()
)
return len(session.exec(statement).all()) > 0 return len(session.exec(statement).all()) > 0
def compute_recurrent_prices(
products_quantities: list[dict],
nb_shipment: int
):
"""Compute price for recurrent products"""
result = 0
for product_quantity in products_quantities:
product = product_quantity['product']
quantity = product_quantity['quantity']
result += compute_product_price(product, quantity, nb_shipment)
return result
def compute_occasional_prices(occasionals: list[dict]):
"""Compute prices for occassional products"""
result = 0
for occasional in occasionals:
result += occasional['price']
return result
def compute_product_price(
product: models.Product,
quantity: int,
nb_shipment: int = 1
):
"""Compute price for a product"""
product_quantity_unit = (
1 if product.unit == models.Unit.KILO else 1000
)
final_quantity = (
quantity if product.price else quantity / product_quantity_unit
)
final_price = (
product.price if product.price else product.price_kg
)
return final_price * final_quantity * nb_shipment
def find_dict_in_list(lst, key, value):
"""Find the index of a dictionnary in a list of dictionnaries given a key
and a value.
"""
for i, dic in enumerate(lst):
if dic[key].id == value:
return i
return -1
def create_occasional_dict(contract_products: list[models.ContractProduct]):
"""Create a dictionnary of occasional products"""
result = []
for contract_product in contract_products:
existing_id = find_dict_in_list(
result,
'shipment',
contract_product.shipment.id
)
if existing_id < 0:
result.append({
'shipment': contract_product.shipment,
'price': compute_product_price(
contract_product.product,
contract_product.quantity
),
'products': [{
'product': contract_product.product,
'quantity': contract_product.quantity
}]
})
else:
result[existing_id]['products'].append({
'product': contract_product.product,
'quantity': contract_product.quantity
})
result[existing_id]['price'] += compute_product_price(
contract_product.product,
contract_product.quantity
)
return result
def generate_products_prices(
occasionals: list[dict],
recurrents: list[dict],
shipments: list[models.ShipmentPublic]
):
recurrent_price = compute_recurrent_prices(
recurrents,
len(shipments)
)
occasional_price = compute_occasional_prices(occasionals)
price = recurrent_price + occasional_price
return {
'total': price,
'recurrent': recurrent_price,
'occasionnal': occasional_price
}

View File

@@ -274,7 +274,7 @@
else ""}} else ""}}
</td> </td>
<td> <td>
{{rec.product.quantity if rec.product.quantity != None else ""}}{{"g" if rec.product.unit == "1" else "kg" if {{rec.quantity if rec.quantity != None else ""}}{{"g" if rec.product.unit == "1" else "kg" if
rec.product.unit == "2" else "p" }} rec.product.unit == "2" else "p" }}
</td> </td>
</tr> </tr>
@@ -317,7 +317,7 @@
product.product.quantity_unit != None else ""}} product.product.quantity_unit != None else ""}}
</td> </td>
<td> <td>
{{product.product.quantity if product.product.quantity != None {{product.quantity if product.quantity != None
else ""}}{{"g" if product.product.unit == "1" else else ""}}{{"g" if product.product.unit == "1" else
"kg" if product.product.unit == "2" else "p" }} "kg" if product.product.unit == "2" else "p" }}
</td> </td>

View File

@@ -1,11 +1,14 @@
from sqlmodel import create_engine, SQLModel, Session from sqlmodel import Session, SQLModel, create_engine
from src.settings import settings from src.settings import settings
engine = create_engine(f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}') engine = create_engine(
f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}')
def get_session(): def get_session():
with Session(engine) as session: with Session(engine) as session:
yield session yield session
def create_all_tables(): def create_all_tables():
SQLModel.metadata.create_all(engine) SQLModel.metadata.create_all(engine)

View File

@@ -0,0 +1,26 @@
"""Forms module exceptions"""
import logging
class FormServiceError(Exception):
"""Form service exception"""
def __init__(self, message: str):
super().__init__(message)
logging.error('FormService : %s', message)
class UserNotFoundError(FormServiceError):
pass
class ProductorNotFoundError(FormServiceError):
pass
class FormNotFoundError(FormServiceError):
pass
class FormCreateError(FormServiceError):
def __init__(self, message: str, field: str | None = None):
super().__init__(message)
self.field = field

View File

@@ -1,55 +1,107 @@
from fastapi import APIRouter, HTTPException, Depends, Query import src.forms.exceptions as exceptions
import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.forms.service as service import src.forms.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import messages, models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/forms') router = APIRouter(prefix='/forms')
@router.get('', response_model=list[models.FormPublic]) @router.get('', response_model=list[models.FormPublic])
async def get_forms( async def get_forms(
seasons: list[str] = Query([]), seasons: list[str] = Query([]),
productors: list[str] = Query([]), productors: list[str] = Query([]),
current_season: bool = False, current_season: bool = False,
session: Session = Depends(get_session) session: Session = Depends(get_session),
): ):
return service.get_all(session, seasons, productors, current_season) return service.get_all(session, seasons, productors, current_season)
@router.get('/{id}', response_model=models.FormPublic)
async def get_form(id: int, session: Session = Depends(get_session)): @router.get('/referents', response_model=list[models.FormPublic])
result = service.get_one(session, id) async def get_forms_filtered(
seasons: list[str] = Query([]),
productors: list[str] = Query([]),
current_season: bool = False,
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
return service.get_all(session, seasons, productors, current_season, user)
@router.get('/{_id}', response_model=models.FormPublic)
async def get_form(
_id: int,
session: Session = Depends(get_session)
):
result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.notfound) raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('form')
)
return result return result
@router.post('', response_model=models.FormPublic) @router.post('', response_model=models.FormPublic)
async def create_form( async def create_form(
form: models.FormCreate, form: models.FormCreate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
return service.create_one(session, form) if not service.is_allowed(session, user, form=form):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('forms', 'update')
)
try:
form = service.create_one(session, form)
except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) from error
except exceptions.UserNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) from error
except exceptions.FormCreateError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
return form
@router.put('/{id}', response_model=models.FormPublic)
@router.put('/{_id}', response_model=models.FormPublic)
async def update_form( async def update_form(
id: int, form: models.FormUpdate, _id: int,
form: models.FormUpdate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.update_one(session, id, form) if not service.is_allowed(session, user, _id=_id):
if result is None: raise HTTPException(
raise HTTPException(status_code=404, detail=messages.notfound) status_code=403,
detail=messages.Messages.not_allowed('forms', 'update')
)
try:
result = service.update_one(session, _id, form)
except exceptions.FormNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) from error
except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) from error
except exceptions.UserNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) from error
return result return result
@router.delete('/{id}', response_model=models.FormPublic)
@router.delete('/{_id}', response_model=models.FormPublic)
async def delete_form( async def delete_form(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.delete_one(session, id) if not service.is_allowed(session, user, _id=_id):
if result is None: raise HTTPException(
raise HTTPException(status_code=404, detail=messages.notfound) status_code=403,
detail=messages.Messages.not_allowed('forms', 'delete')
)
try:
result = service.delete_one(session, _id)
except exceptions.FormNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) from error
return result return result

View File

@@ -1,26 +1,40 @@
from sqlmodel import Session, select import src.forms.exceptions as exceptions
import src.models as models
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import Session, select
from src import messages, models
def get_all( def get_all(
session: Session, session: Session,
seasons: list[str], seasons: list[str],
productors: list[str], productors: list[str],
current_season: bool, current_season: bool,
user: models.User = None
) -> list[models.FormPublic]: ) -> list[models.FormPublic]:
statement = select(models.Form) statement = select(models.Form)
if user:
statement = statement .join(
models.Productor,
models.Form.productor_id == models.Productor.id) .where(
models.Productor.type.in_(
[
r.name for r in user.roles])) .distinct()
if len(seasons) > 0: if len(seasons) > 0:
statement = statement.where(models.Form.season.in_(seasons)) statement = statement.where(models.Form.season.in_(seasons))
if len(productors) > 0: if len(productors) > 0:
statement = statement.join(models.Productor).where(models.Productor.name.in_(productors)) statement = statement.join(
models.Productor).where(
models.Productor.name.in_(productors))
if not user:
statement = statement.where(models.Form.visible)
if current_season: if current_season:
subquery = ( subquery = (
select( select(
models.Productor.type, models.Productor.type,
func.max(models.Form.start).label("max_start") func.max(models.Form.start).label("max_start")
) )
.join(models.Form)\ .join(models.Form)
.group_by(models.Productor.type)\ .group_by(models.Productor.type)
.subquery() .subquery()
) )
statement = select(models.Form)\ statement = select(models.Form)\
@@ -29,13 +43,26 @@ def get_all(
(models.Productor.type == subquery.c.type) & (models.Productor.type == subquery.c.type) &
(models.Form.start == subquery.c.max_start) (models.Form.start == subquery.c.max_start)
) )
if not user:
statement = statement.where(models.Form.visible)
return session.exec(statement.order_by(models.Form.name)).all() return session.exec(statement.order_by(models.Form.name)).all()
return session.exec(statement.order_by(models.Form.name)).all() return session.exec(statement.order_by(models.Form.name)).all()
def get_one(session: Session, form_id: int) -> models.FormPublic: def get_one(session: Session, form_id: int) -> models.FormPublic:
return session.get(models.Form, form_id) return session.get(models.Form, form_id)
def create_one(session: Session, form: models.FormCreate) -> models.FormPublic: def create_one(session: Session, form: models.FormCreate) -> models.FormPublic:
if not form:
raise exceptions.FormCreateError(
messages.Messages.invalid_input(
'form', 'input cannot be None'))
if not session.get(models.Productor, form.productor_id):
raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
if not session.get(models.User, form.referer_id):
raise exceptions.UserNotFoundError(messages.Messages.not_found('user'))
form_create = form.model_dump(exclude_unset=True) form_create = form.model_dump(exclude_unset=True)
new_form = models.Form(**form_create) new_form = models.Form(**form_create)
session.add(new_form) session.add(new_form)
@@ -43,12 +70,22 @@ def create_one(session: Session, form: models.FormCreate) -> models.FormPublic:
session.refresh(new_form) session.refresh(new_form)
return new_form return new_form
def update_one(session: Session, id: int, form: models.FormUpdate) -> models.FormPublic:
statement = select(models.Form).where(models.Form.id == id) def update_one(
session: Session,
_id: int,
form: models.FormUpdate) -> models.FormPublic:
statement = select(models.Form).where(models.Form.id == _id)
result = session.exec(statement) result = session.exec(statement)
new_form = result.first() new_form = result.first()
if not new_form: if not new_form:
return None raise exceptions.FormNotFoundError(messages.Messages.not_found('form'))
if form.productor_id and not session.get(
models.Productor, form.productor_id):
raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
if form.referer_id and not session.get(models.User, form.referer_id):
raise exceptions.UserNotFoundError(messages.Messages.not_found('user'))
form_updates = form.model_dump(exclude_unset=True) form_updates = form.model_dump(exclude_unset=True)
for key, value in form_updates.items(): for key, value in form_updates.items():
setattr(new_form, key, value) setattr(new_form, key, value)
@@ -57,21 +94,46 @@ def update_one(session: Session, id: int, form: models.FormUpdate) -> models.For
session.refresh(new_form) session.refresh(new_form)
return new_form return new_form
def delete_one(session: Session, id: int) -> models.FormPublic:
statement = select(models.Form).where(models.Form.id == id) def delete_one(session: Session, _id: int) -> models.FormPublic:
statement = select(models.Form).where(models.Form.id == _id)
result = session.exec(statement) result = session.exec(statement)
form = result.first() form = result.first()
if not form: if not form:
return None raise exceptions.FormNotFoundError(messages.Messages.not_found('form'))
result = models.FormPublic.model_validate(form) result = models.FormPublic.model_validate(form)
session.delete(form) session.delete(form)
session.commit() session.commit()
return result return result
def is_allowed(session: Session, user: models.User, id: int) -> bool:
statement = select(models.Form)\ def is_allowed(
.join(models.Productor, models.Form.productor_id == models.Productor.id)\ session: Session,
.where(models.Form.id == id)\ user: models.User,
.where(models.Productor.type.in_([r.name for r in user.roles]))\ _id: int = None,
form: models.FormCreate = None
) -> bool:
if not _id and not form:
return False
if not _id:
statement = (
select(models.Productor)
.where(models.Productor.id == form.productor_id)
)
productor = session.exec(statement).first()
return productor.type in [r.name for r in user.roles]
statement = (
select(models.Form)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id
)
.where(models.Form.id == _id)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct() .distinct()
)
return len(session.exec(statement).all()) > 0 return len(session.exec(statement).all()) > 0

View File

@@ -1,18 +1,15 @@
from sqlmodel import SQLModel
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from src.auth.auth import router as auth_router
from src.templates.templates import router as template_router
from src.contracts.contracts import router as contracts_router from src.contracts.contracts import router as contracts_router
from src.forms.forms import router as forms_router from src.forms.forms import router as forms_router
from src.productors.productors import router as productors_router from src.productors.productors import router as productors_router
from src.products.products import router as products_router from src.products.products import router as products_router
from src.users.users import router as users_router
from src.auth.auth import router as auth_router
from src.shipments.shipments import router as shipment_router
from src.settings import settings from src.settings import settings
from src.database import engine, create_all_tables from src.shipments.shipments import router as shipment_router
from src.templates.templates import router as template_router
from src.users.users import router as users_router
app = FastAPI() app = FastAPI()

View File

@@ -1,10 +1,20 @@
notfound = "Resource was not found." pdferror = 'An error occured during PDF generation please contact administrator'
pdferror = "An error occured during PDF generation please contact administrator"
tokenexipired = "Token expired"
invalidtoken = "Invalid token" class Messages:
notauthenticated = "Not authenticated" unauthorized = 'User is Unauthorized'
usernotfound = "User not found" notauthenticated = 'User is not authenticated'
userloggedout = "User logged out" tokenexipired = 'Token has expired'
failtogettoken = "Failed to get token" invalidtoken = 'Token is invalid'
unauthorized = "Unauthorized"
notallowed = "Not Allowed" @staticmethod
def not_found(resource: str) -> str:
return f'{resource.capitalize()} not found'
@staticmethod
def invalid_input(resource: str, reason: str = "") -> str:
return f'Invalid {resource} input {':' if reason else ""} {reason}'
@staticmethod
def not_allowed(resource: str, action: str) -> str:
return f'User is not allowed to {action} this {resource}'

View File

@@ -1,99 +1,136 @@
from sqlmodel import Field, SQLModel, Relationship, Column, LargeBinary import datetime
from enum import StrEnum from enum import StrEnum
from typing import Optional from typing import Optional
import datetime
from sqlmodel import Column, Field, LargeBinary, Relationship, SQLModel
class ContractType(SQLModel, table=True): class ContractType(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(
default=None,
primary_key=True
)
name: str name: str
class UserContractTypeLink(SQLModel, table=True): class UserContractTypeLink(SQLModel, table=True):
user_id: int = Field(foreign_key="user.id", primary_key=True) user_id: int = Field(
contract_type_id: int = Field(foreign_key="contracttype.id", primary_key=True) foreign_key='user.id',
primary_key=True
)
contract_type_id: int = Field(
foreign_key='contracttype.id',
primary_key=True
)
class UserBase(SQLModel): class UserBase(SQLModel):
name: str name: str
email: str email: str
class UserPublic(UserBase): class UserPublic(UserBase):
id: int id: int
roles: list[ContractType] roles: list[ContractType]
class User(UserBase, table=True): class User(UserBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
roles: list[ContractType] = Relationship( roles: list[ContractType] = Relationship(
link_model=UserContractTypeLink link_model=UserContractTypeLink
) )
class UserUpdate(SQLModel): class UserUpdate(SQLModel):
name: str | None name: str | None
email: str | None email: str | None
role_names: list[str] | None role_names: list[str] | None
class UserCreate(UserBase): class UserCreate(UserBase):
role_names: list[str] | None role_names: list[str] | None
class PaymentMethodBase(SQLModel): class PaymentMethodBase(SQLModel):
name: str name: str
details: str details: str
max: int | None max: int | None
class PaymentMethod(PaymentMethodBase, table=True): class PaymentMethod(PaymentMethodBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
productor_id: int = Field(foreign_key="productor.id", ondelete="CASCADE") productor_id: int = Field(foreign_key='productor.id', ondelete='CASCADE')
productor: Optional["Productor"] = Relationship( productor: Optional['Productor'] = Relationship(
back_populates="payment_methods", back_populates='payment_methods',
) )
class PaymentMethodPublic(PaymentMethodBase): class PaymentMethodPublic(PaymentMethodBase):
id: int id: int
productor: Optional["Productor"] productor: Optional['Productor']
class ProductorBase(SQLModel): class ProductorBase(SQLModel):
name: str name: str
address: str address: str
type: str type: str
class ProductorPublic(ProductorBase): class ProductorPublic(ProductorBase):
id: int id: int
products: list["Product"] = [] products: list['Product'] = Field(default_factory=list)
payment_methods: list["PaymentMethod"] = [] payment_methods: list['PaymentMethod'] = Field(default_factory=list)
class Productor(ProductorBase, table=True): class Productor(ProductorBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
products: list["Product"] = Relationship( products: list['Product'] = Relationship(
back_populates='productor', back_populates='productor',
sa_relationship_kwargs={ sa_relationship_kwargs={
"order_by": "Product.name" 'order_by': 'Product.name'
}, },
) )
payment_methods: list["PaymentMethod"] = Relationship( payment_methods: list['PaymentMethod'] = Relationship(
back_populates="productor", back_populates='productor',
cascade_delete=True cascade_delete=True
) )
class ProductorUpdate(SQLModel): class ProductorUpdate(SQLModel):
name: str | None name: str | None
address: str | None address: str | None
payment_methods: list["PaymentMethod"] = [] payment_methods: list['PaymentMethod'] = Field(default_factory=list)
type: str | None type: str | None
class ProductorCreate(ProductorBase): class ProductorCreate(ProductorBase):
payment_methods: list["PaymentMethod"] = [] payment_methods: list['PaymentMethod'] = Field(default_factory=list)
class Unit(StrEnum): class Unit(StrEnum):
GRAMS = "1" GRAMS = '1'
KILO = "2" KILO = '2'
PIECE = "3" PIECE = '3'
class ProductType(StrEnum): class ProductType(StrEnum):
OCCASIONAL = "1" OCCASIONAL = '1'
RECCURENT = "2" RECCURENT = '2'
class ShipmentProductLink(SQLModel, table=True): class ShipmentProductLink(SQLModel, table=True):
shipment_id: Optional[int] = Field(default=None, foreign_key="shipment.id", primary_key=True) shipment_id: Optional[int] = Field(
product_id: Optional[int] = Field(default=None, foreign_key="product.id", primary_key=True) default=None,
foreign_key='shipment.id',
primary_key=True
)
product_id: Optional[int] = Field(
default=None,
foreign_key='product.id',
primary_key=True
)
class ProductBase(SQLModel): class ProductBase(SQLModel):
name: str name: str
@@ -103,17 +140,31 @@ class ProductBase(SQLModel):
quantity: float | None quantity: float | None
quantity_unit: str | None quantity_unit: str | None
type: ProductType type: ProductType
productor_id: int | None = Field(default=None, foreign_key="productor.id") productor_id: int | None = Field(
default=None,
foreign_key='productor.id'
)
class ProductPublic(ProductBase): class ProductPublic(ProductBase):
id: int id: int
productor: Productor | None productor: Productor | None
shipments: list["Shipment"] | None shipments: list['Shipment'] | None
class Product(ProductBase, table=True): class Product(ProductBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(
shipments: list["Shipment"] = Relationship(back_populates="products", link_model=ShipmentProductLink) default=None,
productor: Optional[Productor] = Relationship(back_populates="products") primary_key=True
)
shipments: list['Shipment'] = Relationship(
back_populates='products',
link_model=ShipmentProductLink
)
productor: Optional[Productor] = Relationship(
back_populates='products'
)
class ProductUpdate(SQLModel): class ProductUpdate(SQLModel):
name: str | None name: str | None
@@ -125,40 +176,46 @@ class ProductUpdate(SQLModel):
productor_id: int | None productor_id: int | None
type: ProductType | None type: ProductType | None
class ProductCreate(ProductBase): class ProductCreate(ProductBase):
pass pass
class FormBase(SQLModel): class FormBase(SQLModel):
name: str name: str
productor_id: int | None = Field(default=None, foreign_key="productor.id") productor_id: int | None = Field(default=None, foreign_key='productor.id')
referer_id: int | None = Field(default=None, foreign_key="user.id") referer_id: int | None = Field(default=None, foreign_key='user.id')
season: str season: str
start: datetime.date start: datetime.date
end: datetime.date end: datetime.date
minimum_shipment_value: float | None minimum_shipment_value: float | None
visible: bool
class FormPublic(FormBase): class FormPublic(FormBase):
id: int id: int
productor: ProductorPublic | None productor: ProductorPublic | None
referer: User | None referer: User | None
shipments: list["ShipmentPublic"] = [] shipments: list['ShipmentPublic'] = Field(default_factory=list)
class Form(FormBase, table=True): class Form(FormBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
productor: Optional['Productor'] = Relationship() productor: Optional['Productor'] = Relationship()
referer: Optional['User'] = Relationship() referer: Optional['User'] = Relationship()
shipments: list["Shipment"] = Relationship( shipments: list['Shipment'] = Relationship(
back_populates="form", back_populates='form',
cascade_delete=True, cascade_delete=True,
sa_relationship_kwargs={ sa_relationship_kwargs={
"order_by": "Shipment.name" 'order_by': 'Shipment.name'
}, },
) )
contracts: list["Contract"] = Relationship( contracts: list['Contract'] = Relationship(
back_populates="form", back_populates='form',
cascade_delete=True cascade_delete=True
) )
class FormUpdate(SQLModel): class FormUpdate(SQLModel):
name: str | None name: str | None
productor_id: int | None productor_id: int | None
@@ -167,36 +224,46 @@ class FormUpdate(SQLModel):
start: datetime.date | None start: datetime.date | None
end: datetime.date | None end: datetime.date | None
minimum_shipment_value: float | None minimum_shipment_value: float | None
visible: bool | None
class FormCreate(FormBase): class FormCreate(FormBase):
pass pass
class TemplateBase(SQLModel): class TemplateBase(SQLModel):
pass pass
class TemplatePublic(TemplateBase): class TemplatePublic(TemplateBase):
id: int id: int
class Template(TemplateBase, table=True): class Template(TemplateBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
class TemplateUpdate(SQLModel): class TemplateUpdate(SQLModel):
pass pass
class TemplateCreate(TemplateBase): class TemplateCreate(TemplateBase):
pass pass
class ChequeBase(SQLModel): class ChequeBase(SQLModel):
name: str name: str
value: str value: str
class Cheque(ChequeBase, table=True): class Cheque(ChequeBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
contract_id: int = Field(foreign_key="contract.id", ondelete="CASCADE") contract_id: int = Field(foreign_key='contract.id', ondelete='CASCADE')
contract: Optional["Contract"] = Relationship( contract: Optional['Contract'] = Relationship(
back_populates="cheques", back_populates='cheques',
) )
class ContractBase(SQLModel): class ContractBase(SQLModel):
firstname: str firstname: str
lastname: str lastname: str
@@ -205,105 +272,122 @@ class ContractBase(SQLModel):
payment_method: str payment_method: str
cheque_quantity: int cheque_quantity: int
class Contract(ContractBase, table=True): class Contract(ContractBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
form_id: int = Field( form_id: int = Field(
foreign_key="form.id", foreign_key='form.id',
nullable=False, nullable=False,
ondelete="CASCADE" ondelete='CASCADE'
) )
products: list["ContractProduct"] = Relationship( products: list['ContractProduct'] = Relationship(
back_populates="contract", back_populates='contract',
cascade_delete=True cascade_delete=True
) )
form: Optional[Form] = Relationship(back_populates="contracts") form: Form = Relationship(back_populates='contracts')
cheques: list[Cheque] = Relationship( cheques: list[Cheque] = Relationship(
back_populates="contract", back_populates='contract',
cascade_delete=True cascade_delete=True
) )
file: bytes = Field(sa_column=Column(LargeBinary)) file: bytes = Field(sa_column=Column(LargeBinary))
total_price: float | None total_price: float | None
class ContractCreate(ContractBase): class ContractCreate(ContractBase):
products: list["ContractProductCreate"] = [] products: list['ContractProductCreate'] = Field(default_factory=list)
cheques: list["Cheque"] = [] cheques: list['Cheque'] = Field(default_factory=list)
form_id: int form_id: int
class ContractUpdate(SQLModel): class ContractUpdate(SQLModel):
file: bytes file: bytes
class ContractPublic(ContractBase): class ContractPublic(ContractBase):
id: int id: int
products: list["ContractProduct"] = [] products: list['ContractProduct'] = Field(default_factory=list)
form: Form form: Form
total_price: float | None total_price: float | None
# file: bytes # file: bytes
class ContractProductBase(SQLModel): class ContractProductBase(SQLModel):
product_id: int = Field( product_id: int = Field(
foreign_key="product.id", foreign_key='product.id',
nullable=False, nullable=False,
ondelete="CASCADE" ondelete='CASCADE'
) )
shipment_id: int | None = Field( shipment_id: int | None = Field(
default=None, default=None,
foreign_key="shipment.id", foreign_key='shipment.id',
nullable=True, nullable=True,
ondelete="CASCADE" ondelete='CASCADE'
) )
quantity: float quantity: float
class ContractProduct(ContractProductBase, table=True): class ContractProduct(ContractProductBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
contract_id: int = Field( contract_id: int = Field(
foreign_key="contract.id", foreign_key='contract.id',
nullable=False, nullable=False,
ondelete="CASCADE" ondelete='CASCADE'
) )
contract: Optional["Contract"] = Relationship(back_populates="products") contract: Optional['Contract'] = Relationship(back_populates='products')
product: Optional["Product"] = Relationship() product: Optional['Product'] = Relationship()
shipment: Optional["Shipment"] = Relationship() shipment: Optional['Shipment'] = Relationship()
class ContractProductPublic(ContractProductBase): class ContractProductPublic(ContractProductBase):
id: int id: int
quantity: float quantity: float
contract: Contract contract: Contract
product: Product product: Product
shipment: Optional["Shipment"] shipment: Optional['Shipment']
class ContractProductCreate(ContractProductBase): class ContractProductCreate(ContractProductBase):
pass pass
class ContractProductUpdate(ContractProductBase): class ContractProductUpdate(ContractProductBase):
pass pass
class ShipmentBase(SQLModel): class ShipmentBase(SQLModel):
name: str name: str
date: datetime.date date: datetime.date
form_id: int | None = Field(default=None, foreign_key="form.id", ondelete="CASCADE") form_id: int | None = Field(
default=None,
foreign_key='form.id',
ondelete='CASCADE')
class ShipmentPublic(ShipmentBase): class ShipmentPublic(ShipmentBase):
id: int id: int
products: list[Product] = [] products: list[Product] = Field(default_factory=list)
form: Form | None form: Form | None
class Shipment(ShipmentBase, table=True): class Shipment(ShipmentBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
products: list[Product] = Relationship( products: list[Product] = Relationship(
back_populates="shipments", back_populates='shipments',
link_model=ShipmentProductLink, link_model=ShipmentProductLink,
sa_relationship_kwargs={ sa_relationship_kwargs={
"order_by": "Product.name" 'order_by': 'Product.name'
}, },
) )
form: Optional[Form] = Relationship(back_populates="shipments") form: Optional[Form] = Relationship(back_populates='shipments')
class ShipmentUpdate(SQLModel): class ShipmentUpdate(SQLModel):
name: str | None name: str | None
date: str | None date: datetime.date | None
product_ids: list[int] | None = [] product_ids: list[int] | None = Field(default_factory=list)
class ShipmentCreate(ShipmentBase): class ShipmentCreate(ShipmentBase):
product_ids: list[int] = [] product_ids: list[int] = Field(default_factory=list)
form_id: int form_id: int

View File

@@ -0,0 +1,17 @@
import logging
class ProductorServiceError(Exception):
def __init__(self, message: str):
super().__init__(message)
logging.error('ProductorService : %s', message)
class ProductorNotFoundError(ProductorServiceError):
pass
class ProductorCreateError(ProductorServiceError):
def __init__(self, message: str, field: str | None = None):
super().__init__(message)
self.field = field

View File

@@ -1,13 +1,13 @@
from fastapi import APIRouter, HTTPException, Depends, Query from fastapi import APIRouter, Depends, HTTPException, Query
import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session from sqlmodel import Session
import src.productors.service as service from src import messages, models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session
from src.productors import exceptions, service
router = APIRouter(prefix='/productors') router = APIRouter(prefix='/productors')
@router.get('', response_model=list[models.ProductorPublic]) @router.get('', response_model=list[models.ProductorPublic])
def get_productors( def get_productors(
names: list[str] = Query([]), names: list[str] = Query([]),
@@ -15,45 +15,78 @@ def get_productors(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
return service.get_all(session, names, types) return service.get_all(session, user, names, types)
@router.get('/{id}', response_model=models.ProductorPublic)
@router.get('/{_id}', response_model=models.ProductorPublic)
def get_productor( def get_productor(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.get_one(session, id) if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('productor', 'get')
)
result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.notfound) raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('productor')
)
return result return result
@router.post('', response_model=models.ProductorPublic) @router.post('', response_model=models.ProductorPublic)
def create_productor( def create_productor(
productor: models.ProductorCreate, productor: models.ProductorCreate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
return service.create_one(session, productor) if not service.is_allowed(session, user, productor=productor):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('productor', 'create')
)
try:
result = service.create_one(session, productor)
except exceptions.ProductorCreateError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
return result
@router.put('/{id}', response_model=models.ProductorPublic)
@router.put('/{_id}', response_model=models.ProductorPublic)
def update_productor( def update_productor(
id: int, productor: models.ProductorUpdate, _id: int, productor: models.ProductorUpdate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.update_one(session, id, productor) if not service.is_allowed(session, user, _id=_id):
if result is None: raise HTTPException(
raise HTTPException(status_code=404, detail=messages.notfound) status_code=403,
detail=messages.Messages.not_allowed('productor', 'update')
)
try:
result = service.update_one(session, _id, productor)
except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) from error
return result return result
@router.delete('/{id}', response_model=models.ProductorPublic)
@router.delete('/{_id}', response_model=models.ProductorPublic)
def delete_productor( def delete_productor(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.delete_one(session, id) if not service.is_allowed(session, user, _id=_id):
if result is None: raise HTTPException(
raise HTTPException(status_code=404, detail=messages.notfound) status_code=403,
detail=messages.Messages.not_allowed('productor', 'delete')
)
try:
result = service.delete_one(session, _id)
except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) from error
return result return result

View File

@@ -1,23 +1,37 @@
from sqlmodel import Session, select from sqlmodel import Session, select
import src.models as models from src import messages, models
from src.productors import exceptions
def get_all( def get_all(
session: Session, session: Session,
user: models.User,
names: list[str], names: list[str],
types: list[str] types: list[str]
) -> list[models.ProductorPublic]: ) -> list[models.ProductorPublic]:
statement = select(models.Productor) statement = select(models.Productor)\
.where(models.Productor.type.in_([r.name for r in user.roles]))\
.distinct()
if len(names) > 0: if len(names) > 0:
statement = statement.where(models.Productor.name.in_(names)) statement = statement.where(models.Productor.name.in_(names))
if len(types) > 0: if len(types) > 0:
statement = statement.where(models.Productor.type.in_(types)) statement = statement.where(models.Productor.type.in_(types))
return session.exec(statement.order_by(models.Productor.name)).all() return session.exec(statement.order_by(models.Productor.name)).all()
def get_one(session: Session, productor_id: int) -> models.ProductorPublic: def get_one(session: Session, productor_id: int) -> models.ProductorPublic:
return session.get(models.Productor, productor_id) return session.get(models.Productor, productor_id)
def create_one(session: Session, productor: models.ProductorCreate) -> models.ProductorPublic:
productor_create = productor.model_dump(exclude_unset=True, exclude="payment_methods") def create_one(
session: Session,
productor: models.ProductorCreate) -> models.ProductorPublic:
if not productor:
raise exceptions.ProductorCreateError(
messages.Messages.invalid_input(
'productor', 'input cannot be None'))
productor_create = productor.model_dump(
exclude_unset=True, exclude='payment_methods')
new_productor = models.Productor(**productor_create) new_productor = models.Productor(**productor_create)
new_productor.payment_methods = [ new_productor.payment_methods = [
@@ -32,26 +46,32 @@ def create_one(session: Session, productor: models.ProductorCreate) -> models.Pr
session.refresh(new_productor) session.refresh(new_productor)
return new_productor return new_productor
def update_one(session: Session, id: int, productor: models.ProductorUpdate) -> models.ProductorPublic:
statement = select(models.Productor).where(models.Productor.id == id) def update_one(
session: Session,
_id: int,
productor: models.ProductorUpdate
) -> models.ProductorPublic:
statement = select(models.Productor).where(models.Productor.id == _id)
result = session.exec(statement) result = session.exec(statement)
new_productor = result.first() new_productor = result.first()
if not new_productor: if not new_productor:
return None raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
productor_updates = productor.model_dump(exclude_unset=True) productor_updates = productor.model_dump(exclude_unset=True)
if "payment_methods" in productor_updates: if 'payment_methods' in productor_updates:
new_productor.payment_methods.clear() new_productor.payment_methods.clear()
for pm in productor_updates["payment_methods"]: for pm in productor_updates['payment_methods']:
new_productor.payment_methods.append( new_productor.payment_methods.append(
models.PaymentMethod( models.PaymentMethod(
name=pm["name"], name=pm['name'],
details=pm["details"], details=pm['details'],
productor_id=id, productor_id=id,
max=pm["max"] max=pm['max']
) )
) )
del productor_updates["payment_methods"] del productor_updates['payment_methods']
for key, value in productor_updates.items(): for key, value in productor_updates.items():
setattr(new_productor, key, value) setattr(new_productor, key, value)
@@ -60,13 +80,33 @@ def update_one(session: Session, id: int, productor: models.ProductorUpdate) ->
session.refresh(new_productor) session.refresh(new_productor)
return new_productor return new_productor
def delete_one(session: Session, id: int) -> models.ProductorPublic:
statement = select(models.Productor).where(models.Productor.id == id) def delete_one(session: Session, _id: int) -> models.ProductorPublic:
statement = select(models.Productor).where(models.Productor.id == _id)
result = session.exec(statement) result = session.exec(statement)
productor = result.first() productor = result.first()
if not productor: if not productor:
return None raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
result = models.ProductorPublic.model_validate(productor) result = models.ProductorPublic.model_validate(productor)
session.delete(productor) session.delete(productor)
session.commit() session.commit()
return result return result
def is_allowed(
session: Session,
user: models.User,
_id: int = None,
productor: models.ProductorCreate = None
) -> bool:
if not _id and not productor:
return False
if not _id:
return productor.type in [r.name for r in user.roles]
statement = (
select(models.Productor)
.where(models.Productor.id == _id)
.where(models.Productor.type.in_([r.name for r in user.roles]))
.distinct()
)
return len(session.exec(statement).all()) > 0

View File

@@ -0,0 +1,17 @@
class ProductServiceError(Exception):
def __init__(self, message: str):
super().__init__(message)
class ProductorNotFoundError(ProductServiceError):
pass
class ProductNotFoundError(ProductServiceError):
pass
class ProductCreateError(ProductServiceError):
def __init__(self, message: str, field: str | None = None):
super().__init__(message)
self.field = field

View File

@@ -1,12 +1,14 @@
from fastapi import APIRouter, HTTPException, Depends, Query
import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.products.service as service import src.products.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import messages, models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session
from src.products import exceptions
router = APIRouter(prefix='/products') router = APIRouter(prefix='/products')
@router.get('', response_model=list[models.ProductPublic], ) @router.get('', response_model=list[models.ProductPublic], )
def get_products( def get_products(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
@@ -17,48 +19,99 @@ def get_products(
): ):
return service.get_all( return service.get_all(
session, session,
user,
names, names,
productors, productors,
types, types,
) )
@router.get('/{id}', response_model=models.ProductPublic)
@router.get('/{_id}', response_model=models.ProductPublic)
def get_product( def get_product(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.get_one(session, id) if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('product', 'create')
)
result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.notfound) raise HTTPException(status_code=404,
detail=messages.Messages.not_found('product'))
return result return result
@router.post('', response_model=models.ProductPublic) @router.post('', response_model=models.ProductPublic)
def create_product( def create_product(
product: models.ProductCreate, product: models.ProductCreate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
return service.create_one(session, product) if not service.is_allowed(session, user, product=product):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('product', 'create')
)
try:
result = service.create_one(session, product)
except exceptions.ProductCreateError as error:
raise HTTPException(
status_code=400,
detail=str(error)
) from error
except exceptions.ProductorNotFoundError as error:
raise HTTPException(
status_code=404,
detail=str(error)
) from error
return result
@router.put('/{id}', response_model=models.ProductPublic)
@router.put('/{_id}', response_model=models.ProductPublic)
def update_product( def update_product(
id: int, product: models.ProductUpdate, _id: int, product: models.ProductUpdate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.update_one(session, id, product) if not service.is_allowed(session, user, _id=_id):
if result is None: raise HTTPException(
raise HTTPException(status_code=404, detail=messages.notfound) status_code=403,
detail=messages.Messages.not_allowed('product', 'update')
)
try:
result = service.update_one(session, _id, product)
except exceptions.ProductNotFoundError as error:
raise HTTPException(
status_code=404,
detail=str(error)
) from error
except exceptions.ProductorNotFoundError as error:
raise HTTPException(
status_code=404,
detail=str(error)
) from error
return result return result
@router.delete('/{id}', response_model=models.ProductPublic)
@router.delete('/{_id}', response_model=models.ProductPublic)
def delete_product( def delete_product(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.delete_one(session, id) if not service.is_allowed(session, user, _id=_id):
if result is None: raise HTTPException(
raise HTTPException(status_code=404, detail=messages.notfound) status_code=403,
detail=messages.Messages.not_allowed('product', 'delete')
)
try:
result = service.delete_one(session, _id)
except exceptions.ProductNotFoundError as error:
raise HTTPException(
status_code=404,
detail=str(error)
) from error
return result return result

View File

@@ -1,25 +1,49 @@
from sqlmodel import Session, select from sqlmodel import Session, select
import src.models as models from src import messages, models
from src.products import exceptions
def get_all( def get_all(
session: Session, session: Session,
user: models.User,
names: list[str], names: list[str],
productors: list[str], productors: list[str],
types: list[str], types: list[str],
) -> list[models.ProductPublic]: ) -> list[models.ProductPublic]:
statement = select(models.Product) statement = select(
models.Product) .join(
models.Productor,
models.Product.productor_id == models.Productor.id) .where(
models.Productor.type.in_(
[
r.name for r in user.roles])) .distinct()
if len(names) > 0: if len(names) > 0:
statement = statement.where(models.Product.name.in_(names)) statement = statement.where(models.Product.name.in_(names))
if len(productors) > 0: if len(productors) > 0:
statement = statement.join(models.Productor).where(models.Productor.name.in_(productors)) statement = statement.where(models.Productor.name.in_(productors))
if len(types) > 0: if len(types) > 0:
statement = statement.where(models.Product.type.in_(types)) statement = statement.where(models.Product.type.in_(types))
return session.exec(statement.order_by(models.Product.name)).all() return session.exec(statement.order_by(models.Product.name)).all()
def get_one(session: Session, product_id: int) -> models.ProductPublic:
def get_one(
session: Session,
product_id: int,
) -> models.ProductPublic:
return session.get(models.Product, product_id) return session.get(models.Product, product_id)
def create_one(session: Session, product: models.ProductCreate) -> models.ProductPublic:
def create_one(
session: Session,
product: models.ProductCreate,
) -> models.ProductPublic:
if not product:
raise exceptions.ProductCreateError(
messages.Messages.invalid_input(
'product', 'input cannot be None'))
if not session.get(models.Productor, product.productor_id):
raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
product_create = product.model_dump(exclude_unset=True) product_create = product.model_dump(exclude_unset=True)
new_product = models.Product(**product_create) new_product = models.Product(**product_create)
session.add(new_product) session.add(new_product)
@@ -27,12 +51,23 @@ def create_one(session: Session, product: models.ProductCreate) -> models.Produc
session.refresh(new_product) session.refresh(new_product)
return new_product return new_product
def update_one(session: Session, id: int, product: models.ProductUpdate) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == id) def update_one(
session: Session,
_id: int,
product: models.ProductUpdate
) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == _id)
result = session.exec(statement) result = session.exec(statement)
new_product = result.first() new_product = result.first()
if not new_product: if not new_product:
return None raise exceptions.ProductNotFoundError(
messages.Messages.not_found('product'))
if product.productor_id and not session.get(
models.Productor, product.productor_id):
raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
product_updates = product.model_dump(exclude_unset=True) product_updates = product.model_dump(exclude_unset=True)
for key, value in product_updates.items(): for key, value in product_updates.items():
setattr(new_product, key, value) setattr(new_product, key, value)
@@ -42,13 +77,49 @@ def update_one(session: Session, id: int, product: models.ProductUpdate) -> mode
session.refresh(new_product) session.refresh(new_product)
return new_product return new_product
def delete_one(session: Session, id: int) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == id) def delete_one(
session: Session,
_id: int
) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == _id)
result = session.exec(statement) result = session.exec(statement)
product = result.first() product = result.first()
if not product: if not product:
return None raise exceptions.ProductNotFoundError(
messages.Messages.not_found('product'))
result = models.ProductPublic.model_validate(product) result = models.ProductPublic.model_validate(product)
session.delete(product) session.delete(product)
session.commit() session.commit()
return result return result
def is_allowed(
session: Session,
user: models.User,
_id: int = None,
product: models.ProductCreate = None,
) -> bool:
if not _id and not product:
return False
if not _id:
statement = (
select(models.Product)
.join(
models.Productor,
models.Product.productor_id == models.Productor.id
)
.where(models.Product.id == product.productor_id)
)
productor = session.exec(statement).first()
return productor.type in [r.name for r in user.roles]
statement = (
select(models.Product)
.join(
models.Productor,
models.Product.productor_id == models.Productor.id
)
.where(models.Product.id == _id)
.where(models.Productor.type.in_([r.name for r in user.roles]))
.distinct()
)
return len(session.exec(statement).all()) > 0

View File

@@ -1,4 +1,5 @@
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings): class Settings(BaseSettings):
origins: str origins: str
@@ -16,13 +17,25 @@ class Settings(BaseSettings):
max_age: int max_age: int
debug: bool debug: bool
class Config: model_config = SettingsConfigDict(
env_file = "../.env" env_file='../.env'
)
settings = Settings() settings = Settings()
AUTH_URL = f"{settings.keycloak_server}/realms/{settings.keycloak_realm}/protocol/openid-connect/auth" AUTH_URL = (
TOKEN_URL = f"{settings.keycloak_server}/realms/{settings.keycloak_realm}/protocol/openid-connect/token" f'{settings.keycloak_server}/realms/'
ISSUER = f"{settings.keycloak_server}/realms/{settings.keycloak_realm}" f'{settings.keycloak_realm}/protocol/openid-connect/auth'
JWKS_URL = f"{ISSUER}/protocol/openid-connect/certs" )
LOGOUT_URL = f'{settings.keycloak_server}/realms/{settings.keycloak_realm}/protocol/openid-connect/logout' TOKEN_URL = (
f'{settings.keycloak_server}/realms/'
f'{settings.keycloak_realm}/protocol/openid-connect/token'
)
ISSUER = f'{settings.keycloak_server}/realms/{settings.keycloak_realm}'
JWKS_URL = f'{ISSUER}/protocol/openid-connect/certs'
LOGOUT_URL = (
f'{settings.keycloak_server}/realms/'
f'{settings.keycloak_realm}/protocol/openid-connect/logout'
)

View File

@@ -0,0 +1,17 @@
import logging
class ShipmentServiceError(Exception):
def __init__(self, message: str):
super().__init__(message)
logging.error('ShipmentService : %s', message)
class ShipmentNotFoundError(ShipmentServiceError):
pass
class ShipmentCreateError(ShipmentServiceError):
def __init__(self, message: str, field: str | None = None):
super().__init__(message)
self.field = field

View File

@@ -1,46 +1,110 @@
# pylint: disable=E1101
import datetime
import src.shipments.exceptions as exceptions
from sqlmodel import Session, select from sqlmodel import Session, select
import src.models as models from src import messages, models
def get_all( def get_all(
session: Session, session: Session,
names: list[str], user: models.User,
dates: list[str], names: list[str] = None,
forms: list[int] dates: list[str] = None,
forms: list[str] = None
) -> list[models.ShipmentPublic]: ) -> list[models.ShipmentPublic]:
statement = select(models.Shipment) statement = (
if len(names) > 0: select(models.Shipment)
.join(
models.Form,
models.Shipment.form_id == models.Form.id)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct()
)
if names and len(names) > 0:
statement = statement.where(models.Shipment.name.in_(names)) statement = statement.where(models.Shipment.name.in_(names))
if len(dates) > 0: if dates and len(dates) > 0:
statement = statement.where(models.Shipment.date.in_(list(map(lambda x: datetime.strptime(x, '%Y-%m-%d'), dates)))) statement = statement.where(
if len(forms) > 0: models.Shipment.date.in_(
statement = statement.join(models.Form).where(models.Form.name.in_(forms)) list(map(
lambda x: datetime.datetime.strptime(
x, '%Y-%m-%d').date(),
dates
))
)
)
if forms and len(forms) > 0:
statement = statement.where(models.Form.name.in_(forms))
return session.exec(statement.order_by(models.Shipment.name)).all() return session.exec(statement.order_by(models.Shipment.name)).all()
def get_one(session: Session, shipment_id: int) -> models.ShipmentPublic: def get_one(session: Session, shipment_id: int) -> models.ShipmentPublic:
return session.get(models.Shipment, shipment_id) return session.get(models.Shipment, shipment_id)
def create_one(session: Session, shipment: models.ShipmentCreate) -> models.ShipmentPublic:
products = session.exec(select(models.Product).where(models.Product.id.in_(shipment.product_ids))).all() def create_one(
shipment_create = shipment.model_dump(exclude_unset=True, exclude={'product_ids'}) session: Session,
shipment: models.ShipmentCreate) -> models.ShipmentPublic:
if shipment is None:
raise exceptions.ShipmentCreateError(
messages.Messages.invalid_input(
'shipment', 'input cannot be None'))
products = session.exec(
select(models.Product)
.where(
models.Product.id.in_(
shipment.product_ids
)
)
).all()
shipment_create = shipment.model_dump(
exclude_unset=True, exclude={'product_ids'}
)
new_shipment = models.Shipment(**shipment_create, products=products) new_shipment = models.Shipment(**shipment_create, products=products)
session.add(new_shipment) session.add(new_shipment)
session.commit() session.commit()
session.refresh(new_shipment) session.refresh(new_shipment)
return new_shipment return new_shipment
def update_one(session: Session, id: int, shipment: models.ShipmentUpdate) -> models.ShipmentPublic:
statement = select(models.Shipment).where(models.Shipment.id == id) def update_one(
session: Session,
_id: int,
shipment: models.ShipmentUpdate) -> models.ShipmentPublic:
if shipment is None:
raise exceptions.ShipmentCreateError(
messages.Messages.invalid_input(
'shipment', 'input cannot be None'))
statement = select(models.Shipment).where(models.Shipment.id == _id)
result = session.exec(statement) result = session.exec(statement)
new_shipment = result.first() new_shipment = result.first()
if not new_shipment: if not new_shipment:
return None raise exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment'))
products_to_add = session.exec(select(models.Product).where(models.Product.id.in_(shipment.product_ids))).all() products_to_add = session.exec(
select(
models.Product
).where(
models.Product.id.in_(
shipment.product_ids
)
)
).all()
new_shipment.products.clear() new_shipment.products.clear()
for add in products_to_add: for add in products_to_add:
new_shipment.products.append(add) new_shipment.products.append(add)
shipment_updates = shipment.model_dump(exclude_unset=True, exclude={"product_ids"}) shipment_updates = shipment.model_dump(
exclude_unset=True, exclude={"product_ids"}
)
for key, value in shipment_updates.items(): for key, value in shipment_updates.items():
setattr(new_shipment, key, value) setattr(new_shipment, key, value)
@@ -49,13 +113,53 @@ def update_one(session: Session, id: int, shipment: models.ShipmentUpdate) -> mo
session.refresh(new_shipment) session.refresh(new_shipment)
return new_shipment return new_shipment
def delete_one(session: Session, id: int) -> models.ShipmentPublic:
statement = select(models.Shipment).where(models.Shipment.id == id) def delete_one(session: Session, _id: int) -> models.ShipmentPublic:
statement = select(models.Shipment).where(models.Shipment.id == _id)
result = session.exec(statement) result = session.exec(statement)
shipment = result.first() shipment = result.first()
if not shipment: if not shipment:
return None raise exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment'))
result = models.ShipmentPublic.model_validate(shipment) result = models.ShipmentPublic.model_validate(shipment)
session.delete(shipment) session.delete(shipment)
session.commit() session.commit()
return result return result
def is_allowed(
session: Session,
user: models.User,
_id: int = None,
shipment: models.ShipmentCreate = None,
):
if not _id and not shipment:
return False
if not _id:
statement = (
select(models.Shipment)
.join(
models.Form,
models.Shipment.form_id == models.Form.id
)
.where(models.Form.id == shipment.form_id)
)
form = session.exec(statement).first()
return form.productor.type in [r.name for r in user.roles]
statement = (
select(models.Shipment)
.join(
models.Form,
models.Shipment.form_id == models.Form.id
)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id
)
.where(models.Shipment.id == _id)
.where(models.Productor.type.in_([r.name for r in user.roles]))
.distinct()
)
return len(session.exec(statement).all()) > 0

View File

@@ -1,64 +1,102 @@
from fastapi import APIRouter, HTTPException, Depends, Query import src.shipments.exceptions as exceptions
import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.shipments.service as service import src.shipments.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import messages, models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/shipments') router = APIRouter(prefix='/shipments')
@router.get('', response_model=list[models.ShipmentPublic], ) @router.get('', response_model=list[models.ShipmentPublic], )
def get_shipments( def get_shipments(
session: Session = Depends(get_session), session: Session = Depends(get_session),
user: models.User = Depends(get_current_user),
names: list[str] = Query([]), names: list[str] = Query([]),
dates: list[str] = Query([]), dates: list[str] = Query([]),
forms: list[str] = Query([]), forms: list[str] = Query([]),
): ):
return service.get_all( return service.get_all(
session, session,
user,
names, names,
dates, dates,
forms, forms,
) )
@router.get('/{id}', response_model=models.ShipmentPublic)
@router.get('/{_id}', response_model=models.ShipmentPublic)
def get_shipment( def get_shipment(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.get_one(session, id) if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('shipment', 'get')
)
result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.notfound) raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('shipment')
)
return result return result
@router.post('', response_model=models.ShipmentPublic) @router.post('', response_model=models.ShipmentPublic)
def create_shipment( def create_shipment(
shipment: models.ShipmentCreate, shipment: models.ShipmentCreate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
return service.create_one(session, shipment) if not service.is_allowed(session, user, shipment=shipment):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('shipment', 'create')
)
try:
result = service.create_one(session, shipment)
except exceptions.ShipmentCreateError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
return result
@router.put('/{id}', response_model=models.ShipmentPublic)
@router.put('/{_id}', response_model=models.ShipmentPublic)
def update_shipment( def update_shipment(
id: int, shipment: models.ShipmentUpdate, _id: int,
shipment: models.ShipmentUpdate,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.update_one(session, id, shipment) if not service.is_allowed(session, user, _id=_id):
if result is None: raise HTTPException(
raise HTTPException(status_code=404, detail=messages.notfound) status_code=403,
detail=messages.Messages.not_allowed('shipment', 'update')
)
try:
result = service.update_one(session, _id, shipment)
except exceptions.ShipmentNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) from error
return result return result
@router.delete('/{id}', response_model=models.ShipmentPublic)
@router.delete('/{_id}', response_model=models.ShipmentPublic)
def delete_shipment( def delete_shipment(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.delete_one(session, id) if not service.is_allowed(session, user, _id=_id):
if result is None: raise HTTPException(
raise HTTPException(status_code=404, detail=messages.notfound) status_code=403,
detail=messages.Messages.not_allowed('shipment', 'delete')
)
try:
result = service.delete_one(session, _id)
except exceptions.ShipmentNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error)) from error
return result return result

View File

@@ -1,14 +1,19 @@
from sqlmodel import Session, select from sqlmodel import Session, select
import src.models as models from src import models
def get_all(session: Session) -> list[models.TemplatePublic]: def get_all(session: Session) -> list[models.TemplatePublic]:
statement = select(models.Template) statement = select(models.Template)
return session.exec(statement.order_by(models.Template.name)).all() return session.exec(statement.order_by(models.Template.name)).all()
def get_one(session: Session, template_id: int) -> models.TemplatePublic: def get_one(session: Session, template_id: int) -> models.TemplatePublic:
return session.get(models.Template, template_id) return session.get(models.Template, template_id)
def create_one(session: Session, template: models.TemplateCreate) -> models.TemplatePublic:
def create_one(
session: Session,
template: models.TemplateCreate) -> models.TemplatePublic:
template_create = template.model_dump(exclude_unset=True) template_create = template.model_dump(exclude_unset=True)
new_template = models.Template(**template_create) new_template = models.Template(**template_create)
session.add(new_template) session.add(new_template)
@@ -16,7 +21,11 @@ def create_one(session: Session, template: models.TemplateCreate) -> models.Temp
session.refresh(new_template) session.refresh(new_template)
return new_template return new_template
def update_one(session: Session, id: int, template: models.TemplateUpdate) -> models.TemplatePublic:
def update_one(
session: Session,
id: int,
template: models.TemplateUpdate) -> models.TemplatePublic:
statement = select(models.Template).where(models.Template.id == id) statement = select(models.Template).where(models.Template.id == id)
result = session.exec(statement) result = session.exec(statement)
new_template = result.first() new_template = result.first()
@@ -30,6 +39,7 @@ def update_one(session: Session, id: int, template: models.TemplateUpdate) -> mo
session.refresh(new_template) session.refresh(new_template)
return new_template return new_template
def delete_one(session: Session, id: int) -> models.TemplatePublic: def delete_one(session: Session, id: int) -> models.TemplatePublic:
statement = select(models.Template).where(models.Template.id == id) statement = select(models.Template).where(models.Template.id == id)
result = session.exec(statement) result = session.exec(statement)

View File

@@ -1,13 +1,13 @@
from fastapi import APIRouter, HTTPException, Depends
import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.templates.service as service import src.templates.service as service
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session
from src import messages, models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/templates') router = APIRouter(prefix='/templates')
@router.get('', response_model=list[models.TemplatePublic]) @router.get('', response_model=list[models.TemplatePublic])
def get_templates( def get_templates(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
@@ -15,6 +15,7 @@ def get_templates(
): ):
return service.get_all(session) return service.get_all(session)
@router.get('/{id}', response_model=models.TemplatePublic) @router.get('/{id}', response_model=models.TemplatePublic)
def get_template( def get_template(
id: int, id: int,
@@ -23,9 +24,11 @@ def get_template(
): ):
result = service.get_one(session, id) result = service.get_one(session, id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.notfound) raise HTTPException(status_code=404,
detail=messages.Messages.not_found('template'))
return result return result
@router.post('', response_model=models.TemplatePublic) @router.post('', response_model=models.TemplatePublic)
def create_template( def create_template(
template: models.TemplateCreate, template: models.TemplateCreate,
@@ -34,6 +37,7 @@ def create_template(
): ):
return service.create_one(session, template) return service.create_one(session, template)
@router.put('/{id}', response_model=models.TemplatePublic) @router.put('/{id}', response_model=models.TemplatePublic)
def update_template( def update_template(
id: int, template: models.TemplateUpdate, id: int, template: models.TemplateUpdate,
@@ -42,9 +46,11 @@ def update_template(
): ):
result = service.update_one(session, id, template) result = service.update_one(session, id, template)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.notfound) raise HTTPException(status_code=404,
detail=messages.Messages.not_found('template'))
return result return result
@router.delete('/{id}', response_model=models.TemplatePublic) @router.delete('/{id}', response_model=models.TemplatePublic)
def delete_template( def delete_template(
id: int, id: int,
@@ -53,5 +59,6 @@ def delete_template(
): ):
result = service.delete_one(session, id) result = service.delete_one(session, id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.notfound) raise HTTPException(status_code=404,
detail=messages.Messages.not_found('template'))
return result return result

View File

@@ -0,0 +1,17 @@
import logging
class UserServiceError(Exception):
def __init__(self, message: str):
super().__init__(message)
logging.error('UserService : %s', message)
class UserNotFoundError(UserServiceError):
pass
class UserCreateError(UserServiceError):
def __init__(self, message: str, field: str | None = None):
super().__init__(message)
self.field = field

View File

@@ -1,5 +1,7 @@
import src.users.exceptions as exceptions
from sqlmodel import Session, select from sqlmodel import Session, select
import src.models as models from src import messages, models
def get_all( def get_all(
session: Session, session: Session,
@@ -13,11 +15,15 @@ def get_all(
statement = statement.where(models.User.email.in_(emails)) statement = statement.where(models.User.email.in_(emails))
return session.exec(statement.order_by(models.User.name)).all() return session.exec(statement.order_by(models.User.name)).all()
def get_one(session: Session, user_id: int) -> models.UserPublic: def get_one(session: Session, user_id: int) -> models.UserPublic:
return session.get(models.User, user_id) return session.get(models.User, user_id)
def get_or_create_roles(session: Session, role_names: list[str]) -> list[models.ContractType]:
statement = select(models.ContractType).where(models.ContractType.name.in_(role_names)) def get_or_create_roles(session: Session,
role_names: list[str]) -> list[models.ContractType]:
statement = select(models.ContractType).where(
models.ContractType.name.in_(role_names))
existing = session.exec(statement).all() existing = session.exec(statement).all()
existing_roles = {role.name for role in existing} existing_roles = {role.name for role in existing}
missing_role = set(role_names) - existing_roles missing_role = set(role_names) - existing_roles
@@ -33,22 +39,36 @@ def get_or_create_roles(session: Session, role_names: list[str]) -> list[models.
session.refresh(role) session.refresh(role)
return existing + new_roles return existing + new_roles
def get_or_create_user(session: Session, user_create: models.UserCreate): def get_or_create_user(session: Session, user_create: models.UserCreate):
statement = select(models.User).where(models.User.email == user_create.email) statement = select(
models.User).where(
models.User.email == user_create.email)
user = session.exec(statement).first() user = session.exec(statement).first()
if user: if user:
user_role_names = [r.name for r in user.roles] user_role_names = [r.name for r in user.roles]
if user_role_names != user_create.role_names or user.name != user_create.name: if (user_role_names != user_create.role_names or
user.name != user_create.name):
user = update_one(session, user.id, user_create) user = update_one(session, user.id, user_create)
return user return user
user = create_one(session, user_create) user = create_one(session, user_create)
return user return user
def get_roles(session: Session): def get_roles(session: Session):
statement = select(models.ContractType) statement = (
select(models.ContractType)
)
return session.exec(statement.order_by(models.ContractType.name)).all() return session.exec(statement.order_by(models.ContractType.name)).all()
def create_one(session: Session, user: models.UserCreate) -> models.UserPublic: def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
if user is None:
raise exceptions.UserCreateError(
messages.Messages.invalid_input(
'user', 'input cannot be None'
)
)
new_user = models.User( new_user = models.User(
name=user.name, name=user.name,
email=user.email email=user.email
@@ -62,13 +82,22 @@ def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
session.refresh(new_user) session.refresh(new_user)
return new_user return new_user
def update_one(session: Session, id: int, user: models.UserCreate) -> models.UserPublic:
statement = select(models.User).where(models.User.id == id) def update_one(
session: Session,
_id: int,
user: models.UserCreate) -> models.UserPublic:
if user is None:
raise exceptions.UserCreateError(
messages.Messages.invalid_input(
'user', 'input cannot be None'
)
)
statement = select(models.User).where(models.User.id == _id)
result = session.exec(statement) result = session.exec(statement)
new_user = result.first() new_user = result.first()
if not new_user: if not new_user:
return None raise exceptions.UserNotFoundError(f'User {_id} not found')
new_user.email = user.email new_user.email = user.email
new_user.name = user.name new_user.name = user.name
@@ -79,13 +108,20 @@ def update_one(session: Session, id: int, user: models.UserCreate) -> models.Use
session.refresh(new_user) session.refresh(new_user)
return new_user return new_user
def delete_one(session: Session, id: int) -> models.UserPublic:
statement = select(models.User).where(models.User.id == id) def delete_one(session: Session, _id: int) -> models.UserPublic:
statement = select(models.User).where(models.User.id == _id)
result = session.exec(statement) result = session.exec(statement)
user = result.first() user = result.first()
if not user: if not user:
return None raise exceptions.UserNotFoundError(f'User {_id} not found')
result = models.UserPublic.model_validate(user) result = models.UserPublic.model_validate(user)
session.delete(user) session.delete(user)
session.commit() session.commit()
return result return result
def is_allowed(
logged_user: models.User,
):
return len(logged_user.roles) >= 5

View File

@@ -1,17 +1,18 @@
from fastapi import APIRouter, HTTPException, Depends, Query import src.users.exceptions as exceptions
import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.users.service as service import src.users.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import messages, models
from src.auth.auth import get_current_user from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/users') router = APIRouter(prefix='/users')
@router.get('', response_model=list[models.UserPublic]) @router.get('', response_model=list[models.UserPublic])
def get_users( def get_users(
session: Session = Depends(get_session), session: Session = Depends(get_session),
user: models.User = Depends(get_current_user), _: models.User = Depends(get_current_user),
names: list[str] = Query([]), names: list[str] = Query([]),
emails: list[str] = Query([]), emails: list[str] = Query([]),
): ):
@@ -21,51 +22,99 @@ def get_users(
emails, emails,
) )
@router.get('/roles', response_model=list[models.ContractType]) @router.get('/roles', response_model=list[models.ContractType])
def get_roles( def get_roles(
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
if not service.is_allowed(user):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('roles', 'get all')
)
return service.get_roles(session) return service.get_roles(session)
@router.get('/{id}', response_model=models.UserPublic)
def get_users( @router.get('/{_id}', response_model=models.UserPublic)
id: int, def get_user(
_id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.get_one(session, id) if not service.is_allowed(user):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('user', 'get')
)
result = service.get_one(session, _id)
if result is None: if result is None:
raise HTTPException(status_code=404, detail=messages.notfound) raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('user')
)
return result return result
@router.post('', response_model=models.UserPublic) @router.post('', response_model=models.UserPublic)
def create_user( def create_user(
user: models.UserCreate, user: models.UserCreate,
logged_user: models.User = Depends(get_current_user), logged_user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
return service.create_one(session, user) if not service.is_allowed(logged_user):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('user', 'create')
)
try:
user = service.create_one(session, user)
except exceptions.UserCreateError as error:
raise HTTPException(
status_code=400,
detail=str(error)
) from error
return user
@router.put('/{id}', response_model=models.UserPublic)
@router.put('/{_id}', response_model=models.UserPublic)
def update_user( def update_user(
id: int, _id: int,
user: models.UserUpdate, user: models.UserUpdate,
logged_user: models.User = Depends(get_current_user), logged_user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.update_one(session, id, user) if not service.is_allowed(logged_user):
if result is None: raise HTTPException(
raise HTTPException(status_code=404, detail=messages.notfound) status_code=403,
detail=messages.Messages.not_allowed('user', 'update')
)
try:
result = service.update_one(session, _id, user)
except exceptions.UserNotFoundError as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('user')
) from error
return result return result
@router.delete('/{id}', response_model=models.UserPublic)
@router.delete('/{_id}', response_model=models.UserPublic)
def delete_user( def delete_user(
id: int, _id: int,
user: models.User = Depends(get_current_user), user: models.User = Depends(get_current_user),
session: Session = Depends(get_session) session: Session = Depends(get_session)
): ):
result = service.delete_one(session, id) if not service.is_allowed(user):
if result is None: raise HTTPException(
raise HTTPException(status_code=404, detail=messages.notfound) status_code=403,
detail=messages.Messages.not_allowed('user', 'delete')
)
try:
result = service.delete_one(session, _id)
except exceptions.UserNotFoundError as error:
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('user')
) from error
return result return result

Binary file not shown.

62
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,62 @@
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src import models
from src.auth.auth import get_current_user
from src.database import get_session
from src.main import app
from .fixtures import *
@pytest.fixture
def mock_session(mocker):
session = mocker.Mock()
def override():
return session
app.dependency_overrides[get_session] = override
yield session
app.dependency_overrides.clear()
@pytest.fixture
def mock_user():
user = models.User(id=1, name='test user', email='test@user.com')
def override():
return user
app.dependency_overrides[get_current_user] = override
yield user
app.dependency_overrides.clear()
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture(name='session')
def session_fixture():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
connection = engine.connect()
transaction = connection.begin()
session = Session(bind=connection)
try:
yield session
finally:
transaction.rollback()
session.close()
connection.close()
engine.dispose()

View File

@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
#
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,65 @@
import tests.factories.contracts as contract_factory
import tests.factories.products as product_factory
from src import models
def contract_product_factory(**kwargs):
contract = contract_factory.contract_factory(id=1)
product = product_factory.product_public_factory(
id=1, type=models.ProductType.RECCURENT)
data = dict(
product_id=1,
shipment_id=1,
quantity=1,
contract_id=1,
product=product,
contract=contract
)
data.update(kwargs)
return models.ContractProduct(**data)
def contract_product_public_factory(**kwargs):
contract = contract_factory.contract_factory(id=1)
product = product_factory.product_public_factory(id=1)
data = dict(
id=1,
product_id=1,
shipment_id=None,
contract=contract,
product=product,
shipment=None,
quantity=1
)
data.update(kwargs)
return models.ContractProductPublic(**data)
def contract_product_create_factory(**kwargs):
data = dict(
product_id=1,
shipment_id=1,
quantity=1,
)
data.update(kwargs)
return models.ContractProductCreate(**data)
def contract_product_update_factory(**kwargs):
data = dict(
product_id=1,
shipment_id=1,
quantity=1,
)
data.update(kwargs)
return models.ContractProductUpdate(**data)
def contract_product_body_factory(**kwargs):
data = dict(
product_id=1,
shipment_id=1,
quantity=1,
)
data.update(kwargs)
return data

View File

@@ -0,0 +1,82 @@
from src import models
from .forms import form_factory
def contract_factory(**kwargs):
data = dict(
id=1,
firstname="test",
lastname="test",
email="test@test.test",
phone="00000000",
payment_method="cheque",
cheque_quantity=1,
form_id=1,
products=[],
cheques=[],
)
data.update(kwargs)
return models.Contract(**data)
def contract_public_factory(**kwargs):
data = dict(
id=1,
firstname="test",
lastname="test",
email="test@test.test",
phone="00000000",
payment_method="cheque",
cheque_quantity=1,
total_price=10,
products=[],
form=form_factory()
)
data.update(kwargs)
return models.ContractPublic(**data)
def contract_create_factory(**kwargs):
data = dict(
firstname="test",
lastname="test",
email="test@test.test",
phone="00000000",
payment_method="cheque",
cheque_quantity=1,
products=[],
cheques=[],
form_id=1,
)
data.update(kwargs)
return models.ContractCreate(**data)
def contract_update_factory(**kwargs):
data = dict(
firstname="test",
lastname="test",
email="test@test.test",
phone="00000000",
payment_method="cheque",
cheque_quantity=1,
)
data.update(kwargs)
return models.ContractUpdate(**data)
def contract_body_factory(**kwargs):
data = dict(
firstname="test",
lastname="test",
email="test@test.test",
phone="00000000",
payment_method="cheque",
cheque_quantity=1,
products=[],
cheques=[],
form_id=1
)
data.update(kwargs)
return data

View File

@@ -0,0 +1,89 @@
import datetime
from src import models
from .productors import productor_public_factory
from .users import user_factory
def form_factory(**kwargs):
data = dict(
id=1,
name='form 1',
productor_id=1,
referer_id=1,
season='hiver-2026',
start=datetime.date(2025, 10, 10),
end=datetime.date(2025, 10, 10),
minimum_shipment_value=0,
visible=True,
referer=user_factory(),
shipments=[],
productor=productor_public_factory(),
)
data.update(kwargs)
return models.Form(**data)
def form_body_factory(**kwargs):
data = dict(
name='form 1',
productor_id=1,
referer_id=1,
season='hiver-2026',
start='2025-10-10',
end='2025-10-10',
minimum_shipment_value=0,
visible=True
)
data.update(kwargs)
return data
def form_create_factory(**kwargs):
data = dict(
name='form 1',
productor_id=1,
referer_id=1,
season='hiver-2026',
start=datetime.date(2025, 10, 10),
end=datetime.date(2025, 10, 10),
minimum_shipment_value=0,
visible=True
)
data.update(kwargs)
return models.FormCreate(**data)
def form_update_factory(**kwargs):
data = dict(
name='form 1',
productor_id=1,
referer_id=1,
season='hiver-2026',
start=datetime.date(2025, 10, 10),
end=datetime.date(2025, 10, 10),
minimum_shipment_value=0,
visible=True
)
data.update(kwargs)
return models.FormUpdate(**data)
def form_public_factory(form=None, shipments=[], **kwargs):
data = dict(
id=1,
name='form 1',
productor_id=1,
referer_id=1,
season='hiver-2026',
start=datetime.date(2025, 10, 10),
end=datetime.date(2025, 10, 10),
minimum_shipment_value=0,
visible=True,
referer=user_factory(),
shipments=[],
productor=productor_public_factory(),
)
data.update(kwargs)
return models.FormPublic(**data)

View File

@@ -0,0 +1,64 @@
from src import models
def productor_factory(**kwargs):
data = dict(
id=1,
name="test productor",
address="test address",
type="test type"
)
data.update(kwargs)
return models.Productor(**data)
def productor_public_factory(**kwargs):
data = dict(
id=1,
name="test productor",
address="test address",
type="test type",
products=[],
payment_methods=[],
)
data.update(kwargs)
return models.ProductorPublic(**data)
def productor_create_factory(**kwargs):
data = dict(
id=1,
name="test productor",
address="test address",
type="test type",
products=[],
payment_methods=[],
)
data.update(kwargs)
return models.ProductorCreate(**data)
def productor_update_factory(**kwargs):
data = dict(
id=1,
name="test productor",
address="test address",
type="test type",
products=[],
payment_methods=[],
)
data.update(kwargs)
return models.ProductorUpdate(**data)
def productor_body_factory(**kwargs):
data = dict(
id=1,
name="test productor",
address="test address",
type="test type",
products=[],
payment_methods=[],
)
data.update(kwargs)
return data

View File

@@ -0,0 +1,68 @@
from src import models
from .productors import productor_factory
def product_body_factory(**kwargs):
data = dict(
name='product test 1',
unit=models.Unit.PIECE,
price=10.2,
price_kg=20.4,
quantity=500,
quantity_unit='g',
type=models.ProductType.OCCASIONAL,
productor_id=1,
)
data.update(kwargs)
return data
def product_create_factory(**kwargs):
data = dict(
name='product test 1',
unit=models.Unit.PIECE,
price=10.2,
price_kg=20.4,
quantity=500,
quantity_unit='g',
type=models.ProductType.OCCASIONAL,
productor_id=1,
)
data.update(kwargs)
return models.ProductCreate(**data)
def product_update_factory(**kwargs):
data = dict(
name='product test 1',
unit=models.Unit.PIECE,
price=10.2,
price_kg=20.4,
quantity=500,
quantity_unit='g',
type=models.ProductType.OCCASIONAL,
productor_id=1,
)
data.update(kwargs)
return models.ProductUpdate(**data)
def product_public_factory(productor=None, shipments=[], **kwargs):
if productor is None:
productor = productor_factory()
data = dict(
id=1,
name='product test 1',
unit=models.Unit.PIECE,
price=10.2,
price_kg=20.4,
quantity=500,
quantity_unit='g',
type=models.ProductType.OCCASIONAL,
productor_id=1,
productor=productor,
shipments=shipments,
)
data.update(kwargs)
return models.ProductPublic(**data)

View File

@@ -0,0 +1,59 @@
import datetime
from src import models
def shipment_factory(**kwargs):
data = dict(
id=1,
name="test shipment",
date=datetime.date(2025, 10, 10),
form_id=1,
)
data.update(kwargs)
return models.Shipment(**data)
def shipment_public_factory(**kwargs):
data = dict(
id=1,
name="test shipment",
date=datetime.date(2025, 10, 10),
form_id=1,
products=[],
form=models.Form(id=1, name="test")
)
data.update(kwargs)
return models.ShipmentPublic(**data)
def shipment_create_factory(**kwargs):
data = dict(
name="test shipment",
form_id=1,
date='2025-10-10',
product_ids=[],
)
data.update(kwargs)
return models.ShipmentCreate(**data)
def shipment_update_factory(**kwargs):
data = dict(
name="test shipment",
form_id=1,
date='2025-10-10',
product_ids=[],
)
data.update(kwargs)
return models.ShipmentUpdate(**data)
def shipment_body_factory(**kwargs):
data = dict(
name="test shipment",
form_id=1,
date="2025-10-10",
)
data.update(kwargs)
return data

View File

@@ -0,0 +1,53 @@
from src import models
def user_factory(**kwargs):
data = dict(
id=1,
name="test user",
email="test.test@test.test",
roles=[]
)
data.update(kwargs)
return models.User(**data)
def user_public_factory(**kwargs):
data = dict(
id=1,
name="test user",
email="test.test@test.test",
roles=[]
)
data.update(kwargs)
return models.UserPublic(**data)
def user_create_factory(**kwargs):
data = dict(
name="test user",
email="test.test@test.test",
role_names=[],
)
data.update(kwargs)
return models.UserCreate(**data)
def user_update_factory(**kwargs):
data = dict(
name="test user",
email="test.test@test.test",
role_names=[],
)
data.update(kwargs)
return models.UserUpdate(**data)
def user_body_factory(**kwargs):
data = dict(
name="test user",
email="test.test@test.test",
role_names=[],
)
data.update(kwargs)
return data

188
backend/tests/fixtures.py Normal file
View File

@@ -0,0 +1,188 @@
import datetime
import pytest
import src.forms.service as forms_service
import src.productors.service as productors_service
import src.products.service as products_service
import src.shipments.service as shipments_service
import src.users.service as users_service
import tests.factories.forms as forms_factory
import tests.factories.productors as productors_factory
import tests.factories.products as products_factory
import tests.factories.shipments as shipments_factory
import tests.factories.users as users_factory
from sqlmodel import Session
from src import models
@pytest.fixture
def productor(session: Session) -> models.ProductorPublic:
result = productors_service.create_one(
session,
productors_factory.productor_create_factory(
name='test productor',
type='Légumineuses',
)
)
return result
@pytest.fixture
def productors(session: Session) -> models.ProductorPublic:
result = [
productors_service.create_one(
session,
productors_factory.productor_create_factory(
name='test productor 1',
type='Légumineuses',
)
),
productors_service.create_one(
session,
productors_factory.productor_create_factory(
name='test productor 2',
type='Légumes',
)
)
]
return result
@pytest.fixture
def products(
session: Session,
productor: models.ProductorPublic
) -> list[models.ProductPublic]:
result = [
products_service.create_one(
session,
products_factory.product_create_factory(
name='product 1 occasionnal',
type=models.ProductType.OCCASIONAL,
productor_id=productor.id
)
),
products_service.create_one(
session,
products_factory.product_create_factory(
name='product 2 recurrent',
type=models.ProductType.RECCURENT,
productor_id=productor.id
)
),
]
return result
@pytest.fixture
def user(session: Session) -> models.UserPublic:
user = users_service.create_one(
session,
users_factory.user_create_factory(
name='test user',
email='test@test.com',
role_names=['Légumineuses']
)
)
return user
@pytest.fixture
def users(session: Session) -> list[models.UserPublic]:
result = [
users_service.create_one(
session,
users_factory.user_create_factory(
name='test user 1 (admin)',
email='test1@test.com',
role_names=[
'Légumineuses',
'Légumes',
'Oeufs',
'Porc-Agneau',
'Vin',
'Fruits'])),
users_service.create_one(
session,
users_factory.user_create_factory(
name='test user 2',
email='test2@test.com',
role_names=['Légumineuses'])),
users_service.create_one(
session,
users_factory.user_create_factory(
name='test user 3',
email='test3@test.com',
role_names=['Porc-Agneau']))]
return result
@pytest.fixture
def referer(session: Session) -> models.UserPublic:
result = users_service.create_one(
session,
users_factory.user_create_factory(
name='test referer',
email='test@test.com',
role_names=['Légumineuses'],
)
)
return result
@pytest.fixture
def shipments(
session: Session,
forms: list[models.FormPublic],
products: list[models.ProductPublic]
):
result = [
shipments_service.create_one(
session,
shipments_factory.shipment_create_factory(
name='test shipment 1',
date=datetime.date(2025, 10, 10),
form_id=forms[0].id,
product_ids=[p.id for p in products]
)
),
shipments_service.create_one(
session,
shipments_factory.shipment_create_factory(
name='test shipment 2',
date=datetime.date(2025, 11, 10),
form_id=forms[0].id,
product_ids=[p.id for p in products]
)
),
]
return result
@pytest.fixture
def forms(
session: Session,
productor: models.ProductorPublic,
referer: models.UserPublic
) -> list[models.FormPublic]:
result = [
forms_service.create_one(
session,
forms_factory.form_create_factory(
name='test form 1',
productor_id=productor.id,
referer_id=referer.id,
season='test season 1',
)
),
forms_service.create_one(
session,
forms_factory.form_create_factory(
name='test form 2',
productor_id=productor.id,
referer_id=referer.id,
season='test season 2',
)
)
]
return result

View File

@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
#
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,248 @@
import src.contracts.service as service
import tests.factories.contracts as contract_factory
from fastapi.exceptions import HTTPException
from src.auth.auth import get_current_user
from src.main import app
class TestContracts:
def test_get_all(
self,
client,
mocker,
mock_session,
mock_user
):
mock_results = [
contract_factory.contract_public_factory(id=1),
contract_factory.contract_public_factory(id=2),
contract_factory.contract_public_factory(id=3),
]
mock = mocker.patch.object(
service,
'get_all',
return_value=mock_results
)
response = client.get('/api/contracts')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 1
assert len(response_data) == len(mock_results)
mock.assert_called_once_with(
mock_session,
mock_user,
[],
)
def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user
):
mock_results = [
contract_factory.contract_public_factory(id=2),
]
mock = mocker.patch.object(
service,
'get_all',
return_value=mock_results
)
response = client.get('/api/contracts?forms=form test')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 2
assert len(response_data) == len(mock_results)
mock.assert_called_once_with(
mock_session,
mock_user,
['form test'],
)
def test_get_all_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.contracts.service.get_all')
response = client.get('/api/contracts')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_get_one(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = contract_factory.contract_public_factory(id=2)
mock = mocker.patch.object(
service,
'get_one',
return_value=mock_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.get('/api/contracts/2')
response_data = response.json()
assert response.status_code == 200
assert response_data['id'] == 2
mock.assert_called_once_with(
mock_session,
2
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
2
)
def test_get_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = None
mock = mocker.patch.object(
service,
'get_one',
return_value=mock_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.get('/api/contracts/2')
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
2
)
def test_get_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.contracts.service.get_one')
response = client.get('/api/contracts/2')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_delete_one(
self,
client,
mocker,
mock_session,
mock_user,
):
contract_result = contract_factory.contract_public_factory()
mock = mocker.patch.object(
service,
'delete_one',
return_value=contract_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/contracts/2')
assert response.status_code == 200
mock.assert_called_once_with(
mock_session,
2,
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
2
)
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
contract_result = None
mock = mocker.patch.object(
service,
'delete_one',
return_value=contract_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/contracts/2')
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
2
)
def test_delete_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.contracts.service.delete_one')
response = client.delete('/api/contracts/2')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()

View File

@@ -0,0 +1,517 @@
import src.forms.exceptions as forms_exceptions
import src.forms.service as service
import tests.factories.forms as form_factory
from fastapi.exceptions import HTTPException
from src import messages
from src.auth.auth import get_current_user
from src.main import app
class TestForms:
def test_get_all(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
form_factory.form_public_factory(name="test 1", id=1),
form_factory.form_public_factory(name="test 2", id=2),
form_factory.form_public_factory(name="test 3", id=3),
]
mock = mocker.patch.object(
service,
'get_all',
return_value=mock_results
)
response = client.get('/api/forms/referents')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 1
assert len(response_data) == len(mock_results)
mock.assert_called_once_with(
mock_session,
[],
[],
False,
mock_user,
)
def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
form_factory.form_public_factory(name="test 2", id=2),
]
mock = mocker.patch.object(
service,
'get_all',
return_value=mock_results
)
response = client.get(
'/api/forms/referents?current_season=true&seasons=hiver-2025&productors=test productor')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 2
assert len(response_data) == len(mock_results)
mock.assert_called_once_with(
mock_session,
['hiver-2025'],
['test productor'],
True,
mock_user,
)
def test_get_all_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.forms.service.get_all')
response = client.get('/api/forms/referents')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_get_one(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = form_factory.form_public_factory(name="test 2", id=2)
mock = mocker.patch.object(
service,
'get_one',
return_value=mock_result
)
response = client.get('/api/forms/2')
response_data = response.json()
assert response.status_code == 200
assert response_data['id'] == 2
mock.assert_called_once_with(
mock_session,
2
)
assert mock_user
def test_get_one_notfound(
self,
client,
mocker,
mock_session,
):
mock_result = None
mock = mocker.patch.object(
service,
'get_one',
return_value=mock_result
)
response = client.get('/api/forms/2')
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2
)
def test_create_one(
self,
client,
mocker,
mock_session,
mock_user,
):
form_body = form_factory.form_body_factory(name='test form create')
form_create = form_factory.form_create_factory(name='test form create')
form_result = form_factory.form_public_factory(name='test form create')
mock = mocker.patch.object(
service,
'create_one',
return_value=form_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.post('/api/forms', json=form_body)
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test form create'
mock.assert_called_once_with(
mock_session,
form_create
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
form=form_create
)
def test_create_one_referer_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
form_body = form_factory.form_body_factory(
name='test form create', referer_id=12312
)
form_create = form_factory.form_create_factory(
name='test form create', referer_id=12312
)
mock = mocker.patch.object(
service,
'create_one',
side_effect=forms_exceptions.UserNotFoundError(
messages.Messages.not_found('referer')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.post('/api/forms', json=form_body)
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
form_create
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
form=form_create
)
def test_create_one_productor_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
form_body = form_factory.form_body_factory(
name='test form create', productor_id=1231
)
form_create = form_factory.form_create_factory(
name='test form create', productor_id=1231
)
mock = mocker.patch.object(
service,
'create_one',
side_effect=forms_exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.post('/api/forms', json=form_body)
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
form_create
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
form=form_create
)
def test_create_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
form_body = form_factory.form_body_factory(name='test form create')
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.forms.service.create_one')
response = client.post('/api/forms', json=form_body)
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_update_one(
self,
client,
mocker,
mock_session,
mock_user,
):
form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update')
form_result = form_factory.form_public_factory(name='test form update')
mock = mocker.patch.object(
service,
'update_one',
return_value=form_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/forms/2', json=form_body)
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test form update'
mock.assert_called_once_with(
mock_session,
2,
form_update
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update')
mock = mocker.patch.object(
service,
'update_one',
side_effect=forms_exceptions.FormNotFoundError(
messages.Messages.not_found('form')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/forms/2', json=form_body)
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
form_update
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_referer_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update')
mock = mocker.patch.object(
service, 'update_one', side_effect=forms_exceptions.UserNotFoundError(
messages.Messages.not_found('referer')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/forms/2', json=form_body)
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
form_update
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_productor_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update')
mock = mocker.patch.object(
service,
'update_one',
side_effect=forms_exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/forms/2', json=form_body)
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
form_update
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
form_body = form_factory.form_body_factory(name='test form update')
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.forms.service.update_one')
response = client.put('/api/forms/2', json=form_body)
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_delete_one(
self,
client,
mocker,
mock_session,
mock_user,
):
form_result = form_factory.form_public_factory(name='test form delete')
mock = mocker.patch.object(
service,
'delete_one',
return_value=form_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/forms/2')
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test form delete'
mock.assert_called_once_with(
mock_session,
2,
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock = mocker.patch.object(
service,
'delete_one',
side_effect=forms_exceptions.FormNotFoundError(
messages.Messages.not_found('form'))
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/forms/2')
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.forms.service.delete_one')
response = client.delete('/api/forms/2')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()

View File

@@ -0,0 +1,429 @@
import tests.factories.productors as productor_factory
from fastapi.exceptions import HTTPException
from src import messages
from src.auth.auth import get_current_user
from src.main import app
from src.productors import exceptions, service
class TestProductors:
def test_get_all(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
productor_factory.productor_public_factory(name="test 1", id=1),
productor_factory.productor_public_factory(name="test 2", id=2),
productor_factory.productor_public_factory(name="test 3", id=3),
]
mock = mocker.patch.object(
service,
'get_all',
return_value=mock_results
)
response = client.get('/api/productors')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 1
assert len(response_data) == len(mock_results)
mock.assert_called_once_with(
mock_session,
mock_user,
[],
[],
)
def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
productor_factory.productor_public_factory(name="test 2", id=2),
]
mock = mocker.patch.object(
service,
'get_all',
return_value=mock_results
)
response = client.get(
'/api/productors?types=Légumineuses&names=test 2')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 2
assert len(response_data) == len(mock_results)
mock.assert_called_once_with(
mock_session,
mock_user,
['test 2'],
['Légumineuses'],
)
def test_get_all_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.productors.service.get_all')
response = client.get('/api/productors')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_get_one(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = productor_factory.productor_public_factory(
name="test 2", id=2)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'get_one',
return_value=mock_result
)
response = client.get('/api/productors/2')
response_data = response.json()
assert response.status_code == 200
assert response_data['id'] == 2
mock.assert_called_once_with(
mock_session,
2
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_get_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock_result = None
mock = mocker.patch.object(
service,
'get_one',
return_value=mock_result
)
response = client.get('/api/productors/2')
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_get_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.productors.service.get_one')
response = client.get('/api/productors/2')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_create_one(
self,
client,
mocker,
mock_session,
mock_user,
):
productor_body = productor_factory.productor_body_factory(
name='test productor create')
productor_create = productor_factory.productor_create_factory(
name='test productor create')
productor_result = productor_factory.productor_public_factory(
name='test productor create')
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'create_one',
return_value=productor_result
)
response = client.post('/api/productors', json=productor_body)
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test productor create'
mock.assert_called_once_with(
mock_session,
productor_create
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
productor=productor_create
)
def test_create_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory(
name='test productor create')
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.productors.service.create_one')
response = client.post('/api/productors', json=productor_body)
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_update_one(
self,
client,
mocker,
mock_session,
mock_user,
):
productor_body = productor_factory.productor_body_factory(
name='test productor update')
productor_update = productor_factory.productor_update_factory(
name='test productor update')
productor_result = productor_factory.productor_public_factory(
name='test productor update')
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'update_one',
return_value=productor_result
)
response = client.put('/api/productors/2', json=productor_body)
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test productor update'
mock.assert_called_once_with(
mock_session,
2,
productor_update
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
productor_body = productor_factory.productor_body_factory(
name='test productor update',
)
productor_update = productor_factory.productor_update_factory(
name='test productor update',
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'update_one',
side_effect=exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')
)
)
response = client.put('/api/productors/2', json=productor_body)
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
productor_update
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory(
name='test productor update')
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.productors.service.update_one')
response = client.put('/api/productors/2', json=productor_body)
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_delete_one(
self,
client,
mocker,
mock_session,
mock_user,
):
productor_result = productor_factory.productor_public_factory(
name='test productor delete')
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'delete_one',
return_value=productor_result
)
response = client.delete('/api/productors/2')
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test productor delete'
mock.assert_called_once_with(
mock_session,
2,
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'delete_one',
side_effect=exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')
)
)
response = client.delete('/api/productors/2')
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.productors.service.delete_one')
response = client.delete('/api/productors/2')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()

View File

@@ -0,0 +1,420 @@
import src.products.service as service
import tests.factories.products as product_factory
from fastapi.exceptions import HTTPException
from src.auth.auth import get_current_user
from src.main import app
from src.products import exceptions
class TestProducts:
def test_get_all(
self,
client,
mocker,
mock_session,
mock_user
):
mock_results = [
product_factory.product_public_factory(name="test 1", id=1),
product_factory.product_public_factory(name="test 2", id=2),
product_factory.product_public_factory(name="test 3", id=3),
]
mock = mocker.patch.object(
service,
'get_all',
return_value=mock_results
)
response = client.get('/api/products')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 1
assert len(response_data) == len(mock_results)
mock.assert_called_once_with(
mock_session,
mock_user,
[],
[],
[]
)
def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user
):
mock_results = [
product_factory.product_public_factory(name="test 2", id=2),
]
mock = mocker.patch.object(
service,
'get_all',
return_value=mock_results
)
response = client.get('/api/products?types=1&names=test 2')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 2
assert len(response_data) == len(mock_results)
mock.assert_called_once_with(
mock_session,
mock_user,
['test 2'],
[],
['1'],
)
def test_get_all_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.products.service.get_all')
response = client.get('/api/products')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_get_one(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = product_factory.product_public_factory(
name="test 2", id=2)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'get_one',
return_value=mock_result
)
response = client.get('/api/products/2')
response_data = response.json()
assert response.status_code == 200
assert response_data['id'] == 2
mock.assert_called_once_with(
mock_session,
2
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_get_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = None
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'get_one',
return_value=mock_result
)
response = client.get('/api/products/2')
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_get_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.products.service.get_one')
response = client.get('/api/products/2')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_create_one(
self,
client,
mocker,
mock_session,
mock_user,
):
product_body = product_factory.product_body_factory(
name='test product create')
product_create = product_factory.product_create_factory(
name='test product create')
product_result = product_factory.product_public_factory(
name='test product create')
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'create_one',
return_value=product_result
)
response = client.post('/api/products', json=product_body)
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test product create'
mock.assert_called_once_with(
mock_session,
product_create
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
product=product_create
)
def test_create_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory(
name='test product create')
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.products.service.create_one')
response = client.post('/api/products', json=product_body)
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_update_one(
self,
client,
mocker,
mock_session,
mock_user,
):
product_body = product_factory.product_body_factory(
name='test product update'
)
product_update = product_factory.product_update_factory(
name='test product update'
)
product_result = product_factory.product_public_factory(
name='test product update'
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'update_one',
return_value=product_result
)
response = client.put('/api/products/2', json=product_body)
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test product update'
mock.assert_called_once_with(
mock_session,
2,
product_update
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
product_body = product_factory.product_body_factory(
name='test product update'
)
product_update = product_factory.product_update_factory(
name='test product update'
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
mock = mocker.patch.object(
service,
'update_one',
side_effect=exceptions.ProductNotFoundError('Product not found')
)
response = client.put('/api/products/2', json=product_body)
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
product_update
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory(
name='test product update')
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.products.service.update_one')
response = client.put('/api/products/2', json=product_body)
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_delete_one(
self,
client,
mocker,
mock_session,
mock_user,
):
product_result = product_factory.product_public_factory(
name='test product delete')
mock = mocker.patch.object(
service,
'delete_one',
return_value=product_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/products/2')
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test product delete'
mock.assert_called_once_with(
mock_session,
2,
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock = mocker.patch.object(
service,
'delete_one',
side_effect=exceptions.ProductNotFoundError('Product not found')
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/products/2')
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.products.service.delete_one')
response = client.delete('/api/products/2')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()

View File

@@ -0,0 +1,430 @@
import src.shipments.exceptions as exceptions
import src.shipments.service as service
import tests.factories.shipments as shipment_factory
from fastapi.exceptions import HTTPException
from src import messages
from src.auth.auth import get_current_user
from src.main import app
class TestShipments:
def test_get_all(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
shipment_factory.shipment_public_factory(name="test 1", id=1),
shipment_factory.shipment_public_factory(name="test 2", id=2),
shipment_factory.shipment_public_factory(name="test 3", id=3),
]
mock = mocker.patch.object(
service,
'get_all',
return_value=mock_results
)
response = client.get('/api/shipments')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 1
assert len(response_data) == len(mock_results)
mock.assert_called_once_with(
mock_session,
mock_user,
[],
[],
[],
)
def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
shipment_factory.shipment_public_factory(name="test 2", id=2),
]
mock = mocker.patch.object(
service,
'get_all',
return_value=mock_results
)
response = client.get(
'/api/shipments?dates=2025-10-10&names=test 2&forms=contract form 1')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 2
assert len(response_data) == len(mock_results)
mock.assert_called_once_with(
mock_session,
mock_user,
['test 2'],
['2025-10-10'],
['contract form 1'],
)
def test_get_all_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.shipments.service.get_all')
response = client.get('/api/shipments')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_get_one(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = shipment_factory.shipment_public_factory(
name="test 2", id=2)
mock = mocker.patch.object(
service,
'get_one',
return_value=mock_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.get('/api/shipments/2')
response_data = response.json()
assert response.status_code == 200
assert response_data['id'] == 2
mock.assert_called_once_with(
mock_session,
2
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_get_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = None
mock = mocker.patch.object(
service,
'get_one',
return_value=mock_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.get('/api/shipments/2')
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_get_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.shipments.service.get_one')
response = client.get('/api/shipments/2')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_create_one(
self,
client,
mocker,
mock_session,
mock_user,
):
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment create'
)
shipment_create = shipment_factory.shipment_create_factory(
name='test shipment create'
)
shipment_result = shipment_factory.shipment_public_factory(
name='test shipment create'
)
mock = mocker.patch.object(
service,
'create_one',
return_value=shipment_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.post('/api/shipments', json=shipment_body)
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test shipment create'
mock.assert_called_once_with(
mock_session,
shipment_create
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
shipment=shipment_create
)
def test_create_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment create'
)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.shipments.service.create_one')
response = client.post('/api/shipments', json=shipment_body)
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_update_one(
self,
client,
mocker,
mock_session,
mock_user,
):
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update'
)
shipment_update = shipment_factory.shipment_update_factory(
name='test shipment update'
)
shipment_result = shipment_factory.shipment_public_factory(
name='test shipment update'
)
mock = mocker.patch.object(
service,
'update_one',
return_value=shipment_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/shipments/2', json=shipment_body)
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test shipment update'
mock.assert_called_once_with(
mock_session,
2,
shipment_update
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update'
)
shipment_update = shipment_factory.shipment_update_factory(
name='test shipment update'
)
mock = mocker.patch.object(
service,
'update_one',
side_effect=exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/shipments/2', json=shipment_body)
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
shipment_update
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_update_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update'
)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.shipments.service.update_one')
response = client.put('/api/shipments/2', json=shipment_body)
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_delete_one(
self,
client,
mocker,
mock_session,
mock_user,
):
shipment_result = shipment_factory.shipment_public_factory(
name='test shipment delete'
)
mock = mocker.patch.object(
service,
'delete_one',
return_value=shipment_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/shipments/2')
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test shipment delete'
mock.assert_called_once_with(
mock_session,
2,
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock = mocker.patch.object(
service,
'delete_one',
side_effect=exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment')
)
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/shipments/2')
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
)
mock_is_allowed.assert_called_once_with(
mock_session,
mock_user,
_id=2
)
def test_delete_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.shipments.service.delete_one')
response = client.delete('/api/shipments/2')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()

View File

@@ -0,0 +1,383 @@
import src.users.exceptions as exceptions
import src.users.service as service
import tests.factories.users as user_factory
from fastapi.exceptions import HTTPException
from src.auth.auth import get_current_user
from src.main import app
class TestUsers:
def test_get_all(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
user_factory.user_public_factory(name="test 1", id=1),
user_factory.user_public_factory(name="test 2", id=2),
user_factory.user_public_factory(name="test 3", id=3),
]
mock = mocker.patch.object(
service,
'get_all',
return_value=mock_results
)
response = client.get('/api/users')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 1
assert len(response_data) == len(mock_results)
mock.assert_called_once_with(
mock_session,
[],
[],
)
assert mock_user
def test_get_all_filters(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_results = [
user_factory.user_public_factory(name="test 2", id=2),
]
mock = mocker.patch.object(
service,
'get_all',
return_value=mock_results
)
response = client.get('/api/users?emails=test@test.test&names=test 2')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 2
assert len(response_data) == len(mock_results)
mock.assert_called_once_with(
mock_session,
['test 2'],
['test@test.test'],
)
assert mock_user
def test_get_all_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.users.service.get_all')
response = client.get('/api/users')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_get_one(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = user_factory.user_public_factory(name="test 2", id=2)
mock = mocker.patch.object(
service,
'get_one',
return_value=mock_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.get('/api/users/2')
response_data = response.json()
assert response.status_code == 200
assert response_data['id'] == 2
mock.assert_called_once_with(
mock_session,
2
)
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_get_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock_result = None
mock = mocker.patch.object(
service,
'get_one',
return_value=mock_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.get('/api/users/2')
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2
)
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_get_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.users.service.get_one')
response = client.get('/api/users/2')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_create_one(
self,
client,
mocker,
mock_session,
mock_user,
):
user_body = user_factory.user_body_factory(name='test user create')
user_create = user_factory.user_create_factory(name='test user create')
user_result = user_factory.user_public_factory(name='test user create')
mock = mocker.patch.object(
service,
'create_one',
return_value=user_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.post('/api/users', json=user_body)
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test user create'
mock.assert_called_once_with(
mock_session,
user_create
)
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_create_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
user_body = user_factory.user_body_factory(name='test user create')
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.users.service.create_one')
response = client.post('/api/users', json=user_body)
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_update_one(
self,
client,
mocker,
mock_session,
mock_user,
):
user_body = user_factory.user_body_factory(name='test user update')
user_update = user_factory.user_update_factory(name='test user update')
user_result = user_factory.user_public_factory(name='test user update')
mock = mocker.patch.object(
service,
'update_one',
return_value=user_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/users/2', json=user_body)
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test user update'
mock.assert_called_once_with(
mock_session,
2,
user_update
)
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
user_body = user_factory.user_body_factory(name='test user update')
user_update = user_factory.user_update_factory(name='test user update')
mock = mocker.patch.object(
service,
'update_one',
side_effect=exceptions.UserNotFoundError('User 2 not found')
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.put('/api/users/2', json=user_body)
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
user_update
)
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_update_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
user_body = user_factory.user_body_factory(name='test user update')
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.users.service.update_one')
response = client.put('/api/users/2', json=user_body)
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()
def test_delete_one(
self,
client,
mocker,
mock_session,
mock_user,
):
user_result = user_factory.user_public_factory(name='test user delete')
mock = mocker.patch.object(
service,
'delete_one',
return_value=user_result
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/users/2')
response_data = response.json()
assert response.status_code == 200
assert response_data['name'] == 'test user delete'
mock.assert_called_once_with(
mock_session,
2,
)
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user,
):
mock = mocker.patch.object(
service,
'delete_one',
side_effect=exceptions.UserNotFoundError('User 2 not found')
)
mock_is_allowed = mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/users/2')
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2,
)
mock_is_allowed.assert_called_once_with(
mock_user
)
def test_delete_one_unauthorized(
self,
client,
mocker,
):
def unauthorized():
raise HTTPException(status_code=401)
app.dependency_overrides[get_current_user] = unauthorized
mock = mocker.patch('src.users.service.delete_one')
response = client.delete('/api/users/2')
assert response.status_code == 401
mock.assert_not_called()
app.dependency_overrides.clear()

View File

@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2026-present Julien Aldon <julien.aldon@wanadoo.fr>
#
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,158 @@
import pytest
import src.forms.exceptions as forms_exceptions
import src.forms.service as forms_service
import tests.factories.forms as forms_factory
from sqlmodel import Session
from src import models
class TestFormsService:
def test_get_all_forms(self, session: Session,
forms: list[models.FormPublic]):
result = forms_service.get_all(session, [], [], False)
assert len(result) == 2
assert result == forms
def test_get_all_forms_filter_productors(
self, session: Session, forms: list[models.FormPublic]):
result = forms_service.get_all(session, [], ['test productor'], False)
assert len(result) == 2
assert result == forms
def test_get_all_forms_filter_season(
self, session: Session, forms: list[models.FormPublic]):
result = forms_service.get_all(session, ['test season 1'], [], False)
assert len(result) == 1
def test_get_all_forms_all_filters(
self, session: Session, forms: list[models.FormPublic]):
result = forms_service.get_all(
session, ['test season 1'], ['test productor'], True)
assert result == forms
def test_get_one_form(self, session: Session,
forms: list[models.FormPublic]):
result = forms_service.get_one(session, forms[0].id)
assert result == forms[0]
def test_get_one_form_notfound(self, session: Session):
result = forms_service.get_one(session, 122)
assert result is None
def test_create_form(
self,
session: Session,
productor: models.ProductorPublic,
referer: models.ProductorPublic
):
form_create = forms_factory.form_create_factory(
name="new test form",
productor_id=productor.id,
referer=referer.id,
season="new test season",
)
result = forms_service.create_one(session, form_create)
assert result.id is not None
assert result.name == "new test form"
assert result.productor.name == "test productor"
def test_create_form_invalidinput(
self,
session: Session,
productor: models.Productor
):
form_create = None
with pytest.raises(forms_exceptions.FormCreateError):
result = forms_service.create_one(session, form_create)
form_create = forms_factory.form_create_factory(productor_id=123)
with pytest.raises(forms_exceptions.ProductorNotFoundError):
result = forms_service.create_one(session, form_create)
form_create = forms_factory.form_create_factory(
productor_id=productor.id,
referer_id=123
)
with pytest.raises(forms_exceptions.UserNotFoundError):
result = forms_service.create_one(session, form_create)
def test_update_form(
self,
session: Session,
productor: models.ProductorPublic,
referer: models.ProductorPublic,
forms: list[models.FormPublic]
):
form_update = forms_factory.form_update_factory(
name='updated test form',
productor_id=productor.id,
referer_id=referer.id,
season='updated test season'
)
form_id = forms[0].id
result = forms_service.update_one(session, form_id, form_update)
assert result.id == form_id
assert result.name == 'updated test form'
assert result.season == 'updated test season'
def test_update_form_notfound(
self,
session: Session,
productor: models.ProductorPublic,
referer: models.ProductorPublic,
):
form_update = forms_factory.form_update_factory(
name='updated test form',
productor_id=productor.id,
referer_id=referer.id,
season='updated test season'
)
form_id = 123
with pytest.raises(forms_exceptions.FormNotFoundError):
result = forms_service.update_one(session, form_id, form_update)
def test_update_form_invalidinput(
self,
session: Session,
productor: models.ProductorPublic,
forms: list[models.FormPublic]
):
form_id = forms[0].id
form_update = forms_factory.form_update_factory(productor_id=123)
with pytest.raises(forms_exceptions.ProductorNotFoundError):
result = forms_service.update_one(session, form_id, form_update)
form_update = forms_factory.form_update_factory(
productor_id=productor.id,
referer_id=123
)
with pytest.raises(forms_exceptions.UserNotFoundError):
result = forms_service.update_one(session, form_id, form_update)
def test_delete_form(
self,
session: Session,
forms: list[models.FormPublic]
):
form_id = forms[0].id
result = forms_service.delete_one(session, form_id)
check = forms_service.get_one(session, form_id)
assert check is None
def test_delete_form_notfound(
self,
session: Session,
forms: list[models.FormPublic]
):
form_id = 123
with pytest.raises(forms_exceptions.FormNotFoundError):
result = forms_service.delete_one(session, form_id)

View File

@@ -0,0 +1,146 @@
import pytest
import src.productors.exceptions as productors_exceptions
import src.productors.service as productors_service
import tests.factories.productors as productors_factory
from sqlmodel import Session
from src import models
class TestProductorsService:
def test_get_all_productors(
self,
session: Session,
productors: list[models.ProductorPublic],
user: models.UserPublic
):
result = productors_service.get_all(session, user, [], [])
assert len(result) == 1
assert result == [productors[0]]
def test_get_all_productors_filter_names(
self,
session: Session,
productors: list[models.ProductorPublic],
user: models.UserPublic
):
result = productors_service.get_all(
session,
user,
['test productor 1'],
[]
)
assert len(result) == 1
def test_get_all_productors_filter_types(
self,
session: Session,
productors: list[models.ProductorPublic],
user: models.UserPublic
):
result = productors_service.get_all(
session,
user,
[],
['Légumineuses'],
)
assert len(result) == 1
def test_get_all_productors_all_filters(
self,
session: Session,
productors: list[models.ProductorPublic],
user: models.UserPublic
):
result = productors_service.get_all(
session,
user,
['test productor 1'],
['Légumineuses'],
)
assert len(result) == 1
def test_get_one_productor(self,
session: Session,
productors: list[models.ProductorPublic]):
result = productors_service.get_one(session, productors[0].id)
assert result == productors[0]
def test_get_one_productor_notfound(self, session: Session):
result = productors_service.get_one(session, 122)
assert result is None
def test_create_productor(
self,
session: Session,
referer: models.ProductorPublic
):
productor_create = productors_factory.productor_create_factory(
name="new test productor",
)
result = productors_service.create_one(session, productor_create)
assert result.id is not None
assert result.name == "new test productor"
def test_create_productor_invalidinput(
self,
session: Session,
):
productor_create = None
with pytest.raises(productors_exceptions.ProductorCreateError):
result = productors_service.create_one(session, productor_create)
def test_update_productor(
self,
session: Session,
referer: models.ProductorPublic,
productors: list[models.ProductorPublic]
):
productor_update = productors_factory.productor_update_factory(
name='updated test productor',
)
productor_id = productors[0].id
result = productors_service.update_one(
session, productor_id, productor_update)
assert result.id == productor_id
assert result.name == 'updated test productor'
def test_update_productor_notfound(
self,
session: Session,
referer: models.ProductorPublic,
):
productor_update = productors_factory.productor_update_factory(
name='updated test productor',
)
productor_id = 123
with pytest.raises(productors_exceptions.ProductorNotFoundError):
result = productors_service.update_one(
session, productor_id, productor_update)
def test_delete_productor(
self,
session: Session,
productors: list[models.ProductorPublic]
):
productor_id = productors[0].id
result = productors_service.delete_one(session, productor_id)
check = productors_service.get_one(session, productor_id)
assert check is None
def test_delete_productor_notfound(
self,
session: Session,
productors: list[models.ProductorPublic]
):
productor_id = 123
with pytest.raises(productors_exceptions.ProductorNotFoundError):
result = productors_service.delete_one(session, productor_id)

View File

@@ -0,0 +1,196 @@
import pytest
import src.products.exceptions as products_exceptions
import src.products.service as products_service
import tests.factories.products as products_factory
from sqlmodel import Session
from src import models
class TestProductsService:
def test_get_all_products(
self,
session: Session,
products: list[models.ProductPublic],
user: models.UserPublic
):
result = products_service.get_all(session, user, [], [], [])
assert len(result) == 2
assert result == products
def test_get_all_products_filter_productors(
self,
session: Session,
products: list[models.ProductPublic],
user: models.UserPublic
):
result = products_service.get_all(
session,
user,
[],
['test productor'],
[]
)
assert len(result) == 2
assert result == products
def test_get_all_products_filter_names(
self,
session: Session,
products: list[models.ProductPublic],
user: models.UserPublic
):
result = products_service.get_all(
session,
user,
['product 1 occasionnal'],
[],
[]
)
assert len(result) == 1
def test_get_all_products_filter_types(
self,
session: Session,
products: list[models.ProductPublic],
user: models.UserPublic
):
result = products_service.get_all(
session,
user,
[],
[],
['1']
)
assert len(result) == 1
def test_get_all_products_all_filters(
self,
session: Session,
products: list[models.ProductPublic],
user: models.UserPublic
):
result = products_service.get_all(
session,
user,
['product 1 occasionnal'],
['test productor'],
['1']
)
assert len(result) == 1
def test_get_one_product(self, session: Session,
products: list[models.ProductPublic]):
result = products_service.get_one(session, products[0].id)
assert result == products[0]
def test_get_one_product_notfound(self, session: Session):
result = products_service.get_one(session, 122)
assert result is None
def test_create_product(
self,
session: Session,
productor: models.ProductorPublic,
referer: models.ProductorPublic
):
product_create = products_factory.product_create_factory(
name="new test product",
productor_id=productor.id,
)
result = products_service.create_one(session, product_create)
assert result.id is not None
assert result.name == "new test product"
assert result.productor.name == "test productor"
def test_create_product_invalidinput(
self,
session: Session,
productor: models.Productor
):
product_create = None
with pytest.raises(products_exceptions.ProductCreateError):
result = products_service.create_one(session, product_create)
product_create = products_factory.product_create_factory(
productor_id=123)
with pytest.raises(products_exceptions.ProductorNotFoundError):
result = products_service.create_one(session, product_create)
def test_update_product(
self,
session: Session,
productor: models.ProductorPublic,
referer: models.ProductorPublic,
products: list[models.ProductPublic]
):
product_update = products_factory.product_update_factory(
name='updated test product',
productor_id=productor.id,
)
product_id = products[0].id
result = products_service.update_one(
session, product_id, product_update)
assert result.id == product_id
assert result.name == 'updated test product'
def test_update_product_notfound(
self,
session: Session,
productor: models.ProductorPublic,
referer: models.ProductorPublic,
):
product_update = products_factory.product_update_factory(
name='updated test product',
productor_id=productor.id,
)
product_id = 123
with pytest.raises(products_exceptions.ProductNotFoundError):
result = products_service.update_one(
session, product_id, product_update)
def test_update_product_invalidinput(
self,
session: Session,
productor: models.ProductorPublic,
products: list[models.ProductPublic]
):
product_id = products[0].id
product_update = products_factory.product_update_factory(
productor_id=123)
with pytest.raises(products_exceptions.ProductorNotFoundError):
result = products_service.update_one(
session, product_id, product_update)
product_update = products_factory.product_update_factory(
productor_id=productor.id,
referer_id=123
)
def test_delete_product(
self,
session: Session,
products: list[models.ProductPublic]
):
product_id = products[0].id
result = products_service.delete_one(session, product_id)
check = products_service.get_one(session, product_id)
assert check is None
def test_delete_product_notfound(
self,
session: Session,
products: list[models.ProductPublic]
):
product_id = 123
with pytest.raises(products_exceptions.ProductNotFoundError):
result = products_service.delete_one(session, product_id)

View File

@@ -0,0 +1,150 @@
import datetime
import pytest
import src.shipments.exceptions as shipments_exceptions
import src.shipments.service as shipments_service
import tests.factories.shipments as shipments_factory
from sqlmodel import Session
from src import models
class TestShipmentsService:
def test_get_all_shipments(
self,
session: Session,
shipments: list[models.ShipmentPublic],
user: models.UserPublic,
):
result = shipments_service.get_all(session, user, [], [], [])
assert len(result) == 2
assert result == shipments
def test_get_all_shipments_filter_names(
self,
session: Session,
shipments: list[models.ShipmentPublic],
user: models.UserPublic,
):
result = shipments_service.get_all(
session, user, ['test shipment 1'], [], [])
assert len(result) == 1
assert result == [shipments[0]]
def test_get_all_shipments_filter_dates(
self,
session: Session,
shipments: list[models.ShipmentPublic],
user: models.UserPublic,
):
result = shipments_service.get_all(
session, user, [], ['2025-10-10'], [])
assert len(result) == 1
def test_get_all_shipments_filter_forms(
self,
session: Session,
shipments: list[models.ShipmentPublic],
forms: list[models.FormPublic],
user: models.UserPublic,
):
result = shipments_service.get_all(
session, user, [], [], [forms[0].name])
assert len(result) == 2
def test_get_all_shipments_all_filters(
self,
session: Session,
shipments: list[models.ShipmentPublic],
forms: list[models.FormPublic],
user: models.UserPublic,
):
result = shipments_service.get_all(session, user, ['test shipment 1'], [
'2025-10-10'], [forms[0].name])
assert len(result) == 1
def test_get_one_shipment(self, session: Session,
shipments: list[models.ShipmentPublic]):
result = shipments_service.get_one(session, shipments[0].id)
assert result == shipments[0]
def test_get_one_shipment_notfound(self, session: Session):
result = shipments_service.get_one(session, 122)
assert result is None
def test_create_shipment(
self,
session: Session,
):
shipment_create = shipments_factory.shipment_create_factory(
name='new test shipment',
date='2025-10-10',
)
result = shipments_service.create_one(session, shipment_create)
assert result.id is not None
assert result.name == "new test shipment"
def test_create_shipment_invalidinput(
self,
session: Session,
):
shipment_create = None
with pytest.raises(shipments_exceptions.ShipmentCreateError):
result = shipments_service.create_one(session, shipment_create)
def test_update_shipment(
self,
session: Session,
shipments: list[models.ShipmentPublic]
):
shipment_update = shipments_factory.shipment_update_factory(
name='updated shipment 1',
date='2025-12-10',
)
shipment_id = shipments[0].id
result = shipments_service.update_one(
session, shipment_id, shipment_update)
assert result.id == shipment_id
assert result.name == 'updated shipment 1'
assert result.date == datetime.date(2025, 12, 10)
def test_update_shipment_notfound(
self,
session: Session,
):
shipment_update = shipments_factory.shipment_update_factory(
name='updated shipment 1',
date=datetime.date(2025, 10, 10),
)
shipment_id = 123
with pytest.raises(shipments_exceptions.ShipmentNotFoundError):
result = shipments_service.update_one(
session, shipment_id, shipment_update)
def test_delete_shipment(
self,
session: Session,
shipments: list[models.ShipmentPublic]
):
shipment_id = shipments[0].id
result = shipments_service.delete_one(session, shipment_id)
check = shipments_service.get_one(session, shipment_id)
assert check is None
def test_delete_shipment_notfound(
self,
session: Session,
shipments: list[models.ShipmentPublic]
):
shipment_id = 123
with pytest.raises(shipments_exceptions.ShipmentNotFoundError):
result = shipments_service.delete_one(session, shipment_id)

View File

@@ -0,0 +1,120 @@
import pytest
import src.users.exceptions as users_exceptions
import src.users.service as users_service
import tests.factories.users as users_factory
from sqlmodel import Session
from src import models
class TestUsersService:
def test_get_all_users(self, session: Session,
users: list[models.UserPublic]):
result = users_service.get_all(session, [], [])
assert len(result) == 3
assert result == users
def test_get_all_users_filter_names(
self, session: Session, users: list[models.UserPublic]):
result = users_service.get_all(session, ['test user 1 (admin)'], [])
assert len(result) == 1
assert result == [users[0]]
def test_get_all_users_filter_emails(
self, session: Session, users: list[models.UserPublic]):
result = users_service.get_all(session, [], ['test1@test.com'])
assert len(result) == 1
def test_get_all_users_all_filters(
self, session: Session, users: list[models.UserPublic]):
result = users_service.get_all(
session, ['test user 1 (admin)'], ['test1@test.com'])
assert len(result) == 1
def test_get_one_user(self, session: Session,
users: list[models.UserPublic]):
result = users_service.get_one(session, users[0].id)
assert result == users[0]
def test_get_one_user_notfound(self, session: Session):
result = users_service.get_one(session, 122)
assert result is None
def test_create_user(
self,
session: Session,
):
user_create = users_factory.user_create_factory(
name="new test user",
email='test@test.fr',
role_names=['test role']
)
result = users_service.create_one(session, user_create)
assert result.id is not None
assert result.name == "new test user"
assert result.email == "test@test.fr"
assert len(result.roles) == 1
def test_create_user_invalidinput(
self,
session: Session,
):
user_create = None
with pytest.raises(users_exceptions.UserCreateError):
result = users_service.create_one(session, user_create)
def test_update_user(
self,
session: Session,
users: list[models.UserPublic]
):
user_update = users_factory.user_update_factory(
name="updated test user",
email='test@testttt.fr',
role_names=['test role']
)
user_id = users[0].id
result = users_service.update_one(session, user_id, user_update)
assert result.id == user_id
assert result.name == 'updated test user'
assert result.email == 'test@testttt.fr'
def test_update_user_notfound(
self,
session: Session,
):
user_update = users_factory.user_update_factory(
name="updated test user",
email='test@testttt.fr',
role_names=['test role']
)
user_id = 123
with pytest.raises(users_exceptions.UserNotFoundError):
result = users_service.update_one(session, user_id, user_update)
def test_delete_user(
self,
session: Session,
users: list[models.UserPublic]
):
user_id = users[0].id
result = users_service.delete_one(session, user_id)
check = users_service.get_one(session, user_id)
assert check is None
def test_delete_user_notfound(
self,
session: Session,
users: list[models.UserPublic]
):
user_id = 123
with pytest.raises(users_exceptions.UserNotFoundError):
result = users_service.delete_one(session, user_id)

View File

@@ -82,7 +82,10 @@
"you can download all contracts for your form using the export all": "you can download all contracts for your form using the export all", "you can download all contracts for your form using the export all": "you can download all contracts for your form using the export all",
"in the same corner you can download a recap by clicking on the button": "in the same corner you can download a recap by clicking on the", "in the same corner you can download a recap by clicking on the button": "in the same corner you can download a recap by clicking on the",
"once all contracts downloaded, you can delete the form (to avoid new submissions) and hide it from the home page": "once all contracts downloaded, you can delete the form (to avoid new submissions) and hide it from the home page", "once all contracts downloaded, you can delete the form (to avoid new submissions) and hide it from the home page": "once all contracts downloaded, you can delete the form (to avoid new submissions) and hide it from the home page",
"by checking this option the form will be accessible publicly on the home page, only check it if everything is fine with your form": "by checking this option the form will be accessible publicly on the home page, only check it if everything is fine with your form",
"contracts": "contracts", "contracts": "contracts",
"hidden": "hidden",
"visible": "visible",
"minimum price for this shipment should be at least": "minimum price for this shipment should be at least", "minimum price for this shipment should be at least": "minimum price for this shipment should be at least",
"there is": "there is", "there is": "there is",
"for this contract": "for this contract.", "for this contract": "for this contract.",
@@ -157,6 +160,7 @@
"and/or": "and/or", "and/or": "and/or",
"form name recommandation": "recommendation: Contract <contract-type> (Example: Pork-Lamb Contract)", "form name recommandation": "recommendation: Contract <contract-type> (Example: Pork-Lamb Contract)",
"submit contract": "submit contract", "submit contract": "submit contract",
"submit": "submit",
"example in user forms": "example in user contract form", "example in user forms": "example in user contract form",
"occasional product": "occasional product", "occasional product": "occasional product",
"recurrent product": "recurrent product", "recurrent product": "recurrent product",

View File

@@ -73,7 +73,10 @@
"shipment products is necessary only for occasional products (if all products are recurrent leave empty)": "il est nécessaire de configurer les produits pour la livraison uniquement si il y a des produits occasionnels (laisser vide si tous les produits sont récurents).", "shipment products is necessary only for occasional products (if all products are recurrent leave empty)": "il est nécessaire de configurer les produits pour la livraison uniquement si il y a des produits occasionnels (laisser vide si tous les produits sont récurents).",
"recurrent product is for all shipments, occasional product is for a specific shipment (see shipment form)": "les produits récurrents sont pour toutes les livraisons, les produits occasionnels sont pour une livraison particulière (voir formulaire de création de livraison).", "recurrent product is for all shipments, occasional product is for a specific shipment (see shipment form)": "les produits récurrents sont pour toutes les livraisons, les produits occasionnels sont pour une livraison particulière (voir formulaire de création de livraison).",
"some contracts require a minimum value per shipment, ignore this field if it's not the case": "certains contrats nécessitent une valeur minimum par livraison. Ce champ peut être ignoré sil ne sapplique pas à votre contrat.", "some contracts require a minimum value per shipment, ignore this field if it's not the case": "certains contrats nécessitent une valeur minimum par livraison. Ce champ peut être ignoré sil ne sapplique pas à votre contrat.",
"by checking this option the form will be accessible publicly on the home page, only check it if everything is fine with your form": "en cochant cette option le formulaire sera accessible publiquement sur la page d'accueil, cochez cette option uniquement si tout est prêt avec votre formulaire.",
"contracts": "contrats", "contracts": "contrats",
"hidden": "caché",
"visible": "visible",
"minimum price for this shipment should be at least": "le prix minimum d'une livraison doit être au moins de", "minimum price for this shipment should be at least": "le prix minimum d'une livraison doit être au moins de",
"there is": "il y a", "there is": "il y a",
"for this contract": "pour ce contrat.", "for this contract": "pour ce contrat.",
@@ -157,6 +160,7 @@
"and/or": "et/ou", "and/or": "et/ou",
"form name recommandation": "recommandation : Contrat <contract-type> (Exemple : Contrat Porc-Agneau)", "form name recommandation": "recommandation : Contrat <contract-type> (Exemple : Contrat Porc-Agneau)",
"submit contract": "envoyer le contrat", "submit contract": "envoyer le contrat",
"submit": "envoyer",
"example in user forms": "exemple dans le formulaire à destination des amapiens", "example in user forms": "exemple dans le formulaire à destination des amapiens",
"occasional product": "produit occasionnel", "occasional product": "produit occasionnel",
"recurrent product": "produit récurrent", "recurrent product": "produit récurrent",
@@ -166,7 +170,7 @@
"with cheque and transfer": "avec chèques et virements configuré pour le producteur", "with cheque and transfer": "avec chèques et virements configuré pour le producteur",
"mililiter": "mililitres (ml)", "mililiter": "mililitres (ml)",
"this field is optionnal a product can have a quantity if configured inside the product it will be shown inside the form": "ce champ est optionnel dans la configuration d'un produit, il représente la quantité d'un produit (poids d'une tranche de foie, poids d'un panier, taille d'un bocal...). Si ce champs est renseigné il sera affiché dans le formulaire à destination des amapiens.", "this field is optionnal a product can have a quantity if configured inside the product it will be shown inside the form": "ce champ est optionnel dans la configuration d'un produit, il représente la quantité d'un produit (poids d'une tranche de foie, poids d'un panier, taille d'un bocal...). Si ce champs est renseigné il sera affiché dans le formulaire à destination des amapiens.",
"this field is also optionnal if a product have a quantity you can select the correct unit (metric system). It will be shown next to product quantity inside the form": "ce champs est optionnel dans la configuation d'un produit, il représente l'unité de mesure associé à la quantité d'un produit (g, kg, ml, L). Si ce champs est renseigné il sera affiché dans le formulaire à destination des amapiens à coté de la quantité du produit.", "this field is also optionnal if a product have a quantity you can select the correct unit (metric system). It will be shown next to product quantity inside the form": "ce champs est optionnel dans la configuation d'un produit, il représente l'unité de mesure associée à la quantité d'un produit (g, kg, ml, L). Si ce champs est renseigné il sera affiché dans le formulaire à destination des amapiens à coté de la quantité du produit.",
"with 150 set as quantity and g as quantity unit in product": "avec \"150\" en quantité de produit et \"grammes\" selectionné dans l'unité de quantité du produit", "with 150 set as quantity and g as quantity unit in product": "avec \"150\" en quantité de produit et \"grammes\" selectionné dans l'unité de quantité du produit",
"all shipments should be recreated for each form creation": "les livraisons étant liées à un formulaire elles doivent être recréés pour chaque nouveau formulaire.", "all shipments should be recreated for each form creation": "les livraisons étant liées à un formulaire elles doivent être recréés pour chaque nouveau formulaire.",
"a productor can be edited if its informations change, it should not be recreated for each contracts": "un(e) producteur·trice peut être édité si ses informations changent, il/elle ne doit pas être recréé pour chaque nouveau contrat.", "a productor can be edited if its informations change, it should not be recreated for each contracts": "un(e) producteur·trice peut être édité si ses informations changent, il/elle ne doit pas être recréé pour chaque nouveau contrat.",

View File

@@ -26,6 +26,9 @@ export function ContractModal({ opened, onClose, handleSubmit }: ContractModalPr
}); });
const formSelect = useMemo(() => { const formSelect = useMemo(() => {
if (!allForms) {
return [];
}
return allForms?.map((form) => ({ return allForms?.map((form) => ({
value: String(form.id), value: String(form.id),
label: `${form.season} ${form.name}`, label: `${form.season} ${form.name}`,

View File

@@ -1,5 +1,6 @@
import { import {
Button, Button,
Checkbox,
Group, Group,
Modal, Modal,
NumberInput, NumberInput,
@@ -33,6 +34,7 @@ export default function FormModal({ opened, onClose, currentForm, handleSubmit }
productor_id: currentForm?.productor?.id.toString() ?? "", productor_id: currentForm?.productor?.id.toString() ?? "",
referer_id: currentForm?.referer?.id.toString() ?? "", referer_id: currentForm?.referer?.id.toString() ?? "",
minimum_shipment_value: currentForm?.minimum_shipment_value ?? null, minimum_shipment_value: currentForm?.minimum_shipment_value ?? null,
visible: currentForm?.visible ?? false
}, },
validate: { validate: {
name: (value) => name: (value) =>
@@ -51,6 +53,8 @@ export default function FormModal({ opened, onClose, currentForm, handleSubmit }
}); });
const usersSelect = useMemo(() => { const usersSelect = useMemo(() => {
if (!users)
return [];
return users?.map((user) => ({ return users?.map((user) => ({
value: String(user.id), value: String(user.id),
label: `${user.name}`, label: `${user.name}`,
@@ -58,6 +62,8 @@ export default function FormModal({ opened, onClose, currentForm, handleSubmit }
}, [users]); }, [users]);
const productorsSelect = useMemo(() => { const productorsSelect = useMemo(() => {
if (!productors)
return [];
return productors?.map((prod) => ({ return productors?.map((prod) => ({
value: String(prod.id), value: String(prod.id),
label: `${prod.name}`, label: `${prod.name}`,
@@ -136,6 +142,11 @@ export default function FormModal({ opened, onClose, currentForm, handleSubmit }
radius="sm" radius="sm"
{...form.getInputProps("minimum_shipment_value")} {...form.getInputProps("minimum_shipment_value")}
/> />
<Checkbox mt="lg"
label={t("visible", {capfirst: true})}
description={t("by checking this option the form will be accessible publicly on the home page, only check it if everything is fine with your form", {capfirst: true})}
{...form.getInputProps("visible", {type: "checkbox"})}
/>
<Group mt="sm" justify="space-between"> <Group mt="sm" justify="space-between">
<Button <Button
variant="filled" variant="filled"

View File

@@ -1,4 +1,4 @@
import { ActionIcon, Table, Tooltip } from "@mantine/core"; import { ActionIcon, Badge, Table, Tooltip } from "@mantine/core";
import { useNavigate, useSearchParams } from "react-router"; import { useNavigate, useSearchParams } from "react-router";
import { useDeleteForm } from "@/services/api"; import { useDeleteForm } from "@/services/api";
import { IconEdit, IconX } from "@tabler/icons-react"; import { IconEdit, IconX } from "@tabler/icons-react";
@@ -16,6 +16,12 @@ export default function FormRow({ form }: FormRowProps) {
return ( return (
<Table.Tr key={form.id}> <Table.Tr key={form.id}>
<Table.Td>
{form.visible ?
<Badge color="green">{t("visible", {capfirst: true})}</Badge> :
<Badge color="red">{t("hidden", {capfirst: true})}</Badge>
}
</Table.Td>
<Table.Td>{form.name}</Table.Td> <Table.Td>{form.name}</Table.Td>
<Table.Td>{form.season}</Table.Td> <Table.Td>{form.season}</Table.Td>
<Table.Td>{form.start}</Table.Td> <Table.Td>{form.start}</Table.Td>

View File

@@ -4,9 +4,12 @@ import "./index.css";
import { Group, Loader } from "@mantine/core"; import { Group, Loader } from "@mantine/core";
import { Config } from "@/config/config"; import { Config } from "@/config/config";
import { useAuth } from "@/services/auth/AuthProvider"; import { useAuth } from "@/services/auth/AuthProvider";
import { useMediaQuery } from "@mantine/hooks";
import { IconHome, IconLogin, IconLogout, IconSettings } from "@tabler/icons-react";
export function Navbar() { export function Navbar() {
const { loggedUser: user, isLoading } = useAuth(); const { loggedUser: user, isLoading } = useAuth();
const isPhone = useMediaQuery("(max-width: 760px");
if (!user && isLoading) { if (!user && isLoading) {
return ( return (
@@ -20,11 +23,11 @@ export function Navbar() {
<nav> <nav>
<Group> <Group>
<NavLink className={"navLink"} aria-label={t("home")} to="/"> <NavLink className={"navLink"} aria-label={t("home")} to="/">
{t("home", { capfirst: true })} {isPhone ? <IconHome/> : t("home", { capfirst: true })}
</NavLink> </NavLink>
{user?.logged ? ( {user?.logged ? (
<NavLink className={"navLink"} aria-label={t("dashboard")} to="/dashboard/help"> <NavLink className={"navLink"} aria-label={t("dashboard")} to="/dashboard/help">
{t("dashboard", { capfirst: true })} {isPhone ? <IconSettings/> : t("dashboard", { capfirst: true })}
</NavLink> </NavLink>
) : null} ) : null}
</Group> </Group>
@@ -34,7 +37,7 @@ export function Navbar() {
className={"navLink"} className={"navLink"}
aria-label={t("login with keycloak", { capfirst: true })} aria-label={t("login with keycloak", { capfirst: true })}
> >
{t("login with keycloak", { capfirst: true })} {isPhone ? <IconLogin/> : t("login with keycloak", { capfirst: true })}
</a> </a>
) : ( ) : (
<a <a
@@ -42,7 +45,7 @@ export function Navbar() {
className={"navLink"} className={"navLink"}
aria-label={t("logout", { capfirst: true })} aria-label={t("logout", { capfirst: true })}
> >
{t("logout", { capfirst: true })} {isPhone ? <IconLogout/> : t("logout", { capfirst: true })}
</a> </a>
)} )}
</nav> </nav>

View File

@@ -19,7 +19,7 @@ import {
type ProductorInputs, type ProductorInputs,
} from "@/services/resources/productors"; } from "@/services/resources/productors";
import { useMemo } from "react"; import { useMemo } from "react";
import { useGetRoles } from "@/services/api"; import { useAuth } from "@/services/auth/AuthProvider";
export type ProductorModalProps = ModalBaseProps & { export type ProductorModalProps = ModalBaseProps & {
currentProductor?: Productor; currentProductor?: Productor;
@@ -32,7 +32,7 @@ export function ProductorModal({
currentProductor, currentProductor,
handleSubmit, handleSubmit,
}: ProductorModalProps) { }: ProductorModalProps) {
const { data: allRoles } = useGetRoles(); const { loggedUser } = useAuth();
const form = useForm<ProductorInputs>({ const form = useForm<ProductorInputs>({
initialValues: { initialValues: {
@@ -58,8 +58,8 @@ export function ProductorModal({
}); });
const roleSelect = useMemo(() => { const roleSelect = useMemo(() => {
return allRoles?.map((role) => ({ value: String(role.name), label: role.name })); return loggedUser?.user?.roles?.map((role) => ({ value: String(role.name), label: role.name }));
}, [allRoles]); }, [loggedUser?.user?.roles]);
return ( return (
<Modal opened={opened} onClose={onClose} title={t("create productor", { capfirst: true })}> <Modal opened={opened} onClose={onClose} title={t("create productor", { capfirst: true })}>

View File

@@ -59,6 +59,8 @@ export function ProductModal({ opened, onClose, currentProduct, handleSubmit }:
}); });
const productorsSelect = useMemo(() => { const productorsSelect = useMemo(() => {
if (!productors)
return [];
return productors?.map((productor) => ({ return productors?.map((productor) => ({
value: String(productor.id), value: String(productor.id),
label: `${productor.name}`, label: `${productor.name}`,

View File

@@ -13,7 +13,7 @@ import { IconCancel, IconEdit, IconPlus } from "@tabler/icons-react";
import { useForm } from "@mantine/form"; import { useForm } from "@mantine/form";
import { useMemo } from "react"; import { useMemo } from "react";
import { type Shipment, type ShipmentInputs } from "@/services/resources/shipments"; import { type Shipment, type ShipmentInputs } from "@/services/resources/shipments";
import { useGetForms, useGetProductors, useGetProducts } from "@/services/api"; import { useGetReferentForms, useGetProductors, useGetProducts } from "@/services/api";
export type ShipmentModalProps = ModalBaseProps & { export type ShipmentModalProps = ModalBaseProps & {
currentShipment?: Shipment; currentShipment?: Shipment;
@@ -43,11 +43,13 @@ export default function ShipmentModal({
}, },
}); });
const { data: allForms } = useGetForms(); const { data: allForms } = useGetReferentForms();
const { data: allProducts } = useGetProducts(new URLSearchParams("types=1")); const { data: allProducts } = useGetProducts(new URLSearchParams("types=1"));
const { data: allProductors } = useGetProductors(); const { data: allProductors } = useGetProductors();
const formsSelect = useMemo(() => { const formsSelect = useMemo(() => {
if (!allForms)
return [];
return allForms?.map((currentForm) => ({ return allForms?.map((currentForm) => ({
value: String(currentForm.id), value: String(currentForm.id),
label: `${currentForm.name} ${currentForm.season}`, label: `${currentForm.name} ${currentForm.season}`,
@@ -55,7 +57,7 @@ export default function ShipmentModal({
}, [allForms]); }, [allForms]);
const productsSelect = useMemo(() => { const productsSelect = useMemo(() => {
if (!allProducts || !allProductors) return; if (!allProducts || !allProductors) return [];
return allProductors?.map((productor) => { return allProductors?.map((productor) => {
return { return {
group: productor.name, group: productor.name,

View File

@@ -36,6 +36,8 @@ export function UserModal({ opened, onClose, currentUser, handleSubmit }: UserMo
}); });
const roleSelect = useMemo(() => { const roleSelect = useMemo(() => {
if (!allRoles)
return [];
return allRoles?.map((role) => ({ value: String(role.name), label: role.name })); return allRoles?.map((role) => ({ value: String(role.name), label: role.name }));
}, [allRoles]); }, [allRoles]);

View File

@@ -165,7 +165,7 @@ export function Contract() {
); );
return ( return (
<Stack w={{ base: "100%", md: "80%", lg: "50%" }}> <Stack w={{ base: "100%", md: "80%", lg: "50%" }} p={{base: 'xs'}}>
<Title order={2}>{form.name}</Title> <Title order={2}>{form.name}</Title>
<Title order={3}>{t("informations", { capfirst: true })}</Title> <Title order={3}>{t("informations", { capfirst: true })}</Title>
<Text size="sm"> <Text size="sm">
@@ -283,6 +283,10 @@ export function Contract() {
ref={(el) => { ref={(el) => {
inputRefs.current.payment_method = el; inputRefs.current.payment_method = el;
}} }}
comboboxProps={{
withinPortal: false,
position: "bottom-start",
}}
/> />
{inputForm.values.payment_method === "cheque" ? ( {inputForm.values.payment_method === "cheque" ? (
<ContractCheque <ContractCheque
@@ -319,7 +323,7 @@ export function Contract() {
<Button <Button
leftSection={<IconDownload/>} leftSection={<IconDownload/>}
aria-label={t("submit contracts")} onClick={handleSubmit}> aria-label={t("submit contracts")} onClick={handleSubmit}>
{t("submit contract", {capfirst: true})} {t("submit", {capfirst: true})}
</Button> </Button>
</Overlay> </Overlay>
</Stack> </Stack>

View File

@@ -27,6 +27,8 @@ export default function Contracts() {
const { data: allContracts } = useGetContracts(); const { data: allContracts } = useGetContracts();
const forms = useMemo(() => { const forms = useMemo(() => {
if (!allContracts)
return [];
return allContracts return allContracts
?.map((contract: Contract) => contract.form.name) ?.map((contract: Contract) => contract.form.name)
.filter((contract, index, array) => array.indexOf(contract) === index); .filter((contract, index, array) => array.indexOf(contract) === index);
@@ -89,7 +91,7 @@ export default function Contracts() {
label={t("download recap", { capfirst: true })} label={t("download recap", { capfirst: true })}
> >
<ActionIcon <ActionIcon
disabled={true} disabled={false}
onClick={(e) => { onClick={(e) => {
e.stopPropagation(); e.stopPropagation();
navigate( navigate(

View File

@@ -1,5 +1,5 @@
import { Stack, Loader, Title, Group, ActionIcon, Tooltip, Table, ScrollArea } from "@mantine/core"; import { Stack, Loader, Title, Group, ActionIcon, Tooltip, Table, ScrollArea } from "@mantine/core";
import { useCreateForm, useEditForm, useGetForm, useGetForms } from "@/services/api"; import { useCreateForm, useEditForm, useGetForm, useGetReferentForms } from "@/services/api";
import { t } from "@/config/i18n"; import { t } from "@/config/i18n";
import { useLocation, useNavigate, useSearchParams } from "react-router"; import { useLocation, useNavigate, useSearchParams } from "react-router";
import { IconPlus } from "@tabler/icons-react"; import { IconPlus } from "@tabler/icons-react";
@@ -28,12 +28,12 @@ export function Forms() {
navigate(`/dashboard/forms${searchParams ? `?${searchParams.toString()}` : ""}`); navigate(`/dashboard/forms${searchParams ? `?${searchParams.toString()}` : ""}`);
}, [navigate, searchParams]); }, [navigate, searchParams]);
const { isPending, data } = useGetForms(searchParams); const { isPending, data } = useGetReferentForms(searchParams);
const { data: currentForm } = useGetForm(Number(editId), { const { data: currentForm } = useGetForm(Number(editId), {
enabled: !!editId, enabled: !!editId,
}); });
const { data: allForms } = useGetForms(); const { data: allForms } = useGetReferentForms();
const seasons = useMemo(() => { const seasons = useMemo(() => {
return allForms return allForms
@@ -148,6 +148,7 @@ export function Forms() {
<Table striped> <Table striped>
<Table.Thead> <Table.Thead>
<Table.Tr> <Table.Tr>
<Table.Th>{t("visible", { capfirst: true })}</Table.Th>
<Table.Th>{t("name", { capfirst: true })}</Table.Th> <Table.Th>{t("name", { capfirst: true })}</Table.Th>
<Table.Th>{t("type", { capfirst: true })}</Table.Th> <Table.Th>{t("type", { capfirst: true })}</Table.Th>
<Table.Th>{t("start", { capfirst: true })}</Table.Th> <Table.Th>{t("start", { capfirst: true })}</Table.Th>

View File

@@ -40,12 +40,16 @@ export default function Productors() {
}, [navigate, searchParams]); }, [navigate, searchParams]);
const names = useMemo(() => { const names = useMemo(() => {
if (!allProductors)
return [];
return allProductors return allProductors
?.map((productor: Productor) => productor.name) ?.map((productor: Productor) => productor.name)
.filter((season, index, array) => array.indexOf(season) === index); .filter((season, index, array) => array.indexOf(season) === index);
}, [allProductors]); }, [allProductors]);
const types = useMemo(() => { const types = useMemo(() => {
if (!allProductors)
return [];
return allProductors return allProductors
?.map((productor: Productor) => productor.type) ?.map((productor: Productor) => productor.type)
.filter((productor, index, array) => array.indexOf(productor) === index); .filter((productor, index, array) => array.indexOf(productor) === index);

View File

@@ -38,12 +38,16 @@ export default function Products() {
const { data: allProducts } = useGetProducts(); const { data: allProducts } = useGetProducts();
const names = useMemo(() => { const names = useMemo(() => {
if (!allProducts)
return [];
return allProducts return allProducts
?.map((product: Product) => product.name) ?.map((product: Product) => product.name)
.filter((season, index, array) => array.indexOf(season) === index); .filter((season, index, array) => array.indexOf(season) === index);
}, [allProducts]); }, [allProducts]);
const productors = useMemo(() => { const productors = useMemo(() => {
if (!allProducts)
return [];
return allProducts return allProducts
?.map((product: Product) => product.productor.name) ?.map((product: Product) => product.productor.name)
.filter((productor, index, array) => array.indexOf(productor) === index); .filter((productor, index, array) => array.indexOf(productor) === index);

View File

@@ -44,12 +44,16 @@ export default function Shipments() {
const { data: allShipments } = useGetShipments(); const { data: allShipments } = useGetShipments();
const names = useMemo(() => { const names = useMemo(() => {
if (!allShipments)
return [];
return allShipments return allShipments
?.map((shipment: Shipment) => shipment.name) ?.map((shipment: Shipment) => shipment.name)
.filter((season, index, array) => array.indexOf(season) === index); .filter((season, index, array) => array.indexOf(season) === index);
}, [allShipments]); }, [allShipments]);
const forms = useMemo(() => { const forms = useMemo(() => {
if (!allShipments)
return [];
return allShipments return allShipments
?.map((shipment: Shipment) => shipment.form.name) ?.map((shipment: Shipment) => shipment.form.name)
.filter((season, index, array) => array.indexOf(season) === index); .filter((season, index, array) => array.indexOf(season) === index);

View File

@@ -36,6 +36,8 @@ export default function Users() {
const { data: allUsers } = useGetUsers(); const { data: allUsers } = useGetUsers();
const names = useMemo(() => { const names = useMemo(() => {
if (!allUsers)
return [];
return allUsers return allUsers
?.map((user: User) => user.name) ?.map((user: User) => user.name)
.filter((season, index, array) => array.indexOf(season) === index); .filter((season, index, array) => array.indexOf(season) === index);

View File

@@ -330,6 +330,15 @@ export function useGetForms(filters?: URLSearchParams): UseQueryResult<Form[], E
}); });
} }
export function useGetReferentForms(filters?: URLSearchParams): UseQueryResult<Form[], Error> {
const queryString = filters?.toString();
return useQuery<Form[]>({
queryKey: ["forms", queryString],
queryFn: () =>
fetchWithAuth(`${Config.backend_uri}/forms/referents${filters ? `?${queryString}` : ""}`).then((res) => res.json()),
});
}
export function useCreateForm() { export function useCreateForm() {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
@@ -711,7 +720,6 @@ export function useGetContractFile() {
}); });
} }
export function useGetContractFileTemplate() { export function useGetContractFileTemplate() {
return useMutation({ return useMutation({
mutationFn: async (form_id: number) => { mutationFn: async (form_id: number) => {

View File

@@ -12,6 +12,7 @@ export type Form = {
referer: User; referer: User;
shipments: Shipment[]; shipments: Shipment[];
minimum_shipment_value: number | null; minimum_shipment_value: number | null;
visible: boolean;
}; };
export type FormCreate = { export type FormCreate = {
@@ -22,6 +23,7 @@ export type FormCreate = {
productor_id: number; productor_id: number;
referer_id: number; referer_id: number;
minimum_shipment_value: number | null; minimum_shipment_value: number | null;
visible: boolean;
}; };
export type FormEdit = { export type FormEdit = {
@@ -32,6 +34,7 @@ export type FormEdit = {
productor_id?: number | null; productor_id?: number | null;
referer_id?: number | null; referer_id?: number | null;
minimum_shipment_value: number | null; minimum_shipment_value: number | null;
visible: boolean;
}; };
export type FormEditPayload = { export type FormEditPayload = {
@@ -47,4 +50,5 @@ export type FormInputs = {
productor_id: string; productor_id: string;
referer_id: string; referer_id: string;
minimum_shipment_value: number | string | null; minimum_shipment_value: number | string | null;
visible: boolean;
}; };