diff --git a/examples/thermal_square/fourier_control_dict_N_21.pkl b/examples/thermal_square/fourier_control_dict_N_21.pkl new file mode 100644 index 0000000..369b9e5 Binary files /dev/null and b/examples/thermal_square/fourier_control_dict_N_21.pkl differ diff --git a/examples/thermal_square/fourier_control_dict_N_51.pkl b/examples/thermal_square/fourier_control_dict_N_51.pkl new file mode 100644 index 0000000..bb0cf44 Binary files /dev/null and b/examples/thermal_square/fourier_control_dict_N_51.pkl differ diff --git a/examples/thermal_square/sample_1367_plot_mesh_vec_data.pdf b/examples/thermal_square/sample_1367_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..3793a7d Binary files /dev/null and b/examples/thermal_square/sample_1367_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_square/sample_1641_plot_mesh_vec_data.pdf b/examples/thermal_square/sample_1641_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..eaf0f55 Binary files /dev/null and b/examples/thermal_square/sample_1641_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_square/sample_1893_plot_mesh_vec_data.pdf b/examples/thermal_square/sample_1893_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..92b198b Binary files /dev/null and b/examples/thermal_square/sample_1893_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_square/sample_1989_plot_mesh_vec_data.pdf b/examples/thermal_square/sample_1989_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..0e5bf2a Binary files /dev/null and b/examples/thermal_square/sample_1989_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_square/sample_1_plot_mesh_vec_data.pdf b/examples/thermal_square/sample_1_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..4c513df Binary files /dev/null and b/examples/thermal_square/sample_1_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_square/sample_289_plot_mesh_vec_data.pdf b/examples/thermal_square/sample_289_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..6c0a609 Binary files /dev/null and b/examples/thermal_square/sample_289_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_square/sample_367_plot_mesh_vec_data.pdf b/examples/thermal_square/sample_367_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..257ac66 Binary files /dev/null and b/examples/thermal_square/sample_367_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_square/sample_689_plot_mesh_vec_data.pdf b/examples/thermal_square/sample_689_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..cfcef76 Binary files /dev/null and b/examples/thermal_square/sample_689_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_square/sample_990_plot_mesh_vec_data.pdf b/examples/thermal_square/sample_990_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..e618617 Binary files /dev/null and b/examples/thermal_square/sample_990_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_square/thermal_2D_heat_source.py b/examples/thermal_square/thermal_2D_heat_source.py new file mode 100644 index 0000000..acd506b --- /dev/null +++ b/examples/thermal_square/thermal_2D_heat_source.py @@ -0,0 +1,151 @@ +import sys +import os + +import numpy as np +from fol.computational_models.fe_model import FiniteElementModel +from fol.loss_functions.thermal_2D_fe_quad import ThermalLoss2D +from fol.solvers.fe_solver import FiniteElementSolver +from fol.solvers.nonlinear_solver import NonLinearSolver +from fol.controls.fourier_control import FourierControl +from fol.deep_neural_networks.fe_operator_learning import FiniteElementOperatorLearning +from fol.tools.usefull_functions import * +from fol.tools.logging_functions import Logger +import pickle, time + +def main(fol_num_epochs=10,solve_FE=False,clean_dir=False): + # directory & save handling + working_directory_name = 'thermal_2D_heat_source' + case_dir = os.path.join('.', working_directory_name) + create_clean_directory(working_directory_name) + sys.stdout = Logger(os.path.join(case_dir,working_directory_name+".log")) + + # problem setup + model_settings = {"L":1, + "N":51, + "T_left":1.0,"T_bottom":1.0,"T_right":1.0,"T_top":1.0} + + # model_settings = {"L":1, + # "N":21, + # "T_left":1,"T_right":0.1} + + # creation of the model + model_info = create_2D_square_model_info_thermal_dirichlet(**model_settings) + # model_info = create_2D_square_model_info_thermal(**model_settings) + + # creation of the objects + fe_model = FiniteElementModel("FE_model",model_info) + thermal_loss_2d = ThermalLoss2D("thermal_loss_2d",fe_model,{"num_gp":2}) + + # fourier control + fourier_control_settings = {"x_freqs":np.array([1,2,3]),"y_freqs":np.array([1,2,3]),"z_freqs":np.array([0]), + "beta":10,"min":1e-1,"max":1} + fourier_control = FourierControl("fourier_control",fourier_control_settings,fe_model) + + # create some random coefficients & K for training + create_random_coefficients = False + if create_random_coefficients: + number_of_random_samples = 2000 + coeffs_matrix,K_matrix = create_random_fourier_samples(fourier_control,number_of_random_samples) + export_dict = model_settings.copy() + export_dict["coeffs_matrix"] = coeffs_matrix + export_dict["x_freqs"] = fourier_control.x_freqs + export_dict["y_freqs"] = fourier_control.y_freqs + export_dict["z_freqs"] = fourier_control.z_freqs + with open(f'fourier_control_dict_N_{model_settings["N"]}.pkl', 'wb') as f: + pickle.dump(export_dict,f) + else: + with open(f'fourier_control_dict_N_{model_settings["N"]}.pkl', 'rb') as f: + loaded_dict = pickle.load(f) + + coeffs_matrix = loaded_dict["coeffs_matrix"] + # coeffs_/matrix = np.loadtxt('coeffs_matrix.txt') + + K_matrix = fourier_control.ComputeBatchControlledVariables(coeffs_matrix) + + # specify id of the K of interest + eval_id = 1 + eval_id2 = 289 + eval_id3 = 990 + eval_id4 = 689 + eval_id5 = 367 + eval_id6 = 1989 + eval_id7 = 1367 + eval_id8 = 1641 + eval_id9 = 1893 + train_id = 1000 + + # now we need to create, initialize and train fol + fol = FiniteElementOperatorLearning("first_fol",fourier_control,[thermal_loss_2d],[100,100], + "swish",load_NN_params=False,working_directory=working_directory_name) + fol.Initialize() + + start_time = time.process_time() + fol.Train(loss_functions_weights=[1],X_train=coeffs_matrix[:train_id,:],batch_size=10,num_epochs=fol_num_epochs, + learning_rate=0.001,optimizer="adam",convergence_criterion="total_loss",relative_error=1e-10,absolute_error=1e-10, + plot_list=["avg_res","max_res","total_loss"],plot_rate=1,NN_params_save_file_name="NN_params_"+working_directory_name) + + FOL_T = np.array(fol.Predict(coeffs_matrix[eval_id,:].reshape(-1,1).T)) + + # solve FE here + if solve_FE: + + first_fe_solver = FiniteElementSolver("first_fe_solver", thermal_loss_2d) + start_time = time.process_time() + FE_T = np.array(first_fe_solver.SingleSolve(K_matrix[eval_id],np.zeros(fe_model.GetNumberOfNodes()))) + print(f"\n############### FE solve took: {time.process_time() - start_time} s ###############\n") + + relative_error = abs(FOL_T.reshape(-1,1)- FE_T.reshape(-1,1)) + plot_mesh_vec_data_paper_temp([K_matrix[eval_id,:], FOL_T, FE_T],['Heat Source', '$T$, FOL', '$T$, FEM'],f'sample_{eval_id}') + + eval_list = [eval_id2,eval_id3,eval_id4,eval_id5,eval_id6,eval_id7,eval_id8,eval_id9] + FOL_T = np.zeros((K_matrix[eval_id,:].reshape(-1,1).T).shape) + for i,eval_id in enumerate(eval_list): + FOL_T = np.array(fol.Predict(coeffs_matrix[eval_id].reshape(-1,1).T)) + # FOL_T = np.array(fol.Predict(coeffs_matrix[eval_id,:])) + print(f'eval coeffs: {coeffs_matrix[eval_id,:]}') + # print(f"predicted array: {FOL_T}") + start_time = time.process_time() + FE_T = np.array(first_fe_solver.SingleSolve(K_matrix[eval_id],np.zeros(fe_model.GetNumberOfNodes()))) + print(f"\n############### FE solve took: {time.process_time() - start_time} s ###############\n") + plot_mesh_vec_data_paper_temp([K_matrix[eval_id,:], FOL_T, FE_T],['Heat Source', '$T$, FOL', '$T$, FEM'],f'sample_{eval_id}') + + if clean_dir: + shutil.rmtree(case_dir) + +if __name__ == "__main__": + # Initialize default values + fol_num_epochs = 2000 + solve_FE = True + clean_dir = False + + # Parse the command-line arguments + args = sys.argv[1:] + + # Process the arguments if provided + for arg in args: + if arg.startswith("fol_num_epochs="): + try: + fol_num_epochs = int(arg.split("=")[1]) + except ValueError: + print("fol_num_epochs should be an integer.") + sys.exit(1) + elif arg.startswith("solve_FE="): + value = arg.split("=")[1] + if value.lower() in ['true', 'false']: + solve_FE = value.lower() == 'true' + else: + print("solve_FE should be True or False.") + sys.exit(1) + elif arg.startswith("clean_dir="): + value = arg.split("=")[1] + if value.lower() in ['true', 'false']: + clean_dir = value.lower() == 'true' + else: + print("clean_dir should be True or False.") + sys.exit(1) + else: + print("Usage: python mechanical_2D.py fol_num_epochs=10 solve_FE=False clean_dir=False") + sys.exit(1) + + # Call the main function with the parsed values + main(fol_num_epochs, solve_FE,clean_dir) diff --git a/examples/thermal_transient_square/sample_10_plot_mesh_vec_data.pdf b/examples/thermal_transient_square/sample_10_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..aad64ac Binary files /dev/null and b/examples/thermal_transient_square/sample_10_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_transient_square/sample_1_plot_mesh_vec_data.pdf b/examples/thermal_transient_square/sample_1_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..5e554c7 Binary files /dev/null and b/examples/thermal_transient_square/sample_1_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_transient_square/sample_289_plot_mesh_vec_data.pdf b/examples/thermal_transient_square/sample_289_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..5b4141e Binary files /dev/null and b/examples/thermal_transient_square/sample_289_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_transient_square/sample_2_plot_mesh_vec_data.pdf b/examples/thermal_transient_square/sample_2_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..39eecca Binary files /dev/null and b/examples/thermal_transient_square/sample_2_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_transient_square/sample_367_plot_mesh_vec_data.pdf b/examples/thermal_transient_square/sample_367_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..a9896d9 Binary files /dev/null and b/examples/thermal_transient_square/sample_367_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_transient_square/sample_5_plot_mesh_vec_data.pdf b/examples/thermal_transient_square/sample_5_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..e9efc20 Binary files /dev/null and b/examples/thermal_transient_square/sample_5_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_transient_square/sample_689_plot_mesh_vec_data.pdf b/examples/thermal_transient_square/sample_689_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..70910e0 Binary files /dev/null and b/examples/thermal_transient_square/sample_689_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_transient_square/sample_990_plot_mesh_vec_data.pdf b/examples/thermal_transient_square/sample_990_plot_mesh_vec_data.pdf new file mode 100644 index 0000000..13e9fbb Binary files /dev/null and b/examples/thermal_transient_square/sample_990_plot_mesh_vec_data.pdf differ diff --git a/examples/thermal_transient_square/thermal_2D_heat_transient.py b/examples/thermal_transient_square/thermal_2D_heat_transient.py new file mode 100644 index 0000000..bc5ae4d --- /dev/null +++ b/examples/thermal_transient_square/thermal_2D_heat_transient.py @@ -0,0 +1,172 @@ +import sys +import os + +import numpy as np +from fol.computational_models.fe_model import FiniteElementModel +from fol.loss_functions.thermal_transient_2D_fe_quad import ThermalLoss2D +from fol.solvers.fe_solver import FiniteElementSolver +from fol.solvers.nonlinear_solver import NonLinearSolver +# from fol.controls.fourier_control import FourierControl +from fol.controls.no_control import NoControl +from fol.deep_neural_networks.fe_operator_learning import FiniteElementOperatorLearning +from fol.tools.usefull_functions import * +from fol.tools.logging_functions import Logger +import pickle, time + +def main(fol_num_epochs=10,solve_FE=False,clean_dir=False): + # directory & save handling + working_directory_name = 'thermal_2D_heat_transient' + case_dir = os.path.join('.', working_directory_name) + create_clean_directory(working_directory_name) + sys.stdout = Logger(os.path.join(case_dir,working_directory_name+".log")) + + # problem setup + model_settings = {"L":1, + "N":21, + "T_left":1.0,"T_bottom":1.0,"T_right":0.0,"T_top":0.0} + + # model_settings = {"L":1, + # "N":21, + # "T_left":1,"T_right":0.1} + + # creation of the model + model_info = create_2D_square_model_info_thermal_dirichlet(**model_settings) + # model_info = create_2D_square_model_info_thermal(**model_settings) + + # creation of the objects + fe_model = FiniteElementModel("FE_model",model_info) + thermal_loss_2d = ThermalLoss2D("thermal_loss_2d",fe_model,{"num_gp":2, "rho":1.0, "cp":10.0, "dt":0.05}) + + # fourier control + # fourier_control_settings = {"x_freqs":np.array([1,2,3]),"y_freqs":np.array([1,2,3]),"z_freqs":np.array([0]), + # "beta":10,"min":1e-1,"max":1} + # fourier_control = Control("fourier_control",fourier_control_settings,fe_model) + no_control = NoControl("no_control",fe_model) + + # create some random coefficients & K for training + create_random_coefficients = False + if create_random_coefficients: + pass + # number_of_random_samples = 2000 + # coeffs_matrix,K_matrix = create_random_fourier_samples(fourier_control,number_of_random_samples) + # export_dict = model_settings.copy() + # export_dict["coeffs_matrix"] = coeffs_matrix + # export_dict["x_freqs"] = fourier_control.x_freqs + # export_dict["y_freqs"] = fourier_control.y_freqs + # export_dict["z_freqs"] = fourier_control.z_freqs + # with open(f'fourier_control_dict_N_{model_settings["N"]}.pkl', 'wb') as f: + # pickle.dump(export_dict,f) + else: + # with open(f'fourier_control_dict_N_{model_settings["N"]}.pkl', 'rb') as f: + # loaded_dict = pickle.load(f) + # coeffs_matrix = loaded_dict["coeffs_matrix"] + coeffs_matrix = np.load('gaussian_kernel_50000_N21.npy') + + Ts_c = no_control.ComputeBatchControlledVariables(coeffs_matrix) + + # specify id of the K of interest + eval_id = 1 + eval_id2 = 289 + eval_id3 = 990 + eval_id4 = 689 + eval_id5 = 367 + eval_id6 = 1989 + eval_id7 = 1367 + eval_id8 = 1641 + eval_id9 = 1893 + train_id = 10000 + + # now we need to create, initialize and train fol + fol = FiniteElementOperatorLearning("first_fol",no_control,[thermal_loss_2d],[1000,1000], + "swish",load_NN_params=False,working_directory=working_directory_name) + fol.Initialize() + + start_time = time.process_time() + fol.Train(loss_functions_weights=[1],X_train=coeffs_matrix[:train_id,:],batch_size=100,num_epochs=fol_num_epochs, + learning_rate=0.001,optimizer="adam",convergence_criterion="total_loss",relative_error=1e-10,absolute_error=1e-10, + plot_list=["avg_res","max_res","total_loss"],plot_rate=1,NN_params_save_file_name="NN_params_"+working_directory_name) + + num_steps = 10 + FOL_T_list = [] + T_c = coeffs_matrix[eval_id,:] + FOL_T_list.append(T_c) + for i in range(num_steps): + FOL_T = np.array(fol.Predict(T_c.reshape(-1,1).T)) + FOL_T_list.append(FOL_T) + T_c = FOL_T + + # solve FE here + if solve_FE: + first_fe_solver = FiniteElementSolver("first_fe_solver", thermal_loss_2d) + start_time = time.process_time() + T_c = Ts_c[eval_id] + FE_T_list = [] + FE_T_list.append(T_c) + for i in range(num_steps): + FE_T = np.array(first_fe_solver.SingleSolve(T_c,np.zeros(fe_model.GetNumberOfNodes()))) + FE_T_list.append(FE_T) + T_c = FE_T + print(f"\n############### FE solve took: {time.process_time() - start_time} s ###############\n") + + FOL_T_list = np.array(FOL_T_list) + FE_T_list = np.array(FE_T_list) + + relative_error = abs(FOL_T.reshape(-1,1)- FE_T.reshape(-1,1)) + time_steps = [1,2,5,10] + plot_mesh_vec_data_paper_temp([Ts_c[eval_id,:], FOL_T_list[time_steps[0]], FE_T_list[time_steps[0]]],['Heat Source', '$T$, FOL', '$T$, FEM'],f'sample_{time_steps[0]}') + plot_mesh_vec_data_paper_temp([Ts_c[eval_id,:], FOL_T_list[time_steps[1]], FE_T_list[time_steps[1]]],['Heat Source', '$T$, FOL', '$T$, FEM'],f'sample_{time_steps[1]}') + plot_mesh_vec_data_paper_temp([Ts_c[eval_id,:], FOL_T_list[time_steps[2]], FE_T_list[time_steps[2]]],['Heat Source', '$T$, FOL', '$T$, FEM'],f'sample_{time_steps[2]}') + plot_mesh_vec_data_paper_temp([Ts_c[eval_id,:], FOL_T_list[time_steps[3]], FE_T_list[time_steps[3]]],['Heat Source', '$T$, FOL', '$T$, FEM'],f'sample_{time_steps[3]}') + + # eval_list = [eval_id2,eval_id3,eval_id4,eval_id5]#,eval_id6,eval_id7,eval_id8,eval_id9] + # FOL_T = np.zeros((Ts_c[eval_id,:].reshape(-1,1).T).shape) + # for i,eval_id in enumerate(eval_list): + # FOL_T = np.array(fol.Predict(coeffs_matrix[eval_id].reshape(-1,1).T)) + # # FOL_T = np.array(fol.Predict(coeffs_matrix[eval_id,:])) + # # print(f'eval coeffs: {coeffs_matrix[eval_id,:]}') + # # print(f"predicted array: {FOL_T}") + # start_time = time.process_time() + # FE_T = np.array(first_fe_solver.SingleSolve(Ts_c[eval_id],np.zeros(fe_model.GetNumberOfNodes()))) + # print(f"\n############### FE solve took: {time.process_time() - start_time} s ###############\n") + # plot_mesh_vec_data_paper_temp([Ts_c[eval_id,:], FOL_T, FE_T],['Heat Source', '$T$, FOL', '$T$, FEM'],f'sample_{eval_id}') + + if clean_dir: + shutil.rmtree(case_dir) + +if __name__ == "__main__": + # Initialize default values + fol_num_epochs = 1000 + solve_FE = True + clean_dir = False + + # Parse the command-line arguments + args = sys.argv[1:] + + # Process the arguments if provided + for arg in args: + if arg.startswith("fol_num_epochs="): + try: + fol_num_epochs = int(arg.split("=")[1]) + except ValueError: + print("fol_num_epochs should be an integer.") + sys.exit(1) + elif arg.startswith("solve_FE="): + value = arg.split("=")[1] + if value.lower() in ['true', 'false']: + solve_FE = value.lower() == 'true' + else: + print("solve_FE should be True or False.") + sys.exit(1) + elif arg.startswith("clean_dir="): + value = arg.split("=")[1] + if value.lower() in ['true', 'false']: + clean_dir = value.lower() == 'true' + else: + print("clean_dir should be True or False.") + sys.exit(1) + else: + print("Usage: python mechanical_2D.py fol_num_epochs=10 solve_FE=False clean_dir=False") + sys.exit(1) + + # Call the main function with the parsed values + main(fol_num_epochs, solve_FE,clean_dir) diff --git a/fol/controls/no_control.py b/fol/controls/no_control.py new file mode 100644 index 0000000..2224c31 --- /dev/null +++ b/fol/controls/no_control.py @@ -0,0 +1,39 @@ +""" + Authors: Reza Najian Asl, https://github.com/RezaNajian + Date: April, 2024 + License: FOL/License.txt +""" +from .control import Control +import jax.numpy as jnp +from jax import jit,jacfwd +from functools import partial +from jax.nn import sigmoid +from fol.tools.decoration_functions import * + +class NoControl(Control): + @print_with_timestamp_and_execution_time + def __init__(self,control_name: str,fe_model): + super().__init__(control_name) + self.fe_model = fe_model + self.num_control_vars = self.fe_model.GetNumberOfNodes() + self.num_controlled_vars = self.fe_model.GetNumberOfNodes() + + def GetNumberOfVariables(self): + return self.num_control_vars + + def GetNumberOfControlledVariables(self): + return self.num_controlled_vars + + def Initialize(self) -> None: + pass + + def Finalize(self) -> None: + pass + + @partial(jit, static_argnums=(0,)) + def ComputeControlledVariables(self,variable_vector:jnp.array): + return variable_vector + + @partial(jit, static_argnums=(0,)) + def ComputeJacobian(self,control_vec): + pass \ No newline at end of file diff --git a/fol/deep_neural_networks/fe_operator_learning.py b/fol/deep_neural_networks/fe_operator_learning.py index d6b6d1c..5a566a5 100644 --- a/fol/deep_neural_networks/fe_operator_learning.py +++ b/fol/deep_neural_networks/fe_operator_learning.py @@ -46,8 +46,8 @@ def InitializeParameters(self): for n_in, n_out in zip(layer_sizes[:-1], layer_sizes[1:]): key_w, rng_key = random.split(rng_key) limit = jnp.sqrt(6 / (n_in + n_out)) - # weights = random.uniform(key_w, (n_in, n_out), minval=-limit, maxval=limit) - weights = jnp.zeros((n_in, n_out)) + weights = random.uniform(key_w, (n_in, n_out), minval=-limit, maxval=limit) + # weights = jnp.zeros((n_in, n_out)) biases = jnp.zeros(n_out) self.NN_params.append((weights, biases)) super().InitializeParameters() diff --git a/fol/loss_functions/thermal_transient_2D_fe_quad.py b/fol/loss_functions/thermal_transient_2D_fe_quad.py new file mode 100644 index 0000000..299ce17 --- /dev/null +++ b/fol/loss_functions/thermal_transient_2D_fe_quad.py @@ -0,0 +1,122 @@ +""" + Authors: Yusuke Yamazaki + Date: September, 2024 + License: FOL/License.txt +""" +import jax +import jax.numpy as jnp +from jax import jit +from functools import partial +from .fe_loss import FiniteElementLoss +from fol.tools.fem_utilities import * +from fol.computational_models.fe_model import FiniteElementModel + +class ThermalLoss2D(FiniteElementLoss): + """FE-based 2D Thermal loss + + This is the base class for the loss functions require FE formulation. + + """ + def __init__(self, name: str, fe_model: FiniteElementModel, loss_settings: dict={}): + super().__init__(name,fe_model,["T"],{**loss_settings,"compute_dims":2,"rho":1.0,"cp":10.0, "dt":0.05}) + self.shape_function = QuadShapeFunction() + self.rho = loss_settings["rho"] + self.cp = loss_settings["rho"] + self.dt = loss_settings["dt"] + # self.Ke = loss_settings["Ke"] + + + @partial(jit, static_argnums=(0,)) + def ComputeElement(self,xyze,Te_c,Te_n,body_force): + xye = jnp.array([xyze[::3], xyze[1::3]]) + Te_c = Te_c.reshape(-1,1) + Te_n = Te_n.reshape(-1,1) + @jit + def compute_at_gauss_point(xi,eta,total_weight): + Nf = self.shape_function.evaluate(xi,eta) + # conductivity_at_gauss = jnp.dot(Nf, Ke.squeeze()) + dN_dxi = self.shape_function.derivatives(xi,eta) + J = jnp.dot(dN_dxi.T, xye.T) + detJ = jnp.linalg.det(J) + invJ = jnp.linalg.inv(J) + B = jnp.dot(invJ,dN_dxi.T) + T_at_gauss_n = jnp.dot(Nf, Te_n) + T_at_gauss_c = jnp.dot(Nf, Te_c) + gp_stiffness = jnp.dot(B.T, B) * detJ * total_weight #* conductivity_at_gauss + gp_mass =jnp.outer(Nf, Nf) * detJ * total_weight + gp_f = total_weight * detJ * body_force * Nf.reshape(-1,1) + gp_t = total_weight * detJ *(T_at_gauss_n-T_at_gauss_c)**2 + return gp_stiffness,gp_mass, gp_f, gp_t + @jit + def vmap_compatible_compute_at_gauss_point(gp_index): + return compute_at_gauss_point(self.g_points[self.dim*gp_index], + self.g_points[self.dim*gp_index+1], + self.g_weights[self.dim*gp_index] * self.g_weights[self.dim*gp_index+1]) + + k_gps,m_gps,f_gps,t_gps = jax.vmap(vmap_compatible_compute_at_gauss_point,(0))(jnp.arange(self.num_gp**self.dim)) + Se = jnp.sum(k_gps, axis=0) + Me = jnp.sum(m_gps, axis=0) + Fe = jnp.sum(f_gps, axis=0) + Te = jnp.sum(t_gps) + + return 0.5*Te_n.T @Se@Te_n+0.5*self.rho*self.cp*Te/self.dt, (Me+self.dt*Se)@Te_n - Me@Te_c, (Me+self.dt*Se) + + def ComputeElementEnergy(self,xyze,de,uvwe,body_force=0.0): + return self.ComputeElement(xyze,de,uvwe,body_force)[0] + + def ComputeElementResidualsAndStiffness(self,xyze,de,uvwe,body_force=0.0): + _,re,ke = self.ComputeElement(xyze,de,uvwe,body_force) + return re,ke + + def ComputeElementResiduals(self,xyze,de,uvwe,body_force=0.0): + return self.ComputeElement(xyze,de,uvwe,body_force)[1] + + def ComputeElementStiffness(self,xyze,de,uvwe,body_force=0.0): + return self.ComputeElement(xyze,de,uvwe,body_force)[2] + + @partial(jit, static_argnums=(0,)) + def ComputeElementResidualsVmapCompatible(self,element_id,elements_nodes,X,Y,Z,C,UV): + return self.ComputeElementResiduals(jnp.ravel(jnp.column_stack((X[elements_nodes[element_id]], + Y[elements_nodes[element_id]], + Z[elements_nodes[element_id]]))), + C[elements_nodes[element_id]], + UV[((self.number_dofs_per_node*elements_nodes[element_id])[:, jnp.newaxis] + + jnp.arange(self.number_dofs_per_node))].reshape(-1,1)) + + @partial(jit, static_argnums=(0,)) + def ComputeElementResidualsAndStiffnessVmapCompatible(self,element_id,elements_nodes,X,Y,Z,C,UV): + return self.ComputeElementResidualsAndStiffness(jnp.ravel(jnp.column_stack((X[elements_nodes[element_id]], + Y[elements_nodes[element_id]], + Z[elements_nodes[element_id]]))), + C[elements_nodes[element_id]], + UV[((self.number_dofs_per_node*elements_nodes[element_id])[:, jnp.newaxis] + + jnp.arange(self.number_dofs_per_node))].reshape(-1,1)) + + @partial(jit, static_argnums=(0,)) + def ComputeElementEnergyVmapCompatible(self,element_id,elements_nodes,X,Y,Z,C,UV): + return self.ComputeElementEnergy(jnp.ravel(jnp.column_stack((X[elements_nodes[element_id]], + Y[elements_nodes[element_id]], + Z[elements_nodes[element_id]]))), + C[elements_nodes[element_id]], + UV[((self.number_dofs_per_node*elements_nodes[element_id])[:, jnp.newaxis] + + jnp.arange(self.number_dofs_per_node))].reshape(-1,1)) + + # @partial(jit, static_argnums=(0,)) + # def ComputeSingleLoss(self,full_control_params,unknown_dofs): + # elem_residual = self.ComputeResiduals(full_control_params.reshape(-1,1), + # self.ExtendUnknowDOFsWithBC(unknown_dofs)) + # # some extra calculation for reporting and not traced + # avg_elem_residual = jax.lax.stop_gradient(jnp.mean(elem_residual)) + # max_elem_residual = jax.lax.stop_gradient(jnp.max(elem_residual)) + # min_elem_residual = jax.lax.stop_gradient(jnp.min(elem_residual)) + # return jnp.sum(elem_residual),(0,max_elem_residual,avg_elem_residual) + + @partial(jit, static_argnums=(0,)) + def ComputeSingleLoss(self,full_control_params,unknown_dofs): + elems_energies = self.ComputeElementsEnergies(full_control_params.reshape(-1,1), + self.ExtendUnknowDOFsWithBC(unknown_dofs)) + # some extra calculation for reporting and not traced + avg_elem_energy = jax.lax.stop_gradient(jnp.mean(elems_energies)) + max_elem_energy = jax.lax.stop_gradient(jnp.max(elems_energies)) + min_elem_energy = jax.lax.stop_gradient(jnp.min(elems_energies)) + return jnp.sum(elems_energies),(0,max_elem_energy,avg_elem_energy) diff --git a/fol/tools/usefull_functions.py b/fol/tools/usefull_functions.py index eaebbf9..9474a1f 100644 --- a/fol/tools/usefull_functions.py +++ b/fol/tools/usefull_functions.py @@ -141,6 +141,57 @@ def create_2D_square_model_info_thermal(L,N,T_left,T_right): dofs_dict = {"T":{"non_dirichlet_nodes_ids":non_boundary_nodes,"dirichlet_nodes_ids":boundary_nodes,"dirichlet_nodes_dof_value":boundary_values}} return {"nodes_dict":nodes_dict,"elements_dict":elements_dict,"dofs_dict":dofs_dict} +def create_2D_square_model_info_thermal_dirichlet(L,N,T_left,T_bottom,T_right,T_top): + # FE init starts here + Ne = N - 1 # Number of elements in each direction + nx = Ne + 1 # Number of nodes in the x-direction + ny = Ne + 1 # Number of nodes in the y-direction + ne = Ne * Ne # Total number of elements + # Generate mesh coordinates + x = jnp.linspace(0, L, nx) + y = jnp.linspace(0, L, ny) + X, Y = jnp.meshgrid(x, y) + X = X.flatten() + Y = Y.flatten() + Z = jnp.zeros((Y.shape[-1])) + # Gauss quadrature points and weights (for a 2x2 integration) + # Create a matrix to store element nodal information + elements_nodes = jnp.zeros((ne, 4), dtype=int) + # Fill in the elements_nodes with element and node numbers + for i in range(Ne): + for j in range(Ne): + e = i * Ne + j # Element index + # Define the nodes of the current element + nodes = jnp.array([i * (Ne + 1) + j, i * (Ne + 1) + j + 1, (i + 1) * (Ne + 1) + j + 1, (i + 1) * (Ne + 1) + j]) + # Store element and node numbers in the matrix + elements_nodes = elements_nodes.at[e].set(nodes) # Node numbers + + element_ids = jnp.arange(0,elements_nodes.shape[0]) + + # Identify boundary nodes on the left and right edges + left_boundary_nodes = jnp.arange(0, ny * nx, nx) # Nodes on the left boundary + left_boundary_nodes_values = T_left * jnp.ones(left_boundary_nodes.shape) + right_boundary_nodes = jnp.arange(nx - 1, ny * nx, nx) # Nodes on the right boundary + right_boundary_nodes_values = T_right * jnp.ones(right_boundary_nodes.shape) + # bottom_boundary_nodes = jnp.arange((ny-1)*nx + 1, ny * nx - 1) # Nodes on the bottom boundary + # bottom_boundary_nodes_values = T_bottom * jnp.ones(bottom_boundary_nodes.shape) + # top_boundary_nodes = jnp.arange(1, nx - 1) # Nodes on the top boundary + # top_boundary_nodes_values = T_top * jnp.ones(top_boundary_nodes.shape) + boundary_nodes = jnp.concatenate([left_boundary_nodes, right_boundary_nodes]) + boundary_values = jnp.concatenate([left_boundary_nodes_values, right_boundary_nodes_values]) + # boundary_nodes = jnp.concatenate([left_boundary_nodes, bottom_boundary_nodes, right_boundary_nodes, top_boundary_nodes]) + # boundary_values = jnp.concatenate([left_boundary_nodes_values, bottom_boundary_nodes_values, right_boundary_nodes_values, top_boundary_nodes_values]) + non_boundary_nodes = [] + for i in range(N*N): + if not jnp.any(boundary_nodes == i): + non_boundary_nodes.append(i) + non_boundary_nodes = jnp.array(non_boundary_nodes) + + nodes_dict = {"nodes_ids":jnp.arange(Y.shape[-1]),"X":X,"Y":Y,"Z":Z} + elements_dict = {"elements_ids":element_ids,"elements_nodes":elements_nodes} + dofs_dict = {"T":{"non_dirichlet_nodes_ids":non_boundary_nodes,"dirichlet_nodes_ids":boundary_nodes,"dirichlet_nodes_dof_value":boundary_values}} + return {"nodes_dict":nodes_dict,"elements_dict":elements_dict,"dofs_dict":dofs_dict} + def box_mesh(Nx, Ny, Nz, Lx, Ly, Lz, case_dir): cell_type = 'hexahedron' @@ -484,6 +535,202 @@ def create_clean_directory(case_dir): os.makedirs(case_dir) + +def plot_mesh_vec_data_paper_temp(vectors_list:list,title_list:list,plot_name:str): + + if len(vectors_list) != 3 or len(title_list) != 3: + raise ValueError('vector list and title list should have 3 components') + fontsize = 16 + fig, axs = plt.subplots(1, 4, figsize=(20, 8)) # Adjusted to 4 columns + + # Plot the first entity in the first row + data = vectors_list[0] + N = int((data.reshape(-1, 1).shape[0]) ** 0.5) + # im = axs[0, 0].imshow(data.reshape(N, N), cmap='viridis', aspect='equal') + # axs[0, 0].set_xticks([]) + # axs[0, 0].set_yticks([]) + # axs[0, 0].set_title(title_list[0], fontsize=fontsize) + # cbar = fig.colorbar(im, ax=axs[0, 0], pad=0.02, shrink=0.7) + # cbar.ax.tick_params(labelsize=fontsize) + # cbar.ax.yaxis.labelpad = 5 + # cbar.ax.tick_params(length=5, width=1) + + # # Plot the same entity with mesh grid in the first row, second column + # im = axs[0, 1].imshow(data.reshape(N, N), cmap='bone', aspect='equal') + # axs[0, 1].set_xticks([]) + # axs[0, 1].set_yticks([]) + # axs[0, 1].set_xticklabels([]) # Remove text on x-axis + # axs[0, 1].set_yticklabels([]) # Remove text on y-axis + # axs[0, 1].set_title(r'Mesh Grid: {} $\times {}$'.format(N,N), fontsize=fontsize) + # axs[0, 1].grid(True, color='red', linestyle='-', linewidth=1) # Adding solid grid lines with red color + # axs[0, 1].xaxis.grid(True) + # axs[0, 1].yaxis.grid(True) + + # x_ticks = np.linspace(0, N, N) + # y_ticks = np.linspace(0, N, N) + # axs[0, 1].set_xticks(x_ticks) + # axs[0, 1].set_yticks(y_ticks) + + # cbar = fig.colorbar(im, ax=axs[0, 1], pad=0.02, shrink=0.7) + # cbar.ax.tick_params(labelsize=fontsize) + # cbar.ax.yaxis.labelpad = 5 + # cbar.ax.tick_params(length=5, width=1) + + # # Zoomed-in region + # zoom_region = data.reshape(N, N)[20:40, 20:40] + # im = axs[0, 2].imshow(zoom_region, cmap='bone', aspect='equal') + # axs[0, 2].set_xticks([]) + # axs[0, 2].set_yticks([]) + # axs[0, 2].set_xticklabels([]) # Remove text on x-axis + # axs[0, 2].set_yticklabels([]) # Remove text on y-axis + # axs[0, 2].set_title('Zoomed-in: $x \in [0.4, 0.8], y \in [0.2, 0.6]$', fontsize=fontsize) + # cbar = fig.colorbar(im, ax=axs[0, 2], pad=0.02, shrink=0.7) + # cbar.ax.tick_params(labelsize=fontsize) + # cbar.ax.yaxis.labelpad = 5 + # cbar.ax.tick_params(length=5, width=1) + + # # Plot the mesh grid + # axs[0, 2].xaxis.set_major_locator(plt.LinearLocator(21)) + # axs[0, 2].yaxis.set_major_locator(plt.LinearLocator(21)) + # axs[0, 2].grid(color='red', linestyle='-', linewidth=2) + + # # Plot cross-sections along x-axis at y=0.5 for U (FOL and FEM) in the second row, fourth column + # y_idx = int(N * 0.5) + # U1 = vectors_list[0].reshape(N, N) + # axs[0, 3].plot(np.linspace(0, 1, N), U1[y_idx, :], label=title_list[0], color='black') + # axs[0, 3].set_xlim([0, 1]) + # #axs[0, 3].set_ylim([min(U1[y_idx, :].min()), max(U1[y_idx, :].max())]) + # axs[0, 3].set_aspect(aspect='auto') + # axs[0, 3].set_title('Cross-section of Q at y=0.5', fontsize=fontsize) + # axs[0, 3].legend(fontsize=fontsize) + # axs[0, 3].grid(True) + # axs[0, 3].set_xlabel('x', fontsize=fontsize) + # axs[0, 3].set_ylabel('K', fontsize=fontsize) + + + # Plot the second entity in the second row + data = vectors_list[1] + im = axs[0].imshow(data.reshape(N, N), cmap='jet', aspect='equal') + axs[0].set_xticks([]) + axs[0].set_yticks([]) + axs[0].set_title(title_list[1], fontsize=fontsize) + cbar = fig.colorbar(im, ax=axs[0], pad=0.02, shrink=0.7) + cbar.ax.tick_params(labelsize=fontsize) + cbar.ax.yaxis.labelpad = 5 + cbar.ax.tick_params(length=5, width=1) + + # Plot the fourth entity in the second row + data = vectors_list[2] + im = axs[1].imshow(data.reshape(N, N), cmap='jet', aspect='equal') + axs[1].set_xticks([]) + axs[1].set_yticks([]) + axs[1].set_title(title_list[2], fontsize=fontsize) + cbar = fig.colorbar(im, ax=axs[1], pad=0.02, shrink=0.7) + cbar.ax.tick_params(labelsize=fontsize) + cbar.ax.yaxis.labelpad = 5 + cbar.ax.tick_params(length=5, width=1) + + # Plot the absolute difference between vectors_list[1] and vectors_list[3] in the third row, second column + diff_data_1 = np.abs(vectors_list[1] - vectors_list[2]) + im = axs[2].imshow(diff_data_1.reshape(N, N), cmap='jet', aspect='equal') + axs[2].set_xticks([]) + axs[2].set_yticks([]) + axs[2].set_title('Abs. Difference $T$', fontsize=fontsize) + cbar = fig.colorbar(im, ax=axs[2], pad=0.02, shrink=0.7) + cbar.ax.tick_params(labelsize=fontsize) + cbar.ax.yaxis.labelpad = 5 + cbar.ax.tick_params(length=5, width=1) + + # Plot cross-sections along x-axis at y=0.5 for U (FOL and FEM) in the second row, fourth column + y_idx = int(N * 0.5) + U1 = vectors_list[1].reshape(N, N) + U2 = vectors_list[2].reshape(N, N) + axs[3].plot(np.linspace(0, 1, N), U1[y_idx, :], label='T FOL', color='blue') + axs[3].plot(np.linspace(0, 1, N), U2[y_idx, :], label='T FEM', color='red') + axs[3].set_xlim([0, 1]) + axs[3].set_ylim([min(U1[y_idx, :].min(), U2[y_idx, :].min()), max(U1[y_idx, :].max(), U2[y_idx, :].max())]) + axs[3].set_aspect(aspect='auto') + axs[3].set_title('Cross-section of T at y=0.5', fontsize=fontsize) + axs[3].legend(fontsize=fontsize) + axs[3].grid(True) + axs[3].set_xlabel('x', fontsize=fontsize) + axs[3].set_ylabel('T', fontsize=fontsize) + + plt.tight_layout() + + # Save the figure in multiple formats + plt.savefig(plot_name+'_plot_mesh_vec_data.png', dpi=300) + plt.savefig(plot_name+'_plot_mesh_vec_data.pdf') + + # plt.show() + +def plot_temp_evolution(vectors_list:list,title_list:list,plot_name:str,selected_timesteps:list): + + if len(vectors_list) != 3 or len(title_list) != 3: + raise ValueError('vector list and title list should have 3 components') + fontsize = 16 + fig, axs = plt.subplots(3, 4, figsize=(20, 8)) # Adjusted to 4 columns + + # Plot the first entity in the first row + data = vectors_list[0] + N = int((data.reshape(-1, 1).shape[0]) ** 0.5) + + # Plot the second entity in the second row + data = vectors_list[1] + im = axs[0].imshow(data.reshape(N, N), cmap='jet', aspect='equal') + axs[0].set_xticks([]) + axs[0].set_yticks([]) + axs[0].set_title(title_list[1], fontsize=fontsize) + cbar = fig.colorbar(im, ax=axs[0], pad=0.02, shrink=0.7) + cbar.ax.tick_params(labelsize=fontsize) + cbar.ax.yaxis.labelpad = 5 + cbar.ax.tick_params(length=5, width=1) + + # Plot the fourth entity in the second row + data = vectors_list[2] + im = axs[1].imshow(data.reshape(N, N), cmap='jet', aspect='equal') + axs[1].set_xticks([]) + axs[1].set_yticks([]) + axs[1].set_title(title_list[2], fontsize=fontsize) + cbar = fig.colorbar(im, ax=axs[1], pad=0.02, shrink=0.7) + cbar.ax.tick_params(labelsize=fontsize) + cbar.ax.yaxis.labelpad = 5 + cbar.ax.tick_params(length=5, width=1) + + # Plot the absolute difference between vectors_list[1] and vectors_list[3] in the third row, second column + diff_data_1 = np.abs(vectors_list[1] - vectors_list[2]) + im = axs[2].imshow(diff_data_1.reshape(N, N), cmap='jet', aspect='equal') + axs[2].set_xticks([]) + axs[2].set_yticks([]) + axs[2].set_title('Abs. Difference $T$', fontsize=fontsize) + cbar = fig.colorbar(im, ax=axs[2], pad=0.02, shrink=0.7) + cbar.ax.tick_params(labelsize=fontsize) + cbar.ax.yaxis.labelpad = 5 + cbar.ax.tick_params(length=5, width=1) + + # Plot cross-sections along x-axis at y=0.5 for U (FOL and FEM) in the second row, fourth column + y_idx = int(N * 0.5) + U1 = vectors_list[1].reshape(N, N) + U2 = vectors_list[2].reshape(N, N) + axs[3].plot(np.linspace(0, 1, N), U1[y_idx, :], label='T FOL', color='blue') + axs[3].plot(np.linspace(0, 1, N), U2[y_idx, :], label='T FEM', color='red') + axs[3].set_xlim([0, 1]) + axs[3].set_ylim([min(U1[y_idx, :].min(), U2[y_idx, :].min()), max(U1[y_idx, :].max(), U2[y_idx, :].max())]) + axs[3].set_aspect(aspect='auto') + axs[3].set_title('Cross-section of T at y=0.5', fontsize=fontsize) + axs[3].legend(fontsize=fontsize) + axs[3].grid(True) + axs[3].set_xlabel('x', fontsize=fontsize) + axs[3].set_ylabel('T', fontsize=fontsize) + + plt.tight_layout() + + # Save the figure in multiple formats + plt.savefig(plot_name+'_plot_mesh_vec_data.png', dpi=300) + plt.savefig(plot_name+'_plot_mesh_vec_data.pdf') + + # plt.show() + def TensorToVoigt(tensor): if tensor.size == 4: voigt = jnp.zeros((3,1))