import math
import mip


def solve_lp(A, b, c, lb_all, ub_all):

    # lb_all spodni meze na vsechny promenne, vektor [, , ,..] velikosti n
    # ub_all horni meze na vsechny promenne

    n = len(c)
    m = mip.Model(solver_name="CBC")

    x = [m.add_var(var_type=mip.CONTINUOUS, lb=0) for j in range(n)]

    n_rows = len(A)

    for i_row in range(n_rows):
        m += mip.quicksum(A[i_row][j] * x[j] for j in range(n)) <= b[i_row]

    for j in range(n):
        m += x[j] >= lb_all[j]
        m += x[j] <= ub_all[j]

    m.objective = mip.maximize(mip.quicksum(c[j] * x[j] for j in range(n)))
    m.verbose = 0

    status = m.optimize()

    if status == mip.OptimizationStatus.OPTIMAL:
        return m.objective_value, [x[j].x for j in range(n)]
    else:
        return None, None


def is_integer(x_jedna):
    return abs(x_jedna - round(x_jedna)) < 1e-5


def branch_and_bound(A, b, c):

    n = len(c)

    lb_all = [0]*n
    ub_all = [math.inf] * n
    ub_z_parent = math.inf

    best_obj = -math.inf
    best_sol = None

    queue = [(lb_all, ub_all, ub_z_parent)]


    node_count = 1
    while len(queue) > 0:
    # for i in range(20):

        print()
        print(node_count)
        lb_curr, ub_curr, est_ub = queue.pop(0)

        if est_ub <= best_obj:
            print(f"Odrezavame, protoze mame lepsi LB")
            print()
            continue

        print(f"Resime s mezemi na promennych LB - {lb_all}; UB - {ub_all}")
        obj, x_lp = solve_lp(A, b, c, lb_curr, ub_curr)
        node_count += 1

        if obj is None:
            print("Nepripustne reseni")
            print()
            continue

        print("Vysledky LP relaxace: ")
        print(obj)
        print(x_lp)

        # jsou vsechna x celociselna?
        # jestli ne, najde most fractional variable

        all_integer = True
        max_frac = -1
        best_frac_idx = -1

        for j in range(n):
            if not is_integer(x_lp[j]):
                all_integer = False
                frac = abs(x_lp[j] - round(x_lp[j]))
                if frac > max_frac:
                    max_frac = frac
                    best_frac_idx = j

        if all_integer:
            print("Nalezeno celociselne reseni")
            if obj > best_obj:
                best_obj = obj
                best_sol = x_lp
                print("... a navic je ZATIM NEJLEPSI!!!!!")
            print()
            continue

        # jdeme vetvit na best_frac_idx-te promenne:
        val = x_lp[best_frac_idx]
        print(f"Vetvime na promenne X[{best_frac_idx}] = {val}")

        # SPODNI VETEV
        new_ub = ub_curr.copy()
        new_ub[best_frac_idx] = math.floor(val)
        queue.append((lb_all, new_ub, obj))

        # HORNI VETEV
        new_lb = lb_curr.copy()
        new_lb[best_frac_idx] = math.ceil(val)
        queue.append((new_lb, ub_curr, obj))


    return best_sol, best_obj


def control_panel():

    c = [5, 6]
    b = [35, 3, 4]

    A = [[6, 7], [1, 0], [0, 1]]
    print(len(A))

    best_sol, best_obj = branch_and_bound(A, b, c)

    print(f"NEJLEPSI RESENI: {best_sol}")
    print(f"UCELOVA FUNKCE: {best_obj}")

    # predpokladame:
    # max c^T x
    # A x <= b
    #   x >= 0, celociselne


    # --- 2 VARIABLE EXAMPLES ---
    examples_2v = [
    ([[6, 7], [1, 0], [0, 1]], [35, 3, 4], [5, 6], "Classic knapsack-style constraint"),
        ([[6, 7]], [35], [5, 6], "Classic knapsack-style constraint"),
        ([[1, 1], [2, 1]], [6, 10], [3, 2], "Two intersecting constraints"),
        ([[5, 2], [1, 4]], [20, 20], [4, 5], "Tight corner relaxation"),
        ([[3, 5], [6, 2]], [15, 18], [10, 8], "Steep objective function"),
        ([[1.5, 2.5]], [10], [4, 7], "Fractional coefficients"),
        ([[1, 0], [0, 1], [1, 1]], [3.5, 3.5, 5], [1, 1], "Simple box constraints"),
        ([[7, 3], [2, 5]], [25, 20], [6, 5], "Asymmetric constraints"),
        ([[10, 12]], [45], [1, 1], "Large coefficients, small objective"),
        ([[0.7, 1.2]], [3], [5, 5], "Sensitivity test"),
        ([[4, 3], [1, 1]], [12, 3.5], [8, 5], "Narrow feasible region")
    ]

    # --- 3 VARIABLE EXAMPLES ---
    examples_3v = [
        ([[2, 1, 3], [1, 2, 1]], [10, 8], [4, 5, 6], "Basic 3D resource allocation"),
        ([[5, 7, 4], [1, 1, 1]], [25, 6], [10, 12, 7], "Multiple resource bottlenecks"),
        ([[1, 2, 0], [0, 2, 1], [2, 0, 1]], [7, 5, 8], [3, 4, 5], "Cyclic variable dependency"),
        ([[10, 2, 1]], [22], [5, 1, 1], "One dominant variable"),
        ([[1, 1, 1], [10, 1, 1]], [10, 20], [2, 15, 15], "Large penalty for X0"),
        ([[3, 4, 5], [1, 1, 1], [5, 2, 1]], [20, 7, 15], [10, 10, 10], "Dense constraint matrix"),
        ([[1.2, 0.8, 1.5]], [4], [3, 2, 4], "Highly fractional optimal relaxation"),
        ([[8, 2, 3], [1, 9, 2]], [30, 30], [5, 5, 5], "Sparse interaction"),
        ([[1, 1, 0], [0, 1, 1], [1, 0, 1], [1, 1, 1]], [2, 2, 2, 3], [1, 1, 1], "Symmetric box constraints"),
        ([[4, 6, 2], [1, 1, 5]], [15, 12], [8, 7, 9], "Competition between X1 and X2")
    ]

    A, b, c, popisek = examples_3v[7]

if __name__ == "__main__":

    control_panel()

    # print(is_integer(3.000000000000000000000000000001))
