Question

For a undirected graph with tree characteristics, we can choose any node as the root. The result graph is then a rooted tree. Among all possible rooted trees, those with minimum height are called minimum height trees (MHTs). Given such a graph, write a function to find all the MHTs and return a list of their root labels.

Format The graph contains n nodes which are labeled from 0 to n - 1. You will be given the number n and a list of undirected edges (each edge is a pair of labels).

You can assume that no duplicate edges will appear in edges. Since all edges are undirected, [0, 1] is the same as [1, 0] and thus will not appear together in edges.

Example 1 :

Input: n = 4, edges = [[1, 0], [1, 2], [1, 3]]

        0
        |
        1
       / \
      2   3

Output: [1]

Example 2 :

Input: n = 6, edges = [[0, 3], [1, 3], [2, 3], [4, 3], [5, 4]]

     0  1  2
      \ | /
        3
        |
        4
        |
        5

Output: [3, 4]

Note:

  • According to the definition of tree on Wikipedia: “a tree is an undirected graph in which any two vertices are connected by exactly one path. In other words, any connected graph without simple cycles is a tree.”
  • The height of a rooted tree is the number of edges on the longest downward path between the root and a leaf.

Solution

This problem becomes finding the mid-point(s) in the longest path in this tree. We can use two BFS to find the longest path in the tree. First, find the other end, say e, of the longest path from any node; second, find the longest path from e.

class Solution {
    class Node {
        int id;
        List<Integer> neighbors;
        int dist;
        int prev;

        public Node(int id) {
            this.id = id;
            this.neighbors = new LinkedList<>();
            this.dist = -1;
            this.prev = -1;
        }

        public void addNode(int id) {
            this.neighbors.add(id);
        }

        public boolean visit(int dist, int prev) {
            if(this.dist == -1) {
                this.dist = dist;
                this.prev = prev;
                return true;
            } else return false;
        }

        public void clear() {
            this.dist = -1;
            this.prev = -1;
        }
    }
    public List<Integer> findMinHeightTrees(int n, int[][] edges) {
        if(n == 1) return Arrays.asList(0);

        Node[] nodes = new Node[n];
        for(int i = 0; i < n; i++) {
            nodes[i] = new Node(i);
        }
        for(int i = 0; i < edges.length; i++) {
            int from = edges[i][0], to = edges[i][1];
            nodes[from].addNode(to);
            nodes[to].addNode(from);
        }
        List<Integer> path = longestPath(nodes, 0);
        int end = path.get(0);
        for(int i = 0; i < n; i++) nodes[i].clear();
        path = longestPath(nodes, end);
        if(path.size() % 2 == 1)
            return path.subList(path.size() / 2, path.size() / 2 + 1);
        else
            return path.subList(path.size() / 2 - 1, path.size() / 2 + 1);
    }

    private List<Integer> longestPath(Node[] nodes, int start) {
        int distance = 0, furtherest = -1;
        nodes[start].visit(distance, -1);
        Queue<Integer> queue = new LinkedList<>();
        queue.offer(start);
        while(queue.size() > 0) {
            int node = queue.poll();
            distance++;
            for(int n: nodes[node].neighbors) {
                if(nodes[n].visit(distance, node)) {
                    queue.offer(n);
                    furtherest = n;
                }
            }
        }
        List<Integer> path = new LinkedList<>();
        path.add(furtherest);
        Node node = nodes[furtherest];
        while(node.dist != 0) {
            path.add(node.prev);
            node = nodes[node.prev];
        }
        return path;
    }    
}