diff --git a/stripnet/stripnet.py b/stripnet/stripnet.py index 6a27bba..923ce07 100644 --- a/stripnet/stripnet.py +++ b/stripnet/stripnet.py @@ -231,11 +231,9 @@ def fit_transform(self, self.text = text self.remove_isolated_nodes = remove_isolated_nodes - if max_connections: - self.max_connections = max_connections - else: - self.max_connections = utils.calc_max_connections( - len(self.text), 1) + self.max_connections = max_connections or utils.calc_max_connections( + len(self.text), 1 + ) logger.info('========== Step1: Calculating Embeddings ==========') self.embeddings = self.embedding_gen(self.text)