from collections import defaultdict
# this solution is inspired from Diameter of tree
def solution(S, A):
    tree = defaultdict(set)
    root = -1
    for node, parent in enumerate(A):
        if parent == -1:
            root = node
        tree[parent].add(node)
    # f[i] maximum alternating path starting from i node and in its subtree
    f = [0] * len(S)
    # g[i] maximum alternating path starting from  x node passing i node and terminating in y node and in i node subtree
    g = [0] * len(S)
    max_length = 0
    def dfs(current_node):
        nonlocal max_length
        # if node is a leaf
        if current_node not in tree:
            return
        child_list = []
        for child in tree[current_node]:
            dfs(child)
            if S[child] != S[current_node]:
                f[current_node] = max(f[current_node], f[child] + 1)
                child_list.append(f[child])
        max_length = max(max_length, f[current_node])
        child_list = sorted(child_list)
        if len(child_list) >= 2:
            g[current_node] = 2 + child_list[-1] + child_list[-2]
            max_length = max(max_length, g[current_node])
    dfs(root)
    return max_length
print(solution("abbab", [-1, 0, 0, 0, 2]))  # should return 2
print(solution("abbbaabaab", [-1, 0, 0, 0, 1, 2, 2, 3, 3, 4]))  # should return 5
print(solution("aaaaaaaaaa", [-1, 0, 0, 0, 1, 2, 2, 3, 3, 4]))  # should return 0
print(solution("abbbabaaaa", [-1, 0, 0, 0, 1, 2, 2, 3, 3, 4]))  # should return 4