This commit is contained in:
2022-01-01 19:34:46 +03:00
parent 44cbfc4f59
commit e1389700cc
16 changed files with 920 additions and 69 deletions

View File

@@ -1,18 +1,21 @@
from alembic import context
import sys, os
import os
import sys
from alembic import context
from sqlalchemy.engine import create_engine
from core.db import DATABASE_URL
myPath = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, myPath + '/../../')
sys.path.insert(0, myPath + "/../../")
config = context.config
from app.models import BaseMeta
target_metadata = BaseMeta.metadata

View File

@@ -10,19 +10,21 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '750640043cd4'
down_revision = '7e45f53febe1'
revision = "750640043cd4"
down_revision = "7e45f53febe1"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('user_settings', sa.Column('source', sa.String(length=32), nullable=False))
op.add_column(
"user_settings", sa.Column("source", sa.String(length=32), nullable=False)
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('user_settings', 'source')
op.drop_column("user_settings", "source")
# ### end Alembic commands ###

View File

@@ -10,7 +10,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '7e45f53febe1'
revision = "7e45f53febe1"
down_revision = None
branch_labels = None
depends_on = None
@@ -18,36 +18,51 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('languages',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('label', sa.String(length=16), nullable=False),
sa.Column('code', sa.String(length=4), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('code')
op.create_table(
"languages",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("label", sa.String(length=16), nullable=False),
sa.Column("code", sa.String(length=4), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("code"),
)
op.create_table('user_settings',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.BigInteger(), nullable=False),
sa.Column('last_name', sa.String(length=64), nullable=False),
sa.Column('first_name', sa.String(length=64), nullable=False),
sa.Column('username', sa.String(length=32), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('user_id')
op.create_table(
"user_settings",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("last_name", sa.String(length=64), nullable=False),
sa.Column("first_name", sa.String(length=64), nullable=False),
sa.Column("username", sa.String(length=32), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id"),
)
op.create_table('users_languages',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('language', sa.Integer(), nullable=True),
sa.Column('user', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['language'], ['languages.id'], name='fk_users_languages_languages_language_id', onupdate='CASCADE', ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user'], ['user_settings.id'], name='fk_users_languages_user_settings_user_id', onupdate='CASCADE', ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
op.create_table(
"users_languages",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("language", sa.Integer(), nullable=True),
sa.Column("user", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["language"],
["languages.id"],
name="fk_users_languages_languages_language_id",
onupdate="CASCADE",
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["user"],
["user_settings.id"],
name="fk_users_languages_user_settings_user_id",
onupdate="CASCADE",
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('users_languages')
op.drop_table('user_settings')
op.drop_table('languages')
op.drop_table("users_languages")
op.drop_table("user_settings")
op.drop_table("languages")
# ### end Alembic commands ###

View File

@@ -6,4 +6,6 @@ from core.config import env_config
async def check_token(api_key: str = Security(default_security)):
if api_key != env_config.API_KEY:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Wrong api key!")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Wrong api key!"
)

View File

@@ -20,7 +20,7 @@ class Language(ormar.Model):
class User(ormar.Model):
class Meta(BaseMeta):
tablename = "user_settings"
id: int = ormar.Integer(primary_key=True) # type: ignore
user_id: int = ormar.BigInteger(unique=True) # type: ignore

View File

@@ -1,3 +1,5 @@
from typing import Optional
from pydantic import BaseModel, constr
@@ -11,6 +13,7 @@ class LanguageDetail(CreateLanguage):
class UserBase(BaseModel):
user_id: int
last_name: constr(max_length=64) # type: ignore
first_name: constr(max_length=64) # type: ignore
username: constr(max_length=32) # type: ignore
@@ -21,6 +24,13 @@ class UserCreateOrUpdate(UserBase):
allowed_langs: list[str]
class UserDetail(BaseModel):
user_id: int
class UserUpdate(BaseModel):
last_name: Optional[constr(max_length=64)] = None # type: ignore
first_name: Optional[constr(max_length=64)] = None # type: ignore
username: Optional[constr(max_length=32)] = None # type: ignore
source: Optional[constr(max_length=32)] = None # type: ignore
allowed_langs: Optional[list[str]] = None
class UserDetail(UserBase):
allowed_langs: list[LanguageDetail]

24
src/app/services.py Normal file
View File

@@ -0,0 +1,24 @@
from typing import cast
from app.models import User, Language
async def update_user_allowed_langs(user: User, new_allowed_langs: list[str]):
user_allowed_langs = cast(list[Language], user.allowed_langs)
exists_langs = set(lang.code for lang in user_allowed_langs)
new_langs = set(new_allowed_langs)
to_delete = exists_langs - new_langs
to_add = new_langs - exists_langs
all_process_langs = list(to_delete) + list(to_add)
langs = await Language.objects.filter(code__in=all_process_langs).all()
for lang in langs:
if lang.code in to_delete:
await user.allowed_langs.remove(lang)
if lang.code in to_add:
await user.allowed_langs.add(lang)

View File

@@ -4,17 +4,22 @@ from fastapi_pagination import Page, Params
from fastapi_pagination.ext.ormar import paginate
from app.depends import check_token
from app.serializers import UserCreateOrUpdate, UserDetail, CreateLanguage, LanguageDetail
from app.models import User, Language
from app.serializers import (
UserCreateOrUpdate,
UserUpdate,
UserDetail,
CreateLanguage,
LanguageDetail,
)
from app.services import update_user_allowed_langs
# TODO: add redis cache
users_router = APIRouter(
prefix="/users",
tags=["users"],
dependencies=[Depends(check_token)]
prefix="/users", tags=["users"], dependencies=[Depends(check_token)]
)
@@ -25,7 +30,9 @@ async def get_users():
@users_router.get("/{user_id}", response_model=UserDetail)
async def get_user(user_id: int):
user_data = await User.objects.select_related("allowd_langs").get_or_none(user_id=user_id)
user_data = await User.objects.select_related("allowd_langs").get_or_none(
user_id=user_id
)
if user_data is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@@ -33,35 +40,64 @@ async def get_user(user_id: int):
return user_data
@users_router.post("/{user_id}", response_model=UserDetail)
async def create_or_update_user(user_id: int, data: UserCreateOrUpdate):
@users_router.post("/", response_model=UserDetail)
async def create_or_update_user(data: UserCreateOrUpdate):
data_dict = data.dict()
user_data = await User.objects.select_related("allowed_langs").get_or_none(user_id=user_id)
user_data = await User.objects.select_related("allowed_langs").get_or_none(
user_id=data_dict["user_id"]
)
allowed_langs = data_dict.pop("allowed_langs")
if user_data is None:
user_data = await User.objects.select_related("allowed_langs").create(**{**data_dict, "user_id": user_id})
user_data = await User.objects.select_related("allowed_langs").create(
**data_dict
)
else:
data_dict.pop("user_id")
user_data.update_from_dict(data_dict)
await user_data.allowed_langs.clear() # type: ignore
await update_user_allowed_langs(user_data, allowed_langs)
langs = await Language.objects.filter(code__in=allowed_langs).all()
return user_data
for lang in langs:
await user_data.allowed_langs.add(lang)
await user_data.update()
@users_router.patch("/{user_id}", response_model=UserDetail)
async def update_user(user_id: int, data: UserUpdate):
user_data = await User.objects.select_related("allowed_langs").get_or_none(
user_id=user_id
)
if user_data is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
data_dict = data.dict()
update_data = {}
for key in data_dict:
if data_dict[key] is not None:
update_data[key] = data_dict[key]
if not update_data:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
allowed_langs = update_data.pop("allowed_langs", None)
if update_data:
user_data.update_from_dict(update_data)
await user_data.update()
if not allowed_langs:
return user_data
await update_user_allowed_langs(user_data, allowed_langs)
return user_data
languages_router = APIRouter(
prefix="/languages",
tags=["languages"],
dependencies=[Depends(check_token)]
prefix="/languages", tags=["languages"], dependencies=[Depends(check_token)]
)
@@ -76,7 +112,7 @@ async def get_language(code: str):
if language is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return language

View File

@@ -2,8 +2,8 @@ from fastapi import FastAPI
from fastapi_pagination import add_pagination
from core.db import database
from app.views import users_router, languages_router
from core.db import database
def start_app() -> FastAPI:
@@ -16,13 +16,13 @@ def start_app() -> FastAPI:
add_pagination(app)
@app.on_event('startup')
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event('shutdown')
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:

View File

@@ -1,6 +1,6 @@
from urllib.parse import quote
from databases import Database
from databases import Database
from sqlalchemy import MetaData
from core.config import env_config