diff --git a/graph_net/tools/generate_subgraph_dataset.sh b/graph_net/tools/generate_subgraph_dataset.sh index 208ae0d182..e29a0cbb33 100755 --- a/graph_net/tools/generate_subgraph_dataset.sh +++ b/graph_net/tools/generate_subgraph_dataset.sh @@ -93,36 +93,6 @@ function generate_subgraph_list() { | tee $sample_list } -function insert_graph_sample(){ - local target_dir="$1" - local repo_uid="$2" - local sample_type="$3" - local sample_list="$4" - echo ">>> [0] Inserting samples into database: ${DB_PATH}." - echo ">>>" - - if [ ! -f "$DB_PATH" ]; then - echo "Fail ! No Database ! : $DB_PATH" - exit 1 - fi - - local order_value=0 - while IFS= read -r model_rel_path; do - echo "insert : $model_rel_path" - python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \ - --model_path_prefix "${target_dir}" \ - --relative_model_path "$model_rel_path" \ - --repo_uid "${repo_uid}" \ - --sample_type "${sample_type}" \ - --order_value "$order_value" \ - --db_path "$DB_PATH" - - ((order_value++)) - - done < "$sample_list" -} - - function rewrite_device() { echo ">>> [1] Rewrite devices for subgraph samples under ${GRAPH_NET_ROOT}." echo ">>>" @@ -658,29 +628,6 @@ function generate_typical_subgraphs() { # generate_unittest_for_typical_subgraphs 2>&1 | tee ${DECOMPOSE_WORKSPACE}/log_unittests_typical_subgraphs_${suffix}.txt } -function generate_database() { - timestamp=`date +%Y%m%d_%H%M` - - # init database - if [ ! -f ${DB_PATH} ]; then - python ${GRAPH_NET_ROOT}/sqlite/init_db.py --db_path ${DB_PATH} 2>&1 | tee ${DECOMPOSE_WORKSPACE}/log_init_db_${timestamp}.txt - fi - - # full_graph - insert_graph_sample ${GRAPH_NET_ROOT} "hf_torch_samples" "full_graph" ${model_list} - - # fusible_graph, typical_graph - for sample_type in fusible_graph typical_graph; do - insert_graph_sample $OUTPUT_DIR/$sample_type "hf_torch_samples" $sample_type $OUTPUT_DIR/${sample_type}/sample_list.txt - done - - # insert buckets - python ${GRAPH_NET_ROOT}/sqlite/graph_net_sample_bucket_generator.py --db_path ${DB_PATH} - - # insert groups - python ${GRAPH_NET_ROOT}/sqlite/graph_net_sample_groups_insert.py --db_path ${DB_PATH} -} - function main() { do_common_generalzation_and_decompose @@ -693,8 +640,6 @@ function main() { generate_typical_subgraphs #cp -rf $DTYPE_GENERALIZED_TYPICAL_SUBGRAPH_DIR $OUTPUT_DIR/$sample_type #cp -rf $dtype_generalized_typical_subgraph_list $OUTPUT_DIR/$sample_type/sample_list.txt - - #generate_database } function summary() { diff --git a/sqlite/README.md b/sqlite/README.md new file mode 100755 index 0000000000..58cc59f1b5 --- /dev/null +++ b/sqlite/README.md @@ -0,0 +1,135 @@ +# GraphNet SQLite 操作指南 + +## 目录结构 + +``` +sqlite/ +├── migrates/ # SQL 迁移文件(按时间戳顺序执行) +├── orm_models.py # SQLAlchemy ORM 模型定义 +├── init_db.py # 数据库初始化 +├── build_db.py # 批量建库(推荐) +├── graphsample_insert.py # 单条样本插入 +├── graphsample_delete.py # 单条样本删除(软删除) +├── merge_db.py # 数据库合并 +├── graph_sample_bucket_generator.py # 样本分桶元数据生成 +├── graph_sample_groups_insert.py # 采样分组生成 +├── upload_dataset.py # HuggingFace 上传 +├── download_dataset.py # HuggingFace 下载 +``` + +## 数据表概览 + +| 表名 | 用途 | +|------|------| +| `repo` | 仓库源信息 | +| `graph_sample` | 计算图样本主表 | +| `subgraph_source` | 子图来源映射 | +| `dimension_generalization_source` | 维度泛化来源 | +| `datatype_generalization_source` | 数据类型泛化来源 | +| `backward_graph_source` | 反向图来源 | +| `sample_op_name` / `sample_op_name_list` | 算子名称序列 | +| `sample_input_tensor_meta` | 输入张量元信息 | +| `graph_net_sample_buckets` | 样本分桶元数据 | +| `graph_net_sample_groups` | 采样分组 | + +所有删除操作均为软删除(`deleted` 字段标记),不物理删除数据。 + +## 数据库初始化 + +从 `migrates/` 目录按时间戳顺序执行 SQL 文件,创建所有表结构。**库文件已存在时会被删除重建。** + +```bash +# 默认路径 GraphNet.db +python init_db.py 2>&1 | tee logs/init_db_$(date +"%Y%m%d_%H%M%S").log + +# 自定义路径 +python init_db.py --db_path xxx.db +``` + +## 批量建库(推荐) + +一次性处理 `full_graph`、`typical_graph`、`fusible_graph`、`sole_op_graph` 四种样本类型,自动收集目录或读取 list 文件后逐条插入。库文件不存在时自动初始化。 + +```bash +python build_db.py \ + --db_path GraphNet.db \ + --dataset_root /path/to/dataset \ + --repo_uid "hf_torch_samples" \ + --op_names_path_prefix /path/to/sample_op_names +``` + +## 单条样本操作 + +```bash +# 插入单条 +python graphsample_insert.py \ + --model_path_prefix /path/to/dataset/full_graph \ + --relative_model_path models/torch/resnet18 \ + --repo_uid "hf_torch_samples" \ + --sample_type "full_graph" \ + --order_value 0 \ + --db_path GraphNet.db + +# 删除单条(软删除,设置 deleted=1) +python graphsample_delete.py \ + --db_path GraphNet.db \ + --repo_uid "hf_torch_samples" \ + --relative_model_path "models/torch/resnet18" +``` + +## Shell 批量脚本 + +```bash +# 批量插入(从 list 文件逐行读取) +bash graphsample_insert.sh [db_path] + +# 批量删除(从 graph_net/config/delete_list.txt 读取) +bash graphsample_delete.sh [db_path] +``` + +## 数据库合并 + +将新库的所有记录合并到主库,自动跳过已存在的 repo 和 graph_sample。 + +```bash +python merge_db.py \ + --main_db_path GraphNet.db \ + --new_db_path new.db +``` + +## 样本分桶与分组 + +分两步:先生成样本的分桶元数据(op 序列、input shape、dtype 的哈希 ID),再基于分桶结果按策略生成采样分组。 + +```bash +# 生成分桶元数据 → 写入 graph_net_sample_buckets 表 +python graph_sample_bucket_generator.py --db_path GraphNet.db + +# 生成采样分组 → 写入 graph_net_sample_groups 表 +# 策略: bucket_policy_v1 (stride-16 + cross-shape) + bucket_policy_v2 (dtype coverage + sparse) +python graph_sample_groups_insert.py --db_path GraphNet.db --num_dtypes 3 +``` + +## HuggingFace 上传/下载 + +```bash +# 上传:打包 dataset 目录 + GraphNet.db 到 HF Hub +python upload_dataset.py \ + --hf_token \ + --base_dir /path/to/dataset \ + --repo_id "PaddlePaddle/GraphNet" \ + --revision "20260203" \ + --split "GraphNet" + +# 下载:从 HF Hub 拉取 dataset 和 GraphNet.db +python download_dataset.py \ + --repo_id "PaddlePaddle/GraphNet" \ + --revision "20260224" \ + --save_dir ./workspace \ + --split "GraphNet" +``` + +## 关联资源 + +- ORM 模型定义: [orm_models.py](orm_models.py) +- SQL 迁移文件: [migrates/](migrates/) \ No newline at end of file diff --git a/sqlite/Readme.md b/sqlite/Readme.md deleted file mode 100755 index dd181bf4d2..0000000000 --- a/sqlite/Readme.md +++ /dev/null @@ -1,55 +0,0 @@ -# SQLite Operations Guide - -**Working Directory:** `/GraphNet` - -## Setup - -```bash -# Create logs directory -mkdir -p sqlite/logs -``` - -## Initialize Database - -```bash -# Default DB path (sqlite/GraphNet.db) -python ./sqlite/init_db.py 2>&1 | tee sqlite/logs/init_db_$(date +"%Y%m%d_%H%M%S").log - -# Custom DB path -python ./sqlite/init_db.py --db_path sqlite/xxx.db 2>&1 | tee sqlite/logs/init_db_$(date +"%Y%m%d_%H%M%S").log -``` - -## Insert Graph Samples - -```bash -# Insert to custom DB -bash ./sqlite/graphsample_insert.sh sqlite/xxx.db 2>&1 | tee sqlite/logs/insert_$(date +"%Y%m%d_%H%M%S").log - -# Insert to default DB (sqlite/GraphNet.db) -bash ./sqlite/graphsample_insert.sh 2>&1 | tee sqlite/logs/insert_$(date +"%Y%m%d_%H%M%S").log -``` - -## Delete Graph Samples - -```bash -# Delete from custom DB -bash ./sqlite/graphsample_delete.sh sqlite/xxx.db 2>&1 | tee sqlite/logs/delete_$(date +"%Y%m%d_%H%M%S").log - -# Delete from default DB (sqlite/GraphNet.db) -bash ./sqlite/graphsample_delete.sh 2>&1 | tee sqlite/logs/delete_$(date +"%Y%m%d_%H%M%S").log -``` - -## Merge Databases - -```bash -# Usage: python ./sqlite/merge_db.py --main_db_path --new_db_path -python ./sqlite/merge_db.py --main_db_path sqlite/GraphNet.db --new_db_path sqlite/new.db -``` - -## Upload to Hugging Face - -```bash -python ./sqlite/upload.py -``` - -**Note:** Set `HF_TOKEN` variable in `upload.py` before running. diff --git a/sqlite/build_db.py b/sqlite/build_db.py new file mode 100755 index 0000000000..12ee6615d9 --- /dev/null +++ b/sqlite/build_db.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +import argparse +import os +import sys + +from graphsample_insert import insert_one_sample +from init_db import migrate + + +def collect_sample_paths(model_path_prefix): + """Collect relative paths of directories containing model.py under model_path_prefix.""" + sample_paths = [] + for dirpath, _, filenames in os.walk(model_path_prefix): + if "model.py" in filenames: + rel = os.path.relpath(dirpath, model_path_prefix) + sample_paths.append(rel) + sample_paths.sort() + return sample_paths + + +def insert_from_list( + list_file_path, + model_path_prefix, + sample_type, + repo_uid, + db_path, + op_names_path_prefix, + start_order=0, +): + if os.path.isfile(list_file_path): + with open(list_file_path) as f: + sample_paths = [line.strip() for line in f if line.strip()] + sample_paths.sort() + else: + print( + f"List file not found: {list_file_path}, collecting from {model_path_prefix}" + ) + sample_paths = collect_sample_paths(model_path_prefix) + + total = len(sample_paths) + order_value = start_order + for relative_model_path in sample_paths: + print(f"insert : {relative_model_path}") + successed = insert_one_sample( + model_path_prefix=model_path_prefix, + relative_model_path=relative_model_path, + repo_uid=repo_uid, + sample_type=sample_type, + order_value=order_value, + db_path=db_path, + op_names_path_prefix=op_names_path_prefix, + ) + if successed: + order_value += 1 + assert successed + + return order_value, total + + +def main(args): + dataset_root = args.dataset_root.strip() + db_path = args.db_path.strip() + repo_uid = args.repo_uid.strip() + op_names_path_prefix = args.op_names_path_prefix.strip() + + if not os.path.exists(db_path): + migrate(db_path) + + print(f"db_path={db_path}, repo_uid={repo_uid}") + order_value = 0 + + sample_types = ["full_graph", "typical_graph", "fusible_graph", "sole_op_graph"] + for sample_type in sample_types: + model_path_prefix = os.path.join(dataset_root, sample_type) + list_file_path = os.path.join(dataset_root, f"{sample_type}_list.txt") + print(f"\n[{sample_type}] samples={model_path_prefix}, list={list_file_path}") + + if not os.path.isdir(model_path_prefix): + if sample_type == "full_graph": + print(f"Fail ! full_graph directory not found: {model_path_prefix}") + sys.exit(1) + print(f"[{sample_type}] skipped, directory not found") + continue + + order_start = order_value + order_value, total = insert_from_list( + list_file_path=list_file_path, + model_path_prefix=model_path_prefix, + sample_type=sample_type, + repo_uid=repo_uid, + db_path=db_path, + op_names_path_prefix=op_names_path_prefix, + start_order=order_value, + ) + num_success = order_value - order_start + print( + f"[{sample_type}] total={total}, success={num_success}, " + f"fail={total - num_success}, order=[{order_start}, {order_value})" + ) + + from graph_sample_bucket_generator import generate_buckets + from graph_sample_groups_insert import generate_groups + + print("\nGenerate buckets and groups:") + generate_buckets(db_path) + generate_groups(db_path, num_dtypes=3) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Batch insert graph samples from list files" + ) + parser.add_argument( + "--db_path", + type=str, + required=True, + help="Database file path", + ) + parser.add_argument( + "--dataset_root", + type=str, + required=True, + help="Dataset root directory", + ) + parser.add_argument( + "--repo_uid", + type=str, + default="hf_torch_samples", + help="Repository uid", + ) + parser.add_argument( + "--op_names_path_prefix", + type=str, + required=True, + help="Path prefix of op names files", + ) + args = parser.parse_args() + main(args) diff --git a/sqlite/download.py b/sqlite/download.py deleted file mode 100755 index d9cec2c722..0000000000 --- a/sqlite/download.py +++ /dev/null @@ -1,22 +0,0 @@ -import os -from datasets import load_dataset -from huggingface_hub import hf_hub_download - -REPO_ID = "PaddlePaddle/GraphNet" -REVISION = "20260224" -SAVE_DIR = "./workspace" - -ds = load_dataset(REPO_ID, split="GraphNet", revision=REVISION) -for item in ds: - full_path = os.path.join(SAVE_DIR, item["path"]) - os.makedirs(os.path.dirname(full_path), exist_ok=True) - with open(full_path, "w", encoding="utf-8") as f: - f.write(item["content"]) - -hf_hub_download( - repo_id=REPO_ID, - filename="GraphNet.db", - repo_type="dataset", - revision=REVISION, - local_dir=SAVE_DIR, -) diff --git a/sqlite/download_dataset.py b/sqlite/download_dataset.py new file mode 100755 index 0000000000..9e2702b1e6 --- /dev/null +++ b/sqlite/download_dataset.py @@ -0,0 +1,44 @@ +import argparse +import os +from datasets import load_dataset +from huggingface_hub import hf_hub_download + + +def main(args): + ds = load_dataset(args.repo_id, split=args.split, revision=args.revision) + for item in ds: + full_path = os.path.join(args.save_dir, item["path"]) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, "w", encoding="utf-8") as f: + f.write(item["content"]) + + hf_hub_download( + repo_id=args.repo_id, + filename=args.db_file, + repo_type="dataset", + revision=args.revision, + local_dir=args.save_dir, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Download dataset and DB from HuggingFace Hub" + ) + parser.add_argument( + "--repo_id", type=str, default="PaddlePaddle/GraphNet", help="HF repo ID" + ) + parser.add_argument( + "--revision", type=str, default="main", help="HF repo revision/branch" + ) + parser.add_argument( + "--save_dir", type=str, default="./workspace", help="Local save directory" + ) + parser.add_argument( + "--split", type=str, default="GraphNet", help="Dataset split name" + ) + parser.add_argument( + "--db_file", type=str, default="GraphNet.db", help="DB filename to download" + ) + args = parser.parse_args() + main(args) diff --git a/sqlite/graph_net_sample_bucket_generator.py b/sqlite/graph_sample_bucket_generator.py similarity index 85% rename from sqlite/graph_net_sample_bucket_generator.py rename to sqlite/graph_sample_bucket_generator.py index 0a690e4c67..dbbf3b2682 100644 --- a/sqlite/graph_net_sample_bucket_generator.py +++ b/sqlite/graph_sample_bucket_generator.py @@ -213,52 +213,36 @@ def save_bucket_results( return count -def main(): - parser = argparse.ArgumentParser( - description="Generate graph_net_sample_buckets from graph_sample" - ) - parser.add_argument( - "--db_path", - type=str, - required=True, - help="Path to the SQLite database file", - ) - parser.add_argument( - "--dry_run", - action="store_true", - help="Only print what would be done, don't actually insert into database", - ) - - args = parser.parse_args() - - session = get_session(args.db_path) - - print("=" * 70) - print("Step 1: Generating bucket info from graph_sample...") - sample_type_results, all_bucket_info_map = generate_sample_buckets(session) - print(f" Total samples: {len(all_bucket_info_map)}") - print(f" Number of sample_types: {len(sample_type_results)}") - - print() - for result in sorted(sample_type_results, key=lambda x: -len(x)): - flag = " [sole-op]" if result.is_sole_op else "" - print(f" {result.sample_type}{flag}: {len(result)} samples") +def generate_buckets(db_path, dry_run=False): + """Generate buckets and save to DB.""" + session = get_session(db_path) + try: + print("=" * 70) + print("Step 1: Generating bucket info from graph_sample...") + _, bucket_info_map = generate_sample_buckets(session) + if dry_run: + print("Dry run mode - skipping database insert") + print( + f" Would insert {len(bucket_info_map)} records into graph_net_sample_buckets" + ) + return 0 - print("=" * 70) - if args.dry_run: - print("Dry run mode - skipping database insert") - print( - f" Would insert {len(all_bucket_info_map)} records into graph_net_sample_buckets" - ) - else: print("Step 2: Saving to database...") - count = save_bucket_results(session, all_bucket_info_map) + count = save_bucket_results(session, bucket_info_map) print(f" Inserted {count} records into graph_net_sample_buckets") + return count + finally: + session.close() - print("=" * 70) - print("Done!") - session.close() +def main(): + parser = argparse.ArgumentParser( + description="Generate graph_net_sample_buckets from graph_sample" + ) + parser.add_argument("--db_path", type=str, required=True) + parser.add_argument("--dry_run", action="store_true") + args = parser.parse_args() + generate_buckets(args.db_path, dry_run=args.dry_run) if __name__ == "__main__": diff --git a/sqlite/graph_net_sample_groups_insert.py b/sqlite/graph_sample_groups_insert.py similarity index 72% rename from sqlite/graph_net_sample_groups_insert.py rename to sqlite/graph_sample_groups_insert.py index d22e64e231..aaf83adb11 100755 --- a/sqlite/graph_net_sample_groups_insert.py +++ b/sqlite/graph_sample_groups_insert.py @@ -7,8 +7,6 @@ from orm_models import get_session, GraphNetSampleGroup -# ── Types ── - BucketGroup = namedtuple( "BucketGroup", ["head_uid", "op_seq", "shapes", "sample_type", "all_uids_csv"], @@ -20,9 +18,6 @@ ) -# ── Helpers ── - - def _new_group_id(): return str(uuid_module.uuid4()) @@ -43,17 +38,14 @@ def _print_stats(stats): for rule in rule_order: key = (sample_type, rule) if key in stats: - n_records = stats[key]["records"] - n_groups = len(stats[key]["groups"]) - print(f" {rule}: {n_records} records, {n_groups} groups") - total_records += n_records - total_groups += n_groups + record_count = stats[key]["records"] + group_count = len(stats[key]["groups"]) + print(f" {rule}: {record_count} records, {group_count} groups") + total_records += record_count + total_groups += group_count print(f"\n Total: {total_records} records, {total_groups} groups.") -# ── Database Queries ── - - class DB: def __init__(self, path): self.path = path @@ -122,12 +114,10 @@ def query_v2_candidates(db: DB) -> list[Candidate]: # ═══════════════════════════════════════════════════════════════════ # V1: Rule 1 (bucket-internal stride sampling) + Rule 2 (cross-shape) # ═══════════════════════════════════════════════════════════════════ - - def generate_v1_groups(bucket_groups: list[BucketGroup]): """Yields (sample_type, uid, group_id, rule_name). - Rule 1: stride-16 sampling within each bucket, one group per sample. + Rule 1: stride-16 sampling within each bucket, 1 subgraph per group. Rule 2: aggregate all bucket heads sharing the same (sample_type, op_seq). """ # Rule 1 @@ -141,54 +131,55 @@ def generate_v1_groups(bucket_groups: list[BucketGroup]): heads_by_type_op = defaultdict(list) for bucket in bucket_groups: heads_by_type_op[(bucket.sample_type, bucket.op_seq)].append(bucket.head_uid) - for (sample_type, _op), heads in heads_by_type_op.items(): - gid = _new_group_id() + for (sample_type, _), heads in heads_by_type_op.items(): + group_id = _new_group_id() for uid in heads: - yield sample_type, uid, gid, "rule2" + yield sample_type, uid, group_id, "rule2" # ═══════════════════════════════════════════════════════════════════ # V2: Rule 4 (dtype coverage) + Rule 3 (sparse sampling on remainder) # ═══════════════════════════════════════════════════════════════════ - - def generate_v2_groups(candidates: list[Candidate], num_dtypes: int): """Yields (sample_type, uid, group_id, rule_name). Rule 4 (first): per (sample_type, op_seq, shape), pick up to - num_dtypes samples with distinct dtypes. + num_dtypes samples with distinct dtypes. Rule 3 (second): window-based sparse sampling on the remainder, window_size = num_dtypes * 5, pick first num_dtypes. """ by_type_op = defaultdict(list) - for c in candidates: - by_type_op[(c.sample_type, c.op_seq)].append(c) + for candidate in candidates: + by_type_op[(candidate.sample_type, candidate.op_seq)].append(candidate) covered_uids = set() # Rule 4: dtype coverage - for (sample_type, _op), group in by_type_op.items(): + for (sample_type, _), group in by_type_op.items(): by_shape = defaultdict(list) - for c in group: - by_shape[c.shapes].append(c) + for candidate in group: + by_shape[candidate.shapes].append(candidate) picked = [] - for _shape, shape_group in by_shape.items(): - seen_dtypes = set() - for c in shape_group: - if c.dtypes not in seen_dtypes and len(seen_dtypes) < num_dtypes: - seen_dtypes.add(c.dtypes) - picked.append(c.uid) - covered_uids.add(c.uid) + for _, shape_group in by_shape.items(): + picked_dtypes = set() + for candidate in shape_group: + if ( + candidate.dtypes not in picked_dtypes + and len(picked_dtypes) < num_dtypes + ): + picked_dtypes.add(candidate.dtypes) + picked.append(candidate.uid) + covered_uids.add(candidate.uid) if picked: - gid = _new_group_id() + group_id = _new_group_id() for uid in picked: - yield sample_type, uid, gid, "rule4" + yield sample_type, uid, group_id, "rule4" # Rule 3: sparse sampling on remainder window_size = num_dtypes * 5 - for (sample_type, _op), group in by_type_op.items(): + for (sample_type, _), group in by_type_op.items(): remaining = sorted( (c for c in group if c.uid not in covered_uids), key=lambda c: c.uid, @@ -197,14 +188,9 @@ def generate_v2_groups(candidates: list[Candidate], num_dtypes: int): c.uid for i, c in enumerate(remaining) if (i % window_size) < num_dtypes ] if picked: - gid = _new_group_id() + group_id = _new_group_id() for uid in picked: - yield sample_type, uid, gid, "rule3" - - -# ═══════════════════════════════════════════════════════════════════ -# Insert -# ═══════════════════════════════════════════════════════════════════ + yield sample_type, uid, group_id, "rule3" def _insert_groups(session, rows, policy): @@ -229,42 +215,29 @@ def _insert_groups(session, rows, policy): return stats -# ═══════════════════════════════════════════════════════════════════ -# Main -# ═══════════════════════════════════════════════════════════════════ - - -def main(): - parser = argparse.ArgumentParser( - description="Generate graph_net_sample_groups (v1 + v2)" - ) - parser.add_argument("--db_path", type=str, required=True) - parser.add_argument("--num_dtypes", type=int, default=3) - args = parser.parse_args() - - db = DB(args.db_path) +def generate_groups(db_path, num_dtypes=3): + """Generate sample groups and save to DB.""" + db = DB(db_path) db.connect() - session = get_session(args.db_path) - + session = get_session(db_path) all_stats = defaultdict(lambda: {"records": 0, "groups": set()}) - try: - # V1 buckets = query_bucket_groups(db) print(f"Bucket groups: {len(buckets)}") - v1 = _insert_groups(session, generate_v1_groups(buckets), "bucket_policy_v1") - _merge_stats(all_stats, v1) + v1_stats = _insert_groups( + session, generate_v1_groups(buckets), "bucket_policy_v1" + ) + _merge_stats(all_stats, v1_stats) - # V2 candidates = query_v2_candidates(db) print(f"V2 candidates: {len(candidates)}") if candidates: - v2 = _insert_groups( + v2_stats = _insert_groups( session, - generate_v2_groups(candidates, args.num_dtypes), + generate_v2_groups(candidates, num_dtypes), "bucket_policy_v2", ) - _merge_stats(all_stats, v2) + _merge_stats(all_stats, v2_stats) else: print("No V2 candidates found. Skipping.") except Exception: @@ -276,7 +249,17 @@ def main(): print("=" * 60) _print_stats(all_stats) - print("\nDone!") + return all_stats + + +def main(): + parser = argparse.ArgumentParser( + description="Generate graph_net_sample_groups (v1 + v2)" + ) + parser.add_argument("--db_path", type=str, required=True) + parser.add_argument("--num_dtypes", type=int, default=3) + args = parser.parse_args() + generate_groups(args.db_path, args.num_dtypes) if __name__ == "__main__": diff --git a/sqlite/graphsample_insert.py b/sqlite/graphsample_insert.py index afb40816ab..c321c569ce 100755 --- a/sqlite/graphsample_insert.py +++ b/sqlite/graphsample_insert.py @@ -16,7 +16,6 @@ SampleInputTensorMeta, ) from sqlalchemy import delete as sql_delete -from sqlalchemy.exc import IntegrityError # graph_sample insert func @@ -44,76 +43,60 @@ def get_graph_sample_data( return data -def insert_graph_sample(db_path: str, data: dict, model_path_prefix: str): - session = get_session(db_path) - try: - graph_sample = GraphSample(**data) - session.add(graph_sample) - session.commit() - return graph_sample - except IntegrityError as e: - session.rollback() - raise e - finally: - session.close() +def insert_graph_sample(session, data): + graph_sample = GraphSample(**data) + session.add(graph_sample) + return graph_sample # subgraph source insert func def insert_subgraph_source( + session, subgraph_uuid: str, model_path_prefix: str, sample_type: str, relative_model_path: str, - db_path: str, ): - session = get_session(db_path) - try: - parent_relative_path = get_parent_relative_path(relative_model_path) - if sample_type == "fusible_graph" or sample_type == "typical_graph": - parent_parts = parent_relative_path.split("/") - parent_parts = parent_parts[2:] - parent_relative_path = "/".join(parent_parts) - if sample_type == "sole_op_graph": - parent_parts = parent_relative_path.split("/") - parent_parts = parent_parts[1:] - parent_relative_path = "/".join(parent_parts) - - full_graph = ( - session.query(GraphSample) - .filter( - GraphSample.relative_model_path == parent_relative_path, - GraphSample.sample_type == "full_graph", - ) - .first() + parent_relative_path = get_parent_relative_path(relative_model_path) + if sample_type == "fusible_graph" or sample_type == "typical_graph": + parent_parts = parent_relative_path.split("/") + parent_parts = parent_parts[2:] + parent_relative_path = "/".join(parent_parts) + if sample_type == "sole_op_graph": + parent_parts = parent_relative_path.split("/") + parent_parts = parent_parts[1:] + parent_relative_path = "/".join(parent_parts) + + full_graph = ( + session.query(GraphSample) + .filter( + GraphSample.relative_model_path == parent_relative_path, + GraphSample.sample_type == "full_graph", ) + .first() + ) - if not full_graph: - raise ValueError(f"Full graph not found for path: {parent_relative_path}") + if not full_graph: + raise ValueError(f"Full graph not found for path: {parent_relative_path}") - range_info = _get_parent_key_and_range(model_path_prefix, relative_model_path) - subgraph_source = SubgraphSource( - subgraph_uuid=subgraph_uuid, - full_graph_uuid=full_graph.uuid, - range_start=range_info["start"], - range_end=range_info["end"], - create_at=datetime.now(), - deleted=False, - delete_at=None, - ) - session.add(subgraph_source) - session.commit() + range_info = _get_parent_key_and_range(model_path_prefix, relative_model_path) + subgraph_source = SubgraphSource( + subgraph_uuid=subgraph_uuid, + full_graph_uuid=full_graph.uuid, + range_start=range_info["start"], + range_end=range_info["end"], + create_at=datetime.now(), + deleted=False, + delete_at=None, + ) + session.add(subgraph_source) - return { - "subgraph_uuid": subgraph_source.subgraph_uuid, - "full_graph_uuid": subgraph_source.full_graph_uuid, - "range_start": subgraph_source.range_start, - "range_end": subgraph_source.range_end, - } - except IntegrityError as e: - session.rollback() - raise e - finally: - session.close() + return { + "subgraph_uuid": subgraph_source.subgraph_uuid, + "full_graph_uuid": subgraph_source.full_graph_uuid, + "range_start": subgraph_source.range_start, + "range_end": subgraph_source.range_end, + } def get_parent_relative_path(relative_path: str) -> str: @@ -177,31 +160,23 @@ def _get_create_at() -> datetime: # DimensionGeneralizationSource insert func def insert_dimension_generalization_source( + session, generalized_graph_uuid: str, original_graph_uuid: str, model_path_prefix: str, relative_model_path: str, - db_path: str, ): - session = get_session(db_path) - try: - dimension_source = DimensionGeneralizationSource( - generalized_graph_uuid=generalized_graph_uuid, - original_graph_uuid=original_graph_uuid, - total_element_size=_get_total_element_size( - model_path_prefix, relative_model_path - ), - create_at=datetime.now(), - deleted=False, - delete_at=None, - ) - session.add(dimension_source) - session.commit() - except IntegrityError as e: - session.rollback() - raise e - finally: - session.close() + dimension_source = DimensionGeneralizationSource( + generalized_graph_uuid=generalized_graph_uuid, + original_graph_uuid=original_graph_uuid, + total_element_size=_get_total_element_size( + model_path_prefix, relative_model_path + ), + create_at=datetime.now(), + deleted=False, + delete_at=None, + ) + session.add(dimension_source) def _get_total_element_size(model_path_prefix: str, relative_model_path: str): @@ -233,29 +208,21 @@ def _get_total_element_size(model_path_prefix: str, relative_model_path: str): # DataTypeGeneralizationSource insert func def insert_datatype_generalization_source( + session, generalized_graph_uuid: str, original_graph_uuid: str, model_path_prefix: str, relative_model_path: str, - db_path: str, ): - session = get_session(db_path) - try: - data_type_source = DataTypeGeneralizationSource( - generalized_graph_uuid=generalized_graph_uuid, - original_graph_uuid=original_graph_uuid, - data_type=_get_data_type(model_path_prefix, relative_model_path), - create_at=datetime.now(), - deleted=False, - delete_at=None, - ) - session.add(data_type_source) - session.commit() - except IntegrityError as e: - session.rollback() - raise e - finally: - session.close() + data_type_source = DataTypeGeneralizationSource( + generalized_graph_uuid=generalized_graph_uuid, + original_graph_uuid=original_graph_uuid, + data_type=_get_data_type(model_path_prefix, relative_model_path), + create_at=datetime.now(), + deleted=False, + delete_at=None, + ) + session.add(data_type_source) def _get_data_type(model_path_prefix: str, relative_model_path: str): @@ -284,11 +251,11 @@ def _get_parent_key_and_range(model_path_prefix: str, relative_model_path: str) def insert_sample_op_name_list( + session, sample_uuid: str, model_path_prefix: str, op_names_path_prefix: str, relative_model_path: str, - db_path: str, ): if not op_names_path_prefix: print("op_names_path_prefix not provided, skipping insert_sample_op_name_list") @@ -333,54 +300,42 @@ def insert_sample_op_name_list( op_names_json = json.dumps( [{"op_name": name, "op_idx": i} for i, name in enumerate(selected_op_names)] ) - session = get_session(db_path) - try: - session.execute( - sql_delete(SampleOpNameList).where( - SampleOpNameList.sample_uuid == sample_uuid - ) - ) - session.execute( - sql_delete(SampleOpName).where(SampleOpName.sample_uuid == sample_uuid) - ) - sample_op_name_list = SampleOpNameList( + session.execute( + sql_delete(SampleOpNameList).where(SampleOpNameList.sample_uuid == sample_uuid) + ) + session.execute( + sql_delete(SampleOpName).where(SampleOpName.sample_uuid == sample_uuid) + ) + sample_op_name_list = SampleOpNameList( + sample_uuid=sample_uuid, + op_names_json=op_names_json, + create_at=datetime.now(), + deleted=False, + delete_at=None, + ) + session.add(sample_op_name_list) + + for idx, op_name in enumerate(selected_op_names): + sample_op_name = SampleOpName( sample_uuid=sample_uuid, - op_names_json=op_names_json, + op_name=op_name, + op_idx=idx, + op_size=op_size, create_at=datetime.now(), deleted=False, delete_at=None, ) - session.add(sample_op_name_list) + session.add(sample_op_name) - for idx, op_name in enumerate(selected_op_names): - sample_op_name = SampleOpName( - sample_uuid=sample_uuid, - op_name=op_name, - op_idx=idx, - op_size=op_size, - create_at=datetime.now(), - deleted=False, - delete_at=None, - ) - session.add(sample_op_name) - - session.commit() - print( - f"Inserted {len(selected_op_names)} op_names for sample_uuid={sample_uuid}" - ) - except IntegrityError as e: - session.rollback() - raise e - finally: - session.close() + print(f"Inserted {len(selected_op_names)} op_names for sample_uuid={sample_uuid}") # SampleInputTensorMeta insert func def insert_sample_input_tensor_meta( + session, sample_uuid: str, model_path_prefix: str, relative_model_path: str, - db_path: str, ): from graph_net.tensor_meta import TensorMeta @@ -408,15 +363,14 @@ def insert_sample_input_tensor_meta( print(f"No tensor meta found in {weight_meta_file}") return - session = get_session(db_path) - try: - session.execute( - sql_delete(SampleInputTensorMeta).where( - SampleInputTensorMeta.sample_uuid == sample_uuid - ) + session.execute( + sql_delete(SampleInputTensorMeta).where( + SampleInputTensorMeta.sample_uuid == sample_uuid ) - for meta in input_tensor_metas: - sample_input_tensor_meta = SampleInputTensorMeta( + ) + for meta in input_tensor_metas: + session.add( + SampleInputTensorMeta( sample_uuid=sample_uuid, input_name=meta["input_name"], input_idx=meta["input_idx"], @@ -426,74 +380,106 @@ def insert_sample_input_tensor_meta( deleted=False, delete_at=None, ) - session.add(sample_input_tensor_meta) - - session.commit() - print( - f"Inserted {len(input_tensor_metas)} input tensor meta(s) for sample_uuid={sample_uuid}" ) - except IntegrityError as e: - session.rollback() - print(f"Error inserting input tensor meta: {e}") - raise e - finally: - session.close() + + print( + f"Inserted {len(input_tensor_metas)} input tensor meta(s) for sample_uuid={sample_uuid}" + ) # main func -def main(args): +def insert_one_sample( + model_path_prefix: str, + relative_model_path: str, + repo_uid: str, + sample_type: str, + order_value: int, + db_path: str, + op_names_path_prefix: str = "", +): + model_path_prefix = model_path_prefix.strip() + relative_model_path = relative_model_path.strip() + repo_uid = repo_uid.strip() + sample_type = sample_type.strip() + db_path = db_path.strip() + op_names_path_prefix = op_names_path_prefix.strip() if op_names_path_prefix else "" + data = get_graph_sample_data( - model_path_prefix=args.model_path_prefix, - relative_model_path=args.relative_model_path, - repo_uid=args.repo_uid, - sample_type=args.sample_type, - order_value=args.order_value, + model_path_prefix=model_path_prefix, + relative_model_path=relative_model_path, + repo_uid=repo_uid, + sample_type=sample_type, + order_value=order_value, ) - print(f"\ninsert into database: {args.db_path}") + print(f"\ninsert into database: {db_path=}, {sample_type=}") + successed = True + session = get_session(db_path) try: - insert_graph_sample(args.db_path, data, args.model_path_prefix) + insert_graph_sample(session, data) if data["is_subgraph"]: subgraph_source_data = insert_subgraph_source( + session=session, subgraph_uuid=data["uuid"], - model_path_prefix=args.model_path_prefix, - sample_type=args.sample_type, - relative_model_path=args.relative_model_path, - db_path=args.db_path, + model_path_prefix=model_path_prefix, + sample_type=sample_type, + relative_model_path=relative_model_path, ) insert_sample_op_name_list( + session=session, sample_uuid=data["uuid"], - model_path_prefix=args.model_path_prefix, - op_names_path_prefix=args.op_names_path_prefix, - relative_model_path=args.relative_model_path, - db_path=args.db_path, + model_path_prefix=model_path_prefix, + op_names_path_prefix=op_names_path_prefix, + relative_model_path=relative_model_path, ) insert_sample_input_tensor_meta( + session=session, sample_uuid=data["uuid"], - model_path_prefix=args.model_path_prefix, - relative_model_path=args.relative_model_path, - db_path=args.db_path, + model_path_prefix=model_path_prefix, + relative_model_path=relative_model_path, ) - if args.sample_type in ["fusible_graph", "typical_graph"]: + if sample_type in ["fusible_graph", "typical_graph"]: insert_dimension_generalization_source( - subgraph_source_data["subgraph_uuid"], - subgraph_source_data["full_graph_uuid"], - args.model_path_prefix, - args.relative_model_path, - args.db_path, + session=session, + generalized_graph_uuid=subgraph_source_data["subgraph_uuid"], + original_graph_uuid=subgraph_source_data["full_graph_uuid"], + model_path_prefix=model_path_prefix, + relative_model_path=relative_model_path, ) insert_datatype_generalization_source( - subgraph_source_data["subgraph_uuid"], - subgraph_source_data["full_graph_uuid"], - args.model_path_prefix, - args.relative_model_path, - args.db_path, + session=session, + generalized_graph_uuid=subgraph_source_data["subgraph_uuid"], + original_graph_uuid=subgraph_source_data["full_graph_uuid"], + model_path_prefix=model_path_prefix, + relative_model_path=relative_model_path, ) - print(f"success insert: {data['relative_model_path']}") + session.commit() + print(f"insert {sample_type} success: {data['relative_model_path']}") except sqlite3.IntegrityError as e: - print("insert failed: integrity error (possible duplicate uuid or graph_hash)") + session.rollback() + print( + "insert {sample_type} failed: integrity error (possible duplicate uuid or graph_hash)" + ) print(f"error info: {e}") + successed = False except Exception as e: - print(f"insert failed: {e}") + session.rollback() + print(f"insert {sample_type} failed: {e}") + successed = False + finally: + session.close() + return successed + + +def main(args): + insert_one_sample( + model_path_prefix=args.model_path_prefix, + relative_model_path=args.relative_model_path, + repo_uid=args.repo_uid, + sample_type=args.sample_type, + order_value=args.order_value, + db_path=args.db_path, + op_names_path_prefix=args.op_names_path_prefix, + ) if __name__ == "__main__": diff --git a/sqlite/init_db.py b/sqlite/init_db.py index da1d0ab577..b4214c00c0 100755 --- a/sqlite/init_db.py +++ b/sqlite/init_db.py @@ -1,6 +1,7 @@ import sqlite3 import re import argparse +import os from pathlib import Path @@ -12,7 +13,10 @@ def parse_timestamp(filename: str) -> int: return 0 -def migrate(db_path: str = "sqlite/GraphNet.db", migrates_dir: str = "sqlite/migrates"): +def migrate(db_path: str): + script_dir = os.path.dirname(os.path.abspath(__file__)) + migrates_dir = os.path.join(script_dir, "migrates") + db_path_obj = Path(db_path) migrates_path = Path(migrates_dir) @@ -58,8 +62,8 @@ def migrate(db_path: str = "sqlite/GraphNet.db", migrates_dir: str = "sqlite/mig parser.add_argument( "--db_path", type=str, - default="sqlite/GraphNet.db", - help="Database file path (default: sqlite/GraphNet.db)", + default="GraphNet.db", + help="Database file path", ) args = parser.parse_args() migrate(args.db_path) diff --git a/sqlite/upload.py b/sqlite/upload.py deleted file mode 100755 index 72308f28a4..0000000000 --- a/sqlite/upload.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -from datasets import Dataset -from huggingface_hub import HfApi, login - - -HF_TOKEN = "" -REPO_ID = "PaddlePaddle/GraphNet" -REVISION = "20260203" -BASE_DIR = "/work/GraphNet/torch_paddle_samples/subgraph_dataset_20260203" -FOLDERS_TO_PACK = ["full_graph", "fusible_graph", "sole_op_graph", "typical_graph"] -DB_FILE = "GraphNet.db" - - -def is_clean_file(filename, root): - ext = os.path.splitext(filename)[1].lower() - if ext in {".pyc", ".pyo", ".pyd", ".so"}: - return False - if any(x in root for x in ["__pycache__", ".git", ".ipynb_checkpoints"]): - return False - return True - - -def file_generator(): - file_list = [ - (os.path.join(root, f), folder) - for folder in FOLDERS_TO_PACK - if os.path.exists(os.path.join(BASE_DIR, folder)) - for root, _, files in os.walk(os.path.join(BASE_DIR, folder)) - for f in files - if is_clean_file(f, root) - and os.path.splitext(f)[1].lower() in {".py", ".json", ".txt", ".yaml", ".md"} - ] - - return ( - { - "path": os.path.relpath(fp, BASE_DIR), - "content": open(fp, "r", encoding="utf-8", errors="ignore").read(), - "source_folder": src, - } - for fp, src in file_list - ) - - -def main(): - login(token=HF_TOKEN) - - ds = Dataset.from_generator(file_generator) - ds.push_to_hub(REPO_ID, split="GraphNet", max_shard_size="500MB", revision=REVISION) - print("Folder data uploaded successfully!") - - api = HfApi() - db_path = os.path.join(BASE_DIR, DB_FILE) - if os.path.exists(db_path): - api.upload_file( - path_or_fileobj=db_path, - path_in_repo=DB_FILE, - repo_id=REPO_ID, - repo_type="dataset", - revision=REVISION, - ) - print(f"{DB_FILE} uploaded successfully!") - - -if __name__ == "__main__": - main() diff --git a/sqlite/upload_dataset.py b/sqlite/upload_dataset.py new file mode 100755 index 0000000000..d2e40756b0 --- /dev/null +++ b/sqlite/upload_dataset.py @@ -0,0 +1,90 @@ +import argparse +import os +from datasets import Dataset +from huggingface_hub import HfApi, login + + +def is_clean_file(filename, root): + ext = os.path.splitext(filename)[1].lower() + if ext in {".pyc", ".pyo", ".pyd", ".so"}: + return False + if any(x in root for x in ["__pycache__", ".git", ".ipynb_checkpoints"]): + return False + return True + + +def file_generator(base_dir, folders): + file_list = [ + (os.path.join(root, f), folder) + for folder in folders + if os.path.exists(os.path.join(base_dir, folder)) + for root, _, files in os.walk(os.path.join(base_dir, folder)) + for f in files + if is_clean_file(f, root) + and os.path.splitext(f)[1].lower() in {".py", ".json", ".txt", ".yaml", ".md"} + ] + + return ( + { + "path": os.path.relpath(fp, base_dir), + "content": open(fp, "r", encoding="utf-8", errors="ignore").read(), + "source_folder": src, + } + for fp, src in file_list + ) + + +def main(args): + folders = ["full_graph", "fusible_graph", "sole_op_graph", "typical_graph"] + + login(token=args.hf_token) + + ds = Dataset.from_generator(lambda: file_generator(args.base_dir, folders)) + ds.push_to_hub( + args.repo_id, + split=args.split, + max_shard_size=args.max_shard_size, + revision=args.revision, + ) + print("Folder data uploaded successfully!") + + api = HfApi() + db_path = os.path.join(args.base_dir, args.db_file) + if os.path.exists(db_path): + api.upload_file( + path_or_fileobj=db_path, + path_in_repo=args.db_file, + repo_id=args.repo_id, + repo_type="dataset", + revision=args.revision, + ) + print(f"{args.db_file} uploaded successfully!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Upload dataset and DB to HuggingFace Hub" + ) + parser.add_argument( + "--hf_token", type=str, required=True, help="HuggingFace API token" + ) + parser.add_argument( + "--repo_id", type=str, default="PaddlePaddle/GraphNet", help="HF repo ID" + ) + parser.add_argument( + "--revision", type=str, default="main", help="HF repo revision/branch" + ) + parser.add_argument( + "--base_dir", type=str, required=True, help="Local dataset root directory" + ) + parser.add_argument( + "--db_file", type=str, default="GraphNet.db", help="DB filename in base_dir" + ) + parser.add_argument( + "--split", type=str, default="GraphNet", help="Dataset split name" + ) + parser.add_argument( + "--max_shard_size", type=str, default="500MB", help="Max shard size" + ) + args = parser.parse_args() + main(args)