Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions interface_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,17 @@ class Typename(object):
)
).setParseAction(lambda t: Typename(t.namespaces_name, t.instantiations))

def __init__(self, namespaces_name, instantiations=''):
def __init__(self, namespaces_name, instantiations=[]):
self.namespaces = namespaces_name[:-1]
self.name = namespaces_name[-1]

if instantiations:
self.instantiations = instantiations.asList()
if not isinstance(instantiations, list):
self.instantiations = instantiations.asList()
else:
self.instantiations = instantiations
else:
self.instantiations = ''
self.instantiations = []
if self.name in ["Matrix", "Vector"] and not self.namespaces:
self.namespaces = ["gtsam"]

Expand Down
73 changes: 51 additions & 22 deletions matlab_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ class MatlabWrapper(object):
"""Methods that should not be wrapped directly"""
whitelist = ['serializable', 'serialize']
"""Datatypes that do not need to be checked in methods"""
not_check_type = ['int', 'double', 'bool', 'char', 'unsigned char', 'size_t']
not_check_type = []
"""Data types that are primitive types"""
not_ptr_type = ['int', 'double', 'bool', 'char', 'unsigned char', 'size_t']
"""Ignore the namespace for these datatypes"""
ignore_namespace = ['Matrix', 'Vector', 'Point2', 'Point3']
"""The amount of times the wrapper has created a call to
Expand All @@ -75,7 +77,7 @@ def __init__(self, module, module_name, top_module_namespace='', ignore_classes=
self.module_name = module_name
self.top_module_namespace = top_module_namespace
self.ignore_classes = ignore_classes
self.verbose = True
self.verbose = False

def _debug(self, message):
if not self.verbose:
Expand Down Expand Up @@ -138,7 +140,7 @@ def _is_ptr(self, arg_type):
"""Determine if the interface_parser.Type should be treated as a
pointer in the wrapper.
"""
return arg_type.is_ptr or (arg_type.typename.name not in self.not_check_type
return arg_type.is_ptr or (arg_type.typename.name not in self.not_ptr_type
and arg_type.typename.name not in self.ignore_namespace
and arg_type.typename.name != 'string')

Expand Down Expand Up @@ -195,19 +197,32 @@ def _format_type_name(self, type_name, separator='::', include_namespace=True, c

if include_namespace:
for namespace in type_name.namespaces:
if name not in self.ignore_namespace:
if name not in self.ignore_namespace and namespace != '':
formatted_type_name += namespace + separator

#self._debug("formatted_ns: {}, ns: {}".format(formatted_type_name, type_name.namespaces))
if constructor:
formatted_type_name += self.data_type.get(name) or name
elif method:
formatted_type_name += self.data_type_param.get(name) or name
else:
formatted_type_name += name

if len(type_name.instantiations) == 1:
if separator == "::": # C++
formatted_type_name += '<{}>'.format(self._format_type_name(type_name.instantiations[0],
include_namespace=include_namespace,
constructor=constructor, method=method))
else:
formatted_type_name += '{}'.format(self._format_type_name(
type_name.instantiations[0],
separator=separator,
include_namespace=False,
constructor=constructor, method=method
))
return formatted_type_name

def _format_return_type(self, return_type, include_namespace=False):
def _format_return_type(self, return_type, include_namespace=False, separator="::"):
"""Format return_type.

Args:
Expand All @@ -217,11 +232,19 @@ def _format_return_type(self, return_type, include_namespace=False):
return_wrap = ''

if self._return_count(return_type) == 1:
return_wrap = self._format_type_name(return_type.type1.typename, include_namespace=include_namespace)
return_wrap = self._format_type_name(
return_type.type1.typename,
separator=separator,
include_namespace=include_namespace
)
else:
return_wrap = 'pair< {type1}, {type2} >'.format(
type1=self._format_type_name(return_type.type1.typename, include_namespace=include_namespace),
type2=self._format_type_name(return_type.type2.typename, include_namespace=include_namespace))
type1=self._format_type_name(
return_type.type1.typename, separator=separator, include_namespace=include_namespace
),
type2=self._format_type_name(
return_type.type2.typename, separator=separator, include_namespace=include_namespace
))

return return_wrap

Expand All @@ -238,10 +261,10 @@ def _format_class_name(self, instantiated_class, separator=''):
#
# class_name += instantiated_class.name
parentname = "".join([separator + x for x in parent_full_ns]) + separator

class_name = parentname[2 * len(separator):]

if class_name != '':
class_name += instantiated_class.name
class_name += instantiated_class.name

return class_name

Expand Down Expand Up @@ -322,6 +345,9 @@ def _wrap_variable_arguments(self, args, wrap_datatypes=True):

check_type = self.data_type_param.get(name)

if self.data_type.get(check_type):
check_type = self.data_type[check_type]

if check_type is None:
check_type = self._format_type_name(arg.ctype.typename, separator='.', constructor=not wrap_datatypes)

Expand Down Expand Up @@ -379,6 +405,9 @@ def _wrap_method_check_statement(self, args):

check_type = self.data_type_param.get(name)

