Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 135 additions & 142 deletions scripts/sqlpp23-ddl2cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import sys
import re
import os
import unittest
import jinja2
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Dict, List
Expand Down Expand Up @@ -553,154 +554,146 @@ class DdlExecutor:
class ModelWriter:
"""This class writes the database model as C++ headers and/or a module file"""

TEMPLATE = """{% if is_module %}
module;
{% else %}
#pragma once
{% endif %}

// clang-format off
// generated by {{ command_line }}

{% if use_import_std %}
import std;
{% else %}
#include <optional>
{% endif %}

{% if is_module %}
#include <sqlpp23/core/name/create_name_tag.h>

import sqlpp23.core;

export module {{ module_name }};

{% else %}
{% if use_import_sqlpp23 %}
import sqlpp23.core;

#include <sqlpp23/core/name/create_name_tag.h>

{% else %}
#include <sqlpp23/core/basic/table.h>
#include <sqlpp23/core/basic/table_columns.h>
#include <sqlpp23/core/name/create_name_tag.h>
#include <sqlpp23/core/type_traits.h>

{% endif %}
{% if custom_directives %}
{% for directive in custom_directives.split(',') %}
{{ directive }}
{% endfor %}

{% endif %}
{% endif %}
namespace {{ namespace }} {
{% for table in tables %}
{% if generate_table_creation_helper %}
{{ export_prefix }}template<typename Db>
void {{ table.creation_helper_func }}(Db& db) {
db(R"+++(DROP TABLE IF EXISTS {{ table.sql_name }})+++");
{% for command in table.commands %}
db(R"+++({{ command }})+++");
{% endfor %}
}

{% endif %}
{{ export_prefix }}struct {{ table.class_name }}_ {
{% for column in table.columns %}
struct {{ column.class_name }} {
SQLPP_CREATE_NAME_TAG_FOR_SQL_AND_CPP({{ column.sql_name }}, {{ column.member_name }});
using data_type = {{ column.const_prefix }}{% if column.is_nullable %}std::optional<{{ column.cpp_type }}>{% else %}{{ column.cpp_type }}{% endif %};
using has_default = {{ 'std::true_type' if column.has_default else 'std::false_type' }};
};
{% endfor %}
SQLPP_CREATE_NAME_TAG_FOR_SQL_AND_CPP({{ table.sql_name }}, {{ table.member_name }});
template<typename T>
using _table_columns = sqlpp::table_columns<T,
{{ table.column_classes | join(',\n ') }}>;
using _required_insert_columns = sqlpp::detail::type_set<{% if table.required_insert_columns %}
{{ table.required_insert_columns | join(',\n ') }}{% endif %}>;
};
{{ export_prefix }}using {{ table.class_name }} = ::sqlpp::table_t<{{ table.class_name }}_>;

{% endfor %}
} // namespace {{ namespace }}
"""

@classmethod
def write(cls, tables, args):
template = jinja2.Template(cls.TEMPLATE, trim_blocks=True, lstrip_blocks=True)

common_context = {
"command_line": " ".join(sys.argv),
"use_import_std": args.use_import_std,
"use_import_sqlpp23": args.use_import_sqlpp23,
"custom_directives": args.custom_directives,
"namespace": args.namespace,
"generate_table_creation_helper": args.generate_table_creation_helper,
"naming_style": args.naming_style,
"export_prefix": "export " if args.path_to_module else "",
}

if args.path_to_header:
cls._create_header(tables, args)
cls._write_to_file(args.path_to_header, tables.values(), {**common_context, "is_module": False}, template, args)
if args.path_to_header_directory:
cls._create_split_headers(tables, args)
for t in tables.values():
path = os.path.join(args.path_to_header_directory, cls._to_class_name(t.name, args) + ".h")
cls._write_to_file(path, [t], {**common_context, "is_module": False}, template, args)
if args.path_to_module:
cls._create_module(tables, args)

@classmethod
def _create_header(cls, tables, args):
header = cls._begin_header(args.path_to_header, args)
for t in tables.values():
cls._write_table(t, header, args)
cls._end_header(header, args)

@classmethod
def _create_split_headers(cls, tables, args):
for t in tables.values():
header = cls._begin_header(os.path.join(args.path_to_header_directory, cls._to_class_name(t.name, args) + ".h"), args)
cls._write_table(t, header, args)
cls._end_header(header, args)

@staticmethod
def _begin_header(path_to_header, args):
header = open(path_to_header, "w")
print("#pragma once", file=header)
print("", file=header)
print("// clang-format off", file=header)
print("// generated by " + " ".join(sys.argv), file=header)
print("", file=header)
if args.use_import_std:
print("import std;", file=header)
else:
print("#include <optional>", file=header)
if args.use_import_sqlpp23:
print("import sqlpp23.core;", file=header)
print("", file=header)
print("#include <sqlpp23/core/name/create_name_tag.h>", file=header)
else:
print("", file=header)
print("#include <sqlpp23/core/basic/table.h>", file=header)
print("#include <sqlpp23/core/basic/table_columns.h>", file=header)
print("#include <sqlpp23/core/name/create_name_tag.h>", file=header)
print("#include <sqlpp23/core/type_traits.h>", file=header)
if args.custom_directives:
for directive in args.custom_directives.split(","):
print(directive, file=header)
print("", file=header)
print("namespace " + args.namespace + " {", file=header)
return header

@staticmethod
def _end_header(header, args):
print("} // namespace " + args.namespace, file=header)
header.close()

@classmethod
def _create_module(cls, tables, args):
module = cls._begin_module(args.path_to_module, args)
for t in tables.values():
cls._write_table(t, module, args)
cls._end_module(module, args)

