# main.py # FastAPI scaffold with JWT auth stubs and key endpoints import os from datetime import datetime, timedelta, date from typing import Optional, List from fastapi import FastAPI, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from pydantic import BaseModel from sqlalchemy.orm import sessionmaker, Session from sqlalchemy import create_engine import jwt from models import init_db, User, Role, Account, Stock, PriceHistory, OptionsChain, GlobalEnvVar, NightlyJob, PriorityStock, ChangeLog # === Configuration === DB_URL = os.getenv("DATABASE_URL", "sqlite:///./earnalot.db") JWT_SECRET = os.getenv("JWT_SECRET", "unsafe-dev-secret") JWT_ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 60 engine = init_db(DB_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) app = FastAPI(title="Earnalot - Algo Trading API (MVP)") # ================== # Utility functions # ================== def get_db(): db = SessionLocal() try: yield db finally: db.close() def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM) return encoded_jwt # Very small stub: in real app use passlib to verify hashed passwords def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: user = db.query(User).filter(User.username == username).first() if not user: return None # NOTE: password_hash currently used as plain text in dev. Replace with bcrypt in prod. if user.password_hash != password: return None return user # Dependency to get current user (simple JWT decode) def get_current_user(token: str = Depends(OAuth2PasswordRequestForm), db: Session = Depends(next(get_db))): # NOTE: for simplicity, we will accept form with 'username' as token in dev, but implement proper OAuth2 in prod. # This is a placeholder — replace with OAuth2PasswordBearer and JWT decode for production. raise HTTPException(status_code=501, detail="Use the /auth/login endpoint for JWT token retrieval in this scaffold.") # Simple role check decorator def require_role(db: Session, user: User, roles: List[str]): if user is None: raise HTTPException(status_code=401, detail="Unauthorized") if user.role is None or user.role.name not in roles: raise HTTPException(status_code=403, detail="Forbidden") # ================== # Pydantic schemas # ================== class TokenResp(BaseModel): access_token: str token_type: str = "bearer" class EnvVarIn(BaseModel): name: str value: str value_type: Optional[str] = "str" description: Optional[str] = None class OrderIn(BaseModel): account_id: int stock_symbol: Optional[str] = None option_id: Optional[int] = None order_type: str side: str quantity: int price: Optional[float] = None # ================== # Auth endpoints # ================== @app.post("/auth/login", response_model=TokenResp) def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(lambda: next(get_db()))): # OAuth2PasswordRequestForm uses 'username' and 'password' user = authenticate_user(db, form_data.username, form_data.password) if not user: raise HTTPException(status_code=400, detail="Incorrect username or password") token = create_access_token({"sub": user.username, "role": user.role.name if user.role else "viewer", "uid": user.id}) return {"access_token": token, "token_type": "bearer"} # ================== # User / account endpoints (minimal) # ================== @app.get("/users/me") def users_me(token: str = Depends(lambda: None)): # Placeholder — decode token in real implementation raise HTTPException(status_code=501, detail="Implement JWT decode to return current user info.") # ================== # ENV VAR endpoints (admin) # ================== @app.post("/env/global") def set_global_env(var: EnvVarIn, db: Session = Depends(lambda: next(get_db()))): # Admin-only in prod: check token/role existing = db.query(GlobalEnvVar).filter(GlobalEnvVar.name == var.name).first() if existing: existing.value = var.value existing.value_type = var.value_type or existing.value_type existing.description = var.description or existing.description else: g = GlobalEnvVar(name=var.name, value=var.value, value_type=var.value_type or "str", description=var.description) db.add(g) db.commit() return {"ok": True, "name": var.name, "value": var.value} # ================== # Stocks & price endpoints # ================== @app.get("/stocks") def list_stocks(limit: int = 100, db: Session = Depends(lambda: next(get_db()))): stocks = db.query(Stock).limit(limit).all() return [{"symbol": s.symbol, "name": s.name, "id": s.id} for s in stocks] @app.get("/stocks/{symbol}/price_history") def stock_price_history(symbol: str, days: int = 60, db: Session = Depends(lambda: next(get_db()))): stock = db.query(Stock).filter(Stock.symbol == symbol).first() if not stock: raise HTTPException(status_code=404, detail="Stock not found") cutoff = datetime.utcnow() - timedelta(days=days) rows = db.query(PriceHistory).filter(PriceHistory.stock_id == stock.id, PriceHistory.ts >= cutoff).order_by(PriceHistory.ts.asc()).all() return [{"ts": r.ts.isoformat(), "open": float(r.open) if r.open is not None else None, "high": float(r.high) if r.high is not None else None, "low": float(r.low) if r.low is not None else None, "close": float(r.close) if r.close is not None else None, "volume": r.volume} for r in rows] @app.get("/stocks/{symbol}/options") def stock_options(symbol: str, type: Optional[str] = None, expiry: Optional[date] = None, db: Session = Depends(lambda: next(get_db()))): stock = db.query(Stock).filter(Stock.symbol == symbol).first() if not stock: raise HTTPException(status_code=404, detail="Stock not found") q = db.query(OptionsChain).filter(OptionsChain.stock_id == stock.id) if type: q = q.filter(OptionsChain.type == type.upper()) if expiry: q = q.filter(OptionsChain.expiry == expiry) rows = q.order_by(OptionsChain.volume.desc()).limit(200).all() return [{"option_symbol": r.option_symbol, "strike": float(r.strike), "expiry": r.expiry.isoformat(), "type": r.type, "last_price": float(r.last_price) if r.last_price is not None else None, "volume": r.volume} for r in rows] # ================== # Orders & Positions (stubs) # ================== @app.post("/orders") def place_order(order: OrderIn, db: Session = Depends(lambda: next(get_db()))): # Permissions check should occur here. # For demo, we will create a PENDING order row and return it. # In production, connect to broker adapter / paper-trade engine. account = db.query(Account).filter(Account.id == order.account_id).first() if not account: raise HTTPException(status_code=404, detail="Account not found") # Basic validation if order.quantity <= 0: raise HTTPException(status_code=400, detail="Quantity must be > 0") new_order = { "account_id": order.account_id, "order_type": order.order_type, "side": order.side, "quantity": order.quantity, "price": order.price, "status": "PENDING", "created_at": datetime.utcnow() } # For now, return a simulated response return {"ok": True, "order": new_order} @app.get("/positions") def list_positions(account_id: Optional[int] = None, db: Session = Depends(lambda: next(get_db()))): q = db.query(Account) if account_id: acct = db.query(Account).filter(Account.id == account_id).first() if not acct: raise HTTPException(status_code=404, detail="Account not found") return [{"id": p.id, "stock_id": p.stock_id, "option_id": p.option_id, "qty": p.quantity, "avg_price": float(p.avg_price)} for p in acct.positions] # list all positions (summary) sessions = db.query(Account).all() out = [] for acct in sessions: for p in acct.positions: out.append({"account_id": acct.id, "position_id": p.id, "stock_id": p.stock_id, "option_id": p.option_id, "qty": p.quantity}) return out # ================== # Nightly job trigger (admin) # ================== @app.post("/jobs/nightly") def trigger_nightly(db: Session = Depends(lambda: next(get_db()))): # In production this would enqueue a background job (Celery / RQ / Cloud run). # Here we'll create a NightlyJob row and return a response. job = NightlyJob(job_date=date.today(), status="STARTED", started_at=datetime.utcnow()) db.add(job) # example: after job completes we would update job.status and finished_at db.commit() return {"ok": True, "job_id": job.id, "job_date": str(job.job_date)} # ================== # Priority list # ================== @app.get("/priority") def get_priority(limit: int = 200, db: Session = Depends(lambda: next(get_db()))): rows = db.query(PriorityStock).order_by(PriorityStock.score.desc()).limit(limit).all() out = [] for r in rows: out.append({"symbol": r.stock.symbol if r.stock else None, "reason": r.reason, "score": float(r.score) if r.score is not None else None, "flagged_at": r.flagged_at.isoformat()}) return out # ================== # Change log entry (helper) # ================== @app.post("/changelog") def add_changelog(change_tag: str, details: str, db: Session = Depends(lambda: next(get_db()))): row = ChangeLog(change_tag=change_tag, details=details) db.add(row) db.commit() return {"ok": True, "id": row.id} # ================== # Root # ================== @app.get("/") def root(): return {"service": "Earnalot - Algo Trading API", "version": "mvp", "now": datetime.utcnow().isoformat()}