Compare commits

12 Commits

Author SHA1 Message Date
5e413b11e0 add permission check for form productor and product 2026-03-04 23:36:17 +01:00
Julien Aldon
6679107b13 downgrade python version in tests
All checks were successful
Deploy Amap / deploy (push) Successful in 1m47s
2026-03-03 14:34:00 +01:00
Julien Aldon
20eba7f183 remove debug in router
Some checks failed
Deploy Amap / deploy (push) Failing after 11s
2026-03-03 11:37:40 +01:00
Julien Aldon
c6d75831c9 add debug for tests
Some checks failed
Deploy Amap / deploy (push) Failing after 11s
2026-03-03 11:37:23 +01:00
Julien Aldon
b2e2d02818 add debug for tests
Some checks failed
Deploy Amap / deploy (push) Failing after 10s
2026-03-03 11:36:22 +01:00
Julien Aldon
8cb7893aff add debug for tests
Some checks failed
Deploy Amap / deploy (push) Failing after 10s
2026-03-03 11:34:41 +01:00
Julien Aldon
015e09a980 add debug for tests
Some checks failed
Deploy Amap / deploy (push) Failing after 10s
2026-03-03 11:31:16 +01:00
Julien Aldon
a70ab5d3cb add debug for tests
Some checks failed
Deploy Amap / deploy (push) Failing after 10s
2026-03-03 11:29:57 +01:00
Julien Aldon
9d5dbd80cc add debug for tests
Some checks failed
Deploy Amap / deploy (push) Failing after 10s
2026-03-03 11:27:01 +01:00
Julien Aldon
1c6e810ec1 fix python test version
Some checks failed
Deploy Amap / deploy (push) Failing after 1m30s
2026-03-03 11:15:39 +01:00
Julien Aldon
8352097ffb fix pylint errors
Some checks failed
Deploy Amap / deploy (push) Failing after 16s
2026-03-03 11:08:08 +01:00
Julien Aldon
0e48d1bbaa fix test format 2026-03-02 16:33:52 +01:00
61 changed files with 2155 additions and 1108 deletions

View File

@@ -0,0 +1,26 @@
default_language_version:
python: python3.13
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-added-large-files
- id: trailing-whitespace
- id: check-ast
- id: check-builtin-literals
- id: check-docstring-first
- id: check-yaml
- id: check-toml
- id: mixed-line-ending
- id: end-of-file-fixer
- repo: local
hooks:
- id: check-pylint
name: check-pylint
entry: pylint -d R0801,R0903,W0511,W0603,C0103,R0902
language: system
types: [python]
pass_filenames: false
args:
- backend

View File

