diff --git a/src/sqlmodelgen/codegen/cir_to_full_ast/code_ir_to_ast.py b/src/sqlmodelgen/codegen/cir_to_full_ast/code_ir_to_ast.py index 5202bec..651af80 100644 --- a/src/sqlmodelgen/codegen/cir_to_full_ast/code_ir_to_ast.py +++ b/src/sqlmodelgen/codegen/cir_to_full_ast/code_ir_to_ast.py @@ -1,7 +1,7 @@ import ast from typing import Iterable -from sqlmodelgen.codegen.code_ir.code_ir import AttributeIR, AttrCallIR, ModelIR +from sqlmodelgen.codegen.code_ir.code_ir import AttributeIR, AttrCallIR, ModelIR, SchemaNameArgIR from sqlmodelgen.codegen.cir_to_full_ast.to_ast_imports import gen_imports @@ -57,13 +57,27 @@ def gen_table_args(model_ir: ModelIR) -> ast.Assign | None: if len(model_ir.table_args) == 0: return None + # I shall build the value + + # first case is if there's only one table argument and it's the + # schema name, then let's just have its dictionary as the __table_args__ + if len(model_ir.table_args) == 1 and isinstance(model_ir.table_args[0], SchemaNameArgIR): + value = model_ir.table_args[0].to_expr() + # otherwise just make a tuple of the args + else: + # NOTE: sqlalchemy requires the dictionary to be placed at + # last + # TODO: enforce the sqlalchemy dictionary (at the moment no mergings + # needed, only schema attribute is present) to be at the last + # position + value = ast.Tuple( + elts=[table_arg.to_expr() for table_arg in model_ir.table_args] + ) # at this level we gust generate the unique constraint return ast.Assign( targets=[ast.Name('__table_args__')], - value=ast.Tuple( - elts=[table_arg.to_expr() for table_arg in model_ir.table_args] - ) + value=value, ) diff --git a/src/sqlmodelgen/codegen/cir_to_full_ast/to_ast_imports.py b/src/sqlmodelgen/codegen/cir_to_full_ast/to_ast_imports.py index 442d8bf..be69b34 100644 --- a/src/sqlmodelgen/codegen/cir_to_full_ast/to_ast_imports.py +++ b/src/sqlmodelgen/codegen/cir_to_full_ast/to_ast_imports.py @@ -7,8 +7,10 @@ from itertools import chain from typing import Iterable, Iterator +# type union for generic import ast node AST_IMPORT_TYPE = ast.Import | ast.ImportFrom +# imports for every specific type TYPE_IMPORTS = { 'datetime': ast.ImportFrom( module='datetime', @@ -33,6 +35,9 @@ } def gen_imports(cdefs: Iterable[ast.ClassDef]) -> Iterator[AST_IMPORT_TYPE]: + ''' + generates import statements from the class nodes + ''' data_types_names = set(chain(*map(_iter_data_type_names, cdefs))) call_names = set(chain(*map(_iter_call_names, cdefs))) @@ -46,6 +51,10 @@ def gen_imports(cdefs: Iterable[ast.ClassDef]) -> Iterator[AST_IMPORT_TYPE]: def gen_sqlmodel_import(call_names: set[str]) -> ast.ImportFrom: + ''' + based on the collected calls this function returns an est elements + with the imports necessary from the sqlmodel library + ''' sqlmodel_import = ast.ImportFrom( module='sqlmodel', names=[ @@ -53,16 +62,15 @@ def gen_sqlmodel_import(call_names: set[str]) -> ast.ImportFrom: ] ) + # checking the call names for specific imports if 'Field' in call_names: sqlmodel_import.names.append( ast.alias('Field') ) - if 'Relationship' in call_names: sqlmodel_import.names.append( ast.alias('Relationship') ) - if 'UniqueConstraint' in call_names: sqlmodel_import.names.append( ast.alias('UniqueConstraint') diff --git a/src/sqlmodelgen/codegen/code_ir/build_cir.py b/src/sqlmodelgen/codegen/code_ir/build_cir.py index df5fabe..4bb673c 100644 --- a/src/sqlmodelgen/codegen/code_ir/build_cir.py +++ b/src/sqlmodelgen/codegen/code_ir/build_cir.py @@ -2,7 +2,7 @@ from sqlmodelgen.codegen.code_ir.code_ir import ModelIR from sqlmodelgen.codegen.code_ir.build_rels import add_relationships_attrs from sqlmodelgen.codegen.code_ir.build_col_attrs import attribute_from_col -from sqlmodelgen.codegen.code_ir.build_table_args import build_unique_constraints +from sqlmodelgen.codegen.code_ir.build_table_args import build_table_args from sqlmodelgen.ir.ir import SchemaIR, TableIR def build_model_irs(schema_ir: SchemaIR, gen_relationships: bool, table_name_transform: Callable[[str], str] | None = None, column_name_transform: Callable[[str], str] | None = None) -> list[ModelIR]: @@ -10,7 +10,13 @@ def build_model_irs(schema_ir: SchemaIR, gen_relationships: bool, table_name_tra models_by_table_name: dict[str, ModelIR] = dict() for table_ir in schema_ir.table_irs: - model_ir = build_model_ir(table_ir=table_ir, class_names=class_names, table_name_transform=table_name_transform, column_name_transform=column_name_transform) + model_ir = build_model_ir( + table_ir=table_ir, + class_names=class_names, + table_name_transform=table_name_transform, + column_name_transform=column_name_transform, + schema_name=schema_ir.schema_name, + ) models_by_table_name[model_ir.table_name] = model_ir @@ -33,10 +39,16 @@ def gen_class_name(table_name: str, class_names: set[str], table_name_transform: -def build_model_ir(table_ir: TableIR, class_names: set[str], table_name_transform: Callable[[str], str] | None = None, column_name_transform: Callable[[str], str] | None = None) -> ModelIR: +def build_model_ir( + table_ir: TableIR, + class_names: set[str], + table_name_transform: Callable[[str], str] | None = None, + column_name_transform: Callable[[str], str] | None = None, + schema_name: str | None = None, +) -> ModelIR: return ModelIR( class_name=gen_class_name(table_ir.name, class_names, table_name_transform), table_name=table_ir.name, attrs=[attribute_from_col(col_ir, column_name_transform) for col_ir in table_ir.col_irs], - table_args=list(build_unique_constraints(table_ir)), + table_args=list(build_table_args(table_ir, schema_name)), ) diff --git a/src/sqlmodelgen/codegen/code_ir/build_table_args.py b/src/sqlmodelgen/codegen/code_ir/build_table_args.py index 2a0878d..523a086 100644 --- a/src/sqlmodelgen/codegen/code_ir/build_table_args.py +++ b/src/sqlmodelgen/codegen/code_ir/build_table_args.py @@ -1,8 +1,19 @@ from typing import Iterator -from sqlmodelgen.codegen.code_ir.code_ir import UniqueTableArgIR +from sqlmodelgen.codegen.code_ir.code_ir import UniqueTableArgIR, SchemaNameArgIR from sqlmodelgen.ir.ir import TableIR +def build_table_args(table_ir: TableIR, schema_name: str | None) -> Iterator[UniqueTableArgIR]: + yield from build_unique_constraints(table_ir) + + # yield the schema name at last in order to be consistent + # with the sqlalchemy requirement of having the schema + # dictionary at last + if schema_name is not None and schema_name != 'public': + yield SchemaNameArgIR(schema_name=schema_name) + + + def build_unique_constraints(table_ir: TableIR) -> Iterator[UniqueTableArgIR]: # TODO: still no code tp generate unique for multiple columns diff --git a/src/sqlmodelgen/codegen/code_ir/code_ir.py b/src/sqlmodelgen/codegen/code_ir/code_ir.py index 40001f3..3f5f141 100644 --- a/src/sqlmodelgen/codegen/code_ir/code_ir.py +++ b/src/sqlmodelgen/codegen/code_ir/code_ir.py @@ -41,9 +41,20 @@ def to_expr(self) -> ast.Call: return ast.Call( func=ast.Name('UniqueConstraint'), args=[ast.Constant(col_name) for col_name in self._col_names], - keywords=[] + keywords=[], ) + + +class SchemaNameArgIR(): + def __init__(self, schema_name: str): + self._schema_name = schema_name + + def to_expr(self)-> ast.Dict: + return ast.Dict( + keys=[ast.Constant(value='schema')], + values=[ast.Constant(value=self._schema_name)], + ) @dataclass diff --git a/src/sqlmodelgen/codegen/codegen.py b/src/sqlmodelgen/codegen/codegen.py index 6377c80..8e637d8 100644 --- a/src/sqlmodelgen/codegen/codegen.py +++ b/src/sqlmodelgen/codegen/codegen.py @@ -3,20 +3,29 @@ ''' import ast -from typing import Callable +from typing import Callable, Iterable from sqlmodelgen.ir.ir import SchemaIR from sqlmodelgen.codegen.code_ir.build_cir import build_model_irs +from sqlmodelgen.codegen.code_ir.code_ir import ModelIR from sqlmodelgen.codegen.cir_to_full_ast.code_ir_to_ast import models_to_ast def gen_code( - schema_ir: SchemaIR, + schema_ir: Iterable[SchemaIR] | SchemaIR, generate_relationships: bool = False, table_name_transform: Callable[[str], str] | None = None, column_name_transform: Callable[[str], str] | None = None, ) -> str: - model_irs = build_model_irs(schema_ir, generate_relationships, table_name_transform, column_name_transform) + # in case the schema_ir attribute is a single SchemaIR + if isinstance(schema_ir, SchemaIR): + model_irs = build_model_irs(schema_ir, generate_relationships, table_name_transform, column_name_transform) + # otherwise I assume schema_ir is an iterable + else: + schema_irs = schema_ir + model_irs: list[ModelIR] = [] + for schema_ir in schema_irs: + model_irs += build_model_irs(schema_ir, generate_relationships, table_name_transform, column_name_transform) models_ast = models_to_ast(model_irs) return ast.unparse(models_ast) diff --git a/src/sqlmodelgen/ir/ir.py b/src/sqlmodelgen/ir/ir.py index 32e424f..18c0106 100644 --- a/src/sqlmodelgen/ir/ir.py +++ b/src/sqlmodelgen/ir/ir.py @@ -40,6 +40,7 @@ def get_col_ir(self, name: str) -> ColIR | None: @dataclass class SchemaIR: table_irs: list[TableIR] + schema_name: str | None = None def get_table_ir(self, name: str) -> TableIR | None: ''' diff --git a/src/sqlmodelgen/ir/postgres/postgres_collect.py b/src/sqlmodelgen/ir/postgres/postgres_collect.py index 58f19f5..73fed56 100644 --- a/src/sqlmodelgen/ir/postgres/postgres_collect.py +++ b/src/sqlmodelgen/ir/postgres/postgres_collect.py @@ -4,7 +4,7 @@ import psycopg from dataclasses import dataclass -from typing import Iterator +from typing import Generator, Iterator from sqlmodelgen.ir.ir import ( ColIR, @@ -13,6 +13,14 @@ FKIR ) + +SCHEMAS_TO_AVOID = { + 'pg_toast', + 'pg_catalog', + 'information_schema', +} + + @dataclass class ContraintsData: uniques: dict[str, set[str]] @@ -34,11 +42,27 @@ def get_foreign_key(self, table_name: str, column_name: str) -> FKIR | None: return table_fks.get(column_name) -def collect_postgres_ir(postgres_conn_addr: str, schema_name: str = 'public') -> SchemaIR: +def collect_postgres_ir(postgres_conn_addr: str, schema_name: str | None = None) -> Generator[SchemaIR, None, None]: conn = psycopg.connect(postgres_conn_addr) cursor = conn.cursor() + # obtaining the schemas in the database + cursor.execute('SELECT nspname FROM pg_catalog.pg_namespace') + schema_names = [schema_row[0] for schema_row in cursor.fetchall() if schema_row[0] not in SCHEMAS_TO_AVOID] + + for schema_name in schema_names: + yield collect_schema_ir(cursor, schema_name) + + # TODO: potentially collect contraints regarding foreign keys + + conn.close() + + +def collect_schema_ir( + cursor: psycopg.Cursor, + schema_name: str, +) -> SchemaIR: constraints = collect_contraints(cursor, schema_name) cursor.execute('SELECT * FROM pg_catalog.pg_tables WHERE schemaname=%s', (schema_name, )) @@ -58,12 +82,9 @@ def collect_postgres_ir(postgres_conn_addr: str, schema_name: str = 'public') -> )) )) - # TODO: potentially collect contraints regarding foreign keys - - conn.close() - return SchemaIR( - table_irs=table_irs + schema_name=schema_name, + table_irs=table_irs, ) diff --git a/tests/helpers/helpers.py b/tests/helpers/helpers.py index 0482cdb..dd87867 100644 --- a/tests/helpers/helpers.py +++ b/tests/helpers/helpers.py @@ -93,8 +93,9 @@ class ColumnAstInfo: class ClassAstInfo: class_name: str table_name: str | None - uniques: set[tuple[str]] + uniques: set[tuple[str, ...]] cols_info: dict[str, ColumnAstInfo] + schema_name_arg: str | None = None @dataclass @@ -149,7 +150,8 @@ def collect_sqlmodel_class(class_def: ast.ClassDef) -> ClassAstInfo | None: class_name = class_def.name table_name: str | None = None - uniques: list[tuple[str]] = list() + uniques: set[tuple[str]] = set() + schema_name: str | None = None cols_info: dict[str, ColumnAstInfo] = dict() for stat in class_def.body: @@ -170,6 +172,7 @@ def collect_sqlmodel_class(class_def: ast.ClassDef) -> ClassAstInfo | None: table_name = collect_table_name(stat) elif var_name == '__table_args__': uniques = collect_uniques(stat.value) + schema_name = collect_schema_name_table_arg(stat.value) elif type(stat) is ast.AnnAssign: col_info = collect_col_info(stat) @@ -179,7 +182,8 @@ def collect_sqlmodel_class(class_def: ast.ClassDef) -> ClassAstInfo | None: class_name=class_name, table_name=table_name, uniques=uniques, - cols_info=cols_info + cols_info=cols_info, + schema_name_arg=schema_name, ) @@ -259,6 +263,37 @@ def collect_uniques(table_args: ast.expr) -> set[tuple[str]]: return uniques +def collect_schema_name_table_arg(table_args: ast.AST) -> str | None: + schema_arg: ast.AST | None = None + + # TODO: this shall support the parsing of all the possible + # types of values __table_args__ could possess, I remember + # also a dictionary being possible and maybe something else + # other than a tuple + if isinstance(table_args, ast.Tuple): + for elt in table_args.elts: + # looking for the dictionary with the table args + if not isinstance(elt, ast.Dict): + continue + for key, val in zip(elt.keys, elt.values): + if not isinstance(key, ast.Constant): + continue + if key.value == 'schema': + schema_arg = val + break + elif isinstance(table_args, ast.Dict): + for key, val in zip(table_args.keys, table_args.values): + if not isinstance(key, ast.Constant): + continue + if key.value == 'schema': + schema_arg = val + break + + schema_name = schema_arg.value if isinstance(schema_arg, ast.Constant) else None + + return schema_name + + def is_valid_sqlmodel_class(class_def: ast.ClassDef) -> bool: # ensuring that the class inherits from 'SQLModel' for base in class_def.bases: diff --git a/tests/helpers/postgres_container.py b/tests/helpers/postgres_container.py index 8f2120f..d38f76f 100644 --- a/tests/helpers/postgres_container.py +++ b/tests/helpers/postgres_container.py @@ -4,17 +4,19 @@ import docker import psycopg +import uuid class PostgresContainer: def __init__(self): self.client = docker.from_env() self.image = 'postgres:16' - self.host_port = 8111#self._find_free_port() + # self.host_port = 8111#self._find_free_port() + self.host_port = self._find_free_port() self.container_port = 5432 self.username = 'tester' self.password = 'password' - self.database = 'test' + self.database = f'test_{uuid.uuid4()}' self.container = None diff --git a/tests/test_exec_postgres.py b/tests/test_exec_postgres.py new file mode 100644 index 0000000..0aa56fc --- /dev/null +++ b/tests/test_exec_postgres.py @@ -0,0 +1,162 @@ +''' +this test shall somehow execute the python code generated against a real postgres database +''' + +import psycopg +import pytest +from sqlmodel import SQLModel + +from sqlmodelgen import gen_code_from_postgres +from helpers.postgres_container import postgres_container + + +@pytest.fixture(autouse=True) +def reset_sqlmodel(): + ''' + this hereby implemented fixture is used to reset the metadata, + i.e. the data regarding the declared SQLModel classes + representing the tables. In this way several tests interfacing + with different database instances can regenerate different + tables with the same name + ''' + yield + # drops table objects + SQLModel.metadata.clear() + + +def test_exec_single_schema_name_with_uniques(): + ''' + verifies that a table in a schema with uniques can be + actually inserted and then selected rows + ''' + + sql = '''CREATE SCHEMA IF NOT EXISTS user_data; + +CREATE TABLE user_data.users( + id uuid NOT NULL, + PRIMARY KEY (id), + email TEXT NOT NULL UNIQUE, + name TEXT NOT NULL UNIQUE, + psw TEXT NOT NULL +); +''' + + with postgres_container() as pgc: + conn_str = pgc.get_conn_string() + with psycopg.connect(conn_str) as conn: + # creating schema and tables with cursor + cursor = conn.cursor() + if isinstance(sql, str): + cursor.execute(sql) + elif isinstance(sql, list): + for statement in sql: + cursor.execute(statement) + conn.commit() + + # generating code + generated_code = gen_code_from_postgres( + postgres_conn_addr=conn_str, + schema_name='user_data', + ) + + # support_code is the code to be executed against + # the existing database, in order to verify the actual + # code's functionality to interact with the database + support_code = f''' + +from sqlmodel import Session, create_engine, select + +conn_str = conn_str.replace('postgres', 'postgresql+psycopg') +engine = create_engine(conn_str, echo=False) + +with Session(engine) as session: + + + hero = Users( + name='Robin', + email='robin@waine_ind.com', + psw='bruceWayneBoomer' + ) + session.add(hero) + session.commit() + + heroes = session.exec(select(Users)).all() + + assert len(heroes) == 1 + assert heroes[0].name == 'Robin' + assert heroes[0].psw == 'bruceWayneBoomer' +''' + exec_code = generated_code + support_code + + print(exec_code) + + exec(exec_code, locals()) + + +def test_exec_single_schema_name_without_uniques(): + ''' + verifies that a table in a schema without uniques can be + actually inserted and selected rows + ''' + + sql = '''CREATE SCHEMA IF NOT EXISTS user_data; + +CREATE TABLE user_data.users( + id uuid NOT NULL, + PRIMARY KEY (id), + email TEXT NOT NULL, + name TEXT NOT NULL, + psw TEXT NOT NULL +); +''' + + with postgres_container() as pgc: + conn_str = pgc.get_conn_string() + with psycopg.connect(conn_str) as conn: + # creating schema and tables with cursor + cursor = conn.cursor() + if isinstance(sql, str): + cursor.execute(sql) + elif isinstance(sql, list): + for statement in sql: + cursor.execute(statement) + conn.commit() + + # generating code + generated_code = gen_code_from_postgres( + postgres_conn_addr=conn_str, + schema_name='user_data', + ) + + # support_code is the code to be executed against + # the existing database, in order to verify the actual + # code's functionality to interact with the database + support_code = f''' + +from sqlmodel import Session, create_engine, select + +conn_str = conn_str.replace('postgres', 'postgresql+psycopg') +engine = create_engine(conn_str, echo=False) + +with Session(engine) as session: + + + hero = Users( + name='Robin', + email='robin@waine_ind.com', + psw='bruceWayneBoomer' + ) + session.add(hero) + session.commit() + + heroes = session.exec(select(Users)).all() + + assert len(heroes) == 1 + assert heroes[0].name == 'Robin' + assert heroes[0].psw == 'bruceWayneBoomer' +''' + exec_code = generated_code + support_code + + print(exec_code) + + exec(exec_code, locals()) \ No newline at end of file diff --git a/tests/test_gen_from_postgres.py b/tests/test_gen_from_postgres.py index bc8f1ac..b903ed4 100644 --- a/tests/test_gen_from_postgres.py +++ b/tests/test_gen_from_postgres.py @@ -123,3 +123,54 @@ class Athletes(SQLModel, table=True): bio: str | None nickname: str | None nation: 'Nations' | None = Relationship(back_populates='athletess')''') + + +def test_postgres_schema_varying_uniques(): + ''' + the purpose of this test is to verify that a postgres database + with a schema and tables with and without uniques is correctly + processed + ''' + + sql = '''CREATE SCHEMA IF NOT EXISTS user_data; + +CREATE TABLE user_data.users( + id uuid NOT NULL, + PRIMARY KEY (id), + email TEXT NOT NULL UNIQUE, + name TEXT NOT NULL UNIQUE, + psw TEXT NOT NULL +); + +CREATE TABLE user_data.stuff( + id uuid NOT NULL, + PRIMARY KEY (id), + email TEXT NOT NULL, + name TEXT NOT NULL, + psw TEXT NOT NULL +); +''' + + code_generated = postgres_verify(sql, rels=True) + + assert collect_code_info(code_generated) == collect_code_info(''' +from sqlmodel import SQLModel, Field, UniqueConstraint +from uuid import UUID, uuid4 + +class Users(SQLModel, table=True): + __tablename__ = 'users' + __table_args__ = (UniqueConstraint('email'), UniqueConstraint('name'), {'schema': 'user_data'}) + id: UUID = Field(primary_key=True, default_factory=uuid4) + email: str + name: str + psw: str + + +class Stuff(SQLModel, table=True): + __tablename__ = 'stuff' + __table_args__ = {'schema': 'user_data'} + id: UUID = Field(primary_key=True, default_factory=uuid4) + email: str + name: str + psw: str +''') diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 737b6c4..b9a1222 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -9,6 +9,7 @@ from helpers.helpers import ( type_data_from_ast_annassign, collect_code_info, + collect_schema_name_table_arg, ModuleAstInfo, ClassAstInfo, ColumnAstInfo, @@ -41,7 +42,10 @@ def test_collect_code_info(): class a_table(SQLModel, table = True): __tablename__ = 'a_table' - __table_args__ = (UniqueConstraint('name'), ) + __table_args__ = ( + UniqueConstraint('name'), + {'schema':'a_schema'}, + ) id: int | None = Field(primary_key=True) name: str email: str | None''') @@ -81,7 +85,76 @@ class a_table(SQLModel, table = True): optional=True ) ), - } + }, + schema_name_arg='a_schema', ) } ) + +def test_collect_schema_info(): + ''' + this test shall assert that the schema name is + collected successfully if present in different types of table args + ''' + + # declaring code with several SQLModel classes, all with + # different or non existing __table_args__ + table_arg_only_schema_info = collect_code_info(''' +class a_table(SQLModel, table = True): + __tablename__ = 'a_table' + __table_args__ = {'schema':'a_schema'} + id: int | None = Field(primary_key=True) + name: str + +class b_table(SQLModel, table = True): + __tablename__ = 'b_table' + __table_args__ = ({'schema':'b_schema'},) + id: int | None = Field(primary_key=True) + name: str + +class c_table(SQLModel, table = True): + __tablename__ = 'c_table' + __table_args__ = (UniqueConstraint('name'),{'schema':'c_schema'},) + id: int | None = Field(primary_key=True) + name: str + +class d_table(SQLModel, table = True): + __tablename__ = 'd_table' + __table_args__ = (UniqueConstraint('name'),) + id: int | None = Field(primary_key=True) + name: str + +class e_table(SQLModel, table = True): + __tablename__ = 'e_table' + id: int | None = Field(primary_key=True) + name: str +''') + + # out of the collected code info a dictionary associating + # to every table name is built + schema_names_dict = { + class_name : class_info.schema_name_arg + for class_name, class_info in table_arg_only_schema_info.classes_info.items() + } + + assert schema_names_dict == { + 'a_table':'a_schema', + 'b_table':'b_schema', + 'c_table':'c_schema', + 'd_table':None, + 'e_table':None, + } + + +def test_collect_schema_name_table_arg(): + expr = ast.parse('{\'schema\' : \'another_schema\'}', mode='eval') + assert collect_schema_name_table_arg(expr.body) == 'another_schema' + + expr = ast.parse('(\'schema\', \'another_schema\')', mode='eval') + assert collect_schema_name_table_arg(expr.body) is None + + expr = ast.parse('2 + 2', mode='eval') + assert collect_schema_name_table_arg(expr.body) is None + + expr = ast.parse('2 + 2', mode='eval') + assert collect_schema_name_table_arg(expr.body) is None \ No newline at end of file