if self.data_type.get(check_type):
check_type = self.data_type[check_type]

if check_type is None:
check_type = self._format_type_name(arg.ctype.typename, separator='.')

Expand Down Expand Up @@ -717,15 +746,17 @@ def wrap_class_constructors(self, namespace_name, inst_class, parent_name, ctors
if base_obj:
base_obj = '\n' + base_obj

self._debug("class: {}, name: {}".format(inst_class.name, self._format_class_name(inst_class, separator=".")))
methods_wrap += textwrap.indent(textwrap.dedent('''\
else
error('Arguments do not match any overload of {class_name} constructor');
error('Arguments do not match any overload of {class_name_doc} constructor');
end{base_obj}
obj.ptr_{class_name} = my_ptr;
end\n
''').format(namespace=namespace_name,
d='' if namespace_name == '' else '.',
class_name="".join(inst_class.parent.full_namespaces()) + class_name,
class_name_doc=self._format_class_name(inst_class, separator="."),
class_name=self._format_class_name(inst_class, separator=""),
base_obj=base_obj),
prefix=' ')

Expand Down Expand Up @@ -817,7 +848,7 @@ def wrap_class_methods(self, namespace_name, inst_class, methods, serialize=[Fal
prefix=' ')

# Determine format of return and varargout statements
return_type = self._format_return_type(overload.return_type, include_namespace=True)
return_type = self._format_return_type(overload.return_type, include_namespace=True, separator=".")

if self._return_count(overload.return_type) == 1:
varargout = '' \
Expand Down Expand Up @@ -854,7 +885,7 @@ def wrap_class_methods(self, namespace_name, inst_class, methods, serialize=[Fal

final_statement = textwrap.indent(textwrap.dedent("""\
error('Arguments do not match any overload of function {class_name}.{method_name}');
"""),
""".format(class_name=class_name, method_name=method_name)),
prefix=' ')
method_text += final_statement + 'end\n\n'

Expand Down Expand Up @@ -895,7 +926,7 @@ def wrap_static_methods(self, namespace_name, instantiated_class, serialize):
name_upper_case=static_overload.name,
args=self._wrap_args(static_overload.args),
return_type=self._format_return_type(static_overload.return_type,
include_namespace=True),
include_namespace=True, separator="."),
length=len(static_overload.args.args_list),
var_args_list=self._wrap_variable_arguments(static_overload.args),
check_statement=check_statement,
Expand All @@ -909,7 +940,7 @@ def wrap_static_methods(self, namespace_name, instantiated_class, serialize):

method_text += textwrap.indent(textwrap.dedent("""\
error('Arguments do not match any overload of function {class_name}.{method_name}');
"""),
""".format(class_name=class_name, method_name=static_overload.name)),
prefix=' ')
method_text += textwrap.indent(textwrap.dedent('''\
end\n
Expand Down Expand Up @@ -1161,11 +1192,9 @@ def wrap_collector_function_return(self, method):
shared_obj = '{obj},"{method_name_sep}"'.format(obj=obj, method_name_sep=sep_method_name('.'))
else:
self._debug("Non-PTR: {}, {}".format(return_1, type(return_1)))
# FIXME: This is a very dirty hack!!!
if isinstance(return_1, instantiator.InstantiatedClass):
method_name_sep_dot = "".join(return_1.parent.full_namespaces()) + return_1.name
else:
method_name_sep_dot = sep_method_name('.')
self._debug("Inner type is: {}, {}".format(return_1.typename.name, sep_method_name('.')))
self._debug("Inner type instantiations: {}".format(return_1.typename.instantiations))
method_name_sep_dot = sep_method_name('.')
shared_obj = 'boost::make_shared<{method_name_sep_col}>({obj}),"{method_name_sep_dot}"' \
.format(
method_name=return_1.typename.name,
Expand Down Expand Up @@ -1427,7 +1456,7 @@ def generate_wrapper(self, namespace):
bool anyDeleted = false;
''')
rtti_reg_start = textwrap.dedent('''\
void _gtsam_RTTIRegister() {{
void _{module_name}_RTTIRegister() {{
const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_{module_name}_rttiRegistry_created");
if(!alreadyCreated) {{
std::map<std::string, std::string> types;
Expand Down
22 changes: 17 additions & 5 deletions template_instantiator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import interface_parser as parser


def instantiate_type(ctype, template_typenames, instantiations, cpp_typename):
def instantiate_type(ctype, template_typenames, instantiations, cpp_typename, instantiated_class=None):
"""
Instantiate template typename for @p ctype.
@return If ctype's name is in the @p template_typenames, return the
Expand All @@ -20,6 +20,17 @@ def instantiate_type(ctype, template_typenames, instantiations, cpp_typename):
is_basis=ctype.is_basis,
)
elif str_arg_typename == 'This':
# import sys
if instantiated_class:
name = instantiated_class.original.name
namespaces_name = instantiated_class.namespaces()
namespaces_name.append(name)
# print("INST: {}, {}, CPP: {}, CLS: {}".format(
# ctype, instantiations, cpp_typename, instantiated_class.instantiations
# ), file=sys.stderr)
cpp_typename = parser.Typename(
namespaces_name, instantiations=[inst for inst in instantiated_class.instantiations]
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future note: we now construct a proper typename instead of using a string

return parser.Type(
typename=cpp_typename,
is_const=ctype.is_const,
Expand Down Expand Up @@ -58,13 +69,13 @@ def instantiate_args_list(args_list, template_typenames, instantiations,


def instantiate_return_type(return_type, template_typenames, instantiations,
cpp_typename):
cpp_typename, instantiated_class=None):
new_type1 = instantiate_type(
return_type.type1, template_typenames, instantiations, cpp_typename)
return_type.type1, template_typenames, instantiations, cpp_typename, instantiated_class=instantiated_class)
if return_type.type2:
new_type2 = instantiate_type(
return_type.type2, template_typenames, instantiations,
cpp_typename)
cpp_typename, instantiated_class=instantiated_class)
else:
new_type2 = ''
return parser.ReturnType(new_type1, new_type2)
Expand Down Expand Up @@ -215,7 +226,7 @@ def instantiate_static_methods(self):
static_method.args.args_list,
self.original.template.typenames,
self.instantiations,
self.cpp_typename(),
self.cpp_typename()
)
instantiated_static_methods.append(
parser.StaticMethod(
Expand All @@ -225,6 +236,7 @@ def instantiate_static_methods(self):
self.original.template.typenames,
self.instantiations,
self.cpp_typename(),
instantiated_class=self
),
args=parser.ArgumentList(instantiated_args),
parent=self,
Expand Down
22 changes: 11 additions & 11 deletions tests/expected-matlab/+gtsam/Point2.m
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
elseif nargin == 2 && isa(varargin{1},'double') && isa(varargin{2},'double')
my_ptr = geometry_wrapper(2, varargin{1}, varargin{2});
else
error('Arguments do not match any overload of gtsamPoint2 constructor');
error('Arguments do not match any overload of gtsam.Point2 constructor');
end
obj.ptr_gtsamPoint2 = my_ptr;
end
Expand All @@ -45,21 +45,21 @@ function delete(obj)
function varargout = argChar(this, varargin)
% ARGCHAR usage: argChar(char a) : returns void
% Doxygen can be found at http://research.cc.gatech.edu/borg/sites/edu.borg/html/index.html
if length(varargin) == 1
if length(varargin) == 1 && isa(varargin{1},'char')
geometry_wrapper(4, this, varargin{:});
return
end
error('Arguments do not match any overload of function {class_name}.{method_name}');
error('Arguments do not match any overload of function gtsam.Point2.argChar');
end

function varargout = argUChar(this, varargin)
% ARGUCHAR usage: argUChar(unsigned char a) : returns void
% Doxygen can be found at http://research.cc.gatech.edu/borg/sites/edu.borg/html/index.html
if length(varargin) == 1
if length(varargin) == 1 && isa(varargin{1},'unsigned char')
geometry_wrapper(5, this, varargin{:});
return
end
error('Arguments do not match any overload of function {class_name}.{method_name}');
error('Arguments do not match any overload of function gtsam.Point2.argUChar');
end

function varargout = dim(this, varargin)
Expand All @@ -69,7 +69,7 @@ function delete(obj)
varargout{1} = geometry_wrapper(6, this, varargin{:});
return
end
error('Arguments do not match any overload of function {class_name}.{method_name}');
error('Arguments do not match any overload of function gtsam.Point2.dim');
end

function varargout = eigenArguments(this, varargin)
Expand All @@ -79,7 +79,7 @@ function delete(obj)
geometry_wrapper(7, this, varargin{:});
return
end
error('Arguments do not match any overload of function {class_name}.{method_name}');
error('Arguments do not match any overload of function gtsam.Point2.eigenArguments');
end

function varargout = returnChar(this, varargin)
Expand All @@ -89,7 +89,7 @@ function delete(obj)
varargout{1} = geometry_wrapper(8, this, varargin{:});
return
end
error('Arguments do not match any overload of function {class_name}.{method_name}');
error('Arguments do not match any overload of function gtsam.Point2.returnChar');
end

function varargout = vectorConfusion(this, varargin)
Expand All @@ -99,7 +99,7 @@ function delete(obj)
varargout{1} = geometry_wrapper(9, this, varargin{:});
return
end
error('Arguments do not match any overload of function {class_name}.{method_name}');
error('Arguments do not match any overload of function gtsam.Point2.vectorConfusion');
end

function varargout = x(this, varargin)
Expand All @@ -109,7 +109,7 @@ function delete(obj)
varargout{1} = geometry_wrapper(10, this, varargin{:});
return
end
error('Arguments do not match any overload of function {class_name}.{method_name}');
error('Arguments do not match any overload of function gtsam.Point2.x');
end

function varargout = y(this, varargin)
Expand All @@ -119,7 +119,7 @@ function delete(obj)
varargout{1} = geometry_wrapper(11, this, varargin{:});
return
end
error('Arguments do not match any overload of function {class_name}.{method_name}');
error('Arguments do not match any overload of function gtsam.Point2.y');
end

end
Expand Down
Loading