@@ -805,7 +805,7 @@ def _validate_optimizer(self, optimizer_name, item=None):
805805
806806 def _validate_loss_function (self , loss_name , item = None ):
807807 """Validate loss function is supported."""
808- valid_losses = ["mse" , "cross_entropy" , "binary_cross_entropy" , "mae" , "categorical_cross_entropy" , "sparse_categorical_cross_entropy" ]
808+ valid_losses = ["mse" , "cross_entropy" , "binary_cross_entropy" , "mae" , "categorical_cross_entropy" , "sparse_categorical_cross_entropy" , "categorical_crossentropy" ]
809809 if isinstance (loss_name , str ) and loss_name .lower () not in valid_losses :
810810 self .raise_validation_error (f"Invalid loss function: { loss_name } " , item )
811811
@@ -1700,7 +1700,9 @@ def optimizer(self, items):
17001700 items: [optimizer_name, param_style1?]
17011701 Accepts list/dict forms and merges into a flat params dict.
17021702 """
1703- optimizer_type = str (items [0 ])
1703+ optimizer_type = str (items [0 ]).strip ('"' )
1704+ self ._validate_optimizer (optimizer_type , items [0 ])
1705+
17041706 params = {}
17051707
17061708 # Merge parameters if provided (items[1] may be dict or list)
@@ -2151,27 +2153,6 @@ def lstm(self, items):
21512153 params = param_values
21522154 else :
21532155 # Single positional parameter, e.g., LSTM(64)
2154- params ['units' ] = param_values
2155-
2156- if 'units' not in params :
2157- self .raise_validation_error ("LSTM requires 'units' parameter" , items [0 ])
2158-
2159- units = params ['units' ]
2160- if isinstance (units , dict ) and 'hpo' in units :
2161- pass # HPO handled elsewhere
2162- else :
2163- try :
2164- params ['units' ] = validate_units (units )
2165- except ValidationError as e :
2166- self .raise_validation_error (str (e ), items [0 ], Severity .ERROR )
2167- return
2168-
2169- return {'type' : 'LSTM' , 'params' : params }
2170-
2171- def gru (self , items ):
2172- params = {}
2173- if items and items [0 ] is not None :
2174- param_node = items [0 ]
21752156 param_values = self ._extract_value (param_node )
21762157 if isinstance (param_values , list ):
21772158 for val in param_values :
@@ -2886,20 +2867,6 @@ def network(self, items):
28862867 if optimizer_config :
28872868 network_config ['optimizer' ] = optimizer_config
28882869 # logger.debug(f"Adding optimizer to network_config: {optimizer_config}")
2889-
2890- if training_config :
2891- network_config ['training' ] = training_config
2892-
2893- if execution_config :
2894- network_config ['execution' ] = execution_config
2895-
2896- # logger.debug(f"Final network_config: {network_config}")
2897- return network_config
2898-
2899- #########
2900-
2901- def search_method_param (self , items ):
2902- value = self ._extract_value (items [0 ]) # Extract "bayesian" from STRING token
29032870 return {'search_method' : value }
29042871
29052872 def _extract_value (self , item ):
0 commit comments