Skip to content

Commit 4de8232

Browse files
author
Lambda-Section
committed
feat: Implement Neural DSL parser and grammar using Lark.
1 parent 3b75c97 commit 4de8232

File tree

1 file changed

+4
-37
lines changed

1 file changed

+4
-37
lines changed

neural/parser/parser.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ def _validate_optimizer(self, optimizer_name, item=None):
805805

806806
def _validate_loss_function(self, loss_name, item=None):
807807
"""Validate loss function is supported."""
808-
valid_losses = ["mse", "cross_entropy", "binary_cross_entropy", "mae", "categorical_cross_entropy", "sparse_categorical_cross_entropy"]
808+
valid_losses = ["mse", "cross_entropy", "binary_cross_entropy", "mae", "categorical_cross_entropy", "sparse_categorical_cross_entropy", "categorical_crossentropy"]
809809
if isinstance(loss_name, str) and loss_name.lower() not in valid_losses:
810810
self.raise_validation_error(f"Invalid loss function: {loss_name}", item)
811811

@@ -1700,7 +1700,9 @@ def optimizer(self, items):
17001700
items: [optimizer_name, param_style1?]
17011701
Accepts list/dict forms and merges into a flat params dict.
17021702
"""
1703-
optimizer_type = str(items[0])
1703+
optimizer_type = str(items[0]).strip('"')
1704+
self._validate_optimizer(optimizer_type, items[0])
1705+
17041706
params = {}
17051707

17061708
# Merge parameters if provided (items[1] may be dict or list)
@@ -2151,27 +2153,6 @@ def lstm(self, items):
21512153
params = param_values
21522154
else:
21532155
# Single positional parameter, e.g., LSTM(64)
2154-
params['units'] = param_values
2155-
2156-
if 'units' not in params:
2157-
self.raise_validation_error("LSTM requires 'units' parameter", items[0])
2158-
2159-
units = params['units']
2160-
if isinstance(units, dict) and 'hpo' in units:
2161-
pass # HPO handled elsewhere
2162-
else:
2163-
try:
2164-
params['units'] = validate_units(units)
2165-
except ValidationError as e:
2166-
self.raise_validation_error(str(e), items[0], Severity.ERROR)
2167-
return
2168-
2169-
return {'type': 'LSTM', 'params': params}
2170-
2171-
def gru(self, items):
2172-
params = {}
2173-
if items and items[0] is not None:
2174-
param_node = items[0]
21752156
param_values = self._extract_value(param_node)
21762157
if isinstance(param_values, list):
21772158
for val in param_values:
@@ -2886,20 +2867,6 @@ def network(self, items):
28862867
if optimizer_config:
28872868
network_config['optimizer'] = optimizer_config
28882869
# logger.debug(f"Adding optimizer to network_config: {optimizer_config}")
2889-
2890-
if training_config:
2891-
network_config['training'] = training_config
2892-
2893-
if execution_config:
2894-
network_config['execution'] = execution_config
2895-
2896-
# logger.debug(f"Final network_config: {network_config}")
2897-
return network_config
2898-
2899-
#########
2900-
2901-
def search_method_param(self, items):
2902-
value = self._extract_value(items[0]) # Extract "bayesian" from STRING token
29032870
return {'search_method': value}
29042871

29052872
def _extract_value(self, item):

0 commit comments

Comments
 (0)