# Tests for the optimization pass cluster WITH_COLUMNS

import pytest

import polars as pl
from polars.exceptions import ColumnNotFoundError
from polars.testing import assert_frame_equal


def test_basic_cwc() -> None:
    df = (
        pl.LazyFrame({"a": [1, 2]})
        .with_columns(pl.col("a").alias("b") * 2)
        .with_columns(pl.col("a").alias("c") * 3)
        .with_columns(pl.col("a").alias("d") * 4)
    )

    assert (
        """[[(col("a")) * (2)].alias("b"), [(col("a")) * (3)].alias("c"), [(col("a")) * (4)].alias("d")]"""
        in df.explain()
    )


def test_disable_cwc() -> None:
    df = (
        pl.LazyFrame({"a": [1, 2]})
        .with_columns(pl.col("a").alias("b") * 2)
        .with_columns(pl.col("a").alias("c") * 3)
        .with_columns(pl.col("a").alias("d") * 4)
    )

    explain = df.explain(optimizations=pl.QueryOptFlags(cluster_with_columns=False))

    assert """[[(col("a")) * (2)].alias("b")]""" in explain
    assert """[[(col("a")) * (3)].alias("c")]""" in explain
    assert """[[(col("a")) * (4)].alias("d")]""" in explain


def test_refuse_with_deps() -> None:
    df = (
        pl.LazyFrame({"a": [1, 2]})
        .with_columns(pl.col("a").alias("b") * 2)
        .with_columns(pl.col("b").alias("c") * 3)
        .with_columns(pl.col("c").alias("d") * 4)
    )

    explain = df.explain()

    assert """[[(col("a")) * (2)].alias("b")]""" in explain
    assert """[[(col("b")) * (3)].alias("c")]""" in explain
    assert """[[(col("c")) * (4)].alias("d")]""" in explain


def test_partial_deps() -> None:
    df = (
        pl.LazyFrame({"a": [1, 2]})
        .with_columns(pl.col("a").alias("b") * 2)
        .with_columns(
            pl.col("a").alias("c") * 3,
            pl.col("b").alias("d") * 4,
            pl.col("a").alias("e") * 5,
        )
        .with_columns(pl.col("b").alias("f") * 6)
    )

    explain = df.explain()

    assert (
        """[[(col("b")) * (4)].alias("d"), [(col("b")) * (6)].alias("f")]""" in explain
    )
    assert (
        """[[(col("a")) * (2)].alias("b"), [(col("a")) * (3)].alias("c"), [(col("a")) * (5)].alias("e")]"""
        in explain
    )


def test_swap_remove() -> None:
    df = (
        pl.LazyFrame({"a": [1, 2]})
        .with_columns(pl.col("a").alias("b") * 2)
        .with_columns(
            pl.col("b").alias("f") * 6,
            pl.col("a").alias("c") * 3,
            pl.col("b").alias("d") * 4,
            pl.col("b").alias("e") * 5,
        )
    )

    explain = df.explain()
    assert df.collect().equals(
        pl.DataFrame(
            {
                "a": [1, 2],
                "b": [2, 4],
                "f": [12, 24],
                "c": [3, 6],
                "d": [8, 16],
                "e": [10, 20],
            }
        )
    )

    assert (
        """[[(col("b")) * (6)].alias("f"), [(col("b")) * (4)].alias("d"), [(col("b")) * (5)].alias("e")]"""
        in explain
    )
    assert (
        """[[(col("a")) * (2)].alias("b"), [(col("a")) * (3)].alias("c")]""" in explain
    )
    assert """simple π""" in explain


def test_try_remove_simple_project() -> None:
    q = (
        pl.LazyFrame({"a": [1, 2]})
        .with_columns(pl.col("a").alias("b") * 2)
        .with_columns(pl.col("a").alias("d") * 4, pl.col("b").alias("c") * 3)
    )

    assert_frame_equal(
        q.collect(),
        pl.DataFrame(
            [
                pl.Series("a", [1, 2], dtype=pl.Int64),
                pl.Series("b", [2, 4], dtype=pl.Int64),
                pl.Series("d", [4, 8], dtype=pl.Int64),
                pl.Series("c", [6, 12], dtype=pl.Int64),
            ]
        ),
    )

    plan = q.explain()

    assert """[[(col("a")) * (2)].alias("b"), [(col("a")) * (4)].alias("d")]""" in plan
    assert """[[(col("b")) * (3)].alias("c")]""" in plan
    assert """simple π""" not in plan

    q = (
        pl.LazyFrame({"a": [1, 2]})
        .with_columns(pl.col("a").alias("b") * 2)
        .with_columns(pl.col("b").alias("c") * 3, pl.col("a").alias("d") * 4)
    )

    assert_frame_equal(
        q.collect(),
        pl.DataFrame(
            [
                pl.Series("a", [1, 2], dtype=pl.Int64),
                pl.Series("b", [2, 4], dtype=pl.Int64),
                pl.Series("c", [6, 12], dtype=pl.Int64),
                pl.Series("d", [4, 8], dtype=pl.Int64),
            ]
        ),
    )

    plan = q.explain()

    assert """[[(col("a")) * (2)].alias("b"), [(col("a")) * (4)].alias("d")]""" in plan
    assert """[[(col("b")) * (3)].alias("c")]""" in plan
    assert """simple π""" in plan


def test_cwc_with_internal_aliases() -> None:
    df = (
        pl.LazyFrame({"a": [1, 2], "b": [3, 4]})
        .with_columns(pl.any_horizontal((pl.col("a") == 2).alias("b")).alias("c"))
        .with_columns(pl.col("b").alias("d") * 3)
    )

    explain = df.explain()

    assert (
        """[[(col("a")) == (2)].alias("c"), [(col("b")) * (3)].alias("d")]""" in explain
    )


