import streamlit as st
import pandas as pd
from surprise import SVD, Dataset, dump, Reader
from surprise.model_selection import GridSearchCV
from collections import defaultdict
import os

def main():
    # StreamlitUIݒ
    st.title("RhWebAv")

    menu = ["RhGWwK@\", "Rh{@\"]
    choice = st.sidebar.selectbox("j[", menu)

    # [U[̑Iɉċ@\؂ւ
    if choice == "RhGWwK@\":
        recommend_engine_learning()
    else:
        recommend_execution()

def load_data():
    # f[^̃[h@\
    data_source = st.radio(
        "f[^\[X̑I", 
        ["CSVt@C","Tvf[^(ml-100k)"]
    )

    if data_source == "Tvf[^(ml-100k)":
        return Dataset.load_builtin('ml-100k')
    
    uploaded_file = st.file_uploader(
        "CSVt@CIĂ", 
        type="csv"
    )
    if uploaded_file:
        df = pd.read_csv(uploaded_file)
        reader = Reader(rating_scale=(1, 5))
        return Dataset.load_from_df(
            df[["user_id", "item_id", "rating"]], 
            reader
        )

def hyperparameter_tuning(data,cv_value):
    # nCp[p[^̃`[jO@\
    param_grid = {
        'n_epochs': [5, 10, 20, 50], 
        'lr_all': [0.002, 0.005, 0.01],
        'reg_all': [0.02, 0.1, 0.2, 0.4, 0.6],
        'n_factors': [50, 100, 200]
    }
    gs = GridSearchCV(
        SVD, 
        param_grid,
        measures=['rmse', 'mae'], 
        cv=cv_value
    )
    gs.fit(data)
    return gs.best_params

def train_and_save_model(data, cv_value=0):
    # f̊wKƕۑ@\
    save_path = st.text_input(
        "ۑt@C͂Ă", 
        "model.pkl"
    )
    
    if st.button("wK&ۑ"):
        with st.spinner("wK..."):
            if cv_value>1:
                best_params = hyperparameter_tuning(data,cv_value)
                algo = SVD(**best_params['rmse'])
            else:
                algo = SVD()

            trainset = data.build_full_trainset()
            algo.fit(trainset)

            testset = trainset.build_anti_testset()
            predictions = algo.test(testset)

            dump.dump(save_path, predictions=predictions, algo=algo)
            st.success(f"f {save_path} Ƃĕۑ܂I")

def recommend_engine_learning():
    # RhGW̊wK@\
    st.subheader("wKf[^̃[h")
    data = load_data()

    if data:
        st.subheader("f̊wK")
        tuning_choice = st.radio(
            "nCp[p[^`[jȎI", 
            ["nCp[p[^`[jO{", 
             "nCp[p[^`[jO{Ȃ"]
        )
        if tuning_choice == "nCp[p[^`[jO{":
            cv_value = st.slider(
                "CV̉񐔂IĂ", 
                min_value=2, max_value=10, value=2, step=1
            )
            train_and_save_model(data, cv_value)
        else:
            train_and_save_model(data, 0)

def recommend_execution():
    # Rh̎{@\
    st.subheader("wKς݃f̃[h")
    predictions, loaded_algo = load_trained_model()

    st.subheader("RhΏێ")    
    sub_menu = [
        "[Uɑ΂郌Rh̎{", 
        "VK[Uɑ΂郌Rh̎{"
    ]
    choice = st.radio("IvVIĂ", sub_menu)

    if loaded_algo:  
        if choice == "[Uɑ΂郌Rh̎{":
            recommend_existing_users(predictions, loaded_algo)
        else:
            recommend_new_users(loaded_algo)

def load_trained_model():
    # wKς݃f̃[h@\
    uploaded_file = st.file_uploader(
        "wKς݃fAbv[hĂ", 
        type=["pkl"]
    )
    if uploaded_file:
        with st.spinner("..."):
            with open("temp_model.pkl", "wb") as f:
                f.write(uploaded_file.read())
            predictions, loaded_algo = dump.load("temp_model.pkl")
            st.success("I܂I")
            return predictions, loaded_algo

    return None, None

def recommend_existing_users(predictions, loaded_algo):
    # [Uւ̃Rh@\
    st.subheader("[Uւ̃Rh")   
    with st.spinner("[Uɑ΂郌Rh{..."):
        top_n_df = get_top_n(predictions, n=10)
        st.dataframe(top_n_df, width=600, height=400)

    save_filename = st.text_input(
        "ۑCSVt@C͂Ă", 
        "existing_user_recommendations.csv"
    )
    if st.button("ʂCSVt@CƂĕۑ"):
        top_n_df.to_csv(save_filename, index=False)
        st.success(f"ʂ {save_filename} Ƃĕۑ܂I")

def recommend_new_users(loaded_algo):
    # VK[Ũf[^Abv[h
    uploaded_file = st.file_uploader(
        "VK[U̕]f[^CSVt@CIĂ", 
        type="csv"
    )
    if uploaded_file:
        # Abv[hꂽt@Cf[^t[Ƃēǂݍ
        new_user_ratings = pd.read_csv(uploaded_file)
        
        # ASYPf[^Zbg擾
        trainset = loaded_algo.trainset
        unique_users = new_user_ratings['user_id'].unique()
        recommendations = []

        # Tuwb_̕\
        st.subheader("VK[Uւ̃Rh")   

        # VK[Uɑ΂郌Rh{
        with st.spinner("VK[Uɑ΂郌Rh{..."):
            for user in unique_users:
                user_ratings = new_user_ratings[new_user_ratings['user_id'] == user]
                rated_items = user_ratings['item_id'].values.tolist()

                preds = []
                for iid in trainset.all_items():
                    # ܂]ĂȂACeɑ΂\擾
                    if trainset.to_raw_iid(iid) not in rated_items:
                        preds.append(
                            (user, trainset.to_raw_iid(iid), 
                             loaded_algo.predict(user, trainset.to_raw_iid(iid)).est)
                        )

                # \l̍Ƀ\[g
                preds.sort(key=lambda x: x[2], reverse=True)
                recommendations.append(
                    pd.DataFrame(preds[:10], 
                                 columns=['user_id', 'item_id', 'rating'])
                )

            # Rhʂ̕\
            recommendations = pd.concat(recommendations)
            st.dataframe(recommendations, width=600, height=400)

            # RhʂCSVƂĕۑ
            save_filename = st.text_input(
                "ۑCSVt@C͂Ă", 
                "new_user_recommendations.csv"
            )
            if st.button("ʂCSVt@CƂĕۑ"):
                recommendations.to_csv(save_filename, index=False)
                st.success(f"ʂ {save_filename} Ƃĕۑ܂I")

def get_top_n(predictions, n=10):
    # \ʂñACe擾֐
    rows = []
    user_ratings = defaultdict(list)
    
    for uid, iid, true_r, est, _ in predictions:
        user_ratings[uid].append((iid, est))
    
    for uid, ratings in user_ratings.items():
        # \̍Ƀ\[g
        ratings.sort(key=lambda x: x[1], reverse=True)
        top_ratings = ratings[:n]
        for iid, est in top_ratings:
            rows.append([uid, iid, est])
    
    return pd.DataFrame(rows, columns=['User', 'Item', 'Estimate'])

if __name__ == "__main__":
    # C֐̎s
    main()