Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/sqlmodelgen/codegen/cir_to_full_ast/code_ir_to_ast.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
from typing import Iterable

from sqlmodelgen.codegen.code_ir.code_ir import AttributeIR, AttrCallIR, ModelIR
from sqlmodelgen.codegen.code_ir.code_ir import AttributeIR, AttrCallIR, ModelIR, SchemaNameArgIR
from sqlmodelgen.codegen.cir_to_full_ast.to_ast_imports import gen_imports


Expand Down Expand Up @@ -57,13 +57,27 @@ def gen_table_args(model_ir: ModelIR) -> ast.Assign | None:
if len(model_ir.table_args) == 0:
return None

# I shall build the value

# first case is if there's only one table argument and it's the
# schema name, then let's just have its dictionary as the __table_args__
if len(model_ir.table_args) == 1 and isinstance(model_ir.table_args[0], SchemaNameArgIR):
value = model_ir.table_args[0].to_expr()
# otherwise just make a tuple of the args
else:
# NOTE: sqlalchemy requires the dictionary to be placed at
# last
# TODO: enforce the sqlalchemy dictionary (at the moment no mergings
# needed, only schema attribute is present) to be at the last
# position
value = ast.Tuple(
elts=[table_arg.to_expr() for table_arg in model_ir.table_args]
)
# at this level we gust generate the unique constraint

return ast.Assign(
targets=[ast.Name('__table_args__')],
value=ast.Tuple(
elts=[table_arg.to_expr() for table_arg in model_ir.table_args]
)
value=value,
)


Expand Down
12 changes: 10 additions & 2 deletions src/sqlmodelgen/codegen/cir_to_full_ast/to_ast_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from itertools import chain
from typing import Iterable, Iterator

# type union for generic import ast node
AST_IMPORT_TYPE = ast.Import | ast.ImportFrom