@@ -35,6 +35,12 @@ hatch run pytest
hatch run pytest --cov=src -vv
```
## Autoformat
```console
find -type f -name '*.py' ! -path 'alembic/*' -exec autopep8 --in-place --aggressive --aggressive '{}' \;
pylint -d R0801,R0903,W0511,W0603,C0103,R0902 .
```
## License
`backend` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license.

View File

@@ -25,7 +25,8 @@ target_metadata = SQLModel.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
config.set_main_option("sqlalchemy.url", f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}')
config.set_main_option(
"sqlalchemy.url", f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}')
# ... etc.

View File

@@ -22,7 +22,12 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('paymentmethod', sa.Column('max', sa.Integer(), nullable=True))
op.add_column(
'paymentmethod',
sa.Column(
'max',
sa.Integer(),
nullable=True))
# ### end Alembic commands ###

View File

@@ -22,117 +22,121 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('contracttype',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('productor',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('address', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table(
'contracttype',
sa.Column(
'id',
sa.Integer(),
nullable=False),
sa.Column(
'name',
sqlmodel.sql.sqltypes.AutoString(),
nullable=False),
sa.PrimaryKeyConstraint('id'))
op.create_table(
'productor', sa.Column(
'name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
'address', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
'type', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
'id', sa.Integer(), nullable=False), sa.PrimaryKeyConstraint('id'))
op.create_table('template',
sa.Column('id', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('user',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
sa.Column('id', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table(
'user', sa.Column(
'name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
'email', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column(
'id', sa.Integer(), nullable=False), sa.PrimaryKeyConstraint('id'))
op.create_table('form',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('productor_id', sa.Integer(), nullable=True),
sa.Column('referer_id', sa.Integer(), nullable=True),
sa.Column('season', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('start', sa.Date(), nullable=False),
sa.Column('end', sa.Date(), nullable=False),
sa.Column('minimum_shipment_value', sa.Float(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ),
sa.ForeignKeyConstraint(['referer_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('productor_id', sa.Integer(), nullable=True),
sa.Column('referer_id', sa.Integer(), nullable=True),
sa.Column('season', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('start', sa.Date(), nullable=False),
sa.Column('end', sa.Date(), nullable=False),
sa.Column('minimum_shipment_value', sa.Float(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ),
sa.ForeignKeyConstraint(['referer_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('paymentmethod',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('details', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('productor_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('details', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('productor_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('product',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('unit', sa.Enum('GRAMS', 'KILO', 'PIECE', name='unit'), nullable=False),
sa.Column('price', sa.Float(), nullable=True),
sa.Column('price_kg', sa.Float(), nullable=True),
sa.Column('quantity', sa.Float(), nullable=True),
sa.Column('quantity_unit', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('type', sa.Enum('OCCASIONAL', 'RECCURENT', name='producttype'), nullable=False),
sa.Column('productor_id', sa.Integer(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('usercontracttypelink',
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('contract_type_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['contract_type_id'], ['contracttype.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('user_id', 'contract_type_id')
)
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('unit', sa.Enum('GRAMS', 'KILO', 'PIECE', name='unit'), nullable=False),
sa.Column('price', sa.Float(), nullable=True),
sa.Column('price_kg', sa.Float(), nullable=True),
sa.Column('quantity', sa.Float(), nullable=True),
sa.Column('quantity_unit', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('type', sa.Enum('OCCASIONAL', 'RECCURENT', name='producttype'), nullable=False),
sa.Column('productor_id', sa.Integer(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table(
'usercontracttypelink', sa.Column(
'user_id', sa.Integer(), nullable=False), sa.Column(
'contract_type_id', sa.Integer(), nullable=False), sa.ForeignKeyConstraint(
['contract_type_id'], ['contracttype.id'], ), sa.ForeignKeyConstraint(
['user_id'], ['user.id'], ), sa.PrimaryKeyConstraint(
'user_id', 'contract_type_id'))
op.create_table('contract',
sa.Column('firstname', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('lastname', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('phone', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('payment_method', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('cheque_quantity', sa.Integer(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('form_id', sa.Integer(), nullable=False),
sa.Column('file', sa.LargeBinary(), nullable=True),
sa.Column('total_price', sa.Float(), nullable=True),
sa.ForeignKeyConstraint(['form_id'], ['form.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
sa.Column('firstname', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('lastname', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('phone', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('payment_method', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('cheque_quantity', sa.Integer(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('form_id', sa.Integer(), nullable=False),
sa.Column('file', sa.LargeBinary(), nullable=True),
sa.Column('total_price', sa.Float(), nullable=True),
sa.ForeignKeyConstraint(['form_id'], ['form.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('shipment',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('date', sa.Date(), nullable=False),
sa.Column('form_id', sa.Integer(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['form_id'], ['form.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('date', sa.Date(), nullable=False),
sa.Column('form_id', sa.Integer(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['form_id'], ['form.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('cheque',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('value', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('contract_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['contract_id'], ['contract.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('value', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('contract_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['contract_id'], ['contract.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('contractproduct',
sa.Column('product_id', sa.Integer(), nullable=False),
sa.Column('shipment_id', sa.Integer(), nullable=True),
sa.Column('quantity', sa.Float(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('contract_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['contract_id'], ['contract.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['product_id'], ['product.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['shipment_id'], ['shipment.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
sa.Column('product_id', sa.Integer(), nullable=False),
sa.Column('shipment_id', sa.Integer(), nullable=True),
sa.Column('quantity', sa.Float(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('contract_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['contract_id'], ['contract.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['product_id'], ['product.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['shipment_id'], ['shipment.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('shipmentproductlink',
sa.Column('shipment_id', sa.Integer(), nullable=False),
sa.Column('product_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['product_id'], ['product.id'], ),
sa.ForeignKeyConstraint(['shipment_id'], ['shipment.id'], ),
sa.PrimaryKeyConstraint('shipment_id', 'product_id')
)
sa.Column('shipment_id', sa.Integer(), nullable=False),
sa.Column('product_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['product_id'], ['product.id'], ),
sa.ForeignKeyConstraint(['shipment_id'], ['shipment.id'], ),
sa.PrimaryKeyConstraint('shipment_id', 'product_id')
)
# ### end Alembic commands ###

View File

@@ -22,7 +22,14 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('form', sa.Column('visible', sa.Boolean(), nullable=False, default=False, server_default="False"))
op.add_column(
'form',
sa.Column(
'visible',
sa.Boolean(),
nullable=False,
default=False,
server_default="False"))
# ### end Alembic commands ###

View File

@@ -34,6 +34,9 @@ dependencies = [
"pytest",
"pytest-cov",
"pytest-mock",
"autopep8",
"prek",
"pylint",
]
[project.urls]

View File

@@ -0,0 +1,84 @@
alembic==1.18.4
annotated-doc==0.0.4
annotated-types==0.7.0
anyio==4.12.1
astroid==4.0.4
autopep8==2.3.2
brotli==1.2.0
certifi==2026.2.25
cffi==2.0.0
charset-normalizer==3.4.4
click==8.3.1
coverage==7.13.4
cryptography==46.0.5
cssselect2==0.9.0
dill==0.4.1
dnspython==2.8.0
email-validator==2.3.0
fastapi==0.135.1
fastapi-cli==0.0.24
fastapi-cloud-cli==0.14.0
fastar==0.8.0
fonttools==4.61.1
greenlet==3.3.2
h11==0.16.0
httpcore==1.0.9
httptools==0.7.1
httpx==0.28.1
idna==3.11
iniconfig==2.3.0
isort==8.0.1
Jinja2==3.1.6
lxml==6.0.2
Mako==1.3.10
markdown-it-py==4.0.0
MarkupSafe==3.0.3
mccabe==0.7.0
mdurl==0.1.2
odfdo==3.21.0
packaging==26.0
pillow==12.1.1
platformdirs==4.9.2
pluggy==1.6.0
prek==0.3.4
psycopg2-binary==2.9.11
pycodestyle==2.14.0
pycparser==3.0
pydantic==2.12.5
pydantic-extra-types==2.11.0
pydantic-settings==2.13.1
pydantic_core==2.41.5
pydyf==0.12.1
Pygments==2.19.2
PyJWT==2.11.0
pylint==4.0.5
pyphen==0.17.2
pytest==9.0.2
pytest-cov==7.0.0
pytest-mock==3.15.1
python-dotenv==1.2.2
python-multipart==0.0.22
PyYAML==6.0.3
requests==2.32.5
rich==14.3.3
rich-toolkit==0.19.7
rignore==0.7.6
sentry-sdk==2.53.0
shellingham==1.5.4
SQLAlchemy==2.0.47
sqlmodel==0.0.37
starlette==0.52.1
tinycss2==1.5.1
tinyhtml5==2.0.0
tomlkit==0.14.0
typer==0.24.1
typing-inspection==0.4.2
typing_extensions==4.15.0
urllib3==2.6.3
uvicorn==0.41.0
uvloop==0.22.1
watchfiles==1.1.1
weasyprint==68.1
webencodings==0.5.1
websockets==16.0
zopfli==0.4.1

View File

@@ -1,26 +1,28 @@
from typing import Annotated
from fastapi import APIRouter, Security, HTTPException, Depends, Request, Cookie
from fastapi.responses import RedirectResponse, Response
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlmodel import Session, select
import jwt
from jwt import PyJWKClient
from src.settings import AUTH_URL, TOKEN_URL, JWKS_URL, ISSUER, LOGOUT_URL, settings
import src.users.service as service
from src.database import get_session
from src.models import UserCreate, User, UserPublic
import secrets
import requests
from typing import Annotated
from urllib.parse import urlencode
import jwt
import requests
import src.messages as messages
import src.users.service as service
from fastapi import (APIRouter, Cookie, Depends, HTTPException, Request,
Security)
from fastapi.responses import RedirectResponse, Response
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jwt import PyJWKClient
from sqlmodel import Session, select
from src.database import get_session
from src.models import User, UserCreate, UserPublic
from src.settings import (AUTH_URL, ISSUER, JWKS_URL, LOGOUT_URL, TOKEN_URL,
settings)
router = APIRouter(prefix='/auth')
jwk_client = PyJWKClient(JWKS_URL)
security = HTTPBearer()
@router.get('/logout')
def logout():
params = {
@@ -59,9 +61,11 @@ def login():
'redirect_uri': settings.keycloak_redirect_uri,
'state': state,
}
request_url = requests.Request('GET', AUTH_URL, params=params).prepare().url
request_url = requests.Request(
'GET', AUTH_URL, params=params).prepare().url
return RedirectResponse(request_url)
@router.get('/callback')
def callback(code: str, session: Session = Depends(get_session)):
data = {
@@ -85,7 +89,9 @@ def callback(code: str, session: Session = Depends(get_session)):
id_token = token_data['id_token']
decoded_token = jwt.decode(id_token, options={'verify_signature': False})
decoded_access_token = jwt.decode(token_data['access_token'], options={'verify_signature': False})
decoded_access_token = jwt.decode(
token_data['access_token'], options={
'verify_signature': False})
resource_access = decoded_access_token.get('resource_access')
if not resource_access:
data = {
@@ -93,7 +99,7 @@ def callback(code: str, session: Session = Depends(get_session)):
'client_secret': settings.keycloak_client_secret,
'refresh_token': token_data['refresh_token'],
}
res = requests.post(LOGOUT_URL, data=data)
requests.post(LOGOUT_URL, data=data)
resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true')
return resp
roles = resource_access.get(settings.keycloak_client_id)
@@ -103,7 +109,7 @@ def callback(code: str, session: Session = Depends(get_session)):
'client_secret': settings.keycloak_client_secret,
'refresh_token': token_data['refresh_token'],
}
res = requests.post(LOGOUT_URL, data=data)
requests.post(LOGOUT_URL, data=data)
resp = RedirectResponse(f'{settings.origins}?userNotAllowed=true')
return resp
@@ -141,6 +147,7 @@ def callback(code: str, session: Session = Depends(get_session)):
return response
def verify_token(token: str):
try:
signing_key = jwk_client.get_signing_key_from_jwt(token)
@@ -154,28 +161,49 @@ def verify_token(token: str):
)
return decoded
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail=messages.Messages.tokenexipired)
raise HTTPException(
status_code=401,
detail=messages.Messages.tokenexipired
)
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail=messages.Messages.invalidtoken)
raise HTTPException(
status_code=401,
detail=messages.Messages.invalidtoken
)
def get_current_user(request: Request, session: Session = Depends(get_session)):
def get_current_user(
request: Request,
session: Session = Depends(get_session)):
access_token = request.cookies.get('access_token')
if not access_token:
raise HTTPException(status_code=401, detail=messages.Messages.notauthenticated)
raise HTTPException(
status_code=401,
detail=messages.Messages.notauthenticated
)
payload = verify_token(access_token)
if not payload:
raise HTTPException(status_code=401, detail='aze')
raise HTTPException(
status_code=401,
detail='aze'
)
email = payload.get('email')
if not email:
raise HTTPException(status_code=401, detail=messages.Messages.notauthenticated)
raise HTTPException(
status_code=401,
detail=messages.Messages.notauthenticated
)
user = session.exec(select(User).where(User.email == email)).first()
if not user:
raise HTTPException(status_code=401, detail=messages.Messages.not_found('user'))
raise HTTPException(
status_code=401,
detail=messages.Messages.not_found('user')
)
return user
@router.post('/refresh')
def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
refresh = refresh_token
@@ -223,6 +251,7 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None):
)
return response
@router.get('/user/me')
def me(user: UserPublic = Depends(get_current_user)):
if not user:
@@ -233,6 +262,6 @@ def me(user: UserPublic = Depends(get_current_user)):
'name': user.name,
'email': user.email,
'id': user.id,
'roles': [role.name for role in user.roles]
'roles': user.roles
}
}

View File

@@ -1,18 +1,27 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from src.database import get_session
from sqlmodel import Session
from src.contracts.generate_contract import generate_html_contract, generate_recap
from src.auth.auth import get_current_user
import src.models as models
import src.messages as messages
import src.contracts.service as service
import src.forms.service as form_service
"""Router for contract resource"""
import io
import zipfile
import src.contracts.service as service
import src.forms.service as form_service
import src.messages as messages
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user
from src.contracts.generate_contract import (generate_html_contract,
generate_recap)
from src.database import get_session
router = APIRouter(prefix='/contracts')
def compute_recurrent_prices(products_quantities: list[dict], nb_shipment: int):
def compute_recurrent_prices(
products_quantities: list[dict],
nb_shipment: int
):
"""Compute price for recurrent products"""
result = 0
for product_quantity in products_quantities:
product = product_quantity['product']
@@ -20,25 +29,45 @@ def compute_recurrent_prices(products_quantities: list[dict], nb_shipment: int):
result += compute_product_price(product, quantity, nb_shipment)
return result
def compute_occasional_prices(occasionals: list[dict]):
"""Compute prices for occassional products"""
result = 0
for occasional in occasionals:
result += occasional['price']
return result
def compute_product_price(product: models.Product, quantity: int, nb_shipment: int = 1):
product_quantity_unit = 1 if product.unit == models.Unit.KILO else 1000
final_quantity = quantity if product.price else quantity / product_quantity_unit
final_price = product.price if product.price else product.price_kg
def compute_product_price(
product: models.Product,
quantity: int,
nb_shipment: int = 1
):
"""Compute price for a product"""
product_quantity_unit = (
1 if product.unit == models.Unit.KILO else 1000
)
final_quantity = (
quantity if product.price else quantity / product_quantity_unit
)
final_price = (
product.price if product.price else product.price_kg
)
return final_price * final_quantity * nb_shipment
def find_dict_in_list(lst, key, value):
"""Find the index of a dictionnary in a list of dictionnaries given a key
and a value.
"""
for i, dic in enumerate(lst):
if dic[key].id == value:
return i
return -1
def create_occasional_dict(contract_products: list[models.ContractProduct]):
"""Create a dictionnary of occasional products"""
result = []
for contract_product in contract_products:
existing_id = find_dict_in_list(
@@ -69,18 +98,46 @@ def create_occasional_dict(contract_products: list[models.ContractProduct]):
)
return result
@router.post('')
async def create_contract(
contract: models.ContractCreate,
session: Session = Depends(get_session),
):
"""Create contract route"""
new_contract = service.create_one(session, contract)
occasional_contract_products = list(filter(lambda contract_product: contract_product.product.type == models.ProductType.OCCASIONAL, new_contract.products))
occasional_contract_products = list(
filter(
lambda contract_product: (
contract_product.product.type == models.ProductType.OCCASIONAL
),
new_contract.products
)
)
occasionals = create_occasional_dict(occasional_contract_products)
recurrents = list(map(lambda x: {"product": x.product, "quantity": x.quantity}, filter(lambda contract_product: contract_product.product.type == models.ProductType.RECCURENT, new_contract.products)))
recurrent_price = compute_recurrent_prices(recurrents, len(new_contract.form.shipments))
recurrents = list(
map(
lambda x: {'product': x.product, 'quantity': x.quantity},
filter(
lambda contract_product: (
contract_product.product.type ==
models.ProductType.RECCURENT
),
new_contract.products
)
)
)
recurrent_price = compute_recurrent_prices(
recurrents,
len(new_contract.form.shipments)
)
price = recurrent_price + compute_occasional_prices(occasionals)
cheques = list(map(lambda x: {"name": x.name, "value": x.value}, new_contract.cheques))
cheques = list(
map(
lambda x: {'name': x.name, 'value': x.value},
new_contract.cheques
)
)
try:
pdf_bytes = generate_html_contract(
new_contract,
@@ -91,43 +148,63 @@ async def create_contract(
'{:10.2f}'.format(price)
)
pdf_file = io.BytesIO(pdf_bytes)
contract_id = f'{new_contract.firstname}_{new_contract.lastname}_{new_contract.form.productor.type}_{new_contract.form.season}'
contract_id = (
f'{new_contract.firstname}_'
f'{new_contract.lastname}_'
f'{new_contract.form.productor.type}_'
f'{new_contract.form.season}'
)
service.add_contract_file(session, new_contract.id, pdf_bytes, price)
except Exception:
raise HTTPException(status_code=400, detail=messages.pdferror)
except Exception as error:
raise HTTPException(
status_code=400,
detail=messages.pdferror
) from error
return StreamingResponse(
pdf_file,
media_type='application/pdf',
headers={
'Content-Disposition': f'attachment; filename=contract_{contract_id}.pdf'
'Content-Disposition': (
f'attachment; filename=contract_{contract_id}.pdf'
)
}
)
@router.get('/{form_id}/base')
async def get_base_contract_template(
form_id: int,
session: Session = Depends(get_session),
):
"""Get contract template route"""
form = form_service.get_one(session, form_id)
recurrents = list(map(lambda x: {"product": x, "quantity": None}, filter(lambda product: product.type == models.ProductType.RECCURENT, form.productor.products)))
recurrents = [
{'product': product, 'quantity': None}
for product in form.productor.products
if product.type == models.ProductType.RECCURENT
]
occasionals = [{
'shipment': sh,
'price': None,
'products': [{'product': pr, 'quantity': None} for pr in sh.products]
} for sh in form.shipments]
empty_contract = models.ContractPublic(
firstname="",
firstname='',
form=form,
lastname="",
email="",
phone="",
lastname='',
email='',
phone='',
products=[],
payment_method="cheque",
payment_method='cheque',
cheque_quantity=3,
total_price=0,
id=1
)
cheques = [{"name": None, "value": None}, {"name": None, "value": None}, {"name": None, "value": None}]
cheques = [
{'name': None, 'value': None},
{'name': None, 'value': None},
{'name': None, 'value': None}
]
try:
pdf_bytes = generate_html_contract(
empty_contract,
@@ -136,38 +213,60 @@ async def get_base_contract_template(
recurrents,
)
pdf_file = io.BytesIO(pdf_bytes)
contract_id = f'{empty_contract.form.productor.type}_{empty_contract.form.season}'
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail=messages.pdferror)
contract_id = (
f'{empty_contract.form.productor.type}_'
f'{empty_contract.form.season}'
)
except Exception as error:
raise HTTPException(
status_code=400,
detail=messages.pdferror
) from error
return StreamingResponse(
pdf_file,
media_type='application/pdf',
headers={
'Content-Disposition': f'attachment; filename=contract_{contract_id}.pdf'
'Content-Disposition': (
f'attachment; filename=contract_{contract_id}.pdf'
)
}
)
@router.get('', response_model=list[models.ContractPublic])
def get_contracts(
forms: list[str] = Query([]),
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Get all contracts route"""
return service.get_all(session, user, forms)
@router.get('/{id}/file')
@router.get('/{_id}/file')
def get_contract_file(
id: int,
_id: int,
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
if not service.is_allowed(session, user, id):
raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contract', 'get'))
contract = service.get_one(session, id)
"""Get a contract file (in pdf) route"""
if not service.is_allowed(session, user, _id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract', 'get')
)
contract = service.get_one(session, _id)
if contract is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('contract'))
filename = f'{contract.form.name.replace(' ', '_')}_{contract.form.season}_{contract.firstname}-{contract.lastname}'
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('contract')
)
filename = (
f'{contract.form.name.replace(' ', '_')}_'
f'{contract.form.season}_'
f'{contract.firstname}_'
f'{contract.lastname}'
)
return StreamingResponse(
io.BytesIO(contract.file),
media_type='application/pdf',
@@ -176,23 +275,37 @@ def get_contract_file(
}
)
@router.get('/{form_id}/files')
def get_contract_files(
form_id: int,
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Get all contract files for a given form"""
if not form_service.is_allowed(session, user, form_id):
raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contracts', 'get'))
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contracts', 'get')
)
form = form_service.get_one(session, form_id=form_id)
contracts = service.get_all(session, user, forms=[form.name])
zipped_contracts = io.BytesIO()
with zipfile.ZipFile(zipped_contracts, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
with zipfile.ZipFile(
zipped_contracts,
'a',
zipfile.ZIP_DEFLATED,
False
) as zip_file:
for contract in contracts:
contract_filename = f'{contract.form.name.replace(' ', '_')}_{contract.form.season}_{contract.firstname}-{contract.lastname}.pdf'
contract_filename = (
f'{contract.form.name.replace(' ', '_')}_'
f'{contract.form.season}_'
f'{contract.firstname}_'
f'{contract.lastname}'
)
zip_file.writestr(contract_filename, contract.file)
filename = f'{form.name.replace(" ", "_")}_{form.season}'
filename = f'{form.name.replace(' ', '_')}_{form.season}'
return StreamingResponse(
io.BytesIO(zipped_contracts.getvalue()),
media_type='application/zip',
@@ -201,39 +314,69 @@ def get_contract_files(
}
)
@router.get('/{form_id}/recap')
def get_contract_recap(
form_id: int,
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Get a contract recap for a given form"""
if not form_service.is_allowed(session, user, form_id):
raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contract recap', 'get'))
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract recap', 'get')
)
form = form_service.get_one(session, form_id=form_id)
contracts = service.get_all(session, user, forms=[form.name])
return StreamingResponse(
io.BytesIO(generate_recap(contracts, form)),
media_type='application/zip',
headers={
'Content-Disposition': f'attachment; filename=filename.ods'
'Content-Disposition': (
'attachment; filename=filename.ods'
)
}
)
@router.get('/{id}', response_model=models.ContractPublic)
def get_contract(id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user)):
if not service.is_allowed(session, user, id):
raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contract', 'get'))
result = service.get_one(session, id)
@router.get('/{_id}', response_model=models.ContractPublic)
def get_contract(
_id: int,
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Get a contract route"""
if not service.is_allowed(session, user, _id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract', 'get')
)
result = service.get_one(session, _id)
if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('contract'))
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('contract')
)
return result
@router.delete('/{id}', response_model=models.ContractPublic)
def delete_contract(id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user)):
if not service.is_allowed(session, user, id):
raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contract', 'delete'))
result = service.delete_one(session, id)
@router.delete('/{_id}', response_model=models.ContractPublic)
def delete_contract(
_id: int,
session: Session = Depends(get_session),
user: models.User = Depends(get_current_user)
):
"""Delete contract route"""
if not service.is_allowed(session, user, _id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('contract', 'delete')
)
result = service.delete_one(session, _id)
if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('contract'))
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('contract')
)
return result

View File

@@ -1,11 +1,13 @@
import html
import io
import pathlib
import jinja2
import src.models as models
import html
from odfdo import Cell, Document, Row, Table
from src import models
from weasyprint import HTML
import io
import pathlib
def generate_html_contract(
contract: models.Contract,
@@ -17,7 +19,8 @@ def generate_html_contract(
):
template_dir = pathlib.Path("./src/contracts/templates").resolve()
template_loader = jinja2.FileSystemLoader(searchpath=template_dir)
template_env = jinja2.Environment(loader=template_loader, autoescape=jinja2.select_autoescape(["html", "xml"]))
template_env = jinja2.Environment(
loader=template_loader, autoescape=jinja2.select_autoescape(["html", "xml"]))
template_file = "layout.html"
template = template_env.get_template(template_file)
output_text = template.render(
@@ -28,41 +31,36 @@ def generate_html_contract(
referer_email=contract.form.referer.email,
productor_name=contract.form.productor.name,
productor_address=contract.form.productor.address,
payment_methods_map={"cheque": "Ordre du chèque", "transfer": "virements"},
payment_methods_map={
"cheque": "Ordre du chèque",
"transfer": "virements"},
productor_payment_methods=contract.form.productor.payment_methods,
member_name=f'{html.escape(contract.firstname)} {html.escape(contract.lastname)}',
member_email=html.escape(contract.email),
member_phone=html.escape(contract.phone),
member_name=f'{
html.escape(
contract.firstname)} {
html.escape(
contract.lastname)}',
member_email=html.escape(
contract.email),
member_phone=html.escape(
contract.phone),
contract_start_date=contract.form.start,
contract_end_date=contract.form.end,
occasionals=occasionals,
recurrents=reccurents,
recurrent_price=recurrent_price,
total_price=total_price,
contract_payment_method={"cheque": "chèque", "transfer": "virements"}[contract.payment_method],
cheques=cheques
)
# options = {
# 'page-size': 'Letter',
# 'margin-top': '0.5in',
# 'margin-right': '0.5in',
# 'margin-bottom': '0.5in',
# 'margin-left': '0.5in',
# 'encoding': "UTF-8",
# 'print-media-type': True,
# "disable-javascript": True,
# "disable-external-links": True,
# 'enable-local-file-access': False,
# "disable-local-file-access": True,
# "no-images": True,
# }
contract_payment_method={
"cheque": "chèque",
"transfer": "virements"}[
contract.payment_method],
cheques=cheques)
return HTML(
string=output_text,
base_url=template_dir,
).write_pdf()
from odfdo import Document, Table, Row, Cell
def generate_recap(
contracts: list[models.Contract],
@@ -81,4 +79,3 @@ def generate_recap(
doc.save(buffer)
return buffer.getvalue()

View File

@@ -1,28 +1,57 @@
"""Contract service responsible for read, create, update and delete contracts"""
from sqlalchemy.orm import selectinload
from sqlmodel import Session, select
import src.models as models
from src import models
def get_all(
session: Session,
user: models.User,
forms: list[str] = [],
forms: list[str] | None = None,
form_id: int | None = None,
) -> list[models.ContractPublic]:
statement = select(models.Contract)\
.join(models.Form, models.Contract.form_id == models.Form.id)\
.join(models.Productor, models.Form.productor_id == models.Productor.id)\
.where(models.Productor.type.in_([r.name for r in user.roles]))\
"""Get all contracts"""
statement = (
select(models.Contract)
.join(
models.Form,
models.Contract.form_id == models.Form.id
)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id
)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct()
if len(forms) > 0:
)
if forms:
statement = statement.where(models.Form.name.in_(forms))
if form_id:
statement = statement.where(models.Form.id == form_id)
return session.exec(statement.order_by(models.Contract.id)).all()
def get_one(session: Session, contract_id: int) -> models.ContractPublic:
def get_one(
session: Session,
contract_id: int
) -> models.ContractPublic:
"""Get one contract"""
return session.get(models.Contract, contract_id)
def create_one(session: Session, contract: models.ContractCreate) -> models.ContractPublic:
contract_create = contract.model_dump(exclude_unset=True, exclude=["products", "cheques"])
def create_one(
session: Session,
contract: models.ContractCreate
) -> models.ContractPublic:
"""Create one contract"""
contract_create = contract.model_dump(
exclude_unset=True,
exclude=["products", "cheques"]
)
new_contract = models.Contract(**contract_create)
new_contract.cheques = [
@@ -45,10 +74,27 @@ def create_one(session: Session, contract: models.ContractCreate) -> models.Cont
session.add(new_contract)
session.commit()
session.refresh(new_contract)
return new_contract
def add_contract_file(session: Session, id: int, file: bytes, price: float):
statement = select(models.Contract).where(models.Contract.id == id)
statement = (
select(models.Contract)
.where(models.Contract.id == new_contract.id)
.options(
selectinload(models.Contract.form)
.selectinload(models.Form.productor)
)
)
return session.exec(statement).one()
def add_contract_file(
session: Session,
_id: int,
file: bytes,
price: float
):
"""Add a file to an existing contract"""
statement = select(models.Contract).where(models.Contract.id == _id)
result = session.exec(statement)
contract = result.first()
contract.total_price = price
@@ -58,8 +104,14 @@ def add_contract_file(session: Session, id: int, file: bytes, price: float):
session.refresh(contract)
return contract
def update_one(session: Session, id: int, contract: models.ContractUpdate) -> models.ContractPublic:
statement = select(models.Contract).where(models.Contract.id == id)
def update_one(
session: Session,
_id: int,
contract: models.ContractUpdate
) -> models.ContractPublic:
"""Update one contract"""
statement = select(models.Contract).where(models.Contract.id == _id)
result = session.exec(statement)
new_contract = result.first()
if not new_contract:
@@ -72,8 +124,13 @@ def update_one(session: Session, id: int, contract: models.ContractUpdate) -> mo
session.refresh(new_contract)
return new_contract
def delete_one(session: Session, id: int) -> models.ContractPublic:
statement = select(models.Contract).where(models.Contract.id == id)
def delete_one(
session: Session,
_id: int
) -> models.ContractPublic:
"""Delete one contract"""
statement = select(models.Contract).where(models.Contract.id == _id)
result = session.exec(statement)
contract = result.first()
if not contract:
@@ -83,11 +140,29 @@ def delete_one(session: Session, id: int) -> models.ContractPublic:
session.commit()
return result
def is_allowed(session: Session, user: models.User, id: int) -> bool:
statement = select(models.Contract)\
.join(models.Form, models.Contract.form_id == models.Form.id)\
.join(models.Productor, models.Form.productor_id == models.Productor.id)\
.where(models.Contract.id == id)\
.where(models.Productor.type.in_([r.name for r in user.roles]))\
def is_allowed(
session: Session,
user: models.User,
_id: int
) -> bool:
"""Determine if a user is allowed to access a contract by id"""
statement = (
select(models.Contract)
.join(
models.Form,
models.Contract.form_id == models.Form.id
)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id
)
.where(models.Contract.id == _id)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct()
)
return len(session.exec(statement).all()) > 0

View File

@@ -1,11 +1,14 @@
from sqlmodel import create_engine, SQLModel, Session
from sqlmodel import Session, SQLModel, create_engine
from src.settings import settings
engine = create_engine(f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}')
engine = create_engine(
f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}')
def get_session():
with Session(engine) as session:
yield session
def create_all_tables():
SQLModel.metadata.create_all(engine)

View File

@@ -1,16 +1,25 @@
"""Forms module exceptions"""
import logging
class FormServiceError(Exception):
"""Form service exception"""
def __init__(self, message: str):
super().__init__(message)
logging.error('FormService : %s', message)
class UserNotFoundError(FormServiceError):
pass
class ProductorNotFoundError(FormServiceError):
pass
class FormNotFoundError(FormServiceError):
pass
class FormCreateError(FormServiceError):
def __init__(self, message: str, field: str | None = None):
super().__init__(message)

View File

@@ -1,14 +1,15 @@
from fastapi import APIRouter, HTTPException, Depends, Query
import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.forms.service as service
import src.forms.exceptions as exceptions
import src.forms.service as service
import src.messages as messages
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/forms')
@router.get('', response_model=list[models.FormPublic])
async def get_forms(
seasons: list[str] = Query([]),
@@ -18,6 +19,7 @@ async def get_forms(
):
return service.get_all(session, seasons, productors, current_season)
@router.get('/referents', response_model=list[models.FormPublic])
async def get_forms_filtered(
seasons: list[str] = Query([]),
@@ -28,53 +30,79 @@ async def get_forms_filtered(
):
return service.get_all(session, seasons, productors, current_season, user)
@router.get('/{id}', response_model=models.FormPublic)
async def get_form(id: int, session: Session = Depends(get_session)):
result = service.get_one(session, id)
@router.get('/{_id}', response_model=models.FormPublic)
async def get_form(
_id: int,
session: Session = Depends(get_session)
):
result = service.get_one(session, _id)
if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('form'))
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('form')
)
return result
@router.post('', response_model=models.FormPublic)
async def create_form(
form: models.FormCreate,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, form=form):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('forms', 'update')
)
try:
form = service.create_one(session, form)
except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(status_code=404, detail=str(error)) from error
except exceptions.UserNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(status_code=404, detail=str(error)) from error
except exceptions.FormCreateError as error:
raise HTTPException(status_code=400, detail=str(error))
raise HTTPException(status_code=400, detail=str(error)) from error
return form
@router.put('/{id}', response_model=models.FormPublic)
@router.put('/{_id}', response_model=models.FormPublic)
async def update_form(
id: int, form: models.FormUpdate,
_id: int,
form: models.FormUpdate,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('forms', 'update')
)
try:
result = service.update_one(session, id, form)
result = service.update_one(session, _id, form)
except exceptions.FormNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(status_code=404, detail=str(error)) from error
except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(status_code=404, detail=str(error)) from error
except exceptions.UserNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(status_code=404, detail=str(error)) from error
return result
@router.delete('/{id}', response_model=models.FormPublic)
@router.delete('/{_id}', response_model=models.FormPublic)
async def delete_form(
id: int,
_id: int,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
if not service.is_allowed(session, user, _id=_id):
raise HTTPException(
status_code=403,
detail=messages.Messages.not_allowed('forms', 'delete')
)
try:
result = service.delete_one(session, id)
result = service.delete_one(session, _id)
except exceptions.FormNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(status_code=404, detail=str(error)) from error
return result

View File

@@ -1,9 +1,9 @@
from sqlmodel import Session, select
from sqlalchemy import func
import src.models as models
import src.forms.exceptions as exceptions
import src.messages as messages
from sqlalchemy import func
from sqlmodel import Session, select
from src import models
def get_all(
session: Session,
@@ -14,45 +14,54 @@ def get_all(
) -> list[models.FormPublic]:
statement = select(models.Form)
if user:
statement = statement\
.join(models.Productor, models.Form.productor_id == models.Productor.id)\
.where(models.Productor.type.in_([r.name for r in user.roles]))\
.distinct()
statement = statement .join(
models.Productor,
models.Form.productor_id == models.Productor.id) .where(
models.Productor.type.in_(
[
r.name for r in user.roles])) .distinct()
if len(seasons) > 0:
statement = statement.where(models.Form.season.in_(seasons))
if len(productors) > 0:
statement = statement.join(models.Productor).where(models.Productor.name.in_(productors))
statement = statement.join(
models.Productor).where(
models.Productor.name.in_(productors))
if not user:
statement = statement.where(models.Form.visible == True)
statement = statement.where(models.Form.visible)
if current_season:
subquery = (
select(
models.Productor.type,
func.max(models.Form.start).label("max_start")
)
.join(models.Form)\
.group_by(models.Productor.type)\
.join(models.Form)
.group_by(models.Productor.type)
.subquery()
)
statement = select(models.Form)\
.join(models.Productor)\
.join(subquery,
(models.Productor.type == subquery.c.type) &
(models.Form.start == subquery.c.max_start)
)
(models.Productor.type == subquery.c.type) &
(models.Form.start == subquery.c.max_start)
)
if not user:
statement = statement.where(models.Form.visible == True)
statement = statement.where(models.Form.visible)
return session.exec(statement.order_by(models.Form.name)).all()
return session.exec(statement.order_by(models.Form.name)).all()
def get_one(session: Session, form_id: int) -> models.FormPublic:
return session.get(models.Form, form_id)
def create_one(session: Session, form: models.FormCreate) -> models.FormPublic:
if not form:
raise exceptions.FormCreateError(messages.Messages.invalid_input('form', 'input cannot be None'))
raise exceptions.FormCreateError(
messages.Messages.invalid_input(
'form', 'input cannot be None'))
if not session.get(models.Productor, form.productor_id):
raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
if not session.get(models.User, form.referer_id):
raise exceptions.UserNotFoundError(messages.Messages.not_found('user'))
form_create = form.model_dump(exclude_unset=True)
@@ -62,14 +71,20 @@ def create_one(session: Session, form: models.FormCreate) -> models.FormPublic:
session.refresh(new_form)
return new_form
def update_one(session: Session, id: int, form: models.FormUpdate) -> models.FormPublic:
statement = select(models.Form).where(models.Form.id == id)
def update_one(
session: Session,
_id: int,
form: models.FormUpdate) -> models.FormPublic:
statement = select(models.Form).where(models.Form.id == _id)
result = session.exec(statement)
new_form = result.first()
if not new_form:
raise exceptions.FormNotFoundError(messages.Messages.not_found('form'))
if form.productor_id and not session.get(models.Productor, form.productor_id):
raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
if form.productor_id and not session.get(
models.Productor, form.productor_id):
raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
if form.referer_id and not session.get(models.User, form.referer_id):
raise exceptions.UserNotFoundError(messages.Messages.not_found('user'))
form_updates = form.model_dump(exclude_unset=True)
@@ -80,8 +95,9 @@ def update_one(session: Session, id: int, form: models.FormUpdate) -> models.For
session.refresh(new_form)
return new_form
def delete_one(session: Session, id: int) -> models.FormPublic:
statement = select(models.Form).where(models.Form.id == id)
def delete_one(session: Session, _id: int) -> models.FormPublic:
statement = select(models.Form).where(models.Form.id == _id)
result = session.exec(statement)
form = result.first()
if not form:
@@ -91,10 +107,32 @@ def delete_one(session: Session, id: int) -> models.FormPublic:
session.commit()
return result
def is_allowed(session: Session, user: models.User, id: int) -> bool:
statement = select(models.Form)\
.join(models.Productor, models.Form.productor_id == models.Productor.id)\
.where(models.Form.id == id)\
.where(models.Productor.type.in_([r.name for r in user.roles]))\
def is_allowed(
session: Session,
user: models.User,
_id: int = None,
form: models.FormCreate = None
) -> bool:
if not _id:
statement = (
select(models.Productor)
.where(models.Productor.id == form.productor_id)
)
productor = session.exec(statement).first()
return productor.type in [r.name for r in user.roles]
statement = (
select(models.Form)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id
)
.where(models.Form.id == _id)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct()
)
return len(session.exec(statement).all()) > 0

View File

@@ -1,18 +1,15 @@
from sqlmodel import SQLModel
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from src.templates.templates import router as template_router
from src.auth.auth import router as auth_router
from src.contracts.contracts import router as contracts_router
from src.forms.forms import router as forms_router
from src.productors.productors import router as productors_router
from src.products.products import router as products_router
from src.users.users import router as users_router
from src.auth.auth import router as auth_router
from src.shipments.shipments import router as shipment_router
from src.settings import settings
from src.database import engine, create_all_tables
from src.shipments.shipments import router as shipment_router
from src.templates.templates import router as template_router
from src.users.users import router as users_router
app = FastAPI()

View File

@@ -1,5 +1,6 @@
pdferror = 'An error occured during PDF generation please contact administrator'
class Messages:
unauthorized = 'User is Unauthorized'
notauthenticated = 'User is not authenticated'

View File

@@ -1,99 +1,136 @@
from sqlmodel import Field, SQLModel, Relationship, Column, LargeBinary
import datetime
from enum import StrEnum
from typing import Optional
import datetime
from sqlmodel import Column, Field, LargeBinary, Relationship, SQLModel
class ContractType(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
id: int | None = Field(
default=None,
primary_key=True
)
name: str
class UserContractTypeLink(SQLModel, table=True):
user_id: int = Field(foreign_key="user.id", primary_key=True)
contract_type_id: int = Field(foreign_key="contracttype.id", primary_key=True)
user_id: int = Field(
foreign_key='user.id',
primary_key=True
)
contract_type_id: int = Field(
foreign_key='contracttype.id',
primary_key=True
)
class UserBase(SQLModel):
name: str
email: str
class UserPublic(UserBase):
id: int
roles: list[ContractType]
class User(UserBase, table=True):
id: int | None = Field(default=None, primary_key=True)
roles: list[ContractType] = Relationship(
link_model=UserContractTypeLink
)
class UserUpdate(SQLModel):
name: str | None
email: str | None
role_names: list[str] | None
class UserCreate(UserBase):
role_names: list[str] | None
class PaymentMethodBase(SQLModel):
name: str
details: str
max: int | None
class PaymentMethod(PaymentMethodBase, table=True):
id: int | None = Field(default=None, primary_key=True)
productor_id: int = Field(foreign_key="productor.id", ondelete="CASCADE")
productor: Optional["Productor"] = Relationship(
back_populates="payment_methods",
productor_id: int = Field(foreign_key='productor.id', ondelete='CASCADE')
productor: Optional['Productor'] = Relationship(
back_populates='payment_methods',
)
class PaymentMethodPublic(PaymentMethodBase):
id: int
productor: Optional["Productor"]
productor: Optional['Productor']
class ProductorBase(SQLModel):
name: str
address: str
type: str
class ProductorPublic(ProductorBase):
id: int
products: list["Product"] = []
payment_methods: list["PaymentMethod"] = []
products: list['Product'] = Field(default_factory=list)
payment_methods: list['PaymentMethod'] = Field(default_factory=list)
class Productor(ProductorBase, table=True):
id: int | None = Field(default=None, primary_key=True)
products: list["Product"] = Relationship(
products: list['Product'] = Relationship(
back_populates='productor',
sa_relationship_kwargs={
"order_by": "Product.name"
'order_by': 'Product.name'
},
)
payment_methods: list["PaymentMethod"] = Relationship(
back_populates="productor",
payment_methods: list['PaymentMethod'] = Relationship(
back_populates='productor',
cascade_delete=True
)
class ProductorUpdate(SQLModel):
name: str | None
address: str | None
payment_methods: list["PaymentMethod"] = []
payment_methods: list['PaymentMethod'] = Field(default_factory=list)
type: str | None
class ProductorCreate(ProductorBase):
payment_methods: list["PaymentMethod"] = []
payment_methods: list['PaymentMethod'] = Field(default_factory=list)
class Unit(StrEnum):
GRAMS = "1"
KILO = "2"
PIECE = "3"
GRAMS = '1'
KILO = '2'
PIECE = '3'
class ProductType(StrEnum):
OCCASIONAL = "1"
RECCURENT = "2"
OCCASIONAL = '1'
RECCURENT = '2'
class ShipmentProductLink(SQLModel, table=True):
shipment_id: Optional[int] = Field(default=None, foreign_key="shipment.id", primary_key=True)
product_id: Optional[int] = Field(default=None, foreign_key="product.id", primary_key=True)
shipment_id: Optional[int] = Field(
default=None,
foreign_key='shipment.id',
primary_key=True
)
product_id: Optional[int] = Field(
default=None,
foreign_key='product.id',
primary_key=True
)
class ProductBase(SQLModel):
name: str
@@ -103,17 +140,31 @@ class ProductBase(SQLModel):
quantity: float | None
quantity_unit: str | None
type: ProductType
productor_id: int | None = Field(default=None, foreign_key="productor.id")
productor_id: int | None = Field(
default=None,
foreign_key='productor.id'
)
class ProductPublic(ProductBase):
id: int
productor: Productor | None
shipments: list["Shipment"] | None
shipments: list['Shipment'] | None
class Product(ProductBase, table=True):
id: int | None = Field(default=None, primary_key=True)
shipments: list["Shipment"] = Relationship(back_populates="products", link_model=ShipmentProductLink)
productor: Optional[Productor] = Relationship(back_populates="products")
id: int | None = Field(
default=None,
primary_key=True
)
shipments: list['Shipment'] = Relationship(
back_populates='products',
link_model=ShipmentProductLink
)
productor: Optional[Productor] = Relationship(
back_populates='products'
)
class ProductUpdate(SQLModel):
name: str | None
@@ -125,41 +176,46 @@ class ProductUpdate(SQLModel):
productor_id: int | None
type: ProductType | None
class ProductCreate(ProductBase):
pass
class FormBase(SQLModel):
name: str
productor_id: int | None = Field(default=None, foreign_key="productor.id")
referer_id: int | None = Field(default=None, foreign_key="user.id")
productor_id: int | None = Field(default=None, foreign_key='productor.id')
referer_id: int | None = Field(default=None, foreign_key='user.id')
season: str
start: datetime.date
end: datetime.date
minimum_shipment_value: float | None
visible: bool
class FormPublic(FormBase):
id: int
productor: ProductorPublic | None
referer: User | None
shipments: list["ShipmentPublic"] = []
shipments: list['ShipmentPublic'] = Field(default_factory=list)
class Form(FormBase, table=True):
id: int | None = Field(default=None, primary_key=True)
productor: Optional['Productor'] = Relationship()
referer: Optional['User'] = Relationship()
shipments: list["Shipment"] = Relationship(
back_populates="form",
shipments: list['Shipment'] = Relationship(
back_populates='form',
cascade_delete=True,
sa_relationship_kwargs={
"order_by": "Shipment.name"
'order_by': 'Shipment.name'
},
)
contracts: list["Contract"] = Relationship(
back_populates="form",
contracts: list['Contract'] = Relationship(
back_populates='form',
cascade_delete=True
)
class FormUpdate(SQLModel):
name: str | None
productor_id: int | None
@@ -170,35 +226,44 @@ class FormUpdate(SQLModel):
minimum_shipment_value: float | None
visible: bool | None
class FormCreate(FormBase):
pass
class TemplateBase(SQLModel):
pass
class TemplatePublic(TemplateBase):
id: int
class Template(TemplateBase, table=True):
id: int | None = Field(default=None, primary_key=True)
class TemplateUpdate(SQLModel):
pass
class TemplateCreate(TemplateBase):
pass
class ChequeBase(SQLModel):
name: str
value: str
class Cheque(ChequeBase, table=True):
id: int | None = Field(default=None, primary_key=True)
contract_id: int = Field(foreign_key="contract.id", ondelete="CASCADE")
contract: Optional["Contract"] = Relationship(
back_populates="cheques",
contract_id: int = Field(foreign_key='contract.id', ondelete='CASCADE')
contract: Optional['Contract'] = Relationship(
back_populates='cheques',
)
class ContractBase(SQLModel):
firstname: str
lastname: str
@@ -207,105 +272,122 @@ class ContractBase(SQLModel):
payment_method: str
cheque_quantity: int
class Contract(ContractBase, table=True):
id: int | None = Field(default=None, primary_key=True)
form_id: int = Field(
foreign_key="form.id",
foreign_key='form.id',
nullable=False,
ondelete="CASCADE"
ondelete='CASCADE'
)
products: list["ContractProduct"] = Relationship(
back_populates="contract",
products: list['ContractProduct'] = Relationship(
back_populates='contract',
cascade_delete=True
)
form: Optional[Form] = Relationship(back_populates="contracts")
form: Form = Relationship(back_populates='contracts')
cheques: list[Cheque] = Relationship(
back_populates="contract",
back_populates='contract',
cascade_delete=True
)
file: bytes = Field(sa_column=Column(LargeBinary))
total_price: float | None
class ContractCreate(ContractBase):
products: list["ContractProductCreate"] = []
cheques: list["Cheque"] = []
products: list['ContractProductCreate'] = Field(default_factory=list)
cheques: list['Cheque'] = Field(default_factory=list)
form_id: int
class ContractUpdate(SQLModel):
file: bytes
class ContractPublic(ContractBase):
id: int
products: list["ContractProduct"] = []
products: list['ContractProduct'] = Field(default_factory=list)
form: Form
total_price: float | None
# file: bytes
class ContractProductBase(SQLModel):
product_id: int = Field(
foreign_key="product.id",
foreign_key='product.id',
nullable=False,
ondelete="CASCADE"
ondelete='CASCADE'
)
shipment_id: int | None = Field(
default=None,
foreign_key="shipment.id",
foreign_key='shipment.id',
nullable=True,
ondelete="CASCADE"
ondelete='CASCADE'
)
quantity: float
class ContractProduct(ContractProductBase, table=True):
id: int | None = Field(default=None, primary_key=True)
contract_id: int = Field(
foreign_key="contract.id",
foreign_key='contract.id',
nullable=False,
ondelete="CASCADE"
ondelete='CASCADE'
)
contract: Optional["Contract"] = Relationship(back_populates="products")
product: Optional["Product"] = Relationship()
shipment: Optional["Shipment"] = Relationship()
contract: Optional['Contract'] = Relationship(back_populates='products')
product: Optional['Product'] = Relationship()
shipment: Optional['Shipment'] = Relationship()
class ContractProductPublic(ContractProductBase):
id: int
quantity: float
contract: Contract
product: Product
shipment: Optional["Shipment"]
shipment: Optional['Shipment']
class ContractProductCreate(ContractProductBase):
pass
class ContractProductUpdate(ContractProductBase):
pass
class ShipmentBase(SQLModel):
name: str
date: datetime.date
form_id: int | None = Field(default=None, foreign_key="form.id", ondelete="CASCADE")
form_id: int | None = Field(
default=None,
foreign_key='form.id',
ondelete='CASCADE')
class ShipmentPublic(ShipmentBase):
id: int
products: list[Product] = []
products: list[Product] = Field(default_factory=list)
form: Form | None
class Shipment(ShipmentBase, table=True):
id: int | None = Field(default=None, primary_key=True)
products: list[Product] = Relationship(
back_populates="shipments",
back_populates='shipments',
link_model=ShipmentProductLink,
sa_relationship_kwargs={
"order_by": "Product.name"
'order_by': 'Product.name'
},
)
form: Optional[Form] = Relationship(back_populates="shipments")
form: Optional[Form] = Relationship(back_populates='shipments')
class ShipmentUpdate(SQLModel):
name: str | None
date: datetime.date | None
product_ids: list[int] | None = []
product_ids: list[int] | None = Field(default_factory=list)
class ShipmentCreate(ShipmentBase):
product_ids: list[int] = []
product_ids: list[int] = Field(default_factory=list)
form_id: int

View File

@@ -1,10 +1,16 @@
import logging
class ProductorServiceError(Exception):
def __init__(self, message: str):
super().__init__(message)
logging.error('ProductorService : %s', message)
class ProductorNotFoundError(ProductorServiceError):
pass
class ProductorCreateError(ProductorServiceError):
def __init__(self, message: str, field: str | None = None):
super().__init__(message)

View File

@@ -1,14 +1,15 @@
from fastapi import APIRouter, HTTPException, Depends, Query
import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.productors.service as service
import src.productors.exceptions as exceptions
import src.productors.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/productors')
@router.get('', response_model=list[models.ProductorPublic])
def get_productors(
names: list[str] = Query([]),
@@ -18,17 +19,22 @@ def get_productors(
):
return service.get_all(session, user, names, types)
@router.get('/{id}', response_model=models.ProductorPublic)
@router.get('/{_id}', response_model=models.ProductorPublic)
def get_productor(
id: int,
_id: int,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
result = service.get_one(session, id)
result = service.get_one(session, _id)
if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('productor'))
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('productor')
)
return result
@router.post('', response_model=models.ProductorPublic)
def create_productor(
productor: models.ProductorCreate,
@@ -38,29 +44,31 @@ def create_productor(
try:
result = service.create_one(session, productor)
except exceptions.ProductorCreateError as error:
raise HTTPException(status_code=400, detail=str(error))
raise HTTPException(status_code=400, detail=str(error)) from error
return result
@router.put('/{id}', response_model=models.ProductorPublic)
@router.put('/{_id}', response_model=models.ProductorPublic)
def update_productor(
id: int, productor: models.ProductorUpdate,
_id: int, productor: models.ProductorUpdate,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
try:
result = service.update_one(session, id, productor)
result = service.update_one(session, _id, productor)
except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(status_code=404, detail=str(error)) from error
return result
@router.delete('/{id}', response_model=models.ProductorPublic)
@router.delete('/{_id}', response_model=models.ProductorPublic)
def delete_productor(
id: int,
_id: int,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
try:
result = service.delete_one(session, id)
result = service.delete_one(session, _id)
except exceptions.ProductorNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(status_code=404, detail=str(error)) from error
return result

View File

@@ -1,7 +1,8 @@
from sqlmodel import Session, select
import src.models as models
import src.productors.exceptions as exceptions
import src.messages as messages
import src.productors.exceptions as exceptions
from sqlmodel import Session, select
from src import models
def get_all(
session: Session,
@@ -18,13 +19,20 @@ def get_all(
statement = statement.where(models.Productor.type.in_(types))
return session.exec(statement.order_by(models.Productor.name)).all()
def get_one(session: Session, productor_id: int) -> models.ProductorPublic:
return session.get(models.Productor, productor_id)
def create_one(session: Session, productor: models.ProductorCreate) -> models.ProductorPublic:
def create_one(
session: Session,
productor: models.ProductorCreate) -> models.ProductorPublic:
if not productor:
raise exceptions.ProductorCreateError(messages.Messages.invalid_input('productor', 'input cannot be None'))
productor_create = productor.model_dump(exclude_unset=True, exclude='payment_methods')
raise exceptions.ProductorCreateError(
messages.Messages.invalid_input(
'productor', 'input cannot be None'))
productor_create = productor.model_dump(
exclude_unset=True, exclude='payment_methods')
new_productor = models.Productor(**productor_create)
new_productor.payment_methods = [
@@ -39,12 +47,17 @@ def create_one(session: Session, productor: models.ProductorCreate) -> models.Pr
session.refresh(new_productor)
return new_productor
def update_one(session: Session, id: int, productor: models.ProductorUpdate) -> models.ProductorPublic:
def update_one(
session: Session,
id: int,
productor: models.ProductorUpdate) -> models.ProductorPublic:
statement = select(models.Productor).where(models.Productor.id == id)
result = session.exec(statement)
new_productor = result.first()
if not new_productor:
raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
productor_updates = productor.model_dump(exclude_unset=True)
if 'payment_methods' in productor_updates:
@@ -67,13 +80,31 @@ def update_one(session: Session, id: int, productor: models.ProductorUpdate) ->
session.refresh(new_productor)
return new_productor
def delete_one(session: Session, id: int) -> models.ProductorPublic:
statement = select(models.Productor).where(models.Productor.id == id)
result = session.exec(statement)
productor = result.first()
if not productor:
raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
result = models.ProductorPublic.model_validate(productor)
session.delete(productor)
session.commit()
return result
def is_allowed(
session: Session,
user: models.User,
_id: int,
productor: models.ProductorCreate
) -> bool:
if not _id:
return productor.type in [r.name for r in user.roles]
statement = (
select(models.Productor)
.where(models.Productor.id == _id)
.where(models.Productor.type.in_([r.name for r in user.roles]))
.distinct()
)
return len(session.exec(statement).all()) > 0

View File

@@ -2,12 +2,15 @@ class ProductServiceError(Exception):
def __init__(self, message: str):
super().__init__(message)
class ProductorNotFoundError(ProductServiceError):
pass
class ProductNotFoundError(ProductServiceError):
pass
class ProductCreateError(ProductServiceError):
def __init__(self, message: str, field: str | None = None):
super().__init__(message)

View File

@@ -1,14 +1,15 @@
from fastapi import APIRouter, HTTPException, Depends, Query
import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.products.service as service
import src.products.exceptions as exceptions
import src.products.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/products')
@router.get('', response_model=list[models.ProductPublic], )
def get_products(
user: models.User = Depends(get_current_user),
@@ -25,6 +26,7 @@ def get_products(
types,
)
@router.get('/{id}', response_model=models.ProductPublic)
def get_product(
id: int,
@@ -33,9 +35,11 @@ def get_product(
):
result = service.get_one(session, id)
if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('product'))
raise HTTPException(status_code=404,
detail=messages.Messages.not_found('product'))
return result
@router.post('', response_model=models.ProductPublic)
def create_product(
product: models.ProductCreate,
@@ -50,6 +54,7 @@ def create_product(
raise HTTPException(status_code=404, detail=str(error))
return result
@router.put('/{id}', response_model=models.ProductPublic)
def update_product(
id: int, product: models.ProductUpdate,
@@ -64,6 +69,7 @@ def update_product(
raise HTTPException(status_code=404, detail=str(error))
return result
@router.delete('/{id}', response_model=models.ProductPublic)
def delete_product(
id: int,

View File

@@ -1,7 +1,8 @@
from sqlmodel import Session, select
import src.models as models
import src.products.exceptions as exceptions
import src.messages as messages
import src.products.exceptions as exceptions
from sqlmodel import Session, select
from src import models
def get_all(
session: Session,
@@ -10,10 +11,13 @@ def get_all(
productors: list[str],
types: list[str],
) -> list[models.ProductPublic]:
statement = select(models.Product)\
.join(models.Productor, models.Product.productor_id == models.Productor.id)\
.where(models.Productor.type.in_([r.name for r in user.roles]))\
.distinct()
statement = select(
models.Product) .join(
models.Productor,
models.Product.productor_id == models.Productor.id) .where(
models.Productor.type.in_(
[
r.name for r in user.roles])) .distinct()
if len(names) > 0:
statement = statement.where(models.Product.name.in_(names))
if len(productors) > 0:
@@ -22,14 +26,21 @@ def get_all(
statement = statement.where(models.Product.type.in_(types))
return session.exec(statement.order_by(models.Product.name)).all()
def get_one(session: Session, product_id: int) -> models.ProductPublic:
return session.get(models.Product, product_id)
def create_one(session: Session, product: models.ProductCreate) -> models.ProductPublic:
def create_one(
session: Session,
product: models.ProductCreate) -> models.ProductPublic:
if not product:
raise exceptions.ProductCreateError(messages.Messages.invalid_input('product', 'input cannot be None'))
raise exceptions.ProductCreateError(
messages.Messages.invalid_input(
'product', 'input cannot be None'))
if not session.get(models.Productor, product.productor_id):
raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
product_create = product.model_dump(exclude_unset=True)
new_product = models.Product(**product_create)
session.add(new_product)
@@ -37,14 +48,21 @@ def create_one(session: Session, product: models.ProductCreate) -> models.Produc
session.refresh(new_product)
return new_product
def update_one(session: Session, id: int, product: models.ProductUpdate) -> models.ProductPublic:
def update_one(
session: Session,
id: int,
product: models.ProductUpdate) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == id)
result = session.exec(statement)
new_product = result.first()
if not new_product:
raise exceptions.ProductNotFoundError(messages.Messages.not_found('product'))
if product.productor_id and not session.get(models.Productor, product.productor_id):
raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
raise exceptions.ProductNotFoundError(
messages.Messages.not_found('product'))
if product.productor_id and not session.get(
models.Productor, product.productor_id):
raise exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor'))
product_updates = product.model_dump(exclude_unset=True)
for key, value in product_updates.items():
@@ -55,13 +73,44 @@ def update_one(session: Session, id: int, product: models.ProductUpdate) -> mode
session.refresh(new_product)
return new_product
def delete_one(session: Session, id: int) -> models.ProductPublic:
statement = select(models.Product).where(models.Product.id == id)
result = session.exec(statement)
product = result.first()
if not product:
raise exceptions.ProductNotFoundError(messages.Messages.not_found('product'))
raise exceptions.ProductNotFoundError(
messages.Messages.not_found('product'))
result = models.ProductPublic.model_validate(product)
session.delete(product)
session.commit()
return result
def is_allowed(
session: Session,
user: models.User,
_id: int,
product: models.ProductCreate
) -> bool:
if not _id:
statement = (
select(models.Product)
.join(
models.Productor,
models.Product.productor_id == models.Productor.id
)
.where(models.Product.id == product.productor_id)
)
productor = session.exec(statement).first()
return productor.type in [r.name for r in user.roles]
statement = (
select(models.Product)
.join(
models.Productor,
models.Product.productor_id == models.Productor.id
)
.where(models.Product.id == _id)
.where(models.Productor.type.in_([r.name for r in user.roles]))
.distinct()
)
return len(session.exec(statement).all()) > 0

View File

@@ -1,5 +1,6 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
origins: str
db_host: str
@@ -20,10 +21,21 @@ class Settings(BaseSettings):
env_file='../.env'
)
settings = Settings()
AUTH_URL = f"{settings.keycloak_server}/realms/{settings.keycloak_realm}/protocol/openid-connect/auth"
TOKEN_URL = f"{settings.keycloak_server}/realms/{settings.keycloak_realm}/protocol/openid-connect/token"
ISSUER = f"{settings.keycloak_server}/realms/{settings.keycloak_realm}"
JWKS_URL = f"{ISSUER}/protocol/openid-connect/certs"
LOGOUT_URL = f'{settings.keycloak_server}/realms/{settings.keycloak_realm}/protocol/openid-connect/logout'
AUTH_URL = (
f'{settings.keycloak_server}/realms/'
f'{settings.keycloak_realm}/protocol/openid-connect/auth'
)
TOKEN_URL = (
f'{settings.keycloak_server}/realms/'
f'{settings.keycloak_realm}/protocol/openid-connect/token'
)
ISSUER = f'{settings.keycloak_server}/realms/{settings.keycloak_realm}'
JWKS_URL = f'{ISSUER}/protocol/openid-connect/certs'
LOGOUT_URL = (
f'{settings.keycloak_server}/realms/'
f'{settings.keycloak_realm}/protocol/openid-connect/logout'
)

View File

@@ -1,10 +1,16 @@
import logging
class ShipmentServiceError(Exception):
def __init__(self, message: str):
super().__init__(message)
logging.error('ShipmentService : %s', message)
class ShipmentNotFoundError(ShipmentServiceError):
pass
class ShipmentCreateError(ShipmentServiceError):
def __init__(self, message: str, field: str | None = None):
super().__init__(message)

View File

@@ -1,58 +1,111 @@
from sqlmodel import Session, select
import src.models as models
import src.shipments.exceptions as exceptions
import src.messages as messages
# pylint: disable=E1101
import datetime
import src.messages as messages
import src.shipments.exceptions as exceptions
from sqlmodel import Session, select
from src import models
def get_all(
session: Session,
user: models.User,
names: list[str],
dates: list[str],
forms: list[str]
names: list[str] = None,
dates: list[str] = None,
forms: list[str] = None
) -> list[models.ShipmentPublic]:
statement = select(models.Shipment)\
.join(models.Form, models.Shipment.form_id == models.Form.id)\
.join(models.Productor, models.Form.productor_id == models.Productor.id)\
.where(models.Productor.type.in_([r.name for r in user.roles]))\
statement = (
select(models.Shipment)
.join(
models.Form,
models.Shipment.form_id == models.Form.id)
.join(
models.Productor,
models.Form.productor_id == models.Productor.id)
.where(
models.Productor.type.in_(
[r.name for r in user.roles]
)
)
.distinct()
if len(names) > 0:
)
if names and len(names) > 0:
statement = statement.where(models.Shipment.name.in_(names))
if len(dates) > 0:
statement = statement.where(models.Shipment.date.in_(list(map(lambda x: datetime.datetime.strptime(x, '%Y-%m-%d').date(), dates))))
if len(forms) > 0:
if dates and len(dates) > 0:
statement = statement.where(
models.Shipment.date.in_(
list(map(
lambda x: datetime.datetime.strptime(
x, '%Y-%m-%d').date(),
dates
))
)
)
if forms and len(forms) > 0:
statement = statement.where(models.Form.name.in_(forms))
return session.exec(statement.order_by(models.Shipment.name)).all()
def get_one(session: Session, shipment_id: int) -> models.ShipmentPublic:
return session.get(models.Shipment, shipment_id)
def create_one(session: Session, shipment: models.ShipmentCreate) -> models.ShipmentPublic:
def create_one(
session: Session,
shipment: models.ShipmentCreate) -> models.ShipmentPublic:
if shipment is None:
raise exceptions.ShipmentCreateError(messages.Messages.invalid_input('shipment', 'input cannot be None'))
products = session.exec(select(models.Product).where(models.Product.id.in_(shipment.product_ids))).all()
shipment_create = shipment.model_dump(exclude_unset=True, exclude={'product_ids'})
raise exceptions.ShipmentCreateError(
messages.Messages.invalid_input(
'shipment', 'input cannot be None'))
products = session.exec(
select(models.Product)
.where(
models.Product.id.in_(
shipment.product_ids
)
)
).all()
shipment_create = shipment.model_dump(
exclude_unset=True, exclude={'product_ids'}
)
new_shipment = models.Shipment(**shipment_create, products=products)
session.add(new_shipment)
session.commit()
session.refresh(new_shipment)
return new_shipment
def update_one(session: Session, id: int, shipment: models.ShipmentUpdate) -> models.ShipmentPublic:
def update_one(
session: Session,
_id: int,
shipment: models.ShipmentUpdate) -> models.ShipmentPublic:
if shipment is None:
raise exceptions.ShipmentCreateError(messages.Messages.invalid_input('shipment', 'input cannot be None'))
statement = select(models.Shipment).where(models.Shipment.id == id)
raise exceptions.ShipmentCreateError(
messages.Messages.invalid_input(
'shipment', 'input cannot be None'))
statement = select(models.Shipment).where(models.Shipment.id == _id)
result = session.exec(statement)
new_shipment = result.first()
if not new_shipment:
raise exceptions.ShipmentNotFoundError(messages.Messages.not_found('shipment'))
raise exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment'))
products_to_add = session.exec(select(models.Product).where(models.Product.id.in_(shipment.product_ids))).all()
products_to_add = session.exec(
select(
models.Product
).where(
models.Product.id.in_(
shipment.product_ids
)
)
).all()
new_shipment.products.clear()
for add in products_to_add:
new_shipment.products.append(add)
shipment_updates = shipment.model_dump(exclude_unset=True, exclude={"product_ids"})
shipment_updates = shipment.model_dump(
exclude_unset=True, exclude={"product_ids"}
)
for key, value in shipment_updates.items():
setattr(new_shipment, key, value)
@@ -61,12 +114,14 @@ def update_one(session: Session, id: int, shipment: models.ShipmentUpdate) -> mo
session.refresh(new_shipment)
return new_shipment
def delete_one(session: Session, id: int) -> models.ShipmentPublic:
statement = select(models.Shipment).where(models.Shipment.id == id)
def delete_one(session: Session, _id: int) -> models.ShipmentPublic:
statement = select(models.Shipment).where(models.Shipment.id == _id)
result = session.exec(statement)
shipment = result.first()
if not shipment:
raise exceptions.ShipmentNotFoundError(messages.Messages.not_found('shipment'))
raise exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment'))
result = models.ShipmentPublic.model_validate(shipment)
session.delete(shipment)

View File

@@ -1,14 +1,15 @@
from fastapi import APIRouter, HTTPException, Depends, Query
import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.shipments.service as service
import src.shipments.exceptions as exceptions
import src.shipments.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/shipments')
@router.get('', response_model=list[models.ShipmentPublic], )
def get_shipments(
session: Session = Depends(get_session),
@@ -25,17 +26,22 @@ def get_shipments(
forms,
)
@router.get('/{id}', response_model=models.ShipmentPublic)
@router.get('/{_id}', response_model=models.ShipmentPublic)
def get_shipment(
id: int,
_id: int,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
result = service.get_one(session, id)
result = service.get_one(session, _id)
if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('shipment'))
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('shipment')
)
return result
@router.post('', response_model=models.ShipmentPublic)
def create_shipment(
shipment: models.ShipmentCreate,
@@ -45,30 +51,32 @@ def create_shipment(
try:
result = service.create_one(session, shipment)
except exceptions.ShipmentCreateError as error:
raise HTTPException(status_code=400, detail=str(error))
raise HTTPException(status_code=400, detail=str(error)) from error
return result
@router.put('/{id}', response_model=models.ShipmentPublic)
@router.put('/{_id}', response_model=models.ShipmentPublic)
def update_shipment(
id: int,
_id: int,
shipment: models.ShipmentUpdate,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
try:
result = service.update_one(session, id, shipment)
result = service.update_one(session, _id, shipment)
except exceptions.ShipmentNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(status_code=404, detail=str(error)) from error
return result
@router.delete('/{id}', response_model=models.ShipmentPublic)
@router.delete('/{_id}', response_model=models.ShipmentPublic)
def delete_shipment(
id: int,
_id: int,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
try:
result = service.delete_one(session, id)
result = service.delete_one(session, _id)
except exceptions.ShipmentNotFoundError as error:
raise HTTPException(status_code=404, detail=str(error))
raise HTTPException(status_code=404, detail=str(error)) from error
return result

View File

@@ -1,14 +1,19 @@
from sqlmodel import Session, select
import src.models as models
from src import models
def get_all(session: Session) -> list[models.TemplatePublic]:
statement = select(models.Template)
return session.exec(statement.order_by(models.Template.name)).all()
def get_one(session: Session, template_id: int) -> models.TemplatePublic:
return session.get(models.Template, template_id)
def create_one(session: Session, template: models.TemplateCreate) -> models.TemplatePublic:
def create_one(
session: Session,
template: models.TemplateCreate) -> models.TemplatePublic:
template_create = template.model_dump(exclude_unset=True)
new_template = models.Template(**template_create)
session.add(new_template)
@@ -16,7 +21,11 @@ def create_one(session: Session, template: models.TemplateCreate) -> models.Temp
session.refresh(new_template)
return new_template
def update_one(session: Session, id: int, template: models.TemplateUpdate) -> models.TemplatePublic:
def update_one(
session: Session,
id: int,
template: models.TemplateUpdate) -> models.TemplatePublic:
statement = select(models.Template).where(models.Template.id == id)
result = session.exec(statement)
new_template = result.first()
@@ -30,6 +39,7 @@ def update_one(session: Session, id: int, template: models.TemplateUpdate) -> mo
session.refresh(new_template)
return new_template
def delete_one(session: Session, id: int) -> models.TemplatePublic:
statement = select(models.Template).where(models.Template.id == id)
result = session.exec(statement)

View File

@@ -1,13 +1,14 @@
from fastapi import APIRouter, HTTPException, Depends
import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.templates.service as service
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/templates')
@router.get('', response_model=list[models.TemplatePublic])
def get_templates(
user: models.User = Depends(get_current_user),
@@ -15,6 +16,7 @@ def get_templates(
):
return service.get_all(session)
@router.get('/{id}', response_model=models.TemplatePublic)
def get_template(
id: int,
@@ -23,9 +25,11 @@ def get_template(
):
result = service.get_one(session, id)
if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('template'))
raise HTTPException(status_code=404,
detail=messages.Messages.not_found('template'))
return result
@router.post('', response_model=models.TemplatePublic)
def create_template(
template: models.TemplateCreate,
@@ -34,6 +38,7 @@ def create_template(
):
return service.create_one(session, template)
@router.put('/{id}', response_model=models.TemplatePublic)
def update_template(
id: int, template: models.TemplateUpdate,
@@ -42,9 +47,11 @@ def update_template(
):
result = service.update_one(session, id, template)
if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('template'))
raise HTTPException(status_code=404,
detail=messages.Messages.not_found('template'))
return result
@router.delete('/{id}', response_model=models.TemplatePublic)
def delete_template(
id: int,
@@ -53,5 +60,6 @@ def delete_template(
):
result = service.delete_one(session, id)
if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('template'))
raise HTTPException(status_code=404,
detail=messages.Messages.not_found('template'))
return result

View File

@@ -1,10 +1,16 @@
import logging
class UserServiceError(Exception):
def __init__(self, message: str):
super().__init__(message)
logging.error('UserService : %s', message)
class UserNotFoundError(UserServiceError):
pass
class UserCreateError(UserServiceError):
def __init__(self, message: str, field: str | None = None):
super().__init__(message)

View File

@@ -1,9 +1,8 @@
from sqlmodel import Session, select
import src.models as models
import src.messages as messages
import src.users.exceptions as exceptions
from sqlmodel import Session, select
from src import models
def get_all(
session: Session,
@@ -17,11 +16,15 @@ def get_all(
statement = statement.where(models.User.email.in_(emails))
return session.exec(statement.order_by(models.User.name)).all()
def get_one(session: Session, user_id: int) -> models.UserPublic:
return session.get(models.User, user_id)
def get_or_create_roles(session: Session, role_names: list[str]) -> list[models.ContractType]:
statement = select(models.ContractType).where(models.ContractType.name.in_(role_names))
def get_or_create_roles(session: Session,
role_names: list[str]) -> list[models.ContractType]:
statement = select(models.ContractType).where(
models.ContractType.name.in_(role_names))
existing = session.exec(statement).all()
existing_roles = {role.name for role in existing}
missing_role = set(role_names) - existing_roles
@@ -37,8 +40,11 @@ def get_or_create_roles(session: Session, role_names: list[str]) -> list[models.
session.refresh(role)
return existing + new_roles
def get_or_create_user(session: Session, user_create: models.UserCreate):
statement = select(models.User).where(models.User.email == user_create.email)
statement = select(
models.User).where(
models.User.email == user_create.email)
user = session.exec(statement).first()
if user:
user_role_names = [r.name for r in user.roles]
@@ -48,13 +54,21 @@ def get_or_create_user(session: Session, user_create: models.UserCreate):
user = create_one(session, user_create)
return user
def get_roles(session: Session):
statement = select(models.ContractType)
statement = (
select(models.ContractType)
)
return session.exec(statement.order_by(models.ContractType.name)).all()
def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
if user is None:
raise exceptions.UserCreateError(messages.Messages.invalid_input('user', 'input cannot be None'))
raise exceptions.UserCreateError(
messages.Messages.invalid_input(
'user', 'input cannot be None'
)
)
new_user = models.User(
name=user.name,
email=user.email
@@ -68,14 +82,22 @@ def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
session.refresh(new_user)
return new_user
def update_one(session: Session, id: int, user: models.UserCreate) -> models.UserPublic:
def update_one(
session: Session,
_id: int,
user: models.UserCreate) -> models.UserPublic:
if user is None:
raise exceptions.UserCreateError(messages.s.invalid_input('user', 'input cannot be None'))
statement = select(models.User).where(models.User.id == id)
raise exceptions.UserCreateError(
messages.Messages.invalid_input(
'user', 'input cannot be None'
)
)
statement = select(models.User).where(models.User.id == _id)
result = session.exec(statement)
new_user = result.first()
if not new_user:
raise exceptions.UserNotFoundError(f'User {id} not found')
raise exceptions.UserNotFoundError(f'User {_id} not found')
new_user.email = user.email
new_user.name = user.name
@@ -86,12 +108,13 @@ def update_one(session: Session, id: int, user: models.UserCreate) -> models.Use
session.refresh(new_user)
return new_user
def delete_one(session: Session, id: int) -> models.UserPublic:
statement = select(models.User).where(models.User.id == id)
def delete_one(session: Session, _id: int) -> models.UserPublic:
statement = select(models.User).where(models.User.id == _id)
result = session.exec(statement)
user = result.first()
if not user:
raise exceptions.UserNotFoundError(f'User {id} not found')
raise exceptions.UserNotFoundError(f'User {_id} not found')
result = models.UserPublic.model_validate(user)
session.delete(user)
session.commit()

View File

@@ -1,14 +1,15 @@
from fastapi import APIRouter, HTTPException, Depends, Query
import src.messages as messages
import src.models as models
from src.database import get_session
from sqlmodel import Session
import src.users.service as service
from src.auth.auth import get_current_user
import src.users.exceptions as exceptions
import src.users.service as service
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel import Session
from src import models
from src.auth.auth import get_current_user
from src.database import get_session
router = APIRouter(prefix='/users')
@router.get('', response_model=list[models.UserPublic])
def get_users(
session: Session = Depends(get_session),
@@ -22,6 +23,7 @@ def get_users(
emails,
)
@router.get('/roles', response_model=list[models.ContractType])
def get_roles(
user: models.User = Depends(get_current_user),
@@ -29,17 +31,22 @@ def get_roles(
):
return service.get_roles(session)
@router.get('/{id}', response_model=models.UserPublic)
def get_users(
id: int,
@router.get('/{_id}', response_model=models.UserPublic)
def get_user(
_id: int,
user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
result = service.get_one(session, id)
result = service.get_one(session, _id)
if result is None:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('user'))
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('user')
)
return result
@router.post('', response_model=models.UserPublic)
def create_user(
user: models.UserCreate,
@@ -49,22 +56,30 @@ def create_user(
try:
user = service.create_one(session, user)
except exceptions.UserCreateError as error:
raise HTTPException(status_code=400, detail=str(error))
raise HTTPException(
status_code=400,
detail=str(error)
) from error
return user
@router.put('/{id}', response_model=models.UserPublic)
@router.put('/{_id}', response_model=models.UserPublic)
def update_user(
id: int,
_id: int,
user: models.UserUpdate,
logged_user: models.User = Depends(get_current_user),
session: Session = Depends(get_session)
):
try:
result = service.update_one(session, id, user)
result = service.update_one(session, _id, user)
except exceptions.UserNotFoundError as error:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('user'))
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('user')
) from error
return result
@router.delete('/{id}', response_model=models.UserPublic)
def delete_user(
id: int,
@@ -74,5 +89,8 @@ def delete_user(
try:
result = service.delete_one(session, id)
except exceptions.UserNotFoundError as error:
raise HTTPException(status_code=404, detail=messages.Messages.not_found('user'))
raise HTTPException(
status_code=404,
detail=messages.Messages.not_found('user')
) from error
return result

View File

@@ -1,13 +1,14 @@
import pytest
from fastapi.testclient import TestClient
from sqlmodel import SQLModel, Session, create_engine
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src import models
from src.auth.auth import get_current_user
from src.database import get_session
from src.main import app
from .fixtures import *
from src.main import app
import src.models as models
from src.database import get_session
from src.auth.auth import get_current_user
@pytest.fixture
def mock_session(mocker):
@@ -20,6 +21,7 @@ def mock_session(mocker):
yield session
app.dependency_overrides.clear()
@pytest.fixture
def mock_user():
user = models.User(id=1, name='test user', email='test@user.com')
@@ -31,10 +33,12 @@ def mock_user():
yield user
app.dependency_overrides.clear()
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture(name='session')
def session_fixture():
engine = create_engine(

View File

@@ -1,10 +1,12 @@
import src.models as models
import tests.factories.contracts as contract_factory
import tests.factories.products as product_factory
from src import models
def contract_product_factory(**kwargs):
contract = contract_factory.contract_factory(id=1)
product = product_factory.product_public_factory(id=1, type=models.ProductType.RECCURENT)
product = product_factory.product_public_factory(
id=1, type=models.ProductType.RECCURENT)
data = dict(
product_id=1,
shipment_id=1,
@@ -16,6 +18,7 @@ def contract_product_factory(**kwargs):
data.update(kwargs)
return models.ContractProduct(**data)
def contract_product_public_factory(**kwargs):
contract = contract_factory.contract_factory(id=1)
product = product_factory.product_public_factory(id=1)
@@ -31,6 +34,7 @@ def contract_product_public_factory(**kwargs):
data.update(kwargs)
return models.ContractProductPublic(**data)
def contract_product_create_factory(**kwargs):
data = dict(
product_id=1,
@@ -40,6 +44,7 @@ def contract_product_create_factory(**kwargs):
data.update(kwargs)
return models.ContractProductCreate(**data)
def contract_product_update_factory(**kwargs):
data = dict(
product_id=1,
@@ -49,6 +54,7 @@ def contract_product_update_factory(**kwargs):
data.update(kwargs)
return models.ContractProductUpdate(**data)
def contract_product_body_factory(**kwargs):
data = dict(
product_id=1,

View File

@@ -1,6 +1,8 @@
import src.models as models
from src import models
from .forms import form_factory
def contract_factory(**kwargs):
data = dict(
id=1,
@@ -17,6 +19,7 @@ def contract_factory(**kwargs):
data.update(kwargs)
return models.Contract(**data)
def contract_public_factory(**kwargs):
data = dict(
id=1,
@@ -33,6 +36,7 @@ def contract_public_factory(**kwargs):
data.update(kwargs)
return models.ContractPublic(**data)
def contract_create_factory(**kwargs):
data = dict(
firstname="test",
@@ -48,6 +52,7 @@ def contract_create_factory(**kwargs):
data.update(kwargs)
return models.ContractCreate(**data)
def contract_update_factory(**kwargs):
data = dict(
firstname="test",
@@ -60,6 +65,7 @@ def contract_update_factory(**kwargs):
data.update(kwargs)
return models.ContractUpdate(**data)
def contract_body_factory(**kwargs):
data = dict(
firstname="test",

View File

@@ -1,9 +1,11 @@
import src.models as models
from .productors import productor_public_factory
from .shipments import shipment_public_factory
from .users import user_factory
import datetime
from src import models
from .productors import productor_public_factory
from .users import user_factory
def form_factory(**kwargs):
data = dict(
id=1,
@@ -37,6 +39,7 @@ def form_body_factory(**kwargs):
data.update(kwargs)
return data
def form_create_factory(**kwargs):
data = dict(
name='form 1',
@@ -51,6 +54,7 @@ def form_create_factory(**kwargs):
data.update(kwargs)
return models.FormCreate(**data)
def form_update_factory(**kwargs):
data = dict(
name='form 1',
@@ -65,7 +69,8 @@ def form_update_factory(**kwargs):
data.update(kwargs)
return models.FormUpdate(**data)
def form_public_factory(form=None, shipments=[],**kwargs):
def form_public_factory(form=None, shipments=[], **kwargs):
data = dict(
id=1,
name='form 1',

View File

@@ -1,4 +1,5 @@
import src.models as models
from src import models
def productor_factory(**kwargs):
data = dict(
@@ -10,6 +11,7 @@ def productor_factory(**kwargs):
data.update(kwargs)
return models.Productor(**data)
def productor_public_factory(**kwargs):
data = dict(
id=1,
@@ -22,6 +24,7 @@ def productor_public_factory(**kwargs):
data.update(kwargs)
return models.ProductorPublic(**data)
def productor_create_factory(**kwargs):
data = dict(
id=1,
@@ -34,6 +37,7 @@ def productor_create_factory(**kwargs):
data.update(kwargs)
return models.ProductorCreate(**data)
def productor_update_factory(**kwargs):
data = dict(
id=1,
@@ -46,6 +50,7 @@ def productor_update_factory(**kwargs):
data.update(kwargs)
return models.ProductorUpdate(**data)
def productor_body_factory(**kwargs):
data = dict(
id=1,

View File

@@ -1,6 +1,7 @@
import src.models as models
from src import models
from .productors import productor_factory
from .shipments import shipment_factory
def product_body_factory(**kwargs):
data = dict(
@@ -16,6 +17,7 @@ def product_body_factory(**kwargs):
data.update(kwargs)
return data
def product_create_factory(**kwargs):
data = dict(
name='product test 1',
@@ -30,6 +32,7 @@ def product_create_factory(**kwargs):
data.update(kwargs)
return models.ProductCreate(**data)
def product_update_factory(**kwargs):
data = dict(
name='product test 1',
@@ -44,7 +47,8 @@ def product_update_factory(**kwargs):
data.update(kwargs)
return models.ProductUpdate(**data)
def product_public_factory(productor=None, shipments=[],**kwargs):
def product_public_factory(productor=None, shipments=[], **kwargs):
if productor is None:
productor = productor_factory()
data = dict(

View File

@@ -1,6 +1,8 @@
import src.models as models
import datetime
from src import models
def shipment_factory(**kwargs):
data = dict(
id=1,
@@ -11,6 +13,7 @@ def shipment_factory(**kwargs):
data.update(kwargs)
return models.Shipment(**data)
def shipment_public_factory(**kwargs):
data = dict(
id=1,
@@ -23,6 +26,7 @@ def shipment_public_factory(**kwargs):
data.update(kwargs)
return models.ShipmentPublic(**data)
def shipment_create_factory(**kwargs):
data = dict(
name="test shipment",
@@ -33,6 +37,7 @@ def shipment_create_factory(**kwargs):
data.update(kwargs)
return models.ShipmentCreate(**data)
def shipment_update_factory(**kwargs):
data = dict(
name="test shipment",
@@ -43,6 +48,7 @@ def shipment_update_factory(**kwargs):
data.update(kwargs)
return models.ShipmentUpdate(**data)
def shipment_body_factory(**kwargs):
data = dict(
name="test shipment",

View File

@@ -1,4 +1,5 @@
import src.models as models
from src import models
def user_factory(**kwargs):
data = dict(
@@ -10,6 +11,7 @@ def user_factory(**kwargs):
data.update(kwargs)
return models.User(**data)
def user_public_factory(**kwargs):
data = dict(
id=1,
@@ -20,6 +22,7 @@ def user_public_factory(**kwargs):
data.update(kwargs)
return models.UserPublic(**data)
def user_create_factory(**kwargs):
data = dict(
name="test user",
@@ -29,6 +32,7 @@ def user_create_factory(**kwargs):
data.update(kwargs)
return models.UserCreate(**data)
def user_update_factory(**kwargs):
data = dict(
name="test user",
@@ -38,6 +42,7 @@ def user_update_factory(**kwargs):
data.update(kwargs)
return models.UserUpdate(**data)
def user_body_factory(**kwargs):
data = dict(
name="test user",

View File

@@ -1,18 +1,19 @@
import pytest
import datetime
from sqlmodel import Session
import src.models as models
import pytest
import src.forms.service as forms_service
import src.shipments.service as shipments_service
import src.productors.service as productors_service
import src.products.service as products_service
import src.shipments.service as shipments_service
import src.users.service as users_service
import tests.factories.forms as forms_factory
import tests.factories.shipments as shipments_factory
import tests.factories.productors as productors_factory
import tests.factories.products as products_factory
import tests.factories.shipments as shipments_factory
import tests.factories.users as users_factory
from sqlmodel import Session
from src import models
@pytest.fixture
def productor(session: Session) -> models.ProductorPublic:
@@ -46,8 +47,10 @@ def productors(session: Session) -> models.ProductorPublic:
]
return productors
@pytest.fixture
def products(session: Session, productor: models.ProductorPublic) -> list[models.ProductPublic]:
def products(session: Session,
productor: models.ProductorPublic) -> list[models.ProductPublic]:
products = [
products_service.create_one(
session,
@@ -68,6 +71,7 @@ def products(session: Session, productor: models.ProductorPublic) -> list[models
]
return products
@pytest.fixture
def user(session: Session) -> models.UserPublic:
user = users_service.create_one(
@@ -80,6 +84,7 @@ def user(session: Session) -> models.UserPublic:
)
return user
@pytest.fixture
def users(session: Session) -> list[models.UserPublic]:
users = [
@@ -88,28 +93,28 @@ def users(session: Session) -> list[models.UserPublic]:
users_factory.user_create_factory(
name='test user 1 (admin)',
email='test1@test.com',
role_names=['Légumineuses', 'Légumes', 'Oeufs', 'Porc-Agneau', 'Vin', 'Fruits']
)
),
role_names=[
'Légumineuses',
'Légumes',
'Oeufs',
'Porc-Agneau',
'Vin',
'Fruits'])),
users_service.create_one(
session,
users_factory.user_create_factory(
name='test user 2',
email='test2@test.com',
role_names=['Légumineuses']
)
),
role_names=['Légumineuses'])),
users_service.create_one(
session,
users_factory.user_create_factory(
name='test user 3',
email='test3@test.com',
role_names=['Porc-Agneau']
)
)
]
role_names=['Porc-Agneau']))]
return users
@pytest.fixture
def referer(session: Session) -> models.UserPublic:
referer = users_service.create_one(
@@ -122,8 +127,11 @@ def referer(session: Session) -> models.UserPublic:
)
return referer
@pytest.fixture
def shipments(session: Session, forms: list[models.FormPublic], products: list[models.ProductPublic]):
def shipments(session: Session,
forms: list[models.FormPublic],
products: list[models.ProductPublic]):
shipments = [
shipments_service.create_one(
session,
@@ -146,6 +154,7 @@ def shipments(session: Session, forms: list[models.FormPublic], products: list[m
]
return shipments
@pytest.fixture
def forms(
session: Session,
@@ -173,4 +182,3 @@ def forms(
)
]
return forms

View File

@@ -1,12 +1,12 @@
import src.contracts.service as service
import src.models as models
from src.main import app
from src.auth.auth import get_current_user
import tests.factories.contract_products as contract_products_factory
import tests.factories.contracts as contract_factory
import tests.factories.forms as form_factory
import tests.factories.contract_products as contract_products_factory
from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user
from src.main import app
class TestContracts:
def test_get_all(self, client, mocker, mock_session, mock_user):
@@ -31,6 +31,7 @@ class TestContracts:
mock_user,
[],
)
def test_get_all_filters(self, client, mocker, mock_session, mock_user):
mock_results = [
contract_factory.contract_public_factory(id=2),
@@ -52,7 +53,12 @@ class TestContracts:
['form test'],
)
def test_get_all_unauthorized(self, client, mocker, mock_session, mock_user):
def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
@@ -74,7 +80,7 @@ class TestContracts:
'get_one',
return_value=mock_result
)
mock_is_allowed = mocker.patch.object(
mocker.patch.object(
service,
'is_allowed',
return_value=True
@@ -97,20 +103,24 @@ class TestContracts:
return_value=mock_result
)
mock_is_allowed = mocker.patch.object(
mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.get('/api/contracts/2')
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
mock_session,
2
)
def test_get_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
@@ -124,41 +134,6 @@ class TestContracts:
app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user):
contract_body = contract_factory.contract_body_factory(
products=[
contract_products_factory.contract_product_body_factory(product_id=1),
contract_products_factory.contract_product_body_factory(product_id=2),
contract_products_factory.contract_product_body_factory(product_id=3)
],
cheques=[{'name': '123123', 'value': '100'}]
)
contract_result = contract_factory.contract_factory(
products=[
contract_products_factory.contract_product_factory(product_id=1),
contract_products_factory.contract_product_factory(product_id=2),
contract_products_factory.contract_product_factory(product_id=3)
],
form=form_factory.form_factory(),
cheques=[models.Cheque(name='123123', value='100')]
)
mock_create_one = mocker.patch.object(
service,
'create_one',
return_value=contract_result
)
mock_add_contract_file = mocker.patch.object(
service,
'add_contract_file',
return_value=True
)
mock_generate_html_contract = mocker.patch('src.contracts.generate_contract.generate_html_contract')
response = client.post('/api/contracts', json=contract_body)
assert response.status_code == 200
contract_id = 'test_test_test type_hiver-2026'
assert response.headers['Content-Disposition'] == f'attachment; filename=contract_{contract_id}.pdf'
def test_delete_one(self, client, mocker, mock_session, mock_user):
contract_result = contract_factory.contract_public_factory()
@@ -168,14 +143,13 @@ class TestContracts:
return_value=contract_result
)
mock_is_allowed = mocker.patch.object(
mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/contracts/2')
response_data = response.json()
assert response.status_code == 200
mock.assert_called_once_with(
@@ -183,7 +157,13 @@ class TestContracts:
2,
)
def test_delete_one_notfound(self, client, mocker, mock_session, mock_user):
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user
):
contract_result = None
mock = mocker.patch.object(
@@ -192,14 +172,13 @@ class TestContracts:
return_value=contract_result
)
mock_is_allowed = mocker.patch.object(
mocker.patch.object(
service,
'is_allowed',
return_value=True
)
response = client.delete('/api/contracts/2')
response_data = response.json()
assert response.status_code == 404
mock.assert_called_once_with(
@@ -207,10 +186,15 @@ class TestContracts:
2,
)
def test_delete_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user
):
def unauthorized():
raise HTTPException(status_code=401)
contract_body = contract_factory.contract_body_factory()
app.dependency_overrides[get_current_user] = unauthorized

View File

@@ -1,11 +1,12 @@
import src.forms.service as service
import src.forms.exceptions as forms_exceptions
import src.models as models
from src.main import app
from src.auth.auth import get_current_user
import src.forms.service as service
import src.messages as messages
import tests.factories.forms as form_factory
from fastapi.exceptions import HTTPException
import src.messages as messages
from src import models
from src.auth.auth import get_current_user
from src.main import app
class TestForms:
def test_get_all(self, client, mocker, mock_session, mock_user):
@@ -32,6 +33,7 @@ class TestForms:
False,
mock_user,
)
def test_get_all_filters(self, client, mocker, mock_session, mock_user):
mock_results = [
form_factory.form_public_factory(name="test 2", id=2),
@@ -42,7 +44,8 @@ class TestForms:
return_value=mock_results
)
response = client.get('/api/forms/referents?current_season=true&seasons=hiver-2025&productors=test productor')
response = client.get(
'/api/forms/referents?current_season=true&seasons=hiver-2025&productors=test productor')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 2
@@ -55,7 +58,12 @@ class TestForms:
mock_user,
)
def test_get_all_unauthorized(self, client, mocker, mock_session, mock_user):
def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
@@ -103,7 +111,6 @@ class TestForms:
2
)
def test_create_one(self, client, mocker, mock_session, mock_user):
form_body = form_factory.form_body_factory(name='test form create')
form_create = form_factory.form_create_factory(name='test form create')
@@ -125,15 +132,16 @@ class TestForms:
form_create
)
def test_create_one_referer_notfound(self, client, mocker, mock_session, mock_user):
form_body = form_factory.form_body_factory(name='test form create', referer_id=12312)
form_create = form_factory.form_create_factory(name='test form create', referer_id=12312)
def test_create_one_referer_notfound(
self, client, mocker, mock_session, mock_user):
form_body = form_factory.form_body_factory(
name='test form create', referer_id=12312)
form_create = form_factory.form_create_factory(
name='test form create', referer_id=12312)
mock = mocker.patch.object(
service,
'create_one',
side_effect=forms_exceptions.UserNotFoundError(messages.Messages.not_found('referer'))
)
service, 'create_one', side_effect=forms_exceptions.UserNotFoundError(
messages.Messages.not_found('referer')))
response = client.post('/api/forms', json=form_body)
response_data = response.json()
@@ -144,15 +152,16 @@ class TestForms:
form_create
)
def test_create_one_productor_notfound(self, client, mocker, mock_session, mock_user):
form_body = form_factory.form_body_factory(name='test form create', productor_id=1231)
form_create = form_factory.form_create_factory(name='test form create', productor_id=1231)
def test_create_one_productor_notfound(
self, client, mocker, mock_session, mock_user):
form_body = form_factory.form_body_factory(
name='test form create', productor_id=1231)
form_create = form_factory.form_create_factory(
name='test form create', productor_id=1231)
mock = mocker.patch.object(
service,
'create_one',
side_effect=forms_exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
)
service, 'create_one', side_effect=forms_exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')))
response = client.post('/api/forms', json=form_body)
response_data = response.json()
@@ -163,7 +172,12 @@ class TestForms:
form_create
)
def test_create_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
form_body = form_factory.form_body_factory(name='test form create')
@@ -200,15 +214,18 @@ class TestForms:
form_update
)
def test_update_one_notfound(self, client, mocker, mock_session, mock_user):
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update')
mock = mocker.patch.object(
service,
'update_one',
side_effect=forms_exceptions.FormNotFoundError(messages.Messages.not_found('form'))
)
service, 'update_one', side_effect=forms_exceptions.FormNotFoundError(
messages.Messages.not_found('form')))
response = client.put('/api/forms/2', json=form_body)
response_data = response.json()
@@ -220,15 +237,14 @@ class TestForms:
form_update
)
def test_update_one_referer_notfound(self, client, mocker, mock_session, mock_user):
def test_update_one_referer_notfound(
self, client, mocker, mock_session, mock_user):
form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update')
mock = mocker.patch.object(
service,
'update_one',
side_effect=forms_exceptions.UserNotFoundError(messages.Messages.not_found('referer'))
)
service, 'update_one', side_effect=forms_exceptions.UserNotFoundError(
messages.Messages.not_found('referer')))
response = client.put('/api/forms/2', json=form_body)
response_data = response.json()
@@ -240,15 +256,14 @@ class TestForms:
form_update
)
def test_update_one_productor_notfound(self, client, mocker, mock_session, mock_user):
def test_update_one_productor_notfound(
self, client, mocker, mock_session, mock_user):
form_body = form_factory.form_body_factory(name='test form update')
form_update = form_factory.form_update_factory(name='test form update')
mock = mocker.patch.object(
service,
'update_one',
side_effect=forms_exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
)
service, 'update_one', side_effect=forms_exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')))
response = client.put('/api/forms/2', json=form_body)
response_data = response.json()
@@ -260,7 +275,12 @@ class TestForms:
form_update
)
def test_update_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
form_body = form_factory.form_body_factory(name='test form update')
@@ -294,14 +314,17 @@ class TestForms:
2,
)
def test_delete_one_notfound(self, client, mocker, mock_session, mock_user):
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
form_result = None
mock = mocker.patch.object(
service,
'delete_one',
side_effect=forms_exceptions.FormNotFoundError(messages.Messages.not_found('form'))
)
service, 'delete_one', side_effect=forms_exceptions.FormNotFoundError(
messages.Messages.not_found('form')))
response = client.delete('/api/forms/2')
response_data = response.json()
@@ -312,7 +335,12 @@ class TestForms:
2,
)
def test_delete_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)

View File

@@ -1,14 +1,12 @@
from fastapi.exceptions import HTTPException
from src.main import app
import src.models as models
import src.messages as messages
from src.auth.auth import get_current_user
import src.productors.service as service
import src.productors.exceptions as exceptions
import src.productors.service as service
import tests.factories.productors as productor_factory
from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user
from src.main import app
class TestProductors:
def test_get_all(self, client, mocker, mock_session, mock_user):
@@ -34,6 +32,7 @@ class TestProductors:
[],
[],
)
def test_get_all_filters(self, client, mocker, mock_session, mock_user):
mock_results = [
productor_factory.productor_public_factory(name="test 2", id=2),
@@ -44,7 +43,8 @@ class TestProductors:
return_value=mock_results
)
response = client.get('/api/productors?types=Légumineuses&names=test 2')
response = client.get(
'/api/productors?types=Légumineuses&names=test 2')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 2
@@ -56,7 +56,12 @@ class TestProductors:
['Légumineuses'],
)
def test_get_all_unauthorized(self, client, mocker, mock_session, mock_user):
def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
@@ -71,7 +76,8 @@ class TestProductors:
app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user):
mock_result = productor_factory.productor_public_factory(name="test 2", id=2)
mock_result = productor_factory.productor_public_factory(
name="test 2", id=2)
mock = mocker.patch.object(
service,
@@ -104,7 +110,12 @@ class TestProductors:
2
)
def test_get_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
@@ -119,9 +130,12 @@ class TestProductors:
app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user):
productor_body = productor_factory.productor_body_factory(name='test productor create')
productor_create = productor_factory.productor_create_factory(name='test productor create')
productor_result = productor_factory.productor_public_factory(name='test productor create')
productor_body = productor_factory.productor_body_factory(
name='test productor create')
productor_create = productor_factory.productor_create_factory(
name='test productor create')
productor_result = productor_factory.productor_public_factory(
name='test productor create')
mock = mocker.patch.object(
service,
@@ -139,10 +153,16 @@ class TestProductors:
productor_create
)
def test_create_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory(name='test productor create')
productor_body = productor_factory.productor_body_factory(
name='test productor create')
app.dependency_overrides[get_current_user] = unauthorized
@@ -155,9 +175,12 @@ class TestProductors:
app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user):
productor_body = productor_factory.productor_body_factory(name='test productor update')
productor_update = productor_factory.productor_update_factory(name='test productor update')
productor_result = productor_factory.productor_public_factory(name='test productor update')
productor_body = productor_factory.productor_body_factory(
name='test productor update')
productor_update = productor_factory.productor_update_factory(
name='test productor update')
productor_result = productor_factory.productor_public_factory(
name='test productor update')
mock = mocker.patch.object(
service,
@@ -176,16 +199,21 @@ class TestProductors:
productor_update
)
def test_update_one_notfound(self, client, mocker, mock_session, mock_user):
productor_body = productor_factory.productor_body_factory(name='test productor update')
productor_update = productor_factory.productor_update_factory(name='test productor update')
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
productor_body = productor_factory.productor_body_factory(
name='test productor update')
productor_update = productor_factory.productor_update_factory(
name='test productor update')
productor_result = None
mock = mocker.patch.object(
service,
'update_one',
side_effect=exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
)
service, 'update_one', side_effect=exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')))
response = client.put('/api/productors/2', json=productor_body)
response_data = response.json()
@@ -197,10 +225,16 @@ class TestProductors:
productor_update
)
def test_update_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory(name='test productor update')
productor_body = productor_factory.productor_body_factory(
name='test productor update')
app.dependency_overrides[get_current_user] = unauthorized
@@ -213,7 +247,8 @@ class TestProductors:
app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user):
productor_result = productor_factory.productor_public_factory(name='test productor delete')
productor_result = productor_factory.productor_public_factory(
name='test productor delete')
mock = mocker.patch.object(
service,
@@ -231,14 +266,17 @@ class TestProductors:
2,
)
def test_delete_one_notfound(self, client, mocker, mock_session, mock_user):
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
productor_result = None
mock = mocker.patch.object(
service,
'delete_one',
side_effect=exceptions.ProductorNotFoundError(messages.Messages.not_found('productor'))
)
service, 'delete_one', side_effect=exceptions.ProductorNotFoundError(
messages.Messages.not_found('productor')))
response = client.delete('/api/productors/2')
response_data = response.json()
@@ -249,10 +287,16 @@ class TestProductors:
2,
)
def test_delete_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
productor_body = productor_factory.productor_body_factory(name='test productor delete')
productor_body = productor_factory.productor_body_factory(
name='test productor delete')
app.dependency_overrides[get_current_user] = unauthorized

View File

@@ -1,11 +1,11 @@
import src.products.service as service
import src.products.exceptions as exceptions
import src.models as models
from src.main import app
from src.auth.auth import get_current_user
import src.products.service as service
import tests.factories.products as product_factory
from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user
from src.main import app
class TestProducts:
def test_get_all(self, client, mocker, mock_session, mock_user):
@@ -32,6 +32,7 @@ class TestProducts:
[],
[]
)
def test_get_all_filters(self, client, mocker, mock_session, mock_user):
mock_results = [
product_factory.product_public_factory(name="test 2", id=2),
@@ -55,7 +56,12 @@ class TestProducts:
['1'],
)
def test_get_all_unauthorized(self, client, mocker, mock_session, mock_user):
def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
@@ -70,7 +76,8 @@ class TestProducts:
app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user):
mock_result = product_factory.product_public_factory(name="test 2", id=2)
mock_result = product_factory.product_public_factory(
name="test 2", id=2)
mock = mocker.patch.object(
service,
@@ -103,7 +110,12 @@ class TestProducts:
2
)
def test_get_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
@@ -118,9 +130,12 @@ class TestProducts:
app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user):
product_body = product_factory.product_body_factory(name='test product create')
product_create = product_factory.product_create_factory(name='test product create')
product_result = product_factory.product_public_factory(name='test product create')
product_body = product_factory.product_body_factory(
name='test product create')
product_create = product_factory.product_create_factory(
name='test product create')
product_result = product_factory.product_public_factory(
name='test product create')
mock = mocker.patch.object(
service,
@@ -138,10 +153,16 @@ class TestProducts:
product_create
)
def test_create_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory(name='test product create')
product_body = product_factory.product_body_factory(
name='test product create')
app.dependency_overrides[get_current_user] = unauthorized
@@ -154,9 +175,12 @@ class TestProducts:
app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user):
product_body = product_factory.product_body_factory(name='test product update')
product_update = product_factory.product_update_factory(name='test product update')
product_result = product_factory.product_public_factory(name='test product update')
product_body = product_factory.product_body_factory(
name='test product update')
product_update = product_factory.product_update_factory(
name='test product update')
product_result = product_factory.product_public_factory(
name='test product update')
mock = mocker.patch.object(
service,
@@ -175,9 +199,16 @@ class TestProducts:
product_update
)
def test_update_one_notfound(self, client, mocker, mock_session, mock_user):
product_body = product_factory.product_body_factory(name='test product update')
product_update = product_factory.product_update_factory(name='test product update')
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
product_body = product_factory.product_body_factory(
name='test product update')
product_update = product_factory.product_update_factory(
name='test product update')
product_result = None
mock = mocker.patch.object(
@@ -196,10 +227,16 @@ class TestProducts:
product_update
)
def test_update_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory(name='test product update')
product_body = product_factory.product_body_factory(
name='test product update')
app.dependency_overrides[get_current_user] = unauthorized
@@ -212,7 +249,8 @@ class TestProducts:
app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user):
product_result = product_factory.product_public_factory(name='test product delete')
product_result = product_factory.product_public_factory(
name='test product delete')
mock = mocker.patch.object(
service,
@@ -230,7 +268,12 @@ class TestProducts:
2,
)
def test_delete_one_notfound(self, client, mocker, mock_session, mock_user):
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
product_result = None
mock = mocker.patch.object(
@@ -248,10 +291,16 @@ class TestProducts:
2,
)
def test_delete_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
product_body = product_factory.product_body_factory(name='test product delete')
product_body = product_factory.product_body_factory(
name='test product delete')
app.dependency_overrides[get_current_user] = unauthorized

View File

@@ -1,12 +1,12 @@
import src.shipments.service as service
import src.models as models
from src.main import app
import src.messages as messages
import src.shipments.exceptions as exceptions
from src.auth.auth import get_current_user
import src.shipments.service as service
import tests.factories.shipments as shipment_factory
from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user
from src.main import app
class TestShipments:
def test_get_all(self, client, mocker, mock_session, mock_user):
@@ -33,6 +33,7 @@ class TestShipments:
[],
[],
)
def test_get_all_filters(self, client, mocker, mock_session, mock_user):
mock_results = [
shipment_factory.shipment_public_factory(name="test 2", id=2),
@@ -43,7 +44,8 @@ class TestShipments:
return_value=mock_results
)
response = client.get('/api/shipments?dates=2025-10-10&names=test 2&forms=contract form 1')
response = client.get(
'/api/shipments?dates=2025-10-10&names=test 2&forms=contract form 1')
response_data = response.json()
assert response.status_code == 200
assert response_data[0]['id'] == 2
@@ -56,7 +58,12 @@ class TestShipments:
['contract form 1'],
)
def test_get_all_unauthorized(self, client, mocker, mock_session, mock_user):
def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
@@ -71,7 +78,8 @@ class TestShipments:
app.dependency_overrides.clear()
def test_get_one(self, client, mocker, mock_session, mock_user):
mock_result = shipment_factory.shipment_public_factory(name="test 2", id=2)
mock_result = shipment_factory.shipment_public_factory(
name="test 2", id=2)
mock = mocker.patch.object(
service,
@@ -104,7 +112,12 @@ class TestShipments:
2
)
def test_get_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
@@ -119,9 +132,12 @@ class TestShipments:
app.dependency_overrides.clear()
def test_create_one(self, client, mocker, mock_session, mock_user):
shipment_body = shipment_factory.shipment_body_factory(name='test shipment create')
shipment_create = shipment_factory.shipment_create_factory(name='test shipment create')
shipment_result = shipment_factory.shipment_public_factory(name='test shipment create')
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment create')
shipment_create = shipment_factory.shipment_create_factory(
name='test shipment create')
shipment_result = shipment_factory.shipment_public_factory(
name='test shipment create')
mock = mocker.patch.object(
service,
@@ -139,10 +155,16 @@ class TestShipments:
shipment_create
)
def test_create_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory(name='test shipment create')
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment create')
app.dependency_overrides[get_current_user] = unauthorized
@@ -155,9 +177,12 @@ class TestShipments:
app.dependency_overrides.clear()
def test_update_one(self, client, mocker, mock_session, mock_user):
shipment_body = shipment_factory.shipment_body_factory(name='test shipment update')
shipment_update = shipment_factory.shipment_update_factory(name='test shipment update')
shipment_result = shipment_factory.shipment_public_factory(name='test shipment update')
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update')
shipment_update = shipment_factory.shipment_update_factory(
name='test shipment update')
shipment_result = shipment_factory.shipment_public_factory(
name='test shipment update')
mock = mocker.patch.object(
service,
@@ -176,15 +201,20 @@ class TestShipments:
shipment_update
)
def test_update_one_notfound(self, client, mocker, mock_session, mock_user):
shipment_body = shipment_factory.shipment_body_factory(name='test shipment update')
shipment_update = shipment_factory.shipment_update_factory(name='test shipment update')
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update')
shipment_update = shipment_factory.shipment_update_factory(
name='test shipment update')
mock = mocker.patch.object(
service,
'update_one',
side_effect=exceptions.ShipmentNotFoundError(messages.Messages.not_found('shipment'))
)
service, 'update_one', side_effect=exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment')))
response = client.put('/api/shipments/2', json=shipment_body)
response_data = response.json()
@@ -196,10 +226,16 @@ class TestShipments:
shipment_update
)
def test_update_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory(name='test shipment update')
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment update')
app.dependency_overrides[get_current_user] = unauthorized
@@ -212,7 +248,8 @@ class TestShipments:
app.dependency_overrides.clear()
def test_delete_one(self, client, mocker, mock_session, mock_user):
shipment_result = shipment_factory.shipment_public_factory(name='test shipment delete')
shipment_result = shipment_factory.shipment_public_factory(
name='test shipment delete')
mock = mocker.patch.object(
service,
@@ -230,14 +267,17 @@ class TestShipments:
2,
)
def test_delete_one_notfound(self, client, mocker, mock_session, mock_user):
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
shipment_result = None
mock = mocker.patch.object(
service,
'delete_one',
side_effect=exceptions.ShipmentNotFoundError(messages.Messages.not_found('shipment'))
)
service, 'delete_one', side_effect=exceptions.ShipmentNotFoundError(
messages.Messages.not_found('shipment')))
response = client.delete('/api/shipments/2')
response_data = response.json()
@@ -248,10 +288,16 @@ class TestShipments:
2,
)
def test_delete_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
shipment_body = shipment_factory.shipment_body_factory(name='test shipment delete')
shipment_body = shipment_factory.shipment_body_factory(
name='test shipment delete')
app.dependency_overrides[get_current_user] = unauthorized

View File

@@ -1,11 +1,11 @@
import src.users.service as service
import src.models as models
from src.main import app
from src.auth.auth import get_current_user
import tests.factories.users as user_factory
import src.users.exceptions as exceptions
import src.users.service as service
import tests.factories.users as user_factory
from fastapi.exceptions import HTTPException
from src import models
from src.auth.auth import get_current_user
from src.main import app
class TestUsers:
def test_get_all(self, client, mocker, mock_session, mock_user):
@@ -30,6 +30,7 @@ class TestUsers:
[],
[],
)
def test_get_all_filters(self, client, mocker, mock_session, mock_user):
mock_results = [
user_factory.user_public_factory(name="test 2", id=2),
@@ -51,7 +52,12 @@ class TestUsers:
['test@test.test'],
)
def test_get_all_unauthorized(self, client, mocker, mock_session, mock_user):
def test_get_all_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
@@ -99,7 +105,12 @@ class TestUsers:
2
)
def test_get_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_get_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
@@ -134,7 +145,12 @@ class TestUsers:
user_create
)
def test_create_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_create_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
user_body = user_factory.user_body_factory(name='test user create')
@@ -171,7 +187,12 @@ class TestUsers:
user_update
)
def test_update_one_notfound(self, client, mocker, mock_session, mock_user):
def test_update_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
user_body = user_factory.user_body_factory(name='test user update')
user_update = user_factory.user_update_factory(name='test user update')
user_result = None
@@ -192,7 +213,12 @@ class TestUsers:
user_update
)
def test_update_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_update_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
user_body = user_factory.user_body_factory(name='test user update')
@@ -226,7 +252,12 @@ class TestUsers:
2,
)
def test_delete_one_notfound(self, client, mocker, mock_session, mock_user):
def test_delete_one_notfound(
self,
client,
mocker,
mock_session,
mock_user):
user_result = None
mock = mocker.patch.object(
@@ -244,7 +275,12 @@ class TestUsers:
2,
)
def test_delete_one_unauthorized(self, client, mocker, mock_session, mock_user):
def test_delete_one_unauthorized(
self,
client,
mocker,
mock_session,
mock_user):
def unauthorized():
raise HTTPException(status_code=401)
user_body = user_factory.user_body_factory(name='test user delete')

View File

@@ -1,35 +1,41 @@
import pytest
from sqlmodel import Session
import src.models as models
import src.forms.service as forms_service
import src.forms.exceptions as forms_exceptions
import src.forms.service as forms_service
import tests.factories.forms as forms_factory
from sqlmodel import Session
from src import models
class TestFormsService:
def test_get_all_forms(self, session: Session, forms: list[models.FormPublic]):
def test_get_all_forms(self, session: Session,
forms: list[models.FormPublic]):
result = forms_service.get_all(session, [], [], False)
assert len(result) == 2
assert result == forms
def test_get_all_forms_filter_productors(self, session: Session, forms: list[models.FormPublic]):
def test_get_all_forms_filter_productors(
self, session: Session, forms: list[models.FormPublic]):
result = forms_service.get_all(session, [], ['test productor'], False)
assert len(result) == 2
assert result == forms
def test_get_all_forms_filter_season(self, session: Session, forms: list[models.FormPublic]):
def test_get_all_forms_filter_season(
self, session: Session, forms: list[models.FormPublic]):
result = forms_service.get_all(session, ['test season 1'], [], False)
assert len(result) == 1
def test_get_all_forms_all_filters(self, session: Session, forms: list[models.FormPublic]):
result = forms_service.get_all(session, ['test season 1'], ['test productor'], True)
def test_get_all_forms_all_filters(
self, session: Session, forms: list[models.FormPublic]):
result = forms_service.get_all(
session, ['test season 1'], ['test productor'], True)
assert result == forms
def test_get_one_form(self, session: Session, forms: list[models.FormPublic]):
def test_get_one_form(self, session: Session,
forms: list[models.FormPublic]):
result = forms_service.get_one(session, forms[0].id)
assert result == forms[0]
@@ -37,7 +43,7 @@ class TestFormsService:
def test_get_one_form_notfound(self, session: Session):
result = forms_service.get_one(session, 122)
assert result == None
assert result is None
def test_create_form(
self,
@@ -77,7 +83,6 @@ class TestFormsService:
with pytest.raises(forms_exceptions.UserNotFoundError):
result = forms_service.create_one(session, form_create)
def test_update_form(
self,
session: Session,
@@ -141,7 +146,7 @@ class TestFormsService:
result = forms_service.delete_one(session, form_id)
check = forms_service.get_one(session, form_id)
assert check == None
assert check is None
def test_delete_form_notfound(
self,
@@ -151,4 +156,3 @@ class TestFormsService:
form_id = 123
with pytest.raises(forms_exceptions.FormNotFoundError):
result = forms_service.delete_one(session, form_id)

View File

@@ -1,10 +1,10 @@
import pytest
from sqlmodel import Session
import src.models as models
import src.productors.service as productors_service
import src.productors.exceptions as productors_exceptions
import src.productors.service as productors_service
import tests.factories.productors as productors_factory
from sqlmodel import Session
from src import models
class TestProductorsService:
def test_get_all_productors(
@@ -63,7 +63,9 @@ class TestProductorsService:
assert len(result) == 1
def test_get_one_productor(self, session: Session, productors: list[models.ProductorPublic]):
def test_get_one_productor(self,
session: Session,
productors: list[models.ProductorPublic]):
result = productors_service.get_one(session, productors[0].id)
assert result == productors[0]
@@ -71,7 +73,7 @@ class TestProductorsService:
def test_get_one_productor_notfound(self, session: Session):
result = productors_service.get_one(session, 122)
assert result == None
assert result is None
def test_create_productor(
self,
@@ -104,7 +106,8 @@ class TestProductorsService:
name='updated test productor',
)
productor_id = productors[0].id
result = productors_service.update_one(session, productor_id, productor_update)
result = productors_service.update_one(
session, productor_id, productor_update)
assert result.id == productor_id
assert result.name == 'updated test productor'
@@ -119,7 +122,8 @@ class TestProductorsService:
)
productor_id = 123
with pytest.raises(productors_exceptions.ProductorNotFoundError):
result = productors_service.update_one(session, productor_id, productor_update)
result = productors_service.update_one(
session, productor_id, productor_update)
def test_delete_productor(
self,
@@ -130,7 +134,7 @@ class TestProductorsService:
result = productors_service.delete_one(session, productor_id)
check = productors_service.get_one(session, productor_id)
assert check == None
assert check is None
def test_delete_productor_notfound(
self,
@@ -140,4 +144,3 @@ class TestProductorsService:
productor_id = 123
with pytest.raises(productors_exceptions.ProductorNotFoundError):
result = productors_service.delete_one(session, productor_id)

View File

@@ -1,10 +1,10 @@
import pytest
from sqlmodel import Session
import src.models as models
import src.products.service as products_service
import src.products.exceptions as products_exceptions
import src.products.service as products_service
import tests.factories.products as products_factory
from sqlmodel import Session
from src import models
class TestProductsService:
def test_get_all_products(
@@ -83,7 +83,8 @@ class TestProductsService:
assert len(result) == 1
def test_get_one_product(self, session: Session, products: list[models.ProductPublic]):
def test_get_one_product(self, session: Session,
products: list[models.ProductPublic]):
result = products_service.get_one(session, products[0].id)
assert result == products[0]
@@ -91,7 +92,7 @@ class TestProductsService:
def test_get_one_product_notfound(self, session: Session):
result = products_service.get_one(session, 122)
assert result == None
assert result is None
def test_create_product(
self,
@@ -118,7 +119,8 @@ class TestProductsService:
with pytest.raises(products_exceptions.ProductCreateError):
result = products_service.create_one(session, product_create)
product_create = products_factory.product_create_factory(productor_id=123)
product_create = products_factory.product_create_factory(
productor_id=123)
with pytest.raises(products_exceptions.ProductorNotFoundError):
result = products_service.create_one(session, product_create)
@@ -134,7 +136,8 @@ class TestProductsService:
productor_id=productor.id,
)
product_id = products[0].id
result = products_service.update_one(session, product_id, product_update)
result = products_service.update_one(
session, product_id, product_update)
assert result.id == product_id
assert result.name == 'updated test product'
@@ -151,7 +154,8 @@ class TestProductsService:
)
product_id = 123
with pytest.raises(products_exceptions.ProductNotFoundError):
result = products_service.update_one(session, product_id, product_update)
result = products_service.update_one(
session, product_id, product_update)
def test_update_product_invalidinput(
self,
@@ -160,9 +164,11 @@ class TestProductsService:
products: list[models.ProductPublic]
):
product_id = products[0].id
product_update = products_factory.product_update_factory(productor_id=123)
product_update = products_factory.product_update_factory(
productor_id=123)
with pytest.raises(products_exceptions.ProductorNotFoundError):
result = products_service.update_one(session, product_id, product_update)
result = products_service.update_one(
session, product_id, product_update)
product_update = products_factory.product_update_factory(
productor_id=productor.id,
@@ -178,7 +184,7 @@ class TestProductsService:
result = products_service.delete_one(session, product_id)
check = products_service.get_one(session, product_id)
assert check == None
assert check is None
def test_delete_product_notfound(
self,
@@ -188,4 +194,3 @@ class TestProductsService:
product_id = 123
with pytest.raises(products_exceptions.ProductNotFoundError):
result = products_service.delete_one(session, product_id)

View File

@@ -1,11 +1,12 @@
import pytest
import datetime
from sqlmodel import Session
import src.models as models
import src.shipments.service as shipments_service
import pytest
import src.shipments.exceptions as shipments_exceptions
import src.shipments.service as shipments_service
import tests.factories.shipments as shipments_factory
from sqlmodel import Session
from src import models
class TestShipmentsService:
def test_get_all_shipments(
@@ -25,7 +26,8 @@ class TestShipmentsService:
shipments: list[models.ShipmentPublic],
user: models.UserPublic,
):
result = shipments_service.get_all(session, user, ['test shipment 1'], [], [])
result = shipments_service.get_all(
session, user, ['test shipment 1'], [], [])
assert len(result) == 1
assert result == [shipments[0]]
@@ -36,7 +38,8 @@ class TestShipmentsService:
shipments: list[models.ShipmentPublic],
user: models.UserPublic,
):
result = shipments_service.get_all(session, user, [], ['2025-10-10'], [])
result = shipments_service.get_all(
session, user, [], ['2025-10-10'], [])
assert len(result) == 1
@@ -47,7 +50,8 @@ class TestShipmentsService:
forms: list[models.FormPublic],
user: models.UserPublic,
):
result = shipments_service.get_all(session, user, [], [], [forms[0].name])
result = shipments_service.get_all(
session, user, [], [], [forms[0].name])
assert len(result) == 2
@@ -58,11 +62,13 @@ class TestShipmentsService:
forms: list[models.FormPublic],
user: models.UserPublic,
):
result = shipments_service.get_all(session, user, ['test shipment 1'], ['2025-10-10'], [forms[0].name])
result = shipments_service.get_all(session, user, ['test shipment 1'], [
'2025-10-10'], [forms[0].name])
assert len(result) == 1
def test_get_one_shipment(self, session: Session, shipments: list[models.ShipmentPublic]):
def test_get_one_shipment(self, session: Session,
shipments: list[models.ShipmentPublic]):
result = shipments_service.get_one(session, shipments[0].id)
assert result == shipments[0]
@@ -70,7 +76,7 @@ class TestShipmentsService:
def test_get_one_shipment_notfound(self, session: Session):
result = shipments_service.get_one(session, 122)
assert result == None
assert result is None
def test_create_shipment(
self,
@@ -103,7 +109,8 @@ class TestShipmentsService:
date='2025-12-10',
)
shipment_id = shipments[0].id
result = shipments_service.update_one(session, shipment_id, shipment_update)
result = shipments_service.update_one(
session, shipment_id, shipment_update)
assert result.id == shipment_id
assert result.name == 'updated shipment 1'
@@ -119,7 +126,8 @@ class TestShipmentsService:
)
shipment_id = 123
with pytest.raises(shipments_exceptions.ShipmentNotFoundError):
result = shipments_service.update_one(session, shipment_id, shipment_update)
result = shipments_service.update_one(
session, shipment_id, shipment_update)
def test_delete_shipment(
self,
@@ -130,7 +138,7 @@ class TestShipmentsService:
result = shipments_service.delete_one(session, shipment_id)
check = shipments_service.get_one(session, shipment_id)
assert check == None
assert check is None
def test_delete_shipment_notfound(
self,
@@ -140,4 +148,3 @@ class TestShipmentsService:
shipment_id = 123
with pytest.raises(shipments_exceptions.ShipmentNotFoundError):
result = shipments_service.delete_one(session, shipment_id)

View File

@@ -1,35 +1,41 @@
import pytest
from sqlmodel import Session
import src.models as models
import src.users.service as users_service
import src.users.exceptions as users_exceptions
import src.users.service as users_service
import tests.factories.users as users_factory
from sqlmodel import Session
from src import models
class TestUsersService:
def test_get_all_users(self, session: Session, users: list[models.UserPublic]):
def test_get_all_users(self, session: Session,
users: list[models.UserPublic]):
result = users_service.get_all(session, [], [])
assert len(result) == 3
assert result == users
def test_get_all_users_filter_names(self, session: Session, users: list[models.UserPublic]):
def test_get_all_users_filter_names(
self, session: Session, users: list[models.UserPublic]):
result = users_service.get_all(session, ['test user 1 (admin)'], [])
assert len(result) == 1
assert result == [users[0]]
def test_get_all_users_filter_emails(self, session: Session, users: list[models.UserPublic]):
def test_get_all_users_filter_emails(
self, session: Session, users: list[models.UserPublic]):
result = users_service.get_all(session, [], ['test1@test.com'])
assert len(result) == 1
def test_get_all_users_all_filters(self, session: Session, users: list[models.UserPublic]):
result = users_service.get_all(session, ['test user 1 (admin)'], ['test1@test.com'])
def test_get_all_users_all_filters(
self, session: Session, users: list[models.UserPublic]):
result = users_service.get_all(
session, ['test user 1 (admin)'], ['test1@test.com'])
assert len(result) == 1
def test_get_one_user(self, session: Session, users: list[models.UserPublic]):
def test_get_one_user(self, session: Session,
users: list[models.UserPublic]):
result = users_service.get_one(session, users[0].id)
assert result == users[0]
@@ -37,7 +43,7 @@ class TestUsersService:
def test_get_one_user_notfound(self, session: Session):
result = users_service.get_one(session, 122)
assert result == None
assert result is None
def test_create_user(
self,
@@ -102,7 +108,7 @@ class TestUsersService:
result = users_service.delete_one(session, user_id)
check = users_service.get_one(session, user_id)
assert check == None
assert check is None
def test_delete_user_notfound(
self,
@@ -112,4 +118,3 @@ class TestUsersService:
user_id = 123
with pytest.raises(users_exceptions.UserNotFoundError):
result = users_service.delete_one(session, user_id)

View File

@@ -19,7 +19,7 @@ import {
type ProductorInputs,
} from "@/services/resources/productors";
import { useMemo } from "react";
import { useGetRoles } from "@/services/api";
import { useAuth } from "@/services/auth/AuthProvider";
export type ProductorModalProps = ModalBaseProps & {
currentProductor?: Productor;
@@ -32,7 +32,7 @@ export function ProductorModal({
currentProductor,
handleSubmit,
}: ProductorModalProps) {
const { data: allRoles } = useGetRoles();
const { loggedUser } = useAuth();
const form = useForm<ProductorInputs>({
initialValues: {
@@ -58,8 +58,8 @@ export function ProductorModal({
});
const roleSelect = useMemo(() => {
return allRoles?.map((role) => ({ value: String(role.name), label: role.name }));
}, [allRoles]);
return loggedUser?.user?.roles?.map((role) => ({ value: String(role.name), label: role.name }));
}, [loggedUser?.user?.roles]);
return (
<Modal opened={opened} onClose={onClose} title={t("create productor", { capfirst: true })}>