STONKS/main.py

152 lines
3.9 KiB
Python

import os
from fastapi import FastAPI, Depends, Query
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from sqlalchemy import create_engine, text
from dotenv import load_dotenv
load_dotenv()
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASS")
DB_HOST = os.getenv("DB_HOST")
DB_NAME = os.getenv("DB_NAME")
if not all([DB_USER, DB_PASS, DB_HOST, DB_NAME]):
raise EnvironmentError("Missing DB environment variables")
app = FastAPI()
engine = create_engine(
f"mysql+pymysql://{DB_USER}:{DB_PASS}@{DB_HOST}/{DB_NAME}",
pool_pre_ping=True
)
app.mount("/static", StaticFiles(directory="static"), name="static")
# -----------------------
# PAGE 1: DEFAULT /
# -----------------------
@app.get("/")
async def index():
with open("templates/index.html", "r", encoding="utf-8") as f:
return HTMLResponse(f.read())
@app.get("/embed")
async def index():
with open("templates/embed.html", "r", encoding="utf-8") as f:
return HTMLResponse(f.read())
# -----------------------
# PAGE 2: /graphs
# -----------------------
@app.get("/graphs")
async def graphs():
with open("templates/graphs.html", "r", encoding="utf-8") as f:
return HTMLResponse(f.read())
# -----------------------
# PAGE 3: /portgraphs
# -----------------------
@app.get("/port")
async def port():
with open("templates/portfoliograph.html", "r", encoding="utf-8") as f:
return HTMLResponse(f.read())
# -----------------------
# API 1 (used by /)
# -----------------------
@app.get("/api/data")
async def api_data():
query = text("SELECT * FROM STOCKMARKET_HISTORY WHERE Ticker = 'FAT' ORDER BY ChangeID DESC LIMIT 28")
with engine.connect() as conn:
result = conn.execute(query)
rows = [dict(r._mapping) for r in result]
return {
"columns": list(rows[0].keys()) if rows else [],
"rows": rows
}
# -----------------------
# API 2 (used by /graphs)
# -----------------------
@app.get("/api/graphs")
async def api_graphs(
ticker: str = Query(...),
limit: int = Query(28),
):
tickers = [t.strip().upper() for t in ticker.split(",") if t.strip()]
with engine.connect() as conn:
results = {}
for t in tickers:
query = text("""
SELECT
h.*,
s.CompanyName
FROM STOCKMARKET_HISTORY h
JOIN STOCKMARKET_STOCKS s
ON s.Ticker = h.Ticker
WHERE h.Ticker = :ticker
ORDER BY h.ChangeID DESC
LIMIT :limit;
""")
result = conn.execute(query, {
"ticker": t,
"limit": limit * 4
})
rows = [dict(r._mapping) for r in result]
results[t] = rows
return results
@app.get("/api/portfolio")
async def api_folio(
ticker: str = Query(...),
uuid: int = Query(0),
):
tickers = [t.strip().upper() for t in ticker.split(",") if t.strip()]
with engine.connect() as conn:
results = {}
for t in tickers:
query = text("""
SELECT
p.*,
s.CurPrice,
(s.CurPrice * p.SharesOwned) AS CurrentValue
FROM STOCKMARKET_PORTFOLIO p
JOIN STOCKMARKET_STOCKS s
ON s.Ticker = p.Ticker
WHERE p.Ticker = :ticker
AND p.UUID = :uuid;
""")
result = conn.execute(query, {
"ticker": t,
"uuid": uuid
})
rows = [dict(r._mapping) for r in result]
results[t] = rows
return results
# -----------------------
# RUN
# -----------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)