# imports for every specific type
TYPE_IMPORTS = {
'datetime': ast.ImportFrom(
module='datetime',
Expand All @@ -33,6 +35,9 @@
}

def gen_imports(cdefs: Iterable[ast.ClassDef]) -> Iterator[AST_IMPORT_TYPE]:
'''
generates import statements from the class nodes
'''
data_types_names = set(chain(*map(_iter_data_type_names, cdefs)))

call_names = set(chain(*map(_iter_call_names, cdefs)))
Expand All @@ -46,23 +51,26 @@ def gen_imports(cdefs: Iterable[ast.ClassDef]) -> Iterator[AST_IMPORT_TYPE]:


def gen_sqlmodel_import(call_names: set[str]) -> ast.ImportFrom:
'''
based on the collected calls this function returns an est elements
with the imports necessary from the sqlmodel library
'''
sqlmodel_import = ast.ImportFrom(
module='sqlmodel',
names=[
ast.alias('SQLModel')
]
)

# checking the call names for specific imports
if 'Field' in call_names:
sqlmodel_import.names.append(
ast.alias('Field')
)

if 'Relationship' in call_names:
sqlmodel_import.names.append(
ast.alias('Relationship')
)

if 'UniqueConstraint' in call_names:
sqlmodel_import.names.append(
ast.alias('UniqueConstraint')
Expand Down
20 changes: 16 additions & 4 deletions src/sqlmodelgen/codegen/code_ir/build_cir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@
from sqlmodelgen.codegen.code_ir.code_ir import ModelIR
from sqlmodelgen.codegen.code_ir.build_rels import add_relationships_attrs
from sqlmodelgen.codegen.code_ir.build_col_attrs import attribute_from_col
from sqlmodelgen.codegen.code_ir.build_table_args import build_unique_constraints
from sqlmodelgen.codegen.code_ir.build_table_args import build_table_args
from sqlmodelgen.ir.ir import SchemaIR, TableIR

def build_model_irs(schema_ir: SchemaIR, gen_relationships: bool, table_name_transform: Callable[[str], str] | None = None, column_name_transform: Callable[[str], str] | None = None) -> list[ModelIR]:
class_names: set[str] = set()
models_by_table_name: dict[str, ModelIR] = dict()

for table_ir in schema_ir.table_irs:
model_ir = build_model_ir(table_ir=table_ir, class_names=class_names, table_name_transform=table_name_transform, column_name_transform=column_name_transform)
model_ir = build_model_ir(
table_ir=table_ir,
class_names=class_names,
table_name_transform=table_name_transform,
column_name_transform=column_name_transform,
schema_name=schema_ir.schema_name,
)

models_by_table_name[model_ir.table_name] = model_ir

Expand All @@ -33,10 +39,16 @@ def gen_class_name(table_name: str, class_names: set[str], table_name_transform:



def build_model_ir(table_ir: TableIR, class_names: set[str], table_name_transform: Callable[[str], str] | None = None, column_name_transform: Callable[[str], str] | None = None) -> ModelIR:
def build_model_ir(
table_ir: TableIR,
class_names: set[str],
table_name_transform: Callable[[str], str] | None = None,
column_name_transform: Callable[[str], str] | None = None,
schema_name: str | None = None,
) -> ModelIR:
return ModelIR(
class_name=gen_class_name(table_ir.name, class_names, table_name_transform),
table_name=table_ir.name,
attrs=[attribute_from_col(col_ir, column_name_transform) for col_ir in table_ir.col_irs],
table_args=list(build_unique_constraints(table_ir)),
table_args=list(build_table_args(table_ir, schema_name)),
)
13 changes: 12 additions & 1 deletion src/sqlmodelgen/codegen/code_ir/build_table_args.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from typing import Iterator

from sqlmodelgen.codegen.code_ir.code_ir import UniqueTableArgIR
from sqlmodelgen.codegen.code_ir.code_ir import UniqueTableArgIR, SchemaNameArgIR
from sqlmodelgen.ir.ir import TableIR

def build_table_args(table_ir: TableIR, schema_name: str | None) -> Iterator[UniqueTableArgIR]:
yield from build_unique_constraints(table_ir)

# yield the schema name at last in order to be consistent
# with the sqlalchemy requirement of having the schema
# dictionary at last
if schema_name is not None and schema_name != 'public':
yield SchemaNameArgIR(schema_name=schema_name)



def build_unique_constraints(table_ir: TableIR) -> Iterator[UniqueTableArgIR]:
# TODO: still no code tp generate unique for multiple columns

Expand Down
13 changes: 12 additions & 1 deletion src/sqlmodelgen/codegen/code_ir/code_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,20 @@ def to_expr(self) -> ast.Call:
return ast.Call(
func=ast.Name('UniqueConstraint'),
args=[ast.Constant(col_name) for col_name in self._col_names],
keywords=[]
keywords=[],
)


class SchemaNameArgIR():

def __init__(self, schema_name: str):
self._schema_name = schema_name

def to_expr(self)-> ast.Dict:
return ast.Dict(
keys=[ast.Constant(value='schema')],
values=[ast.Constant(value=self._schema_name)],
)


@dataclass
Expand Down
15 changes: 12 additions & 3 deletions src/sqlmodelgen/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,29 @@
'''

import ast
from typing import Callable
from typing import Callable, Iterable

from sqlmodelgen.ir.ir import SchemaIR
from sqlmodelgen.codegen.code_ir.build_cir import build_model_irs
from sqlmodelgen.codegen.code_ir.code_ir import ModelIR
from sqlmodelgen.codegen.cir_to_full_ast.code_ir_to_ast import models_to_ast


def gen_code(
schema_ir: SchemaIR,
schema_ir: Iterable[SchemaIR] | SchemaIR,
generate_relationships: bool = False,
table_name_transform: Callable[[str], str] | None = None,
column_name_transform: Callable[[str], str] | None = None,
) -> str:
model_irs = build_model_irs(schema_ir, generate_relationships, table_name_transform, column_name_transform)
# in case the schema_ir attribute is a single SchemaIR
if isinstance(schema_ir, SchemaIR):
model_irs = build_model_irs(schema_ir, generate_relationships, table_name_transform, column_name_transform)
# otherwise I assume schema_ir is an iterable
else:
schema_irs = schema_ir
model_irs: list[ModelIR] = []
for schema_ir in schema_irs:
model_irs += build_model_irs(schema_ir, generate_relationships, table_name_transform, column_name_transform)
models_ast = models_to_ast(model_irs)

return ast.unparse(models_ast)
1 change: 1 addition & 0 deletions src/sqlmodelgen/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_col_ir(self, name: str) -> ColIR | None:
@dataclass
class SchemaIR:
table_irs: list[TableIR]
schema_name: str | None = None

def get_table_ir(self, name: str) -> TableIR | None:
'''
Expand Down
35 changes: 28 additions & 7 deletions src/sqlmodelgen/ir/postgres/postgres_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import psycopg

from dataclasses import dataclass
from typing import Iterator
from typing import Generator, Iterator

from sqlmodelgen.ir.ir import (
ColIR,
Expand All @@ -13,6 +13,14 @@
FKIR
)


SCHEMAS_TO_AVOID = {
'pg_toast',
'pg_catalog',
'information_schema',
}


@dataclass
class ContraintsData:
uniques: dict[str, set[str]]
Expand All @@ -34,11 +42,27 @@ def get_foreign_key(self, table_name: str, column_name: str) -> FKIR | None:
return table_fks.get(column_name)


def collect_postgres_ir(postgres_conn_addr: str, schema_name: str = 'public') -> SchemaIR:
def collect_postgres_ir(postgres_conn_addr: str, schema_name: str | None = None) -> Generator[SchemaIR, None, None]:

conn = psycopg.connect(postgres_conn_addr)
cursor = conn.cursor()

# obtaining the schemas in the database
cursor.execute('SELECT nspname FROM pg_catalog.pg_namespace')
schema_names = [schema_row[0] for schema_row in cursor.fetchall() if schema_row[0] not in SCHEMAS_TO_AVOID]

for schema_name in schema_names:
yield collect_schema_ir(cursor, schema_name)

# TODO: potentially collect contraints regarding foreign keys

conn.close()


def collect_schema_ir(
cursor: psycopg.Cursor,
schema_name: str,
) -> SchemaIR:
constraints = collect_contraints(cursor, schema_name)

cursor.execute('SELECT * FROM pg_catalog.pg_tables WHERE schemaname=%s', (schema_name, ))
Expand All @@ -58,12 +82,9 @@ def collect_postgres_ir(postgres_conn_addr: str, schema_name: str = 'public') ->
))
))

