diff --git a/drivers/mtd/mtd.c b/drivers/mtd/mtd.c index 030b3f8fba..b1728f5283 100644 --- a/drivers/mtd/mtd.c +++ b/drivers/mtd/mtd.c @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -28,6 +29,23 @@ #include "bitarithm.h" #include "mtd.h" +static bool out_of_bounds(mtd_dev_t *mtd, uint32_t page, uint32_t offset, uint32_t len) +{ + const uint32_t page_shift = bitarithm_msb(mtd->page_size); + const uint32_t pages_numof = mtd->sector_count * mtd->pages_per_sector; + + /* 2 TiB SD cards might be a problem */ + assert(pages_numof >= mtd->sector_count); + + /* read n byte buffer -> last byte will be at n - 1 */ + page += (offset + len - 1) >> page_shift; + if (page >= pages_numof) { + return true; + } + + return false; +} + int mtd_init(mtd_dev_t *mtd) { if (!mtd || !mtd->driver) { @@ -70,6 +88,10 @@ int mtd_read(mtd_dev_t *mtd, void *dest, uint32_t addr, uint32_t count) return -ENODEV; } + if (out_of_bounds(mtd, 0, addr, count)) { + return -EOVERFLOW; + } + if (mtd->driver->read) { return mtd->driver->read(mtd, dest, addr, count); } @@ -88,6 +110,10 @@ int mtd_read_page(mtd_dev_t *mtd, void *dest, uint32_t page, uint32_t offset, return -ENODEV; } + if (out_of_bounds(mtd, page, offset, count)) { + return -EOVERFLOW; + } + if (mtd->driver->read_page == NULL) { /* TODO: remove when all backends implement read_page */ if (mtd->driver->read) { @@ -139,6 +165,10 @@ int mtd_write(mtd_dev_t *mtd, const void *src, uint32_t addr, uint32_t count) return -ENODEV; } + if (out_of_bounds(mtd, 0, addr, count)) { + return -EOVERFLOW; + } + if (mtd->driver->write) { return mtd->driver->write(mtd, src, addr, count); } @@ -214,6 +244,10 @@ int mtd_write_page(mtd_dev_t *mtd, const void *data, uint32_t page, return -ENODEV; } + if (out_of_bounds(mtd, page, offset, len)) { + return -EOVERFLOW; + } + if (mtd->driver->flags & MTD_DRIVER_FLAG_DIRECT_WRITE) { return mtd_write_page_raw(mtd, data, page, offset, len); } @@ -247,6 +281,10 @@ int mtd_write_page_raw(mtd_dev_t *mtd, const void *src, uint32_t page, uint32_t return -ENODEV; } + if (out_of_bounds(mtd, page, offset, count)) { + return -EOVERFLOW; + } + if (mtd->driver->write_page == NULL) { /* TODO: remove when all backends implement write_page */ if (mtd->driver->write) { @@ -321,7 +359,11 @@ int mtd_erase_sector(mtd_dev_t *mtd, uint32_t sector, uint32_t count) return -ENODEV; } - if (sector >= mtd->sector_count) { + if (sector + count > mtd->sector_count) { + return -EOVERFLOW; + } + + if (sector + count < sector) { return -EOVERFLOW; }