From 040a442cca500144fb0d4a149b4560d2061f0c18 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Sat, 9 May 2026 13:24:30 +0800 Subject: [PATCH 01/11] Strip invisible chars from all string args in graphsample_insert main() Co-Authored-By: Claude Opus 4.7 --- sqlite/graphsample_insert.py | 56 ++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/sqlite/graphsample_insert.py b/sqlite/graphsample_insert.py index afb40816ab..5c303a0e2b 100755 --- a/sqlite/graphsample_insert.py +++ b/sqlite/graphsample_insert.py @@ -442,51 +442,59 @@ def insert_sample_input_tensor_meta( # main func def main(args): + model_path_prefix = args.model_path_prefix.strip() + relative_model_path = args.relative_model_path.strip() + repo_uid = args.repo_uid.strip() + sample_type = args.sample_type.strip() + db_path = args.db_path.strip() + op_names_path_prefix = ( + args.op_names_path_prefix.strip() if args.op_names_path_prefix else "" + ) data = get_graph_sample_data( - model_path_prefix=args.model_path_prefix, - relative_model_path=args.relative_model_path, - repo_uid=args.repo_uid, - sample_type=args.sample_type, + model_path_prefix=model_path_prefix, + relative_model_path=relative_model_path, + repo_uid=repo_uid, + sample_type=sample_type, order_value=args.order_value, ) - print(f"\ninsert into database: {args.db_path}") + print(f"\ninsert into database: {db_path}") try: - insert_graph_sample(args.db_path, data, args.model_path_prefix) + insert_graph_sample(db_path, data, model_path_prefix) if data["is_subgraph"]: subgraph_source_data = insert_subgraph_source( subgraph_uuid=data["uuid"], - model_path_prefix=args.model_path_prefix, - sample_type=args.sample_type, - relative_model_path=args.relative_model_path, - db_path=args.db_path, + model_path_prefix=model_path_prefix, + sample_type=sample_type, + relative_model_path=relative_model_path, + db_path=db_path, ) insert_sample_op_name_list( sample_uuid=data["uuid"], - model_path_prefix=args.model_path_prefix, - op_names_path_prefix=args.op_names_path_prefix, - relative_model_path=args.relative_model_path, - db_path=args.db_path, + model_path_prefix=model_path_prefix, + op_names_path_prefix=op_names_path_prefix, + relative_model_path=relative_model_path, + db_path=db_path, ) insert_sample_input_tensor_meta( sample_uuid=data["uuid"], - model_path_prefix=args.model_path_prefix, - relative_model_path=args.relative_model_path, - db_path=args.db_path, + model_path_prefix=model_path_prefix, + relative_model_path=relative_model_path, + db_path=db_path, ) - if args.sample_type in ["fusible_graph", "typical_graph"]: + if sample_type in ["fusible_graph", "typical_graph"]: insert_dimension_generalization_source( subgraph_source_data["subgraph_uuid"], subgraph_source_data["full_graph_uuid"], - args.model_path_prefix, - args.relative_model_path, - args.db_path, + model_path_prefix, + relative_model_path, + db_path, ) insert_datatype_generalization_source( subgraph_source_data["subgraph_uuid"], subgraph_source_data["full_graph_uuid"], - args.model_path_prefix, - args.relative_model_path, - args.db_path, + model_path_prefix, + relative_model_path, + db_path, ) print(f"success insert: {data['relative_model_path']}") except sqlite3.IntegrityError as e: From fa7257d56f769d6d47c0f561aef55074c93d897c Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Sat, 9 May 2026 14:59:48 +0800 Subject: [PATCH 02/11] Extract insert_one_sample from graphsample_insert and add build_db.py Extract the insertion logic into a reusable insert_one_sample() function so build_db.py can import it directly instead of duplicating the code. Co-Authored-By: Claude Opus 4.7 --- sqlite/build_db.py | 129 +++++++++++++++++++++++++++++++++++ sqlite/graphsample_insert.py | 39 ++++++++--- 2 files changed, 158 insertions(+), 10 deletions(-) create mode 100755 sqlite/build_db.py diff --git a/sqlite/build_db.py b/sqlite/build_db.py new file mode 100755 index 0000000000..d7d86c0983 --- /dev/null +++ b/sqlite/build_db.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +import argparse +import os +import sys + +import graph_net + +from graphsample_insert import insert_one_sample + + +GRAPH_NET_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(graph_net.__file__))) + + +def insert_from_list( + list_file, + model_path_prefix, + sample_type, + repo_uid, + db_path, + op_names_path_prefix, + start_order=0, +): + if not os.path.isfile(list_file): + print(f"List file not found: {list_file}, skipping") + return start_order + + with open(list_file) as f: + paths = [line.strip() for line in f if line.strip()] + + order_value = start_order + for relative_model_path in paths: + print(f"insert : {relative_model_path}") + insert_one_sample( + model_path_prefix=model_path_prefix, + relative_model_path=relative_model_path, + repo_uid=repo_uid, + sample_type=sample_type, + order_value=order_value, + db_path=db_path, + op_names_path_prefix=op_names_path_prefix, + ) + order_value += 1 + + return order_value + + +def main(): + parser = argparse.ArgumentParser( + description="Batch insert graph samples from list files" + ) + parser.add_argument( + "--db_path", + type=str, + default=None, + help="Database file path (default: GRAPH_NET_ROOT/sqlite/GraphNet.db)", + ) + parser.add_argument( + "--dataset_root", + type=str, + default=None, + help="Dataset root directory (default: GRAPH_NET_ROOT/20260317)", + ) + parser.add_argument( + "--repo_uid", + type=str, + default="hf_torch_samples", + help="Repository uid", + ) + args = parser.parse_args() + + dataset_root = args.dataset_root or os.path.join(GRAPH_NET_ROOT, "20260317") + db_path = args.db_path or os.path.join(GRAPH_NET_ROOT, "sqlite", "GraphNet.db") + repo_uid = args.repo_uid.strip() + + if not os.path.isfile(db_path): + print(f"Fail ! No Database ! : {db_path}") + sys.exit(1) + + order_value = 0 + + # full_graph + order_value = insert_from_list( + list_file=os.path.join(dataset_root, "full_graph.txt"), + model_path_prefix=os.path.join(dataset_root, "full_graph"), + sample_type="full_graph", + repo_uid=repo_uid, + db_path=db_path, + op_names_path_prefix="", + start_order=order_value, + ) + + # typical_graph + order_value = insert_from_list( + list_file=os.path.join(dataset_root, "typical_graph.txt"), + model_path_prefix=os.path.join(dataset_root, "typical_graph"), + sample_type="typical_graph", + repo_uid=repo_uid, + db_path=db_path, + op_names_path_prefix=os.path.join(dataset_root, "03_sample_op_names"), + start_order=order_value, + ) + + # fusible_graph + order_value = insert_from_list( + list_file=os.path.join(dataset_root, "fusible_graph.txt"), + model_path_prefix=os.path.join(dataset_root, "fusible_graph"), + sample_type="fusible_graph", + repo_uid=repo_uid, + db_path=db_path, + op_names_path_prefix=os.path.join(dataset_root, "03_sample_op_names"), + start_order=order_value, + ) + + # sole_op_graph + order_value = insert_from_list( + list_file=os.path.join(dataset_root, "sole_op_graph.txt"), + model_path_prefix=os.path.join(dataset_root, "sole_op_graph"), + sample_type="sole_op_graph", + repo_uid=repo_uid, + db_path=db_path, + op_names_path_prefix=os.path.join(dataset_root, "03_sample_op_names"), + start_order=order_value, + ) + + print("all done") + + +if __name__ == "__main__": + main() diff --git a/sqlite/graphsample_insert.py b/sqlite/graphsample_insert.py index 5c303a0e2b..7fbc5e52c4 100755 --- a/sqlite/graphsample_insert.py +++ b/sqlite/graphsample_insert.py @@ -441,21 +441,28 @@ def insert_sample_input_tensor_meta( # main func -def main(args): - model_path_prefix = args.model_path_prefix.strip() - relative_model_path = args.relative_model_path.strip() - repo_uid = args.repo_uid.strip() - sample_type = args.sample_type.strip() - db_path = args.db_path.strip() - op_names_path_prefix = ( - args.op_names_path_prefix.strip() if args.op_names_path_prefix else "" - ) +def insert_one_sample( + model_path_prefix: str, + relative_model_path: str, + repo_uid: str, + sample_type: str, + order_value: int, + db_path: str, + op_names_path_prefix: str = "", +): + model_path_prefix = model_path_prefix.strip() + relative_model_path = relative_model_path.strip() + repo_uid = repo_uid.strip() + sample_type = sample_type.strip() + db_path = db_path.strip() + op_names_path_prefix = op_names_path_prefix.strip() if op_names_path_prefix else "" + data = get_graph_sample_data( model_path_prefix=model_path_prefix, relative_model_path=relative_model_path, repo_uid=repo_uid, sample_type=sample_type, - order_value=args.order_value, + order_value=order_value, ) print(f"\ninsert into database: {db_path}") try: @@ -504,6 +511,18 @@ def main(args): print(f"insert failed: {e}") +def main(args): + insert_one_sample( + model_path_prefix=args.model_path_prefix, + relative_model_path=args.relative_model_path, + repo_uid=args.repo_uid, + sample_type=args.sample_type, + order_value=args.order_value, + db_path=args.db_path, + op_names_path_prefix=args.op_names_path_prefix, + ) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="insert graph sample to database") parser.add_argument( From e1393bf014c4cf8f71f9b70c8c15eca2c6dae645 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Sat, 9 May 2026 15:29:26 +0800 Subject: [PATCH 03/11] Fix path issues in build_db and init_db - init_db: compute migrates_dir from script location instead of CWD-relative path - build_db: use main(args), add --op_names_path_prefix as required arg, auto-create db via migrate() - Remove unused GRAPH_NET_ROOT and graph_net import from build_db Co-Authored-By: Claude Opus 4.7 --- sqlite/build_db.py | 78 ++++++++++++++++++++++++---------------------- sqlite/init_db.py | 10 ++++-- 2 files changed, 48 insertions(+), 40 deletions(-) diff --git a/sqlite/build_db.py b/sqlite/build_db.py index d7d86c0983..02affed180 100755 --- a/sqlite/build_db.py +++ b/sqlite/build_db.py @@ -3,12 +3,8 @@ import os import sys -import graph_net - from graphsample_insert import insert_one_sample - - -GRAPH_NET_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(graph_net.__file__))) +from init_db import migrate def insert_from_list( @@ -44,36 +40,16 @@ def insert_from_list( return order_value -def main(): - parser = argparse.ArgumentParser( - description="Batch insert graph samples from list files" - ) - parser.add_argument( - "--db_path", - type=str, - default=None, - help="Database file path (default: GRAPH_NET_ROOT/sqlite/GraphNet.db)", - ) - parser.add_argument( - "--dataset_root", - type=str, - default=None, - help="Dataset root directory (default: GRAPH_NET_ROOT/20260317)", - ) - parser.add_argument( - "--repo_uid", - type=str, - default="hf_torch_samples", - help="Repository uid", - ) - args = parser.parse_args() - - dataset_root = args.dataset_root or os.path.join(GRAPH_NET_ROOT, "20260317") - db_path = args.db_path or os.path.join(GRAPH_NET_ROOT, "sqlite", "GraphNet.db") +def main(args): + dataset_root = args.dataset_root.strip() + db_path = args.db_path.strip() repo_uid = args.repo_uid.strip() + op_names_path_prefix = args.op_names_path_prefix.strip() - if not os.path.isfile(db_path): - print(f"Fail ! No Database ! : {db_path}") + if not os.path.exists(db_path): + migrate(db_path) + else: + print(f"Fail ! Path is not a file: {db_path}") sys.exit(1) order_value = 0 @@ -96,7 +72,7 @@ def main(): sample_type="typical_graph", repo_uid=repo_uid, db_path=db_path, - op_names_path_prefix=os.path.join(dataset_root, "03_sample_op_names"), + op_names_path_prefix=op_names_path_prefix, start_order=order_value, ) @@ -107,7 +83,7 @@ def main(): sample_type="fusible_graph", repo_uid=repo_uid, db_path=db_path, - op_names_path_prefix=os.path.join(dataset_root, "03_sample_op_names"), + op_names_path_prefix=op_names_path_prefix, start_order=order_value, ) @@ -118,7 +94,7 @@ def main(): sample_type="sole_op_graph", repo_uid=repo_uid, db_path=db_path, - op_names_path_prefix=os.path.join(dataset_root, "03_sample_op_names"), + op_names_path_prefix=op_names_path_prefix, start_order=order_value, ) @@ -126,4 +102,32 @@ def main(): if __name__ == "__main__": - main() + parser = argparse.ArgumentParser( + description="Batch insert graph samples from list files" + ) + parser.add_argument( + "--db_path", + type=str, + required=True, + help="Database file path", + ) + parser.add_argument( + "--dataset_root", + type=str, + required=True, + help="Dataset root directory", + ) + parser.add_argument( + "--repo_uid", + type=str, + default="hf_torch_samples", + help="Repository uid", + ) + parser.add_argument( + "--op_names_path_prefix", + type=str, + required=True, + help="Path prefix of op names files", + ) + args = parser.parse_args() + main(args) diff --git a/sqlite/init_db.py b/sqlite/init_db.py index da1d0ab577..b4214c00c0 100755 --- a/sqlite/init_db.py +++ b/sqlite/init_db.py @@ -1,6 +1,7 @@ import sqlite3 import re import argparse +import os from pathlib import Path @@ -12,7 +13,10 @@ def parse_timestamp(filename: str) -> int: return 0 -def migrate(db_path: str = "sqlite/GraphNet.db", migrates_dir: str = "sqlite/migrates"): +def migrate(db_path: str): + script_dir = os.path.dirname(os.path.abspath(__file__)) + migrates_dir = os.path.join(script_dir, "migrates") + db_path_obj = Path(db_path) migrates_path = Path(migrates_dir) @@ -58,8 +62,8 @@ def migrate(db_path: str = "sqlite/GraphNet.db", migrates_dir: str = "sqlite/mig parser.add_argument( "--db_path", type=str, - default="sqlite/GraphNet.db", - help="Database file path (default: sqlite/GraphNet.db)", + default="GraphNet.db", + help="Database file path", ) args = parser.parse_args() migrate(args.db_path) From 585d1fc6502f2894748bb414d2dfbbadb48e2494 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Sat, 9 May 2026 16:00:34 +0800 Subject: [PATCH 04/11] Improve build_db: auto-collect samples, loop over types, track stats - Auto-collect sample paths by scanning for model.py when list file is missing - Use loop over sample_types instead of repeated code blocks - Track and print success/fail counts and order range per type Co-Authored-By: Claude Opus 4.7 --- sqlite/build_db.py | 103 +++++++++++++++++------------------ sqlite/graphsample_insert.py | 4 ++ 2 files changed, 53 insertions(+), 54 deletions(-) diff --git a/sqlite/build_db.py b/sqlite/build_db.py index 02affed180..90bd5a896f 100755 --- a/sqlite/build_db.py +++ b/sqlite/build_db.py @@ -7,8 +7,19 @@ from init_db import migrate +def collect_sample_paths(model_path_prefix): + """Collect relative paths of directories containing model.py under model_path_prefix.""" + sample_paths = [] + for dirpath, _, filenames in os.walk(model_path_prefix): + if "model.py" in filenames: + rel = os.path.relpath(dirpath, model_path_prefix) + sample_paths.append(rel) + sample_paths.sort() + return sample_paths + + def insert_from_list( - list_file, + list_file_path, model_path_prefix, sample_type, repo_uid, @@ -16,17 +27,21 @@ def insert_from_list( op_names_path_prefix, start_order=0, ): - if not os.path.isfile(list_file): - print(f"List file not found: {list_file}, skipping") - return start_order - - with open(list_file) as f: - paths = [line.strip() for line in f if line.strip()] + if os.path.isfile(list_file_path): + with open(list_file_path) as f: + sample_paths = [line.strip() for line in f if line.strip()] + sample_paths.sort() + else: + print( + f"List file not found: {list_file_path}, collecting from {model_path_prefix}" + ) + sample_paths = collect_sample_paths(model_path_prefix) + total = len(sample_paths) order_value = start_order - for relative_model_path in paths: + for relative_model_path in sample_paths: print(f"insert : {relative_model_path}") - insert_one_sample( + successed = insert_one_sample( model_path_prefix=model_path_prefix, relative_model_path=relative_model_path, repo_uid=repo_uid, @@ -35,9 +50,10 @@ def insert_from_list( db_path=db_path, op_names_path_prefix=op_names_path_prefix, ) - order_value += 1 + if successed: + order_value += 1 - return order_value + return order_value, total def main(args): @@ -54,49 +70,28 @@ def main(args): order_value = 0 - # full_graph - order_value = insert_from_list( - list_file=os.path.join(dataset_root, "full_graph.txt"), - model_path_prefix=os.path.join(dataset_root, "full_graph"), - sample_type="full_graph", - repo_uid=repo_uid, - db_path=db_path, - op_names_path_prefix="", - start_order=order_value, - ) - - # typical_graph - order_value = insert_from_list( - list_file=os.path.join(dataset_root, "typical_graph.txt"), - model_path_prefix=os.path.join(dataset_root, "typical_graph"), - sample_type="typical_graph", - repo_uid=repo_uid, - db_path=db_path, - op_names_path_prefix=op_names_path_prefix, - start_order=order_value, - ) - - # fusible_graph - order_value = insert_from_list( - list_file=os.path.join(dataset_root, "fusible_graph.txt"), - model_path_prefix=os.path.join(dataset_root, "fusible_graph"), - sample_type="fusible_graph", - repo_uid=repo_uid, - db_path=db_path, - op_names_path_prefix=op_names_path_prefix, - start_order=order_value, - ) - - # sole_op_graph - order_value = insert_from_list( - list_file=os.path.join(dataset_root, "sole_op_graph.txt"), - model_path_prefix=os.path.join(dataset_root, "sole_op_graph"), - sample_type="sole_op_graph", - repo_uid=repo_uid, - db_path=db_path, - op_names_path_prefix=op_names_path_prefix, - start_order=order_value, - ) + sample_types = [ + ("full_graph", ""), + ("typical_graph", op_names_path_prefix), + ("fusible_graph", op_names_path_prefix), + ("sole_op_graph", op_names_path_prefix), + ] + for sample_type, op_prefix in sample_types: + order_start = order_value + order_value, total = insert_from_list( + list_file_path=os.path.join(dataset_root, f"{sample_type}.txt"), + model_path_prefix=os.path.join(dataset_root, sample_type), + sample_type=sample_type, + repo_uid=repo_uid, + db_path=db_path, + op_names_path_prefix=op_prefix, + start_order=order_value, + ) + num_success = order_value - order_start + print( + f"[{sample_type}] total={total}, success={num_success}, " + f"fail={total - num_success}, order=[{order_start}, {order_value})" + ) print("all done") diff --git a/sqlite/graphsample_insert.py b/sqlite/graphsample_insert.py index 7fbc5e52c4..153b620921 100755 --- a/sqlite/graphsample_insert.py +++ b/sqlite/graphsample_insert.py @@ -507,8 +507,12 @@ def insert_one_sample( except sqlite3.IntegrityError as e: print("insert failed: integrity error (possible duplicate uuid or graph_hash)") print(f"error info: {e}") + return False except Exception as e: print(f"insert failed: {e}") + return False + + return True def main(args): From 34150d24594dfd4149e569ab3cb746903f04a9b8 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Sat, 9 May 2026 16:11:14 +0800 Subject: [PATCH 05/11] Add directory check and path logging for each sample type in build_db - Skip non-full_graph types when directory is missing - Print sample dir and list file paths before processing each type Co-Authored-By: Claude Opus 4.7 --- sqlite/build_db.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/sqlite/build_db.py b/sqlite/build_db.py index 90bd5a896f..773d17dc15 100755 --- a/sqlite/build_db.py +++ b/sqlite/build_db.py @@ -77,10 +77,19 @@ def main(args): ("sole_op_graph", op_names_path_prefix), ] for sample_type, op_prefix in sample_types: + model_path_prefix = os.path.join(dataset_root, sample_type) + list_file_path = os.path.join(dataset_root, f"{sample_type}_list.txt") + print(f"\n[{sample_type}] samples={model_path_prefix}, list={list_file_path}") + if not os.path.isdir(model_path_prefix): + if sample_type == "full_graph": + print(f"Fail ! full_graph directory not found: {model_path_prefix}") + sys.exit(1) + print(f"[{sample_type}] skipped, directory not found") + continue order_start = order_value order_value, total = insert_from_list( - list_file_path=os.path.join(dataset_root, f"{sample_type}.txt"), - model_path_prefix=os.path.join(dataset_root, sample_type), + list_file_path=list_file_path, + model_path_prefix=model_path_prefix, sample_type=sample_type, repo_uid=repo_uid, db_path=db_path, From a3f7c0be222140d4669b36191145f4dd94bfe6dc Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Sat, 9 May 2026 16:13:19 +0800 Subject: [PATCH 06/11] Simplify sample_types to a plain string list in build_db Co-Authored-By: Claude Opus 4.7 --- sqlite/build_db.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/sqlite/build_db.py b/sqlite/build_db.py index 773d17dc15..d42c4c330e 100755 --- a/sqlite/build_db.py +++ b/sqlite/build_db.py @@ -64,28 +64,23 @@ def main(args): if not os.path.exists(db_path): migrate(db_path) - else: - print(f"Fail ! Path is not a file: {db_path}") - sys.exit(1) + print(f"db_path={db_path}, repo_uid={repo_uid}") order_value = 0 - sample_types = [ - ("full_graph", ""), - ("typical_graph", op_names_path_prefix), - ("fusible_graph", op_names_path_prefix), - ("sole_op_graph", op_names_path_prefix), - ] - for sample_type, op_prefix in sample_types: + sample_types = ["full_graph", "typical_graph", "fusible_graph", "sole_op_graph"] + for sample_type in sample_types: model_path_prefix = os.path.join(dataset_root, sample_type) list_file_path = os.path.join(dataset_root, f"{sample_type}_list.txt") print(f"\n[{sample_type}] samples={model_path_prefix}, list={list_file_path}") + if not os.path.isdir(model_path_prefix): if sample_type == "full_graph": print(f"Fail ! full_graph directory not found: {model_path_prefix}") sys.exit(1) print(f"[{sample_type}] skipped, directory not found") continue + order_start = order_value order_value, total = insert_from_list( list_file_path=list_file_path, @@ -93,7 +88,7 @@ def main(args): sample_type=sample_type, repo_uid=repo_uid, db_path=db_path, - op_names_path_prefix=op_prefix, + op_names_path_prefix=op_names_path_prefix, start_order=order_value, ) num_success = order_value - order_start From 71a1d21ed6d94567405a2920d85d5ef38a3625d6 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 11 May 2026 10:37:03 +0800 Subject: [PATCH 07/11] Rename and minor fix. --- sqlite/build_db.py | 1 + sqlite/download.py | 22 ----- sqlite/download_dataset.py | 44 +++++++++ ...or.py => graph_sample_bucket_generator.py} | 0 ...nsert.py => graph_sample_groups_insert.py} | 0 sqlite/graphsample_insert.py | 10 ++- sqlite/upload.py | 65 -------------- sqlite/upload_dataset.py | 90 +++++++++++++++++++ 8 files changed, 141 insertions(+), 91 deletions(-) delete mode 100755 sqlite/download.py create mode 100755 sqlite/download_dataset.py rename sqlite/{graph_net_sample_bucket_generator.py => graph_sample_bucket_generator.py} (100%) rename sqlite/{graph_net_sample_groups_insert.py => graph_sample_groups_insert.py} (100%) delete mode 100755 sqlite/upload.py create mode 100755 sqlite/upload_dataset.py diff --git a/sqlite/build_db.py b/sqlite/build_db.py index d42c4c330e..2ee7176ed4 100755 --- a/sqlite/build_db.py +++ b/sqlite/build_db.py @@ -52,6 +52,7 @@ def insert_from_list( ) if successed: order_value += 1 + assert successed return order_value, total diff --git a/sqlite/download.py b/sqlite/download.py deleted file mode 100755 index d9cec2c722..0000000000 --- a/sqlite/download.py +++ /dev/null @@ -1,22 +0,0 @@ -import os -from datasets import load_dataset -from huggingface_hub import hf_hub_download - -REPO_ID = "PaddlePaddle/GraphNet" -REVISION = "20260224" -SAVE_DIR = "./workspace" - -ds = load_dataset(REPO_ID, split="GraphNet", revision=REVISION) -for item in ds: - full_path = os.path.join(SAVE_DIR, item["path"]) - os.makedirs(os.path.dirname(full_path), exist_ok=True) - with open(full_path, "w", encoding="utf-8") as f: - f.write(item["content"]) - -hf_hub_download( - repo_id=REPO_ID, - filename="GraphNet.db", - repo_type="dataset", - revision=REVISION, - local_dir=SAVE_DIR, -) diff --git a/sqlite/download_dataset.py b/sqlite/download_dataset.py new file mode 100755 index 0000000000..9e2702b1e6 --- /dev/null +++ b/sqlite/download_dataset.py @@ -0,0 +1,44 @@ +import argparse +import os +from datasets import load_dataset +from huggingface_hub import hf_hub_download + + +def main(args): + ds = load_dataset(args.repo_id, split=args.split, revision=args.revision) + for item in ds: + full_path = os.path.join(args.save_dir, item["path"]) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, "w", encoding="utf-8") as f: + f.write(item["content"]) + + hf_hub_download( + repo_id=args.repo_id, + filename=args.db_file, + repo_type="dataset", + revision=args.revision, + local_dir=args.save_dir, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Download dataset and DB from HuggingFace Hub" + ) + parser.add_argument( + "--repo_id", type=str, default="PaddlePaddle/GraphNet", help="HF repo ID" + ) + parser.add_argument( + "--revision", type=str, default="main", help="HF repo revision/branch" + ) + parser.add_argument( + "--save_dir", type=str, default="./workspace", help="Local save directory" + ) + parser.add_argument( + "--split", type=str, default="GraphNet", help="Dataset split name" + ) + parser.add_argument( + "--db_file", type=str, default="GraphNet.db", help="DB filename to download" + ) + args = parser.parse_args() + main(args) diff --git a/sqlite/graph_net_sample_bucket_generator.py b/sqlite/graph_sample_bucket_generator.py similarity index 100% rename from sqlite/graph_net_sample_bucket_generator.py rename to sqlite/graph_sample_bucket_generator.py diff --git a/sqlite/graph_net_sample_groups_insert.py b/sqlite/graph_sample_groups_insert.py similarity index 100% rename from sqlite/graph_net_sample_groups_insert.py rename to sqlite/graph_sample_groups_insert.py diff --git a/sqlite/graphsample_insert.py b/sqlite/graphsample_insert.py index 153b620921..f3638d1bf1 100755 --- a/sqlite/graphsample_insert.py +++ b/sqlite/graphsample_insert.py @@ -464,7 +464,7 @@ def insert_one_sample( sample_type=sample_type, order_value=order_value, ) - print(f"\ninsert into database: {db_path}") + print(f"\ninsert into database: {db_path=}, {sample_type=}") try: insert_graph_sample(db_path, data, model_path_prefix) if data["is_subgraph"]: @@ -503,13 +503,15 @@ def insert_one_sample( relative_model_path, db_path, ) - print(f"success insert: {data['relative_model_path']}") + print(f"insert {sample_type} success: {data['relative_model_path']}") except sqlite3.IntegrityError as e: - print("insert failed: integrity error (possible duplicate uuid or graph_hash)") + print( + "insert {sample_type} failed: integrity error (possible duplicate uuid or graph_hash)" + ) print(f"error info: {e}") return False except Exception as e: - print(f"insert failed: {e}") + print(f"insert {sample_type} failed: {e}") return False return True diff --git a/sqlite/upload.py b/sqlite/upload.py deleted file mode 100755 index 72308f28a4..0000000000 --- a/sqlite/upload.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -from datasets import Dataset -from huggingface_hub import HfApi, login - - -HF_TOKEN = "" -REPO_ID = "PaddlePaddle/GraphNet" -REVISION = "20260203" -BASE_DIR = "/work/GraphNet/torch_paddle_samples/subgraph_dataset_20260203" -FOLDERS_TO_PACK = ["full_graph", "fusible_graph", "sole_op_graph", "typical_graph"] -DB_FILE = "GraphNet.db" - - -def is_clean_file(filename, root): - ext = os.path.splitext(filename)[1].lower() - if ext in {".pyc", ".pyo", ".pyd", ".so"}: - return False - if any(x in root for x in ["__pycache__", ".git", ".ipynb_checkpoints"]): - return False - return True - - -def file_generator(): - file_list = [ - (os.path.join(root, f), folder) - for folder in FOLDERS_TO_PACK - if os.path.exists(os.path.join(BASE_DIR, folder)) - for root, _, files in os.walk(os.path.join(BASE_DIR, folder)) - for f in files - if is_clean_file(f, root) - and os.path.splitext(f)[1].lower() in {".py", ".json", ".txt", ".yaml", ".md"} - ] - - return ( - { - "path": os.path.relpath(fp, BASE_DIR), - "content": open(fp, "r", encoding="utf-8", errors="ignore").read(), - "source_folder": src, - } - for fp, src in file_list - ) - - -def main(): - login(token=HF_TOKEN) - - ds = Dataset.from_generator(file_generator) - ds.push_to_hub(REPO_ID, split="GraphNet", max_shard_size="500MB", revision=REVISION) - print("Folder data uploaded successfully!") - - api = HfApi() - db_path = os.path.join(BASE_DIR, DB_FILE) - if os.path.exists(db_path): - api.upload_file( - path_or_fileobj=db_path, - path_in_repo=DB_FILE, - repo_id=REPO_ID, - repo_type="dataset", - revision=REVISION, - ) - print(f"{DB_FILE} uploaded successfully!") - - -if __name__ == "__main__": - main() diff --git a/sqlite/upload_dataset.py b/sqlite/upload_dataset.py new file mode 100755 index 0000000000..d2e40756b0 --- /dev/null +++ b/sqlite/upload_dataset.py @@ -0,0 +1,90 @@ +import argparse +import os +from datasets import Dataset +from huggingface_hub import HfApi, login + + +def is_clean_file(filename, root): + ext = os.path.splitext(filename)[1].lower() + if ext in {".pyc", ".pyo", ".pyd", ".so"}: + return False + if any(x in root for x in ["__pycache__", ".git", ".ipynb_checkpoints"]): + return False + return True + + +def file_generator(base_dir, folders): + file_list = [ + (os.path.join(root, f), folder) + for folder in folders + if os.path.exists(os.path.join(base_dir, folder)) + for root, _, files in os.walk(os.path.join(base_dir, folder)) + for f in files + if is_clean_file(f, root) + and os.path.splitext(f)[1].lower() in {".py", ".json", ".txt", ".yaml", ".md"} + ] + + return ( + { + "path": os.path.relpath(fp, base_dir), + "content": open(fp, "r", encoding="utf-8", errors="ignore").read(), + "source_folder": src, + } + for fp, src in file_list + ) + + +def main(args): + folders = ["full_graph", "fusible_graph", "sole_op_graph", "typical_graph"] + + login(token=args.hf_token) + + ds = Dataset.from_generator(lambda: file_generator(args.base_dir, folders)) + ds.push_to_hub( + args.repo_id, + split=args.split, + max_shard_size=args.max_shard_size, + revision=args.revision, + ) + print("Folder data uploaded successfully!") + + api = HfApi() + db_path = os.path.join(args.base_dir, args.db_file) + if os.path.exists(db_path): + api.upload_file( + path_or_fileobj=db_path, + path_in_repo=args.db_file, + repo_id=args.repo_id, + repo_type="dataset", + revision=args.revision, + ) + print(f"{args.db_file} uploaded successfully!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Upload dataset and DB to HuggingFace Hub" + ) + parser.add_argument( + "--hf_token", type=str, required=True, help="HuggingFace API token" + ) + parser.add_argument( + "--repo_id", type=str, default="PaddlePaddle/GraphNet", help="HF repo ID" + ) + parser.add_argument( + "--revision", type=str, default="main", help="HF repo revision/branch" + ) + parser.add_argument( + "--base_dir", type=str, required=True, help="Local dataset root directory" + ) + parser.add_argument( + "--db_file", type=str, default="GraphNet.db", help="DB filename in base_dir" + ) + parser.add_argument( + "--split", type=str, default="GraphNet", help="Dataset split name" + ) + parser.add_argument( + "--max_shard_size", type=str, default="500MB", help="Max shard size" + ) + args = parser.parse_args() + main(args) From ed2240baf6ff9081ad8bbf6c35b6a1c4ad994c8a Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 11 May 2026 10:58:08 +0800 Subject: [PATCH 08/11] Update README. --- sqlite/README.md | 135 +++++++++++++++++++++++++++++++++++++++++++++++ sqlite/Readme.md | 55 ------------------- 2 files changed, 135 insertions(+), 55 deletions(-) create mode 100755 sqlite/README.md delete mode 100755 sqlite/Readme.md diff --git a/sqlite/README.md b/sqlite/README.md new file mode 100755 index 0000000000..58cc59f1b5 --- /dev/null +++ b/sqlite/README.md @@ -0,0 +1,135 @@ +# GraphNet SQLite 操作指南 + +## 目录结构 + +``` +sqlite/ +├── migrates/ # SQL 迁移文件(按时间戳顺序执行) +├── orm_models.py # SQLAlchemy ORM 模型定义 +├── init_db.py # 数据库初始化 +├── build_db.py # 批量建库(推荐) +├── graphsample_insert.py # 单条样本插入 +├── graphsample_delete.py # 单条样本删除(软删除) +├── merge_db.py # 数据库合并 +├── graph_sample_bucket_generator.py # 样本分桶元数据生成 +├── graph_sample_groups_insert.py # 采样分组生成 +├── upload_dataset.py # HuggingFace 上传 +├── download_dataset.py # HuggingFace 下载 +``` + +## 数据表概览 + +| 表名 | 用途 | +|------|------| +| `repo` | 仓库源信息 | +| `graph_sample` | 计算图样本主表 | +| `subgraph_source` | 子图来源映射 | +| `dimension_generalization_source` | 维度泛化来源 | +| `datatype_generalization_source` | 数据类型泛化来源 | +| `backward_graph_source` | 反向图来源 | +| `sample_op_name` / `sample_op_name_list` | 算子名称序列 | +| `sample_input_tensor_meta` | 输入张量元信息 | +| `graph_net_sample_buckets` | 样本分桶元数据 | +| `graph_net_sample_groups` | 采样分组 | + +所有删除操作均为软删除(`deleted` 字段标记),不物理删除数据。 + +## 数据库初始化 + +从 `migrates/` 目录按时间戳顺序执行 SQL 文件,创建所有表结构。**库文件已存在时会被删除重建。** + +```bash +# 默认路径 GraphNet.db +python init_db.py 2>&1 | tee logs/init_db_$(date +"%Y%m%d_%H%M%S").log + +# 自定义路径 +python init_db.py --db_path xxx.db +``` + +## 批量建库(推荐) + +一次性处理 `full_graph`、`typical_graph`、`fusible_graph`、`sole_op_graph` 四种样本类型,自动收集目录或读取 list 文件后逐条插入。库文件不存在时自动初始化。 + +```bash +python build_db.py \ + --db_path GraphNet.db \ + --dataset_root /path/to/dataset \ + --repo_uid "hf_torch_samples" \ + --op_names_path_prefix /path/to/sample_op_names +``` + +## 单条样本操作 + +```bash +# 插入单条 +python graphsample_insert.py \ + --model_path_prefix /path/to/dataset/full_graph \ + --relative_model_path models/torch/resnet18 \ + --repo_uid "hf_torch_samples" \ + --sample_type "full_graph" \ + --order_value 0 \ + --db_path GraphNet.db + +# 删除单条(软删除,设置 deleted=1) +python graphsample_delete.py \ + --db_path GraphNet.db \ + --repo_uid "hf_torch_samples" \ + --relative_model_path "models/torch/resnet18" +``` + +## Shell 批量脚本 + +```bash +# 批量插入(从 list 文件逐行读取) +bash graphsample_insert.sh [db_path] + +# 批量删除(从 graph_net/config/delete_list.txt 读取) +bash graphsample_delete.sh [db_path] +``` + +## 数据库合并 + +将新库的所有记录合并到主库,自动跳过已存在的 repo 和 graph_sample。 + +```bash +python merge_db.py \ + --main_db_path GraphNet.db \ + --new_db_path new.db +``` + +## 样本分桶与分组 + +分两步:先生成样本的分桶元数据(op 序列、input shape、dtype 的哈希 ID),再基于分桶结果按策略生成采样分组。 + +```bash +# 生成分桶元数据 → 写入 graph_net_sample_buckets 表 +python graph_sample_bucket_generator.py --db_path GraphNet.db + +# 生成采样分组 → 写入 graph_net_sample_groups 表 +# 策略: bucket_policy_v1 (stride-16 + cross-shape) + bucket_policy_v2 (dtype coverage + sparse) +python graph_sample_groups_insert.py --db_path GraphNet.db --num_dtypes 3 +``` + +## HuggingFace 上传/下载 + +```bash +# 上传:打包 dataset 目录 + GraphNet.db 到 HF Hub +python upload_dataset.py \ + --hf_token \ + --base_dir /path/to/dataset \ + --repo_id "PaddlePaddle/GraphNet" \ + --revision "20260203" \ + --split "GraphNet" + +# 下载:从 HF Hub 拉取 dataset 和 GraphNet.db +python download_dataset.py \ + --repo_id "PaddlePaddle/GraphNet" \ + --revision "20260224" \ + --save_dir ./workspace \ + --split "GraphNet" +``` + +## 关联资源 + +- ORM 模型定义: [orm_models.py](orm_models.py) +- SQL 迁移文件: [migrates/](migrates/) \ No newline at end of file diff --git a/sqlite/Readme.md b/sqlite/Readme.md deleted file mode 100755 index dd181bf4d2..0000000000 --- a/sqlite/Readme.md +++ /dev/null @@ -1,55 +0,0 @@ -# SQLite Operations Guide - -**Working Directory:** `/GraphNet` - -## Setup - -```bash -# Create logs directory -mkdir -p sqlite/logs -``` - -## Initialize Database - -```bash -# Default DB path (sqlite/GraphNet.db) -python ./sqlite/init_db.py 2>&1 | tee sqlite/logs/init_db_$(date +"%Y%m%d_%H%M%S").log - -# Custom DB path -python ./sqlite/init_db.py --db_path sqlite/xxx.db 2>&1 | tee sqlite/logs/init_db_$(date +"%Y%m%d_%H%M%S").log -``` - -## Insert Graph Samples - -```bash -# Insert to custom DB -bash ./sqlite/graphsample_insert.sh sqlite/xxx.db 2>&1 | tee sqlite/logs/insert_$(date +"%Y%m%d_%H%M%S").log - -# Insert to default DB (sqlite/GraphNet.db) -bash ./sqlite/graphsample_insert.sh 2>&1 | tee sqlite/logs/insert_$(date +"%Y%m%d_%H%M%S").log -``` - -## Delete Graph Samples - -```bash -# Delete from custom DB -bash ./sqlite/graphsample_delete.sh sqlite/xxx.db 2>&1 | tee sqlite/logs/delete_$(date +"%Y%m%d_%H%M%S").log - -# Delete from default DB (sqlite/GraphNet.db) -bash ./sqlite/graphsample_delete.sh 2>&1 | tee sqlite/logs/delete_$(date +"%Y%m%d_%H%M%S").log -``` - -## Merge Databases - -```bash -# Usage: python ./sqlite/merge_db.py --main_db_path --new_db_path -python ./sqlite/merge_db.py --main_db_path sqlite/GraphNet.db --new_db_path sqlite/new.db -``` - -## Upload to Hugging Face - -```bash -python ./sqlite/upload.py -``` - -**Note:** Set `HF_TOKEN` variable in `upload.py` before running. From 7b1c685ca4f871bbdcd008e81b94e7d48df9dfd8 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 11 May 2026 11:32:21 +0800 Subject: [PATCH 09/11] Add generation of buckets and groups into the build of db. --- graph_net/tools/generate_subgraph_dataset.sh | 55 ---------------- sqlite/build_db.py | 6 ++ sqlite/graph_sample_bucket_generator.py | 66 ++++++++------------ sqlite/graph_sample_groups_insert.py | 45 +++++-------- 4 files changed, 47 insertions(+), 125 deletions(-) diff --git a/graph_net/tools/generate_subgraph_dataset.sh b/graph_net/tools/generate_subgraph_dataset.sh index 208ae0d182..e29a0cbb33 100755 --- a/graph_net/tools/generate_subgraph_dataset.sh +++ b/graph_net/tools/generate_subgraph_dataset.sh @@ -93,36 +93,6 @@ function generate_subgraph_list() { | tee $sample_list } -function insert_graph_sample(){ - local target_dir="$1" - local repo_uid="$2" - local sample_type="$3" - local sample_list="$4" - echo ">>> [0] Inserting samples into database: ${DB_PATH}." - echo ">>>" - - if [ ! -f "$DB_PATH" ]; then - echo "Fail ! No Database ! : $DB_PATH" - exit 1 - fi - - local order_value=0 - while IFS= read -r model_rel_path; do - echo "insert : $model_rel_path" - python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \ - --model_path_prefix "${target_dir}" \ - --relative_model_path "$model_rel_path" \ - --repo_uid "${repo_uid}" \ - --sample_type "${sample_type}" \ - --order_value "$order_value" \ - --db_path "$DB_PATH" - - ((order_value++)) - - done < "$sample_list" -} - - function rewrite_device() { echo ">>> [1] Rewrite devices for subgraph samples under ${GRAPH_NET_ROOT}." echo ">>>" @@ -658,29 +628,6 @@ function generate_typical_subgraphs() { # generate_unittest_for_typical_subgraphs 2>&1 | tee ${DECOMPOSE_WORKSPACE}/log_unittests_typical_subgraphs_${suffix}.txt } -function generate_database() { - timestamp=`date +%Y%m%d_%H%M` - - # init database - if [ ! -f ${DB_PATH} ]; then - python ${GRAPH_NET_ROOT}/sqlite/init_db.py --db_path ${DB_PATH} 2>&1 | tee ${DECOMPOSE_WORKSPACE}/log_init_db_${timestamp}.txt - fi - - # full_graph - insert_graph_sample ${GRAPH_NET_ROOT} "hf_torch_samples" "full_graph" ${model_list} - - # fusible_graph, typical_graph - for sample_type in fusible_graph typical_graph; do - insert_graph_sample $OUTPUT_DIR/$sample_type "hf_torch_samples" $sample_type $OUTPUT_DIR/${sample_type}/sample_list.txt - done - - # insert buckets - python ${GRAPH_NET_ROOT}/sqlite/graph_net_sample_bucket_generator.py --db_path ${DB_PATH} - - # insert groups - python ${GRAPH_NET_ROOT}/sqlite/graph_net_sample_groups_insert.py --db_path ${DB_PATH} -} - function main() { do_common_generalzation_and_decompose @@ -693,8 +640,6 @@ function main() { generate_typical_subgraphs #cp -rf $DTYPE_GENERALIZED_TYPICAL_SUBGRAPH_DIR $OUTPUT_DIR/$sample_type #cp -rf $dtype_generalized_typical_subgraph_list $OUTPUT_DIR/$sample_type/sample_list.txt - - #generate_database } function summary() { diff --git a/sqlite/build_db.py b/sqlite/build_db.py index 2ee7176ed4..889b468fb0 100755 --- a/sqlite/build_db.py +++ b/sqlite/build_db.py @@ -100,6 +100,12 @@ def main(args): print("all done") + from graph_sample_bucket_generator import generate_buckets + from graph_sample_groups_insert import generate_groups + + generate_buckets(db_path) + generate_groups(db_path, num_dtypes=3) + if __name__ == "__main__": parser = argparse.ArgumentParser( diff --git a/sqlite/graph_sample_bucket_generator.py b/sqlite/graph_sample_bucket_generator.py index 0a690e4c67..dbbf3b2682 100644 --- a/sqlite/graph_sample_bucket_generator.py +++ b/sqlite/graph_sample_bucket_generator.py @@ -213,52 +213,36 @@ def save_bucket_results( return count -def main(): - parser = argparse.ArgumentParser( - description="Generate graph_net_sample_buckets from graph_sample" - ) - parser.add_argument( - "--db_path", - type=str, - required=True, - help="Path to the SQLite database file", - ) - parser.add_argument( - "--dry_run", - action="store_true", - help="Only print what would be done, don't actually insert into database", - ) - - args = parser.parse_args() - - session = get_session(args.db_path) - - print("=" * 70) - print("Step 1: Generating bucket info from graph_sample...") - sample_type_results, all_bucket_info_map = generate_sample_buckets(session) - print(f" Total samples: {len(all_bucket_info_map)}") - print(f" Number of sample_types: {len(sample_type_results)}") - - print() - for result in sorted(sample_type_results, key=lambda x: -len(x)): - flag = " [sole-op]" if result.is_sole_op else "" - print(f" {result.sample_type}{flag}: {len(result)} samples") +def generate_buckets(db_path, dry_run=False): + """Generate buckets and save to DB.""" + session = get_session(db_path) + try: + print("=" * 70) + print("Step 1: Generating bucket info from graph_sample...") + _, bucket_info_map = generate_sample_buckets(session) + if dry_run: + print("Dry run mode - skipping database insert") + print( + f" Would insert {len(bucket_info_map)} records into graph_net_sample_buckets" + ) + return 0 - print("=" * 70) - if args.dry_run: - print("Dry run mode - skipping database insert") - print( - f" Would insert {len(all_bucket_info_map)} records into graph_net_sample_buckets" - ) - else: print("Step 2: Saving to database...") - count = save_bucket_results(session, all_bucket_info_map) + count = save_bucket_results(session, bucket_info_map) print(f" Inserted {count} records into graph_net_sample_buckets") + return count + finally: + session.close() - print("=" * 70) - print("Done!") - session.close() +def main(): + parser = argparse.ArgumentParser( + description="Generate graph_net_sample_buckets from graph_sample" + ) + parser.add_argument("--db_path", type=str, required=True) + parser.add_argument("--dry_run", action="store_true") + args = parser.parse_args() + generate_buckets(args.db_path, dry_run=args.dry_run) if __name__ == "__main__": diff --git a/sqlite/graph_sample_groups_insert.py b/sqlite/graph_sample_groups_insert.py index d22e64e231..80f490f91d 100755 --- a/sqlite/graph_sample_groups_insert.py +++ b/sqlite/graph_sample_groups_insert.py @@ -7,8 +7,6 @@ from orm_models import get_session, GraphNetSampleGroup -# ── Types ── - BucketGroup = namedtuple( "BucketGroup", ["head_uid", "op_seq", "shapes", "sample_type", "all_uids_csv"], @@ -20,9 +18,6 @@ ) -# ── Helpers ── - - def _new_group_id(): return str(uuid_module.uuid4()) @@ -51,9 +46,6 @@ def _print_stats(stats): print(f"\n Total: {total_records} records, {total_groups} groups.") -# ── Database Queries ── - - class DB: def __init__(self, path): self.path = path @@ -229,39 +221,24 @@ def _insert_groups(session, rows, policy): return stats -# ═══════════════════════════════════════════════════════════════════ -# Main -# ═══════════════════════════════════════════════════════════════════ - - -def main(): - parser = argparse.ArgumentParser( - description="Generate graph_net_sample_groups (v1 + v2)" - ) - parser.add_argument("--db_path", type=str, required=True) - parser.add_argument("--num_dtypes", type=int, default=3) - args = parser.parse_args() - - db = DB(args.db_path) +def generate_groups(db_path, num_dtypes=3): + """Generate sample groups and save to DB.""" + db = DB(db_path) db.connect() - session = get_session(args.db_path) - + session = get_session(db_path) all_stats = defaultdict(lambda: {"records": 0, "groups": set()}) - try: - # V1 buckets = query_bucket_groups(db) print(f"Bucket groups: {len(buckets)}") v1 = _insert_groups(session, generate_v1_groups(buckets), "bucket_policy_v1") _merge_stats(all_stats, v1) - # V2 candidates = query_v2_candidates(db) print(f"V2 candidates: {len(candidates)}") if candidates: v2 = _insert_groups( session, - generate_v2_groups(candidates, args.num_dtypes), + generate_v2_groups(candidates, num_dtypes), "bucket_policy_v2", ) _merge_stats(all_stats, v2) @@ -276,7 +253,17 @@ def main(): print("=" * 60) _print_stats(all_stats) - print("\nDone!") + return all_stats + + +def main(): + parser = argparse.ArgumentParser( + description="Generate graph_net_sample_groups (v1 + v2)" + ) + parser.add_argument("--db_path", type=str, required=True) + parser.add_argument("--num_dtypes", type=int, default=3) + args = parser.parse_args() + generate_groups(args.db_path, args.num_dtypes) if __name__ == "__main__": From c865655a2854d0498645079fc0679fd7aa080b60 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 11 May 2026 11:50:58 +0800 Subject: [PATCH 10/11] Rename variables. --- sqlite/build_db.py | 3 +- sqlite/graph_sample_groups_insert.py | 76 +++++++++++++--------------- 2 files changed, 37 insertions(+), 42 deletions(-) diff --git a/sqlite/build_db.py b/sqlite/build_db.py index 889b468fb0..12ee6615d9 100755 --- a/sqlite/build_db.py +++ b/sqlite/build_db.py @@ -98,11 +98,10 @@ def main(args): f"fail={total - num_success}, order=[{order_start}, {order_value})" ) - print("all done") - from graph_sample_bucket_generator import generate_buckets from graph_sample_groups_insert import generate_groups + print("\nGenerate buckets and groups:") generate_buckets(db_path) generate_groups(db_path, num_dtypes=3) diff --git a/sqlite/graph_sample_groups_insert.py b/sqlite/graph_sample_groups_insert.py index 80f490f91d..aaf83adb11 100755 --- a/sqlite/graph_sample_groups_insert.py +++ b/sqlite/graph_sample_groups_insert.py @@ -38,11 +38,11 @@ def _print_stats(stats): for rule in rule_order: key = (sample_type, rule) if key in stats: - n_records = stats[key]["records"] - n_groups = len(stats[key]["groups"]) - print(f" {rule}: {n_records} records, {n_groups} groups") - total_records += n_records - total_groups += n_groups + record_count = stats[key]["records"] + group_count = len(stats[key]["groups"]) + print(f" {rule}: {record_count} records, {group_count} groups") + total_records += record_count + total_groups += group_count print(f"\n Total: {total_records} records, {total_groups} groups.") @@ -114,12 +114,10 @@ def query_v2_candidates(db: DB) -> list[Candidate]: # ═══════════════════════════════════════════════════════════════════ # V1: Rule 1 (bucket-internal stride sampling) + Rule 2 (cross-shape) # ═══════════════════════════════════════════════════════════════════ - - def generate_v1_groups(bucket_groups: list[BucketGroup]): """Yields (sample_type, uid, group_id, rule_name). - Rule 1: stride-16 sampling within each bucket, one group per sample. + Rule 1: stride-16 sampling within each bucket, 1 subgraph per group. Rule 2: aggregate all bucket heads sharing the same (sample_type, op_seq). """ # Rule 1 @@ -133,54 +131,55 @@ def generate_v1_groups(bucket_groups: list[BucketGroup]): heads_by_type_op = defaultdict(list) for bucket in bucket_groups: heads_by_type_op[(bucket.sample_type, bucket.op_seq)].append(bucket.head_uid) - for (sample_type, _op), heads in heads_by_type_op.items(): - gid = _new_group_id() + for (sample_type, _), heads in heads_by_type_op.items(): + group_id = _new_group_id() for uid in heads: - yield sample_type, uid, gid, "rule2" + yield sample_type, uid, group_id, "rule2" # ═══════════════════════════════════════════════════════════════════ # V2: Rule 4 (dtype coverage) + Rule 3 (sparse sampling on remainder) # ═══════════════════════════════════════════════════════════════════ - - def generate_v2_groups(candidates: list[Candidate], num_dtypes: int): """Yields (sample_type, uid, group_id, rule_name). Rule 4 (first): per (sample_type, op_seq, shape), pick up to - num_dtypes samples with distinct dtypes. + num_dtypes samples with distinct dtypes. Rule 3 (second): window-based sparse sampling on the remainder, window_size = num_dtypes * 5, pick first num_dtypes. """ by_type_op = defaultdict(list) - for c in candidates: - by_type_op[(c.sample_type, c.op_seq)].append(c) + for candidate in candidates: + by_type_op[(candidate.sample_type, candidate.op_seq)].append(candidate) covered_uids = set() # Rule 4: dtype coverage - for (sample_type, _op), group in by_type_op.items(): + for (sample_type, _), group in by_type_op.items(): by_shape = defaultdict(list) - for c in group: - by_shape[c.shapes].append(c) + for candidate in group: + by_shape[candidate.shapes].append(candidate) picked = [] - for _shape, shape_group in by_shape.items(): - seen_dtypes = set() - for c in shape_group: - if c.dtypes not in seen_dtypes and len(seen_dtypes) < num_dtypes: - seen_dtypes.add(c.dtypes) - picked.append(c.uid) - covered_uids.add(c.uid) + for _, shape_group in by_shape.items(): + picked_dtypes = set() + for candidate in shape_group: + if ( + candidate.dtypes not in picked_dtypes + and len(picked_dtypes) < num_dtypes + ): + picked_dtypes.add(candidate.dtypes) + picked.append(candidate.uid) + covered_uids.add(candidate.uid) if picked: - gid = _new_group_id() + group_id = _new_group_id() for uid in picked: - yield sample_type, uid, gid, "rule4" + yield sample_type, uid, group_id, "rule4" # Rule 3: sparse sampling on remainder window_size = num_dtypes * 5 - for (sample_type, _op), group in by_type_op.items(): + for (sample_type, _), group in by_type_op.items(): remaining = sorted( (c for c in group if c.uid not in covered_uids), key=lambda c: c.uid, @@ -189,14 +188,9 @@ def generate_v2_groups(candidates: list[Candidate], num_dtypes: int): c.uid for i, c in enumerate(remaining) if (i % window_size) < num_dtypes ] if picked: - gid = _new_group_id() + group_id = _new_group_id() for uid in picked: - yield sample_type, uid, gid, "rule3" - - -# ═══════════════════════════════════════════════════════════════════ -# Insert -# ═══════════════════════════════════════════════════════════════════ + yield sample_type, uid, group_id, "rule3" def _insert_groups(session, rows, policy): @@ -230,18 +224,20 @@ def generate_groups(db_path, num_dtypes=3): try: buckets = query_bucket_groups(db) print(f"Bucket groups: {len(buckets)}") - v1 = _insert_groups(session, generate_v1_groups(buckets), "bucket_policy_v1") - _merge_stats(all_stats, v1) + v1_stats = _insert_groups( + session, generate_v1_groups(buckets), "bucket_policy_v1" + ) + _merge_stats(all_stats, v1_stats) candidates = query_v2_candidates(db) print(f"V2 candidates: {len(candidates)}") if candidates: - v2 = _insert_groups( + v2_stats = _insert_groups( session, generate_v2_groups(candidates, num_dtypes), "bucket_policy_v2", ) - _merge_stats(all_stats, v2) + _merge_stats(all_stats, v2_stats) else: print("No V2 candidates found. Skipping.") except Exception: From 7d9ee77473e815168a77daef81e94b70d19eaff1 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 11 May 2026 14:50:13 +0800 Subject: [PATCH 11/11] Optimize session. --- sqlite/graphsample_insert.py | 291 +++++++++++++++-------------------- 1 file changed, 122 insertions(+), 169 deletions(-) diff --git a/sqlite/graphsample_insert.py b/sqlite/graphsample_insert.py index f3638d1bf1..c321c569ce 100755 --- a/sqlite/graphsample_insert.py +++ b/sqlite/graphsample_insert.py @@ -16,7 +16,6 @@ SampleInputTensorMeta, ) from sqlalchemy import delete as sql_delete -from sqlalchemy.exc import IntegrityError # graph_sample insert func @@ -44,76 +43,60 @@ def get_graph_sample_data( return data -def insert_graph_sample(db_path: str, data: dict, model_path_prefix: str): - session = get_session(db_path) - try: - graph_sample = GraphSample(**data) - session.add(graph_sample) - session.commit() - return graph_sample - except IntegrityError as e: - session.rollback() - raise e - finally: - session.close() +def insert_graph_sample(session, data): + graph_sample = GraphSample(**data) + session.add(graph_sample) + return graph_sample # subgraph source insert func def insert_subgraph_source( + session, subgraph_uuid: str, model_path_prefix: str, sample_type: str, relative_model_path: str, - db_path: str, ): - session = get_session(db_path) - try: - parent_relative_path = get_parent_relative_path(relative_model_path) - if sample_type == "fusible_graph" or sample_type == "typical_graph": - parent_parts = parent_relative_path.split("/") - parent_parts = parent_parts[2:] - parent_relative_path = "/".join(parent_parts) - if sample_type == "sole_op_graph": - parent_parts = parent_relative_path.split("/") - parent_parts = parent_parts[1:] - parent_relative_path = "/".join(parent_parts) - - full_graph = ( - session.query(GraphSample) - .filter( - GraphSample.relative_model_path == parent_relative_path, - GraphSample.sample_type == "full_graph", - ) - .first() + parent_relative_path = get_parent_relative_path(relative_model_path) + if sample_type == "fusible_graph" or sample_type == "typical_graph": + parent_parts = parent_relative_path.split("/") + parent_parts = parent_parts[2:] + parent_relative_path = "/".join(parent_parts) + if sample_type == "sole_op_graph": + parent_parts = parent_relative_path.split("/") + parent_parts = parent_parts[1:] + parent_relative_path = "/".join(parent_parts) + + full_graph = ( + session.query(GraphSample) + .filter( + GraphSample.relative_model_path == parent_relative_path, + GraphSample.sample_type == "full_graph", ) + .first() + ) - if not full_graph: - raise ValueError(f"Full graph not found for path: {parent_relative_path}") + if not full_graph: + raise ValueError(f"Full graph not found for path: {parent_relative_path}") - range_info = _get_parent_key_and_range(model_path_prefix, relative_model_path) - subgraph_source = SubgraphSource( - subgraph_uuid=subgraph_uuid, - full_graph_uuid=full_graph.uuid, - range_start=range_info["start"], - range_end=range_info["end"], - create_at=datetime.now(), - deleted=False, - delete_at=None, - ) - session.add(subgraph_source) - session.commit() + range_info = _get_parent_key_and_range(model_path_prefix, relative_model_path) + subgraph_source = SubgraphSource( + subgraph_uuid=subgraph_uuid, + full_graph_uuid=full_graph.uuid, + range_start=range_info["start"], + range_end=range_info["end"], + create_at=datetime.now(), + deleted=False, + delete_at=None, + ) + session.add(subgraph_source) - return { - "subgraph_uuid": subgraph_source.subgraph_uuid, - "full_graph_uuid": subgraph_source.full_graph_uuid, - "range_start": subgraph_source.range_start, - "range_end": subgraph_source.range_end, - } - except IntegrityError as e: - session.rollback() - raise e - finally: - session.close() + return { + "subgraph_uuid": subgraph_source.subgraph_uuid, + "full_graph_uuid": subgraph_source.full_graph_uuid, + "range_start": subgraph_source.range_start, + "range_end": subgraph_source.range_end, + } def get_parent_relative_path(relative_path: str) -> str: @@ -177,31 +160,23 @@ def _get_create_at() -> datetime: # DimensionGeneralizationSource insert func def insert_dimension_generalization_source( + session, generalized_graph_uuid: str, original_graph_uuid: str, model_path_prefix: str, relative_model_path: str, - db_path: str, ): - session = get_session(db_path) - try: - dimension_source = DimensionGeneralizationSource( - generalized_graph_uuid=generalized_graph_uuid, - original_graph_uuid=original_graph_uuid, - total_element_size=_get_total_element_size( - model_path_prefix, relative_model_path - ), - create_at=datetime.now(), - deleted=False, - delete_at=None, - ) - session.add(dimension_source) - session.commit() - except IntegrityError as e: - session.rollback() - raise e - finally: - session.close() + dimension_source = DimensionGeneralizationSource( + generalized_graph_uuid=generalized_graph_uuid, + original_graph_uuid=original_graph_uuid, + total_element_size=_get_total_element_size( + model_path_prefix, relative_model_path + ), + create_at=datetime.now(), + deleted=False, + delete_at=None, + ) + session.add(dimension_source) def _get_total_element_size(model_path_prefix: str, relative_model_path: str): @@ -233,29 +208,21 @@ def _get_total_element_size(model_path_prefix: str, relative_model_path: str): # DataTypeGeneralizationSource insert func def insert_datatype_generalization_source( + session, generalized_graph_uuid: str, original_graph_uuid: str, model_path_prefix: str, relative_model_path: str, - db_path: str, ): - session = get_session(db_path) - try: - data_type_source = DataTypeGeneralizationSource( - generalized_graph_uuid=generalized_graph_uuid, - original_graph_uuid=original_graph_uuid, - data_type=_get_data_type(model_path_prefix, relative_model_path), - create_at=datetime.now(), - deleted=False, - delete_at=None, - ) - session.add(data_type_source) - session.commit() - except IntegrityError as e: - session.rollback() - raise e - finally: - session.close() + data_type_source = DataTypeGeneralizationSource( + generalized_graph_uuid=generalized_graph_uuid, + original_graph_uuid=original_graph_uuid, + data_type=_get_data_type(model_path_prefix, relative_model_path), + create_at=datetime.now(), + deleted=False, + delete_at=None, + ) + session.add(data_type_source) def _get_data_type(model_path_prefix: str, relative_model_path: str): @@ -284,11 +251,11 @@ def _get_parent_key_and_range(model_path_prefix: str, relative_model_path: str) def insert_sample_op_name_list( + session, sample_uuid: str, model_path_prefix: str, op_names_path_prefix: str, relative_model_path: str, - db_path: str, ): if not op_names_path_prefix: print("op_names_path_prefix not provided, skipping insert_sample_op_name_list") @@ -333,54 +300,42 @@ def insert_sample_op_name_list( op_names_json = json.dumps( [{"op_name": name, "op_idx": i} for i, name in enumerate(selected_op_names)] ) - session = get_session(db_path) - try: - session.execute( - sql_delete(SampleOpNameList).where( - SampleOpNameList.sample_uuid == sample_uuid - ) - ) - session.execute( - sql_delete(SampleOpName).where(SampleOpName.sample_uuid == sample_uuid) - ) - sample_op_name_list = SampleOpNameList( + session.execute( + sql_delete(SampleOpNameList).where(SampleOpNameList.sample_uuid == sample_uuid) + ) + session.execute( + sql_delete(SampleOpName).where(SampleOpName.sample_uuid == sample_uuid) + ) + sample_op_name_list = SampleOpNameList( + sample_uuid=sample_uuid, + op_names_json=op_names_json, + create_at=datetime.now(), + deleted=False, + delete_at=None, + ) + session.add(sample_op_name_list) + + for idx, op_name in enumerate(selected_op_names): + sample_op_name = SampleOpName( sample_uuid=sample_uuid, - op_names_json=op_names_json, + op_name=op_name, + op_idx=idx, + op_size=op_size, create_at=datetime.now(), deleted=False, delete_at=None, ) - session.add(sample_op_name_list) + session.add(sample_op_name) - for idx, op_name in enumerate(selected_op_names): - sample_op_name = SampleOpName( - sample_uuid=sample_uuid, - op_name=op_name, - op_idx=idx, - op_size=op_size, - create_at=datetime.now(), - deleted=False, - delete_at=None, - ) - session.add(sample_op_name) - - session.commit() - print( - f"Inserted {len(selected_op_names)} op_names for sample_uuid={sample_uuid}" - ) - except IntegrityError as e: - session.rollback() - raise e - finally: - session.close() + print(f"Inserted {len(selected_op_names)} op_names for sample_uuid={sample_uuid}") # SampleInputTensorMeta insert func def insert_sample_input_tensor_meta( + session, sample_uuid: str, model_path_prefix: str, relative_model_path: str, - db_path: str, ): from graph_net.tensor_meta import TensorMeta @@ -408,15 +363,14 @@ def insert_sample_input_tensor_meta( print(f"No tensor meta found in {weight_meta_file}") return - session = get_session(db_path) - try: - session.execute( - sql_delete(SampleInputTensorMeta).where( - SampleInputTensorMeta.sample_uuid == sample_uuid - ) + session.execute( + sql_delete(SampleInputTensorMeta).where( + SampleInputTensorMeta.sample_uuid == sample_uuid ) - for meta in input_tensor_metas: - sample_input_tensor_meta = SampleInputTensorMeta( + ) + for meta in input_tensor_metas: + session.add( + SampleInputTensorMeta( sample_uuid=sample_uuid, input_name=meta["input_name"], input_idx=meta["input_idx"], @@ -426,18 +380,11 @@ def insert_sample_input_tensor_meta( deleted=False, delete_at=None, ) - session.add(sample_input_tensor_meta) - - session.commit() - print( - f"Inserted {len(input_tensor_metas)} input tensor meta(s) for sample_uuid={sample_uuid}" ) - except IntegrityError as e: - session.rollback() - print(f"Error inserting input tensor meta: {e}") - raise e - finally: - session.close() + + print( + f"Inserted {len(input_tensor_metas)} input tensor meta(s) for sample_uuid={sample_uuid}" + ) # main func @@ -465,56 +412,62 @@ def insert_one_sample( order_value=order_value, ) print(f"\ninsert into database: {db_path=}, {sample_type=}") + successed = True + session = get_session(db_path) try: - insert_graph_sample(db_path, data, model_path_prefix) + insert_graph_sample(session, data) if data["is_subgraph"]: subgraph_source_data = insert_subgraph_source( + session=session, subgraph_uuid=data["uuid"], model_path_prefix=model_path_prefix, sample_type=sample_type, relative_model_path=relative_model_path, - db_path=db_path, ) insert_sample_op_name_list( + session=session, sample_uuid=data["uuid"], model_path_prefix=model_path_prefix, op_names_path_prefix=op_names_path_prefix, relative_model_path=relative_model_path, - db_path=db_path, ) insert_sample_input_tensor_meta( + session=session, sample_uuid=data["uuid"], model_path_prefix=model_path_prefix, relative_model_path=relative_model_path, - db_path=db_path, ) if sample_type in ["fusible_graph", "typical_graph"]: insert_dimension_generalization_source( - subgraph_source_data["subgraph_uuid"], - subgraph_source_data["full_graph_uuid"], - model_path_prefix, - relative_model_path, - db_path, + session=session, + generalized_graph_uuid=subgraph_source_data["subgraph_uuid"], + original_graph_uuid=subgraph_source_data["full_graph_uuid"], + model_path_prefix=model_path_prefix, + relative_model_path=relative_model_path, ) insert_datatype_generalization_source( - subgraph_source_data["subgraph_uuid"], - subgraph_source_data["full_graph_uuid"], - model_path_prefix, - relative_model_path, - db_path, + session=session, + generalized_graph_uuid=subgraph_source_data["subgraph_uuid"], + original_graph_uuid=subgraph_source_data["full_graph_uuid"], + model_path_prefix=model_path_prefix, + relative_model_path=relative_model_path, ) + session.commit() print(f"insert {sample_type} success: {data['relative_model_path']}") except sqlite3.IntegrityError as e: + session.rollback() print( "insert {sample_type} failed: integrity error (possible duplicate uuid or graph_hash)" ) print(f"error info: {e}") - return False + successed = False except Exception as e: + session.rollback() print(f"insert {sample_type} failed: {e}") - return False - - return True + successed = False + finally: + session.close() + return successed def main(args):