From 26cf123713ef1760a86bdcdaef2186286602119a Mon Sep 17 00:00:00 2001 From: Bedram Tamang Date: Sun, 10 May 2026 16:44:10 -0700 Subject: [PATCH 1/5] feat: postgres setup --- docker-compose.yml | 13 +++++++++ .../masoniteorm/connections/manager.py | 4 +-- .../connections/postgres_connection.py | 26 +++++++++--------- .../tests/masoniteorm/fixtures/migration.py | 8 +++--- .../tests/masoniteorm/postgres/__init__.py | 0 .../masoniteorm/postgres/fixtures/__init__.py | 0 .../tests/masoniteorm/postgres/fixtures/db.py | 25 +++++++++++++++++ .../masoniteorm/postgres/models/__init__.py | 0 .../masoniteorm/postgres/models/test_model.py | 27 +++++++++++++++++++ .../builder/test_sqlite_builder_insert.py | 2 +- .../builder/test_sqlite_query_builder.py | 2 +- ...test_sqlite_query_builder_eager_loading.py | 2 +- ...test_sqlite_query_builder_relationships.py | 2 +- .../sqlite/builder/test_sqlite_transaction.py | 2 +- .../masoniteorm/sqlite/fixtures/__init__.py | 0 .../masoniteorm/{ => sqlite}/fixtures/db.py | 1 - .../sqlite/models/test_sqlite_model.py | 2 +- .../tests/masoniteorm/sqlite/test_case.py | 15 ++++++----- fastapi_startkit/uv.lock | 2 +- 19 files changed, 99 insertions(+), 34 deletions(-) create mode 100644 fastapi_startkit/tests/masoniteorm/postgres/__init__.py create mode 100644 fastapi_startkit/tests/masoniteorm/postgres/fixtures/__init__.py create mode 100644 fastapi_startkit/tests/masoniteorm/postgres/fixtures/db.py create mode 100644 fastapi_startkit/tests/masoniteorm/postgres/models/__init__.py create mode 100644 fastapi_startkit/tests/masoniteorm/postgres/models/test_model.py create mode 100644 fastapi_startkit/tests/masoniteorm/sqlite/fixtures/__init__.py rename fastapi_startkit/tests/masoniteorm/{ => sqlite}/fixtures/db.py (90%) diff --git a/docker-compose.yml b/docker-compose.yml index 79f7a92b..686ad9fa 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,3 +13,16 @@ services: interval: 5s timeout: 5s retries: 10 + postgres: + image: postgres:17 + environment: + POSTGRES_DB: database_app_test + POSTGRES_USER: app + POSTGRES_PASSWORD: secret + ports: + - "5432:5432" + healthcheck: + test: [ "CMD", "pg_isready", "-U", "app", "-d", "database_app_test" ] + interval: 5s + timeout: 5s + retries: 10 diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/manager.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/manager.py index 5add1c27..b988397d 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/manager.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/manager.py @@ -39,7 +39,7 @@ def get_schema_builder(self): return Schema(self) - def clear(self): + async def clear(self): for conn in self.connections.values(): - conn.engine.dispose() + await conn.engine.dispose() self.connections.clear() diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py index c24e43ad..de29c876 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py @@ -1,4 +1,4 @@ -from typing import Any +from sqlalchemy import text from fastapi_startkit.masoniteorm.query.grammars import PostgresGrammar from fastapi_startkit.masoniteorm.query.processors import PostgresPostProcessor from fastapi_startkit.masoniteorm.schema.platforms import PostgresPlatform @@ -20,18 +20,20 @@ def get_default_platform(cls): def get_post_processor(cls): return PostgresPostProcessor - async def insert(self, query: str, bindings: list | None = None) -> Any: - """Postgres uses RETURNING to get the inserted id/row.""" - query, params = self.sql_alchemy_bindings(query, bindings) + async def insert(self, query: str, bindings: list | None = None) -> int | None: + """Execute an INSERT ... RETURNING * and return the generated primary key.""" + query_str, params = self.sql_alchemy_bindings(query, bindings) + conn = await self.get_connection() + result = await conn.execute(text(query_str), params or {}) - from sqlalchemy import text - - async with self.engine.connect() as conn: - result = await conn.execute(text(query), params) + if not self.transactions: await conn.commit() - row = result.fetchone() - if row: - return dict(zip(result.keys(), row)) + row = result.fetchone() + if row: + return row[0] - return None + # Fallback for cases where RETURNING result is unavailable + val_result = await conn.execute(text("SELECT lastval()")) + val_row = val_result.fetchone() + return val_row[0] if val_row else None diff --git a/fastapi_startkit/tests/masoniteorm/fixtures/migration.py b/fastapi_startkit/tests/masoniteorm/fixtures/migration.py index 85d315a1..454d41b4 100644 --- a/fastapi_startkit/tests/masoniteorm/fixtures/migration.py +++ b/fastapi_startkit/tests/masoniteorm/fixtures/migration.py @@ -1,16 +1,14 @@ -from .db import DB +from fastapi_startkit.masoniteorm.schema import Schema -schema = DB.get_schema_builder() - -async def wipe(): +async def wipe(schema: Schema) -> None: for connection in ("default", "dev"): tables = await schema.on(connection).get_all_tables() for table in tables: await schema.on(connection).drop_table_if_exists(table) -async def migrate(): +async def migrate(schema: Schema) -> None: async with await schema.on("default").create_table_if_not_exists("users") as table: table.id() table.string("name") diff --git a/fastapi_startkit/tests/masoniteorm/postgres/__init__.py b/fastapi_startkit/tests/masoniteorm/postgres/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastapi_startkit/tests/masoniteorm/postgres/fixtures/__init__.py b/fastapi_startkit/tests/masoniteorm/postgres/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastapi_startkit/tests/masoniteorm/postgres/fixtures/db.py b/fastapi_startkit/tests/masoniteorm/postgres/fixtures/db.py new file mode 100644 index 00000000..4e328f00 --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/postgres/fixtures/db.py @@ -0,0 +1,25 @@ +from fastapi_startkit.masoniteorm.connections.factory import ConnectionFactory +from fastapi_startkit.masoniteorm.connections.manager import DatabaseManager + +URL = "postgresql+asyncpg://app:secret@localhost:5432/database_app_test" + +DB = DatabaseManager( + ConnectionFactory(), + { + "default": "postgres", + "connections": { + "postgres": { + "driver": "postgres", + "url": URL, + "database": "database_app_test", + }, + "dev": { + "driver": "postgres", + "url": URL, + "database": "database_app_test", + }, + }, + }, +) + +schema = DB.get_schema_builder() diff --git a/fastapi_startkit/tests/masoniteorm/postgres/models/__init__.py b/fastapi_startkit/tests/masoniteorm/postgres/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastapi_startkit/tests/masoniteorm/postgres/models/test_model.py b/fastapi_startkit/tests/masoniteorm/postgres/models/test_model.py new file mode 100644 index 00000000..97c2b68b --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/postgres/models/test_model.py @@ -0,0 +1,27 @@ +from unittest import IsolatedAsyncioTestCase + +from fastapi_startkit.masoniteorm.models.model import Model +from fastapi_startkit.masoniteorm.testing.transaction import RefreshDatabase +from ..fixtures.db import DB, schema +from ...fixtures.migration import migrate, wipe +from ...fixtures.model import User + + +class TestPostGresModel(RefreshDatabase, IsolatedAsyncioTestCase): + async def asyncSetUp(self): + Model.db_manager = DB + await DB.clear() + await wipe(schema) + await migrate(schema) + + async def asyncTearDown(self): + await wipe(schema) + await DB.clear() + + async def test_can_create_and_find_user(self): + user = await User.create({"name": "Alice", "email": "alice@example.com", "is_admin": False}) + self.assertIsNotNone(user.id) + + found = await User.find(user.id) + self.assertEqual(found.name, "Alice") + self.assertEqual(found.email, "alice@example.com") diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_builder_insert.py b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_builder_insert.py index 92c8da9d..8b843b1b 100644 --- a/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_builder_insert.py +++ b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_builder_insert.py @@ -1,6 +1,6 @@ from unittest.mock import AsyncMock -from ...fixtures.db import DB +from ..fixtures.db import DB from ...fixtures.model import User from ..test_case import TestCase diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_query_builder.py b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_query_builder.py index 72156f93..ccd36e0e 100644 --- a/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_query_builder.py +++ b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_query_builder.py @@ -1,7 +1,7 @@ from unittest.mock import AsyncMock from ..test_case import TestCase -from ...fixtures.db import DB +from ..fixtures.db import DB from ...fixtures.model import User diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_query_builder_eager_loading.py b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_query_builder_eager_loading.py index b09ce94e..7ec6ba91 100644 --- a/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_query_builder_eager_loading.py +++ b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_query_builder_eager_loading.py @@ -1,7 +1,7 @@ from unittest.mock import AsyncMock from ...fixtures.model import User, Articles, Profile -from ...fixtures.db import DB +from ..fixtures.db import DB from ..test_case import TestCase diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_query_builder_relationships.py b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_query_builder_relationships.py index b34f85e2..4d92d978 100644 --- a/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_query_builder_relationships.py +++ b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_query_builder_relationships.py @@ -2,7 +2,7 @@ from unittest.mock import AsyncMock from ..test_case import TestCase -from ...fixtures.db import DB +from ..fixtures.db import DB from ...fixtures.model import User diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_transaction.py b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_transaction.py index c8ea0e33..a6afc097 100644 --- a/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_transaction.py +++ b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_transaction.py @@ -1,5 +1,5 @@ from ...fixtures.model import User -from ...fixtures.db import DB +from ..fixtures.db import DB from ..test_case import TestCase diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/fixtures/__init__.py b/fastapi_startkit/tests/masoniteorm/sqlite/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastapi_startkit/tests/masoniteorm/fixtures/db.py b/fastapi_startkit/tests/masoniteorm/sqlite/fixtures/db.py similarity index 90% rename from fastapi_startkit/tests/masoniteorm/fixtures/db.py rename to fastapi_startkit/tests/masoniteorm/sqlite/fixtures/db.py index 12e51e6f..25a88f94 100644 --- a/fastapi_startkit/tests/masoniteorm/fixtures/db.py +++ b/fastapi_startkit/tests/masoniteorm/sqlite/fixtures/db.py @@ -1,7 +1,6 @@ from fastapi_startkit.masoniteorm.connections.factory import ConnectionFactory from fastapi_startkit.masoniteorm.connections.manager import DatabaseManager from fastapi_startkit.masoniteorm.models.model import Model -from fastapi_startkit.masoniteorm.models.registry import Registry DB = DatabaseManager( ConnectionFactory(), diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/models/test_sqlite_model.py b/fastapi_startkit/tests/masoniteorm/sqlite/models/test_sqlite_model.py index 6223a476..9454471e 100644 --- a/fastapi_startkit/tests/masoniteorm/sqlite/models/test_sqlite_model.py +++ b/fastapi_startkit/tests/masoniteorm/sqlite/models/test_sqlite_model.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import AsyncMock -from ...fixtures.db import DB +from ..fixtures.db import DB from ...fixtures.model import User from ..test_case import TestCase diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/test_case.py b/fastapi_startkit/tests/masoniteorm/sqlite/test_case.py index a0e32b10..6feed1ef 100644 --- a/fastapi_startkit/tests/masoniteorm/sqlite/test_case.py +++ b/fastapi_startkit/tests/masoniteorm/sqlite/test_case.py @@ -2,7 +2,7 @@ from fastapi_startkit.masoniteorm.testing.transaction import RefreshDatabase -from ..fixtures.db import DB +from .fixtures.db import DB from ..fixtures.migration import migrate, wipe from ..fixtures.seeder import seeder @@ -10,16 +10,17 @@ class TestCase(RefreshDatabase, IsolatedAsyncioTestCase): async def asyncSetUp(self): self.db = DB - self.schema = self.db.get_schema_builder() + self.schema = DB.get_schema_builder() await self.migrate_database() async def asyncTearDown(self): - DB.clear() - await wipe() + await DB.clear() + await wipe(DB.get_schema_builder()) @staticmethod async def migrate_database(): - DB.clear() - await wipe() - await migrate() + await DB.clear() + schema = DB.get_schema_builder() + await wipe(schema) + await migrate(schema) await seeder() diff --git a/fastapi_startkit/uv.lock b/fastapi_startkit/uv.lock index 97e7d28b..277e2571 100644 --- a/fastapi_startkit/uv.lock +++ b/fastapi_startkit/uv.lock @@ -443,7 +443,7 @@ wheels = [ [[package]] name = "fastapi-startkit" -version = "0.19.0" +version = "0.20.0" source = { editable = "." } dependencies = [ { name = "cleo" }, From cbe0c8802c8009d34bcf05a73b3143d9b13ec411 Mon Sep 17 00:00:00 2001 From: Bedram Tamang Date: Sun, 10 May 2026 17:08:05 -0700 Subject: [PATCH 2/5] feat: wip --- .../connections/postgres_connection.py | 22 ++++++------------- .../masoniteorm/models/builder.py | 4 ++++ .../masoniteorm/query/grammars/BaseGrammar.py | 2 ++ .../query/grammars/PostgresGrammar.py | 4 ++-- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py index de29c876..2ef3bf5f 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py @@ -1,4 +1,3 @@ -from sqlalchemy import text from fastapi_startkit.masoniteorm.query.grammars import PostgresGrammar from fastapi_startkit.masoniteorm.query.processors import PostgresPostProcessor from fastapi_startkit.masoniteorm.schema.platforms import PostgresPlatform @@ -21,19 +20,12 @@ def get_post_processor(cls): return PostgresPostProcessor async def insert(self, query: str, bindings: list | None = None) -> int | None: - """Execute an INSERT ... RETURNING * and return the generated primary key.""" - query_str, params = self.sql_alchemy_bindings(query, bindings) - conn = await self.get_connection() - result = await conn.execute(text(query_str), params or {}) - - if not self.transactions: - await conn.commit() + """Execute INSERT ... RETURNING "pk" and return the generated primary key. + The base Connection.insert() uses lastrowid which is SQLite-specific. + PostgreSQL uses RETURNING to get the pk back; the grammar scopes it to + the exact primary key column so row[0] is always the pk value. + """ + result = await self.execute(query, bindings) row = result.fetchone() - if row: - return row[0] - - # Fallback for cases where RETURNING result is unavailable - val_result = await conn.execute(text("SELECT lastval()")) - val_row = val_result.fetchone() - return val_row[0] if val_row else None + return row[0] if row else None diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py index 60066310..9a56ad4d 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from dumpdie import dd from fastapi_startkit.masoniteorm.expressions.expressions import ( JoinClause, QueryExpression, @@ -124,6 +125,8 @@ def run_scopes(self) -> "QueryBuilder": return self def get_grammar(self): + pk = self._model.__primary_key__ if self._model is not None else None + returning = f'"{pk}"' if pk else "*" return self.grammar( columns=self._columns, table=self._table, @@ -136,6 +139,7 @@ def get_grammar(self): group_by=self._group_by, having=self._having, distinct=self._distinct, + returning=returning, ) def to_qmark(self) -> str: diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/BaseGrammar.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/BaseGrammar.py index 987241c0..de427a79 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/BaseGrammar.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/BaseGrammar.py @@ -37,6 +37,7 @@ def __init__( lock=False, having=(), connection_details=None, + returning="*", ): self._columns = columns self.table = table @@ -55,6 +56,7 @@ def __init__( self._connection_details = connection_details or {} self._column = None + self._returning = returning self._bindings = [] self._sql = "" diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/PostgresGrammar.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/PostgresGrammar.py index 4459c6dd..7b97407c 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/PostgresGrammar.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/PostgresGrammar.py @@ -44,10 +44,10 @@ def update_format(self): return "UPDATE {table} SET {key_equals} {wheres}" def insert_format(self): - return "INSERT INTO {table} ({columns}) VALUES ({values}) RETURNING *" + return f"INSERT INTO {{table}} ({{columns}}) VALUES ({{values}}) RETURNING {self._returning}" def bulk_insert_format(self): - return "INSERT INTO {table} ({columns}) VALUES {values} RETURNING *" + return f"INSERT INTO {{table}} ({{columns}}) VALUES {{values}} RETURNING {self._returning}" def delete_format(self): return "DELETE FROM {table} {wheres}" From 7a44f6c02a6fe32fe9e9aa250da19781ed60ea08 Mon Sep 17 00:00:00 2001 From: Bedram Tamang Date: Mon, 11 May 2026 01:52:09 -0700 Subject: [PATCH 3/5] feat: fix the insert id --- .../masoniteorm/connections/connection.py | 7 ++++ .../connections/postgres_connection.py | 11 ------ .../masoniteorm/models/builder.py | 26 ++++++++++---- .../masoniteorm/models/model.py | 12 ++++--- .../masoniteorm/query/grammars/BaseGrammar.py | 35 +++++++++++++++++++ .../query/grammars/PostgresGrammar.py | 20 +++++++++-- 6 files changed, 87 insertions(+), 24 deletions(-) diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py index 6e11180a..28a71ce5 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py @@ -103,6 +103,13 @@ async def insert(self, query: str, bindings: list | None = None) -> int | None: return getattr(result, "lastrowid", None) + async def insert_get_id(self, query: str, bindings: list | None = None) -> int | None: + result = await self.execute(query, bindings) + row = result.fetchone() + if row is not None: + return row[0] + return getattr(result, "lastrowid", None) + async def update(self, query: str, bindings: list | None = None) -> int: result = await self.execute(query, bindings) diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py index 2ef3bf5f..0e270f7a 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py @@ -18,14 +18,3 @@ def get_default_platform(cls): @classmethod def get_post_processor(cls): return PostgresPostProcessor - - async def insert(self, query: str, bindings: list | None = None) -> int | None: - """Execute INSERT ... RETURNING "pk" and return the generated primary key. - - The base Connection.insert() uses lastrowid which is SQLite-specific. - PostgreSQL uses RETURNING to get the pk back; the grammar scopes it to - the exact primary key column so row[0] is always the pk value. - """ - result = await self.execute(query, bindings) - row = result.fetchone() - return row[0] if row else None diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py index 9a56ad4d..ade2e226 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py @@ -1,8 +1,6 @@ import inspect +from typing import TYPE_CHECKING, Any -from typing import TYPE_CHECKING - -from dumpdie import dd from fastapi_startkit.masoniteorm.expressions.expressions import ( JoinClause, QueryExpression, @@ -108,9 +106,9 @@ async def get_models(self, columns=None): collection = self._model.hydrate(models) if ( - self._eager_relation.eagers - or self._eager_relation.nested_eagers - or self._eager_relation.callback_eagers + self._eager_relation.eagers + or self._eager_relation.nested_eagers + or self._eager_relation.callback_eagers ): await self._load_eagers(collection, self._model) @@ -279,6 +277,16 @@ async def insert(self, values: dict | list) -> int | None: bindings = [val for row in values for val in row.values()] return await self.connection.insert(sql, bindings) + async def insert_get_id( + self, + values: dict[str, Any] | list[dict[str, Any]], + sequences: str | None = None, + ) -> int | None: + sql = self.grammar().compile_insert_get_id(self, values, sequences) + bindings = self.clean_bindings(values) + + return await self.connection.insert_get_id(sql, bindings) + async def update(self, values: dict) -> int: updates = [UpdateQueryExpression(col, val) for col, val in values.items()] grammar = self.grammar() @@ -390,3 +398,9 @@ def or_where_has(self, relation: str, callback=None) -> "QueryBuilder": else: related.query_has(self, method="or_where_exists") return self + + @classmethod + def clean_bindings(cls, values): + if isinstance(values, dict): + values = [values] + return [val for row in values for val in row.values()] diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py index 53bf2b35..5a874d61 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py @@ -21,6 +21,7 @@ class Model(Attribute, Relationship, ObservesEvents): __table__ = None __primary_key__ = "id" __timestamps__ = True + __incrementing__ = True __has_events__ = True __observers__ = {} @@ -215,12 +216,15 @@ def finish_saving(self, options: dict | None = None): async def perform_insert(self, query) -> bool: attributes = self.get_attributes_for_insert() - inserted_id = await query.insert(attributes) - - # Store the auto-generated primary key so subsequent saves do an UPDATE - if inserted_id is not None: + """if the model set auto incrementing, we need to set back the primary key to the inserted id.""" + if self.__incrementing__: + inserted_id = await query.insert_get_id(attributes) + self._attributes[self.__primary_key__] = inserted_id self._dirty_attributes[self.__primary_key__] = inserted_id + else: + await query.insert(attributes) + self._exists = True self._was_recently_created = True self.observe_events(self, "created") diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/BaseGrammar.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/BaseGrammar.py index de427a79..5916ead0 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/BaseGrammar.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/BaseGrammar.py @@ -1,4 +1,10 @@ +from __future__ import annotations + import re +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from ...models.builder import QueryBuilder from ...expressions.expressions import ( JoinClause, @@ -161,6 +167,35 @@ def _compile_insert(self, qmark=False): return self + def compile_insert(self, query: QueryBuilder, values:dict[str, Any] | list[dict[str, Any]]): + table = self.wrap_table(query._table) + + if not values: + return f"INSERT INTO {table} DEFAULT VALUES" + + # Normalise a single dict to a one-element list so the rest of the + # logic can treat every case uniformly. + if isinstance(values, dict): + values = [values] + + columns = self.columnize_bulk_columns(list(values[0].keys())) + + parameters = ", ".join( + "({})".format(", ".join("?" for _ in record)) + for record in values + ) + + return f"INSERT INTO {table} ({columns}) VALUES {parameters}" + + def compile_insert_get_id( + self, + query: QueryBuilder, + values: dict[str, Any] | list[dict[str, Any]], + sequences: str | None = None, + ) -> str: + return self.compile_insert(query, values) + + def _compile_bulk_create(self, qmark=False): """Compiles an insert expression. diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/PostgresGrammar.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/PostgresGrammar.py index 7b97407c..cea36030 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/PostgresGrammar.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/PostgresGrammar.py @@ -1,7 +1,10 @@ import re +from typing import TYPE_CHECKING, Any from .BaseGrammar import BaseGrammar +if TYPE_CHECKING: + from ...models.builder import QueryBuilder class PostgresGrammar(BaseGrammar): """Postgres grammar class.""" @@ -50,7 +53,7 @@ def bulk_insert_format(self): return f"INSERT INTO {{table}} ({{columns}}) VALUES {{values}} RETURNING {self._returning}" def delete_format(self): - return "DELETE FROM {table} {wheres}" + return "DELETE FROM {TABLE} {wheres}" def aggregate_string_with_alias(self): return "{aggregate_function}({column}) AS {alias}" @@ -110,11 +113,11 @@ def create_column_string(self): return "{column} {data_type}{length}{nullable}, " def column_exists_string(self): - return "SELECT column_name FROM information_schema.columns WHERE table_name='{clean_table}' and column_name={value}" + return "SELECT column_name FROM information_schema.columns WHERE table_name='{clean_table}' AND column_name={value}" def table_exists_string(self): return ( - "SELECT * from information_schema.tables where table_name='{clean_table}'" + "SELECT * FROM information_schema.tables WHERE table_name='{clean_table}'" ) def create_column_length(self, column_type): @@ -211,5 +214,16 @@ def truncate_table(self, table, foreign_keys=False): """ return f"TRUNCATE TABLE {self.wrap_table(table)}" + def compile_insert_get_id( + self, + query: "QueryBuilder", + values: dict[str, Any] | list[dict[str, Any]], + sequences: str | None = None, + ) -> str: + return ( + self.compile_insert(query, values) + + f" RETURNING {self.wrap_table(sequences or 'id')}" + ) + def compile_random(self): return "random()" From 815b3410b275bb94adfd999dfac93751fb61c7c6 Mon Sep 17 00:00:00 2001 From: Bedram Tamang Date: Mon, 11 May 2026 01:56:59 -0700 Subject: [PATCH 4/5] feat: wip --- .../masoniteorm/connections/connection.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py index 28a71ce5..8ef4e004 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py @@ -69,7 +69,6 @@ async def close(self) -> None: async def reconnect(self) -> None: await self.close() - @staticmethod def sql_alchemy_bindings(query: str, bindings: list | None = None): params = {} @@ -105,10 +104,11 @@ async def insert(self, query: str, bindings: list | None = None) -> int | None: async def insert_get_id(self, query: str, bindings: list | None = None) -> int | None: result = await self.execute(query, bindings) - row = result.fetchone() - if row is not None: - return row[0] - return getattr(result, "lastrowid", None) + last_insert_id = getattr(result, "lastrowid", None) + if not last_insert_id: + row = result.fetchone() + last_insert_id = row[0] if row else None + return last_insert_id async def update(self, query: str, bindings: list | None = None) -> int: result = await self.execute(query, bindings) From 8fc1bbadaa0e71669f70252c1af11d9e66ff5140 Mon Sep 17 00:00:00 2001 From: Bedram Tamang Date: Mon, 11 May 2026 10:38:35 -0700 Subject: [PATCH 5/5] feat: method added --- example/config-app/uv.lock | 4 +- example/database-app/uv.lock | 14 +- .../masoniteorm/connections/connection.py | 6 +- .../connections/postgres_connection.py | 8 + .../masoniteorm/models/builder.py | 47 +++--- .../masoniteorm/models/model.py | 12 +- .../masoniteorm/query/grammars/BaseGrammar.py | 2 - .../query/grammars/PostgresGrammar.py | 8 +- .../masoniteorm/postgres/models/test_model.py | 142 +++++++++++++++--- .../tests/masoniteorm/postgres/test_case.py | 18 +++ .../tests/masoniteorm/sqlite/test_case.py | 2 + 11 files changed, 208 insertions(+), 55 deletions(-) create mode 100644 fastapi_startkit/tests/masoniteorm/postgres/test_case.py diff --git a/example/config-app/uv.lock b/example/config-app/uv.lock index 4404d604..72f1cab2 100644 --- a/example/config-app/uv.lock +++ b/example/config-app/uv.lock @@ -162,7 +162,7 @@ wheels = [ [[package]] name = "fastapi-startkit" -version = "0.13.6" +version = "0.20.0" source = { editable = "../../fastapi_startkit" } dependencies = [ { name = "cleo" }, @@ -185,6 +185,7 @@ requires-dist = [ { name = "faker", marker = "extra == 'database'", specifier = ">=40.13.0" }, { name = "fastapi", extras = ["standard"], marker = "extra == 'fastapi'", specifier = ">=0.124.4,<0.125.0" }, { name = "inflection", specifier = ">=0.5.1" }, + { name = "itsdangerous", marker = "extra == 'fastapi'", specifier = ">=2.2.0" }, { name = "jinja2", marker = "extra == 'vite'", specifier = ">=3.1" }, { name = "pendulum", specifier = ">=3.1.0,<4.0.0" }, { name = "pydantic", specifier = ">=2.12.5" }, @@ -196,6 +197,7 @@ provides-extras = ["fastapi", "database", "sqlite", "postgres", "mysql", "vite"] [package.metadata.requires-dev] dev = [ { name = "dumpdie", specifier = ">=1.5.0" }, + { name = "itsdangerous", specifier = ">=2.2.0" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "pytest-asyncio", specifier = ">=1.3.0" }, { name = "ruff", specifier = ">=0.9.0" }, diff --git a/example/database-app/uv.lock b/example/database-app/uv.lock index e4bb9324..b98aa2b0 100644 --- a/example/database-app/uv.lock +++ b/example/database-app/uv.lock @@ -498,7 +498,7 @@ wheels = [ [[package]] name = "fastapi-startkit" -version = "0.13.6" +version = "0.20.0" source = { editable = "../../fastapi_startkit" } dependencies = [ { name = "cleo" }, @@ -517,6 +517,7 @@ database = [ ] fastapi = [ { name = "fastapi", extra = ["standard"] }, + { name = "itsdangerous" }, ] postgres = [ { name = "asyncpg" }, @@ -533,6 +534,7 @@ requires-dist = [ { name = "faker", marker = "extra == 'database'", specifier = ">=40.13.0" }, { name = "fastapi", extras = ["standard"], marker = "extra == 'fastapi'", specifier = ">=0.124.4,<0.125.0" }, { name = "inflection", specifier = ">=0.5.1" }, + { name = "itsdangerous", marker = "extra == 'fastapi'", specifier = ">=2.2.0" }, { name = "jinja2", marker = "extra == 'vite'", specifier = ">=3.1" }, { name = "pendulum", specifier = ">=3.1.0,<4.0.0" }, { name = "pydantic", specifier = ">=2.12.5" }, @@ -544,6 +546,7 @@ provides-extras = ["fastapi", "database", "sqlite", "postgres", "mysql", "vite"] [package.metadata.requires-dev] dev = [ { name = "dumpdie", specifier = ">=1.5.0" }, + { name = "itsdangerous", specifier = ">=2.2.0" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "pytest-asyncio", specifier = ">=1.3.0" }, { name = "ruff", specifier = ">=0.9.0" }, @@ -762,6 +765,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "itsdangerous" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9c/cb/8ac0172223afbccb63986cc25049b154ecfb5e85932587206f42317be31d/itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173", size = 54410, upload-time = "2024-04-16T21:28:15.614Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" }, +] + [[package]] name = "jinja2" version = "3.1.6" diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py index 8ef4e004..2ba28792 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py @@ -104,11 +104,7 @@ async def insert(self, query: str, bindings: list | None = None) -> int | None: async def insert_get_id(self, query: str, bindings: list | None = None) -> int | None: result = await self.execute(query, bindings) - last_insert_id = getattr(result, "lastrowid", None) - if not last_insert_id: - row = result.fetchone() - last_insert_id = row[0] if row else None - return last_insert_id + return getattr(result, "lastrowid", None) async def update(self, query: str, bindings: list | None = None) -> int: result = await self.execute(query, bindings) diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py index 0e270f7a..18ee20ed 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/postgres_connection.py @@ -7,6 +7,14 @@ class PostgresConnection(Connection): """Async PostgreSQL connection backed by asyncpg via SQLAlchemy.""" + async def insert_get_id(self, query: str, bindings: list | None = None) -> int | None: + result = await self.run(query, bindings) + row = result.fetchone() + if not self.transactions: + conn = await self.get_connection() + await conn.commit() + return row[0] if row is not None else None + @classmethod def get_query_grammar(cls): return PostgresGrammar diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py index ade2e226..37071585 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py @@ -123,8 +123,6 @@ def run_scopes(self) -> "QueryBuilder": return self def get_grammar(self): - pk = self._model.__primary_key__ if self._model is not None else None - returning = f'"{pk}"' if pk else "*" return self.grammar( columns=self._columns, table=self._table, @@ -137,7 +135,6 @@ def get_grammar(self): group_by=self._group_by, having=self._having, distinct=self._distinct, - returning=returning, ) def to_qmark(self) -> str: @@ -218,26 +215,27 @@ def distinct(self) -> "QueryBuilder": self._distinct = True return self - def aggregate(self, aggregate_type: str, column: str, alias: str = None) -> "QueryBuilder": - if alias: - column = f"{column} as {alias}" - self._aggregates += (AggregateExpression(aggregate_type, column),) - return self + async def aggregate(self, function: str, column: str): + self._aggregates += (AggregateExpression(function, column),) + row = await self.connection.select_one(self.to_qmark(), self.get_bindings()) + if row is None: + return None + return next(iter(row.values())) - def count(self, column: str = "*") -> "QueryBuilder": - return self.aggregate("COUNT", column) + async def count(self, column: str = "*"): + return await self.aggregate("COUNT", column) - def sum(self, column: str) -> "QueryBuilder": - return self.aggregate("SUM", column) + async def sum(self, column: str): + return await self.aggregate("SUM", column) - def max(self, column: str) -> "QueryBuilder": - return self.aggregate("MAX", column) + async def max(self, column: str): + return await self.aggregate("MAX", column) - def min(self, column: str) -> "QueryBuilder": - return self.aggregate("MIN", column) + async def min(self, column: str): + return await self.aggregate("MIN", column) - def avg(self, column: str) -> "QueryBuilder": - return self.aggregate("AVG", column) + async def avg(self, column: str): + return await self.aggregate("AVG", column) async def delete(self, column=None, value=None): if column is not None: @@ -259,6 +257,15 @@ async def first_or_create(self, search: dict, attributes: dict | None = None): return await self.create({**(attributes or {}), **search}) + async def update_or_create(self, search: dict, attributes: dict | None = None): + instance = await self.where(search).first() + if instance is not None: + if attributes: + await instance.update(attributes) + return instance + + return await self.create({**(attributes or {}), **search}) + async def insert(self, values: dict | list) -> int | None: self.set_action("bulk_create") @@ -302,9 +309,7 @@ async def paginate(self, per_page: int = 15, page: int = 1): count_builder._wheres = list(self._wheres) count_builder._joins = self._joins count_builder._global_scopes = self._global_scopes - count_builder.count() - count_result = await self.connection.select(count_builder.to_qmark(), count_builder.get_bindings()) - total = list(count_result[0].values())[0] if count_result else 0 + total = await count_builder.count() or 0 offset = (page - 1) * per_page results = await self.limit(per_page).offset(offset).get() diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py index 5a874d61..66c99176 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py @@ -13,7 +13,7 @@ from fastapi_startkit.masoniteorm.models.relationship import Relationship if TYPE_CHECKING: - from fastapi_startkit.orm.models.builder import QueryBuilder + from fastapi_startkit.masoniteorm.models.builder import QueryBuilder class Model(Attribute, Relationship, ObservesEvents): @@ -119,6 +119,10 @@ def on(cls, connection: str): async def all(cls): return await cls.query().get() + @classmethod + async def count(cls, column: str = "*"): + return await cls.query().count(column) + def set_connection(self, connection: str): self.connection = connection @@ -175,6 +179,12 @@ async def first_or_create( ) -> "Model": return await cls.query().first_or_create(search, attributes) + @classmethod + async def update_or_create( + cls, search: dict, attributes: dict | None = None + ) -> "Model": + return await cls.query().update_or_create(search, attributes) + @classmethod async def create(cls, attributes: dict): instance = cls().new_model_instance(attributes) diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/BaseGrammar.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/BaseGrammar.py index 5916ead0..7e79575b 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/BaseGrammar.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/BaseGrammar.py @@ -43,7 +43,6 @@ def __init__( lock=False, having=(), connection_details=None, - returning="*", ): self._columns = columns self.table = table @@ -62,7 +61,6 @@ def __init__( self._connection_details = connection_details or {} self._column = None - self._returning = returning self._bindings = [] self._sql = "" diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/PostgresGrammar.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/PostgresGrammar.py index cea36030..f40ba826 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/PostgresGrammar.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/grammars/PostgresGrammar.py @@ -46,14 +46,8 @@ def select_format(self): def update_format(self): return "UPDATE {table} SET {key_equals} {wheres}" - def insert_format(self): - return f"INSERT INTO {{table}} ({{columns}}) VALUES ({{values}}) RETURNING {self._returning}" - - def bulk_insert_format(self): - return f"INSERT INTO {{table}} ({{columns}}) VALUES {{values}} RETURNING {self._returning}" - def delete_format(self): - return "DELETE FROM {TABLE} {wheres}" + return "DELETE FROM {table} {wheres}" def aggregate_string_with_alias(self): return "{aggregate_function}({column}) AS {alias}" diff --git a/fastapi_startkit/tests/masoniteorm/postgres/models/test_model.py b/fastapi_startkit/tests/masoniteorm/postgres/models/test_model.py index 97c2b68b..0595cef2 100644 --- a/fastapi_startkit/tests/masoniteorm/postgres/models/test_model.py +++ b/fastapi_startkit/tests/masoniteorm/postgres/models/test_model.py @@ -1,23 +1,8 @@ -from unittest import IsolatedAsyncioTestCase - -from fastapi_startkit.masoniteorm.models.model import Model -from fastapi_startkit.masoniteorm.testing.transaction import RefreshDatabase -from ..fixtures.db import DB, schema -from ...fixtures.migration import migrate, wipe +from ..test_case import TestCase from ...fixtures.model import User -class TestPostGresModel(RefreshDatabase, IsolatedAsyncioTestCase): - async def asyncSetUp(self): - Model.db_manager = DB - await DB.clear() - await wipe(schema) - await migrate(schema) - - async def asyncTearDown(self): - await wipe(schema) - await DB.clear() - +class TestPostGresModel(TestCase): async def test_can_create_and_find_user(self): user = await User.create({"name": "Alice", "email": "alice@example.com", "is_admin": False}) self.assertIsNotNone(user.id) @@ -25,3 +10,126 @@ async def test_can_create_and_find_user(self): found = await User.find(user.id) self.assertEqual(found.name, "Alice") self.assertEqual(found.email, "alice@example.com") + + async def test_find_returns_none_for_missing_id(self): + found = await User.find(99999) + self.assertIsNone(found) + + async def test_first_returns_first_record(self): + await User.create({"name": "Bob", "email": "bob@example.com", "is_admin": False}) + await User.create({"name": "Carol", "email": "carol@example.com", "is_admin": True}) + + user = await User.first() + self.assertIsNotNone(user) + self.assertEqual(user.name, "Bob") + + async def test_first_returns_none_when_table_is_empty(self): + user = await User.first() + self.assertIsNone(user) + + async def test_update_changes_attributes(self): + user = await User.create({"name": "Dave", "email": "dave@example.com", "is_admin": False}) + + await user.update({"name": "David"}) + + refreshed = await User.find(user.id) + self.assertEqual(refreshed.name, "David") + self.assertEqual(refreshed.email, "dave@example.com") + + async def test_update_only_dirty_fields(self): + user = await User.create({"name": "Eve", "email": "eve@example.com", "is_admin": False}) + + await user.update({"name": "Eve", "is_admin": True}) + + refreshed = await User.find(user.id) + self.assertTrue(refreshed.is_admin) + self.assertEqual(refreshed.name, "Eve") + + async def test_delete_removes_record(self): + user = await User.create({"name": "Frank", "email": "frank@example.com", "is_admin": False}) + user_id = user.id + + await User.where("id", user_id).delete() + + found = await User.find(user_id) + self.assertIsNone(found) + + async def test_delete_by_column_removes_matching_records(self): + await User.create({"name": "Grace", "email": "grace@example.com", "is_admin": False}) + await User.create({"name": "Heidi", "email": "heidi@example.com", "is_admin": True}) + + await User.query().delete("is_admin", True) + + admin = await User.where("is_admin", True).first() + self.assertIsNone(admin) + + non_admin = await User.where("is_admin", False).first() + self.assertIsNotNone(non_admin) + + async def test_where_filters_results(self): + await User.create({"name": "Ivan", "email": "ivan@example.com", "is_admin": False}) + await User.create({"name": "Judy", "email": "judy@example.com", "is_admin": True}) + + admins = await User.where("is_admin", True).get() + self.assertEqual(len(admins), 1) + self.assertEqual(admins[0].name, "Judy") + + async def test_first_or_create_creates_when_not_found(self): + user = await User.first_or_create( + {"email": "newuser@example.com"}, + {"name": "New User", "is_admin": False}, + ) + self.assertIsNotNone(user.id) + self.assertEqual(user.email, "newuser@example.com") + self.assertEqual(user.name, "New User") + + async def test_first_or_create_returns_existing_when_found(self): + existing = await User.create({"name": "Existing", "email": "existing@example.com", "is_admin": False}) + + user = await User.first_or_create( + {"email": "existing@example.com"}, + {"name": "Should Not Be Created", "is_admin": True}, + ) + self.assertEqual(user.id, existing.id) + self.assertEqual(user.name, "Existing") + + # Confirm no duplicate was inserted + all_users = await User.where("email", "existing@example.com").get() + self.assertEqual(len(all_users), 1) + + async def test_all_returns_all_records(self): + await User.create({"name": "Karl", "email": "karl@example.com", "is_admin": False}) + await User.create({"name": "Laura", "email": "laura@example.com", "is_admin": False}) + + users = await User.all() + self.assertEqual(len(users), 2) + + async def test_update_or_create_creates_when_not_found(self): + user = await User.update_or_create( + {"email": "new@example.com"}, + {"name": "New User", "is_admin": False}, + ) + self.assertIsNotNone(user.id) + self.assertEqual(user.email, "new@example.com") + self.assertEqual(user.name, "New User") + + async def test_update_or_create_updates_when_found(self): + await User.create({"name": "Original", "email": "update@example.com", "is_admin": False}) + + user = await User.update_or_create( + {"email": "update@example.com"}, + {"name": "Updated", "is_admin": True}, + ) + self.assertEqual(user.name, "Updated") + self.assertTrue(user.is_admin) + + # Confirm no duplicate was inserted + count = await User.where("email", "update@example.com").count() + self.assertEqual(count, 1) + + async def test_count_returns_correct_number(self): + await User.create({"name": "Mallory", "email": "mallory@example.com", "is_admin": False}) + await User.create({"name": "Niaj", "email": "niaj@example.com", "is_admin": False}) + + count = await User.count() + self.assertEqual(count, 2) diff --git a/fastapi_startkit/tests/masoniteorm/postgres/test_case.py b/fastapi_startkit/tests/masoniteorm/postgres/test_case.py new file mode 100644 index 00000000..05ff5ba4 --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/postgres/test_case.py @@ -0,0 +1,18 @@ +from unittest import IsolatedAsyncioTestCase + +from fastapi_startkit.masoniteorm import Model +from fastapi_startkit.masoniteorm.testing.transaction import RefreshDatabase +from .fixtures.db import DB, schema +from ..fixtures.migration import migrate, wipe + + +class TestCase(RefreshDatabase, IsolatedAsyncioTestCase): + async def asyncSetUp(self): + Model.db_manager = DB + await DB.clear() + await wipe(schema) + await migrate(schema) + + async def asyncTearDown(self): + await wipe(schema) + await DB.clear() diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/test_case.py b/fastapi_startkit/tests/masoniteorm/sqlite/test_case.py index 6feed1ef..ae2431ef 100644 --- a/fastapi_startkit/tests/masoniteorm/sqlite/test_case.py +++ b/fastapi_startkit/tests/masoniteorm/sqlite/test_case.py @@ -1,6 +1,7 @@ from unittest import IsolatedAsyncioTestCase from fastapi_startkit.masoniteorm.testing.transaction import RefreshDatabase +from fastapi_startkit.masoniteorm import Model from .fixtures.db import DB from ..fixtures.migration import migrate, wipe @@ -10,6 +11,7 @@ class TestCase(RefreshDatabase, IsolatedAsyncioTestCase): async def asyncSetUp(self): self.db = DB + Model.db_manager = DB self.schema = DB.get_schema_builder() await self.migrate_database()