diff --git a/examples/thermal_fol/thermal_fol.py b/examples/thermal_fol/thermal_fol.py index 87ce3f3..5d29601 100644 --- a/examples/thermal_fol/thermal_fol.py +++ b/examples/thermal_fol/thermal_fol.py @@ -27,7 +27,7 @@ def main(fol_num_epochs=10,solve_FE=False,clean_dir=False): fe_model = FiniteElementModel("FE_model",model_info) # create thermal loss - thermal_loss_3d = ThermalLoss3DTetra("thermal_loss_3d",fe_model,{"beta":2,"c":4}) + thermal_loss_3d = ThermalLoss3DTetra("thermal_loss_3d",fe_model,{"beta":0,"c":4}) # create Fourier parametrization/control x_freqs = np.array([1,2,3]) @@ -64,7 +64,7 @@ def main(fol_num_epochs=10,solve_FE=False,clean_dir=False): io.mesh_io.write(solution_file) # specify id of the K of interest - eval_id = 5 + eval_id = -1 io.mesh_io.point_data['K'] = np.array(K_matrix[eval_id,:]) # now we need to create, initialize and train fol @@ -81,7 +81,7 @@ def main(fol_num_epochs=10,solve_FE=False,clean_dir=False): # solve FE here if solve_FE: - first_fe_solver = NonLinearSolver("first_fe_solver",thermal_loss_3d,relative_error=1e-5,max_num_itr=20) + first_fe_solver = NonLinearSolver("first_fe_solver",thermal_loss_3d,relative_error=1e-5,max_num_itr=5,load_incr=1) start_time = time.process_time() FE_T = np.array(first_fe_solver.SingleSolve(K_matrix[eval_id],np.zeros(fe_model.GetNumberOfNodes()))) print(f"\n############### FE solve took: {time.process_time() - start_time} s ###############\n") diff --git a/fol/loss_functions/thermal_3D_fe_tetra.py b/fol/loss_functions/thermal_3D_fe_tetra.py index ed29867..19a2cc6 100644 --- a/fol/loss_functions/thermal_3D_fe_tetra.py +++ b/fol/loss_functions/thermal_3D_fe_tetra.py @@ -29,15 +29,16 @@ def ComputeElement(self,xyze,de,te,body_force): @jit def compute_at_gauss_point(xi,eta,zeta,total_weight): Nf = self.shape_function.evaluate(xi,eta,zeta) - conductivity_at_gauss = jnp.dot(Nf, de.squeeze()) * (1 + - self.loss_settings["beta"]*(jnp.dot(Nf,te.squeeze()))**self.loss_settings["c"]) + # conductivity_at_gauss = jnp.dot(Nf, de.squeeze()) * (1 + + # self.loss_settings["beta"]*(jnp.dot(Nf,te.squeeze()))**self.loss_settings["c"]) + conductivity_at_gauss = jnp.ones((1,1)) dN_dxi = self.shape_function.derivatives(xi,eta,zeta) J = jnp.dot(dN_dxi.T, xyze.T) detJ = jnp.linalg.det(J) invJ = jnp.linalg.inv(J) B = jnp.dot(invJ,dN_dxi.T) gp_stiffness = conductivity_at_gauss * jnp.dot(B.T, B) * detJ * total_weight - gp_f = total_weight * detJ * body_force * Nf.reshape(-1,1) + gp_f = total_weight * detJ * jnp.dot(Nf, de.squeeze()) * body_force * Nf.reshape(-1,1) return gp_stiffness,gp_f @jit def vmap_compatible_compute_at_gauss_point(gp_index): @@ -52,19 +53,19 @@ def vmap_compatible_compute_at_gauss_point(gp_index): Se = jnp.sum(k_gps, axis=0) Fe = jnp.sum(f_gps, axis=0) element_residuals = jax.lax.stop_gradient(Se @ te - Fe) - return ((te.T @ element_residuals)[0,0]), 2 * (Se @ te - Fe), 2 * Se + return ((1/2)*(te.T @ element_residuals)[0,0]), (Se @ te - Fe), Se - def ComputeElementEnergy(self,xyze,de,te,body_force=jnp.zeros((1,1))): + def ComputeElementEnergy(self,xyze,de,te,body_force=jnp.ones((1,1))): return self.ComputeElement(xyze,de,te,body_force)[0] - def ComputeElementResidualsAndStiffness(self,xyze,de,te,body_force=jnp.zeros((1,1))): + def ComputeElementResidualsAndStiffness(self,xyze,de,te,body_force=jnp.ones((1,1))): _,re,ke = self.ComputeElement(xyze,de,te,body_force) return re,ke - def ComputeElementResiduals(self,xyze,de,te,body_force=jnp.zeros((1,1))): + def ComputeElementResiduals(self,xyze,de,te,body_force=jnp.ones((1,1))): return self.ComputeElement(xyze,de,te,body_force)[1] - def ComputeElementStiffness(self,xyze,de,te,body_force=jnp.zeros((1,1))): + def ComputeElementStiffness(self,xyze,de,te,body_force=jnp.ones((1,1))): return self.ComputeElement(xyze,de,te,body_force)[2] @partial(jit, static_argnums=(0,))