@@ -209,7 +209,7 @@ def generate_workflow(
209209 """
210210 # Create the necessary data structures to hold imports and generated code
211211 import_statements , executed_variables , special_functions_code , code = (
212- { "nodes" : { " NODE_CLASS_MAPPINGS"}} ,
212+ set ([ " NODE_CLASS_MAPPINGS"]) ,
213213 {},
214214 [],
215215 [],
@@ -245,9 +245,8 @@ def generate_workflow(
245245 )
246246 initialized_objects [class_type ] = self .clean_variable_name (class_type )
247247 if class_type in self .base_node_class_mappings .keys ():
248- module_name , import_name = import_statement
249- import_statements .setdefault (module_name , set ()).add (import_name )
250- if 'NODE_CLASS_MAPPINGS["' in class_code :
248+ import_statements .add (import_statement )
249+ if class_type not in self .base_node_class_mappings .keys ():
251250 custom_nodes = True
252251 special_functions_code .append (class_code )
253252
@@ -362,7 +361,7 @@ def format_arg(self, key: str, value: any) -> str:
362361
363362 def assemble_python_code (
364363 self ,
365- import_statements : Dict [ str , set ] ,
364+ import_statements : set ,
366365 speical_functions_code : List [str ],
367366 code : List [str ],
368367 queue_size : int ,
@@ -371,7 +370,7 @@ def assemble_python_code(
371370 """Generates the final code string.
372371
373372 Args:
374- import_statements (Dict[str, set] ): Import statements grouped by module .
373+ import_statements (set): A set of unique import statements .
375374 speical_functions_code (List[str]): A list of special functions code strings.
376375 code (List[str]): A list of code strings.
377376 queue_size (int): Number of photos that will be generated by the script.
@@ -409,10 +408,9 @@ def assemble_python_code(
409408 else :
410409 custom_nodes = ""
411410 # Create import statements for node classes
412- imports_code = []
413- for module_name in sorted (import_statements .keys ()):
414- class_names = ", " .join (sorted (import_statements [module_name ]))
415- imports_code .append (f"from { module_name } import { class_names } " )
411+ imports_code = [
412+ f"from nodes import { ', ' .join ([class_name for class_name in import_statements ])} "
413+ ]
416414 # Assemble the main function code, including custom nodes if applicable
417415 main_function_code = (
418416 "def main():\n \t "
@@ -432,31 +430,20 @@ def assemble_python_code(
432430
433431 return final_code
434432
435- def get_class_info (self , class_type : str ) -> Tuple [str , Tuple [ str , str ] , str ]:
433+ def get_class_info (self , class_type : str ) -> Tuple [str , str , str ]:
436434 """Generates and returns necessary information about class type.
437435
438436 Args:
439437 class_type (str): Class type.
440438
441439 Returns:
442- Tuple[str, Tuple[ str, str], str] : Updated class type, import statement and initialization code.
440+ Tuple[str, str, str]: Updated class type, import statement string, class initialization code.
443441 """
444- class_obj = self .base_node_class_mappings .get (class_type )
445- module_name = "nodes"
446- if class_obj is not None :
447- module_name = class_obj .__module__
448442 variable_name = self .clean_variable_name (class_type )
449- is_importable_module = bool (
450- module_name
451- and "/" not in module_name
452- and "\\ " not in module_name
453- and all (part .isidentifier () for part in module_name .split ("." ))
454- )
455- if class_type in self .base_node_class_mappings .keys () and is_importable_module :
456- import_statement = (module_name , class_type )
443+ import_statement = class_type
444+ if class_type in self .base_node_class_mappings .keys ():
457445 class_code = f"{ variable_name } = { class_type .strip ()} ()"
458446 else :
459- import_statement = ("nodes" , "NODE_CLASS_MAPPINGS" )
460447 class_code = f'{ variable_name } = NODE_CLASS_MAPPINGS["{ class_type } "]()'
461448
462449 return class_type , import_statement , class_code
0 commit comments