This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
from sqlmodel import Session, select
|
||||
|
||||
import src.models as models
|
||||
import src.messages as messages
|
||||
|
||||
import src.users.exceptions as exceptions
|
||||
from sqlmodel import Session, select
|
||||
from src import models
|
||||
|
||||
|
||||
def get_all(
|
||||
session: Session,
|
||||
@@ -17,11 +16,15 @@ def get_all(
|
||||
statement = statement.where(models.User.email.in_(emails))
|
||||
return session.exec(statement.order_by(models.User.name)).all()
|
||||
|
||||
|
||||
def get_one(session: Session, user_id: int) -> models.UserPublic:
|
||||
return session.get(models.User, user_id)
|
||||
|
||||
def get_or_create_roles(session: Session, role_names: list[str]) -> list[models.ContractType]:
|
||||
statement = select(models.ContractType).where(models.ContractType.name.in_(role_names))
|
||||
|
||||
def get_or_create_roles(session: Session,
|
||||
role_names: list[str]) -> list[models.ContractType]:
|
||||
statement = select(models.ContractType).where(
|
||||
models.ContractType.name.in_(role_names))
|
||||
existing = session.exec(statement).all()
|
||||
existing_roles = {role.name for role in existing}
|
||||
missing_role = set(role_names) - existing_roles
|
||||
@@ -37,8 +40,11 @@ def get_or_create_roles(session: Session, role_names: list[str]) -> list[models.
|
||||
session.refresh(role)
|
||||
return existing + new_roles
|
||||
|
||||
|
||||
def get_or_create_user(session: Session, user_create: models.UserCreate):
|
||||
statement = select(models.User).where(models.User.email == user_create.email)
|
||||
statement = select(
|
||||
models.User).where(
|
||||
models.User.email == user_create.email)
|
||||
user = session.exec(statement).first()
|
||||
if user:
|
||||
user_role_names = [r.name for r in user.roles]
|
||||
@@ -48,13 +54,17 @@ def get_or_create_user(session: Session, user_create: models.UserCreate):
|
||||
user = create_one(session, user_create)
|
||||
return user
|
||||
|
||||
|
||||
def get_roles(session: Session):
|
||||
statement = select(models.ContractType)
|
||||
return session.exec(statement.order_by(models.ContractType.name)).all()
|
||||
|
||||
|
||||
def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
|
||||
if user is None:
|
||||
raise exceptions.UserCreateError(messages.Messages.invalid_input('user', 'input cannot be None'))
|
||||
raise exceptions.UserCreateError(
|
||||
messages.Messages.invalid_input(
|
||||
'user', 'input cannot be None'))
|
||||
new_user = models.User(
|
||||
name=user.name,
|
||||
email=user.email
|
||||
@@ -68,9 +78,15 @@ def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
|
||||
session.refresh(new_user)
|
||||
return new_user
|
||||
|
||||
def update_one(session: Session, id: int, user: models.UserCreate) -> models.UserPublic:
|
||||
|
||||
def update_one(
|
||||
session: Session,
|
||||
id: int,
|
||||
user: models.UserCreate) -> models.UserPublic:
|
||||
if user is None:
|
||||
raise exceptions.UserCreateError(messages.s.invalid_input('user', 'input cannot be None'))
|
||||
raise exceptions.UserCreateError(
|
||||
messages.s.invalid_input(
|
||||
'user', 'input cannot be None'))
|
||||
statement = select(models.User).where(models.User.id == id)
|
||||
result = session.exec(statement)
|
||||
new_user = result.first()
|
||||
@@ -86,6 +102,7 @@ def update_one(session: Session, id: int, user: models.UserCreate) -> models.Use
|
||||
session.refresh(new_user)
|
||||
return new_user
|
||||
|
||||
|
||||
def delete_one(session: Session, id: int) -> models.UserPublic:
|
||||
statement = select(models.User).where(models.User.id == id)
|
||||
result = session.exec(statement)
|
||||
@@ -95,4 +112,4 @@ def delete_one(session: Session, id: int) -> models.UserPublic:
|
||||
result = models.UserPublic.model_validate(user)
|
||||
session.delete(user)
|
||||
session.commit()
|
||||
return result
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user