def test_read_of_pushed_column_16436() -> None:
    df = pl.DataFrame(
        {
            "x": [1.12, 2.21, 4.2, 3.21],
            "y": [2.11, 3.32, 2.1, 6.12],
        }
    )

    df = (
        df.lazy()
        .with_columns((pl.col("y") / pl.col("x")).alias("z"))
        .with_columns(
            pl.when(pl.col("z").is_infinite()).then(0).otherwise(pl.col("z")).alias("z")
        )
        .fill_nan(0)
        .collect()
    )


def test_multiple_simple_projections_16435() -> None:
    df = pl.DataFrame({"a": [1]}).lazy()

    df = (
        df.with_columns(b=pl.col("a"))
        .with_columns(c=pl.col("b"))
        .with_columns(l2a=pl.lit(2))
        .with_columns(l2b=pl.col("l2a"))
        .with_columns(m=pl.lit(3))
    )

    df.collect()


def test_reverse_order() -> None:
    df = pl.LazyFrame({"a": [1], "b": [2]})

    df = (
        df.with_columns(a=pl.col("a"), b=pl.col("b"), c=pl.col("a") * pl.col("b"))
        .with_columns(x=pl.col("a"), y=pl.col("b"))
        .with_columns(b=pl.col("a"), a=pl.col("b"))
    )

    df.collect()


def test_realias_of_unread_column_16530() -> None:
    df = (
        pl.LazyFrame({"x": [True]})
        .with_columns(x=pl.lit(False))
        .with_columns(y=~pl.col("x"))
        .with_columns(y=pl.lit(False))
    )

    plan = df.explain()

    assert plan.count("WITH_COLUMNS") == 1
    assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False]}))


def test_realias_with_dependencies() -> None:
    df = (
        pl.LazyFrame({"x": [True]})
        .with_columns(x=pl.lit(False))
        .with_columns(y=~pl.col("x"))
        .with_columns(y=pl.lit(False), z=pl.col("y") | True)
    )

    explain = df.explain()

    assert explain.count("WITH_COLUMNS") == 3
    assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]}))


def test_refuse_pushdown_with_aliases() -> None:
    df = (
        pl.LazyFrame({"x": [True]})
        .with_columns(x=pl.lit(False))
        .with_columns(y=pl.lit(True))
        .with_columns(y=pl.lit(False), z=pl.col("y") | True)
    )

    explain = df.explain()

    assert explain.count("WITH_COLUMNS") == 2
    assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]}))


def test_neighbour_live_expr() -> None:
    df = (
        pl.LazyFrame({"x": [True]})
        .with_columns(y=pl.lit(False))
        .with_columns(x=pl.lit(False), z=pl.col("x") | False)
    )

    explain = df.explain()

    assert explain.count("WITH_COLUMNS") == 1
    assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]}))


def test_cluster_with_columns_collect_all_panic_26092() -> None:
    lf = pl.LazyFrame()
    lf = lf.with_columns(pl.lit(1.0).cast(pl.Float64()).alias("numbers1"))
    lf = lf.with_columns(pl.lit(2.0).cast(pl.Float64()).alias("numbers2"))

    a, b = pl.collect_all([lf, lf])

    assert_frame_equal(a, pl.DataFrame({"numbers1": 1.0, "numbers2": 2.0}))
    assert_frame_equal(b, pl.DataFrame({"numbers1": 1.0, "numbers2": 2.0}))


def test_cluster_with_columns_schema_update_26417() -> None:
    lf = pl.LazyFrame({"x": [[0.0, 1.0]], "y": [[2.0]]})

    q = (
        lf.with_columns(pl.col("x").cast(pl.Array(pl.Float64, shape=2)))
        .with_columns(pl.col("y").cast(pl.Array(pl.Float64, shape=1)))
        .with_columns(pl.col("y").arr.get(0))
    )

    assert_frame_equal(
        q.collect(),
        pl.DataFrame(
            [
                pl.Series("x", [[0.0, 1.0]], dtype=pl.Array(pl.Float64, shape=(2,))),
                pl.Series("y", [2.0], dtype=pl.Float64),
            ]
        ),
    )


def test_cluster_with_columns_use_existing_names_26456() -> None:
    q = (
        pl.LazyFrame({"a": [1, 2, 3]})
        .with_columns(pl.lit(1).alias("b"))
        .with_columns(pl.col("a") + 1, pl.col("b") + pl.col("a"))
    )

    assert_frame_equal(
        q.collect(),
        pl.DataFrame(
            [
                pl.Series("a", [2, 3, 4], dtype=pl.Int64),
                pl.Series("b", [2, 3, 4], dtype=pl.Int64),
            ]
        ),
    )


def test_cluster_with_columns_prune_col() -> None:
    q = (
        pl.LazyFrame({"foo": [0.5, 1.7, 3.2], "bar": [4.1, 1.5, 9.2]})
        .with_columns(pl.col("foo").alias("buzz"))
        .with_columns(pl.col("buzz"), pl.col("foo") * 2.0)
    )

    plan = q.explain()

    assert plan.count("WITH_COLUMNS") == 1

    assert_frame_equal(
        q.collect(),
        pl.DataFrame(
            [
                pl.Series("foo", [1.0, 3.4, 6.4], dtype=pl.Float64),
                pl.Series("bar", [4.1, 1.5, 9.2], dtype=pl.Float64),
                pl.Series("buzz", [0.5, 1.7, 3.2], dtype=pl.Float64),
            ]
        ),
    )

    q = pl.LazyFrame({"a": 1}).with_columns(pl.col("a")).with_columns(pl.col("b"))

    with pytest.raises(ColumnNotFoundError, match='unable to find column "b"'):
        q.collect()