@staticmethod
def _begin_module(path_to_module, args):
module = open(path_to_module, "w")
print("module;", file=module)
print("", file=module)
print("// clang-format off", file=module)
print("// generated by " + " ".join(sys.argv), file=module)
print("", file=module)
if args.use_import_std:
print("import std;", file=module)
else:
print("#include <optional>", file=module)
print("", file=module)
print("#include <sqlpp23/core/name/create_name_tag.h>", file=module)
print("", file=module)
print("import sqlpp23.core;", file=module)
print("", file=module)
print("export module " + args.module_name + ";", file=module)
print("", file=module)
print("namespace " + args.namespace + " {", file=module)
return module

@staticmethod
def _end_module(module, args):
print("} // namespace " + args.namespace, file=module)
module.close()
cls._write_to_file(args.path_to_module, tables.values(), {**common_context, "is_module": True, "module_name": args.module_name}, template, args)

@classmethod
def _write_table(cls, table, header, args):
export = "export " if args.path_to_module else ""
table_class = cls._to_class_name(table.name, args)
table_member = cls._to_member_name(table.name, args)
table_spec = table_class + "_"
table_template_parameters = ""
table_required_insert_columns = ""
if args.generate_table_creation_helper:
creation_helper_func = "create" + ("" if args.naming_style == "camel-case" else "_") + table_class
print(" " + export + "template<typename Db>", file=header)
print(" void " + creation_helper_func + "(Db& db) {", file=header)
print(" db(R\"+++(DROP TABLE IF EXISTS " + table.name + ")+++\");", file=header)
for command in table.commands:
print(" db(R\"+++(" + command + ")+++\");", file=header)
print(" }", file=header)
print("", file=header)
print(" " + export + "struct " + table_spec + " {", file=header)
for column in table.columns.values():
column_class = cls._to_class_name(column.name, args)
column_member = cls._to_member_name(column.name, args)
print(" struct " + column_class + " {", file=header)
print(" SQLPP_CREATE_NAME_TAG_FOR_SQL_AND_CPP("
+ cls._escape_if_reserved(column.name) + ", " + column_member + ");"
, file=header)
const_prefix = "const " if column.is_const else ""
type_str = column.cpp_type
if column.is_nullable:
print(" using data_type = " + const_prefix + "std::optional<" + type_str + ">;", file=header)
else:
print(" using data_type = " + const_prefix + type_str + ";", file=header)
if column.has_default:
print(" using has_default = std::true_type;", file=header)
else:
print(" using has_default = std::false_type;", file=header)
print(" };", file=header)
if table_template_parameters:
table_template_parameters += ","
table_template_parameters += "\n " + column_class
if not column.has_default:
if table_required_insert_columns:
table_required_insert_columns += ","
table_required_insert_columns += "\n sqlpp::column_t<sqlpp::table_t<" + table_spec + ">, " + column_class + ">";
print(" SQLPP_CREATE_NAME_TAG_FOR_SQL_AND_CPP("
+ cls._escape_if_reserved(table.name) + ", " + table_member + ");"
, file=header)
print(" template<typename T>", file=header)
print(" using _table_columns = sqlpp::table_columns<T,"
+ table_template_parameters
+ ">;", file=header)
print(" using _required_insert_columns = sqlpp::detail::type_set<"
+ table_required_insert_columns
+ ">;", file=header)
print(" };", file=header)
print(
" " + export + "using " + table_class + " = ::sqlpp::table_t<" + table_spec + ">;", file=header)
print("", file=header)
def _write_to_file(cls, path, tables, context, template, args):
table_contexts = []
for t in tables:
table_class = cls._to_class_name(t.name, args)
table_member = cls._to_member_name(t.name, args)
table_spec = table_class + "_"

column_contexts = []
column_classes = []
required_insert_columns = []
for col in t.columns.values():
col_class = cls._to_class_name(col.name, args)
col_member = cls._to_member_name(col.name, args)
column_contexts.append({
"class_name": col_class,
"member_name": col_member,
"sql_name": cls._escape_if_reserved(col.name),
"const_prefix": "const " if col.is_const else "",
"cpp_type": col.cpp_type,
"is_nullable": col.is_nullable,
"has_default": col.has_default,
})
column_classes.append(col_class)
if not col.has_default:
required_insert_columns.append(f"sqlpp::column_t<sqlpp::table_t<{table_spec}>, {col_class}>")

table_contexts.append({
"class_name": table_class,
"member_name": table_member,
"sql_name": cls._escape_if_reserved(t.name),
"creation_helper_func": "create" + ("" if args.naming_style == "camel-case" else "_") + table_class,
"commands": t.commands,
"columns": column_contexts,
"column_classes": column_classes,
"required_insert_columns": required_insert_columns,
})

full_context = {**context, "tables": table_contexts}
with open(path, "w") as f:
f.write(template.render(full_context))

@classmethod
# Prepends optional schema with an `_`
Expand Down Expand Up @@ -758,7 +751,7 @@ def parse_commandline_args():
required.add_argument("--path-to-ddl", nargs="*", help="one or more path(s) to DDL input file(s)")
required.add_argument("--namespace", help="namespace for generated table classes")

paths = arg_parser.add_argument_group("Paths", "Choose one or more paths for code generation:")
paths = arg_parser.add_argument_group("Paths", "Choose one or more path(s) for code generation:")
paths.add_argument("--path-to-module", help="path to generated module file (also requires --module-name)")
paths.add_argument("--path-to-header", help="path to generated header file (one file for all tables)")
paths.add_argument("--path-to-header-directory", help="path to directory for generated header files (one file per table)")
Expand Down