# TODO: potentially collect contraints regarding foreign keys

conn.close()

return SchemaIR(
table_irs=table_irs
schema_name=schema_name,
table_irs=table_irs,
)


Expand Down
41 changes: 38 additions & 3 deletions tests/helpers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ class ColumnAstInfo:
class ClassAstInfo:
class_name: str
table_name: str | None
uniques: set[tuple[str]]
uniques: set[tuple[str, ...]]
cols_info: dict[str, ColumnAstInfo]
schema_name_arg: str | None = None


@dataclass
Expand Down Expand Up @@ -149,7 +150,8 @@ def collect_sqlmodel_class(class_def: ast.ClassDef) -> ClassAstInfo | None:

class_name = class_def.name
table_name: str | None = None
uniques: list[tuple[str]] = list()
uniques: set[tuple[str]] = set()
schema_name: str | None = None
cols_info: dict[str, ColumnAstInfo] = dict()

for stat in class_def.body:
Expand All @@ -170,6 +172,7 @@ def collect_sqlmodel_class(class_def: ast.ClassDef) -> ClassAstInfo | None:
table_name = collect_table_name(stat)
elif var_name == '__table_args__':
uniques = collect_uniques(stat.value)
schema_name = collect_schema_name_table_arg(stat.value)

elif type(stat) is ast.AnnAssign:
col_info = collect_col_info(stat)
Expand All @@ -179,7 +182,8 @@ def collect_sqlmodel_class(class_def: ast.ClassDef) -> ClassAstInfo | None:
class_name=class_name,
table_name=table_name,
uniques=uniques,
cols_info=cols_info
cols_info=cols_info,
schema_name_arg=schema_name,
)


Expand Down Expand Up @@ -259,6 +263,37 @@ def collect_uniques(table_args: ast.expr) -> set[tuple[str]]:
return uniques


def collect_schema_name_table_arg(table_args: ast.AST) -> str | None:
schema_arg: ast.AST | None = None

# TODO: this shall support the parsing of all the possible
# types of values __table_args__ could possess, I remember
# also a dictionary being possible and maybe something else
# other than a tuple
if isinstance(table_args, ast.Tuple):
for elt in table_args.elts:
# looking for the dictionary with the table args
if not isinstance(elt, ast.Dict):
continue
for key, val in zip(elt.keys, elt.values):
if not isinstance(key, ast.Constant):
continue
if key.value == 'schema':
schema_arg = val
break
elif isinstance(table_args, ast.Dict):
for key, val in zip(table_args.keys, table_args.values):
if not isinstance(key, ast.Constant):
continue
if key.value == 'schema':
schema_arg = val
break

schema_name = schema_arg.value if isinstance(schema_arg, ast.Constant) else None

return schema_name


def is_valid_sqlmodel_class(class_def: ast.ClassDef) -> bool:
# ensuring that the class inherits from 'SQLModel'
for base in class_def.bases:
Expand Down
6 changes: 4 additions & 2 deletions tests/helpers/postgres_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@

import docker
import psycopg
import uuid

class PostgresContainer:

def __init__(self):
self.client = docker.from_env()
self.image = 'postgres:16'
self.host_port = 8111#self._find_free_port()
# self.host_port = 8111#self._find_free_port()
self.host_port = self._find_free_port()
self.container_port = 5432
self.username = 'tester'
self.password = 'password'
self.database = 'test'
self.database = f'test_{uuid.uuid4()}'
self.container = None


Expand Down
Loading
Loading