跳至主要内容

Leetcode 2256. Count Nodes Equal to Average of Subtree

· 閱讀時間約 6 分鐘
Vincent Chi

久違的 Leetcode Daily,今天剛好在參考其它人的解答時想到一些有趣的議題可以分享。

關於題目的分析與解答我就不另外贅述,因為今天要討論的這篇參考解答已經寫得非常精采。

參考解答分析

在參考解答中,vanAmsen 給出了一共 8 種程式語言的解法,儘管思路上大同小異但卻不失為一個很好的參考。

以下就用它的 C++ 版本作為範例:

class Solution {
public:
int averageOfSubtree(TreeNode* root) {
int result = 0;
traverse(root, result);
return result;
}

private:
pair<int, int> traverse(TreeNode* node, int& result) {
if (!node) return {0, 0};

auto [left_sum, left_count] = traverse(node->left, result);
auto [right_sum, right_count] = traverse(node->right, result);

int curr_sum = node->val + left_sum + right_sum;
int curr_count = 1 + left_count + right_count;

if (curr_sum / curr_count == node->val) result++;

return {curr_sum, curr_count};
}
};

得益於 C++ 的 STL pair 容器 與 C++ 17 的 Structured Binding 特性,使得 C++ 可以一次性回傳多個值。

許多現代化的程式語言(例如 Go 或 Python)都允許函式回傳多個值,這很大程度上幫助工程師專注於邏輯。

以 C 實作

我是一個比較傳統的人,像這種題目我一般會偏好用 C 作答。此時便衍生出一個問題:應該如何使 C 回傳多個值?

結構體

眾所周知,用結構體就可以把多個變數包起來,可以藉由定義一個結構體 Pair 來達成這個目標:

typedef struct Pair {
int sum;
int cnt;
} P;

P traverse(struct TreeNode *n, int *max) {
if (!n) return (P){0, 0};

P left = traverse(n->left, max);
P right = traverse(n->right, max);

int curr_sum = n->val + left.sum + right.sum;
int curr_cnt = 1 + left.cnt + right.cnt;

if (curr_sum / curr_cnt == n->val) (*max)++;

return (P){curr_sum, curr_cnt};
}

int averageOfSubtree(struct TreeNode *root){
int ans = 0;
traverse(root, &ans);

return ans;
}

用整數代表兩個(或以上)的值

如果說,能夠把 sumcnt 的資訊包裝在一個整數之中,那是否就能夠避免使用結構體?

|--------- uint64_t ----------|
| sum (32-bit) | cnt (32-bit) |
|-----------------------------|
#define SUM(val) (val >> 32)
#define CNT(val) (val & 0xffffffff)
#define RES(sum, cnt) (((sum) << 32) + (cnt))

uint64_t traverse(struct TreeNode *n, int *max) {
if (!n) return 0;

uint64_t left = traverse(n->left, max),
right = traverse(n->right, max);

uint64_t curr_sum = n->val + SUM(left) + SUM(right),
curr_cnt = 1 + CNT(left) + CNT(right);

if (curr_sum/curr_cnt == n->val) (*max)++;

return RES(curr_sum, curr_cnt);
}

int averageOfSubtree(struct TreeNode *root){
int ans = 0;
traverse(root, &ans);

return ans;
}

uint64_t 似乎有點太浪費了,那有沒有可能使用 uint32_t 代替呢?

|--------- uint32_t ----------|
| sum (16-bit) | cnt (16-bit) |
|-----------------------------|
#define SUM(val) (val >> 16)
#define CNT(val) (val & 0xffff)
#define RES(sum, cnt) (((sum) << 16) + (cnt))

uint32_t traverse(struct TreeNode *n, int *max) {
if (!n) return 0;

uint32_t left = traverse(n->left, max),
right = traverse(n->right, max);

uint32_t curr_sum = n->val + SUM(left) + SUM(right),
curr_cnt = 1 + CNT(left) + CNT(right);

if (curr_sum/curr_cnt == n->val) (*max)++;

return RES(curr_sum, curr_cnt);
}

可惜不行,經過實測似乎有些極端測資會讓 sum > (2 ** 16) 導致 Overflow 而失敗。

重新檢視題目條件有這麼一條:The number of nodes in the tree is in the range [1, 1000].,也就是說整顆樹最多不會超過 1000 個節點。換句話說,cnt 的值絕對不會超過 1000,這麼一來可以重新設計記憶體排佈:

信息

其實還有另一個條件:0 <= Node.val <= 1000

因此 sum 的極值會小於 1,000,000,事實上 20 個 bits 就可以存放(2 ** 20 = 1,048,576

|--------- uint32_t ----------|
| sum (22-bit) | cnt (10-bit) |
|-----------------------------|
#define SUM(val) (val >> 10)
#define CNT(val) (val & 0x3ff)
#define RES(sum, cnt) (((sum) << 10) + (cnt))

uint32_t traverse(struct TreeNode *n, int *max) {
if (!n) return 0;

uint32_t left = traverse(n->left, max),
right = traverse(n->right, max);

uint32_t curr_sum = n->val + SUM(left) + SUM(right),
curr_cnt = 1 + CNT(left) + CNT(right);

if (curr_sum/curr_cnt == n->val) (*max)++;

return RES(curr_sum, curr_cnt);
}

因為朋友提醒,這邊附上 bit field struct 解法:

typedef struct Pair {
int sum:22;
unsigned int cnt:10;
} P;

P traverse(struct TreeNode *n, int *max) {
if (!n) return (P){0, 0};

P left = traverse(n->left, max);
P right = traverse(n->right, max);

int curr_sum = n->val + left.sum + right.sum;
int curr_cnt = 1 + left.cnt + right.cnt;

if (curr_sum / curr_cnt == n->val) (*max)++;

return (P){curr_sum, curr_cnt};
}

int averageOfSubtree(struct TreeNode *root){
int ans = 0;
traverse(root, &ans);

return ans;
}

補充

或許有人發現到了,在 #define RES(sum, cnt) (((sum) << 10) + (cnt)) 中我用了 + 而非 |,這其實是刻意為之。

在 x86-64 GCC 8.2 -O1(Leetcode 中的 C Compiler 與 Optimize Level)中:

.L3:
sal ecx, 10
lea eax, [rcx+rsi]
pop rbx
pop rbp
pop r12
ret

Ref

可以見到,使用 + 運算符會使用 LEA 指令並減少一個 MOV 指令(x86-64 限定),這個技巧來自於 Linux Kernel #b0687c1

參考資料

  1. C++ Pair: https://cplusplus.com/reference/utility/pair/
  2. C++ 17 Structured Binding: https://en.cppreference.com/w/cpp/language/structured_binding
  3. Linux Kernel #b0687c1: https://github.com/torvalds/linux/commit/b0687c1119b4e8c88a651b6e876b7eae28d076e3