import os from fastapi import FastAPI, Depends, Query from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse from sqlalchemy import create_engine, text from dotenv import load_dotenv load_dotenv() DB_USER = os.getenv("DB_USER") DB_PASS = os.getenv("DB_PASS") DB_HOST = os.getenv("DB_HOST") DB_NAME = os.getenv("DB_NAME") if not all([DB_USER, DB_PASS, DB_HOST, DB_NAME]): raise EnvironmentError("Missing DB environment variables") app = FastAPI() engine = create_engine( f"mysql+pymysql://{DB_USER}:{DB_PASS}@{DB_HOST}/{DB_NAME}", pool_pre_ping=True ) app.mount("/static", StaticFiles(directory="static"), name="static") # ----------------------- # PAGE 1: DEFAULT / # ----------------------- @app.get("/") async def index(): with open("templates/index.html", "r", encoding="utf-8") as f: return HTMLResponse(f.read()) @app.get("/embed") async def index(): with open("templates/embed.html", "r", encoding="utf-8") as f: return HTMLResponse(f.read()) # ----------------------- # PAGE 2: /graphs # ----------------------- @app.get("/graphs") async def graphs(): with open("templates/graphs.html", "r", encoding="utf-8") as f: return HTMLResponse(f.read()) # ----------------------- # PAGE 3: /portgraphs # ----------------------- @app.get("/port") async def port(): with open("templates/portfoliograph.html", "r", encoding="utf-8") as f: return HTMLResponse(f.read()) # ----------------------- # API 1 (used by /) # ----------------------- @app.get("/api/data") async def api_data(): query = text("SELECT * FROM STOCKMARKET_HISTORY WHERE Ticker = 'FAT' ORDER BY ChangeID DESC LIMIT 28") with engine.connect() as conn: result = conn.execute(query) rows = [dict(r._mapping) for r in result] return { "columns": list(rows[0].keys()) if rows else [], "rows": rows } # ----------------------- # API 2 (used by /graphs) # ----------------------- @app.get("/api/graphs") async def api_graphs( ticker: str = Query(...), limit: int = Query(28), ): tickers = [t.strip().upper() for t in ticker.split(",") if t.strip()] with engine.connect() as conn: results = {} for t in tickers: query = text(""" SELECT h.*, s.CompanyName FROM STOCKMARKET_HISTORY h JOIN STOCKMARKET_STOCKS s ON s.Ticker = h.Ticker WHERE h.Ticker = :ticker ORDER BY h.ChangeID DESC LIMIT :limit; """) result = conn.execute(query, { "ticker": t, "limit": limit * 4 }) rows = [dict(r._mapping) for r in result] results[t] = rows return results @app.get("/api/portfolio") async def api_folio( ticker: str = Query(...), uuid: int = Query(0), ): tickers = [t.strip().upper() for t in ticker.split(",") if t.strip()] with engine.connect() as conn: results = {} for t in tickers: query = text(""" SELECT p.*, s.CurPrice, (s.CurPrice * p.SharesOwned) AS CurrentValue FROM STOCKMARKET_PORTFOLIO p JOIN STOCKMARKET_STOCKS s ON s.Ticker = p.Ticker WHERE p.Ticker = :ticker AND p.UUID = :uuid; """) result = conn.execute(query, { "ticker": t, "uuid": uuid }) rows = [dict(r._mapping) for r in result] results[t] = rows return results # ----------------------- # RUN # ----------------------- if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)