91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
from sqlmodel import Session, select
|
|
import src.models as models
|
|
|
|
def get_all(
|
|
session: Session,
|
|
names: list[str],
|
|
emails: list[str],
|
|
) -> list[models.UserPublic]:
|
|
statement = select(models.User)
|
|
if len(names) > 0:
|
|
statement = statement.where(models.User.name.in_(names))
|
|
if len(emails) > 0:
|
|
statement = statement.where(models.User.email.in_(emails))
|
|
return session.exec(statement.order_by(models.User.name)).all()
|
|
|
|
def get_one(session: Session, user_id: int) -> models.UserPublic:
|
|
return session.get(models.User, user_id)
|
|
|
|
def get_or_create_roles(session: Session, role_names: list[str]) -> list[models.ContractType]:
|
|
statement = select(models.ContractType).where(models.ContractType.name.in_(role_names))
|
|
existing = session.exec(statement).all()
|
|
existing_roles = {role.name for role in existing}
|
|
missing_role = set(role_names) - existing_roles
|
|
|
|
new_roles = []
|
|
for role_name in missing_role:
|
|
role = models.ContractType(name=role_name)
|
|
session.add(role)
|
|
new_roles.append(role)
|
|
|
|
session.commit()
|
|
for role in new_roles:
|
|
session.refresh(role)
|
|
return existing + new_roles
|
|
|
|
def get_or_create_user(session: Session, user_create: models.UserCreate):
|
|
statement = select(models.User).where(models.User.email == user_create.email)
|
|
user = session.exec(statement).first()
|
|
if user:
|
|
user_role_names = [r.name for r in user.roles]
|
|
if user_role_names != user_create.role_names or user.name != user_create.name:
|
|
user = update_one(session, user.id, user_create)
|
|
return user
|
|
user = create_one(session, user_create)
|
|
return user
|
|
|
|
def get_roles(session: Session):
|
|
statement = select(models.ContractType)
|
|
return session.exec(statement.order_by(models.ContractType.name)).all()
|
|
|
|
def create_one(session: Session, user: models.UserCreate) -> models.UserPublic:
|
|
new_user = models.User(
|
|
name=user.name,
|
|
email=user.email
|
|
)
|
|
|
|
roles = get_or_create_roles(session, user.role_names)
|
|
new_user.roles = roles
|
|
|
|
session.add(new_user)
|
|
session.commit()
|
|
session.refresh(new_user)
|
|
return new_user
|
|
|
|
def update_one(session: Session, id: int, user: models.UserCreate) -> models.UserPublic:
|
|
statement = select(models.User).where(models.User.id == id)
|
|
result = session.exec(statement)
|
|
new_user = result.first()
|
|
if not new_user:
|
|
return None
|
|
|
|
new_user.email = user.email
|
|
new_user.name = user.name
|
|
|
|
roles = get_or_create_roles(session, user.role_names)
|
|
new_user.roles = roles
|
|
session.add(new_user)
|
|
session.commit()
|
|
session.refresh(new_user)
|
|
return new_user
|
|
|
|
def delete_one(session: Session, id: int) -> models.UserPublic:
|
|
statement = select(models.User).where(models.User.id == id)
|
|
result = session.exec(statement)
|
|
user = result.first()
|
|
if not user:
|
|
return None
|
|
result = models.UserPublic.model_validate(user)
|
|
session.delete(user)
|
|
session.commit()
|
|
return result |