import uuid import transformers import asyncio import os import torch import aiohttp from fastapi import FastAPI from pydantic import BaseModel from contextlib import asynccontextmanager from apscheduler.schedulers.background import BackgroundScheduler # ENV VARS ONEUPTIME_URL = os.getenv("ONEUPTIME_URL") HF_MODEL_NAME = os.getenv("HF_MODEL_NAME") HF_TOKEN = os.getenv("HF_TOKEN") if not HF_MODEL_NAME: HF_MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct" print(f"HF_MODEL_NAME not set. Using default model: {HF_MODEL_NAME}") if not ONEUPTIME_URL: ONEUPTIME_URL = "https://oneuptime.com" if not HF_TOKEN: # Print error and exit print("HF_TOKEN env var is required. This is the Hugging Face API token. You can get it from https://huggingface.co/account/overview. Exiting..") exit() # TODO: Store this in redis down the line. items_pending = {} items_processed = {} errors = {} async def validateSecretKey(secretKey): try: # If no secret key then return false if not secretKey: return False async with aiohttp.ClientSession() as session: print(f"Validating secret key") url = f"{ONEUPTIME_URL}/api/copilot-code-repository/is-valid/{secretKey}" async with session.get(url) as response: print(response) if response.status == 200: return True else: return False except Exception as e: print(repr(e)) return False async def job(queue): print("Downlaoding model from Hugging Face: "+HF_MODEL_NAME) # check if the model is meta-llama/Meta-Llama-3-8B-Instruct if HF_MODEL_NAME == "meta-llama/Meta-Llama-3-8B-Instruct": print("If you want to use a different model, please set the HF_MODEL_NAME environment variable.") print("This may take a while (minutes or sometimes hours) depending on the model size.") # model_path = "/app/Models/Meta-Llama-3-8B-Instruct" model_path = HF_MODEL_NAME pipe = transformers.pipeline( "text-generation", model=model_path, # use gpu if available device="cuda" if torch.cuda.is_available() else "cpu", # max_new_tokens=8096 ) print("Model downloaded.") while True: random_id = None try: # process this item. random_id = await queue.get() print(f"Processing item {random_id}") messages = items_pending[random_id] print(f"Messages:") print(messages) outputs = pipe(messages) items_processed[random_id] = outputs del items_pending[random_id] print(f"Processed item {random_id}") except Exception as e: print(f"Error processing item {random_id}") # store error errors[random_id] = repr(e) # delete from items_pending if random_id in items_pending: del items_pending[random_id] print(e) @asynccontextmanager async def lifespan(app:FastAPI): queue = asyncio.Queue() app.model_queue = queue asyncio.create_task(job(queue)) yield # Declare a Pydantic model for the request body class Prompt(BaseModel): messages: list # secretkey: str # Declare a Pydantic model for the request body class PromptResult(BaseModel): id: str # secretkey: str app = FastAPI(lifespan=lifespan) @app.get("/") async def root(): return {"status": "ok"} @app.get("/status") async def status(): return {"status": "ok"} @app.post("/prompt/") async def create_item(prompt: Prompt): try: # If not prompt then return bad request error if not prompt: return {"error": "Prompt is required"} # Validate the secret key # is_valid = await validateSecretKey(prompt.secretkey) # if not is_valid: # print("Invalid secret key") # return {"error": "Invalid secret key"} # messages are in str format. We need to convert them fron json [] to list messages = prompt.messages # Log prompt to console print(messages) # Generate UUID random_id = str(uuid.uuid4()) # add to queue items_pending[random_id] = messages await app.model_queue.put(random_id) # Return response return { "id": random_id, "status": "queued" } except Exception as e: print(e) return {"error": repr(e)} # Disable this API in production @app.get("/queue-status/") async def queue_status(): try: return {"pending": items_pending, "processed": items_processed, "queue": app.model_queue.qsize(), "errors": errors} except Exception as e: print(e) return {"error": repr(e)} @app.post("/prompt-result/") async def prompt_status(prompt_status: PromptResult): try: # Log prompt status to console print(prompt_status) # Validate the secret key # is_valid = await validateSecretKey(prompt_status.secretkey) # if not is_valid: # print("Invalid secret key") # return {"error": "Invalid secret key"} # If not prompt status then return bad request error if not prompt_status: return {"error": "Prompt status is required"} # check if item is processed. if prompt_status.id in items_processed: return_value = { "id": prompt_status.id, "status": "processed", "output": items_processed[prompt_status.id] } # delete from item_processed del items_processed[prompt_status.id] return return_value else: status = "not found" if prompt_status.id in items_pending: status = "pending" return { "id": prompt_status.id, "status": status } except Exception as e: print(e) return {"error": repr(e)}