import polars as pl
import numpy as np
from sknetwork.data import from_edge_list
from sknetwork.hierarchy import Paris
from sknetwork.clustering import Louvain
import altair as alt
df = pl.read_csv("data/dataset.csv", schema_overrides={"dpt": pl.String})
df = (
    # first, remove rare nouns
    df.filter(pl.col("name") != "_PRENOMS_RARES")

    # then, remove nouns that have less that 100 occurences in total
    .filter(pl.sum("count").over("name") >= 100)

    .with_columns(
        
        
        # add a column with a unique ID for each name
        pl.col("name").rank("dense").alias("name_id"),

        pl.col("dpt").rank("dense").alias("dpt_id"),

        # add decade
        pl.col("year").floordiv(10).mul(10).alias("decade")
    )
    .drop_nulls()
)
dpt = df[["dpt_id", "dpt"]].unique()
name = df[["name_id", "name"]].unique()
df
shape: (3_661_218, 8)
sexedptyearnamecountname_iddpt_iddecade
strstri64stri64u32u32i64
"M""84"1983"AADIL"31841980
"M""92"1992"AADIL"31921990
"M""75"1962"AARON"36751960
"M""75"1976"AARON"36751970
"M""75"1982"AARON"36751980
"F""974"2011"ZYA"313035992010
"F""44"2013"ZYA"413035442010
"F""59"2013"ZYA"313035592010
"F""974"2017"ZYA"313035992010
"F""59"2018"ZYA"313035592010
name_dpt = (
    df.group_by(["name_id", "dpt_id"])
        .agg(pl.col("count").sum()
    )
        .with_columns(
            pl.col("count").rank("dense", descending=True).over("dpt_id").alias("rank_dpt"),
            pl.col("count").sum().over("name_id").rank("dense", descending=True).alias("rank_global")
        )
        .filter(pl.col("rank_global") > 10)
)

biadjacency = from_edge_list(list(name_dpt[["name_id", "dpt_id", "count"]].iter_rows()), bipartite=True)

louvain = Louvain(resolution=1)
louvain.fit(biadjacency, force_bipartite=True)
dpt_labels = pl.DataFrame({"dpt_id": range(len(louvain.labels_col_)), "label_dpt": louvain.labels_col_})
name_labels = pl.DataFrame({"name_id": range(len(louvain.labels_row_)), "label_name": louvain.labels_row_})
name_dpt_cluster = (
    name_dpt
    .filter(pl.col("rank_dpt") < 35)
    .join(dpt_labels, on="dpt_id")
    .join(name_labels, on="name_id")
    .join(dpt, on="dpt_id")
    .join(name, on="name_id")
)
url_geojson = "https://france-geojson.gregoiredavid.fr/repo/departements.geojson"
geodata = alt.Data(url=url_geojson, format=alt.DataFormat(property="features"))
click = alt.selection_point(on='click', fields=['dpt'], empty=True)

# Base data transformation
data = (alt.Chart(name_dpt_cluster)
    .transform_lookup(
        lookup='dpt',
        from_=alt.LookupData(geodata, 'properties.code'),
        as_="geo"
    )
    .add_params(click, hover)  # Add the selection parameter to the base chart
)

# Map visualization
name_map = data.mark_geoshape().encode(
    color='label_dpt:N',
    shape='geo:G',
    stroke=alt.condition(click, alt.value("black"), alt.value("gray")),
    strokeWidth=alt.condition(click, alt.value(2), alt.value(0.5)),
    tooltip=alt.value("click to show detail"),
)

# Bar chart visualization
name_bar = (data
    .transform_filter(click)
    .mark_bar().encode(
        color='label_name:N',
        x='count:Q',
        y=alt.Y("name:N", sort='-x'),
        
    )
)

# Final combined chart
(name_map | name_bar)