| | import streamlit as st |
| | import numpy as np |
| | from llm import load_llm, response_generator |
| | from sql import csv_to_sqlite, run_sql_query |
| |
|
| |
|
| | repo_id = "Qwen/Qwen2.5-Coder-3B-Instruct-GGUF" |
| | filename = "qwen2.5-coder-3b-instruct-q6_k.gguf" |
| | |
| | |
| |
|
| | llm = load_llm(repo_id, filename) |
| |
|
| | st.title("CSV TO SQL") |
| | st.write("To start, Upload your CSV below 👇") |
| | if st.button("Example prompt"): |
| | st.session_state.csv_file = "./data/sales.csv" |
| | st.session_state.db_name = "sales" |
| | st.session_state.table_name = "sales" |
| | csv_to_sqlite("./data/sales.csv", "sales", "sales") |
| |
|
| | prompt = "What is the sum, count and average sales?" |
| |
|
| | st.session_state.messages.append({"role": "user", "content": prompt}) |
| | response_sql = response_generator( |
| | db_name=st.session_state.db_name, |
| | table_name=st.session_state.table_name, |
| | llm=llm, |
| | messages=st.session_state.messages, |
| | question=prompt, |
| | ) |
| | result = run_sql_query(db_name=st.session_state.db_name, query=response_sql) |
| | st.session_state.messages.append({"role": "assistant", "content": response_sql}) |
| | st.session_state.messages.append( |
| | {"role": "assistant", "content": str(result), "result": result} |
| | ) |
| |
|
| |
|
| | with st.expander("Upload CSV"): |
| | csv_file = st.file_uploader( |
| | "CSV", |
| | ) |
| | db_name = st.text_input("DB Name") |
| | table_name = st.text_input("Table Name") |
| | if st.button("Save"): |
| | if csv_file and db_name and table_name: |
| | st.session_state.csv_file = csv_file |
| | st.session_state.db_name = db_name |
| | st.session_state.table_name = table_name |
| |
|
| | csv_to_sqlite(csv_file, db_name, table_name) |
| | st.write("Saved ✅") |
| | else: |
| | st.write("Please enter all values") |
| |
|
| | |
| | if "messages" not in st.session_state: |
| | st.session_state.messages = [] |
| |
|
| | |
| | for message in st.session_state.messages: |
| | with st.chat_message(message["role"]): |
| | if "content" in message: |
| | if message["role"] == "user": |
| | st.markdown(message["content"]) |
| | else: |
| | st.code(message["content"]) |
| | if "result" in message: |
| | st.dataframe(message["result"]) |
| |
|
| | |
| | if prompt := st.chat_input( |
| | "What is up?", |
| | disabled=( |
| | not "db_name" in st.session_state or not "table_name" in st.session_state |
| | ), |
| | ): |
| | |
| | st.session_state.messages.append({"role": "user", "content": prompt}) |
| | |
| | with st.chat_message("user"): |
| | st.markdown(prompt) |
| |
|
| | |
| | with st.chat_message("assistant"): |
| | response_sql = response_generator( |
| | db_name=st.session_state.db_name, |
| | table_name=st.session_state.table_name, |
| | llm=llm, |
| | messages=st.session_state.messages, |
| | question=prompt, |
| | ) |
| | response = st.code(response_sql) |
| | result = run_sql_query(db_name=st.session_state.db_name, query=response_sql) |
| | st.markdown(result) |
| | st.table(result) |
| |
|
| | |
| | st.session_state.messages.append({"role": "assistant", "content": response_sql}) |
| |
|
| | with st.sidebar: |
| | st.title("Data Previewer") |
| | st.write("You can see you CSV file content here") |
| | if ( |
| | "csv_file" in st.session_state |
| | and "db_name" in st.session_state |
| | and "table_name" in st.session_state |
| | ): |
| | result = run_sql_query( |
| | db_name=st.session_state.db_name, |
| | query=f"SELECT * FROM {st.session_state.table_name}", |
| | ) |
| | st.dataframe(result) |
| |
|