oneuptime/Llama/app.py

52 lines
1.1 KiB
Python
Raw Normal View History

2023-10-15 20:04:58 +00:00
from transformers import AutoTokenizer
import transformers
import torch
2023-10-16 10:45:15 +00:00
from fastapi import FastAPI
from pydantic import BaseModel
# Declare a Pydantic model for the request body
class Prompt(BaseModel):
prompt: str
2023-10-16 19:54:21 +00:00
model_path = "./Models/Llama-2-7b-chat-hf"
2023-10-16 10:45:15 +00:00
2023-10-16 19:54:21 +00:00
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
2023-10-15 20:04:58 +00:00
pipeline = transformers.pipeline(
"text-generation",
2023-10-16 19:54:21 +00:00
model=model_path,
2023-10-18 11:07:37 +00:00
# torch_dtype=torch.float32, # for CPU
torch_dtype=torch.float16, # for GPU
2023-10-15 20:04:58 +00:00
device_map="auto",
)
2023-10-15 17:14:15 +00:00
2023-10-16 10:45:15 +00:00
app = FastAPI()
@app.post("/prompt/")
async def create_item(prompt: Prompt):
# If not prompt then return bad request error
if not prompt:
return {"error": "Prompt is required"}
sequences = pipeline(
prompt.prompt,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=200,
)
prompt_response_array = []
for seq in sequences:
print(f"Result: {seq['generated_text']}")
prompt_response_array.append(seq["generated_text"])
# return prompt response
return {"response": prompt_response_array}