Skip to content
Snippets Groups Projects
Verified Commit f277465f authored by hanfi's avatar hanfi
Browse files

add auith for customers

parent 28821e0a
No related branches found
No related tags found
1 merge request!5merge from staging
from datetime import datetime, timedelta
from uuid import UUID from uuid import UUID
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile, status from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from itsdangerous.serializer import Serializer
from itsdangerous import BadSignature from itsdangerous import BadSignature
from itsdangerous.url_safe import URLSafeTimedSerializer
from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
...@@ -36,7 +35,7 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) ...@@ -36,7 +35,7 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Authentication setup # Authentication setup
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
oauth2_tokener = Serializer(settings.signing_key) oauth2_tokener = URLSafeTimedSerializer(settings.signing_key)
# DB Dependency # DB Dependency
...@@ -48,10 +47,11 @@ def get_db(): ...@@ -48,10 +47,11 @@ def get_db():
db.close() db.close()
def check_token(token: str): def check_token(token: str, item_uuid: str):
try: try:
timestamp = oauth2_tokener.loads(token) auth_data = oauth2_tokener.loads(token, max_age=settings.token_lifetime * 60)
if datetime.fromtimestamp(timestamp) > datetime.now(): print(auth_data)
if auth_data == "all" or auth_data == item_uuid:
return # success return # success
except BadSignature: except BadSignature:
pass pass
...@@ -63,7 +63,6 @@ def check_token(token: str): ...@@ -63,7 +63,6 @@ def check_token(token: str):
) )
# Routes # Routes
@app.post("/item/prepare", response_model=schemas.Item) @app.post("/item/prepare", response_model=schemas.Item)
@limiter.limit("2/minute") @limiter.limit("2/minute")
...@@ -81,15 +80,19 @@ def add_item_with_image( ...@@ -81,15 +80,19 @@ def add_item_with_image(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
check_token(token) check_token(token, None)
print(image.file) print(image.file)
return utils.add_item_with_image(db, image) return utils.add_item_with_image(db, image)
@app.post("/item/update/{item_uuid}", response_model=schemas.Item) @app.post("/item/update/{item_uuid}", response_model=schemas.Item)
def update_item( def update_item(
item_uuid: str, data: schemas.ItemUpdate, db: Session = Depends(get_db) item_uuid: str,
data: schemas.ItemUpdate,
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db),
): ):
check_token(token, item_uuid)
item = utils.get_item_by_uuid(db, UUID(item_uuid)) item = utils.get_item_by_uuid(db, UUID(item_uuid))
if not item: if not item:
raise HTTPException(status_code=404, detail="Item not found") raise HTTPException(status_code=404, detail="Item not found")
...@@ -97,7 +100,9 @@ def update_item( ...@@ -97,7 +100,9 @@ def update_item(
@app.get("/item/{item_uuid}", response_model=schemas.Item) @app.get("/item/{item_uuid}", response_model=schemas.Item)
def get_item(item_uuid: str, db: Session = Depends(get_db)): def get_item(
item_uuid: str, token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
):
item = utils.get_item_by_uuid(db, UUID(item_uuid)) item = utils.get_item_by_uuid(db, UUID(item_uuid))
if not item: if not item:
raise HTTPException(status_code=404, detail="Item not found") raise HTTPException(status_code=404, detail="Item not found")
...@@ -106,13 +111,16 @@ def get_item(item_uuid: str, db: Session = Depends(get_db)): ...@@ -106,13 +111,16 @@ def get_item(item_uuid: str, db: Session = Depends(get_db)):
@app.get("/items", response_model=list[schemas.Item]) @app.get("/items", response_model=list[schemas.Item])
def get_items(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): def get_items(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
check_token(token) print(token)
check_token(token, None)
return utils.get_stored_items(db) return utils.get_stored_items(db)
@app.get("/tag/{tag}", response_model=schemas.Item) @app.get("/tag/{tag}", response_model=schemas.Item)
def get_item_by_tag(tag: str, token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): def get_item_by_tag(
check_token(token) tag: str, token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
):
check_token(token, None)
item = utils.get_item_by_tag(db, tag) item = utils.get_item_by_tag(db, tag)
if not item: if not item:
raise HTTPException(status_code=404, detail="Item not found") raise HTTPException(status_code=404, detail="Item not found")
...@@ -121,7 +129,7 @@ def get_item_by_tag(tag: str, token: str = Depends(oauth2_scheme), db: Session = ...@@ -121,7 +129,7 @@ def get_item_by_tag(tag: str, token: str = Depends(oauth2_scheme), db: Session =
@app.get("/storages", response_model=list[schemas.Storage]) @app.get("/storages", response_model=list[schemas.Storage])
def list_storages(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): def list_storages(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
check_token(token) check_token(token, None)
return utils.get_storages(db) return utils.get_storages(db)
...@@ -131,7 +139,7 @@ def checkin_item_by_uuid( ...@@ -131,7 +139,7 @@ def checkin_item_by_uuid(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
check_token(token) check_token(token, None)
item = utils.get_item_by_uuid(db, UUID(checkin.item_uuid)) item = utils.get_item_by_uuid(db, UUID(checkin.item_uuid))
if item is None: if item is None:
raise HTTPException(status_code=404, detail="Item not found") raise HTTPException(status_code=404, detail="Item not found")
...@@ -146,13 +154,27 @@ def verify_supporter(form_data: OAuth2PasswordRequestForm = Depends()): ...@@ -146,13 +154,27 @@ def verify_supporter(form_data: OAuth2PasswordRequestForm = Depends()):
if form_data.password != settings.shared_secret: if form_data.password != settings.shared_secret:
raise HTTPException(status_code=400, detail="Incorrect username or password") raise HTTPException(status_code=400, detail="Incorrect username or password")
return { return {
"access_token": oauth2_tokener.dumps( "access_token": oauth2_tokener.dumps("all"),
(datetime.now() + timedelta(minutes=settings.token_lifetime)).timestamp() "token_type": "bearer",
), }
@app.post("/login")
def verify_customer(
login_data: schemas.LoginData, db: Session = Depends(get_db)
): # item_uuid: str, signature: str):
print(login_data)
item = utils.get_item_by_uuid(db, UUID(login_data.item_uuid))
if not item:
raise HTTPException(status_code=404, detail="Item not found")
if not utils.verify_signature(item, login_data.signature):
raise HTTPException(status_code=400, detail="Invalid signature")
return {
"access_token": oauth2_tokener.dumps(str(item.uuid)),
"token_type": "bearer", "token_type": "bearer",
} }
@app.get("/token/check") @app.get("/token/check")
def check_token_validity(token: str = Depends(oauth2_scheme)): def check_token_validity(token: str = Depends(oauth2_scheme)):
check_token(token) check_token(token, None)
...@@ -32,7 +32,6 @@ class ItemUpdate(BaseModel): ...@@ -32,7 +32,6 @@ class ItemUpdate(BaseModel):
addressee: Union[str, None] = None addressee: Union[str, None] = None
team: Union[str, None] = None team: Union[str, None] = None
amount: Union[int, None] = None amount: Union[int, None] = None
signature: str
class Item(BaseModel): class Item(BaseModel):
...@@ -57,6 +56,11 @@ class Item(BaseModel): ...@@ -57,6 +56,11 @@ class Item(BaseModel):
orm_mode = True orm_mode = True
class LoginData(BaseModel):
item_uuid: str
signature: str
class Storage(BaseModel): class Storage(BaseModel):
name: str name: str
items: List[Item] items: List[Item]
......
...@@ -73,22 +73,24 @@ def add_item_with_image(db: Session, image: SpooledTemporaryFile): ...@@ -73,22 +73,24 @@ def add_item_with_image(db: Session, image: SpooledTemporaryFile):
return db_item return db_item
def update_item(db: Session, item: schemas.Item, data: schemas.ItemUpdate): def verify_signature(item: str, signature: str):
public_key = Ed448PublicKey.from_public_bytes(bytes.fromhex(item.verification)) public_key = Ed448PublicKey.from_public_bytes(bytes.fromhex(item.verification))
verify = "" print(str(item.uuid))
print(signature)
try:
public_key.verify(bytes.fromhex(signature), bytes(str(item.uuid), "utf-8"))
except InvalidSignature:
return False
return True
def update_item(db: Session, item: schemas.Item, data: schemas.ItemUpdate):
if data.addressee: if data.addressee:
verify += data.addressee
item.addressee = escape(data.addressee) item.addressee = escape(data.addressee)
if data.team: if data.team:
verify += data.team
item.team = escape(data.team) item.team = escape(data.team)
if data.amount: if data.amount:
verify += str(data.amount)
item.amount = data.amount item.amount = data.amount
try:
public_key.verify(bytes.fromhex(data.signature), bytes(verify, "utf-8"))
except InvalidSignature:
return None
db.commit() db.commit()
db.refresh(item) db.refresh(item)
return item return item
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment