Kth Problems

Solution with priority queue

套路:找第K个的问题,最常用的做法就是用优先队列来实现,根据题意用最大堆或者最小堆把时间复杂度优化到 O(nlogk).

Merge k Sorted Lists

//Time O(nlogk)
//Space O(n)
//provides greater
struct Cmp {
  bool operator() (ListNode* n1, ListNode* n2) {
    return n1 -> val > n2 -> val;

class Solution {
  ListNode* mergeKLists(vector<ListNode*>& lists) {
    //min_heap needs a greater comparator
    //Method 1: redefine functor
    //priority_queue<ListNode*, vector<ListNode*>, Cmp> min_heap;
    //Method 2: Lambda
    auto cmp = [](ListNode* n1, ListNode* n2) {return n1 -> val > n2 -> val;};
    priority_queue<ListNode*, vector<ListNode*>, decltype(cmp)> min_heap(cmp);
    //maintain the min_heap of size k instead of all nodes
    // klogn => nlogk
    for (int i = 0; i < lists.size(); i++) {
      if (lists[i]) {

    ListNode* dummy = new ListNode(0);
    ListNode* cur = dummy;
    while (!min_heap.empty()) {
      ListNode* temp = min_heap.top();
      cur -> next = temp;
      if (temp -> next) {
        min_heap.emplace(temp -> next);
      cur = cur -> next;
    return dummy -> next;

Kth Smallest Number In Sorted Matrix

Solution 1: use priority queue

class Solution {
    int kthSmallest(vector<vector<int>>& matrix, int k) {
        int n = matrix.size();
        auto cmp = [&matrix](const pair<int, int> & p1, const pair<int, int> & p2) {
            return matrix[p1.first][p1.second] > matrix[p2.first][p2.second];
        priority_queue<pair<int, int>, vector<pair<int, int>>, decltype(cmp)> min_heap(cmp);
        deque<deque<bool>> visited(n, deque<bool>(n, false));
        min_heap.emplace(0, 0);
        visited[0][0] = true;
        while (k > 1) {
            auto p = min_heap.top();
            int x = p.first;
            int y = p.second;
            if (x + 1 < n && visited[x+1][y] == false) {
                visited[x+1][y] = true;
                min_heap.emplace(x+1, y);
            if (y + 1 < n && visited[x][y+1] == false) {
                visited[x][y+1] = true;
                min_heap.emplace(x, y+1);
        return matrix[min_heap.top().first][min_heap.top().second];

Sulution 2: use binary search

int kthSmallest(vector<vector<int>>& matrix, int k) {
        return helper(matrix, k, matrix.front().front(), matrix.back().back());   
    int helper(const vector<vector<int>>& matrix, int k, int left, int right) {
        if (left >= right) {
            return left;
        //pick a reference value
        int mid = left + (right - left)/2;
        int n = matrix.size(); 
        int m = 0;
        //counts number of elements smaller than or equal to mid
        for (int i = 0; i < n; i++) {
             *The upper bound idea if optimized from:
              for (int j = 0; j < matrix[i].size(); j++) {
                  if (matrix[i][j] <= mid) {
             * using for loop is the basic idea to help understand,
             * actually we can use binary search again to find the smallest element that larger than mid

            int num = upper_bound(matrix[i].begin(), matrix[i].end(), mid) - matrix[i].begin();
            m += num;
        // now we know that there are m elements <= mid[include mid]
        // thus mid is m th smallest element in the matrix
        // if m == k, actually mid is the kth 
        // if m < k, then kth smallest element must appear after mid
        // otherwise, kth element must appear before mid

        if (m < k) {
            return helper(matrix, k, mid + 1, right);
        } else {
            return helper(matrix, k, left, mid);

Kth Smallest Sum In Two Sorted Arrays

class Cell {
  int i;
  int j;
  int sum;
  Cell(int _i, int _j, int _sum) {
    i = _i;
    j = _j;
    sum = _sum;

  bool operator < (const Cell & c) const {
    return sum <= c.sum;

  bool operator > (const Cell & c) const {
    return sum > c.sum;

class Solution {
  int kthSum(vector<int> a, vector<int> b, int k) {
    // Write your solution here
    priority_queue<Cell, vector<Cell>, greater<Cell>> min_heap;
    vector<vector<bool>> visited(a.size(), vector<bool>(b.size(), false));
    visited[0][0] = true;
    min_heap.emplace(Cell(0, 0, a[0] + b[0]));
    for (int i = 0; i < k - 1; i++) {
      Cell cur = min_heap.top();
      if (cur.i + 1 < a.size() && !visited[cur.i + 1][cur.j]) {
        int sum = a[cur.i + 1] + b[cur.j];
        min_heap.emplace(Cell(cur.i + 1, cur.j, sum));
        visited[cur.i + 1][cur.j] = true;

      if (cur.j + 1 < b.size() && !visited[cur.i][cur.j + 1]) {
        int sum = a[cur.i] + b[cur.j + 1];
        min_heap.emplace(Cell(cur.i, cur.j + 1, sum));
        visited[cur.i][cur.j + 1] = true;
    return min_heap.top().sum;

Kth Smallest With Only 3, 5, 7 As Factors

  long kth(int k) {
    // Write your solution here.
    priority_queue<long, vector<long>, greater<long>> min_heap;
    set<long> visited;
    for (int i = 0; i < k - 1; i++) {
      long cur = min_heap.top();
      if (visited.find(cur * 3) == visited.end()) {
        min_heap.emplace(cur * 3);
        visited.emplace(cur * 3);

      if (visited.find(cur * 5) == visited.end()) {
        min_heap.emplace(cur * 5);
        visited.emplace(cur * 5);

      if (visited.find(cur * 7) == visited.end()) {
        min_heap.emplace(cur * 7);
        visited.emplace(cur * 7);

    return min_heap.top();

Kth Closest Point

class Point {
    int x;
    int y;
    int z;
    double dis;
    Point (int _x, int _y, int _z, double _dis) {
        x = _x;
        y = _y;
        z = _z;
        dis = _dis;

    bool operator < (const Point & p1) const {
        return dis <= p1.dis;

    bool operator > (const Point & p1) const {
        return dis > p1.dis;


class Solution {
    vector<int> closest(vector<int> a, vector<int> b, vector<int> c, int k) {
        priority_queue<Point, vector<Point>, greater<Point>> min_heap;
        set<vector<int>> visited;
        double d = sqrt(a[0] * a[0] + b[0] * b[0] + c[0] * c[0] + 0.0);
        Point* start = new Point(0,0,0,d);
        for (int i = 0; i < k - 1; i++) {
            Point p = min_heap.top();
            if (p.x + 1 < a.size()) {
                double d = sqrt(a[p.x + 1] * a[p.x + 1] + b[p.y] * b[p.y] + c[p.z] * c[p.z] + 0.0);
                Point* temp = new Point(p.x + 1,p.y,p.z,d);
                if (visited.find({p.x + 1,p.y,p.z}) == visited.end()) {
                    vector<int> v = {p.x + 1, p.y, p.z};
            if (p.y + 1 < b.size()) {
                double d = sqrt(a[p.x] * a[p.x] + b[p.y + 1] * b[p.y + 1] + c[p.z] * c[p.z] + 0.0);
                Point* temp = new Point(p.x,p.y + 1,p.z,d);
                if (visited.find({p.x,p.y + 1,p.z}) == visited.end()) {
                    vector<int> v = {p.x,p.y + 1,p.z};

            if (p.z + 1 < c.size()) {
                double d = sqrt(a[p.x] * a[p.x] + b[p.y] * b[p.y] + c[p.z + 1] * c[p.z + 1] + 0.0);
                Point* temp = new Point(p.x,p.y,p.z + 1,d);
                if (visited.find({p.x,p.y,p.z + 1}) == visited.end()) {
                    vector<int> v = {p.x,p.y,p.z + 1};

        Point rst = min_heap.top();
        return {a[rst.x], b[rst.y], c[rst.z]};


Solution with quick-sort partition

Kth Largest Element in an Array

Find the kth largest element in an unsorted array. Note that it is the kth largest element in the sorted order, not the kth distinct element.

class Solution {
    int partition(vector<int> & nums, int left, int right) {        
        int pivot_index = left + rand() % (right - left + 1); 
        int pivot = nums[pivot_index];
        swap(nums[left], nums[pivot_index]);
        int left_bound = left + 1;
        int right_bound = right;
        while (left_bound <= right_bound) {
            if (nums[right_bound] <= pivot) {
            else if (nums[left_bound] >= pivot) {
            else {
                swap(nums[left_bound++], nums[right_bound--]);
        swap(nums[left], nums[right_bound]);
        return right_bound;
    int findKthLargest(vector<int>& nums, int k) {
        int left = 0;
        int right = nums.size() - 1;
        while (true) {
            int pos = partition(nums, left, right); 
            if (pos == k - 1) {
                return nums[pos];
            else if (pos > k - 1) {
                right = pos - 1;
            else {
                left = pos + 1;