mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-14 18:58:10 +08:00
Install Black formatter
This commit is contained in:
@ -11,4 +11,4 @@ charset = utf-8
|
|||||||
end_of_line = lf
|
end_of_line = lf
|
||||||
|
|
||||||
[*.yml]
|
[*.yml]
|
||||||
indent_size = 4
|
indent_size = 2
|
||||||
|
6
Pipfile
6
Pipfile
@ -11,10 +11,9 @@ isort = "*"
|
|||||||
databases = {extras = ["sqlite"],version = "*"}
|
databases = {extras = ["sqlite"],version = "*"}
|
||||||
pytest-asyncio = "*"
|
pytest-asyncio = "*"
|
||||||
flake8-docstrings = "*"
|
flake8-docstrings = "*"
|
||||||
flake8-commas = "*"
|
|
||||||
flake8-quotes = "*"
|
|
||||||
mkdocs = "*"
|
mkdocs = "*"
|
||||||
mkdocs-material = "*"
|
mkdocs-material = "*"
|
||||||
|
black = "*"
|
||||||
|
|
||||||
[packages]
|
[packages]
|
||||||
fastapi = "*"
|
fastapi = "*"
|
||||||
@ -27,3 +26,6 @@ pyjwt = "*"
|
|||||||
|
|
||||||
[requires]
|
[requires]
|
||||||
python_version = "3.7"
|
python_version = "3.7"
|
||||||
|
|
||||||
|
[pipenv]
|
||||||
|
allow_prereleases = true
|
||||||
|
43
Pipfile.lock
generated
43
Pipfile.lock
generated
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"_meta": {
|
"_meta": {
|
||||||
"hash": {
|
"hash": {
|
||||||
"sha256": "bc420b47b9f0daede18f3dbeef52ffe0e66829e32cd21279e6f5c28ab2bd0778"
|
"sha256": "4687ef95ee5576f1882e551641586bdbeda40a663bb6d9b9ff95d4259e4cd023"
|
||||||
},
|
},
|
||||||
"pipfile-spec": 6,
|
"pipfile-spec": 6,
|
||||||
"requires": {
|
"requires": {
|
||||||
@ -177,6 +177,13 @@
|
|||||||
],
|
],
|
||||||
"version": "==0.10.0"
|
"version": "==0.10.0"
|
||||||
},
|
},
|
||||||
|
"appdirs": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:9e5896d1372858f8dd3344faf4e5014d21849c756c8d5701f78f8a103b372d92",
|
||||||
|
"sha256:d8b24664561d0d34ddfaec54636d502d7cea6e29c3eaf68f3df6180863e2166e"
|
||||||
|
],
|
||||||
|
"version": "==1.4.3"
|
||||||
|
},
|
||||||
"atomicwrites": {
|
"atomicwrites": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4",
|
"sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4",
|
||||||
@ -191,6 +198,14 @@
|
|||||||
],
|
],
|
||||||
"version": "==19.2.0"
|
"version": "==19.2.0"
|
||||||
},
|
},
|
||||||
|
"black": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:09a9dcb7c46ed496a9850b76e4e825d6049ecd38b611f1224857a79bd985a8cf",
|
||||||
|
"sha256:68950ffd4d9169716bcb8719a56c07a2f4485354fec061cdd5910aa07369731c"
|
||||||
|
],
|
||||||
|
"index": "pypi",
|
||||||
|
"version": "==19.3b0"
|
||||||
|
},
|
||||||
"certifi": {
|
"certifi": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:e4f3620cfea4f83eedc95b24abd9cd56f3c4b146dd0177e83a21b4eb49e21e50",
|
"sha256:e4f3620cfea4f83eedc95b24abd9cd56f3c4b146dd0177e83a21b4eb49e21e50",
|
||||||
@ -234,14 +249,6 @@
|
|||||||
"index": "pypi",
|
"index": "pypi",
|
||||||
"version": "==3.7.8"
|
"version": "==3.7.8"
|
||||||
},
|
},
|
||||||
"flake8-commas": {
|
|
||||||
"hashes": [
|
|
||||||
"sha256:d3005899466f51380387df7151fb59afec666a0f4f4a2c6a8995b975de0f44b7",
|
|
||||||
"sha256:ee2141a3495ef9789a3894ed8802d03eff1eaaf98ce6d8653a7c573ef101935e"
|
|
||||||
],
|
|
||||||
"index": "pypi",
|
|
||||||
"version": "==2.0.0"
|
|
||||||
},
|
|
||||||
"flake8-docstrings": {
|
"flake8-docstrings": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:3d5a31c7ec6b7367ea6506a87ec293b94a0a46c0bce2bb4975b7f1d09b6f3717",
|
"sha256:3d5a31c7ec6b7367ea6506a87ec293b94a0a46c0bce2bb4975b7f1d09b6f3717",
|
||||||
@ -250,17 +257,9 @@
|
|||||||
"index": "pypi",
|
"index": "pypi",
|
||||||
"version": "==1.5.0"
|
"version": "==1.5.0"
|
||||||
},
|
},
|
||||||
"flake8-quotes": {
|
|
||||||
"hashes": [
|
|
||||||
"sha256:5dbaf668887873f28346fb87943d6da2e4b9f77ce9f2169cff21764a0a4934ed"
|
|
||||||
],
|
|
||||||
"index": "pypi",
|
|
||||||
"version": "==2.1.0"
|
|
||||||
},
|
|
||||||
"htmlmin": {
|
"htmlmin": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:50c1ef4630374a5d723900096a961cff426dff46b48f34d194a81bbe14eca178",
|
"sha256:50c1ef4630374a5d723900096a961cff426dff46b48f34d194a81bbe14eca178"
|
||||||
"sha256:815e2530cdf4e8f0410cee6c14164d7b537bf6e4f8967dc5ee9e0124ef7e1324"
|
|
||||||
],
|
],
|
||||||
"version": "==0.1.12"
|
"version": "==0.1.12"
|
||||||
},
|
},
|
||||||
@ -296,7 +295,6 @@
|
|||||||
},
|
},
|
||||||
"jsmin": {
|
"jsmin": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:5c93bcd1210c0513cafb5ecebe3c45583c2de93fb6f78568059e6737e9945768",
|
|
||||||
"sha256:b6df99b2cd1c75d9d342e4335b535789b8da9107ec748212706ef7bbe5c2553b"
|
"sha256:b6df99b2cd1c75d9d342e4335b535789b8da9107ec748212706ef7bbe5c2553b"
|
||||||
],
|
],
|
||||||
"version": "==2.2.2"
|
"version": "==2.2.2"
|
||||||
@ -518,6 +516,13 @@
|
|||||||
"index": "pypi",
|
"index": "pypi",
|
||||||
"version": "==1.3.9"
|
"version": "==1.3.9"
|
||||||
},
|
},
|
||||||
|
"toml": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:229f81c57791a41d65e399fc06bf0848bab550a9dfd5ed66df18ce5f05e73d5c",
|
||||||
|
"sha256:235682dd292d5899d361a811df37e04a8828a5b1da3115886b73cf81ebc9100e"
|
||||||
|
],
|
||||||
|
"version": "==0.10.0"
|
||||||
|
},
|
||||||
"tornado": {
|
"tornado": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:349884248c36801afa19e342a77cc4458caca694b0eda633f5878e458a44cb2c",
|
"sha256:349884248c36801afa19e342a77cc4458caca694b0eda633f5878e458a44cb2c",
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
@ -10,19 +9,19 @@ from starlette.responses import Response
|
|||||||
from fastapi_users.authentication import BaseAuthentication
|
from fastapi_users.authentication import BaseAuthentication
|
||||||
from fastapi_users.models import UserDB
|
from fastapi_users.models import UserDB
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl='/login')
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
||||||
|
|
||||||
|
|
||||||
def generate_jwt(data: dict, lifetime_seconds: int, secret: str, algorithm: str) -> str:
|
def generate_jwt(data: dict, lifetime_seconds: int, secret: str, algorithm: str) -> str:
|
||||||
payload = data.copy()
|
payload = data.copy()
|
||||||
expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds)
|
expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds)
|
||||||
payload['exp'] = expire
|
payload["exp"] = expire
|
||||||
return jwt.encode(payload, secret, algorithm=algorithm).decode('utf-8')
|
return jwt.encode(payload, secret, algorithm=algorithm).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
class JWTAuthentication(BaseAuthentication):
|
class JWTAuthentication(BaseAuthentication):
|
||||||
|
|
||||||
algorithm: str = 'HS256'
|
algorithm: str = "HS256"
|
||||||
secret: str
|
secret: str
|
||||||
lifetime_seconds: int
|
lifetime_seconds: int
|
||||||
|
|
||||||
@ -32,18 +31,20 @@ class JWTAuthentication(BaseAuthentication):
|
|||||||
self.lifetime_seconds = lifetime_seconds
|
self.lifetime_seconds = lifetime_seconds
|
||||||
|
|
||||||
async def get_login_response(self, user: UserDB, response: Response):
|
async def get_login_response(self, user: UserDB, response: Response):
|
||||||
data = {'user_id': user.id}
|
data = {"user_id": user.id}
|
||||||
token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm)
|
token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm)
|
||||||
|
|
||||||
return {'token': token}
|
return {"token": token}
|
||||||
|
|
||||||
def get_authentication_method(self):
|
def get_authentication_method(self):
|
||||||
async def authentication_method(token: str = Depends(oauth2_scheme)):
|
async def authentication_method(token: str = Depends(oauth2_scheme)):
|
||||||
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
||||||
user_id: str = data.get('user_id')
|
user_id: str = data.get("user_id")
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
except jwt.PyJWTError:
|
except jwt.PyJWTError:
|
||||||
|
@ -3,9 +3,7 @@ from typing import List
|
|||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
from fastapi_users.models import UserDB
|
from fastapi_users.models import UserDB
|
||||||
from fastapi_users.password import (
|
from fastapi_users.password import get_password_hash, verify_and_update_password
|
||||||
get_password_hash, verify_and_update_password,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseUserDatabase:
|
class BaseUserDatabase:
|
||||||
@ -39,7 +37,9 @@ class BaseUserDatabase:
|
|||||||
if user is None:
|
if user is None:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
verified, updated_password_hash = verify_and_update_password(credentials.password, user.hashed_password)
|
verified, updated_password_hash = verify_and_update_password(
|
||||||
|
credentials.password, user.hashed_password
|
||||||
|
)
|
||||||
if not verified:
|
if not verified:
|
||||||
return None
|
return None
|
||||||
# Update password hash to a more robust one if needed
|
# Update password hash to a more robust one if needed
|
||||||
|
@ -11,7 +11,7 @@ Base = declarative_base()
|
|||||||
|
|
||||||
|
|
||||||
class BaseUser(Base):
|
class BaseUser(Base):
|
||||||
__tablename__ = 'user'
|
__tablename__ = "user"
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True)
|
||||||
email = Column(String, unique=True, index=True)
|
email = Column(String, unique=True, index=True)
|
||||||
|
@ -12,7 +12,7 @@ class UserBase(BaseModel):
|
|||||||
is_active: Optional[bool] = True
|
is_active: Optional[bool] = True
|
||||||
is_superuser: Optional[bool] = False
|
is_superuser: Optional[bool] = False
|
||||||
|
|
||||||
@pydantic.validator('id', pre=True, always=True)
|
@pydantic.validator("id", pre=True, always=True)
|
||||||
def default_id(cls, v):
|
def default_id(cls, v):
|
||||||
return v or str(uuid.uuid4())
|
return v or str(uuid.uuid4())
|
||||||
|
|
||||||
|
@ -2,10 +2,12 @@ from typing import Tuple
|
|||||||
|
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
def verify_and_update_password(plain_password: str, hashed_password: str) -> Tuple[bool, str]:
|
def verify_and_update_password(
|
||||||
|
plain_password: str, hashed_password: str
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
return pwd_context.verify_and_update(plain_password, hashed_password)
|
return pwd_context.verify_and_update(plain_password, hashed_password)
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,19 +10,20 @@ from fastapi_users.password import get_password_hash
|
|||||||
|
|
||||||
|
|
||||||
class UserRouter:
|
class UserRouter:
|
||||||
|
|
||||||
def __new__(cls, userDB: BaseUserDatabase, auth: BaseAuthentication) -> APIRouter:
|
def __new__(cls, userDB: BaseUserDatabase, auth: BaseAuthentication) -> APIRouter:
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.post('/register', response_model=User)
|
@router.post("/register", response_model=User)
|
||||||
async def register(user: UserCreate):
|
async def register(user: UserCreate):
|
||||||
hashed_password = get_password_hash(user.password)
|
hashed_password = get_password_hash(user.password)
|
||||||
db_user = UserDB(**user.dict(), hashed_password=hashed_password)
|
db_user = UserDB(**user.dict(), hashed_password=hashed_password)
|
||||||
created_user = await userDB.create(db_user)
|
created_user = await userDB.create(db_user)
|
||||||
return created_user
|
return created_user
|
||||||
|
|
||||||
@router.post('/login')
|
@router.post("/login")
|
||||||
async def login(response: Response, credentials: OAuth2PasswordRequestForm = Depends()):
|
async def login(
|
||||||
|
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
|
||||||
|
):
|
||||||
user = await userDB.authenticate(credentials)
|
user = await userDB.authenticate(credentials)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
|
@ -1,10 +1,15 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
exclude = docs
|
exclude = docs
|
||||||
max-line-length = 119
|
# Match line length of Black
|
||||||
|
max-line-length = 88
|
||||||
docstring-convention = numpy
|
docstring-convention = numpy
|
||||||
|
# Disable D1* rules which force to have docstring everywhere
|
||||||
|
ignore = D1
|
||||||
|
|
||||||
[isort]
|
[isort]
|
||||||
atomic = true
|
atomic = true
|
||||||
|
# Match line length of Black
|
||||||
|
line_length = 88
|
||||||
multi_line_output = 5
|
multi_line_output = 5
|
||||||
known_standard_library = types
|
known_standard_library = types
|
||||||
known_third_party = pytest,_pytest
|
known_third_party = pytest,_pytest
|
||||||
|
@ -9,16 +9,16 @@ from fastapi_users.models import UserDB
|
|||||||
from fastapi_users.password import get_password_hash
|
from fastapi_users.password import get_password_hash
|
||||||
|
|
||||||
active_user_data = UserDB(
|
active_user_data = UserDB(
|
||||||
id='aaa',
|
id="aaa",
|
||||||
email='king.arthur@camelot.bt',
|
email="king.arthur@camelot.bt",
|
||||||
hashed_password=get_password_hash('guinevere'),
|
hashed_password=get_password_hash("guinevere"),
|
||||||
)
|
)
|
||||||
|
|
||||||
inactive_user_data = UserDB(
|
inactive_user_data = UserDB(
|
||||||
id='bbb',
|
id="bbb",
|
||||||
email='percival@camelot.bt',
|
email="percival@camelot.bt",
|
||||||
hashed_password=get_password_hash('angharad'),
|
hashed_password=get_password_hash("angharad"),
|
||||||
is_active=False
|
is_active=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +33,6 @@ def inactive_user() -> UserDB:
|
|||||||
|
|
||||||
|
|
||||||
class MockUserDatabase(BaseUserDatabase):
|
class MockUserDatabase(BaseUserDatabase):
|
||||||
|
|
||||||
async def get(self, id: str) -> UserDB:
|
async def get(self, id: str) -> UserDB:
|
||||||
if id == active_user_data.id:
|
if id == active_user_data.id:
|
||||||
return active_user_data
|
return active_user_data
|
||||||
@ -58,9 +57,8 @@ def mock_user_db() -> MockUserDatabase:
|
|||||||
|
|
||||||
|
|
||||||
class MockAuthentication(BaseAuthentication):
|
class MockAuthentication(BaseAuthentication):
|
||||||
|
|
||||||
async def get_login_response(self, user: UserDB, response: Response):
|
async def get_login_response(self, user: UserDB, response: Response):
|
||||||
return {'token': user.id}
|
return {"token": user.id}
|
||||||
|
|
||||||
async def authenticate(self, token: str) -> UserDB:
|
async def authenticate(self, token: str) -> UserDB:
|
||||||
user = await self.userDB.get(token)
|
user = await self.userDB.get(token)
|
||||||
|
@ -8,8 +8,8 @@ from starlette.testclient import TestClient
|
|||||||
from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt
|
from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt
|
||||||
from fastapi_users.models import UserDB
|
from fastapi_users.models import UserDB
|
||||||
|
|
||||||
SECRET = 'SECRET'
|
SECRET = "SECRET"
|
||||||
ALGORITHM = 'HS256'
|
ALGORITHM = "HS256"
|
||||||
LIFETIME = 3600
|
LIFETIME = 3600
|
||||||
|
|
||||||
|
|
||||||
@ -21,8 +21,9 @@ def jwt_authentication(mock_user_db):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def token():
|
def token():
|
||||||
def _token(user, lifetime=LIFETIME):
|
def _token(user, lifetime=LIFETIME):
|
||||||
data = {'user_id': user.id}
|
data = {"user_id": user.id}
|
||||||
return generate_jwt(data, lifetime, SECRET, ALGORITHM)
|
return generate_jwt(data, lifetime, SECRET, ALGORITHM)
|
||||||
|
|
||||||
return _token
|
return _token
|
||||||
|
|
||||||
|
|
||||||
@ -30,8 +31,10 @@ def token():
|
|||||||
def test_auth_client(jwt_authentication):
|
def test_auth_client(jwt_authentication):
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
@app.get('/test-auth')
|
@app.get("/test-auth")
|
||||||
def test_auth(user: UserDB = Depends(jwt_authentication.get_authentication_method())):
|
def test_auth(
|
||||||
|
user: UserDB = Depends(jwt_authentication.get_authentication_method())
|
||||||
|
):
|
||||||
return user
|
return user
|
||||||
|
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
@ -41,30 +44,35 @@ def test_auth_client(jwt_authentication):
|
|||||||
async def test_get_login_response(jwt_authentication, user):
|
async def test_get_login_response(jwt_authentication, user):
|
||||||
login_response = await jwt_authentication.get_login_response(user, Response())
|
login_response = await jwt_authentication.get_login_response(user, Response())
|
||||||
|
|
||||||
assert 'token' in login_response
|
assert "token" in login_response
|
||||||
|
|
||||||
token = login_response['token']
|
token = login_response["token"]
|
||||||
decoded = jwt.decode(token, SECRET, algorithms=[ALGORITHM])
|
decoded = jwt.decode(token, SECRET, algorithms=[ALGORITHM])
|
||||||
assert decoded['user_id'] == user.id
|
assert decoded["user_id"] == user.id
|
||||||
|
|
||||||
|
|
||||||
class TestGetAuthenticationMethod:
|
class TestGetAuthenticationMethod:
|
||||||
|
|
||||||
def test_missing_token(self, test_auth_client):
|
def test_missing_token(self, test_auth_client):
|
||||||
response = test_auth_client.get('/test-auth')
|
response = test_auth_client.get("/test-auth")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
def test_invalid_token(self, test_auth_client):
|
def test_invalid_token(self, test_auth_client):
|
||||||
response = test_auth_client.get('/test-auth', headers={'Authorization': 'Bearer foo'})
|
response = test_auth_client.get(
|
||||||
|
"/test-auth", headers={"Authorization": "Bearer foo"}
|
||||||
|
)
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
|
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
|
||||||
response = test_auth_client.get('/test-auth', headers={'Authorization': f'Bearer {token(inactive_user)}'})
|
response = test_auth_client.get(
|
||||||
|
"/test-auth", headers={"Authorization": f"Bearer {token(inactive_user)}"}
|
||||||
|
)
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
def test_valid_token(self, test_auth_client, token, user):
|
def test_valid_token(self, test_auth_client, token, user):
|
||||||
response = test_auth_client.get('/test-auth', headers={'Authorization': f'Bearer {token(user)}'})
|
response = test_auth_client.get(
|
||||||
|
"/test-auth", headers={"Authorization": f"Bearer {token(user)}"}
|
||||||
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
assert response_json['id'] == user.id
|
assert response_json["id"] == user.id
|
||||||
|
@ -5,31 +5,35 @@ from fastapi.security import OAuth2PasswordRequestForm
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def create_oauth2_password_request_form():
|
def create_oauth2_password_request_form():
|
||||||
def _create_oauth2_password_request_form(username, password):
|
def _create_oauth2_password_request_form(username, password):
|
||||||
return OAuth2PasswordRequestForm(
|
return OAuth2PasswordRequestForm(username=username, password=password, scope="")
|
||||||
username=username,
|
|
||||||
password=password,
|
|
||||||
scope='',
|
|
||||||
)
|
|
||||||
return _create_oauth2_password_request_form
|
return _create_oauth2_password_request_form
|
||||||
|
|
||||||
|
|
||||||
class TestAuthenticate:
|
class TestAuthenticate:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unknown_user(self, create_oauth2_password_request_form, mock_user_db):
|
async def test_unknown_user(
|
||||||
form = create_oauth2_password_request_form('lancelot@camelot.bt', 'guinevere')
|
self, create_oauth2_password_request_form, mock_user_db
|
||||||
|
):
|
||||||
|
form = create_oauth2_password_request_form("lancelot@camelot.bt", "guinevere")
|
||||||
user = await mock_user_db.authenticate(form)
|
user = await mock_user_db.authenticate(form)
|
||||||
assert user is None
|
assert user is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_wrong_password(self, create_oauth2_password_request_form, mock_user_db):
|
async def test_wrong_password(
|
||||||
form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'percival')
|
self, create_oauth2_password_request_form, mock_user_db
|
||||||
|
):
|
||||||
|
form = create_oauth2_password_request_form("king.arthur@camelot.bt", "percival")
|
||||||
user = await mock_user_db.authenticate(form)
|
user = await mock_user_db.authenticate(form)
|
||||||
assert user is None
|
assert user is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_valid_credentials(self, create_oauth2_password_request_form, mock_user_db):
|
async def test_valid_credentials(
|
||||||
form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'guinevere')
|
self, create_oauth2_password_request_form, mock_user_db
|
||||||
|
):
|
||||||
|
form = create_oauth2_password_request_form(
|
||||||
|
"king.arthur@camelot.bt", "guinevere"
|
||||||
|
)
|
||||||
user = await mock_user_db.authenticate(form)
|
user = await mock_user_db.authenticate(form)
|
||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.email == 'king.arthur@camelot.bt'
|
assert user.email == "king.arthur@camelot.bt"
|
||||||
|
@ -9,11 +9,11 @@ from fastapi_users.db.sqlalchemy import Base, SQLAlchemyUserDatabase
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase:
|
async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase:
|
||||||
DATABASE_URL = 'sqlite:///./test.db'
|
DATABASE_URL = "sqlite:///./test.db"
|
||||||
database = Database(DATABASE_URL)
|
database = Database(DATABASE_URL)
|
||||||
|
|
||||||
engine = sqlalchemy.create_engine(
|
engine = sqlalchemy.create_engine(
|
||||||
DATABASE_URL, connect_args={'check_same_thread': False}
|
DATABASE_URL, connect_args={"check_same_thread": False}
|
||||||
)
|
)
|
||||||
Base.metadata.create_all(engine)
|
Base.metadata.create_all(engine)
|
||||||
|
|
||||||
@ -57,5 +57,5 @@ async def test_queries(user, sqlalchemy_user_db):
|
|||||||
await sqlalchemy_user_db.create(user)
|
await sqlalchemy_user_db.create(user)
|
||||||
|
|
||||||
# Unknown user
|
# Unknown user
|
||||||
unknown_user = await sqlalchemy_user_db.get_by_email('lancelot@camelot.bt')
|
unknown_user = await sqlalchemy_user_db.get_by_email("lancelot@camelot.bt")
|
||||||
assert unknown_user is None
|
assert unknown_user is None
|
||||||
|
@ -18,89 +18,63 @@ def test_app_client(mock_user_db, mock_authentication) -> TestClient:
|
|||||||
|
|
||||||
|
|
||||||
class TestRegister:
|
class TestRegister:
|
||||||
|
|
||||||
def test_empty_body(self, test_app_client: TestClient):
|
def test_empty_body(self, test_app_client: TestClient):
|
||||||
response = test_app_client.post('/register', json={})
|
response = test_app_client.post("/register", json={})
|
||||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
def test_missing_password(self, test_app_client: TestClient):
|
def test_missing_password(self, test_app_client: TestClient):
|
||||||
json = {
|
json = {"email": "king.arthur@camelot.bt"}
|
||||||
'email': 'king.arthur@camelot.bt',
|
response = test_app_client.post("/register", json=json)
|
||||||
}
|
|
||||||
response = test_app_client.post('/register', json=json)
|
|
||||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
def test_wrong_email(self, test_app_client: TestClient):
|
def test_wrong_email(self, test_app_client: TestClient):
|
||||||
json = {
|
json = {"email": "king.arthur", "password": "guinevere"}
|
||||||
'email': 'king.arthur',
|
response = test_app_client.post("/register", json=json)
|
||||||
'password': 'guinevere',
|
|
||||||
}
|
|
||||||
response = test_app_client.post('/register', json=json)
|
|
||||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
def test_valid_body(self, test_app_client: TestClient):
|
def test_valid_body(self, test_app_client: TestClient):
|
||||||
json = {
|
json = {"email": "king.arthur@camelot.bt", "password": "guinevere"}
|
||||||
'email': 'king.arthur@camelot.bt',
|
response = test_app_client.post("/register", json=json)
|
||||||
'password': 'guinevere',
|
|
||||||
}
|
|
||||||
response = test_app_client.post('/register', json=json)
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
assert 'hashed_password' not in response_json
|
assert "hashed_password" not in response_json
|
||||||
assert 'password' not in response_json
|
assert "password" not in response_json
|
||||||
assert 'id' in response_json
|
assert "id" in response_json
|
||||||
|
|
||||||
|
|
||||||
class TestLogin:
|
class TestLogin:
|
||||||
|
|
||||||
def test_empty_body(self, test_app_client: TestClient):
|
def test_empty_body(self, test_app_client: TestClient):
|
||||||
response = test_app_client.post('/login', data={})
|
response = test_app_client.post("/login", data={})
|
||||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
def test_missing_username(self, test_app_client: TestClient):
|
def test_missing_username(self, test_app_client: TestClient):
|
||||||
data = {
|
data = {"password": "guinevere"}
|
||||||
'password': 'guinevere',
|
response = test_app_client.post("/login", data=data)
|
||||||
}
|
|
||||||
response = test_app_client.post('/login', data=data)
|
|
||||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
def test_missing_password(self, test_app_client: TestClient):
|
def test_missing_password(self, test_app_client: TestClient):
|
||||||
data = {
|
data = {"username": "king.arthur@camelot.bt"}
|
||||||
'username': 'king.arthur@camelot.bt',
|
response = test_app_client.post("/login", data=data)
|
||||||
}
|
|
||||||
response = test_app_client.post('/login', data=data)
|
|
||||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
def test_not_existing_user(self, test_app_client: TestClient):
|
def test_not_existing_user(self, test_app_client: TestClient):
|
||||||
data = {
|
data = {"username": "lancelot@camelot.bt", "password": "guinevere"}
|
||||||
'username': 'lancelot@camelot.bt',
|
response = test_app_client.post("/login", data=data)
|
||||||
'password': 'guinevere',
|
|
||||||
}
|
|
||||||
response = test_app_client.post('/login', data=data)
|
|
||||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
|
||||||
def test_wrong_password(self, test_app_client: TestClient):
|
def test_wrong_password(self, test_app_client: TestClient):
|
||||||
data = {
|
data = {"username": "king.arthur@camelot.bt", "password": "percival"}
|
||||||
'username': 'king.arthur@camelot.bt',
|
response = test_app_client.post("/login", data=data)
|
||||||
'password': 'percival',
|
|
||||||
}
|
|
||||||
response = test_app_client.post('/login', data=data)
|
|
||||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
|
||||||
def test_valid_credentials(self, test_app_client: TestClient, user: UserDB):
|
def test_valid_credentials(self, test_app_client: TestClient, user: UserDB):
|
||||||
data = {
|
data = {"username": "king.arthur@camelot.bt", "password": "guinevere"}
|
||||||
'username': 'king.arthur@camelot.bt',
|
response = test_app_client.post("/login", data=data)
|
||||||
'password': 'guinevere',
|
|
||||||
}
|
|
||||||
response = test_app_client.post('/login', data=data)
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json() == {'token': user.id}
|
assert response.json() == {"token": user.id}
|
||||||
|
|
||||||
def test_inactive_user(self, test_app_client: TestClient):
|
def test_inactive_user(self, test_app_client: TestClient):
|
||||||
data = {
|
data = {"username": "percival@camelot.bt", "password": "angharad"}
|
||||||
'username': 'percival@camelot.bt',
|
response = test_app_client.post("/login", data=data)
|
||||||
'password': 'angharad',
|
|
||||||
}
|
|
||||||
response = test_app_client.post('/login', data=data)
|
|
||||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
Reference in New Issue
Block a user