Skip to content

Commit 433a987

Browse files
committed
Narrow runtime compatibility fix
1 parent 50141bb commit 433a987

2 files changed

Lines changed: 13 additions & 26 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ Flags:
9696
- `--output_file`: output Python file, default `workflow_api.py`
9797
- `--queue_size`: default execution count in the generated script, default `10`
9898

99-
![Dev Mode Options](images/dev_mode_options.png)
99+
![Dev Mode Options](images/dev_mode_options.PNG)
100100

101101
## Generated Scripts
102102

comfyui_to_python.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def generate_workflow(
209209
"""
210210
# Create the necessary data structures to hold imports and generated code
211211
import_statements, executed_variables, special_functions_code, code = (
212-
{"nodes": {"NODE_CLASS_MAPPINGS"}},
212+
set(["NODE_CLASS_MAPPINGS"]),
213213
{},
214214
[],
215215
[],
@@ -245,9 +245,8 @@ def generate_workflow(
245245
)
246246
initialized_objects[class_type] = self.clean_variable_name(class_type)
247247
if class_type in self.base_node_class_mappings.keys():
248-
module_name, import_name = import_statement
249-
import_statements.setdefault(module_name, set()).add(import_name)
250-
if 'NODE_CLASS_MAPPINGS["' in class_code:
248+
import_statements.add(import_statement)
249+
if class_type not in self.base_node_class_mappings.keys():
251250
custom_nodes = True
252251
special_functions_code.append(class_code)
253252

@@ -362,7 +361,7 @@ def format_arg(self, key: str, value: any) -> str:
362361

363362
def assemble_python_code(
364363
self,
365-
import_statements: Dict[str, set],
364+
import_statements: set,
366365
speical_functions_code: List[str],
367366
code: List[str],
368367
queue_size: int,
@@ -371,7 +370,7 @@ def assemble_python_code(
371370
"""Generates the final code string.
372371
373372
Args:
374-
import_statements (Dict[str, set]): Import statements grouped by module.
373+
import_statements (set): A set of unique import statements.
375374
speical_functions_code (List[str]): A list of special functions code strings.
376375
code (List[str]): A list of code strings.
377376
queue_size (int): Number of photos that will be generated by the script.
@@ -409,10 +408,9 @@ def assemble_python_code(
409408
else:
410409
custom_nodes = ""
411410
# Create import statements for node classes
412-
imports_code = []
413-
for module_name in sorted(import_statements.keys()):
414-
class_names = ", ".join(sorted(import_statements[module_name]))
415-
imports_code.append(f"from {module_name} import {class_names}")
411+
imports_code = [
412+
f"from nodes import {', '.join([class_name for class_name in import_statements])}"
413+
]
416414
# Assemble the main function code, including custom nodes if applicable
417415
main_function_code = (
418416
"def main():\n\t"
@@ -432,31 +430,20 @@ def assemble_python_code(
432430

433431
return final_code
434432

435-
def get_class_info(self, class_type: str) -> Tuple[str, Tuple[str, str], str]:
433+
def get_class_info(self, class_type: str) -> Tuple[str, str, str]:
436434
"""Generates and returns necessary information about class type.
437435
438436
Args:
439437
class_type (str): Class type.
440438
441439
Returns:
442-
Tuple[str, Tuple[str, str], str]: Updated class type, import statement and initialization code.
440+
Tuple[str, str, str]: Updated class type, import statement string, class initialization code.
443441
"""
444-
class_obj = self.base_node_class_mappings.get(class_type)
445-
module_name = "nodes"
446-
if class_obj is not None:
447-
module_name = class_obj.__module__
448442
variable_name = self.clean_variable_name(class_type)
449-
is_importable_module = bool(
450-
module_name
451-
and "/" not in module_name
452-
and "\\" not in module_name
453-
and all(part.isidentifier() for part in module_name.split("."))
454-
)
455-
if class_type in self.base_node_class_mappings.keys() and is_importable_module:
456-
import_statement = (module_name, class_type)
443+
import_statement = class_type
444+
if class_type in self.base_node_class_mappings.keys():
457445
class_code = f"{variable_name} = {class_type.strip()}()"
458446
else:
459-
import_statement = ("nodes", "NODE_CLASS_MAPPINGS")
460447
class_code = f'{variable_name} = NODE_CLASS_MAPPINGS["{class_type}"]()'
461448

462449
return class_type, import_statement, class_code

0 commit comments

Comments
 (0)