1414# limitations under the License.
1515from __future__ import annotations
1616
17- import os
1817from typing import (
1918 TYPE_CHECKING ,
2019 Any ,
2625 cast ,
2726)
2827
29- from pydantic import (
30- BaseModel ,
31- ConfigDict ,
32- Field ,
28+ from pydantic import Field
29+
30+ from datacustomcode .common_config import (
31+ BaseConfig ,
32+ BaseObjectConfig ,
33+ ForceableConfig ,
34+ default_config_file ,
3335)
34- import yaml
3536
3637# This lets all readers and writers to be findable via config
3738from datacustomcode .io import * # noqa: F403
4243from datacustomcode .proxy .client .base import BaseProxyClient # noqa: TCH002
4344from datacustomcode .spark .base import BaseSparkSessionProvider
4445
45- DEFAULT_CONFIG_NAME = "config.yaml"
46-
47-
4846if TYPE_CHECKING :
4947 from pyspark .sql import SparkSession
5048
5149
52- class ForceableConfig (BaseModel ):
53- force : bool = Field (
54- default = False ,
55- description = "If True, this takes precedence over parameters passed to the "
56- "initializer of the client." ,
57- )
58-
59-
6050_T = TypeVar ("_T" , bound = "BaseDataAccessLayer" )
6151
6252
63- class AccessLayerObjectConfig (ForceableConfig , Generic [_T ]):
64- model_config = ConfigDict (validate_default = True , extra = "forbid" )
53+ class AccessLayerObjectConfig (BaseObjectConfig , Generic [_T ]):
6554 type_base : ClassVar [Type [BaseDataAccessLayer ]] = BaseDataAccessLayer
66- type_config_name : str = Field (
67- description = "The config name of the object to create. "
68- "For metrics, this would might be 'ipmnormal'. For custom classes, you can "
69- "assign a name to a class variable `CONFIG_NAME` and reference it here." ,
70- )
71- options : dict [str , Any ] = Field (
72- default_factory = dict ,
73- description = "Options passed to the constructor." ,
74- )
7555
7656 def to_object (self , spark : SparkSession ) -> _T :
7757 type_ = self .type_base .subclass_from_config_name (self .type_config_name )
@@ -97,35 +77,25 @@ class SparkConfig(ForceableConfig):
9777_PX = TypeVar ("_PX" , bound = BaseProxyAccessLayer )
9878
9979
100- class ProxyAccessLayerObjectConfig (ForceableConfig , Generic [_PX ]):
80+ class ProxyAccessLayerObjectConfig (BaseObjectConfig , Generic [_PX ]):
10181 """Config for proxy clients that take no constructor args (e.g. no spark)."""
10282
103- model_config = ConfigDict (validate_default = True , extra = "forbid" )
10483 type_base : ClassVar [Type [BaseProxyAccessLayer ]] = BaseProxyAccessLayer
105- type_config_name : str = Field (
106- description = "CONFIG_NAME of the proxy client (e.g. 'LocalProxyClient')." ,
107- )
108- options : dict [str , Any ] = Field (default_factory = dict )
10984
11085 def to_object (self ) -> _PX :
11186 type_ = self .type_base .subclass_from_config_name (self .type_config_name )
11287 return cast (_PX , type_ (** self .options ))
11388
11489
115- class SparkProviderConfig (ForceableConfig , Generic [_P ]):
116- model_config = ConfigDict (validate_default = True , extra = "forbid" )
90+ class SparkProviderConfig (BaseObjectConfig , Generic [_P ]):
11791 type_base : ClassVar [Type [BaseSparkSessionProvider ]] = BaseSparkSessionProvider
118- type_config_name : str = Field (
119- description = "CONFIG_NAME of the Spark session provider."
120- )
121- options : dict [str , Any ] = Field (default_factory = dict )
12292
12393 def to_object (self ) -> _P :
12494 type_ = self .type_base .subclass_from_config_name (self .type_config_name )
12595 return cast (_P , type_ (** self .options ))
12696
12797
128- class ClientConfig (BaseModel ):
98+ class ClientConfig (BaseConfig ):
12999 reader_config : Union [AccessLayerObjectConfig [BaseDataCloudReader ], None ] = None
130100 writer_config : Union [AccessLayerObjectConfig [BaseDataCloudWriter ], None ] = None
131101 proxy_config : Union [ProxyAccessLayerObjectConfig [BaseProxyClient ], None ] = None
@@ -163,31 +133,10 @@ def merge(
163133 )
164134 return self
165135
166- def load (self , config_path : str ) -> ClientConfig :
167- """Load a config from a file and update this config with it.
168136
169- Args:
170- config_path: The path to the config file
171-
172- Returns:
173- Self, with updated values from the loaded config.
174- """
175- with open (config_path , "r" ) as f :
176- config_data = yaml .safe_load (f )
177- loaded_config = ClientConfig .model_validate (config_data )
178-
179- return self .update (loaded_config )
180-
181-
182- config = ClientConfig ()
183137"""Global config object.
184138
185139This is the object that makes config accessible globally and globally mutable.
186140"""
187-
188-
189- def _defaults () -> str :
190- return os .path .join (os .path .dirname (__file__ ), DEFAULT_CONFIG_NAME )
191-
192-
193- config .load (_defaults ())
141+ config = ClientConfig ()
142+ config .load (default_config_file ())
0 commit comments