diff --git a/backend/.pre-commit-config.yaml b/backend/.pre-commit-config.yaml new file mode 100644 index 0000000..9028aa1 --- /dev/null +++ b/backend/.pre-commit-config.yaml @@ -0,0 +1,26 @@ +default_language_version: + python: python3.13 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: check-added-large-files + - id: trailing-whitespace + - id: check-ast + - id: check-builtin-literals + - id: check-docstring-first + - id: check-yaml + - id: check-toml + - id: mixed-line-ending + - id: end-of-file-fixer + - repo: local + hooks: + - id: check-pylint + name: check-pylint + entry: pylint -d R0801,R0903,W0511,W0603,C0103,R0902 + language: system + types: [python] + pass_filenames: false + args: + - backend diff --git a/backend/README.md b/backend/README.md index 9264553..e2b258c 100644 --- a/backend/README.md +++ b/backend/README.md @@ -35,6 +35,12 @@ hatch run pytest hatch run pytest --cov=src -vv ``` +## Autoformat +```console +find -type f -name '*.py' ! -path 'alembic/*' -exec autopep8 --in-place --aggressive --aggressive '{}' \; +pylint -d R0801,R0903,W0511,W0603,C0103,R0902 . +``` + ## License `backend` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license. diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 7c92e0d..ddef843 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -25,7 +25,8 @@ target_metadata = SQLModel.metadata # other values from the config, defined by the needs of env.py, # can be acquired: -config.set_main_option("sqlalchemy.url", f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}') +config.set_main_option( + "sqlalchemy.url", f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}') # ... etc. diff --git a/backend/alembic/versions/7854064278ce_message.py b/backend/alembic/versions/7854064278ce_message.py index 9521bd5..7c655fd 100644 --- a/backend/alembic/versions/7854064278ce_message.py +++ b/backend/alembic/versions/7854064278ce_message.py @@ -22,7 +22,12 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.add_column('paymentmethod', sa.Column('max', sa.Integer(), nullable=True)) + op.add_column( + 'paymentmethod', + sa.Column( + 'max', + sa.Integer(), + nullable=True)) # ### end Alembic commands ### diff --git a/backend/alembic/versions/c0b1073a8394_initial_repository.py b/backend/alembic/versions/c0b1073a8394_initial_repository.py index 0d280fb..c7af2e7 100644 --- a/backend/alembic/versions/c0b1073a8394_initial_repository.py +++ b/backend/alembic/versions/c0b1073a8394_initial_repository.py @@ -1,7 +1,7 @@ """Initial repository Revision ID: c0b1073a8394 -Revises: +Revises: Create Date: 2026-02-20 00:09:35.920486 """ @@ -22,117 +22,121 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.create_table('contracttype', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('productor', - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('address', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('id', sa.Integer(), nullable=False), - sa.PrimaryKeyConstraint('id') - ) + op.create_table( + 'contracttype', + sa.Column( + 'id', + sa.Integer(), + nullable=False), + sa.Column( + 'name', + sqlmodel.sql.sqltypes.AutoString(), + nullable=False), + sa.PrimaryKeyConstraint('id')) + op.create_table( + 'productor', sa.Column( + 'name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column( + 'address', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column( + 'type', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column( + 'id', sa.Integer(), nullable=False), sa.PrimaryKeyConstraint('id')) op.create_table('template', - sa.Column('id', sa.Integer(), nullable=False), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('user', - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('id', sa.Integer(), nullable=False), - sa.PrimaryKeyConstraint('id') - ) + sa.Column('id', sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table( + 'user', sa.Column( + 'name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column( + 'email', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column( + 'id', sa.Integer(), nullable=False), sa.PrimaryKeyConstraint('id')) op.create_table('form', - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('productor_id', sa.Integer(), nullable=True), - sa.Column('referer_id', sa.Integer(), nullable=True), - sa.Column('season', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('start', sa.Date(), nullable=False), - sa.Column('end', sa.Date(), nullable=False), - sa.Column('minimum_shipment_value', sa.Float(), nullable=True), - sa.Column('id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ), - sa.ForeignKeyConstraint(['referer_id'], ['user.id'], ), - sa.PrimaryKeyConstraint('id') - ) + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('productor_id', sa.Integer(), nullable=True), + sa.Column('referer_id', sa.Integer(), nullable=True), + sa.Column('season', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('start', sa.Date(), nullable=False), + sa.Column('end', sa.Date(), nullable=False), + sa.Column('minimum_shipment_value', sa.Float(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ), + sa.ForeignKeyConstraint(['referer_id'], ['user.id'], ), + sa.PrimaryKeyConstraint('id') + ) op.create_table('paymentmethod', - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('details', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('productor_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') - ) + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('details', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('productor_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) op.create_table('product', - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('unit', sa.Enum('GRAMS', 'KILO', 'PIECE', name='unit'), nullable=False), - sa.Column('price', sa.Float(), nullable=True), - sa.Column('price_kg', sa.Float(), nullable=True), - sa.Column('quantity', sa.Float(), nullable=True), - sa.Column('quantity_unit', sqlmodel.sql.sqltypes.AutoString(), nullable=True), - sa.Column('type', sa.Enum('OCCASIONAL', 'RECCURENT', name='producttype'), nullable=False), - sa.Column('productor_id', sa.Integer(), nullable=True), - sa.Column('id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('usercontracttypelink', - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('contract_type_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['contract_type_id'], ['contracttype.id'], ), - sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), - sa.PrimaryKeyConstraint('user_id', 'contract_type_id') - ) + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('unit', sa.Enum('GRAMS', 'KILO', 'PIECE', name='unit'), nullable=False), + sa.Column('price', sa.Float(), nullable=True), + sa.Column('price_kg', sa.Float(), nullable=True), + sa.Column('quantity', sa.Float(), nullable=True), + sa.Column('quantity_unit', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('type', sa.Enum('OCCASIONAL', 'RECCURENT', name='producttype'), nullable=False), + sa.Column('productor_id', sa.Integer(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['productor_id'], ['productor.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table( + 'usercontracttypelink', sa.Column( + 'user_id', sa.Integer(), nullable=False), sa.Column( + 'contract_type_id', sa.Integer(), nullable=False), sa.ForeignKeyConstraint( + ['contract_type_id'], ['contracttype.id'], ), sa.ForeignKeyConstraint( + ['user_id'], ['user.id'], ), sa.PrimaryKeyConstraint( + 'user_id', 'contract_type_id')) op.create_table('contract', - sa.Column('firstname', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('lastname', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('phone', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('payment_method', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('cheque_quantity', sa.Integer(), nullable=False), - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('form_id', sa.Integer(), nullable=False), - sa.Column('file', sa.LargeBinary(), nullable=True), - sa.Column('total_price', sa.Float(), nullable=True), - sa.ForeignKeyConstraint(['form_id'], ['form.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') - ) + sa.Column('firstname', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('lastname', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('phone', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('payment_method', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('cheque_quantity', sa.Integer(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('form_id', sa.Integer(), nullable=False), + sa.Column('file', sa.LargeBinary(), nullable=True), + sa.Column('total_price', sa.Float(), nullable=True), + sa.ForeignKeyConstraint(['form_id'], ['form.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) op.create_table('shipment', - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('date', sa.Date(), nullable=False), - sa.Column('form_id', sa.Integer(), nullable=True), - sa.Column('id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['form_id'], ['form.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') - ) + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('date', sa.Date(), nullable=False), + sa.Column('form_id', sa.Integer(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['form_id'], ['form.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) op.create_table('cheque', - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('value', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('contract_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['contract_id'], ['contract.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') - ) + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('value', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('contract_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['contract_id'], ['contract.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) op.create_table('contractproduct', - sa.Column('product_id', sa.Integer(), nullable=False), - sa.Column('shipment_id', sa.Integer(), nullable=True), - sa.Column('quantity', sa.Float(), nullable=False), - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('contract_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['contract_id'], ['contract.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['product_id'], ['product.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['shipment_id'], ['shipment.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') - ) + sa.Column('product_id', sa.Integer(), nullable=False), + sa.Column('shipment_id', sa.Integer(), nullable=True), + sa.Column('quantity', sa.Float(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('contract_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['contract_id'], ['contract.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['product_id'], ['product.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['shipment_id'], ['shipment.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) op.create_table('shipmentproductlink', - sa.Column('shipment_id', sa.Integer(), nullable=False), - sa.Column('product_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['product_id'], ['product.id'], ), - sa.ForeignKeyConstraint(['shipment_id'], ['shipment.id'], ), - sa.PrimaryKeyConstraint('shipment_id', 'product_id') - ) + sa.Column('shipment_id', sa.Integer(), nullable=False), + sa.Column('product_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['product_id'], ['product.id'], ), + sa.ForeignKeyConstraint(['shipment_id'], ['shipment.id'], ), + sa.PrimaryKeyConstraint('shipment_id', 'product_id') + ) # ### end Alembic commands ### diff --git a/backend/alembic/versions/e777ed5729ce_message.py b/backend/alembic/versions/e777ed5729ce_message.py index d2fa32c..b1b1238 100644 --- a/backend/alembic/versions/e777ed5729ce_message.py +++ b/backend/alembic/versions/e777ed5729ce_message.py @@ -22,7 +22,14 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.add_column('form', sa.Column('visible', sa.Boolean(), nullable=False, default=False, server_default="False")) + op.add_column( + 'form', + sa.Column( + 'visible', + sa.Boolean(), + nullable=False, + default=False, + server_default="False")) # ### end Alembic commands ### diff --git a/backend/pyproject.toml b/backend/pyproject.toml index d7b8006..48506bb 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -34,6 +34,9 @@ dependencies = [ "pytest", "pytest-cov", "pytest-mock", + "autopep8", + "prek", + "pylint", ] [project.urls] diff --git a/backend/requirements.txt b/backend/requirements.txt index e69de29..643296f 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -0,0 +1,84 @@ +alembic==1.18.4 +annotated-doc==0.0.4 +annotated-types==0.7.0 +anyio==4.12.1 +astroid==4.0.4 +autopep8==2.3.2 +brotli==1.2.0 +certifi==2026.2.25 +cffi==2.0.0 +charset-normalizer==3.4.4 +click==8.3.1 +coverage==7.13.4 +cryptography==46.0.5 +cssselect2==0.9.0 +dill==0.4.1 +dnspython==2.8.0 +email-validator==2.3.0 +fastapi==0.135.1 +fastapi-cli==0.0.24 +fastapi-cloud-cli==0.14.0 +fastar==0.8.0 +fonttools==4.61.1 +greenlet==3.3.2 +h11==0.16.0 +httpcore==1.0.9 +httptools==0.7.1 +httpx==0.28.1 +idna==3.11 +iniconfig==2.3.0 +isort==8.0.1 +Jinja2==3.1.6 +lxml==6.0.2 +Mako==1.3.10 +markdown-it-py==4.0.0 +MarkupSafe==3.0.3 +mccabe==0.7.0 +mdurl==0.1.2 +odfdo==3.21.0 +packaging==26.0 +pillow==12.1.1 +platformdirs==4.9.2 +pluggy==1.6.0 +prek==0.3.4 +psycopg2-binary==2.9.11 +pycodestyle==2.14.0 +pycparser==3.0 +pydantic==2.12.5 +pydantic-extra-types==2.11.0 +pydantic-settings==2.13.1 +pydantic_core==2.41.5 +pydyf==0.12.1 +Pygments==2.19.2 +PyJWT==2.11.0 +pylint==4.0.5 +pyphen==0.17.2 +pytest==9.0.2 +pytest-cov==7.0.0 +pytest-mock==3.15.1 +python-dotenv==1.2.2 +python-multipart==0.0.22 +PyYAML==6.0.3 +requests==2.32.5 +rich==14.3.3 +rich-toolkit==0.19.7 +rignore==0.7.6 +sentry-sdk==2.53.0 +shellingham==1.5.4 +SQLAlchemy==2.0.47 +sqlmodel==0.0.37 +starlette==0.52.1 +tinycss2==1.5.1 +tinyhtml5==2.0.0 +tomlkit==0.14.0 +typer==0.24.1 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +urllib3==2.6.3 +uvicorn==0.41.0 +uvloop==0.22.1 +watchfiles==1.1.1 +weasyprint==68.1 +webencodings==0.5.1 +websockets==16.0 +zopfli==0.4.1 diff --git a/backend/src/__init__.py b/backend/src/__init__.py index 10fe5b0..e9a63bc 100644 --- a/backend/src/__init__.py +++ b/backend/src/__init__.py @@ -1,3 +1,3 @@ # SPDX-FileCopyrightText: 2026-present Julien Aldon # -# SPDX-License-Identifier: MIT \ No newline at end of file +# SPDX-License-Identifier: MIT diff --git a/backend/src/auth/auth.py b/backend/src/auth/auth.py index 6abb4ba..02d90e4 100644 --- a/backend/src/auth/auth.py +++ b/backend/src/auth/auth.py @@ -21,6 +21,7 @@ router = APIRouter(prefix='/auth') jwk_client = PyJWKClient(JWKS_URL) security = HTTPBearer() + @router.get('/logout') def logout(): params = { @@ -59,9 +60,11 @@ def login(): 'redirect_uri': settings.keycloak_redirect_uri, 'state': state, } - request_url = requests.Request('GET', AUTH_URL, params=params).prepare().url + request_url = requests.Request( + 'GET', AUTH_URL, params=params).prepare().url return RedirectResponse(request_url) + @router.get('/callback') def callback(code: str, session: Session = Depends(get_session)): data = { @@ -82,10 +85,12 @@ def callback(code: str, session: Session = Depends(get_session)): ) token_data = response.json() - + id_token = token_data['id_token'] decoded_token = jwt.decode(id_token, options={'verify_signature': False}) - decoded_access_token = jwt.decode(token_data['access_token'], options={'verify_signature': False}) + decoded_access_token = jwt.decode( + token_data['access_token'], options={ + 'verify_signature': False}) resource_access = decoded_access_token.get('resource_access') if not resource_access: data = { @@ -141,6 +146,7 @@ def callback(code: str, session: Session = Depends(get_session)): return response + def verify_token(token: str): try: signing_key = jwk_client.get_signing_key_from_jwt(token) @@ -154,28 +160,37 @@ def verify_token(token: str): ) return decoded except jwt.ExpiredSignatureError: - raise HTTPException(status_code=401, detail=messages.Messages.tokenexipired) + raise HTTPException(status_code=401, + detail=messages.Messages.tokenexipired) except jwt.InvalidTokenError: - raise HTTPException(status_code=401, detail=messages.Messages.invalidtoken) + raise HTTPException( + status_code=401, + detail=messages.Messages.invalidtoken) -def get_current_user(request: Request, session: Session = Depends(get_session)): +def get_current_user( + request: Request, + session: Session = Depends(get_session)): access_token = request.cookies.get('access_token') if not access_token: - raise HTTPException(status_code=401, detail=messages.Messages.notauthenticated) + raise HTTPException(status_code=401, + detail=messages.Messages.notauthenticated) payload = verify_token(access_token) if not payload: raise HTTPException(status_code=401, detail='aze') email = payload.get('email') if not email: - raise HTTPException(status_code=401, detail=messages.Messages.notauthenticated) + raise HTTPException(status_code=401, + detail=messages.Messages.notauthenticated) user = session.exec(select(User).where(User.email == email)).first() if not user: - raise HTTPException(status_code=401, detail=messages.Messages.not_found('user')) + raise HTTPException(status_code=401, + detail=messages.Messages.not_found('user')) return user + @router.post('/refresh') def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None): refresh = refresh_token @@ -223,6 +238,7 @@ def refresh_token(refresh_token: Annotated[str | None, Cookie()] = None): ) return response + @router.get('/user/me') def me(user: UserPublic = Depends(get_current_user)): if not user: @@ -235,4 +251,4 @@ def me(user: UserPublic = Depends(get_current_user)): 'id': user.id, 'roles': [role.name for role in user.roles] } - } \ No newline at end of file + } diff --git a/backend/src/contracts/contracts.py b/backend/src/contracts/contracts.py index 2d3082c..f2b13e5 100644 --- a/backend/src/contracts/contracts.py +++ b/backend/src/contracts/contracts.py @@ -1,18 +1,27 @@ -from fastapi import APIRouter, Depends, HTTPException, Query -from fastapi.responses import StreamingResponse -from src.database import get_session -from sqlmodel import Session -from src.contracts.generate_contract import generate_html_contract, generate_recap -from src.auth.auth import get_current_user -import src.models as models -import src.messages as messages -import src.contracts.service as service -import src.forms.service as form_service +"""Router for contract resource""" import io import zipfile + +import src.contracts.service as service +import src.forms.service as form_service +import src.messages as messages +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.responses import StreamingResponse +from sqlmodel import Session +from src import models +from src.auth.auth import get_current_user +from src.contracts.generate_contract import (generate_html_contract, + generate_recap) +from src.database import get_session + router = APIRouter(prefix='/contracts') -def compute_recurrent_prices(products_quantities: list[dict], nb_shipment: int): + +def compute_recurrent_prices( + products_quantities: list[dict], + nb_shipment: int +): + """Compute price for recurrent products""" result = 0 for product_quantity in products_quantities: product = product_quantity['product'] @@ -20,30 +29,50 @@ def compute_recurrent_prices(products_quantities: list[dict], nb_shipment: int): result += compute_product_price(product, quantity, nb_shipment) return result + def compute_occasional_prices(occasionals: list[dict]): + """Compute prices for occassional products""" result = 0 for occasional in occasionals: result += occasional['price'] return result -def compute_product_price(product: models.Product, quantity: int, nb_shipment: int = 1): - product_quantity_unit = 1 if product.unit == models.Unit.KILO else 1000 - final_quantity = quantity if product.price else quantity / product_quantity_unit - final_price = product.price if product.price else product.price_kg - return final_price * final_quantity * nb_shipment + +def compute_product_price( + product: models.Product, + quantity: int, + nb_shipment: int = 1 +): + """Compute price for a product""" + product_quantity_unit = ( + 1 if product.unit == models.Unit.KILO else 1000 + ) + final_quantity = ( + quantity if product.price else quantity / product_quantity_unit + ) + final_price = ( + product.price if product.price else product.price_kg + ) + return final_price * final_quantity * nb_shipment + def find_dict_in_list(lst, key, value): + """Find the index of a dictionnary in a list of dictionnaries given a key + and a value. + """ for i, dic in enumerate(lst): if dic[key].id == value: return i return -1 + def create_occasional_dict(contract_products: list[models.ContractProduct]): + """Create a dictionnary of occasional products""" result = [] for contract_product in contract_products: existing_id = find_dict_in_list( - result, - 'shipment', + result, + 'shipment', contract_product.shipment.id ) if existing_id < 0: @@ -69,18 +98,46 @@ def create_occasional_dict(contract_products: list[models.ContractProduct]): ) return result + @router.post('') async def create_contract( contract: models.ContractCreate, session: Session = Depends(get_session), ): + """Create contract route""" new_contract = service.create_one(session, contract) - occasional_contract_products = list(filter(lambda contract_product: contract_product.product.type == models.ProductType.OCCASIONAL, new_contract.products)) + occasional_contract_products = list( + filter( + lambda contract_product: ( + contract_product.product.type == models.ProductType.OCCASIONAL + ), + new_contract.products + ) + ) occasionals = create_occasional_dict(occasional_contract_products) - recurrents = list(map(lambda x: {"product": x.product, "quantity": x.quantity}, filter(lambda contract_product: contract_product.product.type == models.ProductType.RECCURENT, new_contract.products))) - recurrent_price = compute_recurrent_prices(recurrents, len(new_contract.form.shipments)) + recurrents = list( + map( + lambda x: {'product': x.product, 'quantity': x.quantity}, + filter( + lambda contract_product: ( + contract_product.product.type == + models.ProductType.RECCURENT + ), + new_contract.products + ) + ) + ) + recurrent_price = compute_recurrent_prices( + recurrents, + len(new_contract.form.shipments) + ) price = recurrent_price + compute_occasional_prices(occasionals) - cheques = list(map(lambda x: {"name": x.name, "value": x.value}, new_contract.cheques)) + cheques = list( + map( + lambda x: {'name': x.name, 'value': x.value}, + new_contract.cheques + ) + ) try: pdf_bytes = generate_html_contract( new_contract, @@ -91,43 +148,63 @@ async def create_contract( '{:10.2f}'.format(price) ) pdf_file = io.BytesIO(pdf_bytes) - contract_id = f'{new_contract.firstname}_{new_contract.lastname}_{new_contract.form.productor.type}_{new_contract.form.season}' + contract_id = ( + f'{new_contract.firstname}_' + f'{new_contract.lastname}_' + f'{new_contract.form.productor.type}_' + f'{new_contract.form.season}' + ) service.add_contract_file(session, new_contract.id, pdf_bytes, price) - except Exception: - raise HTTPException(status_code=400, detail=messages.pdferror) + except Exception as error: + raise HTTPException( + status_code=400, + detail=messages.pdferror + ) from error return StreamingResponse( pdf_file, media_type='application/pdf', headers={ - 'Content-Disposition': f'attachment; filename=contract_{contract_id}.pdf' + 'Content-Disposition': ( + f'attachment; filename=contract_{contract_id}.pdf' + ) } ) + @router.get('/{form_id}/base') async def get_base_contract_template( form_id: int, session: Session = Depends(get_session), ): + """Get contract template route""" form = form_service.get_one(session, form_id) - recurrents = list(map(lambda x: {"product": x, "quantity": None}, filter(lambda product: product.type == models.ProductType.RECCURENT, form.productor.products))) + recurrents = [ + {'product': product, 'quantity': None} + for product in form.productor.products + if product.type == models.ProductType.RECCURENT + ] occasionals = [{ - 'shipment': sh, - 'price': None, + 'shipment': sh, + 'price': None, 'products': [{'product': pr, 'quantity': None} for pr in sh.products] } for sh in form.shipments] empty_contract = models.ContractPublic( - firstname="", + firstname='', form=form, - lastname="", - email="", - phone="", + lastname='', + email='', + phone='', products=[], - payment_method="cheque", + payment_method='cheque', cheque_quantity=3, total_price=0, id=1 ) - cheques = [{"name": None, "value": None}, {"name": None, "value": None}, {"name": None, "value": None}] + cheques = [ + {'name': None, 'value': None}, + {'name': None, 'value': None}, + {'name': None, 'value': None} + ] try: pdf_bytes = generate_html_contract( empty_contract, @@ -136,45 +213,68 @@ async def get_base_contract_template( recurrents, ) pdf_file = io.BytesIO(pdf_bytes) - contract_id = f'{empty_contract.form.productor.type}_{empty_contract.form.season}' - except Exception as e: - print(e) - raise HTTPException(status_code=400, detail=messages.pdferror) + contract_id = ( + f'{empty_contract.form.productor.type}_' + f'{empty_contract.form.season}' + ) + except Exception as error: + raise HTTPException( + status_code=400, + detail=messages.pdferror + ) from error return StreamingResponse( pdf_file, media_type='application/pdf', headers={ - 'Content-Disposition': f'attachment; filename=contract_{contract_id}.pdf' + 'Content-Disposition': ( + f'attachment; filename=contract_{contract_id}.pdf' + ) } ) + @router.get('', response_model=list[models.ContractPublic]) def get_contracts( forms: list[str] = Query([]), session: Session = Depends(get_session), user: models.User = Depends(get_current_user) ): + """Get all contracts route""" return service.get_all(session, user, forms) -@router.get('/{id}/file') + +@router.get('/{_id}/file') def get_contract_file( - id: int, + _id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user) ): - if not service.is_allowed(session, user, id): - raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contract', 'get')) - contract = service.get_one(session, id) + """Get a contract file (in pdf) route""" + if not service.is_allowed(session, user, _id): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('contract', 'get') + ) + contract = service.get_one(session, _id) if contract is None: - raise HTTPException(status_code=404, detail=messages.Messages.not_found('contract')) - filename = f'{contract.form.name.replace(' ', '_')}_{contract.form.season}_{contract.firstname}-{contract.lastname}' + raise HTTPException( + status_code=404, + detail=messages.Messages.not_found('contract') + ) + filename = ( + f'{contract.form.name.replace(' ', '_')}_' + f'{contract.form.season}_' + f'{contract.firstname}_' + f'{contract.lastname}' + ) return StreamingResponse( io.BytesIO(contract.file), media_type='application/pdf', headers={ 'Content-Disposition': f'attachment; filename={filename}.pdf' } - ) + ) + @router.get('/{form_id}/files') def get_contract_files( @@ -182,17 +282,30 @@ def get_contract_files( session: Session = Depends(get_session), user: models.User = Depends(get_current_user) ): + """Get all contract files for a given form""" if not form_service.is_allowed(session, user, form_id): - raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contracts', 'get')) + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('contracts', 'get') + ) form = form_service.get_one(session, form_id=form_id) contracts = service.get_all(session, user, forms=[form.name]) zipped_contracts = io.BytesIO() - with zipfile.ZipFile(zipped_contracts, "a", zipfile.ZIP_DEFLATED, False) as zip_file: + with zipfile.ZipFile( + zipped_contracts, + 'a', + zipfile.ZIP_DEFLATED, + False + ) as zip_file: for contract in contracts: - contract_filename = f'{contract.form.name.replace(' ', '_')}_{contract.form.season}_{contract.firstname}-{contract.lastname}.pdf' + contract_filename = ( + f'{contract.form.name.replace(' ', '_')}_' + f'{contract.form.season}_' + f'{contract.firstname}_' + f'{contract.lastname}' + ) zip_file.writestr(contract_filename, contract.file) - - filename = f'{form.name.replace(" ", "_")}_{form.season}' + filename = f'{form.name.replace(' ', '_')}_{form.season}' return StreamingResponse( io.BytesIO(zipped_contracts.getvalue()), media_type='application/zip', @@ -201,39 +314,69 @@ def get_contract_files( } ) + @router.get('/{form_id}/recap') def get_contract_recap( form_id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user) ): + """Get a contract recap for a given form""" if not form_service.is_allowed(session, user, form_id): - raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contract recap', 'get')) + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('contract recap', 'get') + ) form = form_service.get_one(session, form_id=form_id) contracts = service.get_all(session, user, forms=[form.name]) - return StreamingResponse( io.BytesIO(generate_recap(contracts, form)), media_type='application/zip', headers={ - 'Content-Disposition': f'attachment; filename=filename.ods' + 'Content-Disposition': ( + 'attachment; filename=filename.ods' + ) } ) -@router.get('/{id}', response_model=models.ContractPublic) -def get_contract(id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user)): - if not service.is_allowed(session, user, id): - raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contract', 'get')) - result = service.get_one(session, id) + +@router.get('/{_id}', response_model=models.ContractPublic) +def get_contract( + _id: int, + session: Session = Depends(get_session), + user: models.User = Depends(get_current_user) +): + """Get a contract route""" + if not service.is_allowed(session, user, _id): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('contract', 'get') + ) + result = service.get_one(session, _id) if result is None: - raise HTTPException(status_code=404, detail=messages.Messages.not_found('contract')) + raise HTTPException( + status_code=404, + detail=messages.Messages.not_found('contract') + ) return result -@router.delete('/{id}', response_model=models.ContractPublic) -def delete_contract(id: int, session: Session = Depends(get_session), user: models.User = Depends(get_current_user)): - if not service.is_allowed(session, user, id): - raise HTTPException(status_code=403, detail=messages.Messages.not_allowed('contract', 'delete')) - result = service.delete_one(session, id) + +@router.delete('/{_id}', response_model=models.ContractPublic) +def delete_contract( + _id: int, + session: Session = Depends(get_session), + user: models.User = Depends(get_current_user) +): + """Delete contract route""" + if not service.is_allowed(session, user, _id): + raise HTTPException( + status_code=403, + detail=messages.Messages.not_allowed('contract', 'delete') + ) + result = service.delete_one(session, _id) if result is None: - raise HTTPException(status_code=404, detail=messages.Messages.not_found('contract')) + raise HTTPException( + status_code=404, + detail=messages.Messages.not_found('contract') + ) return result diff --git a/backend/src/contracts/generate_contract.py b/backend/src/contracts/generate_contract.py index 735b405..1fea28c 100644 --- a/backend/src/contracts/generate_contract.py +++ b/backend/src/contracts/generate_contract.py @@ -1,11 +1,13 @@ +import html +import io +import pathlib + import jinja2 -import src.models as models -import html +from odfdo import Cell, Document, Row, Table +from src import models from weasyprint import HTML -import io -import pathlib def generate_html_contract( contract: models.Contract, @@ -14,10 +16,11 @@ def generate_html_contract( reccurents: list[dict], recurrent_price: float | None = None, total_price: float | None = None -): +): template_dir = pathlib.Path("./src/contracts/templates").resolve() template_loader = jinja2.FileSystemLoader(searchpath=template_dir) - template_env = jinja2.Environment(loader=template_loader, autoescape=jinja2.select_autoescape(["html", "xml"])) + template_env = jinja2.Environment( + loader=template_loader, autoescape=jinja2.select_autoescape(["html", "xml"])) template_file = "layout.html" template = template_env.get_template(template_file) output_text = template.render( @@ -28,41 +31,36 @@ def generate_html_contract( referer_email=contract.form.referer.email, productor_name=contract.form.productor.name, productor_address=contract.form.productor.address, - payment_methods_map={"cheque": "Ordre du chèque", "transfer": "virements"}, + payment_methods_map={ + "cheque": "Ordre du chèque", + "transfer": "virements"}, productor_payment_methods=contract.form.productor.payment_methods, - member_name=f'{html.escape(contract.firstname)} {html.escape(contract.lastname)}', - member_email=html.escape(contract.email), - member_phone=html.escape(contract.phone), + member_name=f'{ + html.escape( + contract.firstname)} { + html.escape( + contract.lastname)}', + member_email=html.escape( + contract.email), + member_phone=html.escape( + contract.phone), contract_start_date=contract.form.start, contract_end_date=contract.form.end, occasionals=occasionals, recurrents=reccurents, recurrent_price=recurrent_price, total_price=total_price, - contract_payment_method={"cheque": "chèque", "transfer": "virements"}[contract.payment_method], - cheques=cheques - ) - # options = { - # 'page-size': 'Letter', - # 'margin-top': '0.5in', - # 'margin-right': '0.5in', - # 'margin-bottom': '0.5in', - # 'margin-left': '0.5in', - # 'encoding': "UTF-8", - # 'print-media-type': True, - # "disable-javascript": True, - # "disable-external-links": True, - # 'enable-local-file-access': False, - # "disable-local-file-access": True, - # "no-images": True, - # } + contract_payment_method={ + "cheque": "chèque", + "transfer": "virements"}[ + contract.payment_method], + cheques=cheques) return HTML( string=output_text, base_url=template_dir, ).write_pdf() -from odfdo import Document, Table, Row, Cell def generate_recap( contracts: list[models.Contract], @@ -76,9 +74,8 @@ def generate_recap( sheet.set_values(data) doc.body.append(sheet) - + buffer = io.BytesIO() doc.save(buffer) return buffer.getvalue() - diff --git a/backend/src/contracts/service.py b/backend/src/contracts/service.py index 835cc6a..79cb063 100644 --- a/backend/src/contracts/service.py +++ b/backend/src/contracts/service.py @@ -1,28 +1,57 @@ +"""Contract service responsible for read, create, update and delete contracts""" +from sqlalchemy.orm import selectinload from sqlmodel import Session, select -import src.models as models +from src import models + def get_all( session: Session, user: models.User, - forms: list[str] = [], + forms: list[str] | None = None, form_id: int | None = None, ) -> list[models.ContractPublic]: - statement = select(models.Contract)\ - .join(models.Form, models.Contract.form_id == models.Form.id)\ - .join(models.Productor, models.Form.productor_id == models.Productor.id)\ - .where(models.Productor.type.in_([r.name for r in user.roles]))\ + """Get all contracts""" + statement = ( + select(models.Contract) + .join( + models.Form, + models.Contract.form_id == models.Form.id + ) + .join( + models.Productor, + models.Form.productor_id == models.Productor.id + ) + .where( + models.Productor.type.in_( + [r.name for r in user.roles] + ) + ) .distinct() - if len(forms) > 0: + ) + if forms: statement = statement.where(models.Form.name.in_(forms)) if form_id: statement = statement.where(models.Form.id == form_id) return session.exec(statement.order_by(models.Contract.id)).all() -def get_one(session: Session, contract_id: int) -> models.ContractPublic: + +def get_one( + session: Session, + contract_id: int +) -> models.ContractPublic: + """Get one contract""" return session.get(models.Contract, contract_id) -def create_one(session: Session, contract: models.ContractCreate) -> models.ContractPublic: - contract_create = contract.model_dump(exclude_unset=True, exclude=["products", "cheques"]) + +def create_one( + session: Session, + contract: models.ContractCreate +) -> models.ContractPublic: + """Create one contract""" + contract_create = contract.model_dump( + exclude_unset=True, + exclude=["products", "cheques"] + ) new_contract = models.Contract(**contract_create) new_contract.cheques = [ @@ -45,10 +74,27 @@ def create_one(session: Session, contract: models.ContractCreate) -> models.Cont session.add(new_contract) session.commit() session.refresh(new_contract) - return new_contract -def add_contract_file(session: Session, id: int, file: bytes, price: float): - statement = select(models.Contract).where(models.Contract.id == id) + statement = ( + select(models.Contract) + .where(models.Contract.id == new_contract.id) + .options( + selectinload(models.Contract.form) + .selectinload(models.Form.productor) + ) + ) + + return session.exec(statement).one() + + +def add_contract_file( + session: Session, + _id: int, + file: bytes, + price: float +): + """Add a file to an existing contract""" + statement = select(models.Contract).where(models.Contract.id == _id) result = session.exec(statement) contract = result.first() contract.total_price = price @@ -58,8 +104,14 @@ def add_contract_file(session: Session, id: int, file: bytes, price: float): session.refresh(contract) return contract -def update_one(session: Session, id: int, contract: models.ContractUpdate) -> models.ContractPublic: - statement = select(models.Contract).where(models.Contract.id == id) + +def update_one( + session: Session, + _id: int, + contract: models.ContractUpdate +) -> models.ContractPublic: + """Update one contract""" + statement = select(models.Contract).where(models.Contract.id == _id) result = session.exec(statement) new_contract = result.first() if not new_contract: @@ -72,8 +124,13 @@ def update_one(session: Session, id: int, contract: models.ContractUpdate) -> mo session.refresh(new_contract) return new_contract -def delete_one(session: Session, id: int) -> models.ContractPublic: - statement = select(models.Contract).where(models.Contract.id == id) + +def delete_one( + session: Session, + _id: int +) -> models.ContractPublic: + """Delete one contract""" + statement = select(models.Contract).where(models.Contract.id == _id) result = session.exec(statement) contract = result.first() if not contract: @@ -83,11 +140,29 @@ def delete_one(session: Session, id: int) -> models.ContractPublic: session.commit() return result -def is_allowed(session: Session, user: models.User, id: int) -> bool: - statement = select(models.Contract)\ - .join(models.Form, models.Contract.form_id == models.Form.id)\ - .join(models.Productor, models.Form.productor_id == models.Productor.id)\ - .where(models.Contract.id == id)\ - .where(models.Productor.type.in_([r.name for r in user.roles]))\ + +def is_allowed( + session: Session, + user: models.User, + _id: int +) -> bool: + """Determine if a user is allowed to access a contract by id""" + statement = ( + select(models.Contract) + .join( + models.Form, + models.Contract.form_id == models.Form.id + ) + .join( + models.Productor, + models.Form.productor_id == models.Productor.id + ) + .where(models.Contract.id == _id) + .where( + models.Productor.type.in_( + [r.name for r in user.roles] + ) + ) .distinct() - return len(session.exec(statement).all()) > 0 \ No newline at end of file + ) + return len(session.exec(statement).all()) > 0 diff --git a/backend/src/database.py b/backend/src/database.py index c9eaae1..b795367 100644 --- a/backend/src/database.py +++ b/backend/src/database.py @@ -1,11 +1,14 @@ -from sqlmodel import create_engine, SQLModel, Session +from sqlmodel import Session, SQLModel, create_engine from src.settings import settings -engine = create_engine(f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}') +engine = create_engine( + f'postgresql://{settings.db_user}:{settings.db_pass}@{settings.db_host}:5432/{settings.db_name}') + def get_session(): with Session(engine) as session: yield session + def create_all_tables(): SQLModel.metadata.create_all(engine) diff --git a/backend/src/forms/exceptions.py b/backend/src/forms/exceptions.py index 2ce2650..55026fb 100644 --- a/backend/src/forms/exceptions.py +++ b/backend/src/forms/exceptions.py @@ -1,17 +1,26 @@ +"""Forms module exceptions""" +import logging + class FormServiceError(Exception): + """Form service exception""" def __init__(self, message: str): super().__init__(message) + logging.error('FormService : %s', message) + class UserNotFoundError(FormServiceError): pass + class ProductorNotFoundError(FormServiceError): pass + class FormNotFoundError(FormServiceError): pass + class FormCreateError(FormServiceError): def __init__(self, message: str, field: str | None = None): super().__init__(message) - self.field = field \ No newline at end of file + self.field = field diff --git a/backend/src/forms/forms.py b/backend/src/forms/forms.py index 836b863..86fffbd 100644 --- a/backend/src/forms/forms.py +++ b/backend/src/forms/forms.py @@ -1,14 +1,15 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -import src.messages as messages -import src.models as models -from src.database import get_session -from sqlmodel import Session -import src.forms.service as service import src.forms.exceptions as exceptions +import src.forms.service as service +import src.messages as messages +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session +from src import models from src.auth.auth import get_current_user +from src.database import get_session router = APIRouter(prefix='/forms') + @router.get('', response_model=list[models.FormPublic]) async def get_forms( seasons: list[str] = Query([]), @@ -18,6 +19,7 @@ async def get_forms( ): return service.get_all(session, seasons, productors, current_season) + @router.get('/referents', response_model=list[models.FormPublic]) async def get_forms_filtered( seasons: list[str] = Query([]), @@ -28,53 +30,60 @@ async def get_forms_filtered( ): return service.get_all(session, seasons, productors, current_season, user) -@router.get('/{id}', response_model=models.FormPublic) -async def get_form(id: int, session: Session = Depends(get_session)): - result = service.get_one(session, id) + +@router.get('/{_id}', response_model=models.FormPublic) +async def get_form(_id: int, session: Session = Depends(get_session)): + result = service.get_one(session, _id) if result is None: - raise HTTPException(status_code=404, detail=messages.Messages.not_found('form')) + raise HTTPException( + status_code=404, + detail=messages.Messages.not_found('form') + ) return result + @router.post('', response_model=models.FormPublic) async def create_form( - form: models.FormCreate, + form: models.FormCreate, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): try: form = service.create_one(session, form) except exceptions.ProductorNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException(status_code=404, detail=str(error)) from error except exceptions.UserNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException(status_code=404, detail=str(error)) from error except exceptions.FormCreateError as error: - raise HTTPException(status_code=400, detail=str(error)) + raise HTTPException(status_code=400, detail=str(error)) from error return form -@router.put('/{id}', response_model=models.FormPublic) + +@router.put('/{_id}', response_model=models.FormPublic) async def update_form( - id: int, form: models.FormUpdate, + _id: int, form: models.FormUpdate, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): try: - result = service.update_one(session, id, form) + result = service.update_one(session, _id, form) except exceptions.FormNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException(status_code=404, detail=str(error)) from error except exceptions.ProductorNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException(status_code=404, detail=str(error)) from error except exceptions.UserNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException(status_code=404, detail=str(error)) from error return result -@router.delete('/{id}', response_model=models.FormPublic) + +@router.delete('/{_id}', response_model=models.FormPublic) async def delete_form( - id: int, + _id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): try: - result = service.delete_one(session, id) + result = service.delete_one(session, _id) except exceptions.FormNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException(status_code=404, detail=str(error)) from error return result diff --git a/backend/src/forms/service.py b/backend/src/forms/service.py index 0342508..0ae765b 100644 --- a/backend/src/forms/service.py +++ b/backend/src/forms/service.py @@ -1,12 +1,12 @@ -from sqlmodel import Session, select -from sqlalchemy import func - -import src.models as models import src.forms.exceptions as exceptions import src.messages as messages +from sqlalchemy import func +from sqlmodel import Session, select +from src import models + def get_all( - session: Session, + session: Session, seasons: list[str], productors: list[str], current_season: bool, @@ -14,45 +14,54 @@ def get_all( ) -> list[models.FormPublic]: statement = select(models.Form) if user: - statement = statement\ - .join(models.Productor, models.Form.productor_id == models.Productor.id)\ - .where(models.Productor.type.in_([r.name for r in user.roles]))\ - .distinct() + statement = statement .join( + models.Productor, + models.Form.productor_id == models.Productor.id) .where( + models.Productor.type.in_( + [ + r.name for r in user.roles])) .distinct() if len(seasons) > 0: statement = statement.where(models.Form.season.in_(seasons)) if len(productors) > 0: - statement = statement.join(models.Productor).where(models.Productor.name.in_(productors)) + statement = statement.join( + models.Productor).where( + models.Productor.name.in_(productors)) if not user: - statement = statement.where(models.Form.visible == True) + statement = statement.where(models.Form.visible) if current_season: subquery = ( select( - models.Productor.type, + models.Productor.type, func.max(models.Form.start).label("max_start") ) - .join(models.Form)\ - .group_by(models.Productor.type)\ + .join(models.Form) + .group_by(models.Productor.type) .subquery() ) statement = select(models.Form)\ .join(models.Productor)\ - .join(subquery, - (models.Productor.type == subquery.c.type) & - (models.Form.start == subquery.c.max_start) - ) + .join(subquery, + (models.Productor.type == subquery.c.type) & + (models.Form.start == subquery.c.max_start) + ) if not user: - statement = statement.where(models.Form.visible == True) + statement = statement.where(models.Form.visible) return session.exec(statement.order_by(models.Form.name)).all() return session.exec(statement.order_by(models.Form.name)).all() + def get_one(session: Session, form_id: int) -> models.FormPublic: return session.get(models.Form, form_id) + def create_one(session: Session, form: models.FormCreate) -> models.FormPublic: if not form: - raise exceptions.FormCreateError(messages.Messages.invalid_input('form', 'input cannot be None')) + raise exceptions.FormCreateError( + messages.Messages.invalid_input( + 'form', 'input cannot be None')) if not session.get(models.Productor, form.productor_id): - raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor')) + raise exceptions.ProductorNotFoundError( + messages.Messages.not_found('productor')) if not session.get(models.User, form.referer_id): raise exceptions.UserNotFoundError(messages.Messages.not_found('user')) form_create = form.model_dump(exclude_unset=True) @@ -62,14 +71,20 @@ def create_one(session: Session, form: models.FormCreate) -> models.FormPublic: session.refresh(new_form) return new_form -def update_one(session: Session, id: int, form: models.FormUpdate) -> models.FormPublic: - statement = select(models.Form).where(models.Form.id == id) + +def update_one( + session: Session, + _id: int, + form: models.FormUpdate) -> models.FormPublic: + statement = select(models.Form).where(models.Form.id == _id) result = session.exec(statement) new_form = result.first() if not new_form: raise exceptions.FormNotFoundError(messages.Messages.not_found('form')) - if form.productor_id and not session.get(models.Productor, form.productor_id): - raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor')) + if form.productor_id and not session.get( + models.Productor, form.productor_id): + raise exceptions.ProductorNotFoundError( + messages.Messages.not_found('productor')) if form.referer_id and not session.get(models.User, form.referer_id): raise exceptions.UserNotFoundError(messages.Messages.not_found('user')) form_updates = form.model_dump(exclude_unset=True) @@ -80,8 +95,9 @@ def update_one(session: Session, id: int, form: models.FormUpdate) -> models.For session.refresh(new_form) return new_form -def delete_one(session: Session, id: int) -> models.FormPublic: - statement = select(models.Form).where(models.Form.id == id) + +def delete_one(session: Session, _id: int) -> models.FormPublic: + statement = select(models.Form).where(models.Form.id == _id) result = session.exec(statement) form = result.first() if not form: @@ -91,10 +107,19 @@ def delete_one(session: Session, id: int) -> models.FormPublic: session.commit() return result -def is_allowed(session: Session, user: models.User, id: int) -> bool: - statement = select(models.Form)\ - .join(models.Productor, models.Form.productor_id == models.Productor.id)\ - .where(models.Form.id == id)\ - .where(models.Productor.type.in_([r.name for r in user.roles]))\ + +def is_allowed(session: Session, user: models.User, _id: int) -> bool: + statement = ( + select(models.Form) + .join( + models.Productor, + models.Form.productor_id == models.Productor.id) + .where(models.Form.id == _id) + .where( + models.Productor.type.in_( + [r.name for r in user.roles] + ) + ) .distinct() - return len(session.exec(statement).all()) > 0 \ No newline at end of file + ) + return len(session.exec(statement).all()) > 0 diff --git a/backend/src/main.py b/backend/src/main.py index 7b91e07..73b3f80 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -1,18 +1,15 @@ -from sqlmodel import SQLModel - from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware - -from src.templates.templates import router as template_router +from src.auth.auth import router as auth_router from src.contracts.contracts import router as contracts_router from src.forms.forms import router as forms_router from src.productors.productors import router as productors_router from src.products.products import router as products_router -from src.users.users import router as users_router -from src.auth.auth import router as auth_router -from src.shipments.shipments import router as shipment_router from src.settings import settings -from src.database import engine, create_all_tables +from src.shipments.shipments import router as shipment_router +from src.templates.templates import router as template_router +from src.users.users import router as users_router + app = FastAPI() @@ -34,4 +31,4 @@ app.include_router(productors_router, prefix="/api") app.include_router(products_router, prefix="/api") app.include_router(users_router, prefix="/api") app.include_router(auth_router, prefix="/api") -app.include_router(shipment_router, prefix="/api") \ No newline at end of file +app.include_router(shipment_router, prefix="/api") diff --git a/backend/src/messages.py b/backend/src/messages.py index 790cda7..ea8d91f 100644 --- a/backend/src/messages.py +++ b/backend/src/messages.py @@ -1,19 +1,20 @@ pdferror = 'An error occured during PDF generation please contact administrator' + class Messages: unauthorized = 'User is Unauthorized' notauthenticated = 'User is not authenticated' tokenexipired = 'Token has expired' invalidtoken = 'Token is invalid' - + @staticmethod def not_found(resource: str) -> str: return f'{resource.capitalize()} not found' - + @staticmethod def invalid_input(resource: str, reason: str = "") -> str: return f'Invalid {resource} input {':' if reason else ""} {reason}' @staticmethod def not_allowed(resource: str, action: str) -> str: - return f'User is not allowed to {action} this {resource}' \ No newline at end of file + return f'User is not allowed to {action} this {resource}' diff --git a/backend/src/models.py b/backend/src/models.py index dcc407b..c17e8f4 100644 --- a/backend/src/models.py +++ b/backend/src/models.py @@ -1,99 +1,136 @@ -from sqlmodel import Field, SQLModel, Relationship, Column, LargeBinary +import datetime from enum import StrEnum from typing import Optional -import datetime + +from sqlmodel import Column, Field, LargeBinary, Relationship, SQLModel + class ContractType(SQLModel, table=True): - id: int | None = Field(default=None, primary_key=True) + id: int | None = Field( + default=None, + primary_key=True + ) name: str + class UserContractTypeLink(SQLModel, table=True): - user_id: int = Field(foreign_key="user.id", primary_key=True) - contract_type_id: int = Field(foreign_key="contracttype.id", primary_key=True) + user_id: int = Field( + foreign_key='user.id', + primary_key=True + ) + contract_type_id: int = Field( + foreign_key='contracttype.id', + primary_key=True + ) + class UserBase(SQLModel): name: str email: str + class UserPublic(UserBase): id: int roles: list[ContractType] + class User(UserBase, table=True): id: int | None = Field(default=None, primary_key=True) roles: list[ContractType] = Relationship( link_model=UserContractTypeLink ) + class UserUpdate(SQLModel): name: str | None email: str | None role_names: list[str] | None + class UserCreate(UserBase): role_names: list[str] | None + class PaymentMethodBase(SQLModel): name: str details: str max: int | None + class PaymentMethod(PaymentMethodBase, table=True): id: int | None = Field(default=None, primary_key=True) - productor_id: int = Field(foreign_key="productor.id", ondelete="CASCADE") - productor: Optional["Productor"] = Relationship( - back_populates="payment_methods", + productor_id: int = Field(foreign_key='productor.id', ondelete='CASCADE') + productor: Optional['Productor'] = Relationship( + back_populates='payment_methods', ) + class PaymentMethodPublic(PaymentMethodBase): id: int - productor: Optional["Productor"] + productor: Optional['Productor'] + class ProductorBase(SQLModel): name: str address: str type: str + class ProductorPublic(ProductorBase): id: int - products: list["Product"] = [] - payment_methods: list["PaymentMethod"] = [] + products: list['Product'] = Field(default_factory=list) + payment_methods: list['PaymentMethod'] = Field(default_factory=list) + class Productor(ProductorBase, table=True): id: int | None = Field(default=None, primary_key=True) - products: list["Product"] = Relationship( + products: list['Product'] = Relationship( back_populates='productor', sa_relationship_kwargs={ - "order_by": "Product.name" + 'order_by': 'Product.name' }, ) - payment_methods: list["PaymentMethod"] = Relationship( - back_populates="productor", + payment_methods: list['PaymentMethod'] = Relationship( + back_populates='productor', cascade_delete=True ) + class ProductorUpdate(SQLModel): name: str | None address: str | None - payment_methods: list["PaymentMethod"] = [] + payment_methods: list['PaymentMethod'] = Field(default_factory=list) type: str | None + class ProductorCreate(ProductorBase): - payment_methods: list["PaymentMethod"] = [] + payment_methods: list['PaymentMethod'] = Field(default_factory=list) + class Unit(StrEnum): - GRAMS = "1" - KILO = "2" - PIECE = "3" + GRAMS = '1' + KILO = '2' + PIECE = '3' + class ProductType(StrEnum): - OCCASIONAL = "1" - RECCURENT = "2" + OCCASIONAL = '1' + RECCURENT = '2' + class ShipmentProductLink(SQLModel, table=True): - shipment_id: Optional[int] = Field(default=None, foreign_key="shipment.id", primary_key=True) - product_id: Optional[int] = Field(default=None, foreign_key="product.id", primary_key=True) + shipment_id: Optional[int] = Field( + default=None, + foreign_key='shipment.id', + primary_key=True + ) + product_id: Optional[int] = Field( + default=None, + foreign_key='product.id', + primary_key=True + ) + class ProductBase(SQLModel): name: str @@ -103,17 +140,31 @@ class ProductBase(SQLModel): quantity: float | None quantity_unit: str | None type: ProductType - productor_id: int | None = Field(default=None, foreign_key="productor.id") + productor_id: int | None = Field( + default=None, + foreign_key='productor.id' + ) + class ProductPublic(ProductBase): id: int productor: Productor | None - shipments: list["Shipment"] | None + shipments: list['Shipment'] | None + class Product(ProductBase, table=True): - id: int | None = Field(default=None, primary_key=True) - shipments: list["Shipment"] = Relationship(back_populates="products", link_model=ShipmentProductLink) - productor: Optional[Productor] = Relationship(back_populates="products") + id: int | None = Field( + default=None, + primary_key=True + ) + shipments: list['Shipment'] = Relationship( + back_populates='products', + link_model=ShipmentProductLink + ) + productor: Optional[Productor] = Relationship( + back_populates='products' + ) + class ProductUpdate(SQLModel): name: str | None @@ -125,41 +176,46 @@ class ProductUpdate(SQLModel): productor_id: int | None type: ProductType | None + class ProductCreate(ProductBase): pass + class FormBase(SQLModel): name: str - productor_id: int | None = Field(default=None, foreign_key="productor.id") - referer_id: int | None = Field(default=None, foreign_key="user.id") + productor_id: int | None = Field(default=None, foreign_key='productor.id') + referer_id: int | None = Field(default=None, foreign_key='user.id') season: str start: datetime.date end: datetime.date minimum_shipment_value: float | None visible: bool + class FormPublic(FormBase): id: int productor: ProductorPublic | None referer: User | None - shipments: list["ShipmentPublic"] = [] + shipments: list['ShipmentPublic'] = Field(default_factory=list) + class Form(FormBase, table=True): id: int | None = Field(default=None, primary_key=True) productor: Optional['Productor'] = Relationship() referer: Optional['User'] = Relationship() - shipments: list["Shipment"] = Relationship( - back_populates="form", + shipments: list['Shipment'] = Relationship( + back_populates='form', cascade_delete=True, sa_relationship_kwargs={ - "order_by": "Shipment.name" + 'order_by': 'Shipment.name' }, ) - contracts: list["Contract"] = Relationship( - back_populates="form", + contracts: list['Contract'] = Relationship( + back_populates='form', cascade_delete=True ) + class FormUpdate(SQLModel): name: str | None productor_id: int | None @@ -170,35 +226,44 @@ class FormUpdate(SQLModel): minimum_shipment_value: float | None visible: bool | None + class FormCreate(FormBase): pass + class TemplateBase(SQLModel): pass + class TemplatePublic(TemplateBase): id: int + class Template(TemplateBase, table=True): id: int | None = Field(default=None, primary_key=True) - + + class TemplateUpdate(SQLModel): pass + class TemplateCreate(TemplateBase): pass + class ChequeBase(SQLModel): name: str value: str + class Cheque(ChequeBase, table=True): id: int | None = Field(default=None, primary_key=True) - contract_id: int = Field(foreign_key="contract.id", ondelete="CASCADE") - contract: Optional["Contract"] = Relationship( - back_populates="cheques", + contract_id: int = Field(foreign_key='contract.id', ondelete='CASCADE') + contract: Optional['Contract'] = Relationship( + back_populates='cheques', ) + class ContractBase(SQLModel): firstname: str lastname: str @@ -207,105 +272,122 @@ class ContractBase(SQLModel): payment_method: str cheque_quantity: int + class Contract(ContractBase, table=True): id: int | None = Field(default=None, primary_key=True) form_id: int = Field( - foreign_key="form.id", + foreign_key='form.id', nullable=False, - ondelete="CASCADE" + ondelete='CASCADE' ) - products: list["ContractProduct"] = Relationship( - back_populates="contract", + products: list['ContractProduct'] = Relationship( + back_populates='contract', cascade_delete=True ) - form: Optional[Form] = Relationship(back_populates="contracts") + form: Form = Relationship(back_populates='contracts') cheques: list[Cheque] = Relationship( - back_populates="contract", + back_populates='contract', cascade_delete=True ) file: bytes = Field(sa_column=Column(LargeBinary)) total_price: float | None + class ContractCreate(ContractBase): - products: list["ContractProductCreate"] = [] - cheques: list["Cheque"] = [] + products: list['ContractProductCreate'] = Field(default_factory=list) + cheques: list['Cheque'] = Field(default_factory=list) form_id: int + class ContractUpdate(SQLModel): file: bytes + class ContractPublic(ContractBase): id: int - products: list["ContractProduct"] = [] + products: list['ContractProduct'] = Field(default_factory=list) form: Form total_price: float | None # file: bytes + class ContractProductBase(SQLModel): product_id: int = Field( - foreign_key="product.id", + foreign_key='product.id', nullable=False, - ondelete="CASCADE" + ondelete='CASCADE' ) shipment_id: int | None = Field( default=None, - foreign_key="shipment.id", + foreign_key='shipment.id', nullable=True, - ondelete="CASCADE" + ondelete='CASCADE' ) quantity: float + class ContractProduct(ContractProductBase, table=True): id: int | None = Field(default=None, primary_key=True) contract_id: int = Field( - foreign_key="contract.id", + foreign_key='contract.id', nullable=False, - ondelete="CASCADE" + ondelete='CASCADE' ) - contract: Optional["Contract"] = Relationship(back_populates="products") - product: Optional["Product"] = Relationship() - shipment: Optional["Shipment"] = Relationship() + contract: Optional['Contract'] = Relationship(back_populates='products') + product: Optional['Product'] = Relationship() + shipment: Optional['Shipment'] = Relationship() + class ContractProductPublic(ContractProductBase): id: int quantity: float contract: Contract product: Product - shipment: Optional["Shipment"] + shipment: Optional['Shipment'] + class ContractProductCreate(ContractProductBase): pass + class ContractProductUpdate(ContractProductBase): pass + class ShipmentBase(SQLModel): name: str date: datetime.date - form_id: int | None = Field(default=None, foreign_key="form.id", ondelete="CASCADE") + form_id: int | None = Field( + default=None, + foreign_key='form.id', + ondelete='CASCADE') + class ShipmentPublic(ShipmentBase): id: int - products: list[Product] = [] + products: list[Product] = Field(default_factory=list) form: Form | None + class Shipment(ShipmentBase, table=True): id: int | None = Field(default=None, primary_key=True) products: list[Product] = Relationship( - back_populates="shipments", + back_populates='shipments', link_model=ShipmentProductLink, sa_relationship_kwargs={ - "order_by": "Product.name" + 'order_by': 'Product.name' }, ) - form: Optional[Form] = Relationship(back_populates="shipments") + form: Optional[Form] = Relationship(back_populates='shipments') + class ShipmentUpdate(SQLModel): name: str | None date: datetime.date | None - product_ids: list[int] | None = [] + product_ids: list[int] | None = Field(default_factory=list) + class ShipmentCreate(ShipmentBase): - product_ids: list[int] = [] - form_id: int \ No newline at end of file + product_ids: list[int] = Field(default_factory=list) + form_id: int diff --git a/backend/src/productors/__init__.py b/backend/src/productors/__init__.py index 10fe5b0..e9a63bc 100644 --- a/backend/src/productors/__init__.py +++ b/backend/src/productors/__init__.py @@ -1,3 +1,3 @@ # SPDX-FileCopyrightText: 2026-present Julien Aldon # -# SPDX-License-Identifier: MIT \ No newline at end of file +# SPDX-License-Identifier: MIT diff --git a/backend/src/productors/exceptions.py b/backend/src/productors/exceptions.py index 32b0beb..8455109 100644 --- a/backend/src/productors/exceptions.py +++ b/backend/src/productors/exceptions.py @@ -1,11 +1,17 @@ +import logging + + class ProductorServiceError(Exception): def __init__(self, message: str): super().__init__(message) + logging.error('ProductorService : %s', message) + class ProductorNotFoundError(ProductorServiceError): pass + class ProductorCreateError(ProductorServiceError): def __init__(self, message: str, field: str | None = None): super().__init__(message) - self.field = field \ No newline at end of file + self.field = field diff --git a/backend/src/productors/productors.py b/backend/src/productors/productors.py index 2dd390b..b8a2b3d 100644 --- a/backend/src/productors/productors.py +++ b/backend/src/productors/productors.py @@ -1,14 +1,15 @@ -from fastapi import APIRouter, HTTPException, Depends, Query import src.messages as messages -import src.models as models -from src.database import get_session -from sqlmodel import Session -import src.productors.service as service import src.productors.exceptions as exceptions +import src.productors.service as service +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session +from src import models from src.auth.auth import get_current_user +from src.database import get_session router = APIRouter(prefix='/productors') + @router.get('', response_model=list[models.ProductorPublic]) def get_productors( names: list[str] = Query([]), @@ -18,49 +19,56 @@ def get_productors( ): return service.get_all(session, user, names, types) -@router.get('/{id}', response_model=models.ProductorPublic) + +@router.get('/{_id}', response_model=models.ProductorPublic) def get_productor( - id: int, + _id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): - result = service.get_one(session, id) + result = service.get_one(session, _id) if result is None: - raise HTTPException(status_code=404, detail=messages.Messages.not_found('productor')) + raise HTTPException( + status_code=404, + detail=messages.Messages.not_found('productor') + ) return result + @router.post('', response_model=models.ProductorPublic) def create_productor( - productor: models.ProductorCreate, + productor: models.ProductorCreate, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): try: result = service.create_one(session, productor) except exceptions.ProductorCreateError as error: - raise HTTPException(status_code=400, detail=str(error)) + raise HTTPException(status_code=400, detail=str(error)) from error return result -@router.put('/{id}', response_model=models.ProductorPublic) + +@router.put('/{_id}', response_model=models.ProductorPublic) def update_productor( - id: int, productor: models.ProductorUpdate, + _id: int, productor: models.ProductorUpdate, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): try: - result = service.update_one(session, id, productor) + result = service.update_one(session, _id, productor) except exceptions.ProductorNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException(status_code=404, detail=str(error)) from error return result -@router.delete('/{id}', response_model=models.ProductorPublic) + +@router.delete('/{_id}', response_model=models.ProductorPublic) def delete_productor( - id: int, + _id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): try: - result = service.delete_one(session, id) + result = service.delete_one(session, _id) except exceptions.ProductorNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException(status_code=404, detail=str(error)) from error return result diff --git a/backend/src/productors/service.py b/backend/src/productors/service.py index 2580573..73248ca 100644 --- a/backend/src/productors/service.py +++ b/backend/src/productors/service.py @@ -1,12 +1,13 @@ -from sqlmodel import Session, select -import src.models as models -import src.productors.exceptions as exceptions import src.messages as messages +import src.productors.exceptions as exceptions +from sqlmodel import Session, select +from src import models + def get_all( session: Session, - user: models.User, - names: list[str], + user: models.User, + names: list[str], types: list[str] ) -> list[models.ProductorPublic]: statement = select(models.Productor)\ @@ -18,13 +19,20 @@ def get_all( statement = statement.where(models.Productor.type.in_(types)) return session.exec(statement.order_by(models.Productor.name)).all() + def get_one(session: Session, productor_id: int) -> models.ProductorPublic: return session.get(models.Productor, productor_id) -def create_one(session: Session, productor: models.ProductorCreate) -> models.ProductorPublic: + +def create_one( + session: Session, + productor: models.ProductorCreate) -> models.ProductorPublic: if not productor: - raise exceptions.ProductorCreateError(messages.Messages.invalid_input('productor', 'input cannot be None')) - productor_create = productor.model_dump(exclude_unset=True, exclude='payment_methods') + raise exceptions.ProductorCreateError( + messages.Messages.invalid_input( + 'productor', 'input cannot be None')) + productor_create = productor.model_dump( + exclude_unset=True, exclude='payment_methods') new_productor = models.Productor(**productor_create) new_productor.payment_methods = [ @@ -39,13 +47,18 @@ def create_one(session: Session, productor: models.ProductorCreate) -> models.Pr session.refresh(new_productor) return new_productor -def update_one(session: Session, id: int, productor: models.ProductorUpdate) -> models.ProductorPublic: + +def update_one( + session: Session, + id: int, + productor: models.ProductorUpdate) -> models.ProductorPublic: statement = select(models.Productor).where(models.Productor.id == id) result = session.exec(statement) new_productor = result.first() if not new_productor: - raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor')) - + raise exceptions.ProductorNotFoundError( + messages.Messages.not_found('productor')) + productor_updates = productor.model_dump(exclude_unset=True) if 'payment_methods' in productor_updates: new_productor.payment_methods.clear() @@ -67,12 +80,14 @@ def update_one(session: Session, id: int, productor: models.ProductorUpdate) -> session.refresh(new_productor) return new_productor + def delete_one(session: Session, id: int) -> models.ProductorPublic: statement = select(models.Productor).where(models.Productor.id == id) result = session.exec(statement) productor = result.first() if not productor: - raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor')) + raise exceptions.ProductorNotFoundError( + messages.Messages.not_found('productor')) result = models.ProductorPublic.model_validate(productor) session.delete(productor) session.commit() diff --git a/backend/src/products/__init__.py b/backend/src/products/__init__.py index 10fe5b0..e9a63bc 100644 --- a/backend/src/products/__init__.py +++ b/backend/src/products/__init__.py @@ -1,3 +1,3 @@ # SPDX-FileCopyrightText: 2026-present Julien Aldon # -# SPDX-License-Identifier: MIT \ No newline at end of file +# SPDX-License-Identifier: MIT diff --git a/backend/src/products/exceptions.py b/backend/src/products/exceptions.py index ae71d08..9675b6d 100644 --- a/backend/src/products/exceptions.py +++ b/backend/src/products/exceptions.py @@ -2,13 +2,16 @@ class ProductServiceError(Exception): def __init__(self, message: str): super().__init__(message) + class ProductorNotFoundError(ProductServiceError): pass + class ProductNotFoundError(ProductServiceError): pass + class ProductCreateError(ProductServiceError): def __init__(self, message: str, field: str | None = None): super().__init__(message) - self.field = field \ No newline at end of file + self.field = field diff --git a/backend/src/products/products.py b/backend/src/products/products.py index afb7988..3ce5007 100644 --- a/backend/src/products/products.py +++ b/backend/src/products/products.py @@ -1,18 +1,19 @@ -from fastapi import APIRouter, HTTPException, Depends, Query import src.messages as messages -import src.models as models -from src.database import get_session -from sqlmodel import Session -import src.products.service as service import src.products.exceptions as exceptions +import src.products.service as service +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session +from src import models from src.auth.auth import get_current_user +from src.database import get_session router = APIRouter(prefix='/products') + @router.get('', response_model=list[models.ProductPublic], ) def get_products( user: models.User = Depends(get_current_user), - session: Session = Depends(get_session), + session: Session = Depends(get_session), names: list[str] = Query([]), types: list[str] = Query([]), productors: list[str] = Query([]), @@ -20,25 +21,28 @@ def get_products( return service.get_all( session, user, - names, - productors, + names, + productors, types, ) + @router.get('/{id}', response_model=models.ProductPublic) def get_product( - id: int, + id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): result = service.get_one(session, id) if result is None: - raise HTTPException(status_code=404, detail=messages.Messages.not_found('product')) + raise HTTPException(status_code=404, + detail=messages.Messages.not_found('product')) return result + @router.post('', response_model=models.ProductPublic) def create_product( - product: models.ProductCreate, + product: models.ProductCreate, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): @@ -50,9 +54,10 @@ def create_product( raise HTTPException(status_code=404, detail=str(error)) return result + @router.put('/{id}', response_model=models.ProductPublic) def update_product( - id: int, product: models.ProductUpdate, + id: int, product: models.ProductUpdate, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): @@ -64,9 +69,10 @@ def update_product( raise HTTPException(status_code=404, detail=str(error)) return result + @router.delete('/{id}', response_model=models.ProductPublic) def delete_product( - id: int, + id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): diff --git a/backend/src/products/service.py b/backend/src/products/service.py index fc200e8..df35dcc 100644 --- a/backend/src/products/service.py +++ b/backend/src/products/service.py @@ -1,19 +1,23 @@ -from sqlmodel import Session, select -import src.models as models -import src.products.exceptions as exceptions import src.messages as messages +import src.products.exceptions as exceptions +from sqlmodel import Session, select +from src import models + def get_all( - session: Session, + session: Session, user: models.User, - names: list[str], + names: list[str], productors: list[str], types: list[str], ) -> list[models.ProductPublic]: - statement = select(models.Product)\ - .join(models.Productor, models.Product.productor_id == models.Productor.id)\ - .where(models.Productor.type.in_([r.name for r in user.roles]))\ - .distinct() + statement = select( + models.Product) .join( + models.Productor, + models.Product.productor_id == models.Productor.id) .where( + models.Productor.type.in_( + [ + r.name for r in user.roles])) .distinct() if len(names) > 0: statement = statement.where(models.Product.name.in_(names)) if len(productors) > 0: @@ -22,14 +26,21 @@ def get_all( statement = statement.where(models.Product.type.in_(types)) return session.exec(statement.order_by(models.Product.name)).all() + def get_one(session: Session, product_id: int) -> models.ProductPublic: return session.get(models.Product, product_id) -def create_one(session: Session, product: models.ProductCreate) -> models.ProductPublic: + +def create_one( + session: Session, + product: models.ProductCreate) -> models.ProductPublic: if not product: - raise exceptions.ProductCreateError(messages.Messages.invalid_input('product', 'input cannot be None')) + raise exceptions.ProductCreateError( + messages.Messages.invalid_input( + 'product', 'input cannot be None')) if not session.get(models.Productor, product.productor_id): - raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor')) + raise exceptions.ProductorNotFoundError( + messages.Messages.not_found('productor')) product_create = product.model_dump(exclude_unset=True) new_product = models.Product(**product_create) session.add(new_product) @@ -37,14 +48,21 @@ def create_one(session: Session, product: models.ProductCreate) -> models.Produc session.refresh(new_product) return new_product -def update_one(session: Session, id: int, product: models.ProductUpdate) -> models.ProductPublic: + +def update_one( + session: Session, + id: int, + product: models.ProductUpdate) -> models.ProductPublic: statement = select(models.Product).where(models.Product.id == id) result = session.exec(statement) new_product = result.first() if not new_product: - raise exceptions.ProductNotFoundError(messages.Messages.not_found('product')) - if product.productor_id and not session.get(models.Productor, product.productor_id): - raise exceptions.ProductorNotFoundError(messages.Messages.not_found('productor')) + raise exceptions.ProductNotFoundError( + messages.Messages.not_found('product')) + if product.productor_id and not session.get( + models.Productor, product.productor_id): + raise exceptions.ProductorNotFoundError( + messages.Messages.not_found('productor')) product_updates = product.model_dump(exclude_unset=True) for key, value in product_updates.items(): @@ -55,12 +73,14 @@ def update_one(session: Session, id: int, product: models.ProductUpdate) -> mode session.refresh(new_product) return new_product + def delete_one(session: Session, id: int) -> models.ProductPublic: statement = select(models.Product).where(models.Product.id == id) result = session.exec(statement) product = result.first() if not product: - raise exceptions.ProductNotFoundError(messages.Messages.not_found('product')) + raise exceptions.ProductNotFoundError( + messages.Messages.not_found('product')) result = models.ProductPublic.model_validate(product) session.delete(product) session.commit() diff --git a/backend/src/settings.py b/backend/src/settings.py index 44649ca..355357a 100644 --- a/backend/src/settings.py +++ b/backend/src/settings.py @@ -1,5 +1,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict + class Settings(BaseSettings): origins: str db_host: str @@ -20,10 +21,21 @@ class Settings(BaseSettings): env_file='../.env' ) + settings = Settings() -AUTH_URL = f"{settings.keycloak_server}/realms/{settings.keycloak_realm}/protocol/openid-connect/auth" -TOKEN_URL = f"{settings.keycloak_server}/realms/{settings.keycloak_realm}/protocol/openid-connect/token" -ISSUER = f"{settings.keycloak_server}/realms/{settings.keycloak_realm}" -JWKS_URL = f"{ISSUER}/protocol/openid-connect/certs" -LOGOUT_URL = f'{settings.keycloak_server}/realms/{settings.keycloak_realm}/protocol/openid-connect/logout' \ No newline at end of file +AUTH_URL = ( + f'{settings.keycloak_server}/realms/' + f'{settings.keycloak_realm}/protocol/openid-connect/auth' +) +TOKEN_URL = ( + f'{settings.keycloak_server}/realms/' + f'{settings.keycloak_realm}/protocol/openid-connect/token' +) + +ISSUER = f'{settings.keycloak_server}/realms/{settings.keycloak_realm}' +JWKS_URL = f'{ISSUER}/protocol/openid-connect/certs' +LOGOUT_URL = ( + f'{settings.keycloak_server}/realms/' + f'{settings.keycloak_realm}/protocol/openid-connect/logout' +) diff --git a/backend/src/shipments/__init__.py b/backend/src/shipments/__init__.py index 10fe5b0..e9a63bc 100644 --- a/backend/src/shipments/__init__.py +++ b/backend/src/shipments/__init__.py @@ -1,3 +1,3 @@ # SPDX-FileCopyrightText: 2026-present Julien Aldon # -# SPDX-License-Identifier: MIT \ No newline at end of file +# SPDX-License-Identifier: MIT diff --git a/backend/src/shipments/exceptions.py b/backend/src/shipments/exceptions.py index 39cf869..8493043 100644 --- a/backend/src/shipments/exceptions.py +++ b/backend/src/shipments/exceptions.py @@ -1,11 +1,17 @@ +import logging + + class ShipmentServiceError(Exception): def __init__(self, message: str): super().__init__(message) + logging.error('ShipmentService : %s', message) + class ShipmentNotFoundError(ShipmentServiceError): pass + class ShipmentCreateError(ShipmentServiceError): def __init__(self, message: str, field: str | None = None): super().__init__(message) - self.field = field \ No newline at end of file + self.field = field diff --git a/backend/src/shipments/service.py b/backend/src/shipments/service.py index 8d71413..02202e9 100644 --- a/backend/src/shipments/service.py +++ b/backend/src/shipments/service.py @@ -1,58 +1,111 @@ -from sqlmodel import Session, select -import src.models as models -import src.shipments.exceptions as exceptions -import src.messages as messages +# pylint: disable=E1101 import datetime +import src.messages as messages +import src.shipments.exceptions as exceptions +from sqlmodel import Session, select +from src import models + + def get_all( session: Session, user: models.User, - names: list[str], - dates: list[str], - forms: list[str] + names: list[str] = None, + dates: list[str] = None, + forms: list[str] = None ) -> list[models.ShipmentPublic]: - statement = select(models.Shipment)\ - .join(models.Form, models.Shipment.form_id == models.Form.id)\ - .join(models.Productor, models.Form.productor_id == models.Productor.id)\ - .where(models.Productor.type.in_([r.name for r in user.roles]))\ + statement = ( + select(models.Shipment) + .join( + models.Form, + models.Shipment.form_id == models.Form.id) + .join( + models.Productor, + models.Form.productor_id == models.Productor.id) + .where( + models.Productor.type.in_( + [r.name for r in user.roles] + ) + ) .distinct() - if len(names) > 0: + ) + if names and len(names) > 0: statement = statement.where(models.Shipment.name.in_(names)) - if len(dates) > 0: - statement = statement.where(models.Shipment.date.in_(list(map(lambda x: datetime.datetime.strptime(x, '%Y-%m-%d').date(), dates)))) - if len(forms) > 0: + if dates and len(dates) > 0: + statement = statement.where( + models.Shipment.date.in_( + list(map( + lambda x: datetime.datetime.strptime( + x, '%Y-%m-%d').date(), + dates + )) + ) + ) + if forms and len(forms) > 0: statement = statement.where(models.Form.name.in_(forms)) return session.exec(statement.order_by(models.Shipment.name)).all() + def get_one(session: Session, shipment_id: int) -> models.ShipmentPublic: return session.get(models.Shipment, shipment_id) -def create_one(session: Session, shipment: models.ShipmentCreate) -> models.ShipmentPublic: + +def create_one( + session: Session, + shipment: models.ShipmentCreate) -> models.ShipmentPublic: if shipment is None: - raise exceptions.ShipmentCreateError(messages.Messages.invalid_input('shipment', 'input cannot be None')) - products = session.exec(select(models.Product).where(models.Product.id.in_(shipment.product_ids))).all() - shipment_create = shipment.model_dump(exclude_unset=True, exclude={'product_ids'}) + raise exceptions.ShipmentCreateError( + messages.Messages.invalid_input( + 'shipment', 'input cannot be None')) + products = session.exec( + select(models.Product) + .where( + models.Product.id.in_( + shipment.product_ids + ) + ) + ).all() + shipment_create = shipment.model_dump( + exclude_unset=True, exclude={'product_ids'} + ) new_shipment = models.Shipment(**shipment_create, products=products) session.add(new_shipment) session.commit() session.refresh(new_shipment) return new_shipment -def update_one(session: Session, id: int, shipment: models.ShipmentUpdate) -> models.ShipmentPublic: + +def update_one( + session: Session, + _id: int, + shipment: models.ShipmentUpdate) -> models.ShipmentPublic: if shipment is None: - raise exceptions.ShipmentCreateError(messages.Messages.invalid_input('shipment', 'input cannot be None')) - statement = select(models.Shipment).where(models.Shipment.id == id) + raise exceptions.ShipmentCreateError( + messages.Messages.invalid_input( + 'shipment', 'input cannot be None')) + statement = select(models.Shipment).where(models.Shipment.id == _id) result = session.exec(statement) new_shipment = result.first() if not new_shipment: - raise exceptions.ShipmentNotFoundError(messages.Messages.not_found('shipment')) + raise exceptions.ShipmentNotFoundError( + messages.Messages.not_found('shipment')) - products_to_add = session.exec(select(models.Product).where(models.Product.id.in_(shipment.product_ids))).all() + products_to_add = session.exec( + select( + models.Product + ).where( + models.Product.id.in_( + shipment.product_ids + ) + ) + ).all() new_shipment.products.clear() for add in products_to_add: new_shipment.products.append(add) - shipment_updates = shipment.model_dump(exclude_unset=True, exclude={"product_ids"}) + shipment_updates = shipment.model_dump( + exclude_unset=True, exclude={"product_ids"} + ) for key, value in shipment_updates.items(): setattr(new_shipment, key, value) @@ -61,14 +114,16 @@ def update_one(session: Session, id: int, shipment: models.ShipmentUpdate) -> mo session.refresh(new_shipment) return new_shipment -def delete_one(session: Session, id: int) -> models.ShipmentPublic: - statement = select(models.Shipment).where(models.Shipment.id == id) + +def delete_one(session: Session, _id: int) -> models.ShipmentPublic: + statement = select(models.Shipment).where(models.Shipment.id == _id) result = session.exec(statement) shipment = result.first() if not shipment: - raise exceptions.ShipmentNotFoundError(messages.Messages.not_found('shipment')) + raise exceptions.ShipmentNotFoundError( + messages.Messages.not_found('shipment')) result = models.ShipmentPublic.model_validate(shipment) session.delete(shipment) session.commit() - return result \ No newline at end of file + return result diff --git a/backend/src/shipments/shipments.py b/backend/src/shipments/shipments.py index f969d02..a2e7f40 100644 --- a/backend/src/shipments/shipments.py +++ b/backend/src/shipments/shipments.py @@ -1,14 +1,15 @@ -from fastapi import APIRouter, HTTPException, Depends, Query import src.messages as messages -import src.models as models -from src.database import get_session -from sqlmodel import Session -import src.shipments.service as service import src.shipments.exceptions as exceptions +import src.shipments.service as service +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session +from src import models from src.auth.auth import get_current_user +from src.database import get_session router = APIRouter(prefix='/shipments') + @router.get('', response_model=list[models.ShipmentPublic], ) def get_shipments( session: Session = Depends(get_session), @@ -25,17 +26,22 @@ def get_shipments( forms, ) -@router.get('/{id}', response_model=models.ShipmentPublic) + +@router.get('/{_id}', response_model=models.ShipmentPublic) def get_shipment( - id: int, + _id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): - result = service.get_one(session, id) + result = service.get_one(session, _id) if result is None: - raise HTTPException(status_code=404, detail=messages.Messages.not_found('shipment')) + raise HTTPException( + status_code=404, + detail=messages.Messages.not_found('shipment') + ) return result + @router.post('', response_model=models.ShipmentPublic) def create_shipment( shipment: models.ShipmentCreate, @@ -45,30 +51,32 @@ def create_shipment( try: result = service.create_one(session, shipment) except exceptions.ShipmentCreateError as error: - raise HTTPException(status_code=400, detail=str(error)) + raise HTTPException(status_code=400, detail=str(error)) from error return result -@router.put('/{id}', response_model=models.ShipmentPublic) + +@router.put('/{_id}', response_model=models.ShipmentPublic) def update_shipment( - id: int, + _id: int, shipment: models.ShipmentUpdate, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): try: - result = service.update_one(session, id, shipment) + result = service.update_one(session, _id, shipment) except exceptions.ShipmentNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException(status_code=404, detail=str(error)) from error return result -@router.delete('/{id}', response_model=models.ShipmentPublic) + +@router.delete('/{_id}', response_model=models.ShipmentPublic) def delete_shipment( - id: int, + _id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): try: - result = service.delete_one(session, id) + result = service.delete_one(session, _id) except exceptions.ShipmentNotFoundError as error: - raise HTTPException(status_code=404, detail=str(error)) + raise HTTPException(status_code=404, detail=str(error)) from error return result diff --git a/backend/src/templates/__init__.py b/backend/src/templates/__init__.py index 10fe5b0..e9a63bc 100644 --- a/backend/src/templates/__init__.py +++ b/backend/src/templates/__init__.py @@ -1,3 +1,3 @@ # SPDX-FileCopyrightText: 2026-present Julien Aldon # -# SPDX-License-Identifier: MIT \ No newline at end of file +# SPDX-License-Identifier: MIT diff --git a/backend/src/templates/service.py b/backend/src/templates/service.py index ea0e37f..896782e 100644 --- a/backend/src/templates/service.py +++ b/backend/src/templates/service.py @@ -1,14 +1,19 @@ from sqlmodel import Session, select -import src.models as models +from src import models + def get_all(session: Session) -> list[models.TemplatePublic]: statement = select(models.Template) return session.exec(statement.order_by(models.Template.name)).all() + def get_one(session: Session, template_id: int) -> models.TemplatePublic: return session.get(models.Template, template_id) -def create_one(session: Session, template: models.TemplateCreate) -> models.TemplatePublic: + +def create_one( + session: Session, + template: models.TemplateCreate) -> models.TemplatePublic: template_create = template.model_dump(exclude_unset=True) new_template = models.Template(**template_create) session.add(new_template) @@ -16,7 +21,11 @@ def create_one(session: Session, template: models.TemplateCreate) -> models.Temp session.refresh(new_template) return new_template -def update_one(session: Session, id: int, template: models.TemplateUpdate) -> models.TemplatePublic: + +def update_one( + session: Session, + id: int, + template: models.TemplateUpdate) -> models.TemplatePublic: statement = select(models.Template).where(models.Template.id == id) result = session.exec(statement) new_template = result.first() @@ -30,6 +39,7 @@ def update_one(session: Session, id: int, template: models.TemplateUpdate) -> mo session.refresh(new_template) return new_template + def delete_one(session: Session, id: int) -> models.TemplatePublic: statement = select(models.Template).where(models.Template.id == id) result = session.exec(statement) diff --git a/backend/src/templates/templates.py b/backend/src/templates/templates.py index e5c471e..316120e 100644 --- a/backend/src/templates/templates.py +++ b/backend/src/templates/templates.py @@ -1,13 +1,14 @@ -from fastapi import APIRouter, HTTPException, Depends import src.messages as messages -import src.models as models -from src.database import get_session -from sqlmodel import Session import src.templates.service as service +from fastapi import APIRouter, Depends, HTTPException +from sqlmodel import Session +from src import models from src.auth.auth import get_current_user +from src.database import get_session router = APIRouter(prefix='/templates') + @router.get('', response_model=list[models.TemplatePublic]) def get_templates( user: models.User = Depends(get_current_user), @@ -15,43 +16,50 @@ def get_templates( ): return service.get_all(session) + @router.get('/{id}', response_model=models.TemplatePublic) def get_template( - id: int, + id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): result = service.get_one(session, id) if result is None: - raise HTTPException(status_code=404, detail=messages.Messages.not_found('template')) + raise HTTPException(status_code=404, + detail=messages.Messages.not_found('template')) return result + @router.post('', response_model=models.TemplatePublic) def create_template( - template: models.TemplateCreate, + template: models.TemplateCreate, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): return service.create_one(session, template) + @router.put('/{id}', response_model=models.TemplatePublic) def update_template( - id: int, template: models.TemplateUpdate, + id: int, template: models.TemplateUpdate, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): result = service.update_one(session, id, template) if result is None: - raise HTTPException(status_code=404, detail=messages.Messages.not_found('template')) + raise HTTPException(status_code=404, + detail=messages.Messages.not_found('template')) return result + @router.delete('/{id}', response_model=models.TemplatePublic) def delete_template( - id: int, + id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): result = service.delete_one(session, id) if result is None: - raise HTTPException(status_code=404, detail=messages.Messages.not_found('template')) + raise HTTPException(status_code=404, + detail=messages.Messages.not_found('template')) return result diff --git a/backend/src/users/__init__.py b/backend/src/users/__init__.py index 10fe5b0..e9a63bc 100644 --- a/backend/src/users/__init__.py +++ b/backend/src/users/__init__.py @@ -1,3 +1,3 @@ # SPDX-FileCopyrightText: 2026-present Julien Aldon # -# SPDX-License-Identifier: MIT \ No newline at end of file +# SPDX-License-Identifier: MIT diff --git a/backend/src/users/exceptions.py b/backend/src/users/exceptions.py index 771bd04..de7efb5 100644 --- a/backend/src/users/exceptions.py +++ b/backend/src/users/exceptions.py @@ -1,11 +1,17 @@ +import logging + + class UserServiceError(Exception): def __init__(self, message: str): super().__init__(message) + logging.error('UserService : %s', message) + class UserNotFoundError(UserServiceError): pass + class UserCreateError(UserServiceError): def __init__(self, message: str, field: str | None = None): super().__init__(message) - self.field = field \ No newline at end of file + self.field = field diff --git a/backend/src/users/service.py b/backend/src/users/service.py index 466417f..ed8f7c0 100644 --- a/backend/src/users/service.py +++ b/backend/src/users/service.py @@ -1,9 +1,8 @@ -from sqlmodel import Session, select - -import src.models as models import src.messages as messages - import src.users.exceptions as exceptions +from sqlmodel import Session, select +from src import models + def get_all( session: Session, @@ -17,11 +16,15 @@ def get_all( statement = statement.where(models.User.email.in_(emails)) return session.exec(statement.order_by(models.User.name)).all() + def get_one(session: Session, user_id: int) -> models.UserPublic: return session.get(models.User, user_id) -def get_or_create_roles(session: Session, role_names: list[str]) -> list[models.ContractType]: - statement = select(models.ContractType).where(models.ContractType.name.in_(role_names)) + +def get_or_create_roles(session: Session, + role_names: list[str]) -> list[models.ContractType]: + statement = select(models.ContractType).where( + models.ContractType.name.in_(role_names)) existing = session.exec(statement).all() existing_roles = {role.name for role in existing} missing_role = set(role_names) - existing_roles @@ -37,8 +40,11 @@ def get_or_create_roles(session: Session, role_names: list[str]) -> list[models. session.refresh(role) return existing + new_roles + def get_or_create_user(session: Session, user_create: models.UserCreate): - statement = select(models.User).where(models.User.email == user_create.email) + statement = select( + models.User).where( + models.User.email == user_create.email) user = session.exec(statement).first() if user: user_role_names = [r.name for r in user.roles] @@ -48,13 +54,17 @@ def get_or_create_user(session: Session, user_create: models.UserCreate): user = create_one(session, user_create) return user + def get_roles(session: Session): statement = select(models.ContractType) return session.exec(statement.order_by(models.ContractType.name)).all() + def create_one(session: Session, user: models.UserCreate) -> models.UserPublic: if user is None: - raise exceptions.UserCreateError(messages.Messages.invalid_input('user', 'input cannot be None')) + raise exceptions.UserCreateError( + messages.Messages.invalid_input( + 'user', 'input cannot be None')) new_user = models.User( name=user.name, email=user.email @@ -68,9 +78,15 @@ def create_one(session: Session, user: models.UserCreate) -> models.UserPublic: session.refresh(new_user) return new_user -def update_one(session: Session, id: int, user: models.UserCreate) -> models.UserPublic: + +def update_one( + session: Session, + id: int, + user: models.UserCreate) -> models.UserPublic: if user is None: - raise exceptions.UserCreateError(messages.s.invalid_input('user', 'input cannot be None')) + raise exceptions.UserCreateError( + messages.s.invalid_input( + 'user', 'input cannot be None')) statement = select(models.User).where(models.User.id == id) result = session.exec(statement) new_user = result.first() @@ -86,6 +102,7 @@ def update_one(session: Session, id: int, user: models.UserCreate) -> models.Use session.refresh(new_user) return new_user + def delete_one(session: Session, id: int) -> models.UserPublic: statement = select(models.User).where(models.User.id == id) result = session.exec(statement) @@ -95,4 +112,4 @@ def delete_one(session: Session, id: int) -> models.UserPublic: result = models.UserPublic.model_validate(user) session.delete(user) session.commit() - return result \ No newline at end of file + return result diff --git a/backend/src/users/users.py b/backend/src/users/users.py index 73a303f..55e26aa 100644 --- a/backend/src/users/users.py +++ b/backend/src/users/users.py @@ -1,14 +1,15 @@ -from fastapi import APIRouter, HTTPException, Depends, Query import src.messages as messages -import src.models as models -from src.database import get_session -from sqlmodel import Session -import src.users.service as service -from src.auth.auth import get_current_user import src.users.exceptions as exceptions +import src.users.service as service +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session +from src import models +from src.auth.auth import get_current_user +from src.database import get_session router = APIRouter(prefix='/users') + @router.get('', response_model=list[models.UserPublic]) def get_users( session: Session = Depends(get_session), @@ -22,6 +23,7 @@ def get_users( emails, ) + @router.get('/roles', response_model=list[models.ContractType]) def get_roles( user: models.User = Depends(get_current_user), @@ -29,20 +31,23 @@ def get_roles( ): return service.get_roles(session) + @router.get('/{id}', response_model=models.UserPublic) def get_users( - id: int, + id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): result = service.get_one(session, id) if result is None: - raise HTTPException(status_code=404, detail=messages.Messages.not_found('user')) + raise HTTPException(status_code=404, + detail=messages.Messages.not_found('user')) return result + @router.post('', response_model=models.UserPublic) def create_user( - user: models.UserCreate, + user: models.UserCreate, logged_user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): @@ -52,27 +57,31 @@ def create_user( raise HTTPException(status_code=400, detail=str(error)) return user + @router.put('/{id}', response_model=models.UserPublic) def update_user( - id: int, - user: models.UserUpdate, + id: int, + user: models.UserUpdate, logged_user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): try: result = service.update_one(session, id, user) except exceptions.UserNotFoundError as error: - raise HTTPException(status_code=404, detail=messages.Messages.not_found('user')) + raise HTTPException(status_code=404, + detail=messages.Messages.not_found('user')) return result + @router.delete('/{id}', response_model=models.UserPublic) def delete_user( - id: int, + id: int, user: models.User = Depends(get_current_user), session: Session = Depends(get_session) ): try: result = service.delete_one(session, id) except exceptions.UserNotFoundError as error: - raise HTTPException(status_code=404, detail=messages.Messages.not_found('user')) + raise HTTPException(status_code=404, + detail=messages.Messages.not_found('user')) return result diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 70fee98..5096dc8 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,13 +1,13 @@ import pytest from fastapi.testclient import TestClient -from sqlmodel import SQLModel, Session, create_engine from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine +from src import models +from src.auth.auth import get_current_user +from src.database import get_session +from src.main import app from .fixtures import * -from src.main import app -import src.models as models -from src.database import get_session -from src.auth.auth import get_current_user @pytest.fixture diff --git a/backend/tests/factories/contract_products.py b/backend/tests/factories/contract_products.py index 5a43196..461b35d 100644 --- a/backend/tests/factories/contract_products.py +++ b/backend/tests/factories/contract_products.py @@ -1,6 +1,6 @@ -import src.models as models import tests.factories.contracts as contract_factory import tests.factories.products as product_factory +from src import models def contract_product_factory(**kwargs): diff --git a/backend/tests/factories/contracts.py b/backend/tests/factories/contracts.py index e8a142e..00c6a1f 100644 --- a/backend/tests/factories/contracts.py +++ b/backend/tests/factories/contracts.py @@ -1,4 +1,5 @@ -import src.models as models +from src import models + from .forms import form_factory diff --git a/backend/tests/factories/forms.py b/backend/tests/factories/forms.py index a84d4df..d19863c 100644 --- a/backend/tests/factories/forms.py +++ b/backend/tests/factories/forms.py @@ -1,7 +1,9 @@ -import src.models as models +import datetime + +from src import models + from .productors import productor_public_factory from .users import user_factory -import datetime def form_factory(**kwargs): diff --git a/backend/tests/factories/productors.py b/backend/tests/factories/productors.py index 5c2553c..ad1a558 100644 --- a/backend/tests/factories/productors.py +++ b/backend/tests/factories/productors.py @@ -1,4 +1,4 @@ -import src.models as models +from src import models def productor_factory(**kwargs): diff --git a/backend/tests/factories/products.py b/backend/tests/factories/products.py index d19a430..46ce056 100644 --- a/backend/tests/factories/products.py +++ b/backend/tests/factories/products.py @@ -1,4 +1,5 @@ -import src.models as models +from src import models + from .productors import productor_factory diff --git a/backend/tests/factories/shipments.py b/backend/tests/factories/shipments.py index fd86ffa..3705461 100644 --- a/backend/tests/factories/shipments.py +++ b/backend/tests/factories/shipments.py @@ -1,6 +1,7 @@ -import src.models as models import datetime +from src import models + def shipment_factory(**kwargs): data = dict( diff --git a/backend/tests/factories/users.py b/backend/tests/factories/users.py index 5763198..eedd9da 100644 --- a/backend/tests/factories/users.py +++ b/backend/tests/factories/users.py @@ -1,4 +1,4 @@ -import src.models as models +from src import models def user_factory(**kwargs): diff --git a/backend/tests/fixtures.py b/backend/tests/fixtures.py index ad82add..639f47b 100644 --- a/backend/tests/fixtures.py +++ b/backend/tests/fixtures.py @@ -1,18 +1,18 @@ -import pytest import datetime -from sqlmodel import Session -import src.models as models +import pytest import src.forms.service as forms_service -import src.shipments.service as shipments_service import src.productors.service as productors_service import src.products.service as products_service +import src.shipments.service as shipments_service import src.users.service as users_service import tests.factories.forms as forms_factory -import tests.factories.shipments as shipments_factory import tests.factories.productors as productors_factory import tests.factories.products as products_factory +import tests.factories.shipments as shipments_factory import tests.factories.users as users_factory +from sqlmodel import Session +from src import models @pytest.fixture diff --git a/backend/tests/routers/test_contracts.py b/backend/tests/routers/test_contracts.py index 5c6f3d1..460d91b 100644 --- a/backend/tests/routers/test_contracts.py +++ b/backend/tests/routers/test_contracts.py @@ -1,12 +1,11 @@ import src.contracts.service as service -import src.models as models -from src.main import app -from src.auth.auth import get_current_user +import tests.factories.contract_products as contract_products_factory import tests.factories.contracts as contract_factory import tests.factories.forms as form_factory -import tests.factories.contract_products as contract_products_factory - from fastapi.exceptions import HTTPException +from src import models +from src.auth.auth import get_current_user +from src.main import app class TestContracts: diff --git a/backend/tests/routers/test_forms.py b/backend/tests/routers/test_forms.py index e722e40..75a42b7 100644 --- a/backend/tests/routers/test_forms.py +++ b/backend/tests/routers/test_forms.py @@ -1,11 +1,11 @@ -import src.forms.service as service import src.forms.exceptions as forms_exceptions -import src.models as models -from src.main import app -from src.auth.auth import get_current_user +import src.forms.service as service +import src.messages as messages import tests.factories.forms as form_factory from fastapi.exceptions import HTTPException -import src.messages as messages +from src import models +from src.auth.auth import get_current_user +from src.main import app class TestForms: diff --git a/backend/tests/routers/test_productors.py b/backend/tests/routers/test_productors.py index cd229df..5ff9046 100644 --- a/backend/tests/routers/test_productors.py +++ b/backend/tests/routers/test_productors.py @@ -1,14 +1,11 @@ -from fastapi.exceptions import HTTPException - -from src.main import app -import src.models as models import src.messages as messages -from src.auth.auth import get_current_user - -import src.productors.service as service import src.productors.exceptions as exceptions - +import src.productors.service as service import tests.factories.productors as productor_factory +from fastapi.exceptions import HTTPException +from src import models +from src.auth.auth import get_current_user +from src.main import app class TestProductors: diff --git a/backend/tests/routers/test_products.py b/backend/tests/routers/test_products.py index ba661d8..01e79f4 100644 --- a/backend/tests/routers/test_products.py +++ b/backend/tests/routers/test_products.py @@ -1,11 +1,10 @@ -import src.products.service as service import src.products.exceptions as exceptions -import src.models as models -from src.main import app -from src.auth.auth import get_current_user +import src.products.service as service import tests.factories.products as product_factory - from fastapi.exceptions import HTTPException +from src import models +from src.auth.auth import get_current_user +from src.main import app class TestProducts: diff --git a/backend/tests/routers/test_shipments.py b/backend/tests/routers/test_shipments.py index 064f92e..3d60e0f 100644 --- a/backend/tests/routers/test_shipments.py +++ b/backend/tests/routers/test_shipments.py @@ -1,12 +1,11 @@ -import src.shipments.service as service -import src.models as models -from src.main import app import src.messages as messages import src.shipments.exceptions as exceptions -from src.auth.auth import get_current_user +import src.shipments.service as service import tests.factories.shipments as shipment_factory - from fastapi.exceptions import HTTPException +from src import models +from src.auth.auth import get_current_user +from src.main import app class TestShipments: diff --git a/backend/tests/routers/test_users.py b/backend/tests/routers/test_users.py index 932f3fc..55618a8 100644 --- a/backend/tests/routers/test_users.py +++ b/backend/tests/routers/test_users.py @@ -1,11 +1,10 @@ -import src.users.service as service -import src.models as models -from src.main import app -from src.auth.auth import get_current_user -import tests.factories.users as user_factory import src.users.exceptions as exceptions - +import src.users.service as service +import tests.factories.users as user_factory from fastapi.exceptions import HTTPException +from src import models +from src.auth.auth import get_current_user +from src.main import app class TestUsers: diff --git a/backend/tests/services/test_forms_service.py b/backend/tests/services/test_forms_service.py index f5e1153..3aa56ff 100644 --- a/backend/tests/services/test_forms_service.py +++ b/backend/tests/services/test_forms_service.py @@ -1,10 +1,9 @@ import pytest -from sqlmodel import Session - -import src.models as models -import src.forms.service as forms_service import src.forms.exceptions as forms_exceptions +import src.forms.service as forms_service import tests.factories.forms as forms_factory +from sqlmodel import Session +from src import models class TestFormsService: diff --git a/backend/tests/services/test_productors_service.py b/backend/tests/services/test_productors_service.py index 24fb0fb..c9e318e 100644 --- a/backend/tests/services/test_productors_service.py +++ b/backend/tests/services/test_productors_service.py @@ -1,10 +1,9 @@ import pytest -from sqlmodel import Session - -import src.models as models -import src.productors.service as productors_service import src.productors.exceptions as productors_exceptions +import src.productors.service as productors_service import tests.factories.productors as productors_factory +from sqlmodel import Session +from src import models class TestProductorsService: diff --git a/backend/tests/services/test_products_service.py b/backend/tests/services/test_products_service.py index 6d7fff0..18a602c 100644 --- a/backend/tests/services/test_products_service.py +++ b/backend/tests/services/test_products_service.py @@ -1,10 +1,9 @@ import pytest -from sqlmodel import Session - -import src.models as models -import src.products.service as products_service import src.products.exceptions as products_exceptions +import src.products.service as products_service import tests.factories.products as products_factory +from sqlmodel import Session +from src import models class TestProductsService: diff --git a/backend/tests/services/test_shipments_service.py b/backend/tests/services/test_shipments_service.py index 96c3a7e..57e4f8b 100644 --- a/backend/tests/services/test_shipments_service.py +++ b/backend/tests/services/test_shipments_service.py @@ -1,11 +1,11 @@ -import pytest import datetime -from sqlmodel import Session -import src.models as models -import src.shipments.service as shipments_service +import pytest import src.shipments.exceptions as shipments_exceptions +import src.shipments.service as shipments_service import tests.factories.shipments as shipments_factory +from sqlmodel import Session +from src import models class TestShipmentsService: diff --git a/backend/tests/services/test_users_service.py b/backend/tests/services/test_users_service.py index 4315c1c..37d85df 100644 --- a/backend/tests/services/test_users_service.py +++ b/backend/tests/services/test_users_service.py @@ -1,10 +1,9 @@ import pytest -from sqlmodel import Session - -import src.models as models -import src.users.service as users_service import src.users.exceptions as users_exceptions +import src.users.service as users_service import tests.factories.users as users_factory +from sqlmodel import Session +from src import models class TestUsersService: