Skip to content
Snippets Groups Projects
Verified Commit f277465f authored by hanfi's avatar hanfi
Browse files

add auith for customers

parent 28821e0a
No related branches found
No related tags found
1 merge request!5merge from staging
from datetime import datetime, timedelta
from uuid import UUID
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from itsdangerous.serializer import Serializer
from itsdangerous import BadSignature
from itsdangerous.url_safe import URLSafeTimedSerializer
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
......@@ -36,7 +35,7 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Authentication setup
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
oauth2_tokener = Serializer(settings.signing_key)
oauth2_tokener = URLSafeTimedSerializer(settings.signing_key)
# DB Dependency
......@@ -48,10 +47,11 @@ def get_db():
db.close()
def check_token(token: str):
def check_token(token: str, item_uuid: str):
try:
timestamp = oauth2_tokener.loads(token)
if datetime.fromtimestamp(timestamp) > datetime.now():
auth_data = oauth2_tokener.loads(token, max_age=settings.token_lifetime * 60)
print(auth_data)
if auth_data == "all" or auth_data == item_uuid:
return # success
except BadSignature:
pass
......@@ -63,7 +63,6 @@ def check_token(token: str):
)
# Routes
@app.post("/item/prepare", response_model=schemas.Item)
@limiter.limit("2/minute")
......@@ -81,15 +80,19 @@ def add_item_with_image(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db),
):
check_token(token)
check_token(token, None)
print(image.file)
return utils.add_item_with_image(db, image)
@app.post("/item/update/{item_uuid}", response_model=schemas.Item)
def update_item(
item_uuid: str, data: schemas.ItemUpdate, db: Session = Depends(get_db)
item_uuid: str,
data: schemas.ItemUpdate,
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db),
):
check_token(token, item_uuid)
item = utils.get_item_by_uuid(db, UUID(item_uuid))
if not item:
raise HTTPException(status_code=404, detail="Item not found")
......@@ -97,7 +100,9 @@ def update_item(
@app.get("/item/{item_uuid}", response_model=schemas.Item)
def get_item(item_uuid: str, db: Session = Depends(get_db)):
def get_item(
item_uuid: str, token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
):
item = utils.get_item_by_uuid(db, UUID(item_uuid))
if not item:
raise HTTPException(status_code=404, detail="Item not found")
......@@ -106,13 +111,16 @@ def get_item(item_uuid: str, db: Session = Depends(get_db)):
@app.get("/items", response_model=list[schemas.Item])
def get_items(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
check_token(token)
print(token)
check_token(token, None)
return utils.get_stored_items(db)
@app.get("/tag/{tag}", response_model=schemas.Item)
def get_item_by_tag(tag: str, token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
check_token(token)
def get_item_by_tag(
tag: str, token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
):
check_token(token, None)
item = utils.get_item_by_tag(db, tag)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
......@@ -121,7 +129,7 @@ def get_item_by_tag(tag: str, token: str = Depends(oauth2_scheme), db: Session =
@app.get("/storages", response_model=list[schemas.Storage])
def list_storages(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
check_token(token)
check_token(token, None)
return utils.get_storages(db)
......@@ -131,7 +139,7 @@ def checkin_item_by_uuid(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db),
):
check_token(token)
check_token(token, None)
item = utils.get_item_by_uuid(db, UUID(checkin.item_uuid))
if item is None:
raise HTTPException(status_code=404, detail="Item not found")
......@@ -146,13 +154,27 @@ def verify_supporter(form_data: OAuth2PasswordRequestForm = Depends()):
if form_data.password != settings.shared_secret:
raise HTTPException(status_code=400, detail="Incorrect username or password")
return {
"access_token": oauth2_tokener.dumps(
(datetime.now() + timedelta(minutes=settings.token_lifetime)).timestamp()
),
"access_token": oauth2_tokener.dumps("all"),
"token_type": "bearer",
}
@app.post("/login")
def verify_customer(
login_data: schemas.LoginData, db: Session = Depends(get_db)
): # item_uuid: str, signature: str):
print(login_data)
item = utils.get_item_by_uuid(db, UUID(login_data.item_uuid))
if not item:
raise HTTPException(status_code=404, detail="Item not found")
if not utils.verify_signature(item, login_data.signature):
raise HTTPException(status_code=400, detail="Invalid signature")
return {
"access_token": oauth2_tokener.dumps(str(item.uuid)),
"token_type": "bearer",
}
@app.get("/token/check")
def check_token_validity(token: str = Depends(oauth2_scheme)):
check_token(token)
check_token(token, None)
......@@ -32,7 +32,6 @@ class ItemUpdate(BaseModel):
addressee: Union[str, None] = None
team: Union[str, None] = None
amount: Union[int, None] = None
signature: str
class Item(BaseModel):
......@@ -57,6 +56,11 @@ class Item(BaseModel):
orm_mode = True
class LoginData(BaseModel):
item_uuid: str
signature: str
class Storage(BaseModel):
name: str
items: List[Item]
......
......@@ -73,22 +73,24 @@ def add_item_with_image(db: Session, image: SpooledTemporaryFile):
return db_item
def update_item(db: Session, item: schemas.Item, data: schemas.ItemUpdate):
def verify_signature(item: str, signature: str):
public_key = Ed448PublicKey.from_public_bytes(bytes.fromhex(item.verification))
verify = ""
print(str(item.uuid))
print(signature)
try:
public_key.verify(bytes.fromhex(signature), bytes(str(item.uuid), "utf-8"))
except InvalidSignature:
return False
return True
def update_item(db: Session, item: schemas.Item, data: schemas.ItemUpdate):
if data.addressee:
verify += data.addressee
item.addressee = escape(data.addressee)
if data.team:
verify += data.team
item.team = escape(data.team)
if data.amount:
verify += str(data.amount)
item.amount = data.amount
try:
public_key.verify(bytes.fromhex(data.signature), bytes(verify, "utf-8"))
except InvalidSignature:
return None
db.commit()
db.refresh(item)
return item
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment