From eaddaf3465a813c7e38e84e386ff40b5efabc3c2 Mon Sep 17 00:00:00 2001 From: cwalter <cwalter@ethz.ch> Date: Sun, 8 Dec 2024 21:48:22 +0100 Subject: [PATCH 1/5] switch to uuid7 --- .devcontainer/requirements.txt | 1 + README.md | 3 +++ app/api/routes/files.py | 3 ++- app/api/routes/internalTransfers.py | 1 + app/models/addresses.py | 11 +++++++++-- app/models/bills.py | 8 +++++--- app/models/creditPayments.py | 4 ++-- app/models/creditor.py | 19 ++++++++++++++++++- app/models/debitor.py | 3 ++- app/models/ezags.py | 3 ++- app/models/internalTransfers.py | 3 ++- app/models/invoices.py | 5 +++-- app/models/items.py | 3 ++- app/models/ksts.py | 3 ++- app/models/ledgers.py | 3 ++- app/models/reimbursements.py | 3 ++- app/models/unlinked_files.py | 3 ++- requirements.txt | 1 + 18 files changed, 61 insertions(+), 19 deletions(-) diff --git a/.devcontainer/requirements.txt b/.devcontainer/requirements.txt index 38e1f7e..d974550 100644 --- a/.devcontainer/requirements.txt +++ b/.devcontainer/requirements.txt @@ -24,3 +24,4 @@ qrbill authlib itsdangerous requests +uuid6 diff --git a/README.md b/README.md index 9544e78..aeacdeb 100644 --- a/README.md +++ b/README.md @@ -42,3 +42,6 @@ To upgrade the database to the latest migration, run: ```bash alembic upgrade head ``` + +## TODOs +- with python 3.14 uuid7 should be part of the uuid library. rebuild all of the uuid6.uuid7() calls to uuid.uuid7() this shoudn't change anything else. \ No newline at end of file diff --git a/app/api/routes/files.py b/app/api/routes/files.py index e796737..d9ad277 100644 --- a/app/api/routes/files.py +++ b/app/api/routes/files.py @@ -4,6 +4,7 @@ import uuid from datetime import timedelta # Import timedelta from typing import Optional +import uuid6 from fastapi import APIRouter, FastAPI, File, HTTPException, UploadFile, status from fastapi.responses import RedirectResponse from minio import Minio @@ -73,7 +74,7 @@ async def upload_file(file: UploadFile = File(...)): """ Upload a file to MinIO and return its UUID. """ - file_uuid = str(uuid.uuid4()) + file_uuid = str(uuid6.uuid7()) object_name = file_uuid # Using UUID as the object name try: diff --git a/app/api/routes/internalTransfers.py b/app/api/routes/internalTransfers.py index e9f6cc5..3bddfc5 100644 --- a/app/api/routes/internalTransfers.py +++ b/app/api/routes/internalTransfers.py @@ -78,6 +78,7 @@ def create_InternalTransfer( "debitor": debitor, "creditor_id": creditor.id, "debitor_id": debitor.id, + "name": "InternalTransfer", }, ) session.add(internalTransfer) diff --git a/app/models/addresses.py b/app/models/addresses.py index 612a61b..d780e19 100644 --- a/app/models/addresses.py +++ b/app/models/addresses.py @@ -2,6 +2,7 @@ import uuid from datetime import datetime from typing import Optional +import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel @@ -18,7 +19,7 @@ class AddressBase(SQLModel): class Address(AddressBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) time_create: datetime = Field(default=datetime.now()) time_modified: datetime = Field(default=datetime.now()) @@ -34,7 +35,13 @@ class AddressesPublic(SQLModel): class AddressFilter(Filter): - accountnumber: Optional[int] = None + name: str | None = None + address1: str | None = None + address2: str | None = None + address3: str | None = None + city: str | None = None + country: str | None = None + plz: int | None = None search: Optional[str] = None class Constants(Filter.Constants): diff --git a/app/models/bills.py b/app/models/bills.py index f289789..780eceb 100644 --- a/app/models/bills.py +++ b/app/models/bills.py @@ -1,11 +1,13 @@ import uuid from datetime import datetime +from typing import Optional +import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel -from app.models.addresses import Address, AddressBase, AddressPublic +from app.models.addresses import Address, AddressBase, AddressFilter, AddressPublic from app.models.creditor import Creditor, CreditorBase, CreditorPublic @@ -19,7 +21,7 @@ class BillBase(SQLModel): class Bill(BillBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) time_create: datetime = Field(default=datetime.now()) time_modified: datetime = Field(default=datetime.now()) ezag_timestamp: datetime | None = Field(default=None) @@ -55,7 +57,7 @@ class BillsList(SQLModel): class BillFilter(Filter): - creditor: str | None = None + # number_filter: Optional[Filter] = FilterDepends(with_prefix("address", AddressFilter)) address: str | None = None reference: str | None = None iban: str | None = None diff --git a/app/models/creditPayments.py b/app/models/creditPayments.py index 52d29c8..58fe522 100644 --- a/app/models/creditPayments.py +++ b/app/models/creditPayments.py @@ -4,7 +4,7 @@ from enum import Enum from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr -from sqlmodel import Field, Relationship, SQLModel +import uuid6 from app.models.addresses import Address, AddressBase, AddressPublic from app.models.creditor import Creditor, CreditorBase, CreditorPublic @@ -23,7 +23,7 @@ class CreditPaymentBase(SQLModel): class CreditPayment(CreditPaymentBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) name: str = Field(max_length=200) time_create: datetime = Field(default=datetime.now()) time_modified: datetime = Field(default=datetime.now()) diff --git a/app/models/creditor.py b/app/models/creditor.py index 833c8f5..c009d0e 100644 --- a/app/models/creditor.py +++ b/app/models/creditor.py @@ -2,6 +2,7 @@ import uuid from datetime import datetime from enum import Enum +import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter from sqlmodel import Field, Relationship, SQLModel @@ -29,7 +30,7 @@ class CreditorBase(SQLModel): class Creditor(CreditorBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) time_create: datetime = Field(default=datetime.now()) time_modified: datetime = Field(default=datetime.now()) ledger: "Ledger" = Relationship(sa_relationship_kwargs={"lazy": "selectin"}) @@ -38,3 +39,19 @@ class Creditor(CreditorBase, table=True): class CreditorPublic(CreditorBase): id: uuid.UUID = Field() + + +class CreditorFilter(Filter): + kst: str | None = None + ledger: str | None = None + amount: int | None = None + accounting_year: int | None = None + currency: Currency | None = None + comment: str | None = None + qcomment: str | None = None + search: str | None = None + + class Constants(Filter.Constants): + model = Creditor + search_field_name = "search" + search_model_fields = ["comment", "qcomment"] diff --git a/app/models/debitor.py b/app/models/debitor.py index f245439..984edaa 100644 --- a/app/models/debitor.py +++ b/app/models/debitor.py @@ -1,6 +1,7 @@ import uuid from datetime import datetime +import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel @@ -21,7 +22,7 @@ class DebitorBase(SQLModel): class Debitor(DebitorBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) time_create: datetime = Field(default=datetime.now()) time_modified: datetime = Field(default=datetime.now()) kst: "Kst" = Relationship(sa_relationship_kwargs={"lazy": "selectin"}) diff --git a/app/models/ezags.py b/app/models/ezags.py index b0acf15..1ecc6e5 100644 --- a/app/models/ezags.py +++ b/app/models/ezags.py @@ -1,6 +1,7 @@ import uuid from datetime import datetime +import uuid6 from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel @@ -12,6 +13,6 @@ class EzagBase(SQLModel): class Ezag(EzagBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) time_create: datetime = Field(default=datetime.now()) time_modified: datetime = Field(default=datetime.now()) diff --git a/app/models/internalTransfers.py b/app/models/internalTransfers.py index 62f1399..820ad30 100644 --- a/app/models/internalTransfers.py +++ b/app/models/internalTransfers.py @@ -1,6 +1,7 @@ import uuid from datetime import datetime +import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel @@ -21,7 +22,7 @@ class InternalTransferBase(SQLModel): class InternalTransfer(InternalTransferBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) time_create: datetime = Field(default=datetime.now()) time_modified: datetime = Field(default=datetime.now()) debitor: "Debitor" = Relationship(sa_relationship_kwargs={"lazy": "selectin"}) diff --git a/app/models/invoices.py b/app/models/invoices.py index 222dea1..1270729 100644 --- a/app/models/invoices.py +++ b/app/models/invoices.py @@ -1,6 +1,7 @@ import uuid from datetime import datetime +import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel @@ -45,7 +46,7 @@ class InvoiceBase(SQLModel): class Invoice(InvoiceBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) time_create: datetime = Field(default=datetime.now()) time_modified: datetime = Field(default=datetime.now()) @@ -60,7 +61,7 @@ class InvoiceCreate(InvoiceBase): class InvoicePublic(InvoiceBase): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) time_create: datetime = Field(default=None) time_modified: datetime = Field(default=None) items: list[ItemstoInvoicePublic] diff --git a/app/models/items.py b/app/models/items.py index 5016821..4e5fbce 100644 --- a/app/models/items.py +++ b/app/models/items.py @@ -2,6 +2,7 @@ import uuid from datetime import datetime from typing import Optional +import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel @@ -18,7 +19,7 @@ class ItemBase(SQLModel): class Item(ItemBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) time_create: datetime = Field(default=datetime.now()) time_modified: datetime = Field(default=datetime.now()) diff --git a/app/models/ksts.py b/app/models/ksts.py index fd486a7..f737e4c 100644 --- a/app/models/ksts.py +++ b/app/models/ksts.py @@ -2,6 +2,7 @@ import uuid from datetime import datetime from typing import Optional +import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel @@ -18,7 +19,7 @@ class KstBase(SQLModel): class Kst(KstBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) time_create: datetime = Field(default=datetime.now()) time_modified: datetime = Field(default=datetime.now()) diff --git a/app/models/ledgers.py b/app/models/ledgers.py index 9043997..8386b55 100644 --- a/app/models/ledgers.py +++ b/app/models/ledgers.py @@ -2,6 +2,7 @@ import uuid from datetime import datetime from typing import Optional +import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel @@ -15,7 +16,7 @@ class LedgerBase(SQLModel): class Ledger(LedgerBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) time_create: datetime = Field(default=datetime.now()) time_modified: datetime = Field(default=datetime.now()) diff --git a/app/models/reimbursements.py b/app/models/reimbursements.py index 9925ff0..8c1af89 100644 --- a/app/models/reimbursements.py +++ b/app/models/reimbursements.py @@ -1,6 +1,7 @@ import uuid from datetime import datetime +import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel @@ -15,7 +16,7 @@ class ReimbursementBase(SQLModel): class Reimbursement(ReimbursementBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) name: str = Field(max_length=200) time_create: datetime = Field(default=datetime.now()) time_modified: datetime = Field(default=datetime.now()) diff --git a/app/models/unlinked_files.py b/app/models/unlinked_files.py index 6d77238..97d6e08 100644 --- a/app/models/unlinked_files.py +++ b/app/models/unlinked_files.py @@ -2,12 +2,13 @@ import uuid from datetime import datetime +import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel class UnlinkedFiles(SQLModel, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) name: str = Field(max_length=200) creator: str = Field(max_length=30) # TODO: change to user_id foreign key diff --git a/requirements.txt b/requirements.txt index 7bc5ce2..f9e677e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,3 +26,4 @@ svglib authlib itsdangerous requests +uuid6 \ No newline at end of file -- GitLab From 12685384ac51032af94cd824c01b7f543e48dc26 Mon Sep 17 00:00:00 2001 From: cwalter <cwalter@ethz.ch> Date: Sun, 8 Dec 2024 21:48:59 +0100 Subject: [PATCH 2/5] filter experiments v1 --- .devcontainer/devcontainer.json | 1 - app/api/routes/api_helper.py | 2 +- app/api/routes/creditPayments.py | 5 +-- app/models/creditPayments.py | 60 +++++++++++++++++++++++++++++--- requirements.txt | 2 +- 5 files changed, 61 insertions(+), 9 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 9b4df28..2668129 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -34,7 +34,6 @@ // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. // "remoteUser": "root" "workspaceFolder": "/workspace", - //"containerEnv": {"NO_AUTH":"1"}, use this if there shouldn't be any auth. "remoteUser": "vscode" } diff --git a/app/api/routes/api_helper.py b/app/api/routes/api_helper.py index 9f79f0f..54e3315 100644 --- a/app/api/routes/api_helper.py +++ b/app/api/routes/api_helper.py @@ -77,7 +77,7 @@ def read_objects( stmt = filter.filter( select(Obj_type).offset((page) * limit).limit(limit).order_by(order_by) ) - + print(stmt) # Execute the query out = session.exec(stmt).all() diff --git a/app/api/routes/creditPayments.py b/app/api/routes/creditPayments.py index 4a1c0ed..cd57e5d 100644 --- a/app/api/routes/creditPayments.py +++ b/app/api/routes/creditPayments.py @@ -1,5 +1,5 @@ import uuid -from typing import Any +from typing import Annotated, Any from fastapi import APIRouter, HTTPException, Query, Request from fastapi_filter import FilterDepends @@ -30,14 +30,15 @@ router = APIRouter() @router.get("/", response_model=CreditPaymentsList) def read_creditPayments( + filter: Annotated[CreditPaymentFilter, Query()], page: int = 0, limit: int = Query(default=100, le=1000), - filter: CreditPaymentFilter = FilterDepends(CreditPaymentFilter), ) -> CreditPaymentsList: """ retrieve all CreditPaymentes, sorted by name. paginate the results. """ + print(filter.__dict__) # we need to have a smart way of ordering. THe fastapi_filter library does not support ordering. return read_objects( CreditPayment, diff --git a/app/models/creditPayments.py b/app/models/creditPayments.py index 58fe522..30f1ae0 100644 --- a/app/models/creditPayments.py +++ b/app/models/creditPayments.py @@ -1,13 +1,16 @@ import uuid from datetime import datetime from enum import Enum +from typing import Any, Optional +import uuid6 +from fastapi_filter import FilterDepends, with_prefix from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr -import uuid6 +from sqlmodel import Field, Relationship, SQLModel, or_ from app.models.addresses import Address, AddressBase, AddressPublic -from app.models.creditor import Creditor, CreditorBase, CreditorPublic +from app.models.creditor import Creditor, CreditorBase, CreditorFilter, CreditorPublic class Card(str, Enum): @@ -55,7 +58,56 @@ class CreditPaymentsList(SQLModel): total: int -class CreditPaymentFilter(Filter): - creditor: str | None = None +class comptype(str, Enum): + eq = "eq" + ne = "ne" + gt = "gt" + ge = "ge" + lt = "lt" + le = "le" + like = "like" + ilike = "ilike" + in_ = "in" + is_ = "is" + is_not = "is_not" + not_in = "not_in" + + +class CreditPaymentFilter(SQLModel): + recipt: str | None = None card: str | None = None + creator: str | None = None + name: str | None = None + creditor: CreditorFilter | None = None + search: str | None = None + + class Constants: + model = CreditPayment + search_field_name = "search" + ordering_field_name = "name" + + def filter(self, query): + print(self.model_fields.keys()) + search_filters = [] + query = query + for field in self.model_fields: + if getattr(self, field) is not None: + if getattr(self.Constants.model, field, None) is not None: + search_filters.append( + getattr(self.Constants.model, field).ilike( + f"{getattr(self, field)}" + ) + ) + if self.search: + fields = self.model_fields + fields.pop("search", None) + search_filters.extend( + [ + getattr(self.Constants.model, field).ilike(f"%{self.search}%") + for field in fields + ] + ) + if len(search_filters) > 0: + query = query.filter(or_(*search_filters)) + return query diff --git a/requirements.txt b/requirements.txt index f9e677e..1e610c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,4 +26,4 @@ svglib authlib itsdangerous requests -uuid6 \ No newline at end of file +uuid6 -- GitLab From c8307f4167e34e63a20cbeb45f21a5ab79bf6696 Mon Sep 17 00:00:00 2001 From: cwalter <cwalter@ethz.ch> Date: Sun, 15 Dec 2024 22:27:23 +0100 Subject: [PATCH 3/5] filter working including search --- app/api/routes/creditPayments.py | 130 ++++++++++++++- app/models/creditPayments.py | 267 +++++++++++++++++++++++++------ 2 files changed, 344 insertions(+), 53 deletions(-) diff --git a/app/api/routes/creditPayments.py b/app/api/routes/creditPayments.py index cd57e5d..8cb437a 100644 --- a/app/api/routes/creditPayments.py +++ b/app/api/routes/creditPayments.py @@ -1,9 +1,9 @@ import uuid -from typing import Annotated, Any +from typing import Annotated, Any, Dict, Optional, Type, Union, get_args, get_origin -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends -from sqlmodel import Session, create_engine, func, insert, select +from sqlmodel import Session, SQLModel, create_engine, func, insert, select from app.api.routes.api_helper import ( create_object, @@ -28,9 +28,104 @@ from app.model import ( router = APIRouter() +def parse_filter_params(request: Request): + # Convert query_params (which is a MultiDict) to a normal dict + query_dict = dict(request.query_params) + + # We'll build a nested dictionary from keys like "creditor.amount" + nested_data = {} + for key, value in query_dict.items(): + parts = key.split(".") # split by '.' + current = nested_data + for i, part in enumerate(parts): + if i == len(parts) - 1: + # Last part, set the value + current[part] = value + else: + # Intermediate dict creation + if part not in current: + current[part] = {} + current = current[part] + + return nested_data + + +def get_credit_payment_filter(request: Request): + data = parse_filter_params(request) + # Now data might look like {"creditor": {"amount": "10"}} + # Convert strings to the right types if needed. + # Pydantic will handle type conversion on initialization if compatible. + filter_obj = CreditPaymentFilter.model_validate(data) + return filter_obj + + +def resolve_field_type(field_annotation): + """ + Resolves the actual type from field annotations, handling Optional/Union types. + """ + if get_origin(field_annotation) is Union: + # If the field is Optional[Type], extract the non-None type + args = get_args(field_annotation) + return next(arg for arg in args if arg is not type(None)) + return field_annotation + + +from pydantic import BaseModel, Field, create_model + + +def flatten_pydantic_model( + model: Type[BaseModel], parent_key: str = "", sep: str = "__" +) -> Dict[str, Any]: + """ + Recursively flattens a Pydantic model's fields using dot notation for nested fields. + + Args: + model (Type[BaseModel]): The Pydantic model to flatten. + parent_key (str, optional): The base key string for nested fields. Defaults to ''. + sep (str, optional): The separator between parent and child keys. Defaults to '.'. + + Returns: + Dict[str, Any]: A dictionary mapping flattened field names to their types and defaults. + """ + fields = {} + for field_name, field in model.model_fields.items(): + # Generate the full key for flattened fields + full_key = f"{parent_key}{sep}{field_name}" if parent_key else field_name + + # Resolve the field type (handle Optional, Union) + field_type = resolve_field_type(field.annotation) + + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Recursively flatten nested Pydantic models + nested_fields = flatten_pydantic_model( + field_type, parent_key=full_key, sep=sep + ) + fields.update(nested_fields) + else: + fields[full_key] = (Optional[field_type], Field(None, alias=full_key)) + return fields + + +def create_flattened_model(name: str, model: Type[BaseModel]) -> Type[BaseModel]: + """ + Creates a flattened Pydantic model with dot notation aliases for nested fields. + + Args: + name (str): The name of the new flattened model. + model (Type[BaseModel]): The original nested Pydantic model. + + Returns: + Type[BaseModel]: The new flattened Pydantic model. + """ + flattened_fields = flatten_pydantic_model(model) + flattened_model = create_model(name, **flattened_fields, __base__=model) + print(flattened_model.model_fields) + return flattened_model + + @router.get("/", response_model=CreditPaymentsList) def read_creditPayments( - filter: Annotated[CreditPaymentFilter, Query()], + filter: CreditPaymentFilter = Depends(get_credit_payment_filter), page: int = 0, limit: int = Query(default=100, le=1000), ) -> CreditPaymentsList: @@ -50,6 +145,33 @@ def read_creditPayments( ) +FlattenedCreditPaymentFilter: Type[SQLModel] = create_flattened_model( + "FlattenedCreditPaymentFilter", CreditPaymentFilter +) + + +@router.get("/list", response_model=CreditPaymentsList) +def read2_creditPayments( + filters: FlattenedCreditPaymentFilter = Depends(), + page: int = 0, + limit: int = Query(default=100, le=1000), +) -> CreditPaymentsList: + """ + retrieve all CreditPaymentes, sorted by name. + paginate the results. + """ + # print(filters.model_fields) + # we need to have a smart way of ordering. THe fastapi_filter library does not support ordering. + return read_objects( + CreditPayment, + page, + limit, + filters, + CreditPayment.id, + CreditPaymentsList, + ) + + @router.get("/{CreditPayment_id}", response_model=CreditPaymentPublic) def read_creditPayment(creditPayment_id: uuid.UUID) -> CreditPaymentPublic: """ diff --git a/app/models/creditPayments.py b/app/models/creditPayments.py index 30f1ae0..d3c35ae 100644 --- a/app/models/creditPayments.py +++ b/app/models/creditPayments.py @@ -1,16 +1,19 @@ +import inspect import uuid from datetime import datetime from enum import Enum -from typing import Any, Optional +from typing import Any, Dict, List, Optional, Type import uuid6 from fastapi_filter import FilterDepends, with_prefix from fastapi_filter.contrib.sqlalchemy import Filter -from pydantic import EmailStr -from sqlmodel import Field, Relationship, SQLModel, or_ +from pydantic import BaseModel, EmailStr +from sqlalchemy import ColumnElement, String, and_, or_ +from sqlalchemy.orm import InstrumentedAttribute, joinedload +from sqlmodel import Field, Relationship, Session, SQLModel, cast, or_, select from app.models.addresses import Address, AddressBase, AddressPublic -from app.models.creditor import Creditor, CreditorBase, CreditorFilter, CreditorPublic +from app.models.creditor import Creditor, CreditorBase, CreditorPublic class Card(str, Enum): @@ -58,56 +61,222 @@ class CreditPaymentsList(SQLModel): total: int -class comptype(str, Enum): - eq = "eq" - ne = "ne" - gt = "gt" - ge = "ge" - lt = "lt" - le = "le" - like = "like" - ilike = "ilike" - in_ = "in" - is_ = "is" - is_not = "is_not" - not_in = "not_in" +class BaseFilter(SQLModel): + """ + Base filter class to provide dynamic filtering and search functionality. + """ - -class CreditPaymentFilter(SQLModel): - - recipt: str | None = None - card: str | None = None - creator: str | None = None - name: str | None = None - creditor: CreditorFilter | None = None - search: str | None = None + search: Optional[str] = None class Constants: - model = CreditPayment + model: Type[SQLModel] = None # To be set in derived classes search_field_name = "search" - ordering_field_name = "name" def filter(self, query): - print(self.model_fields.keys()) - search_filters = [] - query = query - for field in self.model_fields: - if getattr(self, field) is not None: - if getattr(self.Constants.model, field, None) is not None: - search_filters.append( - getattr(self.Constants.model, field).ilike( - f"{getattr(self, field)}" - ) - ) + """ + Applies filtering logic to a given SQLModel query, including nested filters and search. + """ + model = self.Constants.model + filter_data = self.model_dump(exclude_unset=True) + join_paths = set() + + for field_name, field_value in filter_data.items(): + if field_value is None: + continue + + if field_name == self.Constants.search_field_name: + continue # Handle search separately later + + field_name = field_name.replace("__", ".") # Support nested fields + field_path = field_name.split(".") + model_field = getattr(model, field_path[0], None) + + # Build joins for nested fields + for path in field_path[1:]: + if isinstance(model_field, InstrumentedAttribute): + model_field = model_field.property.mapper.class_ + model_field = getattr(model_field, path, None) + join_paths.add(path) + else: + model_field = getattr(model_field, path, None) + + if model_field is None: + continue # Skip non-existent fields + + column = getattr(model, field_path[0], None) + for path in field_path[1:]: + if isinstance(column, InstrumentedAttribute): + query = query.join(column) + column = column.property.mapper.class_ + column = getattr(column, path, None) + + if column.type.python_type == str: + query = query.filter(column.ilike(f"%{field_value}%")) + else: + query = query.filter(column == field_value) + # Handle global search if self.search: - fields = self.model_fields - fields.pop("search", None) - search_filters.extend( - [ - getattr(self.Constants.model, field).ilike(f"%{self.search}%") - for field in fields - ] - ) - if len(search_filters) > 0: - query = query.filter(or_(*search_filters)) + search_conditions = self._build_search_conditions(model, self.search) + if search_conditions: + # print(f"joins: {self._build_search_joins(model, self.search)}") + for join_path in self._build_search_joins(model, self.search): + # Join the path if not already joined + query = query.join(join_path) + # print(f"search_conditions: {[cond.all_() for cond in search_conditions]}") + query = query.filter(or_(*search_conditions)) return query + + def _build_search_joins( + self, model: SQLModel, search_value: str + ) -> List[InstrumentedAttribute]: + """ + Dynamically creates a list of relationship attributes to be joined for searching. + + This method inspects the given model and its related models recursively. + For each relationship found, it will add the corresponding column to a list + of joins so that conditions on nested fields can be applied. + + Args: + model: The root SQLModel class to inspect. + search_value: The search string. + + Returns: + A list of relationship columns that need to be joined. + """ + search_joins: List[InstrumentedAttribute] = [] + visited_models: set[Type[SQLModel]] = set() + + def process_model(m: Type[SQLModel]): + # Prevent infinite recursion if cyclic relationships exist + if m in visited_models: + return + visited_models.add(m) + + for field_name, column in m.__dict__.items(): + # Check if this is an ORM mapped attribute + if isinstance(column, InstrumentedAttribute): + # Determine if column is a relationship or a simple column + # Relationships have a `.property.mapper` attribute + # print(f"column: {column}") + prop = getattr(column, "property", None) + # print(f"prop: {prop}, prop.mapper: {getattr(prop, 'mapper', None)}") + if prop and hasattr(prop, "mapper"): + # It's a relationship field + # print(f"related_model: {prop.mapper.class_}") + related_model = prop.mapper.class_ + # Add to joins for later + search_joins.append(column) + # Recurse into related model + process_model(related_model) + + # Start from the root model + process_model(model) + return search_joins + + def _build_search_conditions( + self, model: SQLModel, search_value: str + ) -> List[ColumnElement]: + """ + Dynamically creates search conditions for all string-like fields, + including nested fields. + """ + search_conditions = [] + + def process_model(model, parent_alias=None, prefix=""): + """ + Recursively process a model to extract string and numeric columns for search. + Handles nested relationships using joins. + """ + for field_name, column in model.__dict__.items(): + # Check for valid ORM attributes + if isinstance(column, InstrumentedAttribute): + try: + # Resolve the column type + column_type = column.type + # print(f"column: {column} column_type: {column_type}") + if column_type is None: + continue + try: + if hasattr(column_type, "python_type"): + column_type = column_type.python_type + # print(f"column_type: {column_type}") + except NotImplementedError: + pass + # Add string-like fields to search conditions + if column_type == str: + # print(f"str column: {column}") + search_conditions.append( + column.contains(f"%{search_value}%") + ) + + # Add numeric fields, casting to string + elif column_type in [int, float]: + # print(f"int column: {column}") + search_conditions.append( + cast(column, String).contains(f"%{search_value}%") + ) + elif column_type == uuid.UUID: + # skip uuid fields + pass + else: + search_conditions.append( + cast(column, String).contains(f"%{search_value}%") + ) + # print(f"elsecolumn: {column}") + except AttributeError: + # Handle relationships or mappers + related_model = ( + column.property.mapper.class_ + if hasattr(column, "property") + else None + ) + if related_model: + # Recurse into related model + process_model( + related_model, column, prefix=f"{prefix}{field_name}__" + ) + + # Start processing from the root model + process_model(model) + return search_conditions + + +class Kstfilter(BaseFilter): + kst_number: Optional[int] = None + name_de: Optional[str] = None + name_en: Optional[str] = None + owner: Optional[str] = None + active: Optional[bool] = None + budget_plus: Optional[int] = None + budget_minus: Optional[int] = None + + +class CreditorFilter(BaseFilter): + """ + Example nested filter for Creditor. + """ + + name: Optional[str] = None + amount: Optional[int] = None + kst: Optional[Kstfilter] = None + + class Constants: + model = Creditor # Replace with your Creditor SQLModel class + search_field_name = "search" + + +class CreditPaymentFilter(BaseFilter): + """ + Filter class for CreditPayment. + """ + + recipt: Optional[str] = None + card: Optional[str] = None + creator: Optional[str] = None + name: Optional[str] = None + creditor: Optional[CreditorFilter] = None # Nested filter + search: Optional[str] = None + + class Constants: + model = CreditPayment # Replace with your CreditPayment SQLModel class + search_field_name = "search" -- GitLab From 65d54c9b3a80bb7fc0da3443bed5bfdbb9d01ba0 Mon Sep 17 00:00:00 2001 From: cwalter <cwalter@ethz.ch> Date: Mon, 16 Dec 2024 10:07:29 +0100 Subject: [PATCH 4/5] extracted filter to separate file and class. Implemented it in most places --- app/api/routes/addresses.py | 4 +- app/api/routes/bills.py | 4 +- app/api/routes/creditPayments.py | 111 +----------- app/api/routes/internalTransfers.py | 4 +- app/api/routes/invoices.py | 4 +- app/api/routes/items.py | 4 +- app/api/routes/ksts.py | 4 +- app/api/routes/ledgers.py | 4 +- app/api/routes/reimbursements.py | 12 +- app/models/Users.py | 3 + app/models/addresses.py | 14 +- app/models/bills.py | 5 +- app/models/creditPayments.py | 218 +----------------------- app/models/creditor.py | 13 +- app/models/debitor.py | 14 +- app/models/filter.py | 252 ++++++++++++++++++++++++++++ app/models/internalTransfers.py | 12 +- app/models/invoices.py | 14 +- app/models/items.py | 4 +- app/models/ksts.py | 18 +- app/models/ledgers.py | 9 +- app/models/reimbursements.py | 12 +- 22 files changed, 345 insertions(+), 394 deletions(-) create mode 100644 app/models/filter.py diff --git a/app/api/routes/addresses.py b/app/api/routes/addresses.py index 6e2ec8f..1b1687d 100644 --- a/app/api/routes/addresses.py +++ b/app/api/routes/addresses.py @@ -1,7 +1,7 @@ import uuid from typing import Any -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlmodel import Session, create_engine, func, insert, select @@ -28,7 +28,7 @@ router = APIRouter() def read_addresses( page: int = 0, limit: int = Query(default=100, le=1000), - filter: AddressFilter = FilterDepends(AddressFilter), + filter: AddressFilter().create_flattened_model() = Depends(), ) -> AddressesPublic: """ retrieve all addresses, sorted by name. diff --git a/app/api/routes/bills.py b/app/api/routes/bills.py index 24d944e..c938a4d 100644 --- a/app/api/routes/bills.py +++ b/app/api/routes/bills.py @@ -1,7 +1,7 @@ import uuid from typing import Any -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlmodel import Session, create_engine, func, insert, select @@ -32,7 +32,7 @@ router = APIRouter() def read_Bills( page: int = 0, limit: int = Query(default=100, le=1000), - filter: BillFilter = FilterDepends(BillFilter), + filter: BillFilter().create_flattened_model() = Depends(), ) -> BillsList: """ retrieve all Billes, sorted by name. diff --git a/app/api/routes/creditPayments.py b/app/api/routes/creditPayments.py index 8cb437a..087c801 100644 --- a/app/api/routes/creditPayments.py +++ b/app/api/routes/creditPayments.py @@ -28,37 +28,6 @@ from app.model import ( router = APIRouter() -def parse_filter_params(request: Request): - # Convert query_params (which is a MultiDict) to a normal dict - query_dict = dict(request.query_params) - - # We'll build a nested dictionary from keys like "creditor.amount" - nested_data = {} - for key, value in query_dict.items(): - parts = key.split(".") # split by '.' - current = nested_data - for i, part in enumerate(parts): - if i == len(parts) - 1: - # Last part, set the value - current[part] = value - else: - # Intermediate dict creation - if part not in current: - current[part] = {} - current = current[part] - - return nested_data - - -def get_credit_payment_filter(request: Request): - data = parse_filter_params(request) - # Now data might look like {"creditor": {"amount": "10"}} - # Convert strings to the right types if needed. - # Pydantic will handle type conversion on initialization if compatible. - filter_obj = CreditPaymentFilter.model_validate(data) - return filter_obj - - def resolve_field_type(field_annotation): """ Resolves the actual type from field annotations, handling Optional/Union types. @@ -73,86 +42,9 @@ def resolve_field_type(field_annotation): from pydantic import BaseModel, Field, create_model -def flatten_pydantic_model( - model: Type[BaseModel], parent_key: str = "", sep: str = "__" -) -> Dict[str, Any]: - """ - Recursively flattens a Pydantic model's fields using dot notation for nested fields. - - Args: - model (Type[BaseModel]): The Pydantic model to flatten. - parent_key (str, optional): The base key string for nested fields. Defaults to ''. - sep (str, optional): The separator between parent and child keys. Defaults to '.'. - - Returns: - Dict[str, Any]: A dictionary mapping flattened field names to their types and defaults. - """ - fields = {} - for field_name, field in model.model_fields.items(): - # Generate the full key for flattened fields - full_key = f"{parent_key}{sep}{field_name}" if parent_key else field_name - - # Resolve the field type (handle Optional, Union) - field_type = resolve_field_type(field.annotation) - - if isinstance(field_type, type) and issubclass(field_type, BaseModel): - # Recursively flatten nested Pydantic models - nested_fields = flatten_pydantic_model( - field_type, parent_key=full_key, sep=sep - ) - fields.update(nested_fields) - else: - fields[full_key] = (Optional[field_type], Field(None, alias=full_key)) - return fields - - -def create_flattened_model(name: str, model: Type[BaseModel]) -> Type[BaseModel]: - """ - Creates a flattened Pydantic model with dot notation aliases for nested fields. - - Args: - name (str): The name of the new flattened model. - model (Type[BaseModel]): The original nested Pydantic model. - - Returns: - Type[BaseModel]: The new flattened Pydantic model. - """ - flattened_fields = flatten_pydantic_model(model) - flattened_model = create_model(name, **flattened_fields, __base__=model) - print(flattened_model.model_fields) - return flattened_model - - @router.get("/", response_model=CreditPaymentsList) def read_creditPayments( - filter: CreditPaymentFilter = Depends(get_credit_payment_filter), - page: int = 0, - limit: int = Query(default=100, le=1000), -) -> CreditPaymentsList: - """ - retrieve all CreditPaymentes, sorted by name. - paginate the results. - """ - print(filter.__dict__) - # we need to have a smart way of ordering. THe fastapi_filter library does not support ordering. - return read_objects( - CreditPayment, - page, - limit, - filter, - CreditPayment.id, - CreditPaymentsList, - ) - - -FlattenedCreditPaymentFilter: Type[SQLModel] = create_flattened_model( - "FlattenedCreditPaymentFilter", CreditPaymentFilter -) - - -@router.get("/list", response_model=CreditPaymentsList) -def read2_creditPayments( - filters: FlattenedCreditPaymentFilter = Depends(), + filters: CreditPaymentFilter().create_flattened_model() = Depends(), page: int = 0, limit: int = Query(default=100, le=1000), ) -> CreditPaymentsList: @@ -160,7 +52,6 @@ def read2_creditPayments( retrieve all CreditPaymentes, sorted by name. paginate the results. """ - # print(filters.model_fields) # we need to have a smart way of ordering. THe fastapi_filter library does not support ordering. return read_objects( CreditPayment, diff --git a/app/api/routes/internalTransfers.py b/app/api/routes/internalTransfers.py index 3bddfc5..79de117 100644 --- a/app/api/routes/internalTransfers.py +++ b/app/api/routes/internalTransfers.py @@ -1,7 +1,7 @@ import uuid from typing import Any -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlmodel import Session, create_engine, func, insert, select @@ -31,7 +31,7 @@ router = APIRouter() def read_InternalTransferes( page: int = 0, limit: int = Query(default=100, le=1000), - filter: InternalTransferFilter = FilterDepends(InternalTransferFilter), + filter: InternalTransferFilter().create_flattened_model() = Depends(), ) -> InternalTransfersList: """ retrieve all InternalTransferes, sorted by name. diff --git a/app/api/routes/invoices.py b/app/api/routes/invoices.py index c88bcec..9cf7743 100644 --- a/app/api/routes/invoices.py +++ b/app/api/routes/invoices.py @@ -2,7 +2,7 @@ import uuid from datetime import datetime from typing import Any -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlalchemy.orm import selectinload from sqlmodel import Session, create_engine, func, insert, select @@ -47,7 +47,7 @@ def read_Invoice(invoice_id: uuid.UUID) -> InvoicePublic: def read_Invoices( page: int = 0, limit: int = Query(default=100, le=1000), - filter: InvoiceFilter = FilterDepends(InvoiceFilter), + filter: InvoiceFilter().create_flattened_model() = Depends(), ) -> InvoicesList: """ Retrieve all Invoices, sorted by id. diff --git a/app/api/routes/items.py b/app/api/routes/items.py index 0450255..7c21f37 100644 --- a/app/api/routes/items.py +++ b/app/api/routes/items.py @@ -1,7 +1,7 @@ import uuid from typing import Any -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlmodel import Session, create_engine, func, insert, select @@ -22,7 +22,7 @@ router = APIRouter() def read_items( page: int = 0, limit: int = Query(default=100, le=1000), - filter: ItemFilter = FilterDepends(ItemFilter), + filter: ItemFilter().create_flattened_model() = Depends(), ) -> ItemsPublic: """ retrieve all items, sorted by id. diff --git a/app/api/routes/ksts.py b/app/api/routes/ksts.py index 56cc581..5c6d817 100644 --- a/app/api/routes/ksts.py +++ b/app/api/routes/ksts.py @@ -1,7 +1,7 @@ import uuid from typing import Any -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlmodel import Session, create_engine, func, insert, select @@ -22,7 +22,7 @@ router = APIRouter() def read_ksts( page: int = 0, limit: int = Query(default=100, le=1000), - filter: KstFilter = FilterDepends(KstFilter), + filter: KstFilter().create_flattened_model() = Depends(), ) -> KstsPublic: """ retrieve all ksts, sorted by accountnumber. diff --git a/app/api/routes/ledgers.py b/app/api/routes/ledgers.py index 90c7155..b3df35a 100644 --- a/app/api/routes/ledgers.py +++ b/app/api/routes/ledgers.py @@ -1,7 +1,7 @@ import uuid from typing import Any -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlmodel import Session, create_engine, func, insert, select @@ -22,7 +22,7 @@ router = APIRouter() def read_ledgers( page: int = 0, limit: int = Query(default=100, le=1000), - filter: LedgerFilter = FilterDepends(LedgerFilter), + filter: LedgerFilter().create_flattened_model() = Depends(), ) -> LedgersPublic: """ retrieve all ledgers, sorted by accountnumber. diff --git a/app/api/routes/reimbursements.py b/app/api/routes/reimbursements.py index 95b67fc..1809049 100644 --- a/app/api/routes/reimbursements.py +++ b/app/api/routes/reimbursements.py @@ -11,7 +11,7 @@ from fastapi import ( Query, Request, UploadFile, - status + status, ) from fastapi_filter import FilterDepends from sqlmodel import Session, create_engine, func, insert, select @@ -45,13 +45,13 @@ def read_Reimbursements( current_user: Annotated[Any, Depends(get_user_info)], page: int = 0, limit: int = Query(default=100, le=1000), - filter: ReimbursementFilter = FilterDepends(ReimbursementFilter), + filter: ReimbursementFilter().create_flattened_model() = Depends(), ) -> ReimbursementsList: """ retrieve all Reimbursementes, sorted by name. paginate the results. """ - + print(current_user) # we need to have a smart way of ordering. THe fastapi_filter library does not support ordering. return read_objects( @@ -71,8 +71,10 @@ def read_Reimbursement( """ retrieve a single Reimbursement by id. """ - - object:ReimbursementPublic= read_object(Reimbursement, Reimbursement_id, ReimbursementPublic) + + object: ReimbursementPublic = read_object( + Reimbursement, Reimbursement_id, ReimbursementPublic + ) return object diff --git a/app/models/Users.py b/app/models/Users.py index b0ecf5c..2b03d20 100644 --- a/app/models/Users.py +++ b/app/models/Users.py @@ -10,6 +10,7 @@ from app.models.addresses import Address, AddressBase, AddressPublic class DbUserBase(SQLModel): amiv_id: str = Field(max_length=30, primary_key=True) + handle: str = Field(max_length=30) address_id: uuid.UUID = Field(foreign_key="address.id") iban: str = Field(max_length=30) @@ -35,6 +36,7 @@ class DbUsersList(SQLModel): count: int total: int + class BasicUser(SQLModel): amiv_id: str nethz: str @@ -42,6 +44,7 @@ class BasicUser(SQLModel): lastname: str email: EmailStr + class DbUserFilter(Filter): amiv_id: str | None = None address: str | None = None diff --git a/app/models/addresses.py b/app/models/addresses.py index d780e19..e85485c 100644 --- a/app/models/addresses.py +++ b/app/models/addresses.py @@ -7,6 +7,8 @@ from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel +from app.models.filter import BaseFilter + class AddressBase(SQLModel): name: str = Field(max_length=30) @@ -34,7 +36,7 @@ class AddressesPublic(SQLModel): total: int -class AddressFilter(Filter): +class AddressFilter(BaseFilter): name: str | None = None address1: str | None = None address2: str | None = None @@ -44,14 +46,6 @@ class AddressFilter(Filter): plz: int | None = None search: Optional[str] = None - class Constants(Filter.Constants): + class Constants: model = Address search_field_name = "search" - search_model_fields = [ - "name", - "adress1", - "adress2", - "adress3", - "city", - "country", - ] diff --git a/app/models/bills.py b/app/models/bills.py index 780eceb..1fbd557 100644 --- a/app/models/bills.py +++ b/app/models/bills.py @@ -9,6 +9,7 @@ from sqlmodel import Field, Relationship, SQLModel from app.models.addresses import Address, AddressBase, AddressFilter, AddressPublic from app.models.creditor import Creditor, CreditorBase, CreditorPublic +from app.models.filter import BaseFilter class BillBase(SQLModel): @@ -56,9 +57,9 @@ class BillsList(SQLModel): total: int -class BillFilter(Filter): +class BillFilter(BaseFilter): # number_filter: Optional[Filter] = FilterDepends(with_prefix("address", AddressFilter)) - address: str | None = None + address: Optional[AddressFilter] = None reference: str | None = None iban: str | None = None recipt: str | None = None diff --git a/app/models/creditPayments.py b/app/models/creditPayments.py index d3c35ae..f16b938 100644 --- a/app/models/creditPayments.py +++ b/app/models/creditPayments.py @@ -2,18 +2,14 @@ import inspect import uuid from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Type +from typing import Optional import uuid6 -from fastapi_filter import FilterDepends, with_prefix -from fastapi_filter.contrib.sqlalchemy import Filter -from pydantic import BaseModel, EmailStr -from sqlalchemy import ColumnElement, String, and_, or_ -from sqlalchemy.orm import InstrumentedAttribute, joinedload -from sqlmodel import Field, Relationship, Session, SQLModel, cast, or_, select +from sqlmodel import Field, Relationship, SQLModel -from app.models.addresses import Address, AddressBase, AddressPublic -from app.models.creditor import Creditor, CreditorBase, CreditorPublic +from app.models.addresses import Address, AddressBase, AddressFilter, AddressPublic +from app.models.creditor import Creditor, CreditorBase, CreditorFilter, CreditorPublic +from app.models.filter import BaseFilter class Card(str, Enum): @@ -61,210 +57,6 @@ class CreditPaymentsList(SQLModel): total: int -class BaseFilter(SQLModel): - """ - Base filter class to provide dynamic filtering and search functionality. - """ - - search: Optional[str] = None - - class Constants: - model: Type[SQLModel] = None # To be set in derived classes - search_field_name = "search" - - def filter(self, query): - """ - Applies filtering logic to a given SQLModel query, including nested filters and search. - """ - model = self.Constants.model - filter_data = self.model_dump(exclude_unset=True) - join_paths = set() - - for field_name, field_value in filter_data.items(): - if field_value is None: - continue - - if field_name == self.Constants.search_field_name: - continue # Handle search separately later - - field_name = field_name.replace("__", ".") # Support nested fields - field_path = field_name.split(".") - model_field = getattr(model, field_path[0], None) - - # Build joins for nested fields - for path in field_path[1:]: - if isinstance(model_field, InstrumentedAttribute): - model_field = model_field.property.mapper.class_ - model_field = getattr(model_field, path, None) - join_paths.add(path) - else: - model_field = getattr(model_field, path, None) - - if model_field is None: - continue # Skip non-existent fields - - column = getattr(model, field_path[0], None) - for path in field_path[1:]: - if isinstance(column, InstrumentedAttribute): - query = query.join(column) - column = column.property.mapper.class_ - column = getattr(column, path, None) - - if column.type.python_type == str: - query = query.filter(column.ilike(f"%{field_value}%")) - else: - query = query.filter(column == field_value) - # Handle global search - if self.search: - search_conditions = self._build_search_conditions(model, self.search) - if search_conditions: - # print(f"joins: {self._build_search_joins(model, self.search)}") - for join_path in self._build_search_joins(model, self.search): - # Join the path if not already joined - query = query.join(join_path) - # print(f"search_conditions: {[cond.all_() for cond in search_conditions]}") - query = query.filter(or_(*search_conditions)) - return query - - def _build_search_joins( - self, model: SQLModel, search_value: str - ) -> List[InstrumentedAttribute]: - """ - Dynamically creates a list of relationship attributes to be joined for searching. - - This method inspects the given model and its related models recursively. - For each relationship found, it will add the corresponding column to a list - of joins so that conditions on nested fields can be applied. - - Args: - model: The root SQLModel class to inspect. - search_value: The search string. - - Returns: - A list of relationship columns that need to be joined. - """ - search_joins: List[InstrumentedAttribute] = [] - visited_models: set[Type[SQLModel]] = set() - - def process_model(m: Type[SQLModel]): - # Prevent infinite recursion if cyclic relationships exist - if m in visited_models: - return - visited_models.add(m) - - for field_name, column in m.__dict__.items(): - # Check if this is an ORM mapped attribute - if isinstance(column, InstrumentedAttribute): - # Determine if column is a relationship or a simple column - # Relationships have a `.property.mapper` attribute - # print(f"column: {column}") - prop = getattr(column, "property", None) - # print(f"prop: {prop}, prop.mapper: {getattr(prop, 'mapper', None)}") - if prop and hasattr(prop, "mapper"): - # It's a relationship field - # print(f"related_model: {prop.mapper.class_}") - related_model = prop.mapper.class_ - # Add to joins for later - search_joins.append(column) - # Recurse into related model - process_model(related_model) - - # Start from the root model - process_model(model) - return search_joins - - def _build_search_conditions( - self, model: SQLModel, search_value: str - ) -> List[ColumnElement]: - """ - Dynamically creates search conditions for all string-like fields, - including nested fields. - """ - search_conditions = [] - - def process_model(model, parent_alias=None, prefix=""): - """ - Recursively process a model to extract string and numeric columns for search. - Handles nested relationships using joins. - """ - for field_name, column in model.__dict__.items(): - # Check for valid ORM attributes - if isinstance(column, InstrumentedAttribute): - try: - # Resolve the column type - column_type = column.type - # print(f"column: {column} column_type: {column_type}") - if column_type is None: - continue - try: - if hasattr(column_type, "python_type"): - column_type = column_type.python_type - # print(f"column_type: {column_type}") - except NotImplementedError: - pass - # Add string-like fields to search conditions - if column_type == str: - # print(f"str column: {column}") - search_conditions.append( - column.contains(f"%{search_value}%") - ) - - # Add numeric fields, casting to string - elif column_type in [int, float]: - # print(f"int column: {column}") - search_conditions.append( - cast(column, String).contains(f"%{search_value}%") - ) - elif column_type == uuid.UUID: - # skip uuid fields - pass - else: - search_conditions.append( - cast(column, String).contains(f"%{search_value}%") - ) - # print(f"elsecolumn: {column}") - except AttributeError: - # Handle relationships or mappers - related_model = ( - column.property.mapper.class_ - if hasattr(column, "property") - else None - ) - if related_model: - # Recurse into related model - process_model( - related_model, column, prefix=f"{prefix}{field_name}__" - ) - - # Start processing from the root model - process_model(model) - return search_conditions - - -class Kstfilter(BaseFilter): - kst_number: Optional[int] = None - name_de: Optional[str] = None - name_en: Optional[str] = None - owner: Optional[str] = None - active: Optional[bool] = None - budget_plus: Optional[int] = None - budget_minus: Optional[int] = None - - -class CreditorFilter(BaseFilter): - """ - Example nested filter for Creditor. - """ - - name: Optional[str] = None - amount: Optional[int] = None - kst: Optional[Kstfilter] = None - - class Constants: - model = Creditor # Replace with your Creditor SQLModel class - search_field_name = "search" - - class CreditPaymentFilter(BaseFilter): """ Filter class for CreditPayment. diff --git a/app/models/creditor.py b/app/models/creditor.py index c009d0e..37fd79e 100644 --- a/app/models/creditor.py +++ b/app/models/creditor.py @@ -1,13 +1,15 @@ import uuid from datetime import datetime from enum import Enum +from typing import Optional import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter from sqlmodel import Field, Relationship, SQLModel -from app.models.ksts import Kst -from app.models.ledgers import Ledger +from app.models.filter import BaseFilter +from app.models.ksts import Kst, KstFilter +from app.models.ledgers import Ledger, LedgerFilter # create a currency enum @@ -41,9 +43,9 @@ class CreditorPublic(CreditorBase): id: uuid.UUID = Field() -class CreditorFilter(Filter): - kst: str | None = None - ledger: str | None = None +class CreditorFilter(BaseFilter): + kst: Optional[KstFilter] = None + ledger: Optional[LedgerFilter] = None amount: int | None = None accounting_year: int | None = None currency: Currency | None = None @@ -54,4 +56,3 @@ class CreditorFilter(Filter): class Constants(Filter.Constants): model = Creditor search_field_name = "search" - search_model_fields = ["comment", "qcomment"] diff --git a/app/models/debitor.py b/app/models/debitor.py index 984edaa..98810ed 100644 --- a/app/models/debitor.py +++ b/app/models/debitor.py @@ -7,9 +7,10 @@ from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel from app.model import * +from app.models.filter import BaseFilter from app.models.items import Item, ItemPublic -from app.models.ksts import Kst -from app.models.ledgers import Ledger +from app.models.ksts import Kst, KstFilter +from app.models.ledgers import Ledger, LedgerFilter class DebitorBase(SQLModel): @@ -33,12 +34,11 @@ class DebitorPublic(DebitorBase): id: uuid.UUID -class DebitorFilter(Filter): - kst: str | None = None - ledger: str | None = None +class DebitorFilter(BaseFilter): + kst: Optional[KstFilter] = None + ledger: Optional[LedgerFilter] = None mwst: str | None = None - class Constants(Filter.Constants): + class Constants: model = Debitor search_field_name = "search" - search_model_fields = ["adress", "kst", "ledger", "mwst"] diff --git a/app/models/filter.py b/app/models/filter.py new file mode 100644 index 0000000..8df42dd --- /dev/null +++ b/app/models/filter.py @@ -0,0 +1,252 @@ +import uuid +from typing import Any, Dict, List, Optional, Type, Union, get_args, get_origin + +from pydantic import BaseModel +from pydantic import Field as pydantic_Field +from pydantic import create_model +from sqlalchemy import ColumnElement, String +from sqlalchemy.orm import InstrumentedAttribute +from sqlmodel import SQLModel, cast, or_ + + +class BaseFilter(SQLModel): + """ + Base filter class to provide dynamic filtering and search functionality. + """ + + search: Optional[str] = None + + class Constants: + model: Type[SQLModel] = None # To be set in derived classes + search_field_name = "search" + + def resolve_field_type(self, field_annotation): + """ + Resolves the actual type from field annotations, handling Optional/Union types. + """ + if get_origin(field_annotation) is Union: + # If the field is Optional[Type], extract the non-None type + args = get_args(field_annotation) + return next(arg for arg in args if arg is not type(None)) + return field_annotation + + def flatten_pydantic_model( + self, model: Type[BaseModel], parent_key: str = "", sep: str = "__" + ) -> Dict[str, Any]: + """ + Recursively flattens a Pydantic model's fields using dot notation for nested fields. + + Args: + model (Type[BaseModel]): The Pydantic model to flatten. + parent_key (str, optional): The base key string for nested fields. Defaults to ''. + sep (str, optional): The separator between parent and child keys. Defaults to '.'. + + Returns: + Dict[str, Any]: A dictionary mapping flattened field names to their types and defaults. + """ + # print(type(self)) + fields = {} + for field_name, field in model.model_fields.items(): + # Generate the full key for flattened fields + full_key = f"{parent_key}{sep}{field_name}" if parent_key else field_name + + # Resolve the field type (handle Optional, Union) + field_type = self.resolve_field_type(field.annotation) + + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + # Recursively flatten nested Pydantic models + nested_fields = self.flatten_pydantic_model( + field_type, parent_key=full_key, sep=sep + ) + fields.update(nested_fields) + else: + fields[full_key] = ( + Optional[field_type], + pydantic_Field(None, alias=full_key), + ) + return fields + + def create_flattened_model(self) -> Type[BaseModel]: + """ + Creates a flattened Pydantic model with dot notation aliases for nested fields. + + Args: + name (str): The name of the new flattened model. + + Returns: + Type[BaseModel]: The new flattened Pydantic model. + """ + flattened_fields = self.flatten_pydantic_model(self.__class__) + flattened_model = create_model( + "flattened filter", **flattened_fields, __base__=self.__class__ + ) + # print(flattened_model.model_fields) + return flattened_model + + def filter(self, query): + """ + Applies filtering logic to a given SQLModel query, including nested filters and search. + """ + model = self.Constants.model + filter_data = self.model_dump(exclude_unset=True) + join_paths = set() + + for field_name, field_value in filter_data.items(): + if field_value is None: + continue + + if field_name == self.Constants.search_field_name: + continue # Handle search separately later + + field_name = field_name.replace("__", ".") # Support nested fields + field_path = field_name.split(".") + model_field = getattr(model, field_path[0], None) + + # Build joins for nested fields + for path in field_path[1:]: + if isinstance(model_field, InstrumentedAttribute): + model_field = model_field.property.mapper.class_ + model_field = getattr(model_field, path, None) + join_paths.add(path) + else: + model_field = getattr(model_field, path, None) + + if model_field is None: + continue # Skip non-existent fields + + column = getattr(model, field_path[0], None) + for path in field_path[1:]: + if isinstance(column, InstrumentedAttribute): + query = query.join(column) + column = column.property.mapper.class_ + column = getattr(column, path, None) + + if column.type.python_type == str: + query = query.filter(column.ilike(f"%{field_value}%")) + else: + query = query.filter(column == field_value) + # Handle global search + if self.search: + search_conditions = self._build_search_conditions(model, self.search) + if search_conditions: + # print(f"joins: {self._build_search_joins(model, self.search)}") + for join_path in self._build_search_joins(model, self.search): + # Join the path if not already joined + query = query.join(join_path) + # print(f"search_conditions: {[cond.all_() for cond in search_conditions]}") + query = query.filter(or_(*search_conditions)) + return query + + def _build_search_joins( + self, model: SQLModel, search_value: str + ) -> List[InstrumentedAttribute]: + """ + Dynamically creates a list of relationship attributes to be joined for searching. + + This method inspects the given model and its related models recursively. + For each relationship found, it will add the corresponding column to a list + of joins so that conditions on nested fields can be applied. + + Args: + model: The root SQLModel class to inspect. + search_value: The search string. + + Returns: + A list of relationship columns that need to be joined. + """ + search_joins: List[InstrumentedAttribute] = [] + visited_models: set[Type[SQLModel]] = set() + + def process_model(m: Type[SQLModel]): + # Prevent infinite recursion if cyclic relationships exist + if m in visited_models: + return + visited_models.add(m) + + for field_name, column in m.__dict__.items(): + # Check if this is an ORM mapped attribute + if isinstance(column, InstrumentedAttribute): + # Determine if column is a relationship or a simple column + # Relationships have a `.property.mapper` attribute + # print(f"column: {column}") + prop = getattr(column, "property", None) + # print(f"prop: {prop}, prop.mapper: {getattr(prop, 'mapper', None)}") + if prop and hasattr(prop, "mapper"): + # It's a relationship field + # print(f"related_model: {prop.mapper.class_}") + related_model = prop.mapper.class_ + # Add to joins for later + search_joins.append(column) + # Recurse into related model + process_model(related_model) + + # Start from the root model + process_model(model) + return search_joins + + def _build_search_conditions( + self, model: SQLModel, search_value: str + ) -> List[ColumnElement]: + """ + Dynamically creates search conditions for all string-like fields, + including nested fields. + """ + search_conditions = [] + + def process_model(model, parent_alias=None, prefix=""): + """ + Recursively process a model to extract string and numeric columns for search. + Handles nested relationships using joins. + """ + for field_name, column in model.__dict__.items(): + # Check for valid ORM attributes + if isinstance(column, InstrumentedAttribute): + try: + # Resolve the column type + column_type = column.type + # print(f"column: {column} column_type: {column_type}") + if column_type is None: + continue + try: + if hasattr(column_type, "python_type"): + column_type = column_type.python_type + # print(f"column_type: {column_type}") + except NotImplementedError: + pass + # Add string-like fields to search conditions + if column_type == str: + # print(f"str column: {column}") + search_conditions.append( + column.contains(f"%{search_value}%") + ) + + # Add numeric fields, casting to string + elif column_type in [int, float]: + # print(f"int column: {column}") + search_conditions.append( + cast(column, String).contains(f"%{search_value}%") + ) + elif column_type == uuid.UUID: + # skip uuid fields + pass + else: + search_conditions.append( + cast(column, String).contains(f"%{search_value}%") + ) + # print(f"elsecolumn: {column}") + except AttributeError: + # Handle relationships or mappers + related_model = ( + column.property.mapper.class_ + if hasattr(column, "property") + else None + ) + if related_model: + # Recurse into related model + process_model( + related_model, column, prefix=f"{prefix}{field_name}__" + ) + + # Start processing from the root model + process_model(model) + return search_conditions diff --git a/app/models/internalTransfers.py b/app/models/internalTransfers.py index 820ad30..df1bd16 100644 --- a/app/models/internalTransfers.py +++ b/app/models/internalTransfers.py @@ -1,5 +1,6 @@ import uuid from datetime import datetime +from typing import Optional import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter @@ -7,8 +8,9 @@ from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel from app.model import Creditor, CreditorBase, Debitor, DebitorBase -from app.models.creditor import CreditorPublic -from app.models.debitor import DebitorPublic +from app.models.creditor import CreditorFilter, CreditorPublic +from app.models.debitor import DebitorFilter, DebitorPublic +from app.models.filter import BaseFilter class InternalTransferBase(SQLModel): @@ -48,7 +50,7 @@ class InternalTransfersList(SQLModel): total: int -class InternalTransferFilter(Filter): - debitor: str | None = None - creditor: str | None = None +class InternalTransferFilter(BaseFilter): + debitor: Optional[DebitorFilter] = None + creditor: Optional[CreditorFilter] = None amount: str | None = None diff --git a/app/models/invoices.py b/app/models/invoices.py index 1270729..5728de0 100644 --- a/app/models/invoices.py +++ b/app/models/invoices.py @@ -1,13 +1,17 @@ import uuid from datetime import datetime +from typing import Optional import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel +from app.models.addresses import Address, AddressFilter from app.models.debitor import Debitor, DebitorBase, DebitorFilter, DebitorPublic +from app.models.filter import BaseFilter from app.models.items import Item, ItemPublic +from app.models.ksts import KstFilter class ItemstoInvoice(SQLModel, table=True): @@ -42,12 +46,10 @@ class ItemstoInvoiceCreate(SQLModel): class InvoiceBase(SQLModel): payinterval: int = Field(default=30) - adress: str = Field(max_length=30) class Invoice(InvoiceBase, table=True): id: uuid.UUID = Field(default_factory=uuid6.uuid7, primary_key=True) - time_create: datetime = Field(default=datetime.now()) time_modified: datetime = Field(default=datetime.now()) items: list[ItemstoInvoice] = Relationship( @@ -74,14 +76,12 @@ class InvoicesList(SQLModel): total: int -class InvoiceFilter(Filter): - adress: str | None = None - kst: str | None = None +class InvoiceFilter(BaseFilter): + kst: Optional[KstFilter] = None mwst: str | None = None payinterval: int | None = None comment: str | None = None - debitor: DebitorFilter | None = None + debitor: Optional[DebitorFilter] = None time_create: datetime | None = None time_modified: datetime | None = None - items: list[ItemstoInvoicePublic] | None = [] # creator_id: uuid.UUID diff --git a/app/models/items.py b/app/models/items.py index 4e5fbce..1ee0975 100644 --- a/app/models/items.py +++ b/app/models/items.py @@ -7,6 +7,8 @@ from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel +from app.models.filter import BaseFilter + class ItemBase(SQLModel): title_de: str = Field(min_length=1, max_length=255) @@ -34,7 +36,7 @@ class ItemsPublic(SQLModel): total: int -class ItemFilter(Filter): +class ItemFilter(BaseFilter): price: Optional[int] = None unit: Optional[str] = None active: Optional[bool] = None diff --git a/app/models/ksts.py b/app/models/ksts.py index f737e4c..663ea35 100644 --- a/app/models/ksts.py +++ b/app/models/ksts.py @@ -7,6 +7,8 @@ from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel +from app.models.filter import BaseFilter + class KstBase(SQLModel): kst_number: int = Field() @@ -34,11 +36,11 @@ class KstsPublic(SQLModel): total: int -class KstFilter(Filter): - accountnumber: Optional[int] = None - search: Optional[str] = None - - class Constants(Filter.Constants): - model = Kst - search_field_name = "search" - search_model_fields = ["namede", "nameen"] +class KstFilter(BaseFilter): + kst_number: Optional[int] = None + name_de: Optional[str] = None + name_en: Optional[str] = None + owner: Optional[str] = None + active: Optional[bool] = None + budget_plus: Optional[int] = None + budget_minus: Optional[int] = None diff --git a/app/models/ledgers.py b/app/models/ledgers.py index 8386b55..318825b 100644 --- a/app/models/ledgers.py +++ b/app/models/ledgers.py @@ -7,6 +7,8 @@ from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel +from app.models.filter import BaseFilter + class LedgerBase(SQLModel): accountnumber: int = Field(default=0) @@ -31,11 +33,12 @@ class LedgersPublic(SQLModel): total: int -class LedgerFilter(Filter): +class LedgerFilter(BaseFilter): accountnumber: Optional[int] = None + namede: Optional[str] = None + nameen: Optional[str] = None search: Optional[str] = None - class Constants(Filter.Constants): + class Constants: model = Ledger search_field_name = "search" - search_model_fields = ["namede", "nameen"] diff --git a/app/models/reimbursements.py b/app/models/reimbursements.py index 8c1af89..621869d 100644 --- a/app/models/reimbursements.py +++ b/app/models/reimbursements.py @@ -1,5 +1,6 @@ import uuid from datetime import datetime +from typing import Optional import uuid6 from fastapi_filter.contrib.sqlalchemy import Filter @@ -7,7 +8,8 @@ from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel from app.models.addresses import Address, AddressBase, AddressPublic -from app.models.creditor import Creditor, CreditorBase, CreditorPublic +from app.models.creditor import Creditor, CreditorBase, CreditorFilter, CreditorPublic +from app.models.filter import BaseFilter class ReimbursementBase(SQLModel): @@ -49,6 +51,10 @@ class ReimbursementsList(SQLModel): total: int -class ReimbursementFilter(Filter): - creditor: str | None = None +class ReimbursementFilter(BaseFilter): + creditor: Optional[CreditorFilter] = None recipt: str | None = None + + class Constants: + model = Reimbursement + search_field_name = "search" -- GitLab From 77956b5e2a721aa8536e8ea2c3d07f9eb8d808aa Mon Sep 17 00:00:00 2001 From: cwalter <cwalter@ethz.ch> Date: Wed, 18 Dec 2024 17:01:53 +0100 Subject: [PATCH 5/5] updated filters add sort fix bugs fix credit payments --- app/api/routes/addresses.py | 23 +++++++--- app/api/routes/api_helper.py | 4 +- app/api/routes/bills.py | 22 +++++++--- app/api/routes/creditPayments.py | 37 +++++++--------- app/api/routes/files.py | 32 ++++++++++---- app/api/routes/internalTransfers.py | 19 +++++--- app/api/routes/invoices.py | 20 ++++++--- app/api/routes/items.py | 23 +++++++--- app/api/routes/ksts.py | 23 +++++++--- app/api/routes/ledgers.py | 23 +++++++--- app/models/Users.py | 1 - app/models/filter.py | 68 ++++++++++++++++++++++++++--- 12 files changed, 221 insertions(+), 74 deletions(-) diff --git a/app/api/routes/addresses.py b/app/api/routes/addresses.py index 1b1687d..310656f 100644 --- a/app/api/routes/addresses.py +++ b/app/api/routes/addresses.py @@ -1,10 +1,11 @@ import uuid -from typing import Any +from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlmodel import Session, create_engine, func, insert, select +from app.api.auth import get_user_info from app.api.routes.api_helper import ( create_object, delete_object, @@ -20,12 +21,14 @@ from app.model import ( AddressFilter, AddressPublic, ) +from app.models.Users import User router = APIRouter() @router.get("/", response_model=AddressesPublic) def read_addresses( + current_user: Annotated[User, Depends(get_user_info)], page: int = 0, limit: int = Query(default=100, le=1000), filter: AddressFilter().create_flattened_model() = Depends(), @@ -39,7 +42,9 @@ def read_addresses( @router.get("/{address_id}", response_model=Address) -def read_address(address_id: uuid.UUID) -> Address: +def read_address( + current_user: Annotated[User, Depends(get_user_info)], address_id: uuid.UUID +) -> Address: """ retrieve a single Address by id. """ @@ -47,7 +52,9 @@ def read_address(address_id: uuid.UUID) -> Address: @router.post("/", response_model=Address) -def create_address(*, address_in: AddressBase) -> Address: +def create_address( + current_user: Annotated[User, Depends(get_user_info)], address_in: AddressBase +) -> Address: """ create a new Address. """ @@ -55,7 +62,11 @@ def create_address(*, address_in: AddressBase) -> Address: @router.patch("/{address_id}", response_model=Address) -def update_address(*, address_id: uuid.UUID, address_in: AddressPublic) -> Address: +def update_address( + current_user: Annotated[User, Depends(get_user_info)], + address_id: uuid.UUID, + address_in: AddressPublic, +) -> Address: """ update an Address. """ @@ -63,7 +74,9 @@ def update_address(*, address_id: uuid.UUID, address_in: AddressPublic) -> Addre @router.delete("/{address_id}", response_model=Address) -def delete_address(address_id: uuid.UUID) -> Address: +def delete_address( + current_user: Annotated[User, Depends(get_user_info)], address_id: uuid.UUID +) -> Address: """ delete an Address. """ diff --git a/app/api/routes/api_helper.py b/app/api/routes/api_helper.py index 54e3315..ba88418 100644 --- a/app/api/routes/api_helper.py +++ b/app/api/routes/api_helper.py @@ -74,9 +74,7 @@ def read_objects( with Session(engine) as session: # Apply the filter and order the results - stmt = filter.filter( - select(Obj_type).offset((page) * limit).limit(limit).order_by(order_by) - ) + stmt = filter.filter(select(Obj_type).offset((page) * limit).limit(limit)) print(stmt) # Execute the query out = session.exec(stmt).all() diff --git a/app/api/routes/bills.py b/app/api/routes/bills.py index c938a4d..f3db1e9 100644 --- a/app/api/routes/bills.py +++ b/app/api/routes/bills.py @@ -1,10 +1,11 @@ import uuid -from typing import Any +from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlmodel import Session, create_engine, func, insert, select +from app.api.auth import get_user_info from app.api.routes.api_helper import ( create_object, delete_object, @@ -30,6 +31,7 @@ router = APIRouter() @router.get("/", response_model=BillsList) def read_Bills( + current_user: Annotated[Any, Depends(get_user_info)], page: int = 0, limit: int = Query(default=100, le=1000), filter: BillFilter().create_flattened_model() = Depends(), @@ -50,7 +52,9 @@ def read_Bills( @router.get("/{Bill_id}", response_model=BillPublic) -def read_Bill(bill_id: uuid.UUID) -> BillPublic: +def read_Bill( + current_user: Annotated[Any, Depends(get_user_info)], bill_id: uuid.UUID +) -> BillPublic: """ retrieve a single Bill by id. """ @@ -58,7 +62,9 @@ def read_Bill(bill_id: uuid.UUID) -> BillPublic: @router.post("/", response_model=BillPublic) -def create_Bill(*, bill_in: BillCreate) -> BillPublic: +def create_Bill( + current_user: Annotated[Any, Depends(get_user_info)], bill_in: BillCreate +) -> BillPublic: """ create a new Bill. """ @@ -88,7 +94,11 @@ def create_Bill(*, bill_in: BillCreate) -> BillPublic: @router.patch("/{Bill_id}", response_model=BillPublic) -def update_Bill(*, Bill_id: uuid.UUID, bill_in: BillPublic) -> BillPublic: +def update_Bill( + current_user: Annotated[Any, Depends(get_user_info)], + Bill_id: uuid.UUID, + bill_in: BillPublic, +) -> BillPublic: """ update an Bill. """ @@ -107,7 +117,9 @@ def update_Bill(*, Bill_id: uuid.UUID, bill_in: BillPublic) -> BillPublic: @router.delete("/{Bill_id}", response_model=Bill) -def delete_Bill(Bill_id: uuid.UUID) -> Bill: +def delete_Bill( + current_user: Annotated[Any, Depends(get_user_info)], Bill_id: uuid.UUID +) -> Bill: """ delete an Bill. """ diff --git a/app/api/routes/creditPayments.py b/app/api/routes/creditPayments.py index 087c801..480562e 100644 --- a/app/api/routes/creditPayments.py +++ b/app/api/routes/creditPayments.py @@ -5,6 +5,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlmodel import Session, SQLModel, create_engine, func, insert, select +from app.api.auth import get_user_info from app.api.routes.api_helper import ( create_object, delete_object, @@ -28,22 +29,9 @@ from app.model import ( router = APIRouter() -def resolve_field_type(field_annotation): - """ - Resolves the actual type from field annotations, handling Optional/Union types. - """ - if get_origin(field_annotation) is Union: - # If the field is Optional[Type], extract the non-None type - args = get_args(field_annotation) - return next(arg for arg in args if arg is not type(None)) - return field_annotation - - -from pydantic import BaseModel, Field, create_model - - @router.get("/", response_model=CreditPaymentsList) def read_creditPayments( + current_user: Annotated[Any, Depends(get_user_info)], filters: CreditPaymentFilter().create_flattened_model() = Depends(), page: int = 0, limit: int = Query(default=100, le=1000), @@ -63,8 +51,10 @@ def read_creditPayments( ) -@router.get("/{CreditPayment_id}", response_model=CreditPaymentPublic) -def read_creditPayment(creditPayment_id: uuid.UUID) -> CreditPaymentPublic: +@router.get("/{creditPayment_id}", response_model=CreditPaymentPublic) +def read_creditPayment( + creditPayment_id: uuid.UUID, current_user: Annotated[Any, Depends(get_user_info)] +) -> CreditPaymentPublic: """ retrieve a single CreditPayment by id. """ @@ -73,7 +63,8 @@ def read_creditPayment(creditPayment_id: uuid.UUID) -> CreditPaymentPublic: @router.post("/", response_model=CreditPaymentPublic) def create_creditPayment( - *, creditPayment_in: CreditPaymentCreate + current_user: Annotated[Any, Depends(get_user_info)], + creditPayment_in: CreditPaymentCreate, ) -> CreditPaymentPublic: """ create a new CreditPayment. @@ -99,9 +90,11 @@ def create_creditPayment( return returncreditPayment -@router.patch("/{CreditPayment_id}", response_model=CreditPaymentPublic) +@router.patch("/{creditPayment_id}", response_model=CreditPaymentPublic) def update_creditPayment( - *, creditPayment_id: uuid.UUID, creditPayment_in: CreditPaymentPublic + current_user: Annotated[Any, Depends(get_user_info)], + creditPayment_id: uuid.UUID, + creditPayment_in: CreditPaymentPublic, ) -> CreditPaymentPublic: """ update an CreditPayment. @@ -121,8 +114,10 @@ def update_creditPayment( ) -@router.delete("/{CreditPayment_id}", response_model=CreditPayment) -def delete_creditPayment(creditPayment_id: uuid.UUID) -> CreditPayment: +@router.delete("/{creditPayment_id}", response_model=CreditPayment) +def delete_creditPayment( + current_user: Annotated[Any, Depends(get_user_info)], creditPayment_id: uuid.UUID +) -> CreditPayment: """ delete an CreditPayment. """ diff --git a/app/api/routes/files.py b/app/api/routes/files.py index d9ad277..6549337 100644 --- a/app/api/routes/files.py +++ b/app/api/routes/files.py @@ -2,17 +2,19 @@ import io import os import uuid from datetime import timedelta # Import timedelta -from typing import Optional +from typing import Annotated, Optional import uuid6 -from fastapi import APIRouter, FastAPI, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, FastAPI, File, HTTPException, UploadFile, status from fastapi.responses import RedirectResponse from minio import Minio from minio.error import S3Error from sqlmodel import Session, select +from app.api.auth import get_user_info from app.core.config import engine from app.models.unlinked_files import UnlinkedFiles +from app.models.Users import User # Configuration MINIO_ENDPOINT = "minio:9000" # Replace with your MinIO server URL @@ -55,7 +57,9 @@ router = APIRouter() @router.get("/unlinked_files", status_code=200) -async def get_unlinked_files(): +async def get_unlinked_files( + current_user: Annotated[User, Depends(get_user_info)], +): """ Get a list of all files in the MinIO bucket that are not linked to any item. """ @@ -70,7 +74,9 @@ async def get_unlinked_files(): @router.post("/upload", status_code=201) -async def upload_file(file: UploadFile = File(...)): +async def upload_file( + current_user: Annotated[User, Depends(get_user_info)], file: UploadFile = File(...) +): """ Upload a file to MinIO and return its UUID. """ @@ -104,7 +110,11 @@ async def upload_file(file: UploadFile = File(...)): @router.get("/files/{file_id}") -def get_file(file_id: str, expires: Optional[int] = 3600): +def get_file( + current_user: Annotated[User, Depends(get_user_info)], + file_id: str, + expires: Optional[int] = 3600, +): """ Generate a presigned URL for downloading the file from MinIO. """ @@ -135,7 +145,11 @@ def get_file(file_id: str, expires: Optional[int] = 3600): @router.patch("/files/{file_id}", status_code=200) -async def replace_file(file_id: str, file: UploadFile = File(...)): +async def replace_file( + current_user: Annotated[User, Depends(get_user_info)], + file_id: str, + file: UploadFile = File(...), +): """ Replace an existing file in MinIO with a new file using the same UUID. """ @@ -173,7 +187,9 @@ async def replace_file(file_id: str, file: UploadFile = File(...)): # Optional: Endpoint to list all files (for debugging purposes) @router.get("/files") -def list_files(): +def list_files( + current_user: Annotated[User, Depends(get_user_info)], +): """ List all files in the MinIO bucket. """ @@ -189,7 +205,7 @@ def list_files(): # Optional: Endpoint to delete a file @router.delete("/files/{file_id}", status_code=204) -def delete_file(file_id: str): +def delete_file(current_user: Annotated[User, Depends(get_user_info)], file_id: str): """ Delete a file from MinIO. """ diff --git a/app/api/routes/internalTransfers.py b/app/api/routes/internalTransfers.py index 79de117..a6fdeb4 100644 --- a/app/api/routes/internalTransfers.py +++ b/app/api/routes/internalTransfers.py @@ -1,10 +1,11 @@ import uuid -from typing import Any +from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlmodel import Session, create_engine, func, insert, select +from app.api.auth import get_user_info from app.api.routes.api_helper import ( create_object, delete_object, @@ -29,6 +30,7 @@ router = APIRouter() @router.get("/", response_model=InternalTransfersList) def read_InternalTransferes( + current_user: Annotated[Any, Depends(get_user_info)], page: int = 0, limit: int = Query(default=100, le=1000), filter: InternalTransferFilter().create_flattened_model() = Depends(), @@ -49,7 +51,9 @@ def read_InternalTransferes( @router.get("/{InternalTransfer_id}", response_model=InternalTransferPublic) -def read_InternalTransfer(InternalTransfer_id: uuid.UUID) -> InternalTransferPublic: +def read_InternalTransfer( + current_user: Annotated[Any, Depends(get_user_info)], InternalTransfer_id: uuid.UUID +) -> InternalTransferPublic: """ retrieve a single InternalTransfer by id. """ @@ -58,7 +62,8 @@ def read_InternalTransfer(InternalTransfer_id: uuid.UUID) -> InternalTransferPub @router.post("/", response_model=InternalTransferPublic) def create_InternalTransfer( - *, InternalTransfer_in: InternalTransferCreate + current_user: Annotated[Any, Depends(get_user_info)], + InternalTransfer_in: InternalTransferCreate, ) -> InternalTransferPublic: """ create a new InternalTransfer. @@ -91,7 +96,9 @@ def create_InternalTransfer( @router.patch("/{InternalTransfer_id}", response_model=InternalTransferPublic) def update_InternalTransfer( - *, internalTransfer_id: uuid.UUID, internalTransfer_in: InternalTransferPublic + current_user: Annotated[Any, Depends(get_user_info)], + internalTransfer_id: uuid.UUID, + internalTransfer_in: InternalTransferPublic, ) -> InternalTransferPublic: """ update an InternalTransfer. @@ -117,7 +124,9 @@ def update_InternalTransfer( @router.delete("/{InternalTransfer_id}", response_model=InternalTransfer) -def delete_InternalTransfer(InternalTransfer_id: uuid.UUID) -> InternalTransfer: +def delete_InternalTransfer( + current_user: Annotated[Any, Depends(get_user_info)], InternalTransfer_id: uuid.UUID +) -> InternalTransfer: """ delete an InternalTransfer. """ diff --git a/app/api/routes/invoices.py b/app/api/routes/invoices.py index 9cf7743..43e5261 100644 --- a/app/api/routes/invoices.py +++ b/app/api/routes/invoices.py @@ -1,12 +1,13 @@ import uuid from datetime import datetime -from typing import Any +from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlalchemy.orm import selectinload from sqlmodel import Session, create_engine, func, insert, select +from app.api.auth import get_user_info from app.api.routes.api_helper import ( create_object, delete_object, @@ -36,7 +37,9 @@ router = APIRouter() @router.get("/{invoice_id}", response_model=InvoicePublic) -def read_Invoice(invoice_id: uuid.UUID) -> InvoicePublic: +def read_Invoice( + current_user: Annotated[Any, Depends(get_user_info)], invoice_id: uuid.UUID +) -> InvoicePublic: """ Retrieve a single Invoice by id. """ @@ -45,6 +48,7 @@ def read_Invoice(invoice_id: uuid.UUID) -> InvoicePublic: @router.get("/", response_model=InvoicesList) def read_Invoices( + current_user: Annotated[Any, Depends(get_user_info)], page: int = 0, limit: int = Query(default=100, le=1000), filter: InvoiceFilter().create_flattened_model() = Depends(), @@ -57,7 +61,9 @@ def read_Invoices( @router.post("/", response_model=InvoicePublic) -def create_Invoice(*, Invoice_in: InvoiceCreate) -> InvoicePublic: +def create_Invoice( + current_user: Annotated[Any, Depends(get_user_info)], Invoice_in: InvoiceCreate +) -> InvoicePublic: """ Create a new Invoice. """ @@ -119,7 +125,9 @@ def create_Invoice(*, Invoice_in: InvoiceCreate) -> InvoicePublic: @router.patch("/{Invoice_id}", response_model=InvoicePublic) def update_Invoice( - *, invoice_id: uuid.UUID, invoice_in: InvoicePublic + current_user: Annotated[Any, Depends(get_user_info)], + invoice_id: uuid.UUID, + invoice_in: InvoicePublic, ) -> InvoicePublic: """ update an Invoice. @@ -143,7 +151,9 @@ def update_Invoice( @router.delete("/{Invoice_id}", response_model=Invoice) -def delete_Invoice(Invoice_id: uuid.UUID) -> Invoice: +def delete_Invoice( + current_user: Annotated[Any, Depends(get_user_info)], Invoice_id: uuid.UUID +) -> Invoice: """ delete an Invoice. """ diff --git a/app/api/routes/items.py b/app/api/routes/items.py index 7c21f37..a122ebd 100644 --- a/app/api/routes/items.py +++ b/app/api/routes/items.py @@ -1,10 +1,11 @@ import uuid -from typing import Any +from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlmodel import Session, create_engine, func, insert, select +from app.api.auth import get_user_info from app.api.routes.api_helper import ( create_object, delete_object, @@ -14,12 +15,14 @@ from app.api.routes.api_helper import ( ) from app.core.config import engine from app.model import Item, ItemBase, ItemFilter, ItemPublic, ItemsPublic +from app.models.Users import User router = APIRouter() @router.get("/", response_model=ItemsPublic) def read_items( + current_user: Annotated[User, Depends(get_user_info)], page: int = 0, limit: int = Query(default=100, le=1000), filter: ItemFilter().create_flattened_model() = Depends(), @@ -33,7 +36,9 @@ def read_items( @router.get("/{item_id}", response_model=Item) -def read_item(item_id: uuid.UUID) -> Item: +def read_item( + current_user: Annotated[User, Depends(get_user_info)], item_id: uuid.UUID +) -> Item: """ retrieve a single item by id. """ @@ -41,7 +46,9 @@ def read_item(item_id: uuid.UUID) -> Item: @router.post("/", response_model=Item) -def create_item(*, item_in: ItemBase) -> Item: +def create_item( + current_user: Annotated[User, Depends(get_user_info)], item_in: ItemBase +) -> Item: """ create a new item. """ @@ -49,7 +56,11 @@ def create_item(*, item_in: ItemBase) -> Item: @router.patch("/{item_id}", response_model=Item) -def update_item(*, item_id: uuid.UUID, item_in: ItemPublic) -> Item: +def update_item( + current_user: Annotated[User, Depends(get_user_info)], + item_id: uuid.UUID, + item_in: ItemPublic, +) -> Item: """ update an item. """ @@ -57,7 +68,9 @@ def update_item(*, item_id: uuid.UUID, item_in: ItemPublic) -> Item: @router.delete("/{item_id}", response_model=Item) -def delete_item(item_id: uuid.UUID) -> Item: +def delete_item( + current_user: Annotated[User, Depends(get_user_info)], item_id: uuid.UUID +) -> Item: """ delete an item. """ diff --git a/app/api/routes/ksts.py b/app/api/routes/ksts.py index 5c6d817..e126a55 100644 --- a/app/api/routes/ksts.py +++ b/app/api/routes/ksts.py @@ -1,10 +1,11 @@ import uuid -from typing import Any +from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlmodel import Session, create_engine, func, insert, select +from app.api.auth import get_user_info from app.api.routes.api_helper import ( create_object, delete_object, @@ -14,12 +15,14 @@ from app.api.routes.api_helper import ( ) from app.core.config import engine from app.model import Kst, KstBase, KstFilter, KstPublic, KstsPublic +from app.models.Users import User router = APIRouter() @router.get("/", response_model=KstsPublic) def read_ksts( + current_user: Annotated[User, Depends(get_user_info)], page: int = 0, limit: int = Query(default=100, le=1000), filter: KstFilter().create_flattened_model() = Depends(), @@ -33,7 +36,9 @@ def read_ksts( @router.get("/{kst_id}", response_model=Kst) -def read_kst(kst_id: uuid.UUID) -> Kst: +def read_kst( + current_user: Annotated[User, Depends(get_user_info)], kst_id: uuid.UUID +) -> Kst: """ retrieve a single Kst by id. """ @@ -41,7 +46,9 @@ def read_kst(kst_id: uuid.UUID) -> Kst: @router.post("/", response_model=Kst) -def create_kst(*, kst_in: KstBase) -> Kst: +def create_kst( + current_user: Annotated[User, Depends(get_user_info)], kst_in: KstBase +) -> Kst: """ create a new Kst. """ @@ -49,7 +56,11 @@ def create_kst(*, kst_in: KstBase) -> Kst: @router.patch("/{kst_id}", response_model=Kst) -def update_kst(*, kst_id: uuid.UUID, kst_in: KstPublic) -> Kst: +def update_kst( + current_user: Annotated[User, Depends(get_user_info)], + kst_id: uuid.UUID, + kst_in: KstPublic, +) -> Kst: """ update an Kst. """ @@ -57,7 +68,9 @@ def update_kst(*, kst_id: uuid.UUID, kst_in: KstPublic) -> Kst: @router.delete("/{kst_id}", response_model=Kst) -def delete_kst(kst_id: uuid.UUID) -> Kst: +def delete_kst( + current_user: Annotated[User, Depends(get_user_info)], kst_id: uuid.UUID +) -> Kst: """ delete an Kst. """ diff --git a/app/api/routes/ledgers.py b/app/api/routes/ledgers.py index b3df35a..75b5995 100644 --- a/app/api/routes/ledgers.py +++ b/app/api/routes/ledgers.py @@ -1,10 +1,11 @@ import uuid -from typing import Any +from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi_filter import FilterDepends from sqlmodel import Session, create_engine, func, insert, select +from app.api.auth import get_user_info from app.api.routes.api_helper import ( create_object, delete_object, @@ -14,12 +15,14 @@ from app.api.routes.api_helper import ( ) from app.core.config import engine from app.model import Ledger, LedgerBase, LedgerFilter, LedgerPublic, LedgersPublic +from app.models.Users import User router = APIRouter() @router.get("/", response_model=LedgersPublic) def read_ledgers( + current_user: Annotated[User, Depends(get_user_info)], page: int = 0, limit: int = Query(default=100, le=1000), filter: LedgerFilter().create_flattened_model() = Depends(), @@ -35,7 +38,9 @@ def read_ledgers( @router.get("/{ledger_id}", response_model=Ledger) -def read_ledger(ledger_id: uuid.UUID) -> Ledger: +def read_ledger( + current_user: Annotated[User, Depends(get_user_info)], ledger_id: uuid.UUID +) -> Ledger: """ retrieve a single Ledger by id. """ @@ -43,7 +48,9 @@ def read_ledger(ledger_id: uuid.UUID) -> Ledger: @router.post("/", response_model=Ledger) -def create_ledger(*, ledger_in: LedgerBase) -> Ledger: +def create_ledger( + current_user: Annotated[User, Depends(get_user_info)], ledger_in: LedgerBase +) -> Ledger: """ create a new Ledger. """ @@ -51,7 +58,11 @@ def create_ledger(*, ledger_in: LedgerBase) -> Ledger: @router.patch("/{ledger_id}", response_model=Ledger) -def update_ledger(*, ledger_id: uuid.UUID, ledger_in: LedgerPublic) -> Ledger: +def update_ledger( + current_user: Annotated[User, Depends(get_user_info)], + ledger_id: uuid.UUID, + ledger_in: LedgerPublic, +) -> Ledger: """ update an Ledger. """ @@ -59,7 +70,9 @@ def update_ledger(*, ledger_id: uuid.UUID, ledger_in: LedgerPublic) -> Ledger: @router.delete("/{ledger_id}", response_model=Ledger) -def delete_ledger(ledger_id: uuid.UUID) -> Ledger: +def delete_ledger( + current_user: Annotated[User, Depends(get_user_info)], ledger_id: uuid.UUID +) -> Ledger: """ delete an Ledger. """ diff --git a/app/models/Users.py b/app/models/Users.py index 2b03d20..e3c050d 100644 --- a/app/models/Users.py +++ b/app/models/Users.py @@ -10,7 +10,6 @@ from app.models.addresses import Address, AddressBase, AddressPublic class DbUserBase(SQLModel): amiv_id: str = Field(max_length=30, primary_key=True) - handle: str = Field(max_length=30) address_id: uuid.UUID = Field(foreign_key="address.id") iban: str = Field(max_length=30) diff --git a/app/models/filter.py b/app/models/filter.py index 8df42dd..181226a 100644 --- a/app/models/filter.py +++ b/app/models/filter.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Type, Union, get_args, get_origin from pydantic import BaseModel from pydantic import Field as pydantic_Field from pydantic import create_model -from sqlalchemy import ColumnElement, String +from sqlalchemy import ColumnElement, String, asc, desc from sqlalchemy.orm import InstrumentedAttribute from sqlmodel import SQLModel, cast, or_ @@ -15,6 +15,7 @@ class BaseFilter(SQLModel): """ search: Optional[str] = None + sort: Optional[str] = None class Constants: model: Type[SQLModel] = None # To be set in derived classes @@ -90,6 +91,7 @@ class BaseFilter(SQLModel): model = self.Constants.model filter_data = self.model_dump(exclude_unset=True) join_paths = set() + print(filter_data) for field_name, field_value in filter_data.items(): if field_value is None: @@ -97,11 +99,15 @@ class BaseFilter(SQLModel): if field_name == self.Constants.search_field_name: continue # Handle search separately later + if field_name == "sort": + continue field_name = field_name.replace("__", ".") # Support nested fields field_path = field_name.split(".") model_field = getattr(model, field_path[0], None) - + print( + f"model_field: {model_field}, field_path: {field_path}, field_value: {field_value}" + ) # Build joins for nested fields for path in field_path[1:]: if isinstance(model_field, InstrumentedAttribute): @@ -120,11 +126,11 @@ class BaseFilter(SQLModel): query = query.join(column) column = column.property.mapper.class_ column = getattr(column, path, None) - - if column.type.python_type == str: - query = query.filter(column.ilike(f"%{field_value}%")) + print(column.type) + if column.type == String: + query = query.filter(column.contains(f"%{field_value}%")) else: - query = query.filter(column == field_value) + query = query.filter(cast(column, String).contains(f"%{field_value}%")) # Handle global search if self.search: search_conditions = self._build_search_conditions(model, self.search) @@ -135,6 +141,56 @@ class BaseFilter(SQLModel): query = query.join(join_path) # print(f"search_conditions: {[cond.all_() for cond in search_conditions]}") query = query.filter(or_(*search_conditions)) + if self.sort: + for join_path in self._build_search_joins(model, self.search): + # Join the path if not already joined + query = query.join(join_path) + query = self._apply_sorting(query, model) + + return query + + def _resolve_column(self, model, field_path, join_paths, query): + """ + Resolves the column for a nested field path and ensures joins are added. + """ + column = getattr(model, field_path[0], None) + for path in field_path[1:]: + if isinstance(column, InstrumentedAttribute): + query = query.join(column) + column = column.property.mapper.class_ + column = getattr(column, path, None) + return column + + def _apply_sorting(self, query, model): + """ + Dynamically applies sorting logic to the query based on the `sort` field. + + Example sort string: "creditor__amount:desc" + """ + sort_param = self.sort.strip() + field_order = "asc" # Default order + + # Extract field path and order + if ":" in sort_param: + sort_field, field_order = sort_param.split(":") + field_order = field_order.lower() + else: + sort_field = sort_param + + # Resolve nested field path + sort_field = sort_field.replace("__", ".") + field_path = sort_field.split(".") + + column = self._resolve_column(model, field_path, set(), query) + if not column: + raise ValueError(f"Invalid sorting field: {sort_field}") + + # Apply sorting order + if field_order == "desc": + query = query.order_by(desc(column)) + else: + query = query.order_by(asc(column)) + return query def _build_search_joins( -- GitLab