Skip to content

issue in thermal solution for FOL geometry #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/thermal_fol/thermal_fol.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main(fol_num_epochs=10,solve_FE=False,clean_dir=False):
fe_model = FiniteElementModel("FE_model",model_info)

# create thermal loss
thermal_loss_3d = ThermalLoss3DTetra("thermal_loss_3d",fe_model,{"beta":2,"c":4})
thermal_loss_3d = ThermalLoss3DTetra("thermal_loss_3d",fe_model,{"beta":0,"c":4})

# create Fourier parametrization/control
x_freqs = np.array([1,2,3])
Expand Down Expand Up @@ -64,7 +64,7 @@ def main(fol_num_epochs=10,solve_FE=False,clean_dir=False):
io.mesh_io.write(solution_file)

# specify id of the K of interest
eval_id = 5
eval_id = -1
io.mesh_io.point_data['K'] = np.array(K_matrix[eval_id,:])

# now we need to create, initialize and train fol
Expand All @@ -81,7 +81,7 @@ def main(fol_num_epochs=10,solve_FE=False,clean_dir=False):

# solve FE here
if solve_FE:
first_fe_solver = NonLinearSolver("first_fe_solver",thermal_loss_3d,relative_error=1e-5,max_num_itr=20)
first_fe_solver = NonLinearSolver("first_fe_solver",thermal_loss_3d,relative_error=1e-5,max_num_itr=5,load_incr=1)
start_time = time.process_time()
FE_T = np.array(first_fe_solver.SingleSolve(K_matrix[eval_id],np.zeros(fe_model.GetNumberOfNodes())))
print(f"\n############### FE solve took: {time.process_time() - start_time} s ###############\n")
Expand Down
17 changes: 9 additions & 8 deletions fol/loss_functions/thermal_3D_fe_tetra.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ def ComputeElement(self,xyze,de,te,body_force):
@jit
def compute_at_gauss_point(xi,eta,zeta,total_weight):
Nf = self.shape_function.evaluate(xi,eta,zeta)
conductivity_at_gauss = jnp.dot(Nf, de.squeeze()) * (1 +
self.loss_settings["beta"]*(jnp.dot(Nf,te.squeeze()))**self.loss_settings["c"])
# conductivity_at_gauss = jnp.dot(Nf, de.squeeze()) * (1 +
# self.loss_settings["beta"]*(jnp.dot(Nf,te.squeeze()))**self.loss_settings["c"])
conductivity_at_gauss = jnp.ones((1,1))
dN_dxi = self.shape_function.derivatives(xi,eta,zeta)
J = jnp.dot(dN_dxi.T, xyze.T)
detJ = jnp.linalg.det(J)
invJ = jnp.linalg.inv(J)
B = jnp.dot(invJ,dN_dxi.T)
gp_stiffness = conductivity_at_gauss * jnp.dot(B.T, B) * detJ * total_weight
gp_f = total_weight * detJ * body_force * Nf.reshape(-1,1)
gp_f = total_weight * detJ * jnp.dot(Nf, de.squeeze()) * body_force * Nf.reshape(-1,1)
return gp_stiffness,gp_f
@jit
def vmap_compatible_compute_at_gauss_point(gp_index):
Expand All @@ -52,19 +53,19 @@ def vmap_compatible_compute_at_gauss_point(gp_index):
Se = jnp.sum(k_gps, axis=0)
Fe = jnp.sum(f_gps, axis=0)
element_residuals = jax.lax.stop_gradient(Se @ te - Fe)
return ((te.T @ element_residuals)[0,0]), 2 * (Se @ te - Fe), 2 * Se
return ((1/2)*(te.T @ element_residuals)[0,0]), (Se @ te - Fe), Se

def ComputeElementEnergy(self,xyze,de,te,body_force=jnp.zeros((1,1))):
def ComputeElementEnergy(self,xyze,de,te,body_force=jnp.ones((1,1))):
return self.ComputeElement(xyze,de,te,body_force)[0]

def ComputeElementResidualsAndStiffness(self,xyze,de,te,body_force=jnp.zeros((1,1))):
def ComputeElementResidualsAndStiffness(self,xyze,de,te,body_force=jnp.ones((1,1))):
_,re,ke = self.ComputeElement(xyze,de,te,body_force)
return re,ke

def ComputeElementResiduals(self,xyze,de,te,body_force=jnp.zeros((1,1))):
def ComputeElementResiduals(self,xyze,de,te,body_force=jnp.ones((1,1))):
return self.ComputeElement(xyze,de,te,body_force)[1]

def ComputeElementStiffness(self,xyze,de,te,body_force=jnp.zeros((1,1))):
def ComputeElementStiffness(self,xyze,de,te,body_force=jnp.ones((1,1))):
return self.ComputeElement(xyze,de,te,body_force)[2]

@partial(jit, static_argnums=(0,))
Expand Down