@@ -11,7 +11,7 @@ def optimizer_handler(optimizer_name: str, params, lr_mult=1.0, **kwargs):
1111 weight_decay = 0 ,
1212 foreach = None ,
1313 maximize = False ,
14- differentiable = False ,
14+ # differentiable=False,
1515 )
1616 elif optimizer_name == "Adagrad" :
1717 return torch .optim .Adagrad (
@@ -23,7 +23,7 @@ def optimizer_handler(optimizer_name: str, params, lr_mult=1.0, **kwargs):
2323 eps = 1e-10 ,
2424 foreach = None ,
2525 maximize = False ,
26- differentiable = False ,
26+ # differentiable=False,
2727 )
2828 elif optimizer_name == "Adam" :
2929 return torch .optim .Adam (
@@ -36,7 +36,7 @@ def optimizer_handler(optimizer_name: str, params, lr_mult=1.0, **kwargs):
3636 foreach = None ,
3737 maximize = False ,
3838 capturable = False ,
39- differentiable = False ,
39+ # differentiable=False,
4040 fused = None ,
4141 )
4242 elif optimizer_name == "AdamW" :
@@ -76,7 +76,7 @@ def optimizer_handler(optimizer_name: str, params, lr_mult=1.0, **kwargs):
7676 weight_decay = 0 ,
7777 foreach = None ,
7878 maximize = False ,
79- differentiable = False ,
79+ # differentiable=False,
8080 )
8181 elif optimizer_name == "LBFGS" :
8282 return torch .optim .LBFGS (
@@ -102,7 +102,13 @@ def optimizer_handler(optimizer_name: str, params, lr_mult=1.0, **kwargs):
102102 )
103103 elif optimizer_name == "RAdam" :
104104 return torch .optim .RAdam (
105- params , lr = 0.001 , betas = (0.9 , 0.999 ), eps = 1e-08 , weight_decay = 0 , foreach = None , differentiable = False
105+ params ,
106+ lr = 0.001 ,
107+ betas = (0.9 , 0.999 ),
108+ eps = 1e-08 ,
109+ weight_decay = 0 ,
110+ foreach = None ,
111+ # differentiable=False
106112 )
107113 elif optimizer_name == "RMSprop" :
108114 return torch .optim .RMSprop (
@@ -115,7 +121,7 @@ def optimizer_handler(optimizer_name: str, params, lr_mult=1.0, **kwargs):
115121 centered = False ,
116122 foreach = None ,
117123 maximize = False ,
118- differentiable = False ,
124+ # differentiable=False,
119125 )
120126 elif optimizer_name == "Rprop" :
121127 return torch .optim .Rprop (
@@ -125,7 +131,7 @@ def optimizer_handler(optimizer_name: str, params, lr_mult=1.0, **kwargs):
125131 step_sizes = (1e-06 , 50 ),
126132 foreach = None ,
127133 maximize = False ,
128- differentiable = False ,
134+ # differentiable=False,
129135 )
130136 elif optimizer_name == "SGD" :
131137 return torch .optim .SGD (
@@ -137,7 +143,7 @@ def optimizer_handler(optimizer_name: str, params, lr_mult=1.0, **kwargs):
137143 nesterov = False ,
138144 maximize = False ,
139145 foreach = None ,
140- differentiable = False ,
146+ # differentiable=False,
141147 )
142148 else :
143149 raise ValueError (f"Optimizer { optimizer_name } not supported" )
0 commit comments