diff --git a/README.md b/README.md index 6357b21..829d4c5 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ All 10 endpoints are grouped neatly by resource - **Users** and **Tasks** - with | Method | URL | Description | |--------|-----|-------------| | `POST` | `/users/` | Create a new user | -| `GET` | `/users/` | List all users | +| `GET` | `/users/` | List users with optional `skip` and `limit` query params | | `GET` | `/users/{id}` | Get one user (includes their tasks) | | `PUT` | `/users/{id}` | Update a user's name or email | | `DELETE` | `/users/{id}` | Delete a user | @@ -110,7 +110,7 @@ All 10 endpoints are grouped neatly by resource - **Users** and **Tasks** - with | Method | URL | Description | |--------|-----|-------------| | `POST` | `/tasks/` | Create a new task | -| `GET` | `/tasks/` | List all tasks | +| `GET` | `/tasks/` | List tasks with optional `skip` and `limit` query params | | `GET` | `/tasks/{id}` | Get one task by ID | | `PUT` | `/tasks/{id}` | Update a task | | `DELETE` | `/tasks/{id}` | Delete a task | diff --git a/crud.py b/crud.py index 4de4386..ad5ffba 100644 --- a/crud.py +++ b/crud.py @@ -14,8 +14,8 @@ def get_user(db: Session, user_id: int): return db.query(models.User).filter(models.User.id == user_id).first() -def get_users(db: Session): - return db.query(models.User).all() +def get_users(db: Session, skip: int = 0, limit: int = 10): + return db.query(models.User).offset(skip).limit(limit).all() @@ -29,8 +29,8 @@ def create_task(db:Session, task: schemas.TaskCreate): def get_task(db:Session, task_id: int): return db.query(models.Task).filter(models.Task.id == task_id).first() -def get_tasks(db: Session): - return db.query(models.Task).all() +def get_tasks(db: Session, skip: int = 0, limit: int = 10): + return db.query(models.Task).offset(skip).limit(limit).all() def update_task(db:Session, task_id: int, data: schemas.TaskUpdate): db_task = db.query(models.Task).filter(models.Task.id == task_id).first() diff --git a/routers/tasks.py b/routers/tasks.py index 368791f..11b8199 100644 --- a/routers/tasks.py +++ b/routers/tasks.py @@ -12,8 +12,8 @@ def create_task(task: schemas.TaskCreate, db: Session = Depends(get_db)): @router.get("/", response_model=List[schemas.TaskResponse]) -def get_tasks(db: Session = Depends(get_db)): - return crud.get_tasks(db) +def get_tasks(skip: int = 0, limit: int = 10, db: Session = Depends(get_db)): + return crud.get_tasks(db, skip=skip, limit=limit) @router.get("/{task_id}", response_model=schemas.TaskResponse) diff --git a/routers/users.py b/routers/users.py index 7f738a9..c449950 100644 --- a/routers/users.py +++ b/routers/users.py @@ -12,8 +12,8 @@ def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)): @router.get("/", response_model=List[schemas.UserResponse]) -def get_users(db: Session = Depends(get_db)): - return crud.get_users(db) +def get_users(skip: int = 0, limit: int = 10, db: Session = Depends(get_db)): + return crud.get_users(db, skip=skip, limit=limit) @router.get("/{user_id}", response_model=schemas.UserResponse) def get_user(user_id: int, db: Session = Depends(get_db)): diff --git a/tests/test_pagination.py b/tests/test_pagination.py new file mode 100644 index 0000000..4b02a43 --- /dev/null +++ b/tests/test_pagination.py @@ -0,0 +1,89 @@ +import os +import tempfile +import unittest + +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +import database +import main +import models + + +class PaginationTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + cls.temp_db.close() + + database.engine = create_engine( + f"sqlite:///{cls.temp_db.name}", + connect_args={"check_same_thread": False}, + ) + database.SessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=database.engine, + ) + models.Base.metadata.create_all(bind=database.engine) + cls.client = TestClient(main.app) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "client"): + cls.client.close() + if hasattr(database, "SessionLocal"): + database.SessionLocal.close_all() + if hasattr(database, "engine"): + database.engine.dispose() + if os.path.exists(cls.temp_db.name): + try: + os.remove(cls.temp_db.name) + except PermissionError: + pass + + def setUp(self): + db = database.SessionLocal() + try: + db.query(models.Task).delete() + db.query(models.User).delete() + db.commit() + + for index in range(12): + user = models.User(name=f"User {index}", email=f"user{index}@example.com") + db.add(user) + db.commit() + + users = db.query(models.User).all() + for index, user in enumerate(users): + db.add(models.Task(title=f"Task {index}", description="desc", owner_id=user.id)) + db.commit() + finally: + db.close() + + def test_tasks_pagination(self): + response = self.client.get("/tasks/") + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.json()), 10) + + paged_response = self.client.get("/tasks/", params={"skip": 5, "limit": 5}) + self.assertEqual(paged_response.status_code, 200) + data = paged_response.json() + self.assertEqual(len(data), 5) + self.assertEqual(data[0]["title"], "Task 5") + + def test_users_pagination(self): + response = self.client.get("/users/") + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.json()), 10) + + paged_response = self.client.get("/users/", params={"skip": 5, "limit": 5}) + self.assertEqual(paged_response.status_code, 200) + data = paged_response.json() + self.assertEqual(len(data), 5) + self.assertEqual(data[0]["email"], "user5@example.com") + + +if __name__ == "__main__": + unittest.main()