1
// SPDX-License-Identifier: Apache-2.0
2

            
3
use std::env;
4
use std::fs::File;
5
use std::io::{self, Error, ErrorKind, Read, Seek, SeekFrom};
6
use std::path::{Path, PathBuf};
7

            
8
use super::common;
9

            
10
//================================================
11
// Validation
12
//================================================
13

            
14
/// Extracts the ELF class from the ELF header in a shared library.
15
36
fn parse_elf_header(path: &Path) -> io::Result<u8> {
16
36
    let mut file = File::open(path)?;
17
36
    let mut buffer = [0; 5];
18
36
    file.read_exact(&mut buffer)?;
19
36
    if buffer[..4] == [127, 69, 76, 70] {
20
36
        Ok(buffer[4])
21
    } else {
22
        Err(Error::new(ErrorKind::InvalidData, "invalid ELF header"))
23
    }
24
}
25

            
26
/// Extracts the magic number and machine type from the PE header in a shared library.
27
fn parse_pe_header(path: &Path) -> io::Result<(u16, u16)> {
28
    let mut file = File::open(path)?;
29

            
30
    // Extract the header offset.
31
    let mut buffer = [0; 4];
32
    let start = SeekFrom::Start(0x3C);
33
    file.seek(start)?;
34
    file.read_exact(&mut buffer)?;
35
    let offset = i32::from_le_bytes(buffer);
36

            
37
    // Check the validity of the header.
38
    file.seek(SeekFrom::Start(offset as u64))?;
39
    file.read_exact(&mut buffer)?;
40
    if buffer != [80, 69, 0, 0] {
41
        return Err(Error::new(ErrorKind::InvalidData, "invalid PE header"));
42
    }
43

            
44
    // Extract the magic number.
45
    let mut buffer = [0; 2];
46
    file.seek(SeekFrom::Current(20))?;
47
    file.read_exact(&mut buffer)?;
48
    let magic_number = u16::from_le_bytes(buffer);
49

            
50
    // Extract the machine type.
51
    let mut buffer = [0; 2];
52
    file.seek(SeekFrom::Current(-22))?;
53
    file.read_exact(&mut buffer)?;
54
    let machine_type = u16::from_le_bytes(buffer);
55

            
56
    return Ok((magic_number, machine_type));
57
}
58

            
59
/// Checks that a `libclang` shared library matches the target platform.
60
36
fn validate_library(path: &Path) -> Result<(), String> {
61
36
    if target_os!("linux") || target_os!("freebsd") {
62
36
        let class = parse_elf_header(path).map_err(|e| e.to_string())?;
63

            
64
36
        if target_pointer_width!("32") && class != 1 {
65
            return Err("invalid ELF class (64-bit)".into());
66
        }
67

            
68
36
        if target_pointer_width!("64") && class != 2 {
69
            return Err("invalid ELF class (32-bit)".into());
70
        }
71

            
72
36
        Ok(())
73
    } else if target_os!("windows") {
74
        let (magic, machine_type) = parse_pe_header(path).map_err(|e| e.to_string())?;
75

            
76
        if target_pointer_width!("32") && magic != 267 {
77
            return Err("invalid DLL (64-bit)".into());
78
        }
79

            
80
        if target_pointer_width!("64") && magic != 523 {
81
            return Err("invalid DLL (32-bit)".into());
82
        }
83

            
84
        let arch_mismatch = match machine_type {
85
            0x014C if !target_arch!("x86") => Some("x86"),
86
            0x8664 if !target_arch!("x86_64") => Some("x86-64"),
87
            0xAA64 if !target_arch!("aarch64") => Some("ARM64"),
88
            _ => None,
89
        };
90

            
91
        if let Some(arch) = arch_mismatch {
92
            Err(format!("invalid DLL ({arch})"))
93
        } else {
94
            Ok(())
95
        }
96
    } else {
97
        Ok(())
98
    }
99
}
100

            
101
//================================================
102
// Searching
103
//================================================
104

            
105
/// Extracts the version components in a `libclang` shared library filename.
106
36
fn parse_version(filename: &str) -> Vec<u32> {
107
36
    let version = if let Some(version) = filename.strip_prefix("libclang.so.") {
108
6
        version
109
30
    } else if filename.starts_with("libclang-") {
110
28
        &filename[9..filename.len() - 3]
111
    } else {
112
2
        return vec![];
113
    };
114

            
115
64
    version.split('.').map(|s| s.parse().unwrap_or(0)).collect()
116
}
117

            
118
/// Finds `libclang` shared libraries and returns the paths to, filenames of,
119
/// and versions of those shared libraries.
120
2
fn search_libclang_directories(runtime: bool) -> Result<Vec<(PathBuf, String, Vec<u32>)>, String> {
121
2
    let mut files = vec![format!(
122
2
        "{}clang{}",
123
        env::consts::DLL_PREFIX,
124
        env::consts::DLL_SUFFIX
125
    )];
126

            
127
2
    if target_os!("linux") {
128
        // Some Linux distributions don't create a `libclang.so` symlink, so we
129
        // need to look for versioned files (e.g., `libclang-3.9.so`).
130
2
        files.push("libclang-*.so".into());
131

            
132
        // Some Linux distributions don't create a `libclang.so` symlink and
133
        // don't have versioned files as described above, so we need to look for
134
        // suffix versioned files (e.g., `libclang.so.1`). However, `ld` cannot
135
        // link to these files, so this will only be included when linking at
136
        // runtime.
137
2
        if runtime {
138
2
            files.push("libclang.so.*".into());
139
2
            files.push("libclang-*.so.*".into());
140
        }
141
    }
142

            
143
2
    if target_os!("freebsd") || target_os!("haiku") || target_os!("netbsd") || target_os!("openbsd") {
144
        // Some BSD distributions don't create a `libclang.so` symlink either,
145
        // but use a different naming scheme for versioned files (e.g.,
146
        // `libclang.so.7.0`).
147
        files.push("libclang.so.*".into());
148
    }
149

            
150
2
    if target_os!("windows") {
151
        // The official LLVM build uses `libclang.dll` on Windows instead of
152
        // `clang.dll`. However, unofficial builds such as MinGW use `clang.dll`.
153
        files.push("libclang.dll".into());
154
    }
155

            
156
    // Find and validate `libclang` shared libraries and collect the versions.
157
2
    let mut valid = vec![];
158
2
    let mut invalid = vec![];
159
36
    for (directory, filename) in common::search_libclang_directories(&files, "LIBCLANG_PATH") {
160
36
        let path = directory.join(&filename);
161
36
        match validate_library(&path) {
162
            Ok(()) => {
163
36
                let version = parse_version(&filename);
164
36
                valid.push((directory, filename, version))
165
            }
166
            Err(message) => invalid.push(format!("({}: {})", path.display(), message)),
167
        }
168
    }
169

            
170
2
    if !valid.is_empty() {
171
2
        return Ok(valid);
172
    }
173

            
174
    let message = format!(
175
        "couldn't find any valid shared libraries matching: [{}], set the \
176
         `LIBCLANG_PATH` environment variable to a path where one of these files \
177
         can be found (invalid: [{}])",
178
        files
179
            .iter()
180
            .map(|f| format!("'{}'", f))
181
            .collect::<Vec<_>>()
182
            .join(", "),
183
        invalid.join(", "),
184
    );
185

            
186
    Err(message)
187
}
188

            
189
/// Finds the "best" `libclang` shared library and returns the directory and
190
/// filename of that library.
191
2
pub fn find(runtime: bool) -> Result<(PathBuf, String), String> {
192
2
    search_libclang_directories(runtime)?
193
2
        .iter()
194
        // We want to find the `libclang` shared library with the highest
195
        // version number, hence `max_by_key` below.
196
        //
197
        // However, in the case where there are multiple such `libclang` shared
198
        // libraries, we want to use the order in which they appeared in the
199
        // list returned by `search_libclang_directories` as a tiebreaker since
200
        // that function returns `libclang` shared libraries in descending order
201
        // of preference by how they were found.
202
        //
203
        // `max_by_key`, perhaps surprisingly, returns the *last* item with the
204
        // maximum key rather than the first which results in the opposite of
205
        // the tiebreaking behavior we want. This is easily fixed by reversing
206
        // the list first.
207
2
        .rev()
208
2
        .max_by_key(|f| &f.2)
209
2
        .cloned()
210
2
        .map(|(path, filename, _)| (path, filename))
211
2
        .ok_or_else(|| "unreachable".into())
212
}
213

            
214
//================================================
215
// Linking
216
//================================================
217

            
218
/// Finds and links to a `libclang` shared library.
219
#[cfg(not(feature = "runtime"))]
220
pub fn link() {
221
    let cep = common::CommandErrorPrinter::default();
222

            
223
    use std::fs;
224

            
225
    let (directory, filename) = find(false).unwrap();
226
    println!("cargo:rustc-link-search={}", directory.display());
227

            
228
    if cfg!(all(target_os = "windows", target_env = "msvc")) {
229
        // Find the `libclang` stub static library required for the MSVC
230
        // toolchain.
231
        let lib = if !directory.ends_with("bin") {
232
            directory
233
        } else {
234
            directory.parent().unwrap().join("lib")
235
        };
236

            
237
        if lib.join("libclang.lib").exists() {
238
            println!("cargo:rustc-link-search={}", lib.display());
239
        } else if lib.join("libclang.dll.a").exists() {
240
            // MSYS and MinGW use `libclang.dll.a` instead of `libclang.lib`.
241
            // It is linkable with the MSVC linker, but Rust doesn't recognize
242
            // the `.a` suffix, so we need to copy it with a different name.
243
            //
244
            // FIXME: Maybe we can just hardlink or symlink it?
245
            let out = env::var("OUT_DIR").unwrap();
246
            fs::copy(
247
                lib.join("libclang.dll.a"),
248
                Path::new(&out).join("libclang.lib"),
249
            )
250
            .unwrap();
251
            println!("cargo:rustc-link-search=native={}", out);
252
        } else {
253
            panic!(
254
                "using '{}', so 'libclang.lib' or 'libclang.dll.a' must be \
255
                 available in {}",
256
                filename,
257
                lib.display(),
258
            );
259
        }
260

            
261
        println!("cargo:rustc-link-lib=dylib=libclang");
262
    } else {
263
        let name = filename.trim_start_matches("lib");
264

            
265
        // Strip extensions and trailing version numbers (e.g., the `.so.7.0` in
266
        // `libclang.so.7.0`).
267
        let name = match name.find(".dylib").or_else(|| name.find(".so")) {
268
            Some(index) => &name[0..index],
269
            None => name,
270
        };
271

            
272
        println!("cargo:rustc-link-lib=dylib={}", name);
273
    }
274

            
275
    cep.discard();
276
}