From 1f56b00800805491148acba67cf69c3becd7f471 Mon Sep 17 00:00:00 2001 From: Simon Bruder Date: Sat, 22 Apr 2023 12:53:34 +0200 Subject: [PATCH] u01: Add recursive print with cycle detection --- u01/main.cpp | 8 ++++++++ u01/node.cpp | 18 ++++++++++++++++++ u01/node.h | 3 +++ u01/tests.cpp | 13 +++++++++++++ 4 files changed, 42 insertions(+) diff --git a/u01/main.cpp b/u01/main.cpp index 701b5c6..fd769aa 100644 --- a/u01/main.cpp +++ b/u01/main.cpp @@ -7,6 +7,14 @@ int main(int argc, char **argv) { std::cout << *t; delete t; t = nullptr; + + node *n = new node("foo"); + n->add_child(new node("bar")); + n->get_child(0)->add_child(n); + std::cout << n->print_recursive(); + delete n; + n = nullptr; + node *root = new node("root"); root->add_child(new node("left child")); root->add_child(new node("right child")); diff --git a/u01/node.cpp b/u01/node.cpp index 0525148..347dd78 100644 --- a/u01/node.cpp +++ b/u01/node.cpp @@ -42,6 +42,24 @@ void node::print(std::ostream &str, std::size_t depth) const { } } +std::string node::print_recursive(std::size_t depth, + std::set visited) const { + std::stringstream output; + output << std::string(depth, '\t') << get_name(); + visited.insert(this); + for (const node *child : children) { + // std::set::contains is only implemented in C++20 + if (visited.find(child) == visited.end()) { + output << std::endl << child->print_recursive(depth + 1, visited); + } else { + output << " [↝ " << child->get_name() << "]"; + } + } + if (depth == 0) + output << std::endl; + return output.str(); +} + std::ostream &operator<<(std::ostream &os, node &n) { n.print(os); return os; diff --git a/u01/node.h b/u01/node.h index 35176bb..b30131e 100644 --- a/u01/node.h +++ b/u01/node.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -14,6 +15,8 @@ public: node *get_child(std::size_t i) const; void add_child(node *child); void print(std::ostream &str, std::size_t depth = 0) const; + std::string print_recursive(std::size_t depth = 0, + std::set visited = {}) const; private: std::string name; diff --git a/u01/tests.cpp b/u01/tests.cpp index 96c1adf..add4715 100644 --- a/u01/tests.cpp +++ b/u01/tests.cpp @@ -150,3 +150,16 @@ TEST_CASE("Print stream overload") { REQUIRE(output1.str() == output2.str()); } + +TEST_CASE("Cycle detection") { + node *n = new node("foo"); + n->add_child(new node("bar")); + n->get_child(0)->add_child(n); + + std::stringstream output; + output << n->print_recursive(); + + REQUIRE(output.str() == "foo\n\tbar [↝ foo]\n"); + + delete n; +}