Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions src/gramforge/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,41 +57,50 @@ def wrapper(*args, **kwargs):

def Substitution(template, lang=None):
def replace_template(template, a):
# Make numbers formattable 0 -> {0}
wrap = lambda s: (re.sub(r'(\d+)', r'{\1}', s) if isinstance(s, str) else s)

# Function to safely replace only unescaped '?'
def replace_match(m):
slot_idx = int(m.group(1))
replacement = m.group(2)
# The children args 'a' might be FastProduction nodes, so render them if needed
arg_to_sub = a[slot_idx]
if not isinstance(arg_to_sub, str):
arg_to_sub = arg_to_sub.render(lang)
return re.sub(r'(?<!\\)\?', replacement, arg_to_sub)

inner_replaced = re.sub(r"(\d+)\[\?←(.+?)\]", replace_match, template)

# Render the remaining children before formatting
# Pre-render all args once (render() is cached, so repeated calls are cheap).
rendered_args = [arg.render(lang) if not isinstance(arg, str) else arg for arg in a]
output = wrap(inner_replaced).format(*rendered_args)

# Convert escaped "\?" back to "?" after processing
# Split the template at every N[?←X] boundary.
# re.split with two capturing groups yields a flat list:
# [text, N, X, text, N, X, ..., text]
# Indices mod 3: 0=text 1=slot-digit 2=replacement-string
parts = re.split(r"(\d+)\[\?←(.+?)\]", template)

out = []
slot_idx = None
for i, part in enumerate(parts):
cycle = i % 3
if cycle == 0:
# Original template text: wrap bare digits as format slots.
out.append(re.sub(r'(\d+)', r'{\1}', part))
elif cycle == 1:
# N digit: index of the arg to substitute into.
slot_idx = int(part)
else:
# X replacement string: wrap its digits first (e.g. '1' → '{1}'),
# then replace every unescaped '?' in the rendered arg with it.
# Crucially, digits that come from the rendered arg itself are NOT
# wrapped here, so 'pred5(?)' stays 'pred5(…)' rather than 'pred{5}(…)'.
wrapped_x = re.sub(r'(\d+)', r'{\1}', part)
out.append(re.sub(r'(?<!\\)\?', wrapped_x, rendered_args[slot_idx]))

output = ''.join(out).format(*rendered_args)
return output.replace(r'\?', '?')

# The original Substitution returned a function that expected rendered strings.
# The fix here is to make it robust enough to handle node objects too.
def sub(*a, **ka):
return replace_template(template, a)

return sub

# Pre-compile regex patterns
# Pre-compile regex patterns (available for external use)
NUMBER_PATTERN = re.compile(r'(\d+)')
SUBSTITUTION_PATTERN = re.compile(r"(\d+)\[\?←(.+?)\]")


default_preprocess_template = lambda s: (re.sub(r'(\d+)', r'{\1}', s) if type(s)==str and '←' not in s else s)
default_preprocess_template = lambda s: (
re.sub(r'(?<!\{)(\d+)(?!\})', r'{\1}', s)
if type(s) == str and '←' not in s else s
)

def init_grammar(
langs,
Expand Down
Loading