From 32bb5525eb3e88754e29d87d2bb78df7a6ec7727 Mon Sep 17 00:00:00 2001 From: Molly He Date: Thu, 13 Nov 2025 14:14:15 -0800 Subject: [PATCH 01/22] Model customization Init Experience Flow (#290) * model customization init/find model * Adding direct create exp * Model customization Init/Create/Find * Latest model cust changes * init migration done with template validation * Init full experience migrated, CRUDL simple addition in hyp_cli.py, unit tests added, pending nova forge happy case for integ test * remove argcomplete since it is not supported yet * add reset command for dynamic template * fix integ test error for init flow * remove recipe finder and discovery changes --------- Co-authored-by: Amarjeet LNU --- .../hyperpod/cli/commands/inference.py | 8 +- src/sagemaker/hyperpod/cli/commands/init.py | 108 +++- .../cli/commands/training_fine_tuning.py | 500 ++++++++++++++++++ .../hyperpod/cli/constants/init_constants.py | 7 + src/sagemaker/hyperpod/cli/hyp_cli.py | 39 +- src/sagemaker/hyperpod/cli/init_utils.py | 190 ++++++- .../hyperpod/cli/type_handler_utils.py | 24 + .../init/test_init_workflow.py | 12 +- test/integration_tests/init/utils.py | 19 +- test/unit_tests/cli/test_init_dynamic.py | 89 ++++ test/unit_tests/cli/test_init_utils.py | 3 +- .../unit_tests/cli/test_init_utils_dynamic.py | 79 +++ .../cli/test_training_fine_tuning.py | 194 +++++++ 13 files changed, 1211 insertions(+), 61 deletions(-) create mode 100644 src/sagemaker/hyperpod/cli/commands/training_fine_tuning.py create mode 100644 test/unit_tests/cli/test_init_dynamic.py create mode 100644 test/unit_tests/cli/test_init_utils_dynamic.py create mode 100644 test/unit_tests/cli/test_training_fine_tuning.py diff --git a/src/sagemaker/hyperpod/cli/commands/inference.py b/src/sagemaker/hyperpod/cli/commands/inference.py index 49c14fa0..77cc5e67 100644 --- a/src/sagemaker/hyperpod/cli/commands/inference.py +++ b/src/sagemaker/hyperpod/cli/commands/inference.py @@ -134,7 +134,7 @@ def js_list( namespace: Optional[str], ): """ - List all Hyperpod Jumpstart model endpoints. + List all HyperPod Jumpstart model endpoints. """ endpoints = HPJumpStartEndpoint.model_construct().list(namespace) data = [ep.model_dump() for ep in endpoints] @@ -177,7 +177,7 @@ def custom_list( namespace: Optional[str], ): """ - List all Hyperpod custom model endpoints. + List all HyperPod custom model endpoints. """ endpoints = HPEndpoint.model_construct().list(namespace) data = [ep.model_dump() for ep in endpoints] @@ -236,7 +236,7 @@ def js_describe( full: bool ): """ - Describe a Hyperpod Jumpstart model endpoint. + Describe a HyperPod Jumpstart model endpoint. """ my_endpoint = HPJumpStartEndpoint.model_construct().get(name, namespace) data = my_endpoint.model_dump() @@ -385,7 +385,7 @@ def custom_describe( full: bool ): """ - Describe a Hyperpod custom model endpoint. + Describe a HyperPod custom model endpoint. """ my_endpoint = HPEndpoint.model_construct().get(name, namespace) data = my_endpoint.model_dump() diff --git a/src/sagemaker/hyperpod/cli/commands/init.py b/src/sagemaker/hyperpod/cli/commands/init.py index c3a54d16..5c388e16 100644 --- a/src/sagemaker/hyperpod/cli/commands/init.py +++ b/src/sagemaker/hyperpod/cli/commands/init.py @@ -1,10 +1,10 @@ import click import yaml import sys +import shutil from pathlib import Path from datetime import datetime from jinja2 import Template -import shutil from sagemaker.hyperpod.cli.constants.init_constants import ( USAGE_GUIDE_TEXT_CFN, USAGE_GUIDE_TEXT_CRD, @@ -24,23 +24,32 @@ build_config_from_schema, save_template, get_default_version_for_template, - create_from_k8s_yaml + create_from_k8s_yaml, + is_dynamic_template ) from sagemaker.hyperpod.common.utils import get_aws_default_region from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( _hyperpod_telemetry_emitter, ) from sagemaker.hyperpod.common.telemetry.constants import Feature +from sagemaker.hyperpod.cli.commands.training_fine_tuning import _init_fine_tuning_job, _configure_dynamic_template, _validate_dynamic_template, _create_dynamic_template, _generate_dynamic_config_yaml + @click.command("init") -@click.argument("template", type=click.Choice(list(TEMPLATES.keys()))) +@click.argument("template", type=click.Choice(list(TEMPLATES.keys()) + ["fine-tuning-job"])) @click.argument("directory", type=click.Path(file_okay=False), default=".") @click.option("--version", "-v", default=None, help="Schema version") +@click.option("--model-name", help="Model name from SageMaker Public Hub (for fine-tuning-job)") +@click.option("--technique", help="Customization technique (for fine-tuning-job)") +@click.option("--instance-type", help="Instance type (for fine-tuning-job)") @_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "init_template_cli") def init( template: str, directory: str, version: str, + model_name: str, + technique: str, + instance_type: str, ): """ Initialize a TEMPLATE scaffold in DIRECTORY. @@ -57,6 +66,7 @@ def init( The generated files provide a starting point for configuring and submitting jobs to SageMaker HyperPod clusters orchestrated by Amazon EKS. """ + # Original template initialization logic dir_path = Path(directory).resolve() config_file = dir_path / "config.yaml" skip_readme = False @@ -65,8 +75,8 @@ def init( try: if config_file.is_file(): try: - existing = yaml.safe_load(config_file.read_text()) or {} - existing_template = existing.get("template") + # Use load_config to properly read commented template + _, existing_template, _ = load_config(dir_path) except Exception as e: click.echo("Could not parse existing config.yaml: %s", e) existing_template = None @@ -100,6 +110,17 @@ def init( click.secho(f"āŒ Could not create directory {dir_path}: {e}", fg="red") sys.exit(1) + # Handle fine-tuning-job template after validation + if template == "fine-tuning-job": + if not model_name or not technique or not instance_type: + click.secho("āŒ --model-name, --technique, and --instance-type are required for fine-tuning-job", fg="red") + return + + if _init_fine_tuning_job(directory, model_name, technique, instance_type): + click.secho("āœ”ļø Fine-tuning job initialized successfully", fg="green") + click.secho("šŸ“„ Created: config.yaml, k8s.jinja", fg="green") + return + # 3) Build config dict + comment map, then write config.yaml try: # Determine version: use user-provided version or default to latest @@ -162,9 +183,19 @@ def reset(): # 1) Load and validate config data, template, version = load_config(dir_path) - # 2) Build config with default values from schema + # 2) Check if this is a dynamic template + if is_dynamic_template(template, dir_path): + # For dynamic templates, reset using the helper function + try: + _generate_dynamic_config_yaml(dir_path, template, version) + click.secho("āœ”ļø config.yaml reset: all fields set to default values.", fg="green") + except Exception as e: + click.secho(f"šŸ’„ Could not reset config.yaml: {e}", fg="red") + sys.exit(1) + return + + # 3) Standard template reset logic full_cfg, comment_map = build_config_from_schema(template, version) - # 3) Overwrite config.yaml try: save_config_yaml( prefill=full_cfg, @@ -185,7 +216,7 @@ def reset(): @generate_click_command() @click.pass_context @_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "init_configure_cli") -def configure(ctx, model_config): +def configure(ctx, option, value, model_config): """ Update any subset of fields in ./config.yaml by passing -- flags. @@ -201,16 +232,27 @@ def configure(ctx, model_config): # Update multiple fields at once hyp configure --stack-name my-stack --create-fsx-stack: False - + # Update complex fields with JSON object hyp configure --availability-zone-ids '["id1", "id2"]' - """ # 1) Load existing config without validation dir_path = Path(".").resolve() data, template, version = load_config(dir_path) - # 2) Determine which fields the user actually provided + # 2) Check if this is a dynamic template (fine-tuning) + if is_dynamic_template(template, dir_path): + # Handle fine-tuning configure logic + _configure_dynamic_template(ctx, option, value, dir_path) + return + + # 3) Handle standard template configure logic + _configure_standard_template(ctx, model_config, dir_path, data, template, version) + + +def _configure_standard_template(ctx, model_config, dir_path, data, template, version): + """Handle configure for standard templates""" + # Determine which fields the user actually provided # Use Click's parameter source tracking to identify command-line provided parameters user_input_fields = set() @@ -223,10 +265,10 @@ def configure(ctx, model_config): user_input_fields.add(param_name) if not user_input_fields: - click.secho("āš ļø No arguments provided to configure.", fg="yellow") - return + click.echo(ctx.get_help()) + ctx.exit(0) - # 3) Build merged config with user input + # Build merged config with user input full_cfg, comment_map = build_config_from_schema( template=template, version=version, @@ -235,7 +277,7 @@ def configure(ctx, model_config): user_provided_fields=user_input_fields ) - # 4) Validate the merged config, but only check user-provided fields + # Validate the merged config, but only check user-provided fields all_validation_errors = validate_config_against_model(full_cfg, template, version) user_input_errors = filter_validation_errors_for_user_input(all_validation_errors, user_input_fields) @@ -249,7 +291,7 @@ def configure(ctx, model_config): click.secho("āŒ config.yaml was not updated due to invalid input.", fg="red") sys.exit(1) - # 5) Write out the updated config.yaml (only if user input is valid) + # Write out the updated config.yaml (only if user input is valid) try: save_config_yaml( prefill=full_cfg, @@ -268,7 +310,26 @@ def validate(): Validate this directory's config.yaml against the appropriate schema. """ dir_path = Path(".").resolve() - load_config_and_validate(dir_path) + + try: + # Load config to determine template type + data, template, version = load_config(dir_path) + + # Check if this is a dynamic template + if is_dynamic_template(template, dir_path): + # Validate dynamic template + _validate_dynamic_template(dir_path) + click.secho("āœ”ļø Configuration validated successfully", fg="green") + else: + # Use standard validation + load_config_and_validate(dir_path) + click.secho("āœ”ļø Configuration validated successfully", fg="green") + except (FileNotFoundError, ValueError) as e: + click.secho(f"āŒ {e}", fg="red") + sys.exit(1) + except Exception as e: + click.secho(f"āŒ Validation failed: {e}", fg="red") + sys.exit(1) @click.command(name="_default_create") @@ -310,6 +371,17 @@ def _default_create(region, template_version, debug): # 1) Load config to determine template type data, template, version = load_config_and_validate(dir_path) + # Check if this is a dynamic template (fine-tuning) + if is_dynamic_template(template, dir_path): + _create_dynamic_template(dir_path, data) + return + + # Handle standard templates (existing logic) + _create_standard_template(dir_path, data, template, version, region, template_version) + + +def _create_standard_template(dir_path: Path, data: dict, template: str, version: str, region: str, template_version: int): + """Handle create for standard templates""" # Check if region flag is used for non-cluster-stack templates if region and template != "cluster-stack": click.secho(f"āŒ --region flag is only available for cluster-stack template, not for {template}.", fg="red") @@ -324,6 +396,7 @@ def _default_create(region, template_version, debug): jinja_file = dir_path / 'k8s.jinja' # 3) Ensure files exist + config_file = dir_path / 'config.yaml' if not config_file.is_file() or not jinja_file.is_file(): click.secho(f"āŒ Missing config.yaml or {jinja_file.name}. Run `hyp init` first.", fg="red") sys.exit(1) @@ -387,7 +460,6 @@ def _default_create(region, template_version, debug): k8s_file = out_dir / 'k8s.yaml' create_from_k8s_yaml(str(k8s_file), debug=debug) - except Exception as e: click.secho(f"āŒ Failed to submit the command: {e}", fg="red") sys.exit(1) \ No newline at end of file diff --git a/src/sagemaker/hyperpod/cli/commands/training_fine_tuning.py b/src/sagemaker/hyperpod/cli/commands/training_fine_tuning.py new file mode 100644 index 00000000..16e4dcee --- /dev/null +++ b/src/sagemaker/hyperpod/cli/commands/training_fine_tuning.py @@ -0,0 +1,500 @@ +from jinja2 import Template +import boto3 +from datetime import datetime, timezone +from kubernetes import client, config +from kubernetes.client.rest import ApiException +import yaml +import json +import os +import sys +import click +from pathlib import Path +from sagemaker.hyperpod.cli.init_utils import load_dynamic_schema +from sagemaker.hyperpod.common.utils import handle_exception +from sagemaker.hyperpod.cli.type_handler_utils import is_undefined_value +import shutil + +_sagemaker_client = None +_s3_client = None +_k8s_custom_client = None + +def get_sagemaker_client(): + global _sagemaker_client + if _sagemaker_client is None: + _sagemaker_client = boto3.client( + "sagemaker", + endpoint_url="https://sagemaker.beta.us-west-2.ml-platform.aws.a2z.com" + ) + return _sagemaker_client + +def get_s3_client(): + global _s3_client + if _s3_client is None: + _s3_client = boto3.client("s3") + return _s3_client + +def get_k8s_custom_client(): + """Get Kubernetes custom objects API client for PyTorchJob resources.""" + global _k8s_custom_client + if _k8s_custom_client is None: + try: + config.load_kube_config() + except config.ConfigException: + try: + config.load_incluster_config() + except config.ConfigException: + raise Exception("Could not configure kubernetes python client") + _k8s_custom_client = client.CustomObjectsApi() + return _k8s_custom_client + + +def _init_fine_tuning_job(directory: str, model_name: str, technique: str, instance_type: str) -> bool: + """Initialize fine-tuning job configuration.""" + try: + # Get clients + sagemaker_client = boto3.client( + "sagemaker", + endpoint_url="https://sagemaker.beta.us-west-2.ml-platform.aws.a2z.com" + ) + s3_client = boto3.client("s3") + + # Fetch recipe + request = { + "HubName": "SageMakerPublicHub", + "HubContentType": "Model", + "HubContentName": model_name + } + + describe_response = sagemaker_client.describe_hub_content(**request) + hub_content_doc = json.loads(describe_response.get('HubContentDocument', '{}')) + recipe_collection = hub_content_doc.get('RecipeCollection', []) + + matching_recipe = None + for recipe in recipe_collection: + if recipe.get('CustomizationTechnique') == technique: + matching_recipe = recipe + break + + if not matching_recipe: + click.secho(f"āŒ No recipe found for technique: {technique}", fg="red") + return False + + if instance_type not in matching_recipe.get('SupportedInstanceTypes', []): + click.secho(f"āŒ Instance type {instance_type} not supported for this model and technique", fg="red") + click.secho(f"Supported instance types: {matching_recipe.get('SupportedInstanceTypes', [])}", fg="red") + return False + + override_params_uri = matching_recipe.get('HpEksOverrideParamsS3Uri') + k8s_template_uri = matching_recipe.get('HpEksPayloadTemplateS3Uri') + + if not override_params_uri or not k8s_template_uri: + click.secho("āŒ Missing S3 URIs in recipe", fg="red") + return False + + # Create directory + dir_path = Path(directory).resolve() + dir_path.mkdir(parents=True, exist_ok=True) + + # Download override params + override_bucket = override_params_uri.split('/')[2] + override_key = '/'.join(override_params_uri.split('/')[3:]) + override_obj = s3_client.get_object(Bucket=override_bucket, Key=override_key) + override_data = json.loads(override_obj['Body'].read()) + + # Save override spec + with open(dir_path / '.override_spec.json', 'w') as f: + json.dump(override_data, f, indent=2) + + # Create config.yaml + _generate_dynamic_config_yaml(dir_path, "fine-tuning-job", model_name=model_name, technique=technique, instance_type=instance_type) + + # Download k8s template + k8s_bucket = k8s_template_uri.split('/')[2] + k8s_key = '/'.join(k8s_template_uri.split('/')[3:]) + k8s_obj = s3_client.get_object(Bucket=k8s_bucket, Key=k8s_key) + k8s_content = k8s_obj['Body'].read().decode('utf-8') + + with open(dir_path / 'k8s.jinja', 'w') as f: + f.write(k8s_content) + + return True + + except Exception as e: + click.secho(f"āŒ Error: {e}", fg="red") + return False + +def _configure_dynamic_template(ctx, option, value, dir_path): + """Handle configure for dynamic templates (fine-tuning)""" + config_path = dir_path / "config.yaml" + spec_path = dir_path / ".override_spec.json" + + if not spec_path.exists(): + click.secho(f"āŒ .override_spec.json not found", fg="red") + ctx.exit(1) + + # Load spec + spec = load_dynamic_schema(dir_path) + + # Check if user provided --option flags (only those explicitly provided, not defaults) + provided_options = {} + for param_name, param_value in ctx.params.items(): + if param_name not in ['option', 'value', 'model_config']: + # Check if this parameter was actually provided by the user (not a default) + param_source = ctx.get_parameter_source(param_name) + if param_source and param_source.name == 'COMMANDLINE' and param_value is not None: + # Convert back to original key format + original_key = param_name.replace('-', '_') + if original_key in spec: + provided_options[original_key] = param_value + + # If --option flags were used, process them + if provided_options: + for key, value in provided_options.items(): + _update_config_field(config_path, spec, key, value) + + click.secho(f"āœ”ļø config.yaml updated successfully.", fg="green") + return + + # If no arguments, show help like --help does + click.echo(ctx.get_help()) + ctx.exit(0) + + # Validate option exists + if option not in spec: + click.secho(f"āŒ Unknown option: {option}", fg="red") + click.echo(f"\nRun 'hyp configure' to see available options") + ctx.exit(1) + + if value is None: + click.secho(f"āŒ Value required for option: {option}", fg="red") + ctx.exit(1) + + # Validate and convert value + option_spec = spec[option] + value_type = option_spec.get("type", "string") + + try: + if value_type == "integer": + converted_value = int(value) + elif value_type == "float": + converted_value = float(value) + elif value_type == "string": + converted_value = str(value) + else: + converted_value = value + + # Validate constraints + if "min" in option_spec and converted_value < option_spec["min"]: + click.secho(f"āŒ Value {converted_value} is below minimum {option_spec['min']}", fg="red") + ctx.exit(1) + + if "max" in option_spec and converted_value > option_spec["max"]: + click.secho(f"āŒ Value {converted_value} exceeds maximum {option_spec['max']}", fg="red") + ctx.exit(1) + + if "enum" in option_spec and converted_value not in option_spec["enum"]: + click.secho(f"āŒ Value {converted_value} not in allowed values: {option_spec['enum']}", fg="red") + ctx.exit(1) + + except ValueError as e: + click.secho(f"āŒ Invalid value type. Expected {value_type}: {e}", fg="red") + ctx.exit(1) + + # Load and update config.yaml + with open(config_path, 'r') as f: + lines = f.readlines() + + updated = False + new_lines = [] + for line in lines: + if line.strip().startswith(f"{option}:"): + # Preserve comments and formatting + indent = len(line) - len(line.lstrip()) + if value_type == "string": + new_lines.append(f"{' ' * indent}{option}: \"{converted_value}\"\n") + else: + new_lines.append(f"{' ' * indent}{option}: {converted_value}\n") + updated = True + else: + new_lines.append(line) + + if not updated: + click.secho(f"āŒ Option {option} not found in config.yaml", fg="red") + ctx.exit(1) + + # Write back to config.yaml + with open(config_path, 'w') as f: + f.writelines(new_lines) + + click.secho(f"āœ… Successfully set {option} = {converted_value}", fg="green") + + +def _validate_dynamic_template(dir_path: Path): + """Validate dynamic template config against .override_spec.json""" + spec_path = dir_path / ".override_spec.json" + if not spec_path.exists(): + raise FileNotFoundError(".override_spec.json not found") + + spec = load_dynamic_schema(dir_path) + config_data = yaml.safe_load((dir_path / "config.yaml").read_text()) or {} + + validation_errors = [] + for key, field_spec in spec.items(): + value = config_data.get(key) + required = field_spec.get("required", False) + field_type = field_spec.get("type", "string") + + if required and (value is None or value == ""): + validation_errors.append(f"{key}: Required field is missing or empty") + continue + + if value is None: + continue + + # Type validation + if field_type == "integer" and not isinstance(value, int): + validation_errors.append(f"{key}: Expected integer, got {type(value).__name__}") + elif field_type == "float" and not isinstance(value, (int, float)): + validation_errors.append(f"{key}: Expected number, got {type(value).__name__}") + elif field_type == "string" and not isinstance(value, str): + validation_errors.append(f"{key}: Expected string, got {type(value).__name__}") + + # Constraint validation + if isinstance(value, (int, float)): + if "min" in field_spec and value < field_spec["min"]: + validation_errors.append(f"{key}: Value {value} below minimum {field_spec['min']}") + if "max" in field_spec and value > field_spec["max"]: + validation_errors.append(f"{key}: Value {value} exceeds maximum {field_spec['max']}") + + if "enum" in field_spec and value not in field_spec["enum"]: + validation_errors.append(f"{key}: Value {value} not in allowed values: {field_spec['enum']}") + + if validation_errors: + raise ValueError("Config validation failed:\n" + "\n".join(f" • {error}" for error in validation_errors)) + + return True + + +def _create_dynamic_template(dir_path: Path, config_data: dict): + """Handle create for dynamic templates (fine-tuning)""" + try: + # Validate config first + _validate_dynamic_template(dir_path) + click.secho("āœ”ļø Configuration validated successfully", fg="green") + + k8s_template_file = dir_path / 'k8s.jinja' + if not k8s_template_file.exists(): + raise FileNotFoundError("k8s.jinja template not found") + + # Read and render template + template_content = k8s_template_file.read_text() + template = Template(template_content) + rendered = template.render(**config_data) + + # Create run directory + run_root = dir_path / 'run' + run_root.mkdir(exist_ok=True) + timestamp = datetime.now().strftime('%Y%m%dT%H%M%S') + out_dir = run_root / timestamp + out_dir.mkdir() + + # Save files + shutil.copy(dir_path / 'config.yaml', out_dir / 'config.yaml') + (out_dir / 'k8s.yaml').write_text(rendered) + + relative_out_dir = Path("run") / timestamp + click.secho(f"āœ”ļø Files written to {relative_out_dir}", fg="green") + + # Parse and submit to Kubernetes using Python client + k8s_documents = list(yaml.safe_load_all(rendered)) + custom_api = get_k8s_custom_client() + + for k8s_config in k8s_documents: + if not k8s_config: # Skip empty documents + continue + + # Extract resource details from the k8s config + api_version = k8s_config.get('apiVersion', '') + kind = k8s_config.get('kind', '') + metadata = k8s_config.get('metadata', {}) + namespace = metadata.get('namespace', 'default') + + # Handle standard Kubernetes resources vs custom resources + if api_version == 'v1' or api_version.startswith('apps/') or api_version.startswith('extensions/'): + # Standard Kubernetes resource - use CoreV1Api or AppsV1Api + core_api = client.CoreV1Api() + + if kind == 'ConfigMap': + core_api.create_namespaced_config_map(namespace=namespace, body=k8s_config) + elif kind == 'Secret': + core_api.create_namespaced_secret(namespace=namespace, body=k8s_config) + elif kind == 'Service': + core_api.create_namespaced_service(namespace=namespace, body=k8s_config) + else: + # For other standard resources, fall back to custom API + custom_api.create_namespaced_custom_object( + group='', + version=api_version, + namespace=namespace, + plural=kind.lower() + 's', + body=k8s_config, + ) + else: + # Custom resource - use CustomObjectsApi + if '/' in api_version: + group, version = api_version.split('/', 1) + else: + group = '' + version = api_version + + # Convert kind to plural (simple heuristic) + plural = kind.lower() + 's' if not kind.lower().endswith('s') else kind.lower() + + custom_api.create_namespaced_custom_object( + group=group, + version=version, + namespace=namespace, + plural=plural, + body=k8s_config, + ) + + click.secho("āœ”ļø Successfully submitted to Kubernetes", fg="green") + + except (FileNotFoundError, ValueError) as e: + click.secho(f"āŒ {e}", fg="red") + sys.exit(1) + except Exception as e: + # Use existing handle_exception for Kubernetes errors + try: + # Extract resource name from config for better error messages + resource_name = config_data.get('name', 'unknown') + handle_exception(e, resource_name, 'default') + except Exception as handled_e: + click.secho(f"āŒ {handled_e}", fg="red") + sys.exit(1) + + +def _generate_dynamic_config_yaml(dir_path: Path, template: str, version: str = None, model_name: str = None, technique: str = None, instance_type: str = None): + """Generate config.yaml for dynamic templates with default values""" + spec = load_dynamic_schema(dir_path) + + # Try to preserve existing metadata from current config + existing_model = model_name + existing_technique = technique + existing_instance_type = instance_type + + config_path = dir_path / 'config.yaml' + if config_path.exists(): + try: + with open(config_path, 'r') as f: + for line in f: + if line.startswith('# model: ') and not existing_model: + existing_model = line.replace('# model: ', '').strip() + elif line.startswith('# fine tune technique: ') and not existing_technique: + existing_technique = line.replace('# fine tune technique: ', '').strip() + elif line.startswith('# instance type: ') and not existing_instance_type: + existing_instance_type = line.replace('# instance type: ', '').strip() + except: + pass # If reading fails, use provided values + + with open(config_path, 'w') as f: + f.write(f"# template: {template}\n") + if existing_model: + f.write(f"# model: {existing_model}\n") + if existing_technique: + f.write(f"# fine tune technique: {existing_technique}\n") + if existing_instance_type: + f.write(f"# instance type: {existing_instance_type}\n") + f.write("\n") + + for key, param_spec in spec.items(): + default_value = param_spec.get('default') + param_type = param_spec.get('type', 'string') + min_val = param_spec.get('min') + max_val = param_spec.get('max') + description = param_spec.get('description', '') + required = param_spec.get('required', False) + + if description: + f.write(f"# {description}\n") + f.write(f"# Type: {param_type}") + if min_val is not None: + f.write(f", Min: {min_val}") + if max_val is not None: + f.write(f", Max: {max_val}") + f.write(f", Required: {required}\n") + + if default_value is None: + f.write(f"{key}: null\n\n") + elif isinstance(default_value, str): + f.write(f"{key}: {default_value}\n\n") + elif isinstance(default_value, (list, dict)): + f.write(f"{key}: {json.dumps(default_value)}\n\n") + else: + f.write(f"{key}: {default_value}\n\n") + + +def _update_config_field(config_path, spec, option, value): + """Update a single field in config.yaml for dynamic templates""" + # Validate option exists + if option not in spec: + click.secho(f"āŒ Unknown option: {option}", fg="red") + sys.exit(1) + + if is_undefined_value(value): + click.secho(f"āŒ Value required for option: {option}", fg="red") + sys.exit(1) + + # Validate and convert value + option_spec = spec[option] + value_type = option_spec.get("type", "string") + + try: + if value_type == "integer": + converted_value = int(value) + elif value_type == "float": + converted_value = float(value) + elif value_type == "string": + converted_value = str(value) + else: + converted_value = value + + # Validate constraints + if "min" in option_spec and converted_value < option_spec["min"]: + click.secho(f"āŒ Value {converted_value} is below minimum {option_spec['min']}", fg="red") + sys.exit(1) + + if "max" in option_spec and converted_value > option_spec["max"]: + click.secho(f"āŒ Value {converted_value} exceeds maximum {option_spec['max']}", fg="red") + sys.exit(1) + + if "enum" in option_spec and converted_value not in option_spec["enum"]: + click.secho(f"āŒ Value {converted_value} not in allowed values: {option_spec['enum']}", fg="red") + sys.exit(1) + + except ValueError as e: + click.secho(f"āŒ Invalid value type. Expected {value_type}: {e}", fg="red") + sys.exit(1) + + # Load and update config.yaml - preserve existing values + with open(config_path, 'r') as f: + lines = f.readlines() + + updated = False + new_lines = [] + for line in lines: + if line.strip().startswith(f"{option}:"): + # Preserve comments and formatting, no quotes for strings + indent = len(line) - len(line.lstrip()) + new_lines.append(f"{' ' * indent}{option}: {converted_value}\n") + updated = True + else: + new_lines.append(line) + + if not updated: + click.secho(f"āŒ Option {option} not found in config.yaml", fg="red") + sys.exit(1) + + # Write back to config.yaml + with open(config_path, 'w') as f: + f.writelines(new_lines) diff --git a/src/sagemaker/hyperpod/cli/constants/init_constants.py b/src/sagemaker/hyperpod/cli/constants/init_constants.py index 3168484d..f8e3b042 100644 --- a/src/sagemaker/hyperpod/cli/constants/init_constants.py +++ b/src/sagemaker/hyperpod/cli/constants/init_constants.py @@ -38,6 +38,13 @@ "schema_pkg": "hyperpod_cluster_stack_template", "schema_type": CFN, 'type': "jinja" + }, + "fine-tuning-job": { + "registry": {}, + "template_registry": {}, + "schema_pkg": None, + "schema_type": CRD, + 'type': "dynamic" } } diff --git a/src/sagemaker/hyperpod/cli/hyp_cli.py b/src/sagemaker/hyperpod/cli/hyp_cli.py index a33aee29..77b1d3d2 100644 --- a/src/sagemaker/hyperpod/cli/hyp_cli.py +++ b/src/sagemaker/hyperpod/cli/hyp_cli.py @@ -1,11 +1,7 @@ import click -import yaml -import json -import os -import subprocess -from pydantic import BaseModel, ValidationError, Field -from typing import Optional, Union +from typing import Union from importlib.metadata import version, PackageNotFoundError +import copy from sagemaker.hyperpod.cli.commands.cluster import list_cluster, set_cluster_context, get_cluster_context, \ get_monitoring, describe_cluster @@ -62,18 +58,20 @@ from sagemaker.hyperpod.cli.commands.init import ( init, reset, - configure, validate, + configure, _default_create ) + def get_package_version(package_name): try: return version(package_name) except PackageNotFoundError: return "Not installed" + def print_version(ctx, param, value): if not value or ctx.resilient_parsing: return @@ -91,7 +89,8 @@ def print_version(ctx, param, value): @click.group(context_settings={'max_content_width': 200}) -@click.option('--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True, help='Show version information') +@click.option('--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True, + help='Show version information') def cli(): pass @@ -142,11 +141,13 @@ def describe(): """Describe endpoints, pytorch jobs or cluster stacks, spaces or space template.""" pass + @cli.group(cls=CLICommand) def update(): """Update an existing HyperPod cluster configuration, space, or space template.""" pass + @cli.group(cls=CLICommand) def delete(): """Delete endpoints, pytorch jobs, space, space access or space template.""" @@ -217,6 +218,9 @@ def exec(): create.add_command(space_access_create) list.add_command(list_jobs) +fine_tuning_list_cmd = copy.copy(list_jobs) +fine_tuning_list_cmd.help = "List all HyperPod fine-tuning jobs" +list.add_command(fine_tuning_list_cmd, name="fine-tuning-job") list.add_command(js_list) list.add_command(custom_list) list.add_command(list_cluster_stacks) @@ -224,6 +228,9 @@ def exec(): list.add_command(space_template_list) describe.add_command(pytorch_describe) +fine_tuning_describe_cmd = copy.copy(pytorch_describe) +fine_tuning_describe_cmd.help = "Describe a HyperPod fine-tuning job." +describe.add_command(fine_tuning_describe_cmd, name="fine-tuning-job") describe.add_command(js_describe) describe.add_command(custom_describe) describe.add_command(describe_cluster_stack) @@ -237,6 +244,9 @@ def exec(): update.add_command(space_template_update) delete.add_command(pytorch_delete) +fine_tuning_describe_cmd = copy.copy(pytorch_describe) +fine_tuning_describe_cmd.help = "Describe a HyperPod fine-tuning job." +describe.add_command(fine_tuning_describe_cmd, name="fine-tuning-job") delete.add_command(js_delete) delete.add_command(custom_delete) delete.add_command(delete_cluster_stack) @@ -248,10 +258,16 @@ def exec(): stop.add_command(space_stop) list_pods.add_command(pytorch_list_pods) +fine_tuning_list_pods_cmd = copy.copy(pytorch_list_pods) +fine_tuning_list_pods_cmd.help = "List all HyperPod PyTorch pods related to the fine-tuning job." +list_pods.add_command(fine_tuning_list_pods_cmd, name="fine-tuning-job") list_pods.add_command(js_list_pods) list_pods.add_command(custom_list_pods) get_logs.add_command(pytorch_get_logs) +fine_tuning_get_logs_cmd = copy.copy(pytorch_get_logs) +fine_tuning_get_logs_cmd.help = "Get specific pod log for Hyperpod fine-tuning job." +get_logs.add_command(fine_tuning_get_logs_cmd, name="fine-tuning-job") get_logs.add_command(js_get_logs) get_logs.add_command(custom_get_logs) get_logs.add_command(space_get_logs) @@ -259,11 +275,16 @@ def exec(): portforward.add_command(space_portforward) get_operator_logs.add_command(pytorch_get_operator_logs) +fine_tuning_get_operator_logs_cmd = copy.copy(pytorch_get_operator_logs) +fine_tuning_get_operator_logs_cmd.help = "Get operator logs for Hyperpod fine-tuning jobs." +get_operator_logs.add_command(fine_tuning_get_operator_logs_cmd, name="fine-tuning-job") get_operator_logs.add_command(js_get_operator_logs) get_operator_logs.add_command(custom_get_operator_logs) invoke.add_command(custom_invoke) -invoke.add_command(custom_invoke, name="hyp-jumpstart-endpoint") +jumpstart_invoke_cmd = copy.copy(custom_invoke) +jumpstart_invoke_cmd.help = "Invoke a jumpstart model endpoint." +invoke.add_command(jumpstart_invoke_cmd, name="hyp-jumpstart-endpoint") cli.add_command(list_cluster) cli.add_command(set_cluster_context) diff --git a/src/sagemaker/hyperpod/cli/init_utils.py b/src/sagemaker/hyperpod/cli/init_utils.py index f36837d7..9ef2338e 100644 --- a/src/sagemaker/hyperpod/cli/init_utils.py +++ b/src/sagemaker/hyperpod/cli/init_utils.py @@ -8,7 +8,7 @@ import yaml import sys from pathlib import Path -from sagemaker.hyperpod.cli.type_handler_utils import convert_cli_value, to_click_type, is_complex_type, DEFAULT_TYPE_HANDLER +from sagemaker.hyperpod.cli.type_handler_utils import convert_cli_value, to_click_type, is_complex_type, DEFAULT_TYPE_HANDLER, create_click_option, is_undefined_value from pydantic import ValidationError from typing import List, Any from sagemaker.hyperpod.cli.constants.init_constants import ( @@ -187,6 +187,7 @@ def generate_click_command() -> Callable: """ Decorator that: - injects -- for every property in the current template's schema (detected from config.yaml) + - supports both standard templates (Pydantic) and dynamic templates (.override_spec.json) - only works for configure command, returns minimal decorator for others """ @@ -204,7 +205,101 @@ def decorator(func: Callable) -> Callable: click.secho("āŒ No config.yaml found. Run 'hyp init