111 lines
2.8 KiB
Python
111 lines
2.8 KiB
Python
import os
|
|
from fastapi import FastAPI, Depends, Query
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.responses import HTMLResponse
|
|
from sqlalchemy import create_engine, text
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
DB_USER = os.getenv("DB_USER")
|
|
DB_PASS = os.getenv("DB_PASS")
|
|
DB_HOST = os.getenv("DB_HOST")
|
|
DB_NAME = os.getenv("DB_NAME")
|
|
|
|
if not all([DB_USER, DB_PASS, DB_HOST, DB_NAME]):
|
|
raise EnvironmentError("Missing DB environment variables")
|
|
|
|
app = FastAPI()
|
|
|
|
engine = create_engine(
|
|
f"mysql+pymysql://{DB_USER}:{DB_PASS}@{DB_HOST}/{DB_NAME}",
|
|
pool_pre_ping=True
|
|
)
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
|
|
# -----------------------
|
|
# PAGE 1: DEFAULT /
|
|
# -----------------------
|
|
@app.get("/")
|
|
async def index():
|
|
with open("templates/index.html", "r", encoding="utf-8") as f:
|
|
return HTMLResponse(f.read())
|
|
|
|
@app.get("/embed")
|
|
async def index():
|
|
with open("templates/embed.html", "r", encoding="utf-8") as f:
|
|
return HTMLResponse(f.read())
|
|
|
|
# -----------------------
|
|
# PAGE 2: /graphs
|
|
# -----------------------
|
|
@app.get("/graphs")
|
|
async def graphs():
|
|
with open("templates/graphs.html", "r", encoding="utf-8") as f:
|
|
return HTMLResponse(f.read())
|
|
|
|
|
|
# -----------------------
|
|
# API 1 (used by /)
|
|
# -----------------------
|
|
@app.get("/api/data")
|
|
async def api_data():
|
|
query = text("SELECT * FROM STOCKMARKET_HISTORY WHERE Ticker = 'FAT' ORDER BY ChangeID DESC LIMIT 28")
|
|
|
|
with engine.connect() as conn:
|
|
result = conn.execute(query)
|
|
rows = [dict(r._mapping) for r in result]
|
|
|
|
return {
|
|
"columns": list(rows[0].keys()) if rows else [],
|
|
"rows": rows
|
|
}
|
|
|
|
|
|
# -----------------------
|
|
# API 2 (used by /graphs)
|
|
# -----------------------
|
|
from fastapi import Depends, Query
|
|
|
|
@app.get("/api/graphs")
|
|
async def api_graphs(
|
|
ticker: str = Query(...),
|
|
limit: int = Query(28),
|
|
):
|
|
tickers = [t.strip().upper() for t in ticker.split(",") if t.strip()]
|
|
|
|
with engine.connect() as conn:
|
|
results = {}
|
|
|
|
for t in tickers:
|
|
query = text("""
|
|
SELECT
|
|
h.*,
|
|
s.CompanyName
|
|
FROM STOCKMARKET_HISTORY h
|
|
JOIN STOCKMARKET_STOCKS s
|
|
ON s.Ticker = h.Ticker
|
|
WHERE h.Ticker = :ticker
|
|
ORDER BY h.ChangeID DESC
|
|
LIMIT :limit;
|
|
""")
|
|
|
|
result = conn.execute(query, {
|
|
"ticker": t,
|
|
"limit": limit * 4
|
|
})
|
|
|
|
rows = [dict(r._mapping) for r in result]
|
|
results[t] = rows
|
|
|
|
return results
|
|
|
|
|
|
# -----------------------
|
|
# RUN
|
|
# -----------------------
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |