Skip to content

Commit cf86d27

Browse files
Merge pull request #89 from forcedotcom/predict
add support for einstein predict
2 parents 4b5852e + 643586a commit cf86d27

20 files changed

Lines changed: 1036 additions & 72 deletions
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from abc import ABC, abstractmethod
16+
import os
17+
from typing import Any
18+
19+
from pydantic import (
20+
BaseModel,
21+
ConfigDict,
22+
Field,
23+
)
24+
import yaml
25+
26+
DEFAULT_CONFIG_NAME = "config.yaml"
27+
28+
29+
def default_config_file() -> str:
30+
return os.path.join(os.path.dirname(__file__), DEFAULT_CONFIG_NAME)
31+
32+
33+
class ForceableConfig(BaseModel):
34+
force: bool = Field(
35+
default=False,
36+
description="If True, this takes precedence over parameters passed to the "
37+
"initializer of the client",
38+
)
39+
40+
41+
class BaseObjectConfig(ForceableConfig):
42+
model_config = ConfigDict(validate_default=True, extra="forbid")
43+
type_config_name: str = Field(
44+
description="The config name of the object to create",
45+
)
46+
options: dict[str, Any] = Field(
47+
default_factory=dict,
48+
description="Options passed to the constructor.",
49+
)
50+
51+
52+
class BaseConfig(ABC, BaseModel):
53+
@abstractmethod
54+
def update(self, other: Any) -> "BaseConfig": ...
55+
56+
def load(self, config_path: str) -> "BaseConfig":
57+
"""Load configuration from a YAML file and merge with existing config"""
58+
with open(config_path, "r") as f:
59+
config_data = yaml.safe_load(f)
60+
61+
loaded_config = self.__class__.model_validate(config_data)
62+
self.update(loaded_config)
63+
return self

src/datacustomcode/config.py

Lines changed: 13 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
import os
1817
from typing import (
1918
TYPE_CHECKING,
2019
Any,
@@ -26,12 +25,14 @@
2625
cast,
2726
)
2827

29-
from pydantic import (
30-
BaseModel,
31-
ConfigDict,
32-
Field,
28+
from pydantic import Field
29+
30+
from datacustomcode.common_config import (
31+
BaseConfig,
32+
BaseObjectConfig,
33+
ForceableConfig,
34+
default_config_file,
3335
)
34-
import yaml
3536

3637
# This lets all readers and writers to be findable via config
3738
from datacustomcode.io import * # noqa: F403
@@ -42,36 +43,15 @@
4243
from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH002
4344
from datacustomcode.spark.base import BaseSparkSessionProvider
4445

45-
DEFAULT_CONFIG_NAME = "config.yaml"
46-
47-
4846
if TYPE_CHECKING:
4947
from pyspark.sql import SparkSession
5048

5149

52-
class ForceableConfig(BaseModel):
53-
force: bool = Field(
54-
default=False,
55-
description="If True, this takes precedence over parameters passed to the "
56-
"initializer of the client.",
57-
)
58-
59-
6050
_T = TypeVar("_T", bound="BaseDataAccessLayer")
6151

6252

63-
class AccessLayerObjectConfig(ForceableConfig, Generic[_T]):
64-
model_config = ConfigDict(validate_default=True, extra="forbid")
53+
class AccessLayerObjectConfig(BaseObjectConfig, Generic[_T]):
6554
type_base: ClassVar[Type[BaseDataAccessLayer]] = BaseDataAccessLayer
66-
type_config_name: str = Field(
67-
description="The config name of the object to create. "
68-
"For metrics, this would might be 'ipmnormal'. For custom classes, you can "
69-
"assign a name to a class variable `CONFIG_NAME` and reference it here.",
70-
)
71-
options: dict[str, Any] = Field(
72-
default_factory=dict,
73-
description="Options passed to the constructor.",
74-
)
7555

7656
def to_object(self, spark: SparkSession) -> _T:
7757
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
@@ -97,35 +77,25 @@ class SparkConfig(ForceableConfig):
9777
_PX = TypeVar("_PX", bound=BaseProxyAccessLayer)
9878

9979

100-
class ProxyAccessLayerObjectConfig(ForceableConfig, Generic[_PX]):
80+
class ProxyAccessLayerObjectConfig(BaseObjectConfig, Generic[_PX]):
10181
"""Config for proxy clients that take no constructor args (e.g. no spark)."""
10282

103-
model_config = ConfigDict(validate_default=True, extra="forbid")
10483
type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer
105-
type_config_name: str = Field(
106-
description="CONFIG_NAME of the proxy client (e.g. 'LocalProxyClient').",
107-
)
108-
options: dict[str, Any] = Field(default_factory=dict)
10984

11085
def to_object(self) -> _PX:
11186
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
11287
return cast(_PX, type_(**self.options))
11388

11489

115-
class SparkProviderConfig(ForceableConfig, Generic[_P]):
116-
model_config = ConfigDict(validate_default=True, extra="forbid")
90+
class SparkProviderConfig(BaseObjectConfig, Generic[_P]):
11791
type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider
118-
type_config_name: str = Field(
119-
description="CONFIG_NAME of the Spark session provider."
120-
)
121-
options: dict[str, Any] = Field(default_factory=dict)
12292

12393
def to_object(self) -> _P:
12494
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
12595
return cast(_P, type_(**self.options))
12696

12797

128-
class ClientConfig(BaseModel):
98+
class ClientConfig(BaseConfig):
12999
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
130100
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
131101
proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None
@@ -163,31 +133,10 @@ def merge(
163133
)
164134
return self
165135

166-
def load(self, config_path: str) -> ClientConfig:
167-
"""Load a config from a file and update this config with it.
168136

169-
Args:
170-
config_path: The path to the config file
171-
172-
Returns:
173-
Self, with updated values from the loaded config.
174-
"""
175-
with open(config_path, "r") as f:
176-
config_data = yaml.safe_load(f)
177-
loaded_config = ClientConfig.model_validate(config_data)
178-
179-
return self.update(loaded_config)
180-
181-
182-
config = ClientConfig()
183137
"""Global config object.
184138
185139
This is the object that makes config accessible globally and globally mutable.
186140
"""
187-
188-
189-
def _defaults() -> str:
190-
return os.path.join(os.path.dirname(__file__), DEFAULT_CONFIG_NAME)
191-
192-
193-
config.load(_defaults())
141+
config = ClientConfig()
142+
config.load(default_config_file())

src/datacustomcode/config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,7 @@ proxy_config:
2323
type_config_name: LocalProxyClientProvider
2424
options:
2525
credentials_profile: default
26+
27+
einstein_predictions_config:
28+
type_config_name: DefaultEinsteinPredictions
29+
options: {}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from datacustomcode.einstein_predictions.base import EinsteinPredictions
17+
from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions
18+
19+
__all__ = [
20+
"EinsteinPredictions",
21+
"DefaultEinsteinPredictions",
22+
]
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from abc import ABC, abstractmethod
17+
18+
from datacustomcode.einstein_predictions.types import (
19+
PredictionRequest,
20+
PredictionResponse,
21+
)
22+
from datacustomcode.mixin import UserExtendableNamedConfigMixin
23+
24+
25+
class EinsteinPredictions(ABC, UserExtendableNamedConfigMixin):
26+
CONFIG_NAME: str
27+
28+
def __init__(self, **kwargs):
29+
pass
30+
31+
@abstractmethod
32+
def predict(self, request: PredictionRequest) -> PredictionResponse: ...
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from datacustomcode.einstein_predictions.base import EinsteinPredictions
17+
from datacustomcode.einstein_predictions.types import (
18+
PredictionRequest,
19+
PredictionResponse,
20+
)
21+
22+
23+
class DefaultEinsteinPredictions(EinsteinPredictions):
24+
CONFIG_NAME = "DefaultEinsteinPredictions"
25+
26+
def __init__(self, **kwargs):
27+
super().__init__(**kwargs)
28+
29+
def predict(self, request: PredictionRequest) -> PredictionResponse:
30+
return PredictionResponse(
31+
version="v1",
32+
prediction_type=request.prediction_type,
33+
status_code=200,
34+
data={"results": [{"prediction": {"predictedValue": 1.0}}]},
35+
)

0 commit comments

Comments
 (0)