Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,16 @@ services:
interval: 5s
timeout: 5s
retries: 10
postgres:
image: postgres:17
environment:
POSTGRES_DB: database_app_test
POSTGRES_USER: app
POSTGRES_PASSWORD: secret
ports:
- "5432:5432"
healthcheck:
test: [ "CMD", "pg_isready", "-U", "app", "-d", "database_app_test" ]
interval: 5s
timeout: 5s
retries: 10
4 changes: 3 additions & 1 deletion example/config-app/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 13 additions & 1 deletion example/database-app/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ async def close(self) -> None:
async def reconnect(self) -> None:
await self.close()


@staticmethod
def sql_alchemy_bindings(query: str, bindings: list | None = None):
params = {}
Expand Down Expand Up @@ -103,6 +102,10 @@ async def insert(self, query: str, bindings: list | None = None) -> int | None:

return getattr(result, "lastrowid", None)

async def insert_get_id(self, query: str, bindings: list | None = None) -> int | None:
result = await self.execute(query, bindings)
return getattr(result, "lastrowid", None)

async def update(self, query: str, bindings: list | None = None) -> int:
result = await self.execute(query, bindings)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_schema_builder(self):

return Schema(self)

def clear(self):
async def clear(self):
for conn in self.connections.values():
conn.engine.dispose()
await conn.engine.dispose()
self.connections.clear()
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Any
from fastapi_startkit.masoniteorm.query.grammars import PostgresGrammar
from fastapi_startkit.masoniteorm.query.processors import PostgresPostProcessor
from fastapi_startkit.masoniteorm.schema.platforms import PostgresPlatform
Expand All @@ -8,6 +7,14 @@
class PostgresConnection(Connection):
"""Async PostgreSQL connection backed by asyncpg via SQLAlchemy."""

async def insert_get_id(self, query: str, bindings: list | None = None) -> int | None:
result = await self.run(query, bindings)
row = result.fetchone()
if not self.transactions:
conn = await self.get_connection()
await conn.commit()
return row[0] if row is not None else None

@classmethod
def get_query_grammar(cls):
return PostgresGrammar
Expand All @@ -19,19 +26,3 @@ def get_default_platform(cls):
@classmethod
def get_post_processor(cls):
return PostgresPostProcessor

async def insert(self, query: str, bindings: list | None = None) -> Any:
"""Postgres uses RETURNING to get the inserted id/row."""
query, params = self.sql_alchemy_bindings(query, bindings)

from sqlalchemy import text

async with self.engine.connect() as conn:
result = await conn.execute(text(query), params)
await conn.commit()

row = result.fetchone()
if row:
return dict(zip(result.keys(), row))

return None
69 changes: 46 additions & 23 deletions fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import inspect

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from fastapi_startkit.masoniteorm.expressions.expressions import (
JoinClause,
Expand Down Expand Up @@ -107,9 +106,9 @@ async def get_models(self, columns=None):
collection = self._model.hydrate(models)

if (
self._eager_relation.eagers
or self._eager_relation.nested_eagers
or self._eager_relation.callback_eagers
self._eager_relation.eagers
or self._eager_relation.nested_eagers
or self._eager_relation.callback_eagers
):
await self._load_eagers(collection, self._model)

Expand Down Expand Up @@ -216,26 +215,27 @@ def distinct(self) -> "QueryBuilder":
self._distinct = True
return self

def aggregate(self, aggregate_type: str, column: str, alias: str = None) -> "QueryBuilder":
if alias:
column = f"{column} as {alias}"
self._aggregates += (AggregateExpression(aggregate_type, column),)
return self
async def aggregate(self, function: str, column: str):
self._aggregates += (AggregateExpression(function, column),)
row = await self.connection.select_one(self.to_qmark(), self.get_bindings())
if row is None:
return None
return next(iter(row.values()))

def count(self, column: str = "*") -> "QueryBuilder":
return self.aggregate("COUNT", column)
async def count(self, column: str = "*"):
return await self.aggregate("COUNT", column)

def sum(self, column: str) -> "QueryBuilder":
return self.aggregate("SUM", column)
async def sum(self, column: str):
return await self.aggregate("SUM", column)

def max(self, column: str) -> "QueryBuilder":
return self.aggregate("MAX", column)
async def max(self, column: str):
return await self.aggregate("MAX", column)

def min(self, column: str) -> "QueryBuilder":
return self.aggregate("MIN", column)
async def min(self, column: str):
return await self.aggregate("MIN", column)

def avg(self, column: str) -> "QueryBuilder":
return self.aggregate("AVG", column)
async def avg(self, column: str):
return await self.aggregate("AVG", column)

async def delete(self, column=None, value=None):
if column is not None:
Expand All @@ -257,6 +257,15 @@ async def first_or_create(self, search: dict, attributes: dict | None = None):

return await self.create({**(attributes or {}), **search})

async def update_or_create(self, search: dict, attributes: dict | None = None):
instance = await self.where(search).first()
if instance is not None:
if attributes:
await instance.update(attributes)
return instance

return await self.create({**(attributes or {}), **search})

async def insert(self, values: dict | list) -> int | None:
self.set_action("bulk_create")

Expand All @@ -275,6 +284,16 @@ async def insert(self, values: dict | list) -> int | None:
bindings = [val for row in values for val in row.values()]
return await self.connection.insert(sql, bindings)

async def insert_get_id(
self,
values: dict[str, Any] | list[dict[str, Any]],
sequences: str | None = None,
) -> int | None:
sql = self.grammar().compile_insert_get_id(self, values, sequences)
bindings = self.clean_bindings(values)

return await self.connection.insert_get_id(sql, bindings)

async def update(self, values: dict) -> int:
updates = [UpdateQueryExpression(col, val) for col, val in values.items()]
grammar = self.grammar()
Expand All @@ -290,9 +309,7 @@ async def paginate(self, per_page: int = 15, page: int = 1):
count_builder._wheres = list(self._wheres)
count_builder._joins = self._joins
count_builder._global_scopes = self._global_scopes
count_builder.count()
count_result = await self.connection.select(count_builder.to_qmark(), count_builder.get_bindings())
total = list(count_result[0].values())[0] if count_result else 0
total = await count_builder.count() or 0

offset = (page - 1) * per_page
results = await self.limit(per_page).offset(offset).get()
Expand Down Expand Up @@ -386,3 +403,9 @@ def or_where_has(self, relation: str, callback=None) -> "QueryBuilder":
else:
related.query_has(self, method="or_where_exists")
return self

@classmethod
def clean_bindings(cls, values):
if isinstance(values, dict):
values = [values]
return [val for row in values for val in row.values()]
24 changes: 19 additions & 5 deletions fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
from fastapi_startkit.masoniteorm.models.relationship import Relationship

if TYPE_CHECKING:
from fastapi_startkit.orm.models.builder import QueryBuilder
from fastapi_startkit.masoniteorm.models.builder import QueryBuilder


class Model(Attribute, Relationship, ObservesEvents):
db_manager: "DatabaseManager" = None
__table__ = None
__primary_key__ = "id"
__timestamps__ = True
__incrementing__ = True

__has_events__ = True
__observers__ = {}
Expand Down Expand Up @@ -118,6 +119,10 @@ def on(cls, connection: str):
async def all(cls):
return await cls.query().get()

@classmethod
async def count(cls, column: str = "*"):
return await cls.query().count(column)

def set_connection(self, connection: str):
self.connection = connection

Expand Down Expand Up @@ -174,6 +179,12 @@ async def first_or_create(
) -> "Model":
return await cls.query().first_or_create(search, attributes)

@classmethod
async def update_or_create(
cls, search: dict, attributes: dict | None = None
) -> "Model":
return await cls.query().update_or_create(search, attributes)

@classmethod
async def create(cls, attributes: dict):
instance = cls().new_model_instance(attributes)
Expand Down Expand Up @@ -215,12 +226,15 @@ def finish_saving(self, options: dict | None = None):
async def perform_insert(self, query) -> bool:
attributes = self.get_attributes_for_insert()

inserted_id = await query.insert(attributes)

# Store the auto-generated primary key so subsequent saves do an UPDATE
if inserted_id is not None:
"""if the model set auto incrementing, we need to set back the primary key to the inserted id."""
if self.__incrementing__:
inserted_id = await query.insert_get_id(attributes)
self._attributes[self.__primary_key__] = inserted_id
self._dirty_attributes[self.__primary_key__] = inserted_id

else:
await query.insert(attributes)

self._exists = True
self._was_recently_created = True
self.observe_events(self, "created")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from ...models.builder import QueryBuilder

from ...expressions.expressions import (
JoinClause,
Expand Down Expand Up @@ -159,6 +165,35 @@ def _compile_insert(self, qmark=False):

return self

def compile_insert(self, query: QueryBuilder, values:dict[str, Any] | list[dict[str, Any]]):
table = self.wrap_table(query._table)

if not values:
return f"INSERT INTO {table} DEFAULT VALUES"

# Normalise a single dict to a one-element list so the rest of the
# logic can treat every case uniformly.
if isinstance(values, dict):
values = [values]

columns = self.columnize_bulk_columns(list(values[0].keys()))

parameters = ", ".join(
"({})".format(", ".join("?" for _ in record))
for record in values
)

return f"INSERT INTO {table} ({columns}) VALUES {parameters}"

def compile_insert_get_id(
self,
query: QueryBuilder,
values: dict[str, Any] | list[dict[str, Any]],
sequences: str | None = None,
) -> str:
return self.compile_insert(query, values)


def _compile_bulk_create(self, qmark=False):
"""Compiles an insert expression.

Expand Down
Loading